This is an automated email from the ASF dual-hosted git repository.

sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 93748ccc5979 [SPARK-52511][SDP] Support dry-run mode in 
spark-pipelines command
93748ccc5979 is described below

commit 93748ccc59795c36278a9962d1851c44e106ef05
Author: Sandy Ryza <[email protected]>
AuthorDate: Thu Jul 17 18:37:39 2025 -0700

    [SPARK-52511][SDP] Support dry-run mode in spark-pipelines command
    
    ### What changes were proposed in this pull request?
    
    Adds a new `spark-pipelines` command that launches an execution of a 
pipeline that doesn't write or read any data, but catches many kinds of errors 
that would be caught if the pipeline were to actually run. E.g.
    - Syntax errors – e.g. invalid Python or SQL code
    - Analysis errors – e.g. selecting from a table that doesn't exist or 
selecting a column that doesn't exist
    - Graph validation errors - e.g. cyclic dependencies
    
    ### Why are the changes needed?
    
    Leverage the declarative nature of Declarative Pipelines to make pipeline 
development easier.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Adds behavior; doesn't change existing behavior.
    
    ### How was this patch tested?
    
    - Added unit tests
    - Executed `dry-run` on the CLI, for both success and error scenarios
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #51489 from sryza/dry-run.
    
    Lead-authored-by: Sandy Ryza <[email protected]>
    Co-authored-by: Sandy Ryza <[email protected]>
    Signed-off-by: Sandy Ryza <[email protected]>
---
 docs/declarative-pipelines-programming-guide.md    |  9 +-
 python/pyspark/pipelines/cli.py                    | 36 ++++++--
 python/pyspark/pipelines/spark_connect_pipeline.py |  8 +-
 python/pyspark/pipelines/tests/test_cli.py         |  3 +
 .../pyspark/pipelines/tests/test_spark_connect.py  | 97 ++++++++++++++++++++++
 python/pyspark/sql/connect/proto/pipelines_pb2.py  | 30 +++----
 python/pyspark/sql/connect/proto/pipelines_pb2.pyi | 18 ++++
 .../main/protobuf/spark/connect/pipelines.proto    |  6 +-
 .../sql/connect/pipelines/PipelinesHandler.scala   |  6 +-
 .../pipelines/PipelineEventStreamSuite.scala       | 95 ++++++++++++++++-----
 .../sql/pipelines/graph/PipelineExecution.scala    | 48 +++++++----
 11 files changed, 288 insertions(+), 68 deletions(-)

diff --git a/docs/declarative-pipelines-programming-guide.md 
b/docs/declarative-pipelines-programming-guide.md
index 929cd07e5daa..5f938f38e1e4 100644
--- a/docs/declarative-pipelines-programming-guide.md
+++ b/docs/declarative-pipelines-programming-guide.md
@@ -94,7 +94,7 @@ The `spark-pipelines init` command, described below, makes it 
easy to generate a
 
 ## The `spark-pipelines` Command Line Interface
 
-The `spark-pipelines` command line interface (CLI) is the primary way to 
execute a pipeline. It also contains an `init` subcommand for generating a 
pipeline project.
+The `spark-pipelines` command line interface (CLI) is the primary way to 
execute a pipeline. It also contains an `init` subcommand for generating a 
pipeline project and a `dry-run` subcommand for validating a pipeline.
 
 `spark-pipelines` is built on top of `spark-submit`, meaning that it supports 
all cluster managers supported by `spark-submit`. It supports all 
`spark-submit` arguments except for `--class`.
 
@@ -106,6 +106,13 @@ The `spark-pipelines` command line interface (CLI) is the 
primary way to execute
 
 `spark-pipelines run` launches an execution of a pipeline and monitors its 
progress until it completes. The `--spec` parameter allows selecting the 
pipeline spec file. If not provided, the CLI will look in the current directory 
and parent directories for a file named `pipeline.yml` or `pipeline.yaml`.
 
+### `spark-pipelines dry-run`
+
+`spark-pipelines dry-run` launches an execution of a pipeline that doesn't 
write or read any data, but catches many kinds of errors that would be caught 
if the pipeline were to actually run. E.g.
+- Syntax errors – e.g. invalid Python or SQL code
+- Analysis errors – e.g. selecting from a table that doesn't exist or 
selecting a column that doesn't exist
+- Graph validation errors - e.g. cyclic dependencies
+
 ## Programming with SDP in Python
 
 SDP Python functions are defined in the `pyspark.pipelines` module. Your 
