This is an automated email from the ASF dual-hosted git repository.

sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a1c742098c2 [SPARK-53751][SDP] Explicit Versioned Checkpoint Location
9a1c742098c2 is described below

commit 9a1c742098c24c7031aa6f494827b13a41b458fa
Author: Jacky Wang <[email protected]>
AuthorDate: Wed Oct 8 12:27:52 2025 -0700

    [SPARK-53751][SDP] Explicit Versioned Checkpoint Location
    
    ### What changes were proposed in this pull request?
    
    Add a `storage` field in pipeline spec to allow users specify locations of 
metadata such as streaming checkpoints.
    Below is the structure of the directory, which offer supports for 
multi-flow and versioned directory where version number is incremented after a 
full refresh.
    
    ```
    storage-root/
    └── _checkpoints/ # checkpoint root
          ├── myst/
          │    ├── flow1/
          │    │    ├── 0/ # version 0
          │    │    │    ├── commits/
          │    │    │    ├── offsets/
          │    │    │    └── sources/
          │    │    └── 1/ # version 1
          │    │
          │    └── flow2/
          │         ├── 0/
          │         │    ├── commits/
          │         │    ├── offsets/
          │         │    └── sources/
          │         └── 1/
          │
          └── mysink/
                └── flowA/
                     ├── 0/
                     │    ├── commits/
                     │    ├── offsets/
                     │    └── sources/
                     └── 1/
    
    ```
    
    ### Why are the changes needed?
    
    Currently, SDP stores streaming flow ckpts in the table path. This does not 
allow support for versioned checkpoints and does not work for sinks.
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    New and existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52487 from 
JiaqiWang18/SPARK-53751-explicit-versioned-multiflow-checkpoint.
    
    Authored-by: Jacky Wang <[email protected]>
    Signed-off-by: Sandy Ryza <[email protected]>
---
 python/pyspark/pipelines/cli.py                    |  16 +-
 python/pyspark/pipelines/init_cli.py               |   1 +
 python/pyspark/pipelines/spark_connect_pipeline.py |   4 +
 python/pyspark/pipelines/tests/test_cli.py         |  42 +++-
 .../pyspark/pipelines/tests/test_spark_connect.py  |   2 +
 python/pyspark/sql/connect/proto/pipelines_pb2.py  |  54 ++---
 python/pyspark/sql/connect/proto/pipelines_pb2.pyi |  16 ++
 .../main/protobuf/spark/connect/pipelines.proto    |   3 +
 .../sql/connect/pipelines/PipelinesHandler.scala   |   8 +-
 .../spark/sql/connect/SparkConnectServerTest.scala |   4 +-
 .../sql/connect/pipelines/EndToEndAPISuite.scala   |   1 +
 .../pipelines/PipelineEventStreamSuite.scala       |   8 +-
 .../pipelines/PipelineRefreshFunctionalSuite.scala |   8 +
 .../connect/pipelines/PythonPipelineSuite.scala    |  32 +--
 .../SparkDeclarativePipelinesServerTest.scala      |  10 +-
 .../service/SparkConnectSessionHolderSuite.scala   |   3 +-
 .../service/SparkConnectSessionManagerSuite.scala  |   3 +-
 .../spark/sql/pipelines/graph/FlowExecution.scala  |   9 +
 .../spark/sql/pipelines/graph/FlowPlanner.scala    |   3 +-
 .../sql/pipelines/graph/PipelineExecution.scala    |   4 +
 .../pipelines/graph/PipelineUpdateContext.scala    |   5 +
 .../graph/PipelineUpdateContextImpl.scala          |   3 +-
 .../apache/spark/sql/pipelines/graph/State.scala   | 104 +++++++++
 .../spark/sql/pipelines/graph/SystemMetadata.scala | 126 ++++++++++
 .../pipelines/graph/MaterializeTablesSuite.scala   |  94 +++++---
 .../sql/pipelines/graph/SystemMetadataSuite.scala  | 260 +++++++++++++++++++++
 .../graph/TriggeredGraphExecutionSuite.scala       |  32 +--
 .../spark/sql/pipelines/graph/ViewSuite.scala      |  33 ++-
 .../pipelines/utils/BaseCoreExecutionTest.scala    |   5 +-
 .../spark/sql/pipelines/utils/ExecutionTest.scala  |   1 +
 .../spark/sql/pipelines/utils/PipelineTest.scala   |  17 +-
 .../sql/pipelines/utils/StorageRootMixin.scala     |  48 ++++
 32 files changed, 822 insertions(+), 137 deletions(-)

diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py
index 5342b7fb430c..ca198f1c3aff 100644
--- a/python/pyspark/pipelines/cli.py
+++ b/python/pyspark/pipelines/cli.py
@@ -90,6 +90,7 @@ class PipelineSpec:
     """Spec for a pipeline.
 
     :param name: The name of the pipeline.
+    :param storage: The root directory for storing metadata, such as streaming 
checkpoints.
     :param catalog: The default catalog to use for the pipeline.
     :param database: The default database to use for the pipeline.
     :param configuration: A dictionary of Spark configuration properties to 
set for the pipeline.
@@ -97,6 +98,7 @@ class PipelineSpec:
     """
 
     name: str
+    storage: str
     catalog: Optional[str]
     database: Optional[str]
     configuration: Mapping[str, str]
@@ -150,8 +152,16 @@ def load_pipeline_spec(spec_path: Path) -> PipelineSpec:
 
 
 def unpack_pipeline_spec(spec_data: Mapping[str, Any]) -> PipelineSpec:
-    ALLOWED_FIELDS = {"name", "catalog", "database", "schema", 
"configuration", "libraries"}
-    REQUIRED_FIELDS = ["name"]
+    ALLOWED_FIELDS = {
+        "name",
+        "storage",
+        "catalog",
+        "database",
+        "schema",
+        "configuration",
+        "libraries",
+    }
+    REQUIRED_FIELDS = ["name", "storage"]
     for key in spec_data.keys():
         if key not in ALLOWED_FIELDS:
             raise PySparkException(
@@ -167,6 +177,7 @@ def unpack_pipeline_spec(spec_data: Mapping[str, Any]) -> 
PipelineSpec:
 
     return PipelineSpec(
         name=spec_data["name"],
+        storage=spec_data["storage"],
         catalog=spec_data.get("catalog"),
         database=spec_data.get("database", spec_data.get("schema")),
         configuration=validate_str_dict(spec_data.get("configuration", {}), 
"configuration"),
@@ -323,6 +334,7 @@ def run(
         full_refresh_all=full_refresh_all,
         refresh=refresh,
         dry=dry,
+        storage=spec.storage,
     )
     try:
         handle_pipeline_events(result_iter)
diff --git a/python/pyspark/pipelines/init_cli.py 
b/python/pyspark/pipelines/init_cli.py
index 47be703f7795..ffe5d3c12b63 100644
--- a/python/pyspark/pipelines/init_cli.py
+++ b/python/pyspark/pipelines/init_cli.py
@@ -19,6 +19,7 @@ from pathlib import Path
 
 SPEC = """
 name: {{ name }}
+storage: storage-root
 libraries:
   - glob:
       include: transformations/**
diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py 
b/python/pyspark/pipelines/spark_connect_pipeline.py
index 61b72956e5cc..e3c1184cea39 100644
--- a/python/pyspark/pipelines/spark_connect_pipeline.py
+++ b/python/pyspark/pipelines/spark_connect_pipeline.py
@@ -72,6 +72,7 @@ def start_run(
     full_refresh_all: bool,
     refresh: Optional[Sequence[str]],
     dry: bool,
+    storage: str,
 ) -> Iterator[Dict[str, Any]]:
     """Start a run of the dataflow graph in the Spark Connect server.
 
@@ -79,6 +80,8 @@ def start_run(
     :param full_refresh: List of datasets to reset and recompute.
     :param full_refresh_all: Perform a full graph reset and recompute.
     :param refresh: List of datasets to update.
+    :param dry: If true, the run will not actually execute any flows, but only 
validate the graph.
+    :param storage: The storage location to store metadata such as streaming 
checkpoints.
     """
     inner_command = pb2.PipelineCommand.StartRun(
         dataflow_graph_id=dataflow_graph_id,
@@ -86,6 +89,7 @@ def start_run(
         full_refresh_all=full_refresh_all,
         refresh_selection=refresh or [],
         dry=dry,
+        storage=storage,
     )
     command = pb2.Command()
     command.pipeline_command.start_run.CopyFrom(inner_command)
diff --git a/python/pyspark/pipelines/tests/test_cli.py 
b/python/pyspark/pipelines/tests/test_cli.py
index fbc6d3a90ac8..894b8be283f0 100644
--- a/python/pyspark/pipelines/tests/test_cli.py
+++ b/python/pyspark/pipelines/tests/test_cli.py
@@ -60,7 +60,8 @@ class CLIUtilityTests(unittest.TestCase):
                     },
                     "libraries": [
                         {"glob": {"include": "test_include"}}
-                    ]
+                    ],
+                    "storage": "storage_path",
                 }
                 """
             )
@@ -72,6 +73,7 @@ class CLIUtilityTests(unittest.TestCase):
             assert spec.configuration == {"key1": "value1", "key2": "value2"}
             assert len(spec.libraries) == 1
             assert spec.libraries[0].include == "test_include"
+            assert spec.storage == "storage_path"
 
     def test_load_pipeline_spec_name_is_required(self):
         with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
@@ -86,7 +88,8 @@ class CLIUtilityTests(unittest.TestCase):
                     },
                     "libraries": [
                         {"glob": {"include": "test_include"}}
-                    ]
+                    ],
+                    "storage": "storage_path",
                 }
                 """
             )
@@ -112,7 +115,8 @@ class CLIUtilityTests(unittest.TestCase):
                     },
                     "libraries": [
                         {"glob": {"include": "test_include"}}
-                    ]
+                    ],
+                    "storage": "storage_path",
                 }
                 """
             )
