This is an automated email from the ASF dual-hosted git repository.
sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9a1c742098c2 [SPARK-53751][SDP] Explicit Versioned Checkpoint Location
9a1c742098c2 is described below
commit 9a1c742098c24c7031aa6f494827b13a41b458fa
Author: Jacky Wang <[email protected]>
AuthorDate: Wed Oct 8 12:27:52 2025 -0700
[SPARK-53751][SDP] Explicit Versioned Checkpoint Location
### What changes were proposed in this pull request?
Add a `storage` field in pipeline spec to allow users specify locations of
metadata such as streaming checkpoints.
Below is the structure of the directory, which offer supports for
multi-flow and versioned directory where version number is incremented after a
full refresh.
```
storage-root/
└── _checkpoints/ # checkpoint root
├── myst/
│ ├── flow1/
│ │ ├── 0/ # version 0
│ │ │ ├── commits/
│ │ │ ├── offsets/
│ │ │ └── sources/
│ │ └── 1/ # version 1
│ │
│ └── flow2/
│ ├── 0/
│ │ ├── commits/
│ │ ├── offsets/
│ │ └── sources/
│ └── 1/
│
└── mysink/
└── flowA/
├── 0/
│ ├── commits/
│ ├── offsets/
│ └── sources/
└── 1/
```
### Why are the changes needed?
Currently, SDP stores streaming flow ckpts in the table path. This does not
allow support for versioned checkpoints and does not work for sinks.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
New and existing tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52487 from
JiaqiWang18/SPARK-53751-explicit-versioned-multiflow-checkpoint.
Authored-by: Jacky Wang <[email protected]>
Signed-off-by: Sandy Ryza <[email protected]>
---
python/pyspark/pipelines/cli.py | 16 +-
python/pyspark/pipelines/init_cli.py | 1 +
python/pyspark/pipelines/spark_connect_pipeline.py | 4 +
python/pyspark/pipelines/tests/test_cli.py | 42 +++-
.../pyspark/pipelines/tests/test_spark_connect.py | 2 +
python/pyspark/sql/connect/proto/pipelines_pb2.py | 54 ++---
python/pyspark/sql/connect/proto/pipelines_pb2.pyi | 16 ++
.../main/protobuf/spark/connect/pipelines.proto | 3 +
.../sql/connect/pipelines/PipelinesHandler.scala | 8 +-
.../spark/sql/connect/SparkConnectServerTest.scala | 4 +-
.../sql/connect/pipelines/EndToEndAPISuite.scala | 1 +
.../pipelines/PipelineEventStreamSuite.scala | 8 +-
.../pipelines/PipelineRefreshFunctionalSuite.scala | 8 +
.../connect/pipelines/PythonPipelineSuite.scala | 32 +--
.../SparkDeclarativePipelinesServerTest.scala | 10 +-
.../service/SparkConnectSessionHolderSuite.scala | 3 +-
.../service/SparkConnectSessionManagerSuite.scala | 3 +-
.../spark/sql/pipelines/graph/FlowExecution.scala | 9 +
.../spark/sql/pipelines/graph/FlowPlanner.scala | 3 +-
.../sql/pipelines/graph/PipelineExecution.scala | 4 +
.../pipelines/graph/PipelineUpdateContext.scala | 5 +
.../graph/PipelineUpdateContextImpl.scala | 3 +-
.../apache/spark/sql/pipelines/graph/State.scala | 104 +++++++++
.../spark/sql/pipelines/graph/SystemMetadata.scala | 126 ++++++++++
.../pipelines/graph/MaterializeTablesSuite.scala | 94 +++++---
.../sql/pipelines/graph/SystemMetadataSuite.scala | 260 +++++++++++++++++++++
.../graph/TriggeredGraphExecutionSuite.scala | 32 +--
.../spark/sql/pipelines/graph/ViewSuite.scala | 33 ++-
.../pipelines/utils/BaseCoreExecutionTest.scala | 5 +-
.../spark/sql/pipelines/utils/ExecutionTest.scala | 1 +
.../spark/sql/pipelines/utils/PipelineTest.scala | 17 +-
.../sql/pipelines/utils/StorageRootMixin.scala | 48 ++++
32 files changed, 822 insertions(+), 137 deletions(-)
diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py
index 5342b7fb430c..ca198f1c3aff 100644
--- a/python/pyspark/pipelines/cli.py
+++ b/python/pyspark/pipelines/cli.py
@@ -90,6 +90,7 @@ class PipelineSpec:
"""Spec for a pipeline.
:param name: The name of the pipeline.
+ :param storage: The root directory for storing metadata, such as streaming
checkpoints.
:param catalog: The default catalog to use for the pipeline.
:param database: The default database to use for the pipeline.
:param configuration: A dictionary of Spark configuration properties to
set for the pipeline.
@@ -97,6 +98,7 @@ class PipelineSpec:
"""
name: str
+ storage: str
catalog: Optional[str]
database: Optional[str]
configuration: Mapping[str, str]
@@ -150,8 +152,16 @@ def load_pipeline_spec(spec_path: Path) -> PipelineSpec:
def unpack_pipeline_spec(spec_data: Mapping[str, Any]) -> PipelineSpec:
- ALLOWED_FIELDS = {"name", "catalog", "database", "schema",
"configuration", "libraries"}
- REQUIRED_FIELDS = ["name"]
+ ALLOWED_FIELDS = {
+ "name",
+ "storage",
+ "catalog",
+ "database",
+ "schema",
+ "configuration",
+ "libraries",
+ }
+ REQUIRED_FIELDS = ["name", "storage"]
for key in spec_data.keys():
if key not in ALLOWED_FIELDS:
raise PySparkException(
@@ -167,6 +177,7 @@ def unpack_pipeline_spec(spec_data: Mapping[str, Any]) ->
PipelineSpec:
return PipelineSpec(
name=spec_data["name"],
+ storage=spec_data["storage"],
catalog=spec_data.get("catalog"),
database=spec_data.get("database", spec_data.get("schema")),
configuration=validate_str_dict(spec_data.get("configuration", {}),
"configuration"),
@@ -323,6 +334,7 @@ def run(
full_refresh_all=full_refresh_all,
refresh=refresh,
dry=dry,
+ storage=spec.storage,
)
try:
handle_pipeline_events(result_iter)
diff --git a/python/pyspark/pipelines/init_cli.py
b/python/pyspark/pipelines/init_cli.py
index 47be703f7795..ffe5d3c12b63 100644
--- a/python/pyspark/pipelines/init_cli.py
+++ b/python/pyspark/pipelines/init_cli.py
@@ -19,6 +19,7 @@ from pathlib import Path
SPEC = """
name: {{ name }}
+storage: storage-root
libraries:
- glob:
include: transformations/**
diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py
b/python/pyspark/pipelines/spark_connect_pipeline.py
index 61b72956e5cc..e3c1184cea39 100644
--- a/python/pyspark/pipelines/spark_connect_pipeline.py
+++ b/python/pyspark/pipelines/spark_connect_pipeline.py
@@ -72,6 +72,7 @@ def start_run(
full_refresh_all: bool,
refresh: Optional[Sequence[str]],
dry: bool,
+ storage: str,
) -> Iterator[Dict[str, Any]]:
"""Start a run of the dataflow graph in the Spark Connect server.
@@ -79,6 +80,8 @@ def start_run(
:param full_refresh: List of datasets to reset and recompute.
:param full_refresh_all: Perform a full graph reset and recompute.
:param refresh: List of datasets to update.
+ :param dry: If true, the run will not actually execute any flows, but only
validate the graph.
+ :param storage: The storage location to store metadata such as streaming
checkpoints.
"""
inner_command = pb2.PipelineCommand.StartRun(
dataflow_graph_id=dataflow_graph_id,
@@ -86,6 +89,7 @@ def start_run(
full_refresh_all=full_refresh_all,
refresh_selection=refresh or [],
dry=dry,
+ storage=storage,
)
command = pb2.Command()
command.pipeline_command.start_run.CopyFrom(inner_command)
diff --git a/python/pyspark/pipelines/tests/test_cli.py
b/python/pyspark/pipelines/tests/test_cli.py
index fbc6d3a90ac8..894b8be283f0 100644
--- a/python/pyspark/pipelines/tests/test_cli.py
+++ b/python/pyspark/pipelines/tests/test_cli.py
@@ -60,7 +60,8 @@ class CLIUtilityTests(unittest.TestCase):
},
"libraries": [
{"glob": {"include": "test_include"}}
- ]
+ ],
+ "storage": "storage_path",
}
"""
)
@@ -72,6 +73,7 @@ class CLIUtilityTests(unittest.TestCase):
assert spec.configuration == {"key1": "value1", "key2": "value2"}
assert len(spec.libraries) == 1
assert spec.libraries[0].include == "test_include"
+ assert spec.storage == "storage_path"
def test_load_pipeline_spec_name_is_required(self):
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
@@ -86,7 +88,8 @@ class CLIUtilityTests(unittest.TestCase):
},
"libraries": [
{"glob": {"include": "test_include"}}
- ]
+ ],
+ "storage": "storage_path",
}
"""
)
@@ -112,7 +115,8 @@ class CLIUtilityTests(unittest.TestCase):
},
"libraries": [
{"glob": {"include": "test_include"}}
- ]
+ ],
+ "storage": "storage_path",
}
"""
)
@@ -150,21 +154,38 @@ class CLIUtilityTests(unittest.TestCase):
def test_unpack_empty_pipeline_spec(self):
empty_spec = PipelineSpec(
- name="test_pipeline", catalog=None, database=None,
configuration={}, libraries=[]
+ name="test_pipeline",
+ storage="storage_path",
+ catalog=None,
+ database=None,
+ configuration={},
+ libraries=[],
+ )
+ self.assertEqual(
+ unpack_pipeline_spec({"name": "test_pipeline", "storage":
"storage_path"}), empty_spec
)
- self.assertEqual(unpack_pipeline_spec({"name": "test_pipeline"}),
empty_spec)
def test_unpack_pipeline_spec_bad_configuration(self):
with self.assertRaises(TypeError) as context:
- unpack_pipeline_spec({"name": "test_pipeline", "configuration":
"not_a_dict"})
+ unpack_pipeline_spec(
+ {"name": "test_pipeline", "storage": "storage_path",
"configuration": "not_a_dict"}
+ )
self.assertIn("should be a dict", str(context.exception))
with self.assertRaises(TypeError) as context:
- unpack_pipeline_spec({"name": "test_pipeline", "configuration":
{"key": {}}})
+ unpack_pipeline_spec(
+ {"name": "test_pipeline", "storage": "storage_path",
"configuration": {"key": {}}}
+ )
self.assertIn("key", str(context.exception))
with self.assertRaises(TypeError) as context:
- unpack_pipeline_spec({"name": "test_pipeline", "configuration":
{1: "something"}})
+ unpack_pipeline_spec(
+ {
+ "name": "test_pipeline",
+ "storage": "storage_path",
+ "configuration": {1: "something"},
+ }
+ )
self.assertIn("int", str(context.exception))
def test_find_pipeline_spec_in_current_directory(self):
@@ -239,6 +260,7 @@ class CLIUtilityTests(unittest.TestCase):
name="test_pipeline",
catalog=None,
database=None,
+ storage="storage_path",
configuration={},
libraries=[LibrariesGlob(include="subdir1/**")],
)
@@ -280,6 +302,7 @@ class CLIUtilityTests(unittest.TestCase):
"""Errors raised while executing definitions code should make it to
the outer context."""
spec = PipelineSpec(
name="test_pipeline",
+ storage="storage_path",
catalog=None,
database=None,
configuration={},
@@ -298,6 +321,7 @@ class CLIUtilityTests(unittest.TestCase):
def
test_register_definitions_unsupported_file_extension_matches_glob(self):
spec = PipelineSpec(
name="test_pipeline",
+ storage="storage_path",
catalog=None,
database=None,
configuration={},
@@ -352,6 +376,7 @@ class CLIUtilityTests(unittest.TestCase):
registry,
PipelineSpec(
name="test_pipeline",
+ storage="storage_path",
catalog=None,
database=None,
configuration={},
@@ -500,6 +525,7 @@ class CLIUtilityTests(unittest.TestCase):
"""
{
"name": "test_pipeline",
+ "storage": "storage_path",
"libraries": [
{"glob": {"include": "transformations/**/*.py"}}
]
diff --git a/python/pyspark/pipelines/tests/test_spark_connect.py
b/python/pyspark/pipelines/tests/test_spark_connect.py
index 6d81a98c8c44..0b54c0906f9a 100644
--- a/python/pyspark/pipelines/tests/test_spark_connect.py
+++ b/python/pyspark/pipelines/tests/test_spark_connect.py
@@ -58,6 +58,7 @@ class SparkConnectPipelinesTest(ReusedConnectTestCase):
refresh=None,
full_refresh_all=False,
dry=True,
+ storage="storage_path",
)
handle_pipeline_events(result_iter)
@@ -79,6 +80,7 @@ class SparkConnectPipelinesTest(ReusedConnectTestCase):
refresh=None,
full_refresh_all=False,
dry=True,
+ storage="storage_path",
)
with self.assertRaises(AnalysisException) as context:
handle_pipeline_events(result_iter)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py
b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index ba3d2f634a48..f53f91452454 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -41,7 +41,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xfb\x1b\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineD [...]
+
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xa6\x1c\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineD [...]
)
_globals = globals()
@@ -60,10 +60,10 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options
= b"8\001"
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options =
b"8\001"
- _globals["_DATASETTYPE"]._serialized_start = 4896
- _globals["_DATASETTYPE"]._serialized_end = 4993
+ _globals["_DATASETTYPE"]._serialized_start = 4939
+ _globals["_DATASETTYPE"]._serialized_end = 5036
_globals["_PIPELINECOMMAND"]._serialized_start = 168
- _globals["_PIPELINECOMMAND"]._serialized_end = 3747
+ _globals["_PIPELINECOMMAND"]._serialized_end = 3790
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1050
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1358
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start
= 1259
@@ -81,27 +81,27 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 2692
_globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 2750
_globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2865
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3144
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
3147
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 3346
-
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
= 3349
-
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
= 3507
-
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start =
3510
- _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end
= 3731
- _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 3750
- _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 4506
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 4122
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
4220
- _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start =
4223
- _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end =
4357
- _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start =
4360
- _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 4491
- _globals["_PIPELINEEVENTRESULT"]._serialized_start = 4508
- _globals["_PIPELINEEVENTRESULT"]._serialized_end = 4581
- _globals["_PIPELINEEVENT"]._serialized_start = 4583
- _globals["_PIPELINEEVENT"]._serialized_end = 4699
- _globals["_SOURCECODELOCATION"]._serialized_start = 4701
- _globals["_SOURCECODELOCATION"]._serialized_end = 4823
- _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 4825
- _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 4894
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3187
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
3190
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 3389
+
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
= 3392
+
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
= 3550
+
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start =
3553
+ _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end
= 3774
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 3793
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 4549
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 4165
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
4263
+ _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start =
4266
+ _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end =
4400
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start =
4403
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 4534
+ _globals["_PIPELINEEVENTRESULT"]._serialized_start = 4551
+ _globals["_PIPELINEEVENTRESULT"]._serialized_end = 4624
+ _globals["_PIPELINEEVENT"]._serialized_start = 4626
+ _globals["_PIPELINEEVENT"]._serialized_end = 4742
+ _globals["_SOURCECODELOCATION"]._serialized_start = 4744
+ _globals["_SOURCECODELOCATION"]._serialized_end = 4866
+ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 4868
+ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 4937
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index 575e15e0cb58..c42b3eefe2d4 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -578,6 +578,7 @@ class PipelineCommand(google.protobuf.message.Message):
FULL_REFRESH_ALL_FIELD_NUMBER: builtins.int
REFRESH_SELECTION_FIELD_NUMBER: builtins.int
DRY_FIELD_NUMBER: builtins.int
+ STORAGE_FIELD_NUMBER: builtins.int
dataflow_graph_id: builtins.str
"""The graph to start."""
@property
@@ -596,6 +597,8 @@ class PipelineCommand(google.protobuf.message.Message):
"""If true, the run will not actually execute any flows, but will only
validate the graph and
check for any errors. This is useful for testing and validation
purposes.
"""
+ storage: builtins.str
+ """storage location for pipeline checkpoints and metadata."""
def __init__(
self,
*,
@@ -604,6 +607,7 @@ class PipelineCommand(google.protobuf.message.Message):
full_refresh_all: builtins.bool | None = ...,
refresh_selection: collections.abc.Iterable[builtins.str] | None =
...,
dry: builtins.bool | None = ...,
+ storage: builtins.str | None = ...,
) -> None: ...
def HasField(
self,
@@ -614,12 +618,16 @@ class PipelineCommand(google.protobuf.message.Message):
b"_dry",
"_full_refresh_all",
b"_full_refresh_all",
+ "_storage",
+ b"_storage",
"dataflow_graph_id",
b"dataflow_graph_id",
"dry",
b"dry",
"full_refresh_all",
b"full_refresh_all",
+ "storage",
+ b"storage",
],
) -> builtins.bool: ...
def ClearField(
@@ -631,6 +639,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_dry",
"_full_refresh_all",
b"_full_refresh_all",
+ "_storage",
+ b"_storage",
"dataflow_graph_id",
b"dataflow_graph_id",
"dry",
@@ -641,6 +651,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"full_refresh_selection",
"refresh_selection",
b"refresh_selection",
+ "storage",
+ b"storage",
],
) -> None: ...
@typing.overload
@@ -656,6 +668,10 @@ class PipelineCommand(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_full_refresh_all",
b"_full_refresh_all"]
) -> typing_extensions.Literal["full_refresh_all"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_storage",
b"_storage"]
+ ) -> typing_extensions.Literal["storage"] | None: ...
class DefineSqlGraphElements(google.protobuf.message.Message):
"""Parses the SQL file and registers all datasets and flows."""
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index 130dcd54e59b..ef24899a5f74 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -143,6 +143,9 @@ message PipelineCommand {
// If true, the run will not actually execute any flows, but will only
validate the graph and
// check for any errors. This is useful for testing and validation
purposes.
optional bool dry = 5;
+
+ // storage location for pipeline checkpoints and metadata.
+ optional string storage = 6;
}
// Parses the SQL file and registers all datasets and flows.
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index 21f6c9492f68..358821289c85 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -338,11 +338,17 @@ private[connect] object PipelinesHandler extends Logging {
}
}
+ if (cmd.getStorage.isEmpty) {
+ // server-side validation to ensure that storage is always specified
+ throw new IllegalArgumentException("Storage must be specified to start
a run.")
+ }
+
val pipelineUpdateContext = new PipelineUpdateContextImpl(
graphElementRegistry.toDataflowGraph,
eventCallback,
tableFiltersResult.refresh,
- tableFiltersResult.fullRefresh)
+ tableFiltersResult.fullRefresh,
+ cmd.getStorage)
sessionHolder.cachePipelineExecution(dataflowGraphId,
pipelineUpdateContext)
if (cmd.getDry) {
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 92f64875d337..91e728d73e13 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -70,12 +70,12 @@ trait SparkConnectServerTest extends SharedSparkSession {
super.afterAll()
}
- override def beforeEach(): Unit = {
+ protected override def beforeEach(): Unit = {
super.beforeEach()
clearAllExecutions()
}
- override def afterEach(): Unit = {
+ protected override def afterEach(): Unit = {
clearAllExecutions()
super.afterEach()
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
index 923a85cb36f1..0a0cca488f1f 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
@@ -160,6 +160,7 @@ class EndToEndAPISuite extends PipelineTest with APITest
with SparkConnectServer
|name: test-pipeline
|${spec.catalog.map(catalog => s"""catalog: "$catalog"""").getOrElse("")}
|${spec.database.map(database => s"""database:
"$database"""").getOrElse("")}
+ |storage: "${projectDir.resolve("storage").toAbsolutePath}"
|configuration:
| "spark.remote": "sc://localhost:$serverPort"
|libraries:
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
index 83862545a723..214d3b86fb78 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
@@ -42,7 +42,11 @@ class PipelineEventStreamSuite extends
SparkDeclarativePipelinesServerTest {
val capturedEvents = new ArrayBuffer[PipelineEvent]()
withClient { client =>
val startRunRequest = buildStartRunPlan(
-
proto.PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())
+ proto.PipelineCommand.StartRun
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
+ .build())
val responseIterator = client.execute(startRunRequest)
while (responseIterator.hasNext) {
val response = responseIterator.next()
@@ -94,6 +98,7 @@ class PipelineEventStreamSuite extends
SparkDeclarativePipelinesServerTest {
proto.PipelineCommand.StartRun
.newBuilder()
.setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
.setDry(dry)
.build())
val ex = intercept[AnalysisException] {
@@ -144,6 +149,7 @@ class PipelineEventStreamSuite extends
SparkDeclarativePipelinesServerTest {
proto.PipelineCommand.StartRun
.newBuilder()
.setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
.setDry(true)
.build())
val responseIterator = client.execute(startRunRequest)
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala
index 794932544d5f..0e08b39abd19 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala
@@ -128,6 +128,7 @@ class PipelineRefreshFunctionalSuite
PipelineCommand.StartRun
.newBuilder()
.setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
.addAllFullRefreshSelection(List("a").asJava)
.build())
},
@@ -168,6 +169,7 @@ class PipelineRefreshFunctionalSuite
.setDataflowGraphId(graphId)
.addAllFullRefreshSelection(Seq("a", "mv").asJava)
.addRefreshSelection("b")
+ .setStorage(storageRoot)
.build())
},
expectedContentAfterRefresh = Map(
@@ -211,6 +213,7 @@ class PipelineRefreshFunctionalSuite
PipelineCommand.StartRun
.newBuilder()
.setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
.setFullRefreshAll(true)
.build())
},
@@ -231,6 +234,7 @@ class PipelineRefreshFunctionalSuite
val startRun = PipelineCommand.StartRun
.newBuilder()
.setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
.setFullRefreshAll(true)
.addRefreshSelection("a")
.build()
@@ -253,6 +257,7 @@ class PipelineRefreshFunctionalSuite
val startRun = PipelineCommand.StartRun
.newBuilder()
.setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
.setFullRefreshAll(true)
.addFullRefreshSelection("a")
.build()
@@ -275,6 +280,7 @@ class PipelineRefreshFunctionalSuite
val startRun = PipelineCommand.StartRun
.newBuilder()
.setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
.addRefreshSelection("a")
.addFullRefreshSelection("a")
.build()
@@ -298,6 +304,7 @@ class PipelineRefreshFunctionalSuite
val startRun = PipelineCommand.StartRun
.newBuilder()
.setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
.addRefreshSelection("a")
.addRefreshSelection("b")
.addFullRefreshSelection("a")
@@ -321,6 +328,7 @@ class PipelineRefreshFunctionalSuite
val startRun = PipelineCommand.StartRun
.newBuilder()
+ .setStorage(storageRoot)
.setDataflowGraphId(graphId)
.addRefreshSelection("spark_catalog.default.a")
.addFullRefreshSelection("a") // This should be treated as the same
table
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index b5461c415b0d..c4ed554ef978 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -129,7 +129,7 @@ class PythonPipelineSuite
| return df.select("name")
|""".stripMargin)
- val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph,
storageRoot)
updateContext.pipelineExecution.runPipeline()
assertFlowProgressEvent(
@@ -174,7 +174,7 @@ class PythonPipelineSuite
| return spark.readStream.table('mv2')
|""".stripMargin)
- val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph,
storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -602,7 +602,8 @@ class PythonPipelineSuite
)
""")
- val updateContext = new PipelineUpdateContextImpl(graph, _ => ())
+ val updateContext =
+ new PipelineUpdateContextImpl(graph, _ => (), storageRoot = storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -628,18 +629,19 @@ class PythonPipelineSuite
test("MV/ST with partition columns works") {
withTable("mv", "st") {
val graph = buildGraph("""
- |from pyspark.sql.functions import col
- |
- |@dp.materialized_view(partition_cols = ["id_mod"])
- |def mv():
- | return spark.range(5).withColumn("id_mod", col("id") % 2)
- |
- |@dp.table(partition_cols = ["id_mod"])
- |def st():
- | return spark.readStream.table("mv")
- |""".stripMargin)
-
- val updateContext = new PipelineUpdateContextImpl(graph, eventCallback =
_ => ())
+ |from pyspark.sql.functions import col
+ |
+ |@dp.materialized_view(partition_cols = ["id_mod"])
+ |def mv():
+ | return spark.range(5).withColumn("id_mod", col("id") % 2)
+ |
+ |@dp.table(partition_cols = ["id_mod"])
+ |def st():
+ | return spark.readStream.table("mv")
+ |""".stripMargin)
+
+ val updateContext =
+ new PipelineUpdateContextImpl(graph, eventCallback = _ => (),
storageRoot = storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala
index a31883677f92..32c728cb6a64 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala
@@ -24,9 +24,9 @@ import org.apache.spark.connect.proto.{PipelineCommand,
PipelineEvent}
import org.apache.spark.sql.connect.{SparkConnectServerTest,
SparkConnectTestUtils}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.{SessionHolder, SessionKey,
SparkConnectService}
-import org.apache.spark.sql.pipelines.utils.PipelineTest
+import org.apache.spark.sql.pipelines.utils.{PipelineTest, StorageRootMixin}
-class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest {
+class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest with
StorageRootMixin {
override def afterEach(): Unit = {
SparkConnectService.sessionManager
@@ -138,7 +138,11 @@ class SparkDeclarativePipelinesServerTest extends
SparkConnectServerTest {
def startPipelineAndWaitForCompletion(graphId: String):
ArrayBuffer[PipelineEvent] = {
val defaultStartRunCommand =
- PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build()
+ PipelineCommand.StartRun
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setStorage(storageRoot)
+ .build()
startPipelineAndWaitForCompletion(defaultStartRunCommand)
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index a110b0164f19..36c7e9bb80ec 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -434,7 +434,8 @@ class SparkConnectSessionHolderSuite extends
SharedSparkSession {
val graphId = "test_graph"
val pipelineUpdateContext = new PipelineUpdateContextImpl(
new DataflowGraph(Seq(), Seq(), Seq()),
- (_: PipelineEvent) => None)
+ (_: PipelineEvent) => None,
+ storageRoot = "test_storage_root")
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
assert(
sessionHolder.getPipelineExecution(graphId).nonEmpty,
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
index f8b387039a7a..ee227168dc23 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
@@ -160,7 +160,8 @@ class SparkConnectSessionManagerSuite extends
SharedSparkSession with BeforeAndA
val graphId = "test_graph"
val pipelineUpdateContext = new PipelineUpdateContextImpl(
new DataflowGraph(Seq(), Seq(), Seq()),
- (_: PipelineEvent) => None)
+ (_: PipelineEvent) => None,
+ storageRoot = "test_storage_root")
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
assert(
sessionHolder.getPipelineExecution(graphId).nonEmpty,
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
index c907d603199f..30ee77a2315e 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
@@ -191,6 +191,14 @@ trait StreamingFlowExecution extends FlowExecution with
Logging {
/** Starts a stream and returns its streaming query. */
protected def startStream(): StreamingQuery
+ private var _streamingQuery: Option[StreamingQuery] = None
+
+ /** Visible for testing */
+ def getStreamingQuery: StreamingQuery =
+ _streamingQuery.getOrElse(
+ throw new IllegalStateException(s"StreamingPhysicalFlow has not been
started")
+ )
+
/**
* Executes this `StreamingFlowExecution` by starting its stream with the
correct scheduling pool
* and confs.
@@ -201,6 +209,7 @@ trait StreamingFlowExecution extends FlowExecution with
Logging {
log"checkpoint location ${MDC(LogKeys.CHECKPOINT_PATH, checkpointPath)}"
)
val streamingQuery = SparkSessionUtils.withSqlConf(spark, sqlConf.toList:
_*)(startStream())
+ _streamingQuery = Option(streamingQuery)
Future(streamingQuery.awaitTermination())
}
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
index bb154a0081da..c4790e872951 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
@@ -50,6 +50,7 @@ class FlowPlanner(
updateContext = updateContext
)
case sf: StreamingFlow =>
+ val flowMetadata = FlowSystemMetadata(updateContext, sf, graph)
output match {
case o: Table =>
new StreamingTableWrite(
@@ -60,7 +61,7 @@ class FlowPlanner(
updateContext = updateContext,
sqlConf = sf.sqlConf,
trigger = triggerFor(sf),
- checkpointPath = output.path
+ checkpointPath = flowMetadata.latestCheckpointLocation
)
case _ =>
throw new UnsupportedOperationException(
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
index 5bb6e25eaf45..6090b764d543 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
@@ -46,6 +46,10 @@ class PipelineExecution(context: PipelineUpdateContext) {
def startPipeline(): Unit = synchronized {
// Initialize the graph.
val resolvedGraph = resolveGraph()
+ if (context.fullRefreshTables.nonEmpty) {
+ State.reset(resolvedGraph, context)
+ }
+
val initializedGraph = DatasetManager.materializeDatasets(resolvedGraph,
context)
// Execute the graph.
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala
index d6f202080933..d72c895180b5 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala
@@ -33,6 +33,11 @@ trait PipelineUpdateContext {
def resetCheckpointFlows: FlowFilter
+ /**
+ * The root storage location for pipeline metadata, including checkpoints
for streaming flows.
+ */
+ def storageRoot: String
+
/**
* Filter for which flows should be refreshed when performing this update.
Should be a superset of
* fullRefreshFlows.
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala
index 5a298c2f17d9..bb2009b25912 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala
@@ -32,7 +32,8 @@ class PipelineUpdateContextImpl(
override val unresolvedGraph: DataflowGraph,
override val eventCallback: PipelineEvent => Unit,
override val refreshTables: TableFilter = AllTables,
- override val fullRefreshTables: TableFilter = NoTables
+ override val fullRefreshTables: TableFilter = NoTables,
+ override val storageRoot: String
) extends PipelineUpdateContext {
override val spark: SparkSession = SparkSession.getActiveSession.getOrElse(
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
new file mode 100644
index 000000000000..31fc065cadc8
--- /dev/null
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.internal.{Logging, LogKeys}
+import org.apache.spark.sql.AnalysisException
+
+object State extends Logging {
+
+ /**
+ * Find the graph elements to reset given the current update context.
+ * @param graph The graph to reset.
+ * @param env The current update context.
+ */
+ private def findElementsToReset(graph: DataflowGraph, env:
PipelineUpdateContext): Seq[Input] = {
+ // If tableFilter is an instance of SomeTables, this is a refresh
selection and all tables
+ // to reset should be resettable; Otherwise, this is a full graph update,
and we reset all
+ // tables that are resettable.
+ val specifiedTablesToReset = {
+ val specifiedTables = env.fullRefreshTables.filter(graph.tables)
+ env.fullRefreshTables match {
+ case SomeTables(_) =>
+ specifiedTables.foreach { t =>
+ if (!PipelinesTableProperties.resetAllowed.fromMap(t.properties)) {
+ throw new AnalysisException(
+ "TABLE_NOT_RESETTABLE",
+ Map("tableName" -> t.displayName)
+ )
+ }
+ }
+ specifiedTables
+ case AllTables =>
+ specifiedTables.filter(t =>
PipelinesTableProperties.resetAllowed.fromMap(t.properties))
+ case NoTables => Seq.empty
+ }
+ }
+
+ specifiedTablesToReset.flatMap(t => t +:
graph.resolvedFlowsTo(t.identifier))
+ }
+
+ /**
+ * Performs the following on targets selected for full refresh:
+ * - Clearing checkpoint data
+ * - Truncating table data
+ */
+ def reset(resolvedGraph: DataflowGraph, env: PipelineUpdateContext):
Seq[Input] = {
+ val elementsToReset: Seq[Input] = findElementsToReset(resolvedGraph, env)
+
+ elementsToReset.foreach {
+ case f: ResolvedFlow => reset(f, env, resolvedGraph)
+ case _ => // tables is handled in materializeTables since hive metastore
does not support
+ // removing all columns from a table.
+ }
+
+ elementsToReset
+ }
+
+ /**
+ * Resets the checkpoint for the given flow by creating the next consecutive
directory.
+ */
+ private def reset(flow: ResolvedFlow, env: PipelineUpdateContext, graph:
DataflowGraph): Unit = {
+ logInfo(log"Clearing out state for flow ${MDC(LogKeys.FLOW_NAME,
flow.displayName)}")
+ val flowMetadata = FlowSystemMetadata(env, flow, graph)
+ flow match {
+ case f if flowMetadata.latestCheckpointLocationOpt().isEmpty =>
+ logInfo(
+ s"Skipping resetting flow ${f.identifier} since its destination not
been previously" +
+ s"materialized and we can't find the checkpoint location."
+ )
+ case _ =>
+ val hadoopConf = env.spark.sessionState.newHadoopConf()
+
+ // Write a new checkpoint folder if needed
+ val checkpointDir = new Path(flowMetadata.latestCheckpointLocation)
+ val fs1 = checkpointDir.getFileSystem(hadoopConf)
+ if (fs1.exists(checkpointDir)) {
+ val nextVersion = checkpointDir.getName.toInt + 1
+ val nextPath = new Path(checkpointDir.getParent,
nextVersion.toString)
+ fs1.mkdirs(nextPath)
+ logInfo(
+ log"Created new checkpoint for stream ${MDC(LogKeys.FLOW_NAME,
flow.displayName)} " +
+ log"at ${MDC(LogKeys.CHECKPOINT_PATH, nextPath.toString)}."
+ )
+ }
+ }
+ }
+}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
new file mode 100644
index 000000000000..d805f4a689ec
--- /dev/null
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.util.Try
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.internal.{Logging, LogKeys}
+import org.apache.spark.sql.classic.SparkSession
+
+sealed trait SystemMetadata {}
+
+/**
+ * Represents the system metadata associated with a [[Flow]].
+ */
+case class FlowSystemMetadata(
+ context: PipelineUpdateContext,
+ flow: Flow,
+ graph: DataflowGraph
+) extends SystemMetadata with Logging {
+
+ /**
+ * Returns the checkpoint root directory for a given flow
+ * which is storage/_checkpoints/flow_destination_table/flow_name.
+ * @return the checkpoint root directory for `flow`
+ */
+ private def flowCheckpointsDirOpt(): Option[Path] = {
+ Option(if (graph.table.contains(flow.destinationIdentifier)) {
+ val checkpointRoot = new Path(context.storageRoot, "_checkpoints")
+ val flowTableName = flow.destinationIdentifier.table
+ val flowName = flow.identifier.table
+ val checkpointDir = new Path(
+ new Path(checkpointRoot, flowTableName),
+ flowName
+ )
+ logInfo(
+ log"Flow ${MDC(LogKeys.FLOW_NAME, flowName)} using checkpoint " +
+ log"directory: ${MDC(LogKeys.CHECKPOINT_PATH, checkpointDir)}"
+ )
+ checkpointDir
+ } else {
+ throw new IllegalArgumentException(
+ s"Flow ${flow.identifier} does not have a valid destination for
checkpoints."
+ )
+ })
+ }
+
+ /** Returns the location for the most recent checkpoint of a given flow. */
+ def latestCheckpointLocation: String = {
+ val checkpointsDir = flowCheckpointsDirOpt().get
+ SystemMetadata.getLatestCheckpointDir(checkpointsDir)
+ }
+
+ /**
+ * Same as [[latestCheckpointLocation]] but returns None if the flow
checkpoints directory
+ * does not exist.
+ */
+ def latestCheckpointLocationOpt(): Option[String] = {
+ flowCheckpointsDirOpt().map { flowCheckpointsDir =>
+ SystemMetadata.getLatestCheckpointDir(flowCheckpointsDir)
+ }
+ }
+}
+
+object SystemMetadata {
+ private def spark = SparkSession.getActiveSession.get
+
+ /**
+ * Finds the largest checkpoint version subdirectory path within a
checkpoint directory, or
+ * creates and returns a version 0 subdirectory path if no versions exist.
+ * @param rootDir The root/parent directory where all the numbered
checkpoint subdirectories are
+ * stored
+ * @param createNewCheckpointDir If true, a new latest numbered checkpoint
directory should be
+ * created and returned
+ * @return The string URI path to the latest checkpoint directory
+ */
+ def getLatestCheckpointDir(
+ rootDir: Path,
+ createNewCheckpointDir: Boolean = false
+ ): String = {
+ val fs = rootDir.getFileSystem(spark.sessionState.newHadoopConf())
+ val defaultDir = new Path(rootDir, "0")
+ val checkpoint = if (fs.exists(rootDir)) {
+ val availableCheckpoints =
+ fs.listStatus(rootDir)
+ .toSeq
+ .sortBy(fs => Try(fs.getPath.getName.toInt).getOrElse(-1))
+ availableCheckpoints.lastOption
+ .filter(fs => Try(fs.getPath.getName.toInt).isSuccess)
+ .map(
+ latestCheckpoint =>
+ if (createNewCheckpointDir) {
+ val incrementedLatestCheckpointDir =
+ new Path(rootDir,
Math.max(latestCheckpoint.getPath.getName.toInt + 1, 0).toString)
+ fs.mkdirs(incrementedLatestCheckpointDir)
+ incrementedLatestCheckpointDir
+ } else {
+ latestCheckpoint.getPath
+ }
+ )
+ .getOrElse {
+ fs.mkdirs(defaultDir)
+ defaultDir
+ }
+ } else {
+ defaultDir
+ }
+ checkpoint.toUri.toString
+ }
+}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index 37c32a349866..afe769f7b204 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -52,7 +52,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
comment = Option("p-comment"),
query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2"))
)
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val identifier =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")
@@ -80,7 +81,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
comment = Option("p-comment"),
query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2"))
)
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val catalogTable2 = catalog.loadTable(identifier)
assert(
@@ -104,7 +106,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
comment = Option("p-comment"),
query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2"))
)
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val catalogTable3 = catalog.loadTable(identifier)
@@ -136,7 +139,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
)
registerTable("t1")
registerTable("t2")
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val identifier1 =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t1")
@@ -171,7 +175,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
"t1",
dfFlowFunc(Seq(1, 2, 3).toDF("x"))
)
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -201,7 +206,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
query = sqlFlowFunc(spark, "SELECT value AS timestamp FROM a")
)
}
- materializeGraph(new P1().resolveToDataflowGraph())
+ materializeGraph(new P1().resolveToDataflowGraph(), storageRoot =
storageRoot)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val b =
@@ -222,7 +227,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
query = sqlFlowFunc(spark, "SELECT timestamp FROM a")
)
}
- materializeGraph(new P2().resolveToDataflowGraph())
+ materializeGraph(new P2().resolveToDataflowGraph(), storageRoot =
storageRoot)
val b2 =
catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE),
"b"))
assert(
@@ -249,7 +254,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
new TestGraphRegistrationContext(spark) {
registerFlow("t2", "t2", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
registerTable("t2")
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val table2 = catalog.loadTable(identifier)
@@ -271,7 +277,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
new TestGraphRegistrationContext(spark) {
registerView("a", query = dfFlowFunc(streamInts.toDF()))
registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT value AS
x FROM STREAM a")))
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val streamStrings = MemoryStream[String]
@@ -282,7 +289,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
}.resolveToDataflowGraph()
val ex = intercept[TableMaterializationException] {
- materializeGraph(graph2)
+ materializeGraph(graph2, storageRoot = storageRoot)
}
val cause = ex.cause
val exStr = exceptionString(cause)
@@ -314,7 +321,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
),
query = dfFlowFunc(Seq[Short](1, 2).toDF("x"))
)
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val table2 = catalog.loadTable(identifier)
@@ -352,7 +360,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
query = Option(dfFlowFunc(source.toDF().select($"value" as "x")))
)
- }.resolveToDataflowGraph())
+ }.resolveToDataflowGraph(), storageRoot = storageRoot)
}
val cause = ex.cause
val exStr = exceptionString(cause)
@@ -365,7 +373,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
specifiedSchema = Option(new StructType().add("x", IntegerType)),
query = dfFlowFunc(Seq(1, 2).toDF("x"))
)
- }.resolveToDataflowGraph())
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot)
val table2 = catalog.loadTable(identifier)
assert(
table2.columns() sameElements CatalogV2Util
@@ -389,7 +398,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
),
partitionCols = Option(Seq("x2"))
)
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val identifier =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")
@@ -434,7 +444,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
"t7",
partitionCols = Option(Seq("x"))
)
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val table2 = catalog.loadTable(identifier)
@@ -454,7 +465,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
query = dfFlowFunc(Seq((true, 1), (false, 3)).toDF("x", "y"))
)
registerTable("t7")
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
}
assert(ex.cause.asInstanceOf[SparkThrowable].getCondition ==
"CANNOT_UPDATE_PARTITION_COLUMNS")
@@ -490,7 +502,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
}.resolveToDataflowGraph()
val ex = intercept[TableMaterializationException] {
- materializeGraph(graph)
+ materializeGraph(graph, storageRoot = storageRoot)
}
assert(ex.cause.asInstanceOf[SparkThrowable].getCondition ==
"CANNOT_UPDATE_PARTITION_COLUMNS")
val table = catalog.loadTable(identifier)
@@ -513,7 +525,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
query = Option(sqlFlowFunc(spark, "SELECT * FROM STREAM a")),
properties = Map("pipelines.reset.alloweD" -> "true", "some.prop" ->
"foo")
)
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -546,7 +559,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
}.resolveToDataflowGraph()
val ex1 =
intercept[TableMaterializationException] {
- materializeGraph(graph1)
+ materializeGraph(graph1, storageRoot = storageRoot)
}
assert(ex1.cause.isInstanceOf[IllegalArgumentException])
@@ -565,7 +578,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
registerTable("a", query =
Option(dfFlowFunc(spark.readStream.format("rate").load())))
}.resolveToDataflowGraph().validate()
- materializeGraph(graph1)
+ materializeGraph(graph1, storageRoot = storageRoot)
}
for (isFullRefresh <- Seq(true, false)) {
@@ -581,7 +594,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
registerMaterializedView("b", query = sqlFlowFunc(spark, "SELECT x
FROM a"))
}.resolveToDataflowGraph()
- val graph = materializeGraph(rawGraph)
+ val graph = materializeGraph(rawGraph, storageRoot = storageRoot)
val (refreshSelection, fullRefreshSelection) = if (isFullRefresh) {
(NoTables, AllTables)
} else {
@@ -595,9 +608,11 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
spark = spark,
unresolvedGraph = graph,
refreshTables = refreshSelection,
- fullRefreshTables = fullRefreshSelection
+ fullRefreshTables = fullRefreshSelection,
+ storageRoot = storageRoot
)
- )
+ ),
+ storageRoot = storageRoot
)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -613,7 +628,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
new TestGraphRegistrationContext(spark) {
registerView("a", query = dfFlowFunc(Seq((1, 2), (2, 3)).toDF("x",
"y")))
registerMaterializedView("b", query = sqlFlowFunc(spark, "SELECT y
FROM a"))
- }.resolveToDataflowGraph()
+ }.resolveToDataflowGraph(),
+ storageRoot = storageRoot
)
val table2 = catalog.loadTable(identifier)
assert(
@@ -650,10 +666,11 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
spark = spark,
unresolvedGraph = graph,
refreshTables = refreshSelection,
- fullRefreshTables = fullRefreshSelection
+ fullRefreshTables = fullRefreshSelection,
+ storageRoot = storageRoot
)
)
- materializeGraph(graph, contextOpt = updateContextOpt)
+ materializeGraph(graph, contextOpt = updateContextOpt, storageRoot =
storageRoot)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val identifier =
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b")
@@ -668,7 +685,8 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
registerView("a", query = dfFlowFunc(streamInts.toDF()))
registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT value
AS y FROM STREAM a")))
}.resolveToDataflowGraph().validate(),
- contextOpt = updateContextOpt
+ contextOpt = updateContextOpt,
+ storageRoot = storageRoot
)
val table2 = catalog.loadTable(identifier)
@@ -709,9 +727,11 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
spark = spark,
unresolvedGraph = graph,
refreshTables = SomeTables(Set(fullyQualifiedIdentifier("a"))),
- fullRefreshTables = SomeTables(Set(fullyQualifiedIdentifier("c")))
+ fullRefreshTables = SomeTables(Set(fullyQualifiedIdentifier("c"))),
+ storageRoot = storageRoot
)
- )
+ ),
+ storageRoot = storageRoot
)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -756,9 +776,9 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
)
)
}.resolveToDataflowGraph()
- materializeGraph(rawGraph)
+ materializeGraph(rawGraph, storageRoot = storageRoot)
// Materialize twice because some logic compares the incoming schema with
the previous one.
- materializeGraph(rawGraph)
+ materializeGraph(rawGraph, storageRoot = storageRoot)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val tableA =
@@ -805,9 +825,9 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
)
}.resolveToDataflowGraph()
- materializeGraph(rawGraph)
+ materializeGraph(rawGraph, storageRoot = storageRoot)
// Materialize twice because some logic compares the incoming schema with
the previous one.
- materializeGraph(rawGraph)
+ materializeGraph(rawGraph, storageRoot = storageRoot)
val catalog =
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val tableA =
catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE),
"a"))
@@ -849,7 +869,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
registerTable("a")
}.resolveToDataflowGraph()
- materializeGraph(graph1)
+ materializeGraph(graph1, storageRoot = storageRoot)
materializeGraph(
graph2,
contextOpt = Option(
@@ -857,9 +877,11 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
spark = spark,
unresolvedGraph = graph2,
refreshTables = NoTables,
- fullRefreshTables = NoTables
+ fullRefreshTables = NoTables,
+ storageRoot = storageRoot
)
- )
+ ),
+ storageRoot = storageRoot
)
}
}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
new file mode 100644
index 000000000000..15dfa6576171
--- /dev/null
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream,
StreamingQueryWrapper}
+import org.apache.spark.sql.pipelines.utils.{ExecutionTest,
TestGraphRegistrationContext}
+import org.apache.spark.sql.test.SharedSparkSession
+
+class SystemMetadataSuite
+ extends ExecutionTest
+ with SystemMetadataTestHelpers
+ with SharedSparkSession {
+
+ gridTest(
+ "flow checkpoint for ST wrote to the expected location when fullRefresh
="
+ ) (Seq(false, true))
+ { fullRefresh =>
+ val session = spark
+ import session.implicits._
+
+ // create a pipeline with only a single ST
+ val graph = new TestGraphRegistrationContext(spark) {
+ val mem: MemoryStream[Int] = MemoryStream[Int]
+ mem.addData(1, 2, 3)
+ registerView("a", query = dfFlowFunc(mem.toDF()))
+ registerTable("st")
+ registerFlow("st", "st", query = readStreamFlowFunc("a"))
+ }.toDataflowGraph
+
+ val updateContext1 = TestPipelineUpdateContext(
+ unresolvedGraph = graph,
+ spark = spark,
+ storageRoot = storageRoot,
+ failOnErrorEvent = true
+ )
+
+ // run first update
+ updateContext1.pipelineExecution.startPipeline()
+ updateContext1.pipelineExecution.awaitCompletion()
+
+ val graphExecution1 = updateContext1.pipelineExecution.graphExecution.get
+ val executionGraph1 = graphExecution1.graphForExecution
+
+ val stIdentifier = fullyQualifiedIdentifier("st")
+
+ // assert checkpoint v0
+ assertFlowCheckpointDirExists(
+ tableOrSinkElement = executionGraph1.table(stIdentifier),
+ flowElement = executionGraph1.flow(stIdentifier),
+ expectedCheckpointVersion = 0,
+ graphExecution = graphExecution1,
+ updateContext = updateContext1
+ )
+
+ // run second update, either refresh (reuse) or full refresh (increment)
+ val updateContext2 = TestPipelineUpdateContext(
+ unresolvedGraph = graph,
+ spark = spark,
+ storageRoot = storageRoot,
+ refreshTables = if (!fullRefresh) AllTables else NoTables,
+ fullRefreshTables = if (fullRefresh) AllTables else NoTables,
+ failOnErrorEvent = true
+ )
+
+ updateContext2.pipelineExecution.startPipeline()
+ updateContext2.pipelineExecution.awaitCompletion()
+
+ val graphExecution2 = updateContext2.pipelineExecution.graphExecution.get
+ val executionGraph2 = graphExecution2.graphForExecution
+
+ // checkpoint reused in refresh, incremented in full refresh
+ val expectedVersion = if (fullRefresh) 1 else 0
+ assertFlowCheckpointDirExists(
+ tableOrSinkElement = executionGraph2.table(stIdentifier),
+ flowElement = executionGraph1.flow(stIdentifier),
+ expectedCheckpointVersion = expectedVersion,
+ graphExecution = graphExecution2,
+ updateContext2
+ )
+ }
+
+ test(
+ "flow checkpoint for ST (with flow name different from table name) wrote
to the expected " +
+ "location"
+ ) {
+ val session = spark
+ import session.implicits._
+
+ val graph = new TestGraphRegistrationContext(spark) {
+ val mem: MemoryStream[Int] = MemoryStream[Int]
+ mem.addData(1, 2, 3)
+ registerView("a", query = dfFlowFunc(mem.toDF()))
+ registerTable("st")
+ registerFlow("st", "st_flow", query = readStreamFlowFunc("a"))
+ }.toDataflowGraph
+
+ val updateContext1 = TestPipelineUpdateContext(
+ unresolvedGraph = graph,
+ spark = spark,
+ storageRoot = storageRoot
+ )
+
+ // start an update in continuous mode, checkpoints are only created to
streaming query
+ updateContext1.pipelineExecution.startPipeline()
+ updateContext1.pipelineExecution.awaitCompletion()
+
+ val graphExecution1 = updateContext1.pipelineExecution.graphExecution.get
+ val executionGraph1 = graphExecution1.graphForExecution
+
+ val stIdentifier = fullyQualifiedIdentifier("st")
+ val stFlowIdentifier = fullyQualifiedIdentifier("st_flow")
+
+ // assert that the checkpoint dir for the ST is created as expected
+ assertFlowCheckpointDirExists(
+ tableOrSinkElement = executionGraph1.table(stIdentifier),
+ flowElement = executionGraph1.flow(stFlowIdentifier),
+ // the default checkpoint version is 0
+ expectedCheckpointVersion = 0,
+ graphExecution = graphExecution1,
+ updateContext = updateContext1
+ )
+
+ // start another update in full refresh, expected a new checkpoint dir to
be created
+ // with version number incremented to 1
+ val updateContext2 = TestPipelineUpdateContext(
+ unresolvedGraph = graph,
+ spark = spark,
+ storageRoot = storageRoot,
+ fullRefreshTables = AllTables
+ )
+
+ updateContext2.pipelineExecution.startPipeline()
+ updateContext2.pipelineExecution.awaitCompletion()
+ val graphExecution2 = updateContext2.pipelineExecution.graphExecution.get
+ val executionGraph2 = graphExecution2.graphForExecution
+
+ // due to full refresh, assert that new checkpoint dir is created with
version number
+ // incremented to 1
+ assertFlowCheckpointDirExists(
+ tableOrSinkElement = executionGraph2.table(stIdentifier),
+ flowElement = executionGraph1.flow(stFlowIdentifier),
+ // new checkpoint directory is created
+ expectedCheckpointVersion = 1,
+ graphExecution = graphExecution2,
+ updateContext2
+ )
+ }
+}
+
+trait SystemMetadataTestHelpers {
+ this: ExecutionTest =>
+
+ /** Return the expected checkpoint directory location for a table or sink.
+ * These directories have
"<expectedStorageName>/<expectedCheckpointVersion>" as their suffix. */
+ private def getExpectedFlowCheckpointDirForTableOrSink(
+ tableOrSinkElement: GraphElement,
+ flowElement: Flow,
+ expectedCheckpointVersion: Int,
+ updateContext: PipelineUpdateContext
+ ): Path = {
+ val expectedRawCheckPointDir = tableOrSinkElement match {
+ case t: Table => new Path(updateContext.storageRoot)
+
.suffix(s"/_checkpoints/${t.identifier.table}/${flowElement.identifier.table}")
+ .toString
+ case _ =>
+ fail(
+ s"unexpected table element type for
assertFlowCheckpointDirForTableOrSink: " +
+ tableOrSinkElement.getClass.getSimpleName
+ )
+ }
+
+ new Path("file://",
expectedRawCheckPointDir).suffix(s"/$expectedCheckpointVersion")
+ }
+
+ /** Return the actual checkpoint directory location used by the table or
sink.
+ * We use a Flow object as input since the checkpoints for a table or sink
are associated with
+ * each of their incoming flows.
+ */
+ private def getActualFlowCheckpointDirForTableOrSink(
+ flowElement: Flow,
+ graphExecution: GraphExecution
+ ): Path = {
+ // spark flow stream takes a while to be created, so we need to poll for it
+ val flowStream = graphExecution
+ .flowExecutions(flowElement.identifier)
+ .asInstanceOf[StreamingFlowExecution]
+ .getStreamingQuery
+
+ // we grab the checkpoint location from the actual spark stream query
executed in the update
+ // execution, which is the best source of truth.
+ new Path(
+
flowStream.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot
+ )
+ }
+
+ /** Assert the flow checkpoint directory exists in the filesystem. */
+ protected def assertFlowCheckpointDirExists(
+ tableOrSinkElement: GraphElement,
+ flowElement: Flow,
+ expectedCheckpointVersion: Int,
+ graphExecution: GraphExecution,
+ updateContext: PipelineUpdateContext
+ ): Path = {
+ val expectedCheckPointDir = getExpectedFlowCheckpointDirForTableOrSink(
+ tableOrSinkElement = tableOrSinkElement,
+ flowElement = flowElement,
+ expectedCheckpointVersion = expectedCheckpointVersion,
+ updateContext
+ )
+ val actualCheckpointDir = getActualFlowCheckpointDirForTableOrSink(
+ flowElement = flowElement,
+ graphExecution = graphExecution
+ )
+ val fs =
expectedCheckPointDir.getFileSystem(spark.sessionState.newHadoopConf())
+ assert(actualCheckpointDir == expectedCheckPointDir)
+ assert(fs.exists(actualCheckpointDir))
+ actualCheckpointDir
+ }
+
+ /** Assert the flow checkpoint directory does not exist in the filesystem. */
+ protected def assertFlowCheckpointDirNotExists(
+ tableOrSinkElement: GraphElement,
+ flowElement: Flow,
+ expectedCheckpointVersion: Int,
+ graphExecution: GraphExecution,
+ updateContext: PipelineUpdateContext
+ ): Path = {
+ val expectedCheckPointDir = getExpectedFlowCheckpointDirForTableOrSink(
+ tableOrSinkElement = tableOrSinkElement,
+ flowElement = flowElement,
+ expectedCheckpointVersion = expectedCheckpointVersion,
+ updateContext
+ )
+ val actualCheckpointDir = getActualFlowCheckpointDirForTableOrSink(
+ flowElement = flowElement,
+ graphExecution = graphExecution
+ )
+ val fs =
expectedCheckPointDir.getFileSystem(spark.sessionState.newHadoopConf())
+ assert(actualCheckpointDir != expectedCheckPointDir)
+ assert(!fs.exists(expectedCheckPointDir))
+ actualCheckpointDir
+ }
+}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
index 4fcd9dad93fe..57baf4c2d5b1 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
@@ -68,7 +68,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with
SharedSparkSession
resolvedGraph.resolvedFlows.filter(_.identifier ==
fullyQualifiedIdentifier("b")).head
assert(bFlow.inputs == Set(fullyQualifiedIdentifier("a")))
- val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph,
storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -128,7 +128,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
resolvedGraph.resolvedFlows.filter(_.identifier ==
fullyQualifiedIdentifier("d")).head
assert(dFlow.inputs == Set(fullyQualifiedIdentifier("c", isTemporaryView =
true)))
- val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph,
storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -201,7 +201,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -268,7 +268,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -329,7 +329,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
registerTable("z", query = Option(readStreamFlowFunc("x")))
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -437,7 +437,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
}
val graph1 = pipelineDef1.toDataflowGraph
- val updateContext1 = TestPipelineUpdateContext(spark, graph1)
+ val updateContext1 = TestPipelineUpdateContext(spark, graph1, storageRoot)
updateContext1.pipelineExecution.runPipeline()
updateContext1.pipelineExecution.awaitCompletion()
assertFlowProgressEvent(
@@ -458,7 +458,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
)
}
val graph2 = pipelineDef2.toDataflowGraph
- val updateContext2 = TestPipelineUpdateContext(spark, graph2)
+ val updateContext2 = TestPipelineUpdateContext(spark, graph2, storageRoot)
updateContext2.pipelineExecution.runPipeline()
updateContext2.pipelineExecution.awaitCompletion()
@@ -506,6 +506,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
val ctx = TestPipelineUpdateContext(
spark,
pipelineDef.toDataflowGraph,
+ storageRoot,
fullRefreshTables = AllTables,
resetCheckpointFlows = AllFlows
)
@@ -562,7 +563,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.startPipeline()
val graphExecution = updateContext.pipelineExecution.graphExecution.get
@@ -626,7 +627,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -688,7 +689,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -767,7 +768,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -826,7 +827,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
)
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -883,6 +884,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
val updateContext1 = TestPipelineUpdateContext(
spark = spark,
unresolvedGraph = graph1,
+ storageRoot = storageRoot,
refreshTables = SomeTables(
Set(fullyQualifiedIdentifier("source"),
fullyQualifiedIdentifier("all"))
),
@@ -932,6 +934,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
val updateContext2 = TestPipelineUpdateContext(
spark = spark,
unresolvedGraph = graph1,
+ storageRoot = storageRoot,
refreshTables = SomeTables(
Set(fullyQualifiedIdentifier("source"),
fullyQualifiedIdentifier("max_evens"))
),
@@ -986,7 +989,8 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
registerTable("table3", query = Option(sqlFlowFunc(spark, "SELECT * FROM
table1")))
}.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark = spark,
unresolvedGraph = graph)
+ val updateContext = TestPipelineUpdateContext(spark = spark,
+ unresolvedGraph = graph, storageRoot = storageRoot)
updateContext.pipelineExecution.runPipeline()
assertFlowProgressEvent(
@@ -1041,7 +1045,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest
with SharedSparkSession
}
val graph = pipelineDef.toDataflowGraph
- val updateContext = TestPipelineUpdateContext(spark, graph)
+ val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ViewSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ViewSuite.scala
index 3b40f887fe08..452c66844e9c 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ViewSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ViewSuite.scala
@@ -48,7 +48,8 @@ class ViewSuite extends ExecutionTest with SharedSparkSession
{
val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
sqlText = s"CREATE VIEW $viewName AS SELECT * FROM range(1, 4);"
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
updateContext.pipelineExecution.startPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -81,7 +82,8 @@ class ViewSuite extends ExecutionTest with SharedSparkSession
{
val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
sqlText = s"CREATE VIEW $viewName AS SELECT * FROM $source;"
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
updateContext.pipelineExecution.startPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -108,7 +110,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
sqlText = s"""CREATE TEMPORARY VIEW temp_view AS SELECT * FROM range(1,
4);
|CREATE VIEW $viewName AS SELECT * FROM
temp_view;""".stripMargin
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
val ex = intercept[AnalysisException] {
updateContext.pipelineExecution.startPipeline()
@@ -131,7 +134,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
sqlText = s"CREATE VIEW myview AS SELECT * FROM nonexistent_view;"
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
val ex = intercept[Exception] {
updateContext.pipelineExecution.startPipeline()
@@ -150,7 +154,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
sqlText = s"""CREATE STREAMING TABLE source AS SELECT * FROM
STREAM($externalTable1Ident);
|CREATE VIEW $viewName AS SELECT * FROM
STREAM(source);""".stripMargin
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
val ex = intercept[AnalysisException] {
updateContext.pipelineExecution.startPipeline()
@@ -168,7 +173,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
sqlText = s"CREATE VIEW $viewName AS SELECT * FROM
STREAM($externalTable1Ident);"
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
val ex = intercept[AnalysisException] {
updateContext.pipelineExecution.startPipeline()
@@ -190,7 +196,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
|CREATE VIEW pv2 AS SELECT * FROM pv1;
|CREATE VIEW pv1 AS SELECT * FROM range(1,
4);""".stripMargin
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
updateContext.pipelineExecution.startPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -217,7 +224,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
|CREATE VIEW pv2 AS SELECT * FROM pv1;
|CREATE VIEW pv1 AS SELECT 1 + 1;""".stripMargin
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
updateContext.pipelineExecution.startPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -253,7 +261,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
sqlText = s"""CREATE MATERIALIZED VIEW mymv AS SELECT * FROM RANGE(1, 4);
|CREATE VIEW $viewName AS SELECT * FROM
$source;""".stripMargin
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
updateContext.pipelineExecution.startPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -286,7 +295,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
sqlText = s"""CREATE STREAMING TABLE myst AS SELECT * FROM
STREAM($externalTable1Ident);
|CREATE VIEW $viewName AS SELECT * FROM
$source;""".stripMargin
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
updateContext.pipelineExecution.startPipeline()
updateContext.pipelineExecution.awaitCompletion()
@@ -317,7 +327,8 @@ class ViewSuite extends ExecutionTest with
SharedSparkSession {
|CREATE VIEW $viewName AS SELECT * FROM range(1, 4);
|CREATE MATERIALIZED VIEW myviewreader AS SELECT * FROM
$viewName;""".stripMargin
)
- val updateContext = TestPipelineUpdateContext(spark,
unresolvedDataflowGraph)
+ val updateContext =
+ TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
updateContext.pipelineExecution.startPipeline()
updateContext.pipelineExecution.awaitCompletion()
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala
index 3ee73f9394e9..317dba3dd67c 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala
@@ -27,10 +27,11 @@ trait BaseCoreExecutionTest extends ExecutionTest {
*/
protected def materializeGraph(
graph: DataflowGraph,
- contextOpt: Option[PipelineUpdateContext] = None
+ contextOpt: Option[PipelineUpdateContext] = None,
+ storageRoot: String
): DataflowGraph = {
val contextToUse = contextOpt.getOrElse(
- TestPipelineUpdateContext(spark = spark, unresolvedGraph = graph)
+ TestPipelineUpdateContext(spark = spark, unresolvedGraph = graph,
storageRoot = storageRoot)
)
DatasetManager.materializeDatasets(graph, contextToUse)
}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala
index 6c2c07e57498..2fba90d3a0d0 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala
@@ -58,6 +58,7 @@ trait TestPipelineUpdateContextMixin {
case class TestPipelineUpdateContext(
spark: SparkSession,
unresolvedGraph: DataflowGraph,
+ storageRoot: String,
fullRefreshTables: TableFilter = NoTables,
refreshTables: TableFilter = AllTables,
resetCheckpointFlows: FlowFilter = AllFlows,
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
index 54b324c182f2..f9d6aba9e22d 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.pipelines.utils
import java.io.{BufferedReader, FileNotFoundException, InputStreamReader}
-import java.nio.file.Files
import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Try}
@@ -34,22 +33,23 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.execution._
import org.apache.spark.sql.pipelines.graph.{DataflowGraph,
PipelineUpdateContextImpl, SqlGraphRegistrationContext}
-import org.apache.spark.sql.pipelines.utils.PipelineTest.{cleanupMetastore,
createTempDir}
+import org.apache.spark.sql.pipelines.utils.PipelineTest.cleanupMetastore
import org.apache.spark.sql.test.SQLTestUtils
abstract class PipelineTest
extends QueryTest
+ with StorageRootMixin
with SQLTestUtils
with SparkErrorTestMixin
with TargetCatalogAndDatabaseMixin
with Logging
with Eventually {
- final protected val storageRoot = createTempDir()
-
- protected def startPipelineAndWaitForCompletion(unresolvedDataflowGraph:
DataflowGraph): Unit = {
+ protected def startPipelineAndWaitForCompletion(
+ unresolvedDataflowGraph: DataflowGraph): Unit = {
val updateContext = new PipelineUpdateContextImpl(
- unresolvedDataflowGraph, eventCallback = _ => ())
+ unresolvedDataflowGraph, eventCallback = _ => (),
+ storageRoot = storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()
}
@@ -299,11 +299,6 @@ object PipelineTest extends Logging {
"main"
)
- /** Creates a temporary directory. */
- protected def createTempDir(): String = {
- Files.createTempDirectory(getClass.getSimpleName).normalize.toString
- }
-
/**
* Try to drop the schema in the catalog and return whether it is
successfully dropped.
*/
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala
new file mode 100644
index 000000000000..420e2c6ad0e9
--- /dev/null
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.utils
+
+
+import java.io.File
+import java.nio.file.Files
+
+import org.scalatest.{BeforeAndAfterEach, Suite}
+
+import org.apache.spark.util.Utils
+
+/**
+ * A mixin trait for tests that need a temporary directory as the storage root
for pipelines.
+ * This trait creates a temporary directory before each test and deletes it
after each test.
+ * The path to the temporary directory is available via the `storageRoot`
variable.
+ */
+trait StorageRootMixin extends BeforeAndAfterEach { self: Suite =>
+
+ /** A temporary directory created as the pipeline storage root for each
test. */
+ protected var storageRoot: String = _
+
+ override protected def beforeEach(): Unit = {
+ super.beforeEach()
+ storageRoot =
+ Files.createTempDirectory(getClass.getSimpleName).normalize.toString
+ }
+
+ override protected def afterEach(): Unit = {
+ super.afterEach()
+ Utils.deleteRecursively(new File(storageRoot))
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]