pipelines implemented with the Python API must import this module. It's common 
to alias the module to `sdp` to limit the number of characters you need to type 
when using its APIs.
diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py
index 2a0cf880d10c..cbcac35cf1b3 100644
--- a/python/pyspark/pipelines/cli.py
+++ b/python/pyspark/pipelines/cli.py
@@ -222,6 +222,7 @@ def run(
     full_refresh: Sequence[str],
     full_refresh_all: bool,
     refresh: Sequence[str],
+    dry: bool,
 ) -> None:
     """Run the pipeline defined with the given spec.
 
@@ -276,6 +277,7 @@ def run(
         full_refresh=full_refresh,
         full_refresh_all=full_refresh_all,
         refresh=refresh,
+        dry=dry,
     )
     try:
         handle_pipeline_events(result_iter)
@@ -317,6 +319,13 @@ if __name__ == "__main__":
         default=[],
     )
 
+    # "dry-run" subcommand
+    dry_run_parser = subparsers.add_parser(
+        "dry-run",
+        help="Launch a run that just validates the graph and checks for 
errors.",
+    )
+    dry_run_parser.add_argument("--spec", help="Path to the pipeline spec.")
+
     # "init" subcommand
     init_parser = subparsers.add_parser(
         "init",
@@ -330,9 +339,9 @@ if __name__ == "__main__":
     )
 
     args = parser.parse_args()
-    assert args.command in ["run", "init"]
+    assert args.command in ["run", "dry-run", "init"]
 
-    if args.command == "run":
+    if args.command in ["run", "dry-run"]:
         if args.spec is not None:
             spec_path = Path(args.spec)
             if not spec_path.is_file():
@@ -343,11 +352,22 @@ if __name__ == "__main__":
         else:
             spec_path = find_pipeline_spec(Path.cwd())
 
-        run(
-            spec_path=spec_path,
-            full_refresh=args.full_refresh,
-            full_refresh_all=args.full_refresh_all,
-            refresh=args.refresh,
-        )
+        if args.command == "run":
+            run(
+                spec_path=spec_path,
+                full_refresh=args.full_refresh,
+                full_refresh_all=args.full_refresh_all,
+                refresh=args.refresh,
+                dry=args.command == "dry-run",
+            )
+        else:
+            assert args.command == "dry-run"
+            run(
+                spec_path=spec_path,
+                full_refresh=[],
+                full_refresh_all=False,
+                refresh=[],
+                dry=True,
+            )
     elif args.command == "init":
         init(args.name)
diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py 
b/python/pyspark/pipelines/spark_connect_pipeline.py
index f430d33be4a1..61b72956e5cc 100644
--- a/python/pyspark/pipelines/spark_connect_pipeline.py
+++ b/python/pyspark/pipelines/spark_connect_pipeline.py
@@ -68,9 +68,10 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) 
-> None:
 def start_run(
     spark: SparkSession,
     dataflow_graph_id: str,
-    full_refresh: Optional[Sequence[str]] = None,
-    full_refresh_all: bool = False,
-    refresh: Optional[Sequence[str]] = None,
+    full_refresh: Optional[Sequence[str]],
+    full_refresh_all: bool,
+    refresh: Optional[Sequence[str]],
+    dry: bool,
 ) -> Iterator[Dict[str, Any]]:
     """Start a run of the dataflow graph in the Spark Connect server.
 
@@ -84,6 +85,7 @@ def start_run(
         full_refresh_selection=full_refresh or [],
         full_refresh_all=full_refresh_all,
         refresh_selection=refresh or [],
+        dry=dry,
     )
     command = pb2.Command()
     command.pipeline_command.start_run.CopyFrom(inner_command)
diff --git a/python/pyspark/pipelines/tests/test_cli.py 
b/python/pyspark/pipelines/tests/test_cli.py
index ded00e691db4..ff62ce42c4a3 100644
--- a/python/pyspark/pipelines/tests/test_cli.py
+++ b/python/pyspark/pipelines/tests/test_cli.py
@@ -373,6 +373,7 @@ class CLIUtilityTests(unittest.TestCase):
                     full_refresh=["table1", "table2"],
                     full_refresh_all=True,
                     refresh=[],