@@ -150,21 +154,38 @@ class CLIUtilityTests(unittest.TestCase):
 
     def test_unpack_empty_pipeline_spec(self):
         empty_spec = PipelineSpec(
-            name="test_pipeline", catalog=None, database=None, 
configuration={}, libraries=[]
+            name="test_pipeline",
+            storage="storage_path",
+            catalog=None,
+            database=None,
+            configuration={},
+            libraries=[],
+        )
+        self.assertEqual(
+            unpack_pipeline_spec({"name": "test_pipeline", "storage": 
"storage_path"}), empty_spec
         )
-        self.assertEqual(unpack_pipeline_spec({"name": "test_pipeline"}), 
empty_spec)
 
     def test_unpack_pipeline_spec_bad_configuration(self):
         with self.assertRaises(TypeError) as context:
-            unpack_pipeline_spec({"name": "test_pipeline", "configuration": 
"not_a_dict"})
+            unpack_pipeline_spec(
+                {"name": "test_pipeline", "storage": "storage_path", 
"configuration": "not_a_dict"}
+            )
         self.assertIn("should be a dict", str(context.exception))
 
         with self.assertRaises(TypeError) as context:
-            unpack_pipeline_spec({"name": "test_pipeline", "configuration": 
{"key": {}}})
+            unpack_pipeline_spec(
+                {"name": "test_pipeline", "storage": "storage_path", 
"configuration": {"key": {}}}
+            )
         self.assertIn("key", str(context.exception))
 
         with self.assertRaises(TypeError) as context:
-            unpack_pipeline_spec({"name": "test_pipeline", "configuration": 
{1: "something"}})
+            unpack_pipeline_spec(
+                {
+                    "name": "test_pipeline",
+                    "storage": "storage_path",
+                    "configuration": {1: "something"},
+                }
+            )
         self.assertIn("int", str(context.exception))
 
     def test_find_pipeline_spec_in_current_directory(self):
@@ -239,6 +260,7 @@ class CLIUtilityTests(unittest.TestCase):
             name="test_pipeline",
             catalog=None,
             database=None,
+            storage="storage_path",
             configuration={},
             libraries=[LibrariesGlob(include="subdir1/**")],
         )
