This is an automated email from the ASF dual-hosted git repository.
sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e38a65180564 [SPARK-53845] SDP Sinks
e38a65180564 is described below
commit e38a651805649f1f922255deb883d5231a39cdf8
Author: Jacky Wang <[email protected]>
AuthorDate: Mon Oct 13 10:07:07 2025 -0700
[SPARK-53845] SDP Sinks
### What changes were proposed in this pull request?
Create the API and implementation to allow users to define a sink in a SDP
pipeline as outlined in the SPIP. A sink is an external location the pipeline
can write data to.
```python
dp.create_sink(
"myParquetSink",
format = "parquet",
options = {"path": "${dir.getPath}"}
)
dp.append_flow(
target = "myParquetSink",
)
def mySinkFlow():
return spark.readStream.table("src")
```
### Why are the changes needed?
New Feature
### Does this PR introduce _any_ user-facing change?
New API for unrelased SDP
### How was this patch tested?
New and existing tests to ensure sinks work e2e
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52563 from JiaqiWang18/SPARK-53845-sdp-sinks.
Lead-authored-by: Jacky Wang <[email protected]>
Co-authored-by: Jacky Wang <[email protected]>
Signed-off-by: Sandy Ryza <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 8 +-
python/pyspark/pipelines/__init__.py | 2 +
python/pyspark/pipelines/api.py | 43 +++++++
python/pyspark/pipelines/output.py | 9 ++
.../spark_connect_graph_element_registry.py | 10 ++
.../pipelines/tests/test_graph_element_registry.py | 18 ++-
.../sql/connect/pipelines/PipelinesHandler.scala | 38 ++++--
.../connect/pipelines/PythonPipelineSuite.scala | 32 ++++-
.../service/SparkConnectSessionHolderSuite.scala | 2 +-
.../service/SparkConnectSessionManagerSuite.scala | 2 +-
.../graph/CoreDataflowNodeProcessor.scala | 1 +
.../spark/sql/pipelines/graph/DataflowGraph.scala | 15 ++-
.../pipelines/graph/DataflowGraphTransformer.scala | 56 ++++++---
.../spark/sql/pipelines/graph/FlowExecution.scala | 27 +++++
.../spark/sql/pipelines/graph/FlowPlanner.scala | 11 ++
.../pipelines/graph/GraphIdentifierManager.scala | 21 ++++
.../pipelines/graph/GraphRegistrationContext.scala | 46 ++++---
.../sql/pipelines/graph/QueryOriginType.scala | 2 +-
.../apache/spark/sql/pipelines/graph/State.scala | 13 +-
.../spark/sql/pipelines/graph/SystemMetadata.scala | 3 +-
.../spark/sql/pipelines/graph/elements.scala | 23 +++-
.../graph/ConnectValidPipelineSuite.scala | 20 ++++
.../pipelines/graph/MaterializeTablesSuite.scala | 2 +-
.../sql/pipelines/graph/SinkExecutionSuite.scala | 132 +++++++++++++++++++++
.../sql/pipelines/graph/SqlPipelineSuite.scala | 8 +-
.../sql/pipelines/graph/SystemMetadataSuite.scala | 64 +++++++++-
.../apache/spark/sql/pipelines/utils/APITest.scala | 55 +++++++++
.../utils/TestGraphRegistrationContext.scala | 66 ++++++++---
28 files changed, 648 insertions(+), 81 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index afe56f6db2fd..a53992a85187 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -4783,14 +4783,14 @@
"A duplicate identifier was found for elements registered in the
pipeline's dataflow graph."
],
"subClass" : {
- "DATASET" : {
+ "FLOW" : {
"message" : [
- "Attempted to register a <datasetType1> with identifier
<datasetName>, but a <datasetType2> has already been registered with that
identifier. Please ensure all datasets created within this pipeline have unique
identifiers."
+ "Flow <flowName> was found in multiple datasets: <datasetNames>"
]
},
- "FLOW" : {
+ "OUTPUT" : {
"message" : [
- "Flow <flowName> was found in multiple datasets: <datasetNames>"
+ "Attempted to register a <outputType1> with identifier <outputName>,
but a <outputType2> has already been registered with that identifier. Please
ensure all outputs created within this pipeline have unique identifiers."
]
}
},
diff --git a/python/pyspark/pipelines/__init__.py
b/python/pyspark/pipelines/__init__.py
index 1bb1c79a2926..d93320e96376 100644
--- a/python/pyspark/pipelines/__init__.py
+++ b/python/pyspark/pipelines/__init__.py
@@ -20,6 +20,7 @@ from pyspark.pipelines.api import (
materialized_view,
table,
temporary_view,
+ create_sink,
)
__all__ = [
@@ -28,4 +29,5 @@ __all__ = [
"materialized_view",
"table",
"temporary_view",
+ "create_sink",
]
diff --git a/python/pyspark/pipelines/api.py b/python/pyspark/pipelines/api.py
index f0aaf1a6ee41..b68cc30b43a7 100644
--- a/python/pyspark/pipelines/api.py
+++ b/python/pyspark/pipelines/api.py
@@ -27,6 +27,7 @@ from pyspark.pipelines.output import (
MaterializedView,
StreamingTable,
TemporaryView,
+ Sink,
)
from pyspark.sql.types import StructType
@@ -447,3 +448,45 @@ def create_streaming_table(
format=format,
)
get_active_graph_element_registry().register_output(table)
+
+
+def create_sink(
+ name: str,
+ format: str,
+ options: Optional[Dict[str, str]] = None,
+) -> None:
+ """
+ Creates a sink that can be targeted by streaming flows, providing a
generic destination \
+ for flows to send data external to the pipeline.
+
+ :param name: The name of the sink.
+ :param format: The format of the sink, e.g. "parquet".
+ :param options: A dict where the keys are the property names and the
values are the \
+ property values. These properties will be set on the sink.
+ """
+ if type(name) is not str:
+ raise PySparkTypeError(
+ errorClass="NOT_STR",
+ messageParameters={"arg_name": "name", "arg_type":
type(name).__name__},
+ )
+ if type(format) is not str:
+ raise PySparkTypeError(
+ errorClass="NOT_STR",
+ messageParameters={"arg_name": "format", "arg_type":
type(format).__name__},
+ )
+ if options is not None and not isinstance(options, dict):
+ raise PySparkTypeError(
+ errorClass="NOT_DICT",
+ messageParameters={
+ "arg_name": "options",
+ "arg_type": type(options).__name__,
+ },
+ )
+ sink = Sink(
+ name=name,
+ format=format,
+ options=options or {},
+ source_code_location=get_caller_source_code_location(stacklevel=1),
+ comment=None,
+ )
+ get_active_graph_element_registry().register_output(sink)
diff --git a/python/pyspark/pipelines/output.py
b/python/pyspark/pipelines/output.py
index 4be5f509635d..84e950f16174 100644
--- a/python/pyspark/pipelines/output.py
+++ b/python/pyspark/pipelines/output.py
@@ -74,3 +74,12 @@ class TemporaryView(Output):
referenced by flows within the dataflow graph, but are not visible outside
of the graph."""
pass
+
+
+@dataclass(frozen=True)
+class Sink(Output):
+ """Definition of an external sink in a pipeline dataflow graph. An
external sink's
+ contents are written to an external system rather than managed by the
pipeline itself."""
+
+ format: str
+ options: Mapping[str, str]
diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
index 74e8ac0da3bf..5c5ef9fc3040 100644
--- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
+++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
@@ -24,6 +24,7 @@ from pyspark.pipelines.output import (
Output,
MaterializedView,
Table,
+ Sink,
StreamingTable,
TemporaryView,
)
@@ -46,6 +47,8 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
self._dataflow_graph_id = dataflow_graph_id
def register_output(self, output: Output) -> None:
+ table_details = None
+ sink_details = None
if isinstance(output, Table):
if isinstance(output.schema, str):
schema_string = output.schema
@@ -79,6 +82,12 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
elif isinstance(output, TemporaryView):
output_type = pb2.OutputType.TEMPORARY_VIEW
table_details = None
+ elif isinstance(output, Sink):
+ output_type = pb2.OutputType.SINK
+ sink_details = pb2.PipelineCommand.DefineOutput.SinkDetails(
+ options=output.options,
+ format=output.format,
+ )
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_PIPELINES_DATASET_TYPE",
@@ -90,6 +99,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
output_name=output.name,
output_type=output_type,
comment=output.comment,
+ sink_details=sink_details,
table_details=table_details,
source_code_location=source_code_location_to_proto(output.source_code_location),
)
diff --git a/python/pyspark/pipelines/tests/test_graph_element_registry.py
b/python/pyspark/pipelines/tests/test_graph_element_registry.py
index 8e24e28f1233..e3eafd192ca5 100644
--- a/python/pyspark/pipelines/tests/test_graph_element_registry.py
+++ b/python/pyspark/pipelines/tests/test_graph_element_registry.py
@@ -20,7 +20,9 @@ import unittest
from pyspark.errors import PySparkException
from pyspark.pipelines.graph_element_registry import
graph_element_registration_context
from pyspark import pipelines as dp
+from pyspark.pipelines.output import Sink
from pyspark.pipelines.tests.local_graph_element_registry import
LocalGraphElementRegistry
+from typing import cast
class GraphElementRegistryTest(unittest.TestCase):
@@ -46,7 +48,15 @@ class GraphElementRegistryTest(unittest.TestCase):
def flow2():
raise NotImplementedError()
- self.assertEqual(len(registry.outputs), 3)
+ dp.create_sink(
+ name="sink",
+ format="parquet",
+ options={
+ "key1": "value1",
+ },
+ )
+
+ self.assertEqual(len(registry.outputs), 4)
self.assertEqual(len(registry.flows), 4)
mv_obj = registry.outputs[0]
@@ -81,6 +91,12 @@ class GraphElementRegistryTest(unittest.TestCase):
self.assertEqual(st2_flow1_obj.target, "st2")
assert
mv_flow_obj.source_code_location.filename.endswith("test_graph_element_registry.py")
+ sink_obj = cast(Sink, registry.outputs[3])
+ self.assertEqual(sink_obj.name, "sink")
+ self.assertEqual(sink_obj.format, "parquet")
+ self.assertEqual(sink_obj.options["key1"], "value1")
+ assert
sink_obj.source_code_location.filename.endswith("test_graph_element_registry.py")
+
def test_definition_without_graph_element_registry(self):
for decorator in [dp.table, dp.temporary_view, dp.materialized_view]:
with self.assertRaises(PySparkException) as context:
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index c2e4fe3dea07..7e69e546893e 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -32,7 +32,7 @@ import
org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.pipelines.Language.Python
import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED}
-import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis,
GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables,
PipelineUpdateContextImpl, QueryContext, QueryOrigin, QueryOriginType,
SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView,
UnresolvedFlow}
+import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis,
GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables,
PipelineUpdateContextImpl, QueryContext, QueryOrigin, QueryOriginType, Sink,
SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter,
TemporaryView, UnresolvedFlow}
import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress}
import org.apache.spark.sql.types.StructType
@@ -235,6 +235,27 @@ private[connect] object PipelinesHandler extends Logging {
properties = Map.empty,
sqlText = None))
viewIdentifier
+ case proto.OutputType.SINK =>
+ val dataflowGraphId = output.getDataflowGraphId
+ val graphElementRegistry =
+
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
+ val identifier = GraphIdentifierManager
+ .parseTableIdentifier(name = output.getOutputName, spark =
sessionHolder.session)
+ val sinkDetails = output.getSinkDetails
+ graphElementRegistry.registerSink(
+ SinkImpl(
+ identifier = identifier,
+ format = sinkDetails.getFormat,
+ options = sinkDetails.getOptionsMap.asScala.toMap,
+ origin = QueryOrigin(
+ filePath = Option.when(output.getSourceCodeLocation.hasFileName)(
+ output.getSourceCodeLocation.getFileName),
+ line = Option.when(output.getSourceCodeLocation.hasLineNumber)(
+ output.getSourceCodeLocation.getLineNumber),
+ objectType = Option(QueryOriginType.Sink.toString),
+ objectName = Option(identifier.unquotedString),
+ language = Option(Python()))))
+ identifier
case _ =>
throw new IllegalArgumentException(s"Unknown output type:
${output.getOutputType}")
}
@@ -269,14 +290,17 @@ private[connect] object PipelinesHandler extends Logging {
.getViews()
.filter(_.isInstanceOf[TemporaryView])
.exists(_.identifier == rawDestinationIdentifier)
-
- // If the flow is created implicitly as part of defining a view, then we
do not
- // qualify the flow identifier and the flow destination. This is because
views are
- // not permitted to have multipart
- val isImplicitFlowForTempView = isImplicitFlow && flowWritesToView
+ val flowWritesToSink =
+ graphElementRegistry.getSinks
+ .filter(_.isInstanceOf[Sink])
+ .exists(_.identifier == rawDestinationIdentifier)
+ // If the flow is created implicitly as part of defining a view or that it
writes to a sink,
+ // then we do not qualify the flow identifier and the flow destination.
This is because
+ // views and sinks are not permitted to have multipart
+ val isImplicitFlowForTempView = (isImplicitFlow && flowWritesToView)
val Seq(flowIdentifier, destinationIdentifier) =
Seq(rawFlowIdentifier, rawDestinationIdentifier).map { rawIdentifier =>
- if (isImplicitFlowForTempView) {
+ if (isImplicitFlowForTempView || flowWritesToSink) {
rawIdentifier
} else {
GraphIdentifierManager
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index c4ed554ef978..0b4e36c3cf58 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -289,6 +289,36 @@ class PythonPipelineSuite
graphIdentifier("supplement")))
}
+ test("external sink") {
+ val graph = buildGraph("""
+ |dp.create_sink(
+ | "myKafkaSink",
+ | format = "kafka",
+ | options = {"kafka.bootstrap.servers": "host1:port1,host2:port2"}
+ |)
+ |
+ |@dp.append_flow(
+ | target = "myKafkaSink"
+ |)
+ |def mySinkFlow():
+ | return spark.readStream.format("rate").load()
+ |""".stripMargin)
+
+ assert(graph.sinks.map(_.identifier) ==
Seq(TableIdentifier("myKafkaSink")))
+
+ // ensure format and options are properly set
+ graph.sinks.filter(_.identifier == TableIdentifier("myKafkaSink")).foreach
{ sink =>
+ assert(sink.format == "kafka")
+
assert(sink.options.get("kafka.bootstrap.servers").contains("host1:port1,host2:port2"))
+ }
+
+ // ensure the flow is properly linked to the sink
+ assert(
+ graph
+ .flowsTo(TableIdentifier("myKafkaSink"))
+ .map(_.identifier) == Seq(TableIdentifier("mySinkFlow")))
+ }
+
test("referencing internal datasets") {
val graph = buildGraph("""
|@dp.materialized_view
@@ -400,7 +430,7 @@ class PythonPipelineSuite
| return spark.range(1)
|""".stripMargin)
}
- assert(ex.getCondition == "PIPELINE_DUPLICATE_IDENTIFIERS.DATASET")
+ assert(ex.getCondition == "PIPELINE_DUPLICATE_IDENTIFIERS.OUTPUT")
}
test("create datasets with fully/partially qualified names") {
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index 36c7e9bb80ec..25302ef0a51f 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -433,7 +433,7 @@ class SparkConnectSessionHolderSuite extends
SharedSparkSession {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
val graphId = "test_graph"
val pipelineUpdateContext = new PipelineUpdateContextImpl(
- new DataflowGraph(Seq(), Seq(), Seq()),
+ new DataflowGraph(Seq(), Seq(), Seq(), Seq()),
(_: PipelineEvent) => None,
storageRoot = "test_storage_root")
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
index ee227168dc23..a3d851c1ce7b 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
@@ -159,7 +159,7 @@ class SparkConnectSessionManagerSuite extends
SharedSparkSession with BeforeAndA
val sessionHolder =
SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None)
val graphId = "test_graph"
val pipelineUpdateContext = new PipelineUpdateContextImpl(
- new DataflowGraph(Seq(), Seq(), Seq()),
+ new DataflowGraph(Seq(), Seq(), Seq(), Seq()),
(_: PipelineEvent) => None,
storageRoot = "test_storage_root")
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
index 6d12e2281874..b87c02d562cb 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
@@ -103,6 +103,7 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) {
s"${upstreamNodes.getClass}"
)
}
+ case sink: Sink => Seq(sink)
case _ =>
throw new IllegalArgumentException(s"Unsupported node type:
${node.getClass}")
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala
index 2b6db5e5dd42..a1ae49b413a1 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala
@@ -27,15 +27,15 @@ import org.apache.spark.sql.types.StructType
/**
* DataflowGraph represents the core graph structure for Spark declarative
pipelines.
- * It manages the relationships between logical flows, tables, and views,
providing
+ * It manages the relationships between logical flows, tables, sinks, and
views, providing
* operations for graph traversal, validation, and transformation.
*/
-case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views:
Seq[View])
+case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], sinks:
Seq[Sink], views: Seq[View])
extends GraphOperations
with GraphValidations {
/** Map of [[Output]]s by their identifiers */
- lazy val output: Map[TableIdentifier, Output] = mapUnique(tables,
"output")(_.identifier)
+ lazy val output: Map[TableIdentifier, Output] = mapUnique(sinks ++ tables,
"output")(_.identifier)
/**
* [[Flow]]s in this graph that need to get planned and potentially executed
when
@@ -54,6 +54,10 @@ case class DataflowGraph(flows: Seq[Flow], tables:
Seq[Table], views: Seq[View])
lazy val table: Map[TableIdentifier, Table] =
mapUnique(tables, "table")(_.identifier)
+ /** Map of [[Sink]]s by their identifiers */
+ lazy val sink: Map[TableIdentifier, Sink] =
+ mapUnique(sinks, "sink")(_.identifier)
+
/** Map of [[Flow]]s by their identifier */
lazy val flow: Map[TableIdentifier, Flow] = {
// Better error message than using mapUnique.
@@ -130,7 +134,7 @@ case class DataflowGraph(flows: Seq[Flow], tables:
Seq[Table], views: Seq[View])
}
/**
- * Used to reanalyze the flow's DF for a given table. This is done by
finding all upstream
+ * Used to reanalyze the flow's DF for a given table or sink. This is done
by finding all upstream
* flows (until a table is reached) for the specified source and reanalyzing
all upstream
* flows.
*
@@ -153,7 +157,8 @@ case class DataflowGraph(flows: Seq[Flow], tables:
Seq[Table], views: Seq[View])
val subgraph = new DataflowGraph(
flows = upstreamFlows,
views = upstreamViews,
- tables = Seq(table(srcFlow.destinationIdentifier))
+ tables = table.get(srcFlow.destinationIdentifier).toSeq,
+ sinks = sink.get(srcFlow.destinationIdentifier).toSeq
)
subgraph.resolve().resolvedFlow(srcFlow.identifier)
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala
index 23914a55f31e..4121591c46b2 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala
@@ -61,6 +61,8 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends
AutoCloseable {
private var flowsTo: Map[TableIdentifier, Seq[Flow]] = computeFlowsTo()
private var views: Seq[View] = graph.views
private var viewMap: Map[TableIdentifier, View] = computeViewMap()
+ private var sinks: Seq[Sink] = graph.sinks
+ private var sinkMap: Map[TableIdentifier, Sink] = computeSinkMap()
// Fail analysis nodes
// Failed flows are flows that are failed to resolve or its inputs are not
available or its
@@ -68,6 +70,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends
AutoCloseable {
private var failedFlows: Seq[ResolutionCompletedFlow] = Seq.empty
// We define a dataset is failed to resolve if it is a destination of a flow
that is unresolved.
private var failedTables: Seq[Table] = Seq.empty
+ private var failedSinks: Seq[Sink] = Seq.empty
private val parallelism = 10
@@ -92,6 +95,10 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends
AutoCloseable {
views.map(view => view.identifier -> view).toMap
}
+ private def computeSinkMap(): Map[TableIdentifier, Sink] = synchronized {
+ sinks.map(sink => sink.identifier -> sink).toMap
+ }
+
private def computeFlowsTo(): Map[TableIdentifier, Seq[Flow]] = synchronized
{
flows.groupBy(_.destinationIdentifier)
}
@@ -128,6 +135,7 @@ class DataflowGraphTransformer(graph: DataflowGraph)
extends AutoCloseable {
val resolvedFlows = new ConcurrentLinkedQueue[ResolutionCompletedFlow]()
val resolvedTables = new ConcurrentLinkedQueue[Table]()
val resolvedViews = new ConcurrentLinkedQueue[View]()
+ val resolvedSinks = new ConcurrentLinkedQueue[Sink]()
// Flow identifier to a list of transformed flows mapping to track
resolved flows
val resolvedFlowsMap = new ConcurrentHashMap[TableIdentifier, Seq[Flow]]()
val resolvedFlowDestinationsMap = new ConcurrentHashMap[TableIdentifier,
Boolean]()
@@ -238,22 +246,32 @@ class DataflowGraphTransformer(graph: DataflowGraph)
extends AutoCloseable {
resolvedFlows.addAll(
transformed.collect { case f: ResolvedFlow => f
}.asJava
)
- } else {
- if (viewMap.contains(flow.destinationIdentifier)) {
- resolvedViews.addAll {
- val transformed =
- transformer(
- viewMap(flow.destinationIdentifier),
- flowsTo(flow.destinationIdentifier)
- )
- transformed.map(_.asInstanceOf[View]).asJava
- }
- } else {
- throw new IllegalArgumentException(
- s"Unsupported destination
${flow.destinationIdentifier.unquotedString}" +
- s" in flow: ${flow.displayName} at
transformDownNodes"
+ } else if (viewMap.contains(flow.destinationIdentifier)) {
+ resolvedViews.addAll {
+ val transformed =
+ transformer(
+ viewMap(flow.destinationIdentifier),
+ flowsTo(flow.destinationIdentifier)
+ )
+ transformed.map(_.asInstanceOf[View]).asJava
+ }
+ } else if (sinkMap.contains(flow.destinationIdentifier)) {
+ resolvedSinks.addAll {
+ val transformed =
+ transformer(
+ sinkMap(flow.destinationIdentifier),
flowsTo(flow.destinationIdentifier)
+ )
+ require(
+ transformed.forall(_.isInstanceOf[Sink]),
+ "transformer must return a Seq[Sink]"
)
+ transformed.map(_.asInstanceOf[Sink]).asJava
}
+ } else {
+ throw new IllegalArgumentException(
+ s"Unsupported destination
${flow.destinationIdentifier.unquotedString}" +
+ s" in flow: ${flow.displayName} at transformDownNodes"
+ )
}
// Set flow destination as resolved now.
resolvedFlowDestinationsMap.computeIfPresent(
@@ -287,6 +305,11 @@ class DataflowGraphTransformer(graph: DataflowGraph)
extends AutoCloseable {
failedTables = tables.filterNot { table =>
resolvedFlowDestinationsMap.getOrDefault(table.identifier, false)
}
+ // A sink is failed to analyze if:
+ // - It does not exist in the resolvedFlowDestinationsMap
+ failedSinks = sinks.filterNot { sink =>
+ resolvedFlowDestinationsMap.getOrDefault(sink.identifier, false)
+ }
// We maintain the topological sort order of successful flows always
val (resolvedFlowsWithResolvedDest, resolvedFlowsWithFailedDest) =
@@ -313,8 +336,10 @@ class DataflowGraphTransformer(graph: DataflowGraph)
extends AutoCloseable {
flowsTo = computeFlowsTo()
tables = resolvedTables.asScala.toSeq
views = resolvedViews.asScala.toSeq
+ sinks = resolvedSinks.asScala.toSeq
tableMap = computeTableMap()
viewMap = computeViewMap()
+ sinkMap = computeSinkMap()
this
}
@@ -326,7 +351,8 @@ class DataflowGraphTransformer(graph: DataflowGraph)
extends AutoCloseable {
// they will be front of the list in failedFlows and thus by definition
topologically sorted
// in the combined sequence too.
flows = flows ++ failedFlows,
- tables = tables ++ failedTables
+ tables = tables ++ failedTables,
+ sinks = sinks ++ failedSinks
)
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
index 30ee77a2315e..2c9029fdd34d 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala
@@ -274,3 +274,30 @@ class BatchTableWrite(
}
}
}
+
+/** A `StreamingFlowExecution` that writes a streaming `DataFrame` to a
`Sink`. */
+class SinkWrite(
+ val identifier: TableIdentifier,
+ val flow: ResolvedFlow,
+ val graph: DataflowGraph,
+ val updateContext: PipelineUpdateContext,
+ val checkpointPath: String,
+ val trigger: Trigger,
+ val destination: Sink,
+ val sqlConf: Map[String, String]
+) extends StreamingFlowExecution {
+
+ override def getOrigin: QueryOrigin = flow.origin
+
+ def startStream(): StreamingQuery = {
+ val data = graph.reanalyzeFlow(flow).df
+ data.writeStream
+ .queryName(displayName)
+ .option("checkpointLocation", checkpointPath)
+ .trigger(trigger)
+ .outputMode(OutputMode.Append())
+ .format(destination.format)
+ .options(destination.options)
+ .start()
+ }
+}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
index c4790e872951..3f3a9e2aaac2 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala
@@ -63,6 +63,17 @@ class FlowPlanner(
trigger = triggerFor(sf),
checkpointPath = flowMetadata.latestCheckpointLocation
)
+ case s: Sink =>
+ new SinkWrite(
+ graph = graph,
+ flow = flow,
+ identifier = sf.identifier,
+ destination = s,
+ updateContext = updateContext,
+ sqlConf = sf.sqlConf,
+ trigger = triggerFor(sf),
+ checkpointPath = flowMetadata.latestCheckpointLocation
+ )
case _ =>
throw new UnsupportedOperationException(
s"Streaming flow ${sf.identifier} cannot write to non-table
destination: " +
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala
index 414d9d0effea..572c7a6158c3 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphIdentifierManager.scala
@@ -154,6 +154,27 @@ object GraphIdentifierManager {
internalDatasetIdentifier.identifier
}
+ /**
+ * Parses and validates the sink identifier from the raw sink identifier.
+ *
+ * @param rawSinkIdentifier the raw view identifier
+ * @return the parsed sink identifier
+ */
+ @throws[AnalysisException]
+ def parseAndValidateSinkIdentifier(rawSinkIdentifier: TableIdentifier):
TableIdentifier = {
+ val internalDatasetIdentifier = parseAndValidatePipelineDatasetIdentifier(
+ rawDatasetIdentifier = rawSinkIdentifier
+ )
+ // Sinks are not persisted to the catalog in use, therefore should not be
qualified.
+ if (!isSinglePartIdentifier(internalDatasetIdentifier.identifier)) {
+ throw new AnalysisException(
+ "MULTIPART_SINK_NAME_NOT_SUPPORTED",
+ Map("viewName" -> rawSinkIdentifier.unquotedString)
+ )
+ }
+ internalDatasetIdentifier.identifier
+ }
+
/**
* Parses and validates the view identifier from the raw view identifier for
persisted views.
*
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
index b4f8315cc3fd..26432fd2960f 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala
@@ -35,6 +35,7 @@ class GraphRegistrationContext(
protected val tables = new mutable.ListBuffer[Table]
protected val views = new mutable.ListBuffer[View]
+ protected val sinks = new mutable.ListBuffer[Sink]
protected val flows = new mutable.ListBuffer[UnresolvedFlow]
def registerTable(tableDef: Table): Unit = {
@@ -45,10 +46,16 @@ class GraphRegistrationContext(
views += viewDef
}
+ def registerSink(sinkDef: Sink): Unit = {
+ sinks += sinkDef
+ }
+
def getViews(): Seq[View] = {
return views.toSeq
}
+ def getSinks: Seq[Sink] = sinks.toSeq
+
def registerFlow(flowDef: UnresolvedFlow): Unit = {
flows += flowDef.copy(sqlConf = defaultSqlConf ++ flowDef.sqlConf)
}
@@ -56,7 +63,7 @@ class GraphRegistrationContext(
def toDataflowGraph: DataflowGraph = {
if (tables.isEmpty && views.collect { case v: PersistedView =>
v
- }.isEmpty) {
+ }.isEmpty && sinks.isEmpty) {
throw new AnalysisException(
errorClass = "RUN_EMPTY_PIPELINE",
messageParameters = Map.empty)
@@ -65,12 +72,14 @@ class GraphRegistrationContext(
assertNoDuplicates(
qualifiedTables = tables.toSeq,
validatedViews = views.toSeq,
- qualifiedFlows = flows.toSeq
+ qualifiedFlows = flows.toSeq,
+ validatedSinks = sinks.toSeq
)
new DataflowGraph(
tables = tables.toSeq,
views = views.toSeq,
+ sinks = sinks.toSeq,
flows = flows.toSeq
)
}
@@ -78,13 +87,15 @@ class GraphRegistrationContext(
private def assertNoDuplicates(
qualifiedTables: Seq[Table],
validatedViews: Seq[View],
+ validatedSinks: Seq[Sink],
qualifiedFlows: Seq[UnresolvedFlow]): Unit = {
(qualifiedTables.map(_.identifier) ++ validatedViews.map(_.identifier))
.foreach { identifier =>
- assertDatasetIdentifierIsUnique(
+ assertOutputIdentifierIsUnique(
identifier = identifier,
tables = qualifiedTables,
+ sinks = validatedSinks,
views = validatedViews
)
}
@@ -92,34 +103,34 @@ class GraphRegistrationContext(
qualifiedFlows.foreach { flow =>
assertFlowIdentifierIsUnique(
flow = flow,
- datasetType = TableType,
flows = qualifiedFlows
)
}
}
- private def assertDatasetIdentifierIsUnique(
+ private def assertOutputIdentifierIsUnique(
identifier: TableIdentifier,
tables: Seq[Table],
+ sinks: Seq[Sink],
views: Seq[View]): Unit = {
// We need to check for duplicates in both tables and views, as they can
have the same name.
- val allDatasets = tables.map(t => t.identifier -> TableType) ++ views.map(
+ val allOutputs = tables.map(t => t.identifier -> TableType) ++ views.map(
v => v.identifier -> ViewType
- )
+ ) ++ sinks.map(s => s.identifier -> SinkType)
- val grouped = allDatasets.groupBy { case (id, _) => id }
+ val grouped = allOutputs.groupBy { case (id, _) => id }
grouped(identifier).toList match {
case (_, firstType) :: (_, secondType) :: _ =>
// Sort the types in lexicographic order to ensure consistent error
messages.
val sortedTypes = Seq(firstType.toString, secondType.toString).sorted
throw new AnalysisException(
- errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.DATASET",
+ errorClass = "PIPELINE_DUPLICATE_IDENTIFIERS.OUTPUT",
messageParameters = Map(
- "datasetName" -> identifier.quotedString,
- "datasetType1" -> sortedTypes.head,
- "datasetType2" -> sortedTypes.last
+ "outputName" -> identifier.quotedString,
+ "outputType1" -> sortedTypes.head,
+ "outputType2" -> sortedTypes.last
)
)
case _ => // No duplicates found.
@@ -128,7 +139,6 @@ class GraphRegistrationContext(
private def assertFlowIdentifierIsUnique(
flow: UnresolvedFlow,
- datasetType: DatasetType,
flows: Seq[UnresolvedFlow]): Unit = {
flows.groupBy(i => i.identifier).get(flow.identifier).filter(_.size >
1).foreach {
duplicateFlows =>
@@ -148,13 +158,17 @@ class GraphRegistrationContext(
}
object GraphRegistrationContext {
- sealed trait DatasetType
+ sealed trait OutputType
- private object TableType extends DatasetType {
+ private object TableType extends OutputType {
override def toString: String = "TABLE"
}
- private object ViewType extends DatasetType {
+ private object ViewType extends OutputType {
override def toString: String = "VIEW"
}
+
+ private object SinkType extends OutputType {
+ override def toString: String = "SINK"
+ }
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOriginType.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOriginType.scala
index 007002c2ad45..51d4b4a50038 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOriginType.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOriginType.scala
@@ -19,5 +19,5 @@ package org.apache.spark.sql.pipelines.graph
object QueryOriginType extends Enumeration {
type QueryOriginType = Value
- val Flow, Table, View = Value
+ val Flow, Table, View, Sink = Value
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
index 31fc065cadc8..efe5849d1cbd 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/State.scala
@@ -52,7 +52,18 @@ object State extends Logging {
}
}
- specifiedTablesToReset.flatMap(t => t +:
graph.resolvedFlowsTo(t.identifier))
+ val specifiedSinksToReset = {
+ env.fullRefreshTables match {
+ case SomeTables(tablesAndSinks) =>
+ graph.sinks.filter(sink => tablesAndSinks.contains(sink.identifier))
+ case AllTables =>
+ graph.sinks
+ case NoTables => Seq.empty
+ }
+ }
+
+ specifiedTablesToReset.flatMap(t => t +:
graph.resolvedFlowsTo(t.identifier)) ++
+ specifiedSinksToReset.flatMap(s => graph.resolvedFlowsTo(s.identifier))
}
/**
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
index d805f4a689ec..a9db28c33124 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SystemMetadata.scala
@@ -41,7 +41,8 @@ case class FlowSystemMetadata(
* @return the checkpoint root directory for `flow`
*/
private def flowCheckpointsDirOpt(): Option[Path] = {
- Option(if (graph.table.contains(flow.destinationIdentifier)) {
+ Option(if (graph.table.contains(flow.destinationIdentifier) ||
+ graph.sink.contains(flow.destinationIdentifier)) {
val checkpointRoot = new Path(context.storageRoot, "_checkpoints")
val flowTableName = flow.destinationIdentifier.table
val flowName = flow.identifier.table
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
index 95a57dcc4495..e2c3fdf7994e 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
@@ -78,8 +78,12 @@ trait Input extends GraphElement {
* Represents a node in a [[DataflowGraph]] that can be written to by a
[[Flow]].
* Must be backed by a file source.
*/
-sealed trait Output {
+sealed trait Output {}
+/**
+ * A type of [[Output]] that represents a materialized dataset in a
[[DataflowGraph]].
+ */
+sealed trait Dataset extends Output {
/**
* Normalized storage location used for storing materializations for this
[[Output]].
* If None, it means this [[Output]] has not been normalized yet.
@@ -127,7 +131,7 @@ case class Table(
isStreamingTable: Boolean,
format: Option[String]
) extends TableInput
- with Output {
+ with Dataset {
// Load this table's data from underlying storage.
override def load(readOptions: InputReadOptions): DataFrame = {
@@ -248,3 +252,18 @@ case class PersistedView(
comment: Option[String],
origin: QueryOrigin
) extends View {}
+
+trait Sink extends GraphElement with Output {
+ /** format of the sink */
+ val format: String
+
+ /** options defined for the sink */
+ val options: Map[String, String]
+}
+
+case class SinkImpl(
+ identifier: TableIdentifier,
+ format: String,
+ options: Map[String, String],
+ origin: QueryOrigin
+ ) extends Sink {}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
index 2c0e2a728c69..3ac3c0901750 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
@@ -489,6 +489,26 @@ class ConnectValidPipelineSuite extends PipelineTest with
SharedSparkSession {
)
}
+ test("external sink") {
+ val session = spark
+ import session.implicits._
+
+ val P = new TestGraphRegistrationContext(spark) {
+ val mem = MemoryStream[Int]
+ mem.addData(1, 2)
+ registerTemporaryView("a", query = dfFlowFunc(mem.toDF().select($"value"
as "x")))
+ registerSink("sink_a", format = "memory")
+ registerFlow("sink_a", "sink_flow", query = readStreamFlowFunc("a"))
+ }
+ val g = P.resolveToDataflowGraph()
+ g.validate()
+ assert(g.resolved)
+ assert(g.sink(TableIdentifier("sink_a")).isInstanceOf[Sink])
+ val sink = g.sink(TableIdentifier("sink_a"))
+ assert(sink.format == "memory")
+ assert(g.flow(TableIdentifier("sink_flow")).isInstanceOf[StreamingFlow])
+ }
+
/** Verifies the [[DataflowGraph]] has the specified [[Flow]] with the
specified schema. */
private def verifyFlowSchema(
pipeline: DataflowGraph,
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index afe769f7b204..17bddcd446b1 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -859,7 +859,7 @@ abstract class MaterializeTablesSuite extends
BaseCoreExecutionTest {
import session.implicits._
val graph1 =
- new DataflowGraph(flows = Seq.empty, tables = Seq.empty, views =
Seq.empty)
+ new DataflowGraph(flows = Seq.empty, tables = Seq.empty, views =
Seq.empty, sinks = Seq.empty)
val graph2 = new TestGraphRegistrationContext(spark) {
registerFlow(
"a",
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SinkExecutionSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SinkExecutionSuite.scala
new file mode 100644
index 000000000000..958ef5a80fd5
--- /dev/null
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SinkExecutionSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.classic.DataFrame
+import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream,
StreamingQueryWrapper}
+import org.apache.spark.sql.pipelines.utils.{ExecutionTest,
TestGraphRegistrationContext}
+import org.apache.spark.sql.streaming.StreamingQuery
+import org.apache.spark.sql.test.SharedSparkSession
+
+class SinkExecutionSuite extends ExecutionTest with SharedSparkSession {
+ def createDataflowGraph(
+ inputs: DataFrame,
+ sinkName: String,
+ flowName: String,
+ format: String,
+ sinkOptions: Map[String, String] = Map.empty
+ ): DataflowGraph = {
+ val registrationContext = new TestGraphRegistrationContext(spark) {
+ registerTemporaryView("a", query = dfFlowFunc(inputs))
+ registerSink(sinkName, format, sinkOptions)
+ registerFlow(sinkName, flowName, query = readStreamFlowFunc("a"))
+ }
+ registrationContext.toDataflowGraph
+ }
+
+ test("writing to external sink - memory sink") {
+ val session = spark
+ import session.implicits._
+
+ val ints = MemoryStream[Int]
+ ints.addData(1, 2, 3, 4)
+
+ val unresolvedGraph =
+ createDataflowGraph(ints.toDF(), "sink_a", "flow_to_sink_a", "memory")
+ val updateContext = TestPipelineUpdateContext(
+ spark,
+ unresolvedGraph,
+ storageRoot,
+ failOnErrorEvent = true
+ )
+ updateContext.pipelineExecution.startPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ verifyCheckpointLocation(
+ storageRoot,
+ updateContext.pipelineExecution.graphExecution.get,
+ TableIdentifier("sink_a"),
+ TableIdentifier("flow_to_sink_a")
+ )
+
+ checkAnswer(spark.sql("SELECT * FROM flow_to_sink_a"), Seq(1, 2, 3,
4).toDF().collect().toSeq)
+ }
+
+ test("writing to external sink - parquet sink with path") {
+ val session = spark
+ import session.implicits._
+
+ withTempDir { externalDeltaPath =>
+ val ints = MemoryStream[Int]
+ ints.addData(1, 2, 3, 4)
+ val unresolvedGraph = createDataflowGraph(
+ ints.toDF(),
+ "parquet_sink",
+ "flow_to_parquet_sink",
+ "parquet",
+ Map(
+ "path" -> externalDeltaPath.getPath
+ )
+ )
+
+ val updateContext = TestPipelineUpdateContext(
+ spark,
+ unresolvedGraph,
+ storageRoot
+ )
+ updateContext.pipelineExecution.startPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ verifyCheckpointLocation(
+ storageRoot,
+ updateContext.pipelineExecution.graphExecution.get,
+ TableIdentifier("parquet_sink"),
+ TableIdentifier("flow_to_parquet_sink")
+ )
+
+ checkAnswer(
+ spark.read.format("parquet").load(externalDeltaPath.getPath),
+ Seq(1, 2, 3, 4).toDF().collect().toSeq
+ )
+ }
+ }
+
+ def verifyCheckpointLocation(
+ rootDirectory: String,
+ graphExecution: GraphExecution,
+ sinkIdentifier: TableIdentifier,
+ flowIdentifier: TableIdentifier): Unit = {
+ val expectedCheckpointLocation = new Path(
+ "file://" + rootDirectory +
s"/_checkpoints/${sinkIdentifier.table}/${flowIdentifier.table}/0"
+ )
+ val streamingQuery = graphExecution
+ .flowExecutions(flowIdentifier)
+ .asInstanceOf[StreamingFlowExecution]
+ .getStreamingQuery
+
+ val actualCheckpointLocation = new Path(getCheckpointPath(streamingQuery))
+
+ assert(actualCheckpointLocation == expectedCheckpointLocation)
+ }
+
+ private def getCheckpointPath(q: StreamingQuery): String =
+ q.asInstanceOf[StreamingQueryWrapper].streamingQuery.resolvedCheckpointRoot
+}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlPipelineSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlPipelineSuite.scala
index e921a6bfe2ab..950df9167926 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlPipelineSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlPipelineSuite.scala
@@ -109,12 +109,12 @@ class SqlPipelineSuite extends PipelineTest with
SharedSparkSession {
exception = intercept[AnalysisException] {
graphRegistrationContext.toDataflowGraph
},
- condition = "PIPELINE_DUPLICATE_IDENTIFIERS.DATASET",
+ condition = "PIPELINE_DUPLICATE_IDENTIFIERS.OUTPUT",
sqlState = Option("42710"),
parameters = Map(
- "datasetName" -> fullyQualifiedIdentifier("table").quotedString,
- "datasetType1" -> "TABLE",
- "datasetType2" -> "VIEW"
+ "outputName" -> fullyQualifiedIdentifier("table").quotedString,
+ "outputType1" -> "TABLE",
+ "outputType2" -> "VIEW"
)
)
}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
index 15dfa6576171..a8db049b2b68 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.pipelines.graph
import org.apache.hadoop.fs.Path
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream,
StreamingQueryWrapper}
import org.apache.spark.sql.pipelines.utils.{ExecutionTest,
TestGraphRegistrationContext}
import org.apache.spark.sql.test.SharedSparkSession
@@ -162,6 +163,66 @@ class SystemMetadataSuite
updateContext2
)
}
+
+ test("flow checkpoint for sink wrote to the expected location for full
refresh") {
+ val session = spark
+ import session.implicits._
+
+ val graph = new TestGraphRegistrationContext(spark) {
+ val mem: MemoryStream[Int] = MemoryStream[Int]
+ mem.addData(1, 2, 3)
+ registerView("a", query = dfFlowFunc(mem.toDF()))
+ registerSink("sink", format = "console")
+ registerFlow("sink", "sink", query = readStreamFlowFunc("a"))
+ }.toDataflowGraph
+ val sinkIdentifier = TableIdentifier("sink")
+
+ val updateContext1 = TestPipelineUpdateContext(
+ unresolvedGraph = graph,
+ spark = spark,
+ storageRoot = storageRoot
+ )
+
+ updateContext1.pipelineExecution.startPipeline()
+ updateContext1.pipelineExecution.awaitCompletion()
+ val graphExecution1 = updateContext1.pipelineExecution.graphExecution.get
+ val executionGraph1 = graphExecution1.graphForExecution
+
+ // assert that the checkpoint dir for the sink is created as expected
+ assertFlowCheckpointDirExists(
+ tableOrSinkElement = executionGraph1.sink(sinkIdentifier),
+ flowElement = executionGraph1.flow(sinkIdentifier),
+ // the default checkpoint version is 0
+ expectedCheckpointVersion = 0,
+ graphExecution = graphExecution1,
+ updateContext = updateContext1
+ )
+
+ // start another update in full refresh, expected a new checkpoint dir to
be created
+ // with version number incremented to 1
+ val updateContext2 = TestPipelineUpdateContext(
+ unresolvedGraph = graph,
+ spark = spark,
+ storageRoot = storageRoot,
+ fullRefreshTables = AllTables
+ )
+
+ updateContext2.pipelineExecution.startPipeline()
+ updateContext2.pipelineExecution.awaitCompletion()
+ val graphExecution2 = updateContext2.pipelineExecution.graphExecution.get
+ val executionGraph2 = graphExecution2.graphForExecution
+
+ // due to full refresh, assert that new checkpoint dir is created with
version number
+ // incremented to 1
+ assertFlowCheckpointDirExists(
+ tableOrSinkElement = executionGraph2.sink(sinkIdentifier),
+ flowElement = executionGraph2.flow(sinkIdentifier),
+ // new checkpoint directory is created with version number incremented
to 1
+ expectedCheckpointVersion = 1,
+ graphExecution = graphExecution2,
+ updateContext2
+ )
+ }
}
trait SystemMetadataTestHelpers {
@@ -176,7 +237,8 @@ trait SystemMetadataTestHelpers {
updateContext: PipelineUpdateContext
): Path = {
val expectedRawCheckPointDir = tableOrSinkElement match {
- case t: Table => new Path(updateContext.storageRoot)
+ case t if t.isInstanceOf[Table] || t.isInstanceOf[Sink] =>
+ new Path(updateContext.storageRoot)
.suffix(s"/_checkpoints/${t.identifier.table}/${flowElement.identifier.table}")
.toString
case _ =>
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
index efba7aba0a41..bb7c8e833f84 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/APITest.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.QueryTest.checkAnswer
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.util.Utils
/**
* Representation of a pipeline specification
@@ -454,6 +455,60 @@ trait APITest
checkAnswer(spark.sql(s"SELECT * FROM mv_1"), Seq(Row(0), Row(1), Row(2),
Row(3), Row(4)))
}
+ Seq("parquet", "json").foreach { format =>
+ test(s"Python Pipeline with $format sink") {
+ val session = spark
+ import session.implicits._
+
+ // create source data
+ spark.sql(s"CREATE TABLE src AS SELECT * FROM RANGE(5)")
+
+ val dir = Utils.createTempDir()
+ try {
+ val pipelineSpec =
+ TestPipelineSpec(include = Seq("transformations/definition.py"))
+ val pipelineConfig = TestPipelineConfiguration(pipelineSpec)
+
+ val sources = Seq(
+ PipelineSourceFile(
+ name = "transformations/definition.py",
+ contents =
+ s"""
+ |from pyspark import pipelines as dp
+ |from pyspark.sql import DataFrame, SparkSession
+ |
+ |spark = SparkSession.active()
+ |
+ |dp.create_sink(
+ | "mySink",
+ | format = "$format",
+ | options = {"path": "${dir.getPath}"}
+ |)
+ |
+ |@dp.append_flow(
+ | target = "mySink",
+ |)
+ |def mySinkFlow():
+ | return spark.readStream.table("src")
+ |""".stripMargin
+ )
+ )
+
+ val pipeline = createAndRunPipeline(pipelineConfig, sources)
+ awaitPipelineTermination(pipeline)
+
+ // verify sink output
+ checkAnswer(
+ spark.read.format(format).load(dir.getPath),
+ Seq(0, 1, 2, 3, 4).toDF().collect().toSeq
+ )
+ } finally {
+ // clean up temp directory
+ Utils.deleteRecursively(dir)
+ }
+ }
+ }
+
test("Python Pipeline with partition columns") {
val pipelineSpec =
TestPipelineSpec(include = Seq("transformations/**"))
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
index 4a33dd2c61a8..599aab87d1f7 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
@@ -20,20 +20,7 @@ package org.apache.spark.sql.pipelines.utils
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{LocalTempView, PersistedView =>
PersistedViewType, UnresolvedRelation, ViewType}
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
-import org.apache.spark.sql.pipelines.graph.{
- DataflowGraph,
- FlowAnalysis,
- FlowFunction,
- GraphIdentifierManager,
- GraphRegistrationContext,
- PersistedView,
- QueryContext,
- QueryOrigin,
- QueryOriginType,
- Table,
- TemporaryView,
- UnresolvedFlow
-}
+import org.apache.spark.sql.pipelines.graph.{DataflowGraph, FlowAnalysis,
FlowFunction, GraphIdentifierManager, GraphRegistrationContext, PersistedView,
QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, Table,
TemporaryView, UnresolvedFlow}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -80,6 +67,26 @@ class TestGraphRegistrationContext(
)
// scalastyle:on
+ def registerTemporaryView(
+ name: String,
+ query: FlowFunction,
+ sqlConf: Map[String, String] = Map.empty,
+ comment: Option[String] = None,
+ origin: QueryOrigin = QueryOrigin.empty,
+ catalog: Option[String] = None,
+ database: Option[String] = None): Unit = {
+ registerView(
+ name = name,
+ query = query,
+ sqlConf = sqlConf,
+ comment = comment,
+ origin = origin,
+ viewType = LocalTempView,
+ catalog = catalog,
+ database = database
+ )
+ }
+
// scalastyle:off
// Disable scalastyle to ignore argument count.
/** Registers a materialized view in this [[TestGraphRegistrationContext]] */
@@ -264,6 +271,25 @@ class TestGraphRegistrationContext(
)
}
+ def registerSink(
+ name: String,
+ format: String,
+ options: Map[String, String] = Map.empty,
+ origin: QueryOrigin = QueryOrigin.empty
+ ): Unit = {
+ val sinkIdentifier = GraphIdentifierManager
+ .parseAndValidateSinkIdentifier(rawSinkIdentifier =
TableIdentifier(name))
+
+ registerSink(
+ SinkImpl(
+ identifier = sinkIdentifier,
+ format = format,
+ origin = origin,
+ options = options
+ )
+ )
+ }
+
def registerFlow(
destinationName: String,
name: String,
@@ -279,15 +305,17 @@ class TestGraphRegistrationContext(
val flowWritesToView = getViews()
.filter(_.isInstanceOf[TemporaryView])
.exists(_.identifier == rawDestinationIdentifier)
-
- // If the flow is created implicitly as part of defining a view, then we
do not
- // qualify the flow identifier and the flow destination. This is because
views are
- // not permitted to have multipart
+ val flowWritesToSink = getSinks
+ .filter(_.isInstanceOf[Sink])
+ .exists(_.identifier == rawDestinationIdentifier)
+ // If the flow is created implicitly as part of defining a view or that it
writes to a sink,
+ // then we do not qualify the flow identifier and the flow destination.
This is because
+ // views and sinks are not permitted to have multipart
val isImplicitFlow = rawFlowIdentifier == rawDestinationIdentifier
val isImplicitFlowForTempView = isImplicitFlow && flowWritesToView
val Seq(flowIdentifier, flowDestinationIdentifier) =
Seq(rawFlowIdentifier, rawDestinationIdentifier).map { rawIdentifier =>
- if (isImplicitFlowForTempView) {
+ if (isImplicitFlowForTempView || flowWritesToSink) {
rawIdentifier
} else {
GraphIdentifierManager
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]