+                    dry=False,
                 )
 
             self.assertEqual(
@@ -396,6 +397,7 @@ class CLIUtilityTests(unittest.TestCase):
                     full_refresh=[],
                     full_refresh_all=True,
                     refresh=["table1", "table2"],
+                    dry=False,
                 )
 
             self.assertEqual(
@@ -421,6 +423,7 @@ class CLIUtilityTests(unittest.TestCase):
                     full_refresh=["table1"],
                     full_refresh_all=True,
                     refresh=["table2"],
+                    dry=False,
                 )
 
             self.assertEqual(
diff --git a/python/pyspark/pipelines/tests/test_spark_connect.py 
b/python/pyspark/pipelines/tests/test_spark_connect.py
new file mode 100644
index 000000000000..935295ec4a8c
--- /dev/null
+++ b/python/pyspark/pipelines/tests/test_spark_connect.py
@@ -0,0 +1,97 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Tests that run Pipelines against a Spark Connect server.
+"""
+
+import unittest
+
+from pyspark.errors.exceptions.connect import AnalysisException
+from pyspark.pipelines.graph_element_registry import 
graph_element_registration_context
+from pyspark.pipelines.spark_connect_graph_element_registry import (
+    SparkConnectGraphElementRegistry,
+)
+from pyspark.pipelines.spark_connect_pipeline import (
+    create_dataflow_graph,
+    start_run,
+    handle_pipeline_events,
+)
+from pyspark import pipelines as sdp
+from pyspark.testing.connectutils import (
+    ReusedConnectTestCase,
+    should_test_connect,
+    connect_requirement_message,
+)
+
+
[email protected](not should_test_connect, connect_requirement_message)
+class SparkConnectPipelinesTest(ReusedConnectTestCase):
+    def test_dry_run(self):
+        dataflow_graph_id = create_dataflow_graph(self.spark, None, None, None)
+        registry = SparkConnectGraphElementRegistry(self.spark, 
dataflow_graph_id)
+
+        with graph_element_registration_context(registry):
+
+            @sdp.materialized_view
+            def mv():
+                return self.spark.range(1)
+
+        result_iter = start_run(
+            self.spark,
+            dataflow_graph_id,
+            full_refresh=None,
+            refresh=None,
+            full_refresh_all=False,
+            dry=True,
+        )
+        handle_pipeline_events(result_iter)
+
+    def test_dry_run_failure(self):
+        dataflow_graph_id = create_dataflow_graph(self.spark, None, None, None)
+        registry = SparkConnectGraphElementRegistry(self.spark, 
dataflow_graph_id)
+
+        with graph_element_registration_context(registry):
+
+            @sdp.table
+            def st():
+                # Invalid because a streaming query is expected
+                return self.spark.range(1)
+
+        result_iter = start_run(
+            self.spark,
+            dataflow_graph_id,
+            full_refresh=None,
+            refresh=None,
+            full_refresh_all=False,
+            dry=True,
+        )
+        with self.assertRaises(AnalysisException) as context:
+            handle_pipeline_events(result_iter)
+        self.assertIn(
+            "INVALID_FLOW_QUERY_TYPE.BATCH_RELATION_FOR_STREAMING_TABLE", 
str(context.exception)
+        )
+
+
+if __name__ == "__main__":
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py 
b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index e13877d05a60..9558696a3964 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9a\x14\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
 
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
 
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_f
 [...]
+    
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xb9\x14\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
 
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
 
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_f
 [...]
 )
 
 _globals = globals()
@@ -59,10 +59,10 @@ if not _descriptor._USE_C_DESCRIPTORS:
     
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options
 = b"8\001"
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = 
b"8\001"
-    _globals["_DATASETTYPE"]._serialized_start = 3194
-    _globals["_DATASETTYPE"]._serialized_end = 3291
+    _globals["_DATASETTYPE"]._serialized_start = 3225
+    _globals["_DATASETTYPE"]._serialized_end = 3322
     _globals["_PIPELINECOMMAND"]._serialized_start = 140
-    _globals["_PIPELINECOMMAND"]._serialized_end = 2726
+    _globals["_PIPELINECOMMAND"]._serialized_end = 2757
     _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 719
     _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1110
     
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start 
= 928
@@ -80,15 +80,15 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start = 
928
     _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 986
     _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2260
-    _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2508
-    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 
2511
-    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2710
-    _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2729
-    _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2999
-    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start 
= 2886
-    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 
2984
-    _globals["_PIPELINEEVENTRESULT"]._serialized_start = 3001
-    _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3074
-    _globals["_PIPELINEEVENT"]._serialized_start = 3076
-    _globals["_PIPELINEEVENT"]._serialized_end = 3192
+    _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2539
+    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 
2542
+    _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2741
+    _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2760
+    _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 3030
+    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start 
= 2917
+    
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 
3015
+    _globals["_PIPELINEEVENTRESULT"]._serialized_start = 3032
+    _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3105
+    _globals["_PIPELINEEVENT"]._serialized_start = 3107
+    _globals["_PIPELINEEVENT"]._serialized_end = 3223
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi 
b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index fff130ac3b93..6cf395a75cdb 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -533,6 +533,7 @@ class PipelineCommand(google.protobuf.message.Message):
         FULL_REFRESH_SELECTION_FIELD_NUMBER: builtins.int
         FULL_REFRESH_ALL_FIELD_NUMBER: builtins.int
         REFRESH_SELECTION_FIELD_NUMBER: builtins.int
+        DRY_FIELD_NUMBER: builtins.int
         dataflow_graph_id: builtins.str
         """The graph to start."""
         @property
@@ -547,6 +548,10 @@ class PipelineCommand(google.protobuf.message.Message):
             self,
         ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
             """List of dataset to update."""
+        dry: builtins.bool
+        """If true, the run will not actually execute any flows, but will only 
validate the graph and
+        check for any errors. This is useful for testing and validation 
purposes.
+        """
         def __init__(
             self,
             *,
@@ -554,16 +559,21 @@ class PipelineCommand(google.protobuf.message.Message):
             full_refresh_selection: collections.abc.Iterable[builtins.str] | 
None = ...,
             full_refresh_all: builtins.bool | None = ...,
             refresh_selection: collections.abc.Iterable[builtins.str] | None = 
...,
+            dry: builtins.bool | None = ...,
         ) -> None: ...
         def HasField(
             self,
             field_name: typing_extensions.Literal[
                 "_dataflow_graph_id",
                 b"_dataflow_graph_id",
+                "_dry",
+                b"_dry",
                 "_full_refresh_all",
                 b"_full_refresh_all",
                 "dataflow_graph_id",
                 b"dataflow_graph_id",
+                "dry",
+                b"dry",
                 "full_refresh_all",
                 b"full_refresh_all",
             ],
@@ -573,10 +583,14 @@ class PipelineCommand(google.protobuf.message.Message):
             field_name: typing_extensions.Literal[
                 "_dataflow_graph_id",
                 b"_dataflow_graph_id",
+                "_dry",
+                b"_dry",
                 "_full_refresh_all",
                 b"_full_refresh_all",
                 "dataflow_graph_id",
                 b"dataflow_graph_id",
+                "dry",
+                b"dry",
                 "full_refresh_all",
                 b"full_refresh_all",
                 "full_refresh_selection",
@@ -591,6 +605,10 @@ class PipelineCommand(google.protobuf.message.Message):
             oneof_group: typing_extensions.Literal["_dataflow_graph_id", 
b"_dataflow_graph_id"],
         ) -> typing_extensions.Literal["dataflow_graph_id"] | None: ...
         @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_dry", b"_dry"]
+        ) -> typing_extensions.Literal["dry"] | None: ...
+        @typing.overload
         def WhichOneof(
             self, oneof_group: typing_extensions.Literal["_full_refresh_all", 
b"_full_refresh_all"]
         ) -> typing_extensions.Literal["full_refresh_all"] | None: ...
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index 751b14abe478..18a1170ceebb 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -116,7 +116,7 @@ message PipelineCommand {
   message StartRun {
     // The graph to start.
     optional string dataflow_graph_id = 1;
-    
+
     // List of dataset to reset and recompute.
     repeated string full_refresh_selection = 2;
     
@@ -125,6 +125,10 @@ message PipelineCommand {
     
     // List of dataset to update.
     repeated string refresh_selection = 4;
+
+    // If true, the run will not actually execute any flows, but will only 
validate the graph and
+    // check for any errors. This is useful for testing and validation 
purposes.
+    optional bool dry = 5;
   }
 
   // Parses the SQL file and registers all datasets and flows.
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index 7f92aa13944c..38d87278bbc8 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -288,7 +288,11 @@ private[connect] object PipelinesHandler extends Logging {
       tableFiltersResult.refresh,
       tableFiltersResult.fullRefresh)
     sessionHolder.cachePipelineExecution(dataflowGraphId, 
pipelineUpdateContext)
-    pipelineUpdateContext.pipelineExecution.runPipeline()
+    if (cmd.getDry) {
+      pipelineUpdateContext.pipelineExecution.dryRunPipeline()
+    } else {
+      pipelineUpdateContext.pipelineExecution.runPipeline()
+    }
 
     // Rethrow any exceptions that caused the pipeline run to fail so that the 
exception is
     // propagated back to the SC client / CLI.
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
index 100aa2e3b63a..83862545a723 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
@@ -70,14 +70,67 @@ class PipelineEventStreamSuite extends 
SparkDeclarativePipelinesServerTest {
     }
   }
 
-  test("check error events from stream") {
+  test("flow resolution failure") {
+    val dryOptions = Seq(true, false)
+
+    dryOptions.foreach { dry =>
+      withRawBlockingStub { implicit stub =>
+        val graphId = createDataflowGraph
+        val pipeline = new TestPipelineDefinition(graphId) {
+          createTable(
+            name = "a",
+            datasetType = proto.DatasetType.MATERIALIZED_VIEW,
+            sql = Some("SELECT * FROM unknown_table"))
+          createTable(
+            name = "b",
+            datasetType = proto.DatasetType.TABLE,
+            sql = Some("SELECT * FROM STREAM a"))
+        }
+        registerPipelineDatasets(pipeline)
+
+        val capturedEvents = new ArrayBuffer[PipelineEvent]()
+        withClient { client =>
+          val startRunRequest = buildStartRunPlan(
+            proto.PipelineCommand.StartRun
+              .newBuilder()
+              .setDataflowGraphId(graphId)
+              .setDry(dry)
+              .build())
+          val ex = intercept[AnalysisException] {
+            val responseIterator = client.execute(startRunRequest)
+            while (responseIterator.hasNext) {
+              val response = responseIterator.next()
+              if (response.hasPipelineEventResult) {
+                capturedEvents.append(response.getPipelineEventResult.getEvent)
+              }
+            }
+          }
+          // (?s) enables wildcard matching on newline characters
+          val runFailureErrorMsg = "(?s).*Failed to resolve flows in the 
pipeline.*".r
+          assert(runFailureErrorMsg.matches(ex.getMessage))
+          val expectedLogPatterns = Set(
+            "(?s).*Failed to resolve flow.*Failed to read dataset 
'spark_catalog.default.a'.*".r,
+            "(?s).*Failed to resolve flow.*[TABLE_OR_VIEW_NOT_FOUND].*".r)
+          expectedLogPatterns.foreach { logPattern =>
+            assert(
+              capturedEvents.exists(e => logPattern.matches(e.getMessage)),
+              s"Did not receive expected event matching pattern: $logPattern")
+          }
+          // Ensure that the error causing the run failure is not surfaced to 
the user twice
+          assert(capturedEvents.forall(e => 
!runFailureErrorMsg.matches(e.getMessage)))
+        }
+      }
+    }
+  }
+
+  test("successful dry run") {
     withRawBlockingStub { implicit stub =>
       val graphId = createDataflowGraph
       val pipeline = new TestPipelineDefinition(graphId) {
         createTable(
           name = "a",
           datasetType = proto.DatasetType.MATERIALIZED_VIEW,
-          sql = Some("SELECT * FROM unknown_table"))
+          sql = Some("SELECT * FROM RANGE(5)"))
         createTable(
           name = "b",
           datasetType = proto.DatasetType.TABLE,
@@ -88,29 +141,29 @@ class PipelineEventStreamSuite extends 
SparkDeclarativePipelinesServerTest {
       val capturedEvents = new ArrayBuffer[PipelineEvent]()
       withClient { client =>
         val startRunRequest = buildStartRunPlan(
-          
proto.PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())
-        val ex = intercept[AnalysisException] {
-          val responseIterator = client.execute(startRunRequest)
-          while (responseIterator.hasNext) {
-            val response = responseIterator.next()
-            if (response.hasPipelineEventResult) {
-              capturedEvents.append(response.getPipelineEventResult.getEvent)
-            }
+          proto.PipelineCommand.StartRun
+            .newBuilder()
+            .setDataflowGraphId(graphId)
+            .setDry(true)
+            .build())
+        val responseIterator = client.execute(startRunRequest)
+        while (responseIterator.hasNext) {
+          val response = responseIterator.next()
+          if (response.hasPipelineEventResult) {
+            capturedEvents.append(response.getPipelineEventResult.getEvent)
           }
         }
-        // (?s) enables wildcard matching on newline characters
-        val runFailureErrorMsg = "(?s).*Failed to resolve flows in the 
pipeline.*".r
-        assert(runFailureErrorMsg.matches(ex.getMessage))
-        val expectedLogPatterns = Set(
-          "(?s).*Failed to resolve flow.*Failed to read dataset 
'spark_catalog.default.a'.*".r,
-          "(?s).*Failed to resolve flow.*[TABLE_OR_VIEW_NOT_FOUND].*".r)
-        expectedLogPatterns.foreach { logPattern =>
+        val expectedEventMessages = Set("Run is COMPLETED")
+        expectedEventMessages.foreach { eventMessage =>
           assert(
-            capturedEvents.exists(e => logPattern.matches(e.getMessage)),
-            s"Did not receive expected event matching pattern: $logPattern")
+            capturedEvents.exists(e => e.getMessage.contains(eventMessage)),
+            s"Did not receive expected event: $eventMessage")
         }
-        // Ensure that the error causing the run failure is not surfaced to 
the user twice
-        assert(capturedEvents.forall(e => 
!runFailureErrorMsg.matches(e.getMessage)))
+      }
+
+      // No flows should be started in dry run mode
+      capturedEvents.foreach { event =>
+        assert(!event.getMessage.contains("is QUEUED"))
       }
     }
   }
diff --git 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
index a2c54a908af1..5bb6e25eaf45 100644
--- 
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
+++ 
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.pipelines.common.RunState
 import org.apache.spark.sql.pipelines.logging.{
   ConstructPipelineEvent,
   EventLevel,
+  PipelineEvent,
   PipelineEventOrigin,
   RunProgress
 }
@@ -38,31 +39,19 @@ class PipelineExecution(context: PipelineUpdateContext) {
 
   def executionStarted: Boolean = synchronized { graphExecution.nonEmpty }
 
-
   /**
    * Starts the pipeline execution by initializing the graph and starting the 
graph execution
    * thread. This function does not block on the completion of the graph 
execution thread.
    */
   def startPipeline(): Unit = synchronized {
     // Initialize the graph.
-    val initializedGraph = initializeGraph()
+    val resolvedGraph = resolveGraph()
+    val initializedGraph = DatasetManager.materializeDatasets(resolvedGraph, 
context)
 
     // Execute the graph.
     graphExecution = Option(
       new TriggeredGraphExecution(initializedGraph, context, onCompletion = 
terminationReason => {
-        context.eventCallback(
-          ConstructPipelineEvent(
-            origin = PipelineEventOrigin(
-              flowName = None,
-              datasetName = None,
-              sourceCodeLocation = None
-            ),
-            level = EventLevel.INFO,
-            message = terminationReason.message,
-            details = RunProgress(terminationReason.terminalState),
-            exception = terminationReason.cause
-          )
-        )
+        context.eventCallback(constructTerminationEvent(terminationReason))
       })
     )
     graphExecution.foreach(_.start())
@@ -91,15 +80,38 @@ class PipelineExecution(context: PipelineUpdateContext) {
     }
   }
 
-  private def initializeGraph(): DataflowGraph = {
-    val resolvedGraph = try {
+  /** Validates that the pipeline graph can be successfully resolved and 
validates it. */
+  def dryRunPipeline(): Unit = synchronized {
+    resolveGraph()
+    context.eventCallback(
+      constructTerminationEvent(RunCompletion())
+    )
+  }
+
+  private def constructTerminationEvent(
+      terminationReason: RunTerminationReason
+  ): PipelineEvent = {
+    ConstructPipelineEvent(
+      origin = PipelineEventOrigin(
+        flowName = None,
+        datasetName = None,
+        sourceCodeLocation = None
+      ),
+      level = EventLevel.INFO,
+      message = terminationReason.message,
+      details = RunProgress(terminationReason.terminalState),
+      exception = terminationReason.cause
+    )
+  }
+
+  private def resolveGraph(): DataflowGraph = {
+    try {
       context.unresolvedGraph.resolve().validate()
     } catch {
       case e: UnresolvedPipelineException =>
         handleInvalidPipeline(e)
         throw e
     }
-    DatasetManager.materializeDatasets(resolvedGraph, context)
   }
 
   /** Waits for the execution to complete. Only used in tests */


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to