This is an automated email from the ASF dual-hosted git repository.
sandy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 65ff85a31fe8 [SPARK-52640][SDP] Propagate Python Source Code Location
65ff85a31fe8 is described below
commit 65ff85a31fe8a8ea4a2ba713ba2c624709ce815a
Author: anishm-db <[email protected]>
AuthorDate: Thu Oct 2 06:38:47 2025 -0700
[SPARK-52640][SDP] Propagate Python Source Code Location
### What changes were proposed in this pull request?
Propagate source code location details (line number and file path) E2E for
declarative pipelines. That is, collect this information from the python REPL
that registers SDP datasets/flows, propagate it through the appropriate spark
connect handlers, and associate it to the appropriate datasets/flows in
pipeline events/exceptions.
### Why are the changes needed?
Better observability and debugging experience for users. Allows users to
identify the exact lines that cause a particular exception.
### Does this PR introduce _any_ user-facing change?
Yes, we are populating source code information in the origin for pipeline
events, which is user-facing. Currently SDP is not released in any spark
version however.
### How was this patch tested?
Added tests to `org.apache.spark.sql.connect.pipelines.PythonPipelineSuite`
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #51344 from AnishMahto/sdp-python-query-origins.
Authored-by: anishm-db <[email protected]>
Signed-off-by: Sandy Ryza <[email protected]>
---
python/pyspark/pipelines/source_code_location.py | 28 +++++
.../spark_connect_graph_element_registry.py | 11 ++
python/pyspark/sql/connect/proto/pipelines_pb2.py | 68 +++++------
python/pyspark/sql/connect/proto/pipelines_pb2.pyi | 94 +++++++++++++++
.../main/protobuf/spark/connect/pipelines.proto | 14 +++
.../sql/connect/pipelines/PipelinesHandler.scala | 14 ++-
.../connect/pipelines/PythonPipelineSuite.scala | 131 ++++++++++++++++++++-
.../graph/CoreDataflowNodeProcessor.scala | 6 +-
.../apache/spark/sql/pipelines/graph/Flow.scala | 5 +-
.../spark/sql/pipelines/graph/FlowAnalysis.scala | 3 +-
.../graph/SqlGraphRegistrationContext.scala | 6 +-
.../spark/sql/pipelines/graph/elements.scala | 7 +-
.../utils/TestGraphRegistrationContext.scala | 25 +++-
13 files changed, 359 insertions(+), 53 deletions(-)
diff --git a/python/pyspark/pipelines/source_code_location.py
b/python/pyspark/pipelines/source_code_location.py
index 5f23b819abd8..cbf4cbe514a6 100644
--- a/python/pyspark/pipelines/source_code_location.py
+++ b/python/pyspark/pipelines/source_code_location.py
@@ -30,6 +30,34 @@ def get_caller_source_code_location(stacklevel: int) ->
SourceCodeLocation:
"""
Returns a SourceCodeLocation object representing the location code that
invokes this function.
+ If this function is called from a decorator (ex. @sdp.table), note that
the returned line
+ number is affected by how the decorator was triggered - i.e. whether
@sdp.table or @sdp.table()
+ was called - AND what python version is being used
+
+ Case 1:
+ |@sdp.table()
+ |def fn
+
+ @sdp.table() is executed immediately, on line 1. This is true for all
python versions.
+
+ Case 2:
+ |@sdp.table
+ |def fn
+
+ In python < 3.10, @sdp.table will expand to fn = sdp.table(fn), replacing
the line that `fn` is
+ defined on. This would be line 2. More interestingly, this means:
+
+ |@sdp.table
+ |
+ |
+ |def fn
+
+ Will expand to fn = sdp.table(fn) on line 4, where `fn` is defined.
+
+ However, in python 3.10+, the line number in the stack trace will still be
the line that the
+ decorator was defined on. In other words, case 2 will be treated the same
as case 1, and the
+ line number will be 1.
+
:param stacklevel: The number of stack frames to go up. 0 means the direct
caller of this
function, 1 means the caller of the caller, and so on.
"""
diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
index 020c7989138d..8faf7eb9ef58 100644
--- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
+++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
@@ -29,6 +29,7 @@ from pyspark.pipelines.dataset import (
)
from pyspark.pipelines.flow import Flow
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
+from pyspark.pipelines.source_code_location import SourceCodeLocation
from typing import Any, cast
import pyspark.sql.connect.proto as pb2
@@ -79,6 +80,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
partition_cols=partition_cols,
schema=schema,
format=format,
+
source_code_location=source_code_location_to_proto(dataset.source_code_location),
)
command = pb2.Command()
command.pipeline_command.define_dataset.CopyFrom(inner_command)
@@ -95,6 +97,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
target_dataset_name=flow.target,
relation=relation,
sql_conf=flow.spark_conf,
+
source_code_location=source_code_location_to_proto(flow.source_code_location),
)
command = pb2.Command()
command.pipeline_command.define_flow.CopyFrom(inner_command)
@@ -109,3 +112,11 @@ class
SparkConnectGraphElementRegistry(GraphElementRegistry):
command = pb2.Command()
command.pipeline_command.define_sql_graph_elements.CopyFrom(inner_command)
self._client.execute_command(command)
+
+
+def source_code_location_to_proto(
+ source_code_location: SourceCodeLocation,
+) -> pb2.SourceCodeLocation:
+ return pb2.SourceCodeLocation(
+ file_name=source_code_location.filename,
line_number=source_code_location.line_number
+ )
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py
b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index 8f4ba32baccb..849d141f9c49 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -41,7 +41,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xe0\x19\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineD [...]
+
b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xc6\x1b\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01
\x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02
\x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineD [...]
)
_globals = globals()
@@ -60,10 +60,10 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options
= b"8\001"
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options =
b"8\001"
- _globals["_DATASETTYPE"]._serialized_start = 4489
- _globals["_DATASETTYPE"]._serialized_end = 4586
+ _globals["_DATASETTYPE"]._serialized_start = 4843
+ _globals["_DATASETTYPE"]._serialized_end = 4940
_globals["_PIPELINECOMMAND"]._serialized_start = 168
- _globals["_PIPELINECOMMAND"]._serialized_end = 3464
+ _globals["_PIPELINECOMMAND"]._serialized_end = 3694
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1050
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1358
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start
= 1259
@@ -71,35 +71,37 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_start = 1360
_globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_end = 1450
_globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_start = 1453
- _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_end = 2046
-
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_start
= 1890
-
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_end
= 1956
- _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2049
- _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 2579
+ _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_end = 2161
+
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_start
= 1980
+
_globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_end
= 2046
+ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2164
+ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 2809
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start =
1259
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 1317
- _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 2434
- _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 2492
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2582
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2861
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
2864
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 3063
-
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
= 3066
-
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
= 3224
-
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start =
3227
- _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end
= 3448
- _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 3467
- _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 4223
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 3839
-
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
3937
- _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start =
3940
- _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end =
4074
- _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start =
4077
- _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 4208
- _globals["_PIPELINEEVENTRESULT"]._serialized_start = 4225
- _globals["_PIPELINEEVENTRESULT"]._serialized_end = 4298
- _globals["_PIPELINEEVENT"]._serialized_start = 4300
- _globals["_PIPELINEEVENT"]._serialized_end = 4416
- _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 4418
- _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 4487
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 2639
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 2697
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2812
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3091
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start =
3094
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 3293
+
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start
= 3296
+
_globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end
= 3454
+
_globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start =
3457
+ _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end
= 3678
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 3697
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 4453
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start
= 4069
+
_globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end =
4167
+ _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start =
4170
+ _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end =
4304
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start =
4307
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 4438
+ _globals["_PIPELINEEVENTRESULT"]._serialized_start = 4455
+ _globals["_PIPELINEEVENTRESULT"]._serialized_end = 4528
+ _globals["_PIPELINEEVENT"]._serialized_start = 4530
+ _globals["_PIPELINEEVENT"]._serialized_end = 4646
+ _globals["_SOURCECODELOCATION"]._serialized_start = 4648
+ _globals["_SOURCECODELOCATION"]._serialized_end = 4770
+ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 4772
+ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 4841
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index 097614de8a1f..b5ed1c216a83 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -233,6 +233,7 @@ class PipelineCommand(google.protobuf.message.Message):
PARTITION_COLS_FIELD_NUMBER: builtins.int
SCHEMA_FIELD_NUMBER: builtins.int
FORMAT_FIELD_NUMBER: builtins.int
+ SOURCE_CODE_LOCATION_FIELD_NUMBER: builtins.int
dataflow_graph_id: builtins.str
"""The graph to attach this dataset to."""
dataset_name: builtins.str
@@ -260,6 +261,9 @@ class PipelineCommand(google.protobuf.message.Message):
"""The output table format of the dataset. Only applies to
dataset_type == TABLE and
dataset_type == MATERIALIZED_VIEW.
"""
+ @property
+ def source_code_location(self) -> global___SourceCodeLocation:
+ """The location in source code that this dataset was defined."""
def __init__(
self,
*,
@@ -271,6 +275,7 @@ class PipelineCommand(google.protobuf.message.Message):
partition_cols: collections.abc.Iterable[builtins.str] | None =
...,
schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
format: builtins.str | None = ...,
+ source_code_location: global___SourceCodeLocation | None = ...,
) -> None: ...
def HasField(
self,
@@ -287,6 +292,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_format",
"_schema",
b"_schema",
+ "_source_code_location",
+ b"_source_code_location",
"comment",
b"comment",
"dataflow_graph_id",
@@ -299,6 +306,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"format",
"schema",
b"schema",
+ "source_code_location",
+ b"source_code_location",
],
) -> builtins.bool: ...
def ClearField(
@@ -316,6 +325,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_format",
"_schema",
b"_schema",
+ "_source_code_location",
+ b"_source_code_location",
"comment",
b"comment",
"dataflow_graph_id",
@@ -330,6 +341,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"partition_cols",
"schema",
b"schema",
+ "source_code_location",
+ b"source_code_location",
"table_properties",
b"table_properties",
],
@@ -359,6 +372,13 @@ class PipelineCommand(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_schema", b"_schema"]
) -> typing_extensions.Literal["schema"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self,
+ oneof_group: typing_extensions.Literal[
+ "_source_code_location", b"_source_code_location"
+ ],
+ ) -> typing_extensions.Literal["source_code_location"] | None: ...
class DefineFlow(google.protobuf.message.Message):
"""Request to define a flow targeting a dataset."""
@@ -415,6 +435,7 @@ class PipelineCommand(google.protobuf.message.Message):
RELATION_FIELD_NUMBER: builtins.int
SQL_CONF_FIELD_NUMBER: builtins.int
CLIENT_ID_FIELD_NUMBER: builtins.int
+ SOURCE_CODE_LOCATION_FIELD_NUMBER: builtins.int
dataflow_graph_id: builtins.str
"""The graph to attach this flow to."""
flow_name: builtins.str
@@ -435,6 +456,9 @@ class PipelineCommand(google.protobuf.message.Message):
"""Identifier for the client making the request. The server uses this
to determine what flow
evaluation request stream to dispatch evaluation requests to for this
flow.
"""
+ @property
+ def source_code_location(self) -> global___SourceCodeLocation:
+ """The location in source code that this flow was defined."""
def __init__(
self,
*,
@@ -444,6 +468,7 @@ class PipelineCommand(google.protobuf.message.Message):
relation: pyspark.sql.connect.proto.relations_pb2.Relation | None
= ...,
sql_conf: collections.abc.Mapping[builtins.str, builtins.str] |
None = ...,
client_id: builtins.str | None = ...,
+ source_code_location: global___SourceCodeLocation | None = ...,
) -> None: ...
def HasField(
self,
@@ -456,6 +481,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_flow_name",
"_relation",
b"_relation",
+ "_source_code_location",
+ b"_source_code_location",
"_target_dataset_name",
b"_target_dataset_name",
"client_id",
@@ -466,6 +493,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"flow_name",
"relation",
b"relation",
+ "source_code_location",
+ b"source_code_location",
"target_dataset_name",
b"target_dataset_name",
],
@@ -481,6 +510,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_flow_name",
"_relation",
b"_relation",
+ "_source_code_location",
+ b"_source_code_location",
"_target_dataset_name",
b"_target_dataset_name",
"client_id",
@@ -491,6 +522,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"flow_name",
"relation",
b"relation",
+ "source_code_location",
+ b"source_code_location",
"sql_conf",
b"sql_conf",
"target_dataset_name",
@@ -515,6 +548,13 @@ class PipelineCommand(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_relation",
b"_relation"]
) -> typing_extensions.Literal["relation"] | None: ...
@typing.overload
+ def WhichOneof(
+ self,
+ oneof_group: typing_extensions.Literal[
+ "_source_code_location", b"_source_code_location"
+ ],
+ ) -> typing_extensions.Literal["source_code_location"] | None: ...
+ @typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal["_target_dataset_name",
b"_target_dataset_name"],
@@ -1134,6 +1174,60 @@ class PipelineEvent(google.protobuf.message.Message):
global___PipelineEvent = PipelineEvent
+class SourceCodeLocation(google.protobuf.message.Message):
+ """Source code location information associated with a particular dataset
or flow."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ FILE_NAME_FIELD_NUMBER: builtins.int
+ LINE_NUMBER_FIELD_NUMBER: builtins.int
+ file_name: builtins.str
+ """The file that this pipeline source code was defined in."""
+ line_number: builtins.int
+ """The specific line number that this pipeline source code is located at,
if applicable."""
+ def __init__(
+ self,
+ *,
+ file_name: builtins.str | None = ...,
+ line_number: builtins.int | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_file_name",
+ b"_file_name",
+ "_line_number",
+ b"_line_number",
+ "file_name",
+ b"file_name",
+ "line_number",
+ b"line_number",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_file_name",
+ b"_file_name",
+ "_line_number",
+ b"_line_number",
+ "file_name",
+ b"file_name",
+ "line_number",
+ b"line_number",
+ ],
+ ) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_file_name",
b"_file_name"]
+ ) -> typing_extensions.Literal["file_name"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_line_number",
b"_line_number"]
+ ) -> typing_extensions.Literal["line_number"] | None: ...
+
+global___SourceCodeLocation = SourceCodeLocation
+
class PipelineQueryFunctionExecutionSignal(google.protobuf.message.Message):
"""A signal from the server to the client to execute the query function
for one or more flows, and
to register their results with the server.
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index ef1ac5f38073..16d211f9f72d 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -86,6 +86,9 @@ message PipelineCommand {
// The output table format of the dataset. Only applies to dataset_type ==
TABLE and
// dataset_type == MATERIALIZED_VIEW.
optional string format = 8;
+
+ // The location in source code that this dataset was defined.
+ optional SourceCodeLocation source_code_location = 9;
}
// Request to define a flow targeting a dataset.
@@ -110,6 +113,9 @@ message PipelineCommand {
// evaluation request stream to dispatch evaluation requests to for this
flow.
optional string client_id = 6;
+ // The location in source code that this flow was defined.
+ optional SourceCodeLocation source_code_location = 7;
+
message Response {
// Fully qualified flow name that uniquely identify a flow in the
Dataflow graph.
optional string flow_name = 1;
@@ -217,6 +223,14 @@ message PipelineEvent {
optional string message = 2;
}
+// Source code location information associated with a particular dataset or
flow.
+message SourceCodeLocation {
+ // The file that this pipeline source code was defined in.
+ optional string file_name = 1;
+ // The specific line number that this pipeline source code is located at, if
applicable.
+ optional int32 line_number = 2;
+}
+
// A signal from the server to the client to execute the query function for
one or more flows, and
// to register their results with the server.
message PipelineQueryFunctionExecutionSignal {
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index 1b2e039be715..01402c64e8a2 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -195,7 +195,11 @@ private[connect] object PipelinesHandler extends Logging {
partitionCols = Option(dataset.getPartitionColsList.asScala.toSeq)
.filter(_.nonEmpty),
properties = dataset.getTablePropertiesMap.asScala.toMap,
- baseOrigin = QueryOrigin(
+ origin = QueryOrigin(
+ filePath =
Option.when(dataset.getSourceCodeLocation.hasFileName)(
+ dataset.getSourceCodeLocation.getFileName),
+ line = Option.when(dataset.getSourceCodeLocation.hasLineNumber)(
+ dataset.getSourceCodeLocation.getLineNumber),
objectType = Option(QueryOriginType.Table.toString),
objectName = Option(qualifiedIdentifier.unquotedString),
language = Option(Python())),
@@ -212,6 +216,10 @@ private[connect] object PipelinesHandler extends Logging {
identifier = viewIdentifier,
comment = Option(dataset.getComment),
origin = QueryOrigin(
+ filePath =
Option.when(dataset.getSourceCodeLocation.hasFileName)(
+ dataset.getSourceCodeLocation.getFileName),
+ line = Option.when(dataset.getSourceCodeLocation.hasLineNumber)(
+ dataset.getSourceCodeLocation.getLineNumber),
objectType = Option(QueryOriginType.View.toString),
objectName = Option(viewIdentifier.unquotedString),
language = Option(Python())),
@@ -281,6 +289,10 @@ private[connect] object PipelinesHandler extends Logging {
once = false,
queryContext = QueryContext(Option(defaultCatalog),
Option(defaultDatabase)),
origin = QueryOrigin(
+ filePath = Option.when(flow.getSourceCodeLocation.hasFileName)(
+ flow.getSourceCodeLocation.getFileName),
+ line = Option.when(flow.getSourceCodeLocation.hasLineNumber)(
+ flow.getSourceCodeLocation.getLineNumber),
objectType = Option(QueryOriginType.Flow.toString),
objectName = Option(flowIdentifier.unquotedString),
language = Option(Python()))))
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index 897c0209153f..b99615062d45 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -31,7 +31,10 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
-import org.apache.spark.sql.pipelines.graph.{DataflowGraph,
PipelineUpdateContextImpl}
+import org.apache.spark.sql.pipelines.Language.Python
+import org.apache.spark.sql.pipelines.common.FlowStatus
+import org.apache.spark.sql.pipelines.graph.{DataflowGraph,
PipelineUpdateContextImpl, QueryOrigin, QueryOriginType}
+import org.apache.spark.sql.pipelines.logging.EventLevel
import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers,
TestPipelineUpdateContextMixin}
/**
@@ -116,6 +119,132 @@ class PythonPipelineSuite
assert(graph.tables.size == 1)
}
+ test("failed flow progress event has correct python source code location") {
+ // Note that pythonText will be inserted into line 26 of the python script
that is run.
+ val unresolvedGraph = buildGraph(pythonText = """
+ |@dp.table()
+ |def table1():
+ | df = spark.createDataFrame([(25,), (30,), (45,)], ["age"])
+ | return df.select("name")
+ |""".stripMargin)
+
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+ updateContext.pipelineExecution.runPipeline()
+
+ assertFlowProgressEvent(
+ updateContext.eventBuffer,
+ identifier = graphIdentifier("table1"),
+ expectedFlowStatus = FlowStatus.FAILED,
+ cond = flowProgressEvent =>
+ flowProgressEvent.origin.sourceCodeLocation == Option(
+ QueryOrigin(
+ language = Option(Python()),
+ filePath = Option("<string>"),
+ line = Option(28),
+ objectName = Option("spark_catalog.default.table1"),
+ objectType = Option(QueryOriginType.Flow.toString))),
+ errorChecker = ex =>
+ ex.getMessage.contains(
+ "A column, variable, or function parameter with name `name` cannot
be resolved."),
+ expectedEventLevel = EventLevel.WARN)
+ }
+
+ test("flow progress events have correct python source code location") {
+ val unresolvedGraph = buildGraph(pythonText = """
+ |@dp.table(
+ | comment = 'my table'
+ |)
+ |def table1():
+ | return spark.readStream.table('mv')
+ |
+ |@dp.materialized_view
+ |def mv2():
+ | return spark.range(26, 29)
+ |
+ |@dp.materialized_view
+ |def mv():
+ | df = spark.createDataFrame([(25,), (30,), (45,)], ["age"])
+ | return df.select("age")
+ |
+ |@dp.append_flow(
+ | target = 'table1'
+ |)
+ |def standalone_flow1():
+ | return spark.readStream.table('mv2')
+ |""".stripMargin)
+
+ val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph)
+ updateContext.pipelineExecution.runPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ Seq(
+ FlowStatus.QUEUED,
+ FlowStatus.STARTING,
+ FlowStatus.PLANNING,
+ FlowStatus.RUNNING,
+ FlowStatus.COMPLETED).foreach { flowStatus =>
+ assertFlowProgressEvent(
+ updateContext.eventBuffer,
+ identifier = graphIdentifier("mv2"),
+ expectedFlowStatus = flowStatus,
+ cond = flowProgressEvent =>
+ flowProgressEvent.origin.sourceCodeLocation == Option(
+ QueryOrigin(
+ language = Option(Python()),
+ filePath = Option("<string>"),
+ line = Option(34),
+ objectName = Option("spark_catalog.default.mv2"),
+ objectType = Option(QueryOriginType.Flow.toString))),
+ expectedEventLevel = EventLevel.INFO)
+
+ assertFlowProgressEvent(
+ updateContext.eventBuffer,
+ identifier = graphIdentifier("mv"),
+ expectedFlowStatus = flowStatus,
+ cond = flowProgressEvent =>
+ flowProgressEvent.origin.sourceCodeLocation == Option(
+ QueryOrigin(
+ language = Option(Python()),
+ filePath = Option("<string>"),
+ line = Option(38),
+ objectName = Option("spark_catalog.default.mv"),
+ objectType = Option(QueryOriginType.Flow.toString))),
+ expectedEventLevel = EventLevel.INFO)
+ }
+
+ // Note that streaming flows do not have a PLANNING phase.
+ Seq(FlowStatus.QUEUED, FlowStatus.STARTING, FlowStatus.RUNNING,
FlowStatus.COMPLETED)
+ .foreach { flowStatus =>
+ assertFlowProgressEvent(
+ updateContext.eventBuffer,
+ identifier = graphIdentifier("table1"),
+ expectedFlowStatus = flowStatus,
+ cond = flowProgressEvent =>
+ flowProgressEvent.origin.sourceCodeLocation == Option(
+ QueryOrigin(
+ language = Option(Python()),
+ filePath = Option("<string>"),
+ line = Option(28),
+ objectName = Option("spark_catalog.default.table1"),
+ objectType = Option(QueryOriginType.Flow.toString))),
+ expectedEventLevel = EventLevel.INFO)
+
+ assertFlowProgressEvent(
+ updateContext.eventBuffer,
+ identifier = graphIdentifier("standalone_flow1"),
+ expectedFlowStatus = flowStatus,
+ cond = flowProgressEvent =>
+ flowProgressEvent.origin.sourceCodeLocation == Option(
+ QueryOrigin(
+ language = Option(Python()),
+ filePath = Option("<string>"),
+ line = Option(43),
+ objectName = Option("spark_catalog.default.standalone_flow1"),
+ objectType = Option(QueryOriginType.Flow.toString))),
+ expectedEventLevel = EventLevel.INFO)
+ }
+ }
+
test("basic with inverted topological order") {
// This graph is purposefully in the wrong topological order to test the
topological sort
val graph = buildGraph("""
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
index fcab53ae32ac..6d12e2281874 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala
@@ -123,7 +123,8 @@ private class FlowResolver(rawGraph: DataflowGraph) {
allInputs = allInputs,
availableInputs = availableResolvedInputs.values.toList,
configuration = flowToResolve.sqlConf,
- queryContext = flowToResolve.queryContext
+ queryContext = flowToResolve.queryContext,
+ queryOrigin = flowToResolve.origin
)
val result =
flowFunctionResult match {
@@ -169,7 +170,8 @@ private class FlowResolver(rawGraph: DataflowGraph) {
allInputs = allInputs,
availableInputs = availableResolvedInputs.values.toList,
configuration = newSqlConf,
- queryContext = flowToResolve.queryContext
+ queryContext = flowToResolve.queryContext,
+ queryOrigin = flowToResolve.origin
)
} else {
f
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
index 40fb8dbbe5dc..91feee936170 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala
@@ -72,13 +72,16 @@ trait FlowFunction extends Logging {
* @param availableInputs the list of all [[Input]]s available to this flow
* @param configuration the spark configurations that apply to this flow.
* @param queryContext The context of the query being evaluated.
+ * @param queryOrigin The source code location of the flow definition this
flow function was
+ * instantiated from.
* @return the inputs actually used, and the DataFrame expression for the
flow
*/
def call(
allInputs: Set[TableIdentifier],
availableInputs: Seq[Input],
configuration: Map[String, String],
- queryContext: QueryContext
+ queryContext: QueryContext,
+ queryOrigin: QueryOrigin
): FlowFunctionResult
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
index 311bdfd6a3d2..18ae45c4f340 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala
@@ -46,7 +46,8 @@ object FlowAnalysis {
allInputs: Set[TableIdentifier],
availableInputs: Seq[Input],
confs: Map[String, String],
- queryContext: QueryContext
+ queryContext: QueryContext,
+ queryOrigin: QueryOrigin
): FlowFunctionResult = {
val ctx = FlowAnalysisContext(
allInputs = allInputs,
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
index eb6b20b8a4ef..55a03a2d19f9 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
@@ -193,7 +193,7 @@ class SqlGraphRegistrationContext(
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
partitionCols =
Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
properties = cst.tableSpec.properties,
- baseOrigin = queryOrigin.copy(
+ origin = queryOrigin.copy(
objectName = Option(stIdentifier.unquotedString),
objectType = Option(QueryOriginType.Table.toString)
),
@@ -224,7 +224,7 @@ class SqlGraphRegistrationContext(
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
partitionCols =
Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
properties = cst.tableSpec.properties,
- baseOrigin = queryOrigin.copy(
+ origin = queryOrigin.copy(
objectName = Option(stIdentifier.unquotedString),
objectType = Option(QueryOriginType.Table.toString)
),
@@ -274,7 +274,7 @@ class SqlGraphRegistrationContext(
Option.when(cmv.columns.nonEmpty)(StructType(cmv.columns.map(_.toV1Column))),
partitionCols =
Option(PartitionHelper.applyPartitioning(cmv.partitioning, queryOrigin)),
properties = cmv.tableSpec.properties,
- baseOrigin = queryOrigin.copy(
+ origin = queryOrigin.copy(
objectName = Option(mvIdentifier.unquotedString),
objectType = Option(QueryOriginType.Table.toString)
),
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
index ee78f96d5316..95a57dcc4495 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala
@@ -123,17 +123,12 @@ case class Table(
normalizedPath: Option[String],
properties: Map[String, String] = Map.empty,
comment: Option[String],
- baseOrigin: QueryOrigin,
+ override val origin: QueryOrigin,
isStreamingTable: Boolean,
format: Option[String]
) extends TableInput
with Output {
- override val origin: QueryOrigin = baseOrigin.copy(
- objectType = Some("table"),
- objectName = Some(identifier.unquotedString)
- )
-
// Load this table's data from underlying storage.
override def load(readOptions: InputReadOptions): DataFrame = {
try {
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
index 38bd858a688f..4a33dd2c61a8 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.pipelines.graph.{
PersistedView,
QueryContext,
QueryOrigin,
+ QueryOriginType,
Table,
TemporaryView,
UnresolvedFlow
@@ -129,7 +130,6 @@ class TestGraphRegistrationContext(
isStreamingTable: Boolean
): Unit = {
// scalastyle:on
- val tableIdentifier = GraphIdentifierManager.parseTableIdentifier(name,
spark)
val qualifiedIdentifier = GraphIdentifierManager
.parseAndQualifyTableIdentifier(
rawTableIdentifier = GraphIdentifierManager
@@ -144,7 +144,12 @@ class TestGraphRegistrationContext(
specifiedSchema = specifiedSchema,
partitionCols = partitionCols,
properties = properties,
- baseOrigin = baseOrigin,
+ origin = baseOrigin.merge(
+ QueryOrigin(
+ objectName = Option(qualifiedIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Table.toString)
+ )
+ ),
format = format.orElse(Some("parquet")),
normalizedPath = None,
isStreamingTable = isStreamingTable
@@ -215,13 +220,20 @@ class TestGraphRegistrationContext(
case _ => persistedViewIdentifier
}
+ val viewOrigin: QueryOrigin = origin.merge(
+ QueryOrigin(
+ objectName = Option(viewIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.View.toString)
+ )
+ )
+
registerView(
viewType match {
case LocalTempView =>
TemporaryView(
identifier = viewIdentifier,
comment = comment,
- origin = origin,
+ origin = viewOrigin,
properties = Map.empty,
sqlText = sqlText
)
@@ -229,7 +241,7 @@ class TestGraphRegistrationContext(
PersistedView(
identifier = viewIdentifier,
comment = comment,
- origin = origin,
+ origin = viewOrigin,
properties = Map.empty,
sqlText = sqlText
)
@@ -298,7 +310,10 @@ class TestGraphRegistrationContext(
),
sqlConf = Map.empty,
once = once,
- origin = QueryOrigin()
+ origin = QueryOrigin(
+ objectName = Option(flowIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Flow.toString)
+ )
)
)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]