This is an automated email from the ASF dual-hosted git repository.
sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 93748ccc5979 [SPARK-52511][SDP] Support dry-run mode in
spark-pipelines command
93748ccc5979 is described below
commit 93748ccc59795c36278a9962d1851c44e106ef05
Author: Sandy Ryza <[email protected]>
AuthorDate: Thu Jul 17 18:37:39 2025 -0700
[SPARK-52511][SDP] Support dry-run mode in spark-pipelines command
### What changes were proposed in this pull request?
Adds a new `spark-pipelines` command that launches an execution of a
pipeline that doesn't write or read any data, but catches many kinds of errors
that would be caught if the pipeline were to actually run. E.g.
- Syntax errors – e.g. invalid Python or SQL code
- Analysis errors – e.g. selecting from a table that doesn't exist or
selecting a column that doesn't exist
- Graph validation errors - e.g. cyclic dependencies
### Why are the changes needed?
Leverage the declarative nature of Declarative Pipelines to make pipeline
development easier.
### Does this PR introduce _any_ user-facing change?
Adds behavior; doesn't change existing behavior.
### How was this patch tested?
- Added unit tests
- Executed `dry-run` on the CLI, for both success and error scenarios
### Was this patch authored or co-authored using generative AI tooling?
Closes #51489 from sryza/dry-run.
Lead-authored-by: Sandy Ryza <[email protected]>
Co-authored-by: Sandy Ryza <[email protected]>
Signed-off-by: Sandy Ryza <[email protected]>
---
docs/declarative-pipelines-programming-guide.md | 9 +-
python/pyspark/pipelines/cli.py | 36 ++++++--
python/pyspark/pipelines/spark_connect_pipeline.py | 8 +-
python/pyspark/pipelines/tests/test_cli.py | 3 +
.../pyspark/pipelines/tests/test_spark_connect.py | 97 ++++++++++++++++++++++
python/pyspark/sql/connect/proto/pipelines_pb2.py | 30 +++----
python/pyspark/sql/connect/proto/pipelines_pb2.pyi | 18 ++++
.../main/protobuf/spark/connect/pipelines.proto | 6 +-
.../sql/connect/pipelines/PipelinesHandler.scala | 6 +-
.../pipelines/PipelineEventStreamSuite.scala | 95 ++++++++++++++++-----
.../sql/pipelines/graph/PipelineExecution.scala | 48 +++++++----
11 files changed, 288 insertions(+), 68 deletions(-)
diff --git a/docs/declarative-pipelines-programming-guide.md
b/docs/declarative-pipelines-programming-guide.md
index 929cd07e5daa..5f938f38e1e4 100644
--- a/docs/declarative-pipelines-programming-guide.md
+++ b/docs/declarative-pipelines-programming-guide.md
@@ -94,7 +94,7 @@ The `spark-pipelines init` command, described below, makes it
easy to generate a
## The `spark-pipelines` Command Line Interface
-The `spark-pipelines` command line interface (CLI) is the primary way to
execute a pipeline. It also contains an `init` subcommand for generating a
pipeline project.
+The `spark-pipelines` command line interface (CLI) is the primary way to
execute a pipeline. It also contains an `init` subcommand for generating a
pipeline project and a `dry-run` subcommand for validating a pipeline.
`spark-pipelines` is built on top of `spark-submit`, meaning that it supports
all cluster managers supported by `spark-submit`. It supports all
`spark-submit` arguments except for `--class`.
@@ -106,6 +106,13 @@ The `spark-pipelines` command line interface (CLI) is the
primary way to execute
`spark-pipelines run` launches an execution of a pipeline and monitors its
progress until it completes. The `--spec` parameter allows selecting the
pipeline spec file. If not provided, the CLI will look in the current directory
and parent directories for a file named `pipeline.yml` or `pipeline.yaml`.
+### `spark-pipelines dry-run`
+
+`spark-pipelines dry-run` launches an execution of a pipeline that doesn't
write or read any data, but catches many kinds of errors that would be caught
if the pipeline were to actually run. E.g.
+- Syntax errors – e.g. invalid Python or SQL code
+- Analysis errors – e.g. selecting from a table that doesn't exist or
selecting a column that doesn't exist
+- Graph validation errors - e.g. cyclic dependencies
+
## Programming with SDP in Python
SDP Python functions are defined in the `pyspark.pipelines` module. Your
pipelines implemented with the Python API must import this module. It's common
to alias the module to `sdp` to limit the number of characters you need to type
when using its APIs.
diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py
index 2a0cf880d10c..cbcac35cf1b3 100644
--- a/python/pyspark/pipelines/cli.py
+++ b/python/pyspark/pipelines/cli.py
@@ -222,6 +222,7 @@ def run(
full_refresh: Sequence[str],
full_refresh_all: bool,
refresh: Sequence[str],
+ dry: bool,
) -> None:
"""Run the pipeline defined with the given spec.
@@ -276,6 +277,7 @@ def run(
full_refresh=full_refresh,
full_refresh_all=full_refresh_all,
refresh=refresh,
+ dry=dry,
)
try:
handle_pipeline_events(result_iter)
@@ -317,6 +319,13 @@ if __name__ == "__main__":
default=[],
)
+ # "dry-run" subcommand
+ dry_run_parser = subparsers.add_parser(
+ "dry-run",
+ help="Launch a run that just validates the graph and checks for
errors.",
+ )
+ dry_run_parser.add_argument("--spec", help="Path to the pipeline spec.")
+
# "init" subcommand
init_parser = subparsers.add_parser(
"init",
@@ -330,9 +339,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()
- assert args.command in ["run", "init"]
+ assert args.command in ["run", "dry-run", "init"]
- if args.command == "run":
+ if args.command in ["run", "dry-run"]:
if args.spec is not None:
spec_path = Path(args.spec)
if not spec_path.is_file():
@@ -343,11 +352,22 @@ if __name__ == "__main__":
else:
spec_path = find_pipeline_spec(Path.cwd())
- run(
- spec_path=spec_path,
- full_refresh=args.full_refresh,
- full_refresh_all=args.full_refresh_all,
- refresh=args.refresh,
- )
+ if args.command == "run":
+ run(
+ spec_path=spec_path,
+ full_refresh=args.full_refresh,
+ full_refresh_all=args.full_refresh_all,
+ refresh=args.refresh,
+ dry=args.command == "dry-run",
+ )
+ else:
+ assert args.command == "dry-run"
+ run(
+ spec_path=spec_path,
+ full_refresh=[],
+ full_refresh_all=False,
+ refresh=[],
+ dry=True,
+ )
elif args.command == "init":
init(args.name)
diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py
b/python/pyspark/pipelines/spark_connect_pipeline.py
index f430d33be4a1..61b72956e5cc 100644
--- a/python/pyspark/pipelines/spark_connect_pipeline.py
+++ b/python/pyspark/pipelines/spark_connect_pipeline.py
@@ -68,9 +68,10 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]])
-> None:
def start_run(
spark: SparkSession,
dataflow_graph_id: str,
- full_refresh: Optional[Sequence[str]] = None,
- full_refresh_all: bool = False,
- refresh: Optional[Sequence[str]] = None,
+ full_refresh: Optional[Sequence[str]],
+ full_refresh_all: bool,
+ refresh: Optional[Sequence[str]],
+ dry: bool,
) -> Iterator[Dict[str, Any]]:
"""Start a run of the dataflow graph in the Spark Connect server.
@@ -84,6 +85,7 @@ def start_run(
full_refresh_selection=full_refresh or [],
full_refresh_all=full_refresh_all,
refresh_selection=refresh or [],
+ dry=dry,
)
command = pb2.Command()
command.pipeline_command.start_run.CopyFrom(inner_command)
diff --git a/python/pyspark/pipelines/tests/test_cli.py
b/python/pyspark/pipelines/tests/test_cli.py
index ded00e691db4..ff62ce42c4a3 100644
--- a/python/pyspark/pipelines/tests/test_cli.py
+++ b/python/pyspark/pipelines/tests/test_cli.py
@@ -373,6 +373,7 @@ class CLIUtilityTests(unittest.TestCase):
full_refresh=["table1", "table2"],
full_refresh_all=True,
refresh=[],
+ dry=False,
)
self.assertEqual(
@@ -396,6 +397,7 @@ class CLIUtilityTests(unittest.TestCase):
full_refresh=[],
full_refresh_all=True,
refresh=["table1", "table2"],
+ dry=False,
)
self.assertEqual(
@@ -421,6 +423,7 @@ class CLIUtilityTests(unittest.TestCase):
full_refresh=["table1"],
full_refresh_all=True,
refresh=["table2"],
+ dry=False,
)
self.assertEqual(
diff --git a/python/pyspark/pipelines/tests/test_spark_connect.py
b/python/pyspark/pipelines/tests/test_spark_connect.py
new file mode 100644
index 000000000000..935295ec4a8c
--- /dev/null
+++ b/python/pyspark/pipelines/tests/test_spark_connect.py
@@ -0,0 +1,97 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Tests that run Pipelines against a Spark Connect server.
+"""
+
+import unittest
+
+from pyspark.errors.exceptions.connect import AnalysisException
+from pyspark.pipelines.graph_element_registry import
graph_element_registration_context
+from pyspark.pipelines.spark_connect_graph_element_registry import (
+ SparkConnectGraphElementRegistry,
+)
+from pyspark.pipelines.spark_connect_pipeline import (
+ create_dataflow_graph,
+ start_run,
+ handle_pipeline_events,
+)
+from pyspark import pipelines as sdp
+from pyspark.testing.connectutils import (
+ ReusedConnectTestCase,
+ should_test_connect,
+ connect_requirement_message,
+)
+
+
[email protected](not should_test_connect, connect_requirement_message)
+class SparkConnectPipelinesTest(ReusedConnectTestCase):
+ def test_dry_run(self):
+ dataflow_graph_id = create_dataflow_graph(self.spark, None, None, None)
+ registry = SparkConnectGraphElementRegistry(self.spark,
dataflow_graph_id)
+
+ with graph_element_registration_context(registry):
+
+ @sdp.materialized_view
+ def mv():
+ return self.spark.range(1)
+
+ result_iter = start_run(
+ self.spark,
+ dataflow_graph_id,
+ full_refresh=None,
+ refresh=None,
+ full_refresh_all=False,
+ dry=True,
+ )
+ handle_pipeline_events(result_iter)
+
+ def test_dry_run_failure(self):
+ dataflow_graph_id = create_dataflow_graph(self.spark, None, None, None)
+ registry = SparkConnectGraphElementRegistry(self.spark,
dataflow_graph_id)
+
+ with graph_element_registration_context(registry):
+
+ @sdp.table
+ def st():
+ # Invalid because a streaming query is expected
+ return self.spark.range(1)
+
+ result_iter = start_run(
+ self.spark,
+ dataflow_graph_id,
+ full_refresh=None,
+ refresh=None,
+ full_refresh_all=False,
+ dry=True,
+ )
+ with self.assertRaises(AnalysisException) as context:
+ handle_pipeline_events(result_iter)
+ self.assertIn(
+ "INVALID_FLOW_QUERY_TYPE.BATCH_RELATION_FOR_STREAMING_TABLE",
str(context.exception)
+ )
+
+
+if __name__ == "__main__":
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py
b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index e13877d05a60..9558696a3964 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9a\x14\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_f
[...]
+
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xb9\x14\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_f
[...]
)
_globals = globals()
@@ -59,10 +59,10 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options
= b"8\001"
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options =
b"8\001"
- _globals["_DATASETTYPE"]._serialized_start = 3194
- _globals["_DATASETTYPE"]._serialized_end = 3291
+ _globals["_DATASETTYPE"]._serialized_start = 3225
+ _globals["_DATASETTYPE"]._serialized_end = 3322
_globals["_PIPELINECOMMAND"]._serialized_start = 140
- _globals["_PIPELINECOMMAND"]._serialized_end = 2726
+ _globals["_PIPELINECOMMAND"]._serialized_end = 2757
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 719
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1110
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start
= 928
@@ -80,15 +80,15 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start =
928
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 986
_globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2260
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2508
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
2511
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2710
- _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2729
- _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2999
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 2886
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
2984
- _globals["_PIPELINEEVENTRESULT"]._serialized_start = 3001
- _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3074
- _globals["_PIPELINEEVENT"]._serialized_start = 3076
- _globals["_PIPELINEEVENT"]._serialized_end = 3192
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2539
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
2542
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2741
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2760
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 3030
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 2917
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
3015
+ _globals["_PIPELINEEVENTRESULT"]._serialized_start = 3032
+ _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3105
+ _globals["_PIPELINEEVENT"]._serialized_start = 3107
+ _globals["_PIPELINEEVENT"]._serialized_end = 3223
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index fff130ac3b93..6cf395a75cdb 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -533,6 +533,7 @@ class PipelineCommand(google.protobuf.message.Message):
FULL_REFRESH_SELECTION_FIELD_NUMBER: builtins.int
FULL_REFRESH_ALL_FIELD_NUMBER: builtins.int
REFRESH_SELECTION_FIELD_NUMBER: builtins.int
+ DRY_FIELD_NUMBER: builtins.int
dataflow_graph_id: builtins.str
"""The graph to start."""
@property
@@ -547,6 +548,10 @@ class PipelineCommand(google.protobuf.message.Message):
self,
) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""List of dataset to update."""
+ dry: builtins.bool
+ """If true, the run will not actually execute any flows, but will only
validate the graph and
+ check for any errors. This is useful for testing and validation
purposes.
+ """
def __init__(
self,
*,
@@ -554,16 +559,21 @@ class PipelineCommand(google.protobuf.message.Message):
full_refresh_selection: collections.abc.Iterable[builtins.str] |
None = ...,
full_refresh_all: builtins.bool | None = ...,
refresh_selection: collections.abc.Iterable[builtins.str] | None =
...,
+ dry: builtins.bool | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_dataflow_graph_id",
b"_dataflow_graph_id",
+ "_dry",
+ b"_dry",
"_full_refresh_all",
b"_full_refresh_all",
"dataflow_graph_id",
b"dataflow_graph_id",
+ "dry",
+ b"dry",
"full_refresh_all",
b"full_refresh_all",
],
@@ -573,10 +583,14 @@ class PipelineCommand(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"_dataflow_graph_id",
b"_dataflow_graph_id",
+ "_dry",
+ b"_dry",
"_full_refresh_all",
b"_full_refresh_all",
"dataflow_graph_id",
b"dataflow_graph_id",
+ "dry",
+ b"dry",
"full_refresh_all",
b"full_refresh_all",
"full_refresh_selection",
@@ -591,6 +605,10 @@ class PipelineCommand(google.protobuf.message.Message):
oneof_group: typing_extensions.Literal["_dataflow_graph_id",
b"_dataflow_graph_id"],
) -> typing_extensions.Literal["dataflow_graph_id"] | None: ...
@typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_dry", b"_dry"]
+ ) -> typing_extensions.Literal["dry"] | None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_full_refresh_all",
b"_full_refresh_all"]
) -> typing_extensions.Literal["full_refresh_all"] | None: ...
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index 751b14abe478..18a1170ceebb 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -116,7 +116,7 @@ message PipelineCommand {
message StartRun {
// The graph to start.
optional string dataflow_graph_id = 1;
-
+
// List of dataset to reset and recompute.
repeated string full_refresh_selection = 2;
@@ -125,6 +125,10 @@ message PipelineCommand {
// List of dataset to update.
repeated string refresh_selection = 4;
+
+ // If true, the run will not actually execute any flows, but will only
validate the graph and
+ // check for any errors. This is useful for testing and validation
purposes.
+ optional bool dry = 5;
}
// Parses the SQL file and registers all datasets and flows.
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index 7f92aa13944c..38d87278bbc8 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -288,7 +288,11 @@ private[connect] object PipelinesHandler extends Logging {
tableFiltersResult.refresh,
tableFiltersResult.fullRefresh)
sessionHolder.cachePipelineExecution(dataflowGraphId,
pipelineUpdateContext)
- pipelineUpdateContext.pipelineExecution.runPipeline()
+ if (cmd.getDry) {
+ pipelineUpdateContext.pipelineExecution.dryRunPipeline()
+ } else {
+ pipelineUpdateContext.pipelineExecution.runPipeline()
+ }
// Rethrow any exceptions that caused the pipeline run to fail so that the
exception is
// propagated back to the SC client / CLI.
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
index 100aa2e3b63a..83862545a723 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineEventStreamSuite.scala
@@ -70,14 +70,67 @@ class PipelineEventStreamSuite extends
SparkDeclarativePipelinesServerTest {
}
}
- test("check error events from stream") {
+ test("flow resolution failure") {
+ val dryOptions = Seq(true, false)
+
+ dryOptions.foreach { dry =>
+ withRawBlockingStub { implicit stub =>
+ val graphId = createDataflowGraph
+ val pipeline = new TestPipelineDefinition(graphId) {
+ createTable(
+ name = "a",
+ datasetType = proto.DatasetType.MATERIALIZED_VIEW,
+ sql = Some("SELECT * FROM unknown_table"))
+ createTable(
+ name = "b",
+ datasetType = proto.DatasetType.TABLE,
+ sql = Some("SELECT * FROM STREAM a"))
+ }
+ registerPipelineDatasets(pipeline)
+
+ val capturedEvents = new ArrayBuffer[PipelineEvent]()
+ withClient { client =>
+ val startRunRequest = buildStartRunPlan(
+ proto.PipelineCommand.StartRun
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setDry(dry)
+ .build())
+ val ex = intercept[AnalysisException] {
+ val responseIterator = client.execute(startRunRequest)
+ while (responseIterator.hasNext) {
+ val response = responseIterator.next()
+ if (response.hasPipelineEventResult) {
+ capturedEvents.append(response.getPipelineEventResult.getEvent)
+ }
+ }
+ }
+ // (?s) enables wildcard matching on newline characters
+ val runFailureErrorMsg = "(?s).*Failed to resolve flows in the
pipeline.*".r
+ assert(runFailureErrorMsg.matches(ex.getMessage))
+ val expectedLogPatterns = Set(
+ "(?s).*Failed to resolve flow.*Failed to read dataset
'spark_catalog.default.a'.*".r,
+ "(?s).*Failed to resolve flow.*[TABLE_OR_VIEW_NOT_FOUND].*".r)
+ expectedLogPatterns.foreach { logPattern =>
+ assert(
+ capturedEvents.exists(e => logPattern.matches(e.getMessage)),
+ s"Did not receive expected event matching pattern: $logPattern")
+ }
+ // Ensure that the error causing the run failure is not surfaced to
the user twice
+ assert(capturedEvents.forall(e =>
!runFailureErrorMsg.matches(e.getMessage)))
+ }
+ }
+ }
+ }
+
+ test("successful dry run") {
withRawBlockingStub { implicit stub =>
val graphId = createDataflowGraph
val pipeline = new TestPipelineDefinition(graphId) {
createTable(
name = "a",
datasetType = proto.DatasetType.MATERIALIZED_VIEW,
- sql = Some("SELECT * FROM unknown_table"))
+ sql = Some("SELECT * FROM RANGE(5)"))
createTable(
name = "b",
datasetType = proto.DatasetType.TABLE,
@@ -88,29 +141,29 @@ class PipelineEventStreamSuite extends
SparkDeclarativePipelinesServerTest {
val capturedEvents = new ArrayBuffer[PipelineEvent]()
withClient { client =>
val startRunRequest = buildStartRunPlan(
-
proto.PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())
- val ex = intercept[AnalysisException] {
- val responseIterator = client.execute(startRunRequest)
- while (responseIterator.hasNext) {
- val response = responseIterator.next()
- if (response.hasPipelineEventResult) {
- capturedEvents.append(response.getPipelineEventResult.getEvent)
- }
+ proto.PipelineCommand.StartRun
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setDry(true)
+ .build())
+ val responseIterator = client.execute(startRunRequest)
+ while (responseIterator.hasNext) {
+ val response = responseIterator.next()
+ if (response.hasPipelineEventResult) {
+ capturedEvents.append(response.getPipelineEventResult.getEvent)
}
}
- // (?s) enables wildcard matching on newline characters
- val runFailureErrorMsg = "(?s).*Failed to resolve flows in the
pipeline.*".r
- assert(runFailureErrorMsg.matches(ex.getMessage))
- val expectedLogPatterns = Set(
- "(?s).*Failed to resolve flow.*Failed to read dataset
'spark_catalog.default.a'.*".r,
- "(?s).*Failed to resolve flow.*[TABLE_OR_VIEW_NOT_FOUND].*".r)
- expectedLogPatterns.foreach { logPattern =>
+ val expectedEventMessages = Set("Run is COMPLETED")
+ expectedEventMessages.foreach { eventMessage =>
assert(
- capturedEvents.exists(e => logPattern.matches(e.getMessage)),
- s"Did not receive expected event matching pattern: $logPattern")
+ capturedEvents.exists(e => e.getMessage.contains(eventMessage)),
+ s"Did not receive expected event: $eventMessage")
}
- // Ensure that the error causing the run failure is not surfaced to
the user twice
- assert(capturedEvents.forall(e =>
!runFailureErrorMsg.matches(e.getMessage)))
+ }
+
+ // No flows should be started in dry run mode
+ capturedEvents.foreach { event =>
+ assert(!event.getMessage.contains("is QUEUED"))
}
}
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
index a2c54a908af1..5bb6e25eaf45 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.pipelines.common.RunState
import org.apache.spark.sql.pipelines.logging.{
ConstructPipelineEvent,
EventLevel,
+ PipelineEvent,
PipelineEventOrigin,
RunProgress
}
@@ -38,31 +39,19 @@ class PipelineExecution(context: PipelineUpdateContext) {
def executionStarted: Boolean = synchronized { graphExecution.nonEmpty }
-
/**
* Starts the pipeline execution by initializing the graph and starting the
graph execution
* thread. This function does not block on the completion of the graph
execution thread.
*/
def startPipeline(): Unit = synchronized {
// Initialize the graph.
- val initializedGraph = initializeGraph()
+ val resolvedGraph = resolveGraph()
+ val initializedGraph = DatasetManager.materializeDatasets(resolvedGraph,
context)
// Execute the graph.
graphExecution = Option(
new TriggeredGraphExecution(initializedGraph, context, onCompletion =
terminationReason => {
- context.eventCallback(
- ConstructPipelineEvent(
- origin = PipelineEventOrigin(
- flowName = None,
- datasetName = None,
- sourceCodeLocation = None
- ),
- level = EventLevel.INFO,
- message = terminationReason.message,
- details = RunProgress(terminationReason.terminalState),
- exception = terminationReason.cause
- )
- )
+ context.eventCallback(constructTerminationEvent(terminationReason))
})
)
graphExecution.foreach(_.start())
@@ -91,15 +80,38 @@ class PipelineExecution(context: PipelineUpdateContext) {
}
}
- private def initializeGraph(): DataflowGraph = {
- val resolvedGraph = try {
+ /** Validates that the pipeline graph can be successfully resolved and
validates it. */
+ def dryRunPipeline(): Unit = synchronized {
+ resolveGraph()
+ context.eventCallback(
+ constructTerminationEvent(RunCompletion())
+ )
+ }
+
+ private def constructTerminationEvent(
+ terminationReason: RunTerminationReason
+ ): PipelineEvent = {
+ ConstructPipelineEvent(
+ origin = PipelineEventOrigin(
+ flowName = None,
+ datasetName = None,
+ sourceCodeLocation = None
+ ),
+ level = EventLevel.INFO,
+ message = terminationReason.message,
+ details = RunProgress(terminationReason.terminalState),
+ exception = terminationReason.cause
+ )
+ }
+
+ private def resolveGraph(): DataflowGraph = {
+ try {
context.unresolvedGraph.resolve().validate()
} catch {
case e: UnresolvedPipelineException =>
handleInvalidPipeline(e)
throw e
}
- DatasetManager.materializeDatasets(resolvedGraph, context)
}
/** Waits for the execution to complete. Only used in tests */
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]