@@ -280,6 +302,7 @@ class CLIUtilityTests(unittest.TestCase):
         """Errors raised while executing definitions code should make it to 
the outer context."""
         spec = PipelineSpec(
             name="test_pipeline",
+            storage="storage_path",
             catalog=None,
             database=None,
             configuration={},
@@ -298,6 +321,7 @@ class CLIUtilityTests(unittest.TestCase):
     def 
test_register_definitions_unsupported_file_extension_matches_glob(self):
         spec = PipelineSpec(
             name="test_pipeline",
+            storage="storage_path",
             catalog=None,
             database=None,
             configuration={},
@@ -352,6 +376,7 @@ class CLIUtilityTests(unittest.TestCase):
                     registry,
                     PipelineSpec(
                         name="test_pipeline",
+                        storage="storage_path",
                         catalog=None,
                         database=None,
                         configuration={},
@@ -500,6 +525,7 @@ class CLIUtilityTests(unittest.TestCase):
                 """
                 {
                     "name": "test_pipeline",
+                    "storage": "storage_path",
                     "libraries": [
                         {"glob": {"include": "transformations/**/*.py"}}
                     ]
diff --git a/python/pyspark/pipelines/tests/test_spark_connect.py 
b/python/pyspark/pipelines/tests/test_spark_connect.py
index 6d81a98c8c44..0b54c0906f9a 100644
--- a/python/pyspark/pipelines/tests/test_spark_connect.py
+++ b/python/pyspark/pipelines/tests/test_spark_connect.py
@@ -58,6 +58,7 @@ class SparkConnectPipelinesTest(ReusedConnectTestCase):
             refresh=None,
             full_refresh_all=False,
             dry=True,
+            storage="storage_path",
         )
         handle_pipeline_events(result_iter)
 
@@ -79,6 +80,7 @@ class SparkConnectPipelinesTest(ReusedConnectTestCase):
             refresh=None,
             full_refresh_all=False,
             dry=True,
+            storage="storage_path",
         )
         with self.assertRaises(AnalysisException) as context:
             handle_pipeline_events(result_iter)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py 
b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index ba3d2f634a48..f53f91452454 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -41,7 +41,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xfb\x1b\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
 
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineD [...]
+    
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xa6\x1c\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
 
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineD [...]
 )
 
 _globals = globals()
@@ -60,10 +60,10 @@ if not _descriptor._USE_C_DESCRIPTORS:
     
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options
 = b"8\001"
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = 
b"8\001"
-    _globals["_DATASETTYPE"]._serialized_start = 4896
-    _globals["_DATASETTYPE"]._serialized_end = 4993
+    _globals["_DATASETTYPE"]._serialized_start = 4939
+    _globals["_DATASETTYPE"]._serialized_end = 5036
     _globals["_PIPELINECOMMAND"]._serialized_start = 168
-    _globals["_PIPELINECOMMAND"]._serialized_end = 3747
+    _globals["_PIPELINECOMMAND"]._serialized_end = 3790
     _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1050
     _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1358
     
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start 
= 1259
@@ -81,27 +81,27 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 2692
     _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 2750
     _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2865
-    _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3144
-    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 
3147
-    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 3346
-    
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
 = 3349
-    
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
 = 3507
-    
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start = 
3510
-    _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end 
= 3731
-    _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 3750
-    _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 4506
-    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start 
= 4122
-    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 
4220
-    _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start = 
4223
-    _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end = 
4357
-    _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start = 
4360
-    _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 4491
-    _globals["_PIPELINEEVENTRESULT"]._serialized_start = 4508
-    _globals["_PIPELINEEVENTRESULT"]._serialized_end = 4581
-    _globals["_PIPELINEEVENT"]._serialized_start = 4583
-    _globals["_PIPELINEEVENT"]._serialized_end = 4699
-    _globals["_SOURCECODELOCATION"]._serialized_start = 4701
-    _globals["_SOURCECODELOCATION"]._serialized_end = 4823
-    _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 4825
-    _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 4894
+    _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3187
+    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 
3190
+    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 3389
+    
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
 = 3392
+    
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
 = 3550
+    
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start = 
3553
+    _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end 
= 3774
+    _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 3793
+    _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 4549
+    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start 
= 4165
+    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 
4263
+    _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start = 
4266
+    _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end = 
4400
+    _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start = 
4403
+    _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 4534
+    _globals["_PIPELINEEVENTRESULT"]._serialized_start = 4551
+    _globals["_PIPELINEEVENTRESULT"]._serialized_end = 4624
+    _globals["_PIPELINEEVENT"]._serialized_start = 4626
+    _globals["_PIPELINEEVENT"]._serialized_end = 4742
+    _globals["_SOURCECODELOCATION"]._serialized_start = 4744
+    _globals["_SOURCECODELOCATION"]._serialized_end = 4866
+    _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 4868
+    _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 4937
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi 
b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index 575e15e0cb58..c42b3eefe2d4 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -578,6 +578,7 @@ class PipelineCommand(google.protobuf.message.Message):
         FULL_REFRESH_ALL_FIELD_NUMBER: builtins.int
         REFRESH_SELECTION_FIELD_NUMBER: builtins.int
         DRY_FIELD_NUMBER: builtins.int
+        STORAGE_FIELD_NUMBER: builtins.int
         dataflow_graph_id: builtins.str
         """The graph to start."""
         @property
@@ -596,6 +597,8 @@ class PipelineCommand(google.protobuf.message.Message):
         """If true, the run will not actually execute any flows, but will only 
validate the graph and
         check for any errors. This is useful for testing and validation 
purposes.
         """
+        storage: builtins.str
+        """storage location for pipeline checkpoints and metadata."""
         def __init__(
             self,
             *,
@@ -604,6 +607,7 @@ class PipelineCommand(google.protobuf.message.Message):
             full_refresh_all: builtins.bool | None = ...,
             refresh_selection: collections.abc.Iterable[builtins.str] | None = 
...,
             dry: builtins.bool | None = ...,
+            storage: builtins.str | None = ...,
         ) -> None: ...
         def HasField(
             self,
@@ -614,12 +618,16 @@ class PipelineCommand(google.protobuf.message.Message):
                 b"_dry",
                 "_full_refresh_all",
                 b"_full_refresh_all",
+                "_storage",
+                b"_storage",
                 "dataflow_graph_id",
                 b"dataflow_graph_id",
                 "dry",
                 b"dry",
                 "full_refresh_all",
                 b"full_refresh_all",
+                "storage",
+                b"storage",
             ],
         ) -> builtins.bool: ...
         def ClearField(
@@ -631,6 +639,8 @@ class PipelineCommand(google.protobuf.message.Message):
                 b"_dry",
                 "_full_refresh_all",
                 b"_full_refresh_all",
+                "_storage",
+                b"_storage",
                 "dataflow_graph_id",
                 b"dataflow_graph_id",
                 "dry",
@@ -641,6 +651,8 @@ class PipelineCommand(google.protobuf.message.Message):
                 b"full_refresh_selection",
                 "refresh_selection",
                 b"refresh_selection",
+                "storage",
+                b"storage",
             ],
         ) -> None: ...
         @typing.overload
@@ -656,6 +668,10 @@ class PipelineCommand(google.protobuf.message.Message):
         def WhichOneof(
             self, oneof_group: typing_extensions.Literal["_full_refresh_all", 
b"_full_refresh_all"]
         ) -> typing_extensions.Literal["full_refresh_all"] | None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_storage", 
b"_storage"]
+        ) -> typing_extensions.Literal["storage"] | None: ...
 
     class DefineSqlGraphElements(google.protobuf.message.Message):
         """Parses the SQL file and registers all datasets and flows."""
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index 130dcd54e59b..ef24899a5f74 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -143,6 +143,9 @@ message PipelineCommand {
     // If true, the run will not actually execute any flows, but will only 
validate the graph and
     // check for any errors. This is useful for testing and validation 
purposes.
     optional bool dry = 5;
+
+    // storage location for pipeline checkpoints and metadata.
+    optional string storage = 6;
   }
 
   // Parses the SQL file and registers all datasets and flows.
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index 21f6c9492f68..358821289c85 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -338,11 +338,17 @@ private[connect] object PipelinesHandler extends Logging {
         }
       }
 
+      if (cmd.getStorage.isEmpty) {
+        // server-side validation to ensure that storage is always specified
+        throw new IllegalArgumentException("Storage must be specified to start 
a run.")
+      }
+
       val pipelineUpdateContext = new PipelineUpdateContextImpl(
         graphElementRegistry.toDataflowGraph,
         eventCallback,
         tableFiltersResult.refresh,
-        tableFiltersResult.fullRefresh)
+        tableFiltersResult.fullRefresh,
+        cmd.getStorage)
       sessionHolder.cachePipelineExecution(dataflowGraphId, 
pipelineUpdateContext)
 
       if (cmd.getDry) {
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 92f64875d337..91e728d73e13 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -70,12 +70,12 @@ trait SparkConnectServerTest extends SharedSparkSession {
     super.afterAll()
   }
 
-  override def beforeEach(): Unit = {
+  protected override def beforeEach(): Unit = {
     super.beforeEach()
     clearAllExecutions()
   }
 
-  override def afterEach(): Unit = {
+  protected override def afterEach(): Unit = {
     clearAllExecutions()
     super.afterEach()
   }
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
index 923a85cb36f1..0a0cca488f1f 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
@@ -160,6 +160,7 @@ class EndToEndAPISuite extends PipelineTest with APITest 
with SparkConnectServer
       |name: test-pipeline
       |${spec.catalog.map(catalog => s"""catalog: "$catalog"""").getOrElse("")}
       |${spec.database.map(database => s"""database: 
"$database"""").getOrElse("")}
+      |storage: "${projectDir.resolve("storage").toAbsolutePath}"
       |configuration:
       |  "spark.remote": "sc://localhost:$serverPort"
       |libraries:
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
index 83862545a723..214d3b86fb78 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
@@ -42,7 +42,11 @@ class PipelineEventStreamSuite extends 
SparkDeclarativePipelinesServerTest {
       val capturedEvents = new ArrayBuffer[PipelineEvent]()
       withClient { client =>
         val startRunRequest = buildStartRunPlan(
-          
proto.PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())
+          proto.PipelineCommand.StartRun
+            .newBuilder()
+            .setDataflowGraphId(graphId)
+            .setStorage(storageRoot)
+            .build())
         val responseIterator = client.execute(startRunRequest)
         while (responseIterator.hasNext) {
           val response = responseIterator.next()
@@ -94,6 +98,7 @@ class PipelineEventStreamSuite extends 
SparkDeclarativePipelinesServerTest {
             proto.PipelineCommand.StartRun
               .newBuilder()
               .setDataflowGraphId(graphId)
+              .setStorage(storageRoot)
               .setDry(dry)
               .build())
           val ex = intercept[AnalysisException] {
@@ -144,6 +149,7 @@ class PipelineEventStreamSuite extends 
SparkDeclarativePipelinesServerTest {
           proto.PipelineCommand.StartRun
             .newBuilder()
             .setDataflowGraphId(graphId)
+            .setStorage(storageRoot)
             .setDry(true)
             .build())
         val responseIterator = client.execute(startRunRequest)
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala
index 794932544d5f..0e08b39abd19 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala
@@ -128,6 +128,7 @@ class PipelineRefreshFunctionalSuite
           PipelineCommand.StartRun
             .newBuilder()
             .setDataflowGraphId(graphId)
+            .setStorage(storageRoot)
             .addAllFullRefreshSelection(List("a").asJava)
             .build())
       },
@@ -168,6 +169,7 @@ class PipelineRefreshFunctionalSuite
             .setDataflowGraphId(graphId)
             .addAllFullRefreshSelection(Seq("a", "mv").asJava)
             .addRefreshSelection("b")
+            .setStorage(storageRoot)
             .build())
       },
       expectedContentAfterRefresh = Map(
@@ -211,6 +213,7 @@ class PipelineRefreshFunctionalSuite
           PipelineCommand.StartRun
             .newBuilder()
             .setDataflowGraphId(graphId)
+            .setStorage(storageRoot)
             .setFullRefreshAll(true)
             .build())
       },
@@ -231,6 +234,7 @@ class PipelineRefreshFunctionalSuite
       val startRun = PipelineCommand.StartRun
         .newBuilder()
         .setDataflowGraphId(graphId)
+        .setStorage(storageRoot)
         .setFullRefreshAll(true)
         .addRefreshSelection("a")
         .build()
@@ -253,6 +257,7 @@ class PipelineRefreshFunctionalSuite
       val startRun = PipelineCommand.StartRun
         .newBuilder()
         .setDataflowGraphId(graphId)
+        .setStorage(storageRoot)
         .setFullRefreshAll(true)
         .addFullRefreshSelection("a")
         .build()
@@ -275,6 +280,7 @@ class PipelineRefreshFunctionalSuite
       val startRun = PipelineCommand.StartRun
         .newBuilder()
         .setDataflowGraphId(graphId)
+        .setStorage(storageRoot)
         .addRefreshSelection("a")
         .addFullRefreshSelection("a")
         .build()
@@ -298,6 +304,7 @@ class PipelineRefreshFunctionalSuite
       val startRun = PipelineCommand.StartRun
         .newBuilder()
         .setDataflowGraphId(graphId)
+        .setStorage(storageRoot)
         .addRefreshSelection("a")
         .addRefreshSelection("b")
         .addFullRefreshSelection("a")
@@ -321,6 +328,7 @@ class PipelineRefreshFunctionalSuite
 
       val startRun = PipelineCommand.StartRun
         .newBuilder()
+        .setStorage(storageRoot)
         .setDataflowGraphId(graphId)
         .addRefreshSelection("spark_catalog.default.a")
         .addFullRefreshSelection("a") // This should be treated as the same 
table
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index b5461c415b0d..c4ed554ef978 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -129,7 +129,7 @@ class PythonPipelineSuite
         |    return df.select("name")
         |""".stripMargin)
 
-    val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+    val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, 
storageRoot)
     updateContext.pipelineExecution.runPipeline()
 
     assertFlowProgressEvent(
@@ -174,7 +174,7 @@ class PythonPipelineSuite
         |   return spark.readStream.table('mv2')
         |""".stripMargin)
 
-    val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+    val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, 
storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -602,7 +602,8 @@ class PythonPipelineSuite
         )
     """)
 
-    val updateContext = new PipelineUpdateContextImpl(graph, _ => ())
+    val updateContext =
+      new PipelineUpdateContextImpl(graph, _ => (), storageRoot = storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -628,18 +629,19 @@ class PythonPipelineSuite
   test("MV/ST with partition columns works") {
     withTable("mv", "st") {
       val graph = buildGraph("""
-             |from pyspark.sql.functions import col
-             |
-             |@dp.materialized_view(partition_cols = ["id_mod"])
-             |def mv():
-             |  return spark.range(5).withColumn("id_mod", col("id") % 2)
-             |
-             |@dp.table(partition_cols = ["id_mod"])
-             |def st():
-             |  return spark.readStream.table("mv")
-             |""".stripMargin)
-
-      val updateContext = new PipelineUpdateContextImpl(graph, eventCallback = 
_ => ())
+            |from pyspark.sql.functions import col
+            |
+            |@dp.materialized_view(partition_cols = ["id_mod"])
+            |def mv():
+            |  return spark.range(5).withColumn("id_mod", col("id") % 2)
+            |
+            |@dp.table(partition_cols = ["id_mod"])
+            |def st():
+            |  return spark.readStream.table("mv")
+            |""".stripMargin)
+
+      val updateContext =
+        new PipelineUpdateContextImpl(graph, eventCallback = _ => (), 
storageRoot = storageRoot)
       updateContext.pipelineExecution.runPipeline()
       updateContext.pipelineExecution.awaitCompletion()
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala
index a31883677f92..32c728cb6a64 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala
@@ -24,9 +24,9 @@ import org.apache.spark.connect.proto.{PipelineCommand, 
PipelineEvent}
 import org.apache.spark.sql.connect.{SparkConnectServerTest, 
SparkConnectTestUtils}
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.service.{SessionHolder, SessionKey, 
SparkConnectService}
-import org.apache.spark.sql.pipelines.utils.PipelineTest
+import org.apache.spark.sql.pipelines.utils.{PipelineTest, StorageRootMixin}
 
-class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest {
+class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest with 
StorageRootMixin {
 
   override def afterEach(): Unit = {
     SparkConnectService.sessionManager
@@ -138,7 +138,11 @@ class SparkDeclarativePipelinesServerTest extends 
SparkConnectServerTest {
 
   def startPipelineAndWaitForCompletion(graphId: String): 
ArrayBuffer[PipelineEvent] = {
     val defaultStartRunCommand =
-      PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build()
+      PipelineCommand.StartRun
+        .newBuilder()
+        .setDataflowGraphId(graphId)
+        .setStorage(storageRoot)
+        .build()
     startPipelineAndWaitForCompletion(defaultStartRunCommand)
   }
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index a110b0164f19..36c7e9bb80ec 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -434,7 +434,8 @@ class SparkConnectSessionHolderSuite extends 
SharedSparkSession {
     val graphId = "test_graph"
     val pipelineUpdateContext = new PipelineUpdateContextImpl(
       new DataflowGraph(Seq(), Seq(), Seq()),
-      (_: PipelineEvent) => None)
+      (_: PipelineEvent) => None,
+      storageRoot = "test_storage_root")
     sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
     assert(
       sessionHolder.getPipelineExecution(graphId).nonEmpty,
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
index f8b387039a7a..ee227168dc23 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
@@ -160,7 +160,8 @@ class SparkConnectSessionManagerSuite extends 
SharedSparkSession with BeforeAndA
     val graphId = "test_graph"
     val pipelineUpdateContext = new PipelineUpdateContextImpl(
       new DataflowGraph(Seq(), Seq(), Seq()),
-      (_: PipelineEvent) => None)
+      (_: PipelineEvent) => None,
+      storageRoot = "test_storage_root")
     sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
     assert(
       sessionHolder.getPipelineExecution(graphId).nonEmpty,
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
index c907d603199f..30ee77a2315e 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
@@ -191,6 +191,14 @@ trait StreamingFlowExecution extends FlowExecution with 
Logging {
   /** Starts a stream and returns its streaming query. */
   protected def startStream(): StreamingQuery
 
+  private var _streamingQuery: Option[StreamingQuery] = None
+
+  /** Visible for testing */
+  def getStreamingQuery: StreamingQuery =
+    _streamingQuery.getOrElse(
+      throw new IllegalStateException(s"StreamingPhysicalFlow has not been 
started")
+    )
+
   /**
    * Executes this `StreamingFlowExecution` by starting its stream with the 
correct scheduling pool
    * and confs.
@@ -201,6 +209,7 @@ trait StreamingFlowExecution extends FlowExecution with 
Logging {
       log"checkpoint location ${MDC(LogKeys.CHECKPOINT_PATH, checkpointPath)}"
     )
     val streamingQuery = SparkSessionUtils.withSqlConf(spark, sqlConf.toList: 
_*)(startStream())
+    _streamingQuery = Option(streamingQuery)
     Future(streamingQuery.awaitTermination())
   }
 }
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
index bb154a0081da..c4790e872951 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
@@ -50,6 +50,7 @@ class FlowPlanner(
           updateContext = updateContext
         )
       case sf: StreamingFlow =>
+        val flowMetadata = FlowSystemMetadata(updateContext, sf, graph)
         output match {
           case o: Table =>
             new StreamingTableWrite(
@@ -60,7 +61,7 @@ class FlowPlanner(
               updateContext = updateContext,
               sqlConf = sf.sqlConf,
               trigger = triggerFor(sf),
-              checkpointPath = output.path
+              checkpointPath = flowMetadata.latestCheckpointLocation
             )
           case _ =>
             throw new UnsupportedOperationException(
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
index 5bb6e25eaf45..6090b764d543 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
@@ -46,6 +46,10 @@ class PipelineExecution(context: PipelineUpdateContext) {
   def startPipeline(): Unit = synchronized {
     // Initialize the graph.
     val resolvedGraph = resolveGraph()
+    if (context.fullRefreshTables.nonEmpty) {
+      State.reset(resolvedGraph, context)
+    }
+
     val initializedGraph = DatasetManager.materializeDatasets(resolvedGraph, 
context)
 
     // Execute the graph.
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala
index d6f202080933..d72c895180b5 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala
@@ -33,6 +33,11 @@ trait PipelineUpdateContext {
 
   def resetCheckpointFlows: FlowFilter
 
+  /**
+   * The root storage location for pipeline metadata, including checkpoints 
for streaming flows.
+   */
+  def storageRoot: String
+
   /**
    * Filter for which flows should be refreshed when performing this update. 
Should be a superset of
    * fullRefreshFlows.
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala
index 5a298c2f17d9..bb2009b25912 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala
@@ -32,7 +32,8 @@ class PipelineUpdateContextImpl(
     override val unresolvedGraph: DataflowGraph,
     override val eventCallback: PipelineEvent => Unit,
     override val refreshTables: TableFilter = AllTables,
-    override val fullRefreshTables: TableFilter = NoTables
+    override val fullRefreshTables: TableFilter = NoTables,
+    override val storageRoot: String
 ) extends PipelineUpdateContext {
 
   override val spark: SparkSession = SparkSession.getActiveSession.getOrElse(
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
new file mode 100644
index 000000000000..31fc065cadc8
--- /dev/null
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.internal.{Logging, LogKeys}
+import org.apache.spark.sql.AnalysisException
+
+object State extends Logging {
+
+  /**
+   * Find the graph elements to reset given the current update context.
+   * @param graph The graph to reset.
+   * @param env The current update context.
+   */
+  private def findElementsToReset(graph: DataflowGraph, env: 
PipelineUpdateContext): Seq[Input] = {
+    // If tableFilter is an instance of SomeTables, this is a refresh 
selection and all tables
+    // to reset should be resettable; Otherwise, this is a full graph update, 
and we reset all
+    // tables that are resettable.
+    val specifiedTablesToReset = {
+      val specifiedTables = env.fullRefreshTables.filter(graph.tables)
+      env.fullRefreshTables match {
+        case SomeTables(_) =>
+          specifiedTables.foreach { t =>
+            if (!PipelinesTableProperties.resetAllowed.fromMap(t.properties)) {
+              throw new AnalysisException(
+                "TABLE_NOT_RESETTABLE",
+                Map("tableName" -> t.displayName)
+              )
+            }
+          }
+          specifiedTables
+        case AllTables =>
+          specifiedTables.filter(t => 
PipelinesTableProperties.resetAllowed.fromMap(t.properties))
+        case NoTables => Seq.empty
+      }
+    }
+
+    specifiedTablesToReset.flatMap(t => t +: 
graph.resolvedFlowsTo(t.identifier))
+  }
+
+  /**
+   * Performs the following on targets selected for full refresh:
+   * - Clearing checkpoint data
+   * - Truncating table data
+   */
+  def reset(resolvedGraph: DataflowGraph, env: PipelineUpdateContext): 
Seq[Input] = {
+    val elementsToReset: Seq[Input] = findElementsToReset(resolvedGraph, env)
+
+    elementsToReset.foreach {
+      case f: ResolvedFlow => reset(f, env, resolvedGraph)
+      case _ => // tables is handled in materializeTables since hive metastore 
does not support
+                // removing all columns from a table.
+    }
+
+    elementsToReset
+  }
+
+  /**
+   * Resets the checkpoint for the given flow by creating the next consecutive 
directory.
+   */
+  private def reset(flow: ResolvedFlow, env: PipelineUpdateContext, graph: 
DataflowGraph): Unit = {
+    logInfo(log"Clearing out state for flow ${MDC(LogKeys.FLOW_NAME, 
flow.displayName)}")
+    val flowMetadata = FlowSystemMetadata(env, flow, graph)
+    flow match {
+      case f if flowMetadata.latestCheckpointLocationOpt().isEmpty =>
+        logInfo(
+          s"Skipping resetting flow ${f.identifier} since its destination not 
been previously" +
+          s"materialized and we can't find the checkpoint location."
+        )
+      case _ =>
+        val hadoopConf = env.spark.sessionState.newHadoopConf()
+
+        // Write a new checkpoint folder if needed
+        val checkpointDir = new Path(flowMetadata.latestCheckpointLocation)
+        val fs1 = checkpointDir.getFileSystem(hadoopConf)
+        if (fs1.exists(checkpointDir)) {
+          val nextVersion = checkpointDir.getName.toInt + 1
+          val nextPath = new Path(checkpointDir.getParent, 
nextVersion.toString)
+          fs1.mkdirs(nextPath)
+          logInfo(
+            log"Created new checkpoint for stream ${MDC(LogKeys.FLOW_NAME, 
flow.displayName)} " +
+            log"at ${MDC(LogKeys.CHECKPOINT_PATH, nextPath.toString)}."
+          )
+        }
+    }
+  }
+}
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
new file mode 100644
index 000000000000..d805f4a689ec
--- /dev/null
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import scala.util.Try
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.internal.{Logging, LogKeys}
+import org.apache.spark.sql.classic.SparkSession
+
+sealed trait SystemMetadata {}
+
+/**
+ * Represents the system metadata associated with a [[Flow]].
+ */
+case class FlowSystemMetadata(
+    context: PipelineUpdateContext,
+    flow: Flow,
+    graph: DataflowGraph
+) extends SystemMetadata with Logging {
+
+  /**
+   * Returns the checkpoint root directory for a given flow
+   * which is storage/_checkpoints/flow_destination_table/flow_name.
+   * @return the checkpoint root directory for `flow`
+   */
+  private def flowCheckpointsDirOpt(): Option[Path] = {
+    Option(if (graph.table.contains(flow.destinationIdentifier)) {
+      val checkpointRoot = new Path(context.storageRoot, "_checkpoints")
+      val flowTableName = flow.destinationIdentifier.table
+      val flowName = flow.identifier.table
+      val checkpointDir = new Path(
+        new Path(checkpointRoot, flowTableName),
+        flowName
+      )
+      logInfo(
+        log"Flow ${MDC(LogKeys.FLOW_NAME, flowName)} using checkpoint " +
+          log"directory: ${MDC(LogKeys.CHECKPOINT_PATH, checkpointDir)}"
+      )
+      checkpointDir
+    } else {
+      throw new IllegalArgumentException(
+        s"Flow ${flow.identifier} does not have a valid destination for 
checkpoints."
+      )
+    })
+  }
+
+  /** Returns the location for the most recent checkpoint of a given flow. */
+  def latestCheckpointLocation: String = {
+    val checkpointsDir = flowCheckpointsDirOpt().get
+    SystemMetadata.getLatestCheckpointDir(checkpointsDir)
+  }
+
+  /**
+   * Same as [[latestCheckpointLocation]] but returns None if the flow 
checkpoints directory
+   * does not exist.
+   */
+  def latestCheckpointLocationOpt(): Option[String] = {
+    flowCheckpointsDirOpt().map { flowCheckpointsDir =>
+      SystemMetadata.getLatestCheckpointDir(flowCheckpointsDir)
+    }
+  }
+}
+
+object SystemMetadata {
+  private def spark = SparkSession.getActiveSession.get
+
+  /**
+   * Finds the largest checkpoint version subdirectory path within a 
checkpoint directory, or
+   * creates and returns a version 0 subdirectory path if no versions exist.
+   * @param rootDir The root/parent directory where all the numbered 
checkpoint subdirectories are
+   *                stored
+   * @param createNewCheckpointDir If true, a new latest numbered checkpoint 
directory should be
+   *                               created and returned
+   * @return The string URI path to the latest checkpoint directory
+   */
+  def getLatestCheckpointDir(
+      rootDir: Path,
+      createNewCheckpointDir: Boolean = false
+  ): String = {
+    val fs = rootDir.getFileSystem(spark.sessionState.newHadoopConf())
+    val defaultDir = new Path(rootDir, "0")
+    val checkpoint = if (fs.exists(rootDir)) {
+      val availableCheckpoints =
+        fs.listStatus(rootDir)
+          .toSeq
+          .sortBy(fs => Try(fs.getPath.getName.toInt).getOrElse(-1))
+      availableCheckpoints.lastOption
+        .filter(fs => Try(fs.getPath.getName.toInt).isSuccess)
+        .map(
+          latestCheckpoint =>
+            if (createNewCheckpointDir) {
+              val incrementedLatestCheckpointDir =
+                new Path(rootDir, 
Math.max(latestCheckpoint.getPath.getName.toInt + 1, 0).toString)
+              fs.mkdirs(incrementedLatestCheckpointDir)
+              incrementedLatestCheckpointDir
+            } else {
+              latestCheckpoint.getPath
+            }
+        )
+        .getOrElse {
+          fs.mkdirs(defaultDir)
+          defaultDir
+        }
+    } else {
+      defaultDir
+    }
+    checkpoint.toUri.toString
+  }
+}
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index 37c32a349866..afe769f7b204 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -52,7 +52,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           comment = Option("p-comment"),
           query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2"))
         )
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")
@@ -80,7 +81,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           comment = Option("p-comment"),
           query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2"))
         )
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
     val catalogTable2 = catalog.loadTable(identifier)
     assert(
@@ -104,7 +106,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           comment = Option("p-comment"),
           query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2"))
         )
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val catalogTable3 = catalog.loadTable(identifier)
@@ -136,7 +139,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
         )
         registerTable("t1")
         registerTable("t2")
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val identifier1 = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t1")
@@ -171,7 +175,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           "t1",
           dfFlowFunc(Seq(1, 2, 3).toDF("x"))
         )
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -201,7 +206,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
         query = sqlFlowFunc(spark, "SELECT value AS timestamp FROM a")
       )
     }
-    materializeGraph(new P1().resolveToDataflowGraph())
+    materializeGraph(new P1().resolveToDataflowGraph(), storageRoot = 
storageRoot)
 
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
     val b =
@@ -222,7 +227,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
         query = sqlFlowFunc(spark, "SELECT timestamp FROM a")
       )
     }
-    materializeGraph(new P2().resolveToDataflowGraph())
+    materializeGraph(new P2().resolveToDataflowGraph(), storageRoot = 
storageRoot)
     val b2 =
       
catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE),
 "b"))
     assert(
@@ -249,7 +254,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
       new TestGraphRegistrationContext(spark) {
         registerFlow("t2", "t2", query = dfFlowFunc(Seq(1, 2, 3).toDF("x")))
         registerTable("t2")
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val table2 = catalog.loadTable(identifier)
@@ -271,7 +277,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
       new TestGraphRegistrationContext(spark) {
         registerView("a", query = dfFlowFunc(streamInts.toDF()))
         registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT value AS 
x FROM STREAM a")))
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val streamStrings = MemoryStream[String]
@@ -282,7 +289,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
     }.resolveToDataflowGraph()
 
     val ex = intercept[TableMaterializationException] {
-      materializeGraph(graph2)
+      materializeGraph(graph2, storageRoot = storageRoot)
     }
     val cause = ex.cause
     val exStr = exceptionString(cause)
@@ -314,7 +321,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           ),
           query = dfFlowFunc(Seq[Short](1, 2).toDF("x"))
         )
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val table2 = catalog.loadTable(identifier)
@@ -352,7 +360,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           query = Option(dfFlowFunc(source.toDF().select($"value" as "x")))
         )
 
-      }.resolveToDataflowGraph())
+      }.resolveToDataflowGraph(), storageRoot = storageRoot)
     }
     val cause = ex.cause
     val exStr = exceptionString(cause)
@@ -365,7 +373,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
         specifiedSchema = Option(new StructType().add("x", IntegerType)),
         query = dfFlowFunc(Seq(1, 2).toDF("x"))
       )
-    }.resolveToDataflowGraph())
+    }.resolveToDataflowGraph(),
+    storageRoot = storageRoot)
     val table2 = catalog.loadTable(identifier)
     assert(
       table2.columns() sameElements CatalogV2Util
@@ -389,7 +398,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           ),
           partitionCols = Option(Seq("x2"))
         )
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
     val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")
@@ -434,7 +444,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           "t7",
           partitionCols = Option(Seq("x"))
         )
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val table2 = catalog.loadTable(identifier)
@@ -454,7 +465,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
             query = dfFlowFunc(Seq((true, 1), (false, 3)).toDF("x", "y"))
           )
           registerTable("t7")
-        }.resolveToDataflowGraph()
+        }.resolveToDataflowGraph(),
+        storageRoot = storageRoot
       )
     }
     assert(ex.cause.asInstanceOf[SparkThrowable].getCondition == 
"CANNOT_UPDATE_PARTITION_COLUMNS")
@@ -490,7 +502,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
     }.resolveToDataflowGraph()
 
     val ex = intercept[TableMaterializationException] {
-      materializeGraph(graph)
+      materializeGraph(graph, storageRoot = storageRoot)
     }
     assert(ex.cause.asInstanceOf[SparkThrowable].getCondition == 
"CANNOT_UPDATE_PARTITION_COLUMNS")
     val table = catalog.loadTable(identifier)
@@ -513,7 +525,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           query = Option(sqlFlowFunc(spark, "SELECT * FROM STREAM a")),
           properties = Map("pipelines.reset.alloweD" -> "true", "some.prop" -> 
"foo")
         )
-      }.resolveToDataflowGraph()
+      }.resolveToDataflowGraph(),
+      storageRoot = storageRoot
     )
 
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -546,7 +559,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
       }.resolveToDataflowGraph()
     val ex1 =
       intercept[TableMaterializationException] {
-        materializeGraph(graph1)
+        materializeGraph(graph1, storageRoot = storageRoot)
       }
 
     assert(ex1.cause.isInstanceOf[IllegalArgumentException])
@@ -565,7 +578,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
       registerTable("a", query = 
Option(dfFlowFunc(spark.readStream.format("rate").load())))
     }.resolveToDataflowGraph().validate()
 
-    materializeGraph(graph1)
+    materializeGraph(graph1, storageRoot = storageRoot)
   }
 
   for (isFullRefresh <- Seq(true, false)) {
@@ -581,7 +594,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           registerMaterializedView("b", query = sqlFlowFunc(spark, "SELECT x 
FROM a"))
         }.resolveToDataflowGraph()
 
-      val graph = materializeGraph(rawGraph)
+      val graph = materializeGraph(rawGraph, storageRoot = storageRoot)
       val (refreshSelection, fullRefreshSelection) = if (isFullRefresh) {
         (NoTables, AllTables)
       } else {
@@ -595,9 +608,11 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
             spark = spark,
             unresolvedGraph = graph,
             refreshTables = refreshSelection,
-            fullRefreshTables = fullRefreshSelection
+            fullRefreshTables = fullRefreshSelection,
+            storageRoot = storageRoot
           )
-        )
+        ),
+        storageRoot = storageRoot
       )
 
       val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -613,7 +628,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
         new TestGraphRegistrationContext(spark) {
           registerView("a", query = dfFlowFunc(Seq((1, 2), (2, 3)).toDF("x", 
"y")))
           registerMaterializedView("b", query = sqlFlowFunc(spark, "SELECT y 
FROM a"))
-        }.resolveToDataflowGraph()
+        }.resolveToDataflowGraph(),
+        storageRoot = storageRoot
       )
       val table2 = catalog.loadTable(identifier)
       assert(
@@ -650,10 +666,11 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           spark = spark,
           unresolvedGraph = graph,
           refreshTables = refreshSelection,
-          fullRefreshTables = fullRefreshSelection
+          fullRefreshTables = fullRefreshSelection,
+          storageRoot = storageRoot
         )
       )
-      materializeGraph(graph, contextOpt = updateContextOpt)
+      materializeGraph(graph, contextOpt = updateContextOpt, storageRoot = 
storageRoot)
 
       val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
       val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b")
@@ -668,7 +685,8 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           registerView("a", query = dfFlowFunc(streamInts.toDF()))
           registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT value 
AS y FROM STREAM a")))
         }.resolveToDataflowGraph().validate(),
-        contextOpt = updateContextOpt
+        contextOpt = updateContextOpt,
+        storageRoot = storageRoot
       )
 
       val table2 = catalog.loadTable(identifier)
@@ -709,9 +727,11 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           spark = spark,
           unresolvedGraph = graph,
           refreshTables = SomeTables(Set(fullyQualifiedIdentifier("a"))),
-          fullRefreshTables = SomeTables(Set(fullyQualifiedIdentifier("c")))
+          fullRefreshTables = SomeTables(Set(fullyQualifiedIdentifier("c"))),
+          storageRoot = storageRoot
         )
-      )
+      ),
+      storageRoot = storageRoot
     )
 
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -756,9 +776,9 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           )
         )
       }.resolveToDataflowGraph()
-    materializeGraph(rawGraph)
+    materializeGraph(rawGraph, storageRoot = storageRoot)
     // Materialize twice because some logic compares the incoming schema with 
the previous one.
-    materializeGraph(rawGraph)
+    materializeGraph(rawGraph, storageRoot = storageRoot)
 
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
     val tableA =
@@ -805,9 +825,9 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
         )
 
       }.resolveToDataflowGraph()
-    materializeGraph(rawGraph)
+    materializeGraph(rawGraph, storageRoot = storageRoot)
     // Materialize twice because some logic compares the incoming schema with 
the previous one.
-    materializeGraph(rawGraph)
+    materializeGraph(rawGraph, storageRoot = storageRoot)
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
     val tableA =
       
catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE),
 "a"))
@@ -849,7 +869,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
       registerTable("a")
     }.resolveToDataflowGraph()
 
-    materializeGraph(graph1)
+    materializeGraph(graph1, storageRoot = storageRoot)
     materializeGraph(
       graph2,
       contextOpt = Option(
@@ -857,9 +877,11 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
           spark = spark,
           unresolvedGraph = graph2,
           refreshTables = NoTables,
-          fullRefreshTables = NoTables
+          fullRefreshTables = NoTables,
+          storageRoot = storageRoot
         )
-      )
+      ),
+      storageRoot = storageRoot
     )
   }
 }
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
new file mode 100644
index 000000000000..15dfa6576171
--- /dev/null
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryWrapper}
+import org.apache.spark.sql.pipelines.utils.{ExecutionTest, 
TestGraphRegistrationContext}
+import org.apache.spark.sql.test.SharedSparkSession
+
+class SystemMetadataSuite
+    extends ExecutionTest
+    with SystemMetadataTestHelpers
+    with SharedSparkSession {
+
+    gridTest(
+      "flow checkpoint for ST wrote to the expected location when fullRefresh 
="
+    ) (Seq(false, true))
+    { fullRefresh =>
+      val session = spark
+      import session.implicits._
+
+      // create a pipeline with only a single ST
+      val graph = new TestGraphRegistrationContext(spark) {
+        val mem: MemoryStream[Int] = MemoryStream[Int]
+        mem.addData(1, 2, 3)
+        registerView("a", query = dfFlowFunc(mem.toDF()))
+        registerTable("st")
+        registerFlow("st", "st", query = readStreamFlowFunc("a"))
+      }.toDataflowGraph
+
+      val updateContext1 = TestPipelineUpdateContext(
+        unresolvedGraph = graph,
+        spark = spark,
+        storageRoot = storageRoot,
+        failOnErrorEvent = true
+      )
+
+      // run first update
+      updateContext1.pipelineExecution.startPipeline()
+      updateContext1.pipelineExecution.awaitCompletion()
+
+      val graphExecution1 = updateContext1.pipelineExecution.graphExecution.get
+      val executionGraph1 = graphExecution1.graphForExecution
+
+      val stIdentifier = fullyQualifiedIdentifier("st")
+
+      // assert checkpoint v0
+      assertFlowCheckpointDirExists(
+        tableOrSinkElement = executionGraph1.table(stIdentifier),
+        flowElement = executionGraph1.flow(stIdentifier),
+        expectedCheckpointVersion = 0,
+        graphExecution = graphExecution1,
+        updateContext = updateContext1
+      )
+
+      // run second update, either refresh (reuse) or full refresh (increment)
+      val updateContext2 = TestPipelineUpdateContext(
+        unresolvedGraph = graph,
+        spark = spark,
+        storageRoot = storageRoot,
+        refreshTables = if (!fullRefresh) AllTables else NoTables,
+        fullRefreshTables = if (fullRefresh) AllTables else NoTables,
+        failOnErrorEvent = true
+      )
+
+      updateContext2.pipelineExecution.startPipeline()
+      updateContext2.pipelineExecution.awaitCompletion()
+
+      val graphExecution2 = updateContext2.pipelineExecution.graphExecution.get
+      val executionGraph2 = graphExecution2.graphForExecution
+
+      // checkpoint reused in refresh, incremented in full refresh
+      val expectedVersion = if (fullRefresh) 1 else 0
+      assertFlowCheckpointDirExists(
+        tableOrSinkElement = executionGraph2.table(stIdentifier),
+        flowElement = executionGraph1.flow(stIdentifier),
+        expectedCheckpointVersion = expectedVersion,
+        graphExecution = graphExecution2,
+        updateContext2
+      )
+    }
+
+  test(
+    "flow checkpoint for ST (with flow name different from table name) wrote 
to the expected " +
+    "location"
+  ) {
+    val session = spark
+    import session.implicits._
+
+    val graph = new TestGraphRegistrationContext(spark) {
+      val mem: MemoryStream[Int] = MemoryStream[Int]
+      mem.addData(1, 2, 3)
+      registerView("a", query = dfFlowFunc(mem.toDF()))
+      registerTable("st")
+      registerFlow("st", "st_flow", query = readStreamFlowFunc("a"))
+    }.toDataflowGraph
+
+    val updateContext1 = TestPipelineUpdateContext(
+      unresolvedGraph = graph,
+      spark = spark,
+      storageRoot = storageRoot
+    )
+
+    // start an update in continuous mode, checkpoints are only created to 
streaming query
+    updateContext1.pipelineExecution.startPipeline()
+    updateContext1.pipelineExecution.awaitCompletion()
+
+    val graphExecution1 = updateContext1.pipelineExecution.graphExecution.get
+    val executionGraph1 = graphExecution1.graphForExecution
+
+    val stIdentifier = fullyQualifiedIdentifier("st")
+    val stFlowIdentifier = fullyQualifiedIdentifier("st_flow")
+
+    // assert that the checkpoint dir for the ST is created as expected
+    assertFlowCheckpointDirExists(
+      tableOrSinkElement = executionGraph1.table(stIdentifier),
+      flowElement = executionGraph1.flow(stFlowIdentifier),
+      // the default checkpoint version is 0
+      expectedCheckpointVersion = 0,
+      graphExecution = graphExecution1,
+      updateContext = updateContext1
+    )
+
+    // start another update in full refresh, expected a new checkpoint dir to 
be created
+    // with version number incremented to 1
+    val updateContext2 = TestPipelineUpdateContext(
+      unresolvedGraph = graph,
+      spark = spark,
+      storageRoot = storageRoot,
+      fullRefreshTables = AllTables
+    )
+
+    updateContext2.pipelineExecution.startPipeline()
+    updateContext2.pipelineExecution.awaitCompletion()
+    val graphExecution2 = updateContext2.pipelineExecution.graphExecution.get
+    val executionGraph2 = graphExecution2.graphForExecution
+
+    // due to full refresh, assert that new checkpoint dir is created with 
version number
+    // incremented to 1
+    assertFlowCheckpointDirExists(
+      tableOrSinkElement = executionGraph2.table(stIdentifier),
+      flowElement = executionGraph1.flow(stFlowIdentifier),
+      // new checkpoint directory is created
+      expectedCheckpointVersion = 1,
+      graphExecution = graphExecution2,
+      updateContext2
+    )
+  }
+}
+
+trait SystemMetadataTestHelpers {
+  this: ExecutionTest =>
+
+  /** Return the expected checkpoint directory location for a table or sink.
+   *  These directories have 
"<expectedStorageName>/<expectedCheckpointVersion>" as their suffix. */
+  private def getExpectedFlowCheckpointDirForTableOrSink(
+      tableOrSinkElement: GraphElement,
+      flowElement: Flow,
+      expectedCheckpointVersion: Int,
+      updateContext: PipelineUpdateContext
+  ): Path = {
+    val expectedRawCheckPointDir = tableOrSinkElement match {
+      case t: Table => new Path(updateContext.storageRoot)
+        
.suffix(s"/_checkpoints/${t.identifier.table}/${flowElement.identifier.table}")
+        .toString
+      case _ =>
+        fail(
+          s"unexpected table element type for 
assertFlowCheckpointDirForTableOrSink: " +
+          tableOrSinkElement.getClass.getSimpleName
+        )
+    }
+
+    new Path("file://", 
expectedRawCheckPointDir).suffix(s"/$expectedCheckpointVersion")
+  }
+
+  /** Return the actual checkpoint directory location used by the table or 
sink.
+   * We use a Flow object as input since the checkpoints for a table or sink 
are associated with
+   * each of their incoming flows.
+   */
+  private def getActualFlowCheckpointDirForTableOrSink(
+      flowElement: Flow,
+      graphExecution: GraphExecution
+  ): Path = {
+    // spark flow stream takes a while to be created, so we need to poll for it
+    val flowStream = graphExecution
+      .flowExecutions(flowElement.identifier)
+      .asInstanceOf[StreamingFlowExecution]
+      .getStreamingQuery
+
+    // we grab the checkpoint location from the actual spark stream query 
executed in the update
+    // execution, which is the best source of truth.
+    new Path(
+      
flowStream.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot
+    )
+  }
+
+  /** Assert the flow checkpoint directory exists in the filesystem. */
+  protected def assertFlowCheckpointDirExists(
+      tableOrSinkElement: GraphElement,
+      flowElement: Flow,
+      expectedCheckpointVersion: Int,
+      graphExecution: GraphExecution,
+      updateContext: PipelineUpdateContext
+  ): Path = {
+    val expectedCheckPointDir = getExpectedFlowCheckpointDirForTableOrSink(
+      tableOrSinkElement = tableOrSinkElement,
+      flowElement = flowElement,
+      expectedCheckpointVersion = expectedCheckpointVersion,
+      updateContext
+    )
+    val actualCheckpointDir = getActualFlowCheckpointDirForTableOrSink(
+      flowElement = flowElement,
+      graphExecution = graphExecution
+    )
+    val fs = 
expectedCheckPointDir.getFileSystem(spark.sessionState.newHadoopConf())
+    assert(actualCheckpointDir == expectedCheckPointDir)
+    assert(fs.exists(actualCheckpointDir))
+    actualCheckpointDir
+  }
+
+  /** Assert the flow checkpoint directory does not exist in the filesystem. */
+  protected def assertFlowCheckpointDirNotExists(
+      tableOrSinkElement: GraphElement,
+      flowElement: Flow,
+      expectedCheckpointVersion: Int,
+      graphExecution: GraphExecution,
+      updateContext: PipelineUpdateContext
+  ): Path = {
+    val expectedCheckPointDir = getExpectedFlowCheckpointDirForTableOrSink(
+      tableOrSinkElement = tableOrSinkElement,
+      flowElement = flowElement,
+      expectedCheckpointVersion = expectedCheckpointVersion,
+      updateContext
+    )
+    val actualCheckpointDir = getActualFlowCheckpointDirForTableOrSink(
+      flowElement = flowElement,
+      graphExecution = graphExecution
+    )
+    val fs = 
expectedCheckPointDir.getFileSystem(spark.sessionState.newHadoopConf())
+    assert(actualCheckpointDir != expectedCheckPointDir)
+    assert(!fs.exists(expectedCheckPointDir))
+    actualCheckpointDir
+  }
+}
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
index 4fcd9dad93fe..57baf4c2d5b1 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
@@ -68,7 +68,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with 
SharedSparkSession
       resolvedGraph.resolvedFlows.filter(_.identifier == 
fullyQualifiedIdentifier("b")).head
     assert(bFlow.inputs == Set(fullyQualifiedIdentifier("a")))
 
-    val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+    val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, 
storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -128,7 +128,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
       resolvedGraph.resolvedFlows.filter(_.identifier == 
fullyQualifiedIdentifier("d")).head
     assert(dFlow.inputs == Set(fullyQualifiedIdentifier("c", isTemporaryView = 
true)))
 
-    val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+    val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, 
storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -201,7 +201,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     }
     val graph = pipelineDef.toDataflowGraph
 
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -268,7 +268,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     }
     val graph = pipelineDef.toDataflowGraph
 
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -329,7 +329,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
       registerTable("z", query = Option(readStreamFlowFunc("x")))
     }
     val graph = pipelineDef.toDataflowGraph
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -437,7 +437,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     }
 
     val graph1 = pipelineDef1.toDataflowGraph
-    val updateContext1 = TestPipelineUpdateContext(spark, graph1)
+    val updateContext1 = TestPipelineUpdateContext(spark, graph1, storageRoot)
     updateContext1.pipelineExecution.runPipeline()
     updateContext1.pipelineExecution.awaitCompletion()
     assertFlowProgressEvent(
@@ -458,7 +458,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
       )
     }
     val graph2 = pipelineDef2.toDataflowGraph
-    val updateContext2 = TestPipelineUpdateContext(spark, graph2)
+    val updateContext2 = TestPipelineUpdateContext(spark, graph2, storageRoot)
     updateContext2.pipelineExecution.runPipeline()
     updateContext2.pipelineExecution.awaitCompletion()
 
@@ -506,6 +506,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     val ctx = TestPipelineUpdateContext(
       spark,
       pipelineDef.toDataflowGraph,
+      storageRoot,
       fullRefreshTables = AllTables,
       resetCheckpointFlows = AllFlows
     )
@@ -562,7 +563,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     }
 
     val graph = pipelineDef.toDataflowGraph
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.startPipeline()
 
     val graphExecution = updateContext.pipelineExecution.graphExecution.get
@@ -626,7 +627,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     }
 
     val graph = pipelineDef.toDataflowGraph
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -688,7 +689,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     }
 
     val graph = pipelineDef.toDataflowGraph
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -767,7 +768,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     }
 
     val graph = pipelineDef.toDataflowGraph
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -826,7 +827,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
       )
     }
     val graph = pipelineDef.toDataflowGraph
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -883,6 +884,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     val updateContext1 = TestPipelineUpdateContext(
       spark = spark,
       unresolvedGraph = graph1,
+      storageRoot = storageRoot,
       refreshTables = SomeTables(
         Set(fullyQualifiedIdentifier("source"), 
fullyQualifiedIdentifier("all"))
       ),
@@ -932,6 +934,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     val updateContext2 = TestPipelineUpdateContext(
       spark = spark,
       unresolvedGraph = graph1,
+      storageRoot = storageRoot,
       refreshTables = SomeTables(
         Set(fullyQualifiedIdentifier("source"), 
fullyQualifiedIdentifier("max_evens"))
       ),
@@ -986,7 +989,8 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
       registerTable("table3", query = Option(sqlFlowFunc(spark, "SELECT * FROM 
table1")))
     }.toDataflowGraph
 
-    val updateContext = TestPipelineUpdateContext(spark = spark, 
unresolvedGraph = graph)
+    val updateContext = TestPipelineUpdateContext(spark = spark,
+      unresolvedGraph = graph, storageRoot = storageRoot)
     updateContext.pipelineExecution.runPipeline()
 
     assertFlowProgressEvent(
@@ -1041,7 +1045,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     }
 
     val graph = pipelineDef.toDataflowGraph
-    val updateContext = TestPipelineUpdateContext(spark, graph)
+    val updateContext = TestPipelineUpdateContext(spark, graph, storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ViewSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ViewSuite.scala
index 3b40f887fe08..452c66844e9c 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ViewSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ViewSuite.scala
@@ -48,7 +48,8 @@ class ViewSuite extends ExecutionTest with SharedSparkSession 
{
     val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
       sqlText = s"CREATE VIEW $viewName AS SELECT * FROM range(1, 4);"
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
     updateContext.pipelineExecution.startPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -81,7 +82,8 @@ class ViewSuite extends ExecutionTest with SharedSparkSession 
{
     val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
       sqlText = s"CREATE VIEW $viewName AS SELECT * FROM $source;"
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
     updateContext.pipelineExecution.startPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -108,7 +110,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
       sqlText = s"""CREATE TEMPORARY VIEW temp_view AS SELECT * FROM range(1, 
4);
                    |CREATE VIEW $viewName AS SELECT * FROM 
temp_view;""".stripMargin
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
 
     val ex = intercept[AnalysisException] {
       updateContext.pipelineExecution.startPipeline()
@@ -131,7 +134,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
     val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
       sqlText = s"CREATE VIEW myview AS SELECT * FROM nonexistent_view;"
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
 
     val ex = intercept[Exception] {
       updateContext.pipelineExecution.startPipeline()
@@ -150,7 +154,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
       sqlText = s"""CREATE STREAMING TABLE source AS SELECT * FROM 
STREAM($externalTable1Ident);
                    |CREATE VIEW $viewName AS SELECT * FROM 
STREAM(source);""".stripMargin
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
 
     val ex = intercept[AnalysisException] {
       updateContext.pipelineExecution.startPipeline()
@@ -168,7 +173,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
     val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
       sqlText = s"CREATE VIEW $viewName AS SELECT * FROM 
STREAM($externalTable1Ident);"
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
 
     val ex = intercept[AnalysisException] {
       updateContext.pipelineExecution.startPipeline()
@@ -190,7 +196,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
                    |CREATE VIEW pv2 AS SELECT * FROM pv1;
                    |CREATE VIEW pv1 AS SELECT * FROM range(1, 
4);""".stripMargin
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
     updateContext.pipelineExecution.startPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -217,7 +224,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
         |CREATE VIEW pv2 AS SELECT * FROM pv1;
         |CREATE VIEW pv1 AS SELECT 1 + 1;""".stripMargin
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
     updateContext.pipelineExecution.startPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -253,7 +261,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
       sqlText = s"""CREATE MATERIALIZED VIEW mymv AS SELECT * FROM RANGE(1, 4);
                    |CREATE VIEW $viewName AS SELECT * FROM 
$source;""".stripMargin
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
     updateContext.pipelineExecution.startPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -286,7 +295,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
       sqlText = s"""CREATE STREAMING TABLE myst AS SELECT * FROM 
STREAM($externalTable1Ident);
                    |CREATE VIEW $viewName AS SELECT * FROM 
$source;""".stripMargin
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
     updateContext.pipelineExecution.startPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
@@ -317,7 +327,8 @@ class ViewSuite extends ExecutionTest with 
SharedSparkSession {
         |CREATE VIEW $viewName AS SELECT * FROM range(1, 4);
         |CREATE MATERIALIZED VIEW myviewreader AS SELECT * FROM 
$viewName;""".stripMargin
     )
-    val updateContext = TestPipelineUpdateContext(spark, 
unresolvedDataflowGraph)
+    val updateContext =
+      TestPipelineUpdateContext(spark, unresolvedDataflowGraph, storageRoot)
     updateContext.pipelineExecution.startPipeline()
     updateContext.pipelineExecution.awaitCompletion()
 
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala
index 3ee73f9394e9..317dba3dd67c 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala
@@ -27,10 +27,11 @@ trait BaseCoreExecutionTest extends ExecutionTest {
    */
   protected def materializeGraph(
       graph: DataflowGraph,
-      contextOpt: Option[PipelineUpdateContext] = None
+      contextOpt: Option[PipelineUpdateContext] = None,
+      storageRoot: String
   ): DataflowGraph = {
     val contextToUse = contextOpt.getOrElse(
-      TestPipelineUpdateContext(spark = spark, unresolvedGraph = graph)
+      TestPipelineUpdateContext(spark = spark, unresolvedGraph = graph, 
storageRoot = storageRoot)
     )
     DatasetManager.materializeDatasets(graph, contextToUse)
   }
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala
index 6c2c07e57498..2fba90d3a0d0 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala
@@ -58,6 +58,7 @@ trait TestPipelineUpdateContextMixin {
   case class TestPipelineUpdateContext(
       spark: SparkSession,
       unresolvedGraph: DataflowGraph,
+      storageRoot: String,
       fullRefreshTables: TableFilter = NoTables,
       refreshTables: TableFilter = AllTables,
       resetCheckpointFlows: FlowFilter = AllFlows,
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
index 54b324c182f2..f9d6aba9e22d 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.pipelines.utils
 
 import java.io.{BufferedReader, FileNotFoundException, InputStreamReader}
-import java.nio.file.Files
 
 import scala.collection.mutable.ArrayBuffer
 import scala.util.{Failure, Try}
@@ -34,22 +33,23 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.pipelines.graph.{DataflowGraph, 
PipelineUpdateContextImpl, SqlGraphRegistrationContext}
-import org.apache.spark.sql.pipelines.utils.PipelineTest.{cleanupMetastore, 
createTempDir}
+import org.apache.spark.sql.pipelines.utils.PipelineTest.cleanupMetastore
 import org.apache.spark.sql.test.SQLTestUtils
 
 abstract class PipelineTest
   extends QueryTest
+  with StorageRootMixin
   with SQLTestUtils
   with SparkErrorTestMixin
   with TargetCatalogAndDatabaseMixin
   with Logging
   with Eventually {
 
-  final protected val storageRoot = createTempDir()
-
-  protected def startPipelineAndWaitForCompletion(unresolvedDataflowGraph: 
DataflowGraph): Unit = {
+  protected def startPipelineAndWaitForCompletion(
+       unresolvedDataflowGraph: DataflowGraph): Unit = {
     val updateContext = new PipelineUpdateContextImpl(
-      unresolvedDataflowGraph, eventCallback = _ => ())
+      unresolvedDataflowGraph, eventCallback = _ => (),
+      storageRoot = storageRoot)
     updateContext.pipelineExecution.runPipeline()
     updateContext.pipelineExecution.awaitCompletion()
   }
@@ -299,11 +299,6 @@ object PipelineTest extends Logging {
     "main"
   )
 
-  /** Creates a temporary directory. */
-  protected def createTempDir(): String = {
-    Files.createTempDirectory(getClass.getSimpleName).normalize.toString
-  }
-
   /**
    * Try to drop the schema in the catalog and return whether it is 
successfully dropped.
    */
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala
new file mode 100644
index 000000000000..420e2c6ad0e9
--- /dev/null
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.utils
+
+
+import java.io.File
+import java.nio.file.Files
+
+import org.scalatest.{BeforeAndAfterEach, Suite}
+
+import org.apache.spark.util.Utils
+
+/**
+ * A mixin trait for tests that need a temporary directory as the storage root 
for pipelines.
+ * This trait creates a temporary directory before each test and deletes it 
after each test.
+ * The path to the temporary directory is available via the `storageRoot` 
variable.
+ */
+trait StorageRootMixin extends BeforeAndAfterEach { self: Suite =>
+
+  /** A temporary directory created as the pipeline storage root for each 
test. */
+  protected var storageRoot: String = _
+
+  override protected def beforeEach(): Unit = {
+    super.beforeEach()
+    storageRoot =
+      Files.createTempDirectory(getClass.getSimpleName).normalize.toString
+  }
+
+  override protected def afterEach(): Unit = {
+    super.afterEach()
+    Utils.deleteRecursively(new File(storageRoot))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to