This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 80bba4463eba [SPARK-48459][CONNECT][PYTHON] Implement
DataFrameQueryContext in Spark Connect
80bba4463eba is described below
commit 80bba4463eba29a56cdd90642f0681c3710ce87c
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Jun 18 17:25:32 2024 +0900
[SPARK-48459][CONNECT][PYTHON] Implement DataFrameQueryContext in Spark
Connect
### What changes were proposed in this pull request?
This PR proposes to Implement DataFrameQueryContext in Spark Connect.
1. Add two new protobuf messages packed together with `Expression`:
```proto
message Origin {
// (Required) Indicate the origin type.
oneof function {
PythonOrigin python_origin = 1;
}
}
message PythonOrigin {
// (Required) Name of the origin, for example, the name of the
function
string fragment = 1;
// (Required) Callsite to show to end users, for example, stacktrace.
string call_site = 2;
}
```
2. Merge `DataFrameQueryContext.pysparkFragment` and
`DataFrameQueryContext.pysparkcallSite` to existing
`DataFrameQueryContext.fragment` and `DataFrameQueryContext.callSite`
3. Separate `QueryContext` into `SQLQueryContext` and
`DataFrameQueryContext` for consistency w/ Scala side
4. Implement the origin logic. `current_origin` thread local holds the
current call site/the function name, and `Expression` gets it from it.
They are set to individual expression messages, and are used when
analysis happens - this resembles Spark SQL implementation.
See also https://github.com/apache/spark/pull/45377.
### Why are the changes needed?
See https://github.com/apache/spark/pull/45377
### Does this PR introduce _any_ user-facing change?
Yes, same as https://github.com/apache/spark/pull/45377 but in Spark
Connect.
### How was this patch tested?
Same unittests reused in Spark Connect.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #46789 from HyukjinKwon/connect-context.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/protobuf/spark/connect/common.proto | 15 +++
.../main/protobuf/spark/connect/expressions.proto | 7 ++
.../main/protobuf/spark/connect/relations.proto | 2 +
.../sql/connect/planner/SparkConnectPlanner.scala | 19 ++-
.../spark/sql/connect/utils/ErrorUtils.scala | 22 +++-
python/pyspark/errors/exceptions/captured.py | 51 +++++---
python/pyspark/errors/exceptions/connect.py | 83 +++++++++++--
python/pyspark/errors/utils.py | 72 +++++++----
python/pyspark/sql/connect/column.py | 2 +
python/pyspark/sql/connect/expressions.py | 58 +++++----
python/pyspark/sql/connect/proto/common_pb2.py | 6 +-
python/pyspark/sql/connect/proto/common_pb2.pyi | 51 ++++++++
.../pyspark/sql/connect/proto/expressions_pb2.py | 133 +++++++++++----------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 28 +++++
python/pyspark/sql/connect/proto/relations_pb2.pyi | 4 +-
.../connect/test_parity_dataframe_query_context.py | 6 +-
.../sql/tests/test_dataframe_query_context.py | 65 +++++-----
python/pyspark/testing/utils.py | 10 +-
.../spark/sql/catalyst/trees/QueryContexts.scala | 34 +++---
19 files changed, 463 insertions(+), 205 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/common.proto
b/connector/connect/common/src/main/protobuf/spark/connect/common.proto
index da334bfd9ee8..b2848370b01d 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/common.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/common.proto
@@ -81,3 +81,18 @@ message ResourceProfile {
// (e.g., cores, memory, CPU) to its specific request.
map<string, TaskResourceRequest> task_resources = 2;
}
+
+message Origin {
+ // (Required) Indicate the origin type.
+ oneof function {
+ PythonOrigin python_origin = 1;
+ }
+}
+
+message PythonOrigin {
+ // (Required) Name of the origin, for example, the name of the function
+ string fragment = 1;
+
+ // (Required) Callsite to show to end users, for example, stacktrace.
+ string call_site = 2;
+}
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index eb4be1586005..404a2fdcb2e8 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -19,6 +19,7 @@ syntax = 'proto3';
import "google/protobuf/any.proto";
import "spark/connect/types.proto";
+import "spark/connect/common.proto";
package spark.connect;
@@ -30,6 +31,7 @@ option go_package = "internal/generated";
// expressions in SQL appear.
message Expression {
+ ExpressionCommon common = 18;
oneof expr_type {
Literal literal = 1;
UnresolvedAttribute unresolved_attribute = 2;
@@ -342,6 +344,11 @@ message Expression {
}
}
+message ExpressionCommon {
+ // (Required) Keep the information of the origin for this expression such as
stacktrace.
+ Origin origin = 1;
+}
+
message CommonInlineUserDefinedFunction {
// (Required) Name of the user-defined function.
string function_name = 1;
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 0c4ca6290a76..ba1a633b0e61 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -106,6 +106,8 @@ message Unknown {}
// Common metadata of all relations.
message RelationCommon {
+ // TODO(SPARK-48639): Add origin like Expression.ExpressionCommon
+
// (Required) Shared relation metadata.
string source_info = 1;
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 453d2b30876f..a7fc87d8b65d 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile,
TaskResourceProfile, TaskResourceRequest}
-import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter,
Observation, RelationalGroupedDataset, SparkSession}
+import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders,
ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar}
@@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter,
Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup,
CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark,
DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith,
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions,
Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union,
Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter,
StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
@@ -1471,7 +1472,21 @@ class SparkConnectPlanner(
* Catalyst expression
*/
@DeveloperApi
- def transformExpression(exp: proto.Expression): Expression = {
+ def transformExpression(exp: proto.Expression): Expression = if
(exp.hasCommon) {
+ try {
+ val origin = exp.getCommon.getOrigin
+ PySparkCurrentOrigin.set(
+ origin.getPythonOrigin.getFragment,
+ origin.getPythonOrigin.getCallSite)
+ withOrigin { doTransformExpression(exp) }
+ } finally {
+ PySparkCurrentOrigin.clear()
+ }
+ } else {
+ doTransformExpression(exp)
+ }
+
+ private def doTransformExpression(exp: proto.Expression): Expression = {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL =>
transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 773a97e92973..355048cf3036 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods
-import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
+import org.apache.spark.{QueryContextType, SparkEnv, SparkException,
SparkThrowable}
import org.apache.spark.api.python.PythonException
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.internal.{Logging, MDC}
@@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging {
sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
}
for (queryCtx <- sparkThrowable.getQueryContext) {
- sparkThrowableBuilder.addQueryContexts(
- FetchErrorDetailsResponse.QueryContext
- .newBuilder()
+ val builder = FetchErrorDetailsResponse.QueryContext
+ .newBuilder()
+ val context = if (queryCtx.contextType() == QueryContextType.SQL) {
+ builder
+
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL)
.setObjectType(queryCtx.objectType())
.setObjectName(queryCtx.objectName())
.setStartIndex(queryCtx.startIndex())
.setStopIndex(queryCtx.stopIndex())
.setFragment(queryCtx.fragment())
- .build())
+ .setSummary(queryCtx.summary())
+ .build()
+ } else {
+ builder
+
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME)
+ .setFragment(queryCtx.fragment())
+ .setCallSite(queryCtx.callSite())
+ .setSummary(queryCtx.summary())
+ .build()
+ }
+ sparkThrowableBuilder.addQueryContexts(context)
}
if (sparkThrowable.getSqlState != null) {
sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState)
diff --git a/python/pyspark/errors/exceptions/captured.py
b/python/pyspark/errors/exceptions/captured.py
index 2a30eba3fb22..b5bb742161c0 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -166,7 +166,14 @@ class CapturedException(PySparkException):
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- return [QueryContext(q) for q in self._origin.getQueryContext()]
+ contexts: List[BaseQueryContext] = []
+ for q in self._origin.getQueryContext():
+ if q.contextType().toString() == "SQL":
+ contexts.append(SQLQueryContext(q))
+ else:
+ contexts.append(DataFrameQueryContext(q))
+
+ return contexts
else:
return []
@@ -379,17 +386,12 @@ class UnknownException(CapturedException,
BaseUnknownException):
"""
-class QueryContext(BaseQueryContext):
+class SQLQueryContext(BaseQueryContext):
def __init__(self, q: "JavaObject"):
self._q = q
def contextType(self) -> QueryContextType:
- context_type = self._q.contextType().toString()
- assert context_type in ("SQL", "DataFrame")
- if context_type == "DataFrame":
- return QueryContextType.DataFrame
- else:
- return QueryContextType.SQL
+ return QueryContextType.SQL
def objectType(self) -> str:
return str(self._q.objectType())
@@ -409,13 +411,34 @@ class QueryContext(BaseQueryContext):
def callSite(self) -> str:
return str(self._q.callSite())
- def pysparkFragment(self) -> Optional[str]: # type: ignore[return]
- if self.contextType() == QueryContextType.DataFrame:
- return str(self._q.pysparkFragment())
+ def summary(self) -> str:
+ return str(self._q.summary())
+
+
+class DataFrameQueryContext(BaseQueryContext):
+ def __init__(self, q: "JavaObject"):
+ self._q = q
+
+ def contextType(self) -> QueryContextType:
+ return QueryContextType.DataFrame
+
+ def objectType(self) -> str:
+ return str(self._q.objectType())
+
+ def objectName(self) -> str:
+ return str(self._q.objectName())
- def pysparkCallSite(self) -> Optional[str]: # type: ignore[return]
- if self.contextType() == QueryContextType.DataFrame:
- return str(self._q.pysparkCallSite())
+ def startIndex(self) -> int:
+ return int(self._q.startIndex())
+
+ def stopIndex(self) -> int:
+ return int(self._q.stopIndex())
+
+ def fragment(self) -> str:
+ return str(self._q.fragment())
+
+ def callSite(self) -> str:
+ return str(self._q.callSite())
def summary(self) -> str:
return str(self._q.summary())
diff --git a/python/pyspark/errors/exceptions/connect.py
b/python/pyspark/errors/exceptions/connect.py
index 0cffe7268753..8a95358f2697 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -91,7 +91,10 @@ def convert_exception(
)
query_contexts = []
for query_context in
resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
- query_contexts.append(QueryContext(query_context))
+ if query_context.context_type ==
pb2.FetchErrorDetailsResponse.QueryContext.SQL:
+ query_contexts.append(SQLQueryContext(query_context))
+ else:
+ query_contexts.append(DataFrameQueryContext(query_context))
if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(
@@ -430,17 +433,12 @@ class
SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
"""
-class QueryContext(BaseQueryContext):
+class SQLQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q
def contextType(self) -> QueryContextType:
- context_type = self._q.context_type
-
- if int(context_type) == QueryContextType.DataFrame.value:
- return QueryContextType.DataFrame
- else:
- return QueryContextType.SQL
+ return QueryContextType.SQL
def objectType(self) -> str:
return str(self._q.object_type)
@@ -457,6 +455,75 @@ class QueryContext(BaseQueryContext):
def fragment(self) -> str:
return str(self._q.fragment)
+ def callSite(self) -> str:
+ raise UnsupportedOperationException(
+ "",
+ error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+ message_parameters={"className": "SQLQueryContext", "methodName":
"callSite"},
+ sql_state="0A000",
+ server_stacktrace=None,
+ display_server_stacktrace=False,
+ query_contexts=[],
+ )
+
+ def summary(self) -> str:
+ return str(self._q.summary)
+
+
+class DataFrameQueryContext(BaseQueryContext):
+ def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
+ self._q = q
+
+ def contextType(self) -> QueryContextType:
+ return QueryContextType.DataFrame
+
+ def objectType(self) -> str:
+ raise UnsupportedOperationException(
+ "",
+ error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+ message_parameters={"className": "DataFrameQueryContext",
"methodName": "objectType"},
+ sql_state="0A000",
+ server_stacktrace=None,
+ display_server_stacktrace=False,
+ query_contexts=[],
+ )
+
+ def objectName(self) -> str:
+ raise UnsupportedOperationException(
+ "",
+ error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+ message_parameters={"className": "DataFrameQueryContext",
"methodName": "objectName"},
+ sql_state="0A000",
+ server_stacktrace=None,
+ display_server_stacktrace=False,
+ query_contexts=[],
+ )
+
+ def startIndex(self) -> int:
+ raise UnsupportedOperationException(
+ "",
+ error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+ message_parameters={"className": "DataFrameQueryContext",
"methodName": "startIndex"},
+ sql_state="0A000",
+ server_stacktrace=None,
+ display_server_stacktrace=False,
+ query_contexts=[],
+ )
+
+ def stopIndex(self) -> int:
+ raise UnsupportedOperationException(
+ "",
+ error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+ message_parameters={"className": "DataFrameQueryContext",
"methodName": "stopIndex"},
+ sql_state="0A000",
+ server_stacktrace=None,
+ display_server_stacktrace=False,
+ query_contexts=[],
+ )
+
+ def fragment(self) -> str:
+ return str(self._q.fragment)
+
def callSite(self) -> str:
return str(self._q.call_site)
diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py
index cddec3319964..cd3046380284 100644
--- a/python/pyspark/errors/utils.py
+++ b/python/pyspark/errors/utils.py
@@ -19,16 +19,34 @@ import re
import functools
import inspect
import os
-from typing import Any, Callable, Dict, Match, TypeVar, Type, TYPE_CHECKING
+import threading
+from typing import Any, Callable, Dict, Match, TypeVar, Type, Optional,
TYPE_CHECKING
from pyspark.errors.error_classes import ERROR_CLASSES_MAP
-
if TYPE_CHECKING:
from pyspark.sql import SparkSession
- from py4j.java_gateway import JavaClass
T = TypeVar("T")
+_current_origin = threading.local()
+
+
+def current_origin() -> threading.local:
+ global _current_origin
+
+ if not hasattr(_current_origin, "fragment"):
+ _current_origin.fragment = None
+ if not hasattr(_current_origin, "call_site"):
+ _current_origin.call_site = None
+ return _current_origin
+
+
+def set_current_origin(fragment: Optional[str], call_site: Optional[str]) ->
None:
+ global _current_origin
+
+ _current_origin.fragment = fragment
+ _current_origin.call_site = call_site
+
class ErrorClassesReader:
"""
@@ -130,9 +148,7 @@ class ErrorClassesReader:
return message_template
-def _capture_call_site(
- spark_session: "SparkSession", pyspark_origin: "JavaClass", fragment: str
-) -> None:
+def _capture_call_site(spark_session: "SparkSession", depth: int) -> str:
"""
Capture the call site information including file name, line number, and
function name.
This function updates the thread-local storage from JVM side
(PySparkCurrentOrigin)
@@ -142,10 +158,6 @@ def _capture_call_site(
----------
spark_session : SparkSession
Current active Spark session.
- pyspark_origin : py4j.JavaClass
- PySparkCurrentOrigin from current active Spark session.
- fragment : str
- The name of the PySpark API function being captured.
Notes
-----
@@ -153,14 +165,11 @@ def _capture_call_site(
in the user code that led to the error.
"""
stack = list(reversed(inspect.stack()))
- depth = int(
- spark_session.conf.get("spark.sql.stackTracesInDataFrameContext") #
type: ignore[arg-type]
- )
selected_frames = stack[:depth]
call_sites = [f"{frame.filename}:{frame.lineno}" for frame in
selected_frames]
call_sites_str = "\n".join(call_sites)
- pyspark_origin.set(fragment, call_sites_str)
+ return call_sites_str
def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
@@ -172,19 +181,38 @@ def _with_origin(func: Callable[..., Any]) ->
Callable[..., Any]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
from pyspark.sql import SparkSession
+ from pyspark.sql.utils import is_remote
spark = SparkSession.getActiveSession()
if spark is not None and hasattr(func, "__name__"):
- assert spark._jvm is not None
- pyspark_origin =
spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+ if is_remote():
+ global current_origin
- # Update call site when the function is called
- _capture_call_site(spark, pyspark_origin, func.__name__)
+ # Getting the configuration requires RPC call. Uses the
default value for now.
+ depth = 1
+ set_current_origin(func.__name__, _capture_call_site(spark,
depth))
- try:
- return func(*args, **kwargs)
- finally:
- pyspark_origin.clear()
+ try:
+ return func(*args, **kwargs)
+ finally:
+ set_current_origin(None, None)
+ else:
+ assert spark._jvm is not None
+ jvm_pyspark_origin = (
+
spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+ )
+ depth = int(
+ spark.conf.get( # type: ignore[arg-type]
+ "spark.sql.stackTracesInDataFrameContext"
+ )
+ )
+ # Update call site when the function is called
+ jvm_pyspark_origin.set(func.__name__,
_capture_call_site(spark, depth))
+
+ try:
+ return func(*args, **kwargs)
+ finally:
+ jvm_pyspark_origin.clear()
else:
return func(*args, **kwargs)
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index c38717afccda..b63e06bccae1 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -46,6 +46,7 @@ from pyspark.sql.connect.expressions import (
WithField,
DropField,
)
+from pyspark.errors.utils import with_origin_to_class
if TYPE_CHECKING:
@@ -95,6 +96,7 @@ def _unary_op(name: str, self: ParentColumn) -> ParentColumn:
return Column(UnresolvedFunction(name, [self._expr])) # type:
ignore[list-item]
+@with_origin_to_class
class Column(ParentColumn):
def __new__(
cls,
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index e4e31ad60034..c10bef56c3b8 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
-from pyspark.sql.utils import is_timestamp_ntz_preferred
check_dependencies(__name__)
@@ -77,6 +76,8 @@ from pyspark.sql.connect.types import (
proto_schema_to_pyspark_data_type,
)
from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors.utils import current_origin
+from pyspark.sql.utils import is_timestamp_ntz_preferred
if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
@@ -89,7 +90,16 @@ class Expression:
"""
def __init__(self) -> None:
- pass
+ origin = current_origin()
+ fragment = origin.fragment
+ call_site = origin.call_site
+ self.origin = None
+ if fragment is not None and call_site is not None:
+ self.origin = proto.Origin(
+ python_origin=proto.PythonOrigin(
+ fragment=origin.fragment, call_site=origin.call_site
+ )
+ )
def to_plan( # type: ignore[empty-body]
self, session: "SparkConnectClient"
@@ -112,6 +122,12 @@ class Expression:
def name(self) -> str: # type: ignore[empty-body]
...
+ def _create_proto_expression(self) -> proto.Expression:
+ plan = proto.Expression()
+ if self.origin is not None:
+ plan.common.origin.CopyFrom(self.origin)
+ return plan
+
class CaseWhen(Expression):
def __init__(
@@ -162,7 +178,7 @@ class ColumnAlias(Expression):
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
if len(self._alias) == 1:
- exp = proto.Expression()
+ exp = self._create_proto_expression()
exp.alias.name.append(self._alias[0])
exp.alias.expr.CopyFrom(self._child.to_plan(session))
@@ -175,7 +191,7 @@ class ColumnAlias(Expression):
error_class="CANNOT_PROVIDE_METADATA",
message_parameters={},
)
- exp = proto.Expression()
+ exp = self._create_proto_expression()
exp.alias.name.extend(self._alias)
exp.alias.expr.CopyFrom(self._child.to_plan(session))
return exp
@@ -407,7 +423,7 @@ class LiteralExpression(Expression):
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
"""Converts the literal expression to the literal in proto."""
- expr = proto.Expression()
+ expr = self._create_proto_expression()
if self._value is None:
expr.literal.null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
@@ -483,7 +499,7 @@ class ColumnReference(Expression):
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the expression."""
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.unresolved_attribute.unparsed_identifier =
self._unparsed_identifier
if self._plan_id is not None:
expr.unresolved_attribute.plan_id = self._plan_id
@@ -512,7 +528,7 @@ class UnresolvedStar(Expression):
self._plan_id = plan_id
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.unresolved_star.SetInParent()
if self._unparsed_target is not None:
expr.unresolved_star.unparsed_target = self._unparsed_target
@@ -546,7 +562,7 @@ class SQLExpression(Expression):
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the SQL expression."""
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.expression_string.expression = self._expr
return expr
@@ -572,7 +588,7 @@ class SortOrder(Expression):
)
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- sort = proto.Expression()
+ sort = self._create_proto_expression()
sort.sort_order.child.CopyFrom(self._child.to_plan(session))
if self._ascending:
@@ -611,7 +627,7 @@ class UnresolvedFunction(Expression):
self._is_distinct = is_distinct
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- fun = proto.Expression()
+ fun = self._create_proto_expression()
fun.unresolved_function.function_name = self._name
if len(self._args) > 0:
fun.unresolved_function.arguments.extend([arg.to_plan(session) for
arg in self._args])
@@ -708,7 +724,7 @@ class CommonInlineUserDefinedFunction(Expression):
self._function = function
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.common_inline_user_defined_function.function_name =
self._function_name
expr.common_inline_user_defined_function.deterministic =
self._deterministic
if len(self._arguments) > 0:
@@ -762,7 +778,7 @@ class WithField(Expression):
self._valueExpr = valueExpr
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
expr.update_fields.field_name = self._fieldName
expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session))
@@ -787,7 +803,7 @@ class DropField(Expression):
self._fieldName = fieldName
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
expr.update_fields.field_name = self._fieldName
return expr
@@ -811,7 +827,7 @@ class UnresolvedExtractValue(Expression):
self._extraction = extraction
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.unresolved_extract_value.child.CopyFrom(self._child.to_plan(session))
expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session))
return expr
@@ -831,7 +847,7 @@ class UnresolvedRegex(Expression):
self._plan_id = plan_id
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.unresolved_regex.col_name = self.col_name
if self._plan_id is not None:
expr.unresolved_regex.plan_id = self._plan_id
@@ -858,7 +874,7 @@ class CastExpression(Expression):
self._eval_mode = eval_mode
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- fun = proto.Expression()
+ fun = self._create_proto_expression()
fun.cast.expr.CopyFrom(self._expr.to_plan(session))
if isinstance(self._data_type, str):
fun.cast.type_str = self._data_type
@@ -909,7 +925,7 @@ class UnresolvedNamedLambdaVariable(Expression):
self._name_parts = name_parts
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.unresolved_named_lambda_variable.name_parts.extend(self._name_parts)
return expr
@@ -951,7 +967,7 @@ class LambdaFunction(Expression):
self._arguments = arguments
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.lambda_function.function.CopyFrom(self._function.to_plan(session))
expr.lambda_function.arguments.extend(
[arg.to_plan(session).unresolved_named_lambda_variable for arg in
self._arguments]
@@ -984,7 +1000,7 @@ class WindowExpression(Expression):
self._windowSpec = windowSpec
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.window.window_function.CopyFrom(self._windowFunction.to_plan(session))
@@ -1091,7 +1107,7 @@ class CallFunction(Expression):
self._args = args
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.call_function.function_name = self._name
if len(self._args) > 0:
expr.call_function.arguments.extend([arg.to_plan(session) for arg
in self._args])
@@ -1115,7 +1131,7 @@ class NamedArgumentExpression(Expression):
self._value = value
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
- expr = proto.Expression()
+ expr = self._create_proto_expression()
expr.named_argument_expression.key = self._key
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
return expr
diff --git a/python/pyspark/sql/connect/proto/common_pb2.py
b/python/pyspark/sql/connect/proto/common_pb2.py
index a77d1463e51d..fd528fae3369 100644
--- a/python/pyspark/sql/connect/proto/common_pb2.py
+++ b/python/pyspark/sql/connect/proto/common_pb2.py
@@ -29,7 +29,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01
\x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12
\n\x0cuse_off_heap\x18\x03
\x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04
\x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05
\x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01
\x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3 [...]
+
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01
\x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12
\n\x0cuse_off_heap\x18\x03
\x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04
\x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05
\x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01
\x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3 [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -59,4 +59,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_RESOURCEPROFILE_EXECUTORRESOURCESENTRY._serialized_end = 899
_RESOURCEPROFILE_TASKRESOURCESENTRY._serialized_start = 901
_RESOURCEPROFILE_TASKRESOURCESENTRY._serialized_end = 1001
+ _ORIGIN._serialized_start = 1003
+ _ORIGIN._serialized_end = 1091
+ _PYTHONORIGIN._serialized_start = 1093
+ _PYTHONORIGIN._serialized_end = 1164
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/common_pb2.pyi
b/python/pyspark/sql/connect/proto/common_pb2.pyi
index 163781b41998..eda172e26cf4 100644
--- a/python/pyspark/sql/connect/proto/common_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/common_pb2.pyi
@@ -296,3 +296,54 @@ class ResourceProfile(google.protobuf.message.Message):
) -> None: ...
global___ResourceProfile = ResourceProfile
+
+class Origin(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ PYTHON_ORIGIN_FIELD_NUMBER: builtins.int
+ @property
+ def python_origin(self) -> global___PythonOrigin: ...
+ def __init__(
+ self,
+ *,
+ python_origin: global___PythonOrigin | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "function", b"function", "python_origin", b"python_origin"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "function", b"function", "python_origin", b"python_origin"
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["function", b"function"]
+ ) -> typing_extensions.Literal["python_origin"] | None: ...
+
+global___Origin = Origin
+
+class PythonOrigin(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ FRAGMENT_FIELD_NUMBER: builtins.int
+ CALL_SITE_FIELD_NUMBER: builtins.int
+ fragment: builtins.str
+ """(Required) Name of the origin, for example, the name of the function"""
+ call_site: builtins.str
+ """(Required) Callsite to show to end users, for example, stacktrace."""
+ def __init__(
+ self,
+ *,
+ fragment: builtins.str = ...,
+ call_site: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal["call_site", b"call_site",
"fragment", b"fragment"],
+ ) -> None: ...
+
+global___PythonOrigin = PythonOrigin
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 07e1a53f3608..521e15d6950b 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -30,10 +30,11 @@ _sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__pb2
+from pyspark.sql.connect.proto import common_pb2 as
spark_dot_connect_dot_common__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xde.\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x97/\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAtt
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -45,68 +46,70 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._serialized_options = (
b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
)
- _EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 6087
- _EXPRESSION_WINDOW._serialized_start = 1645
- _EXPRESSION_WINDOW._serialized_end = 2428
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2428
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2202
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2347
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2349
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2428
- _EXPRESSION_SORTORDER._serialized_start = 2431
- _EXPRESSION_SORTORDER._serialized_end = 2856
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2661
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2769
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2771
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2856
- _EXPRESSION_CAST._serialized_start = 2859
- _EXPRESSION_CAST._serialized_end = 3174
- _EXPRESSION_CAST_EVALMODE._serialized_start = 3060
- _EXPRESSION_CAST_EVALMODE._serialized_end = 3158
- _EXPRESSION_LITERAL._serialized_start = 3177
- _EXPRESSION_LITERAL._serialized_end = 4740
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4012
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4129
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4131
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4229
- _EXPRESSION_LITERAL_ARRAY._serialized_start = 4232
- _EXPRESSION_LITERAL_ARRAY._serialized_end = 4362
- _EXPRESSION_LITERAL_MAP._serialized_start = 4365
- _EXPRESSION_LITERAL_MAP._serialized_end = 4592
- _EXPRESSION_LITERAL_STRUCT._serialized_start = 4595
- _EXPRESSION_LITERAL_STRUCT._serialized_end = 4724
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4743
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4929
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4932
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5136
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5138
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5188
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5190
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5314
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5316
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5402
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5405
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5537
- _EXPRESSION_UPDATEFIELDS._serialized_start = 5540
- _EXPRESSION_UPDATEFIELDS._serialized_end = 5727
- _EXPRESSION_ALIAS._serialized_start = 5729
- _EXPRESSION_ALIAS._serialized_end = 5849
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5852
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6010
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6012
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6074
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6090
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6454
- _PYTHONUDF._serialized_start = 6457
- _PYTHONUDF._serialized_end = 6612
- _SCALARSCALAUDF._serialized_start = 6615
- _SCALARSCALAUDF._serialized_end = 6829
- _JAVAUDF._serialized_start = 6832
- _JAVAUDF._serialized_end = 6981
- _CALLFUNCTION._serialized_start = 6983
- _CALLFUNCTION._serialized_end = 7091
- _NAMEDARGUMENTEXPRESSION._serialized_start = 7093
- _NAMEDARGUMENTEXPRESSION._serialized_end = 7185
+ _EXPRESSION._serialized_start = 133
+ _EXPRESSION._serialized_end = 6172
+ _EXPRESSION_WINDOW._serialized_start = 1730
+ _EXPRESSION_WINDOW._serialized_end = 2513
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 2020
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2513
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2287
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2432
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2434
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2513
+ _EXPRESSION_SORTORDER._serialized_start = 2516
+ _EXPRESSION_SORTORDER._serialized_end = 2941
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2746
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2854
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2856
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2941
+ _EXPRESSION_CAST._serialized_start = 2944
+ _EXPRESSION_CAST._serialized_end = 3259
+ _EXPRESSION_CAST_EVALMODE._serialized_start = 3145
+ _EXPRESSION_CAST_EVALMODE._serialized_end = 3243
+ _EXPRESSION_LITERAL._serialized_start = 3262
+ _EXPRESSION_LITERAL._serialized_end = 4825
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4097
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4214
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4216
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4314
+ _EXPRESSION_LITERAL_ARRAY._serialized_start = 4317
+ _EXPRESSION_LITERAL_ARRAY._serialized_end = 4447
+ _EXPRESSION_LITERAL_MAP._serialized_start = 4450
+ _EXPRESSION_LITERAL_MAP._serialized_end = 4677
+ _EXPRESSION_LITERAL_STRUCT._serialized_start = 4680
+ _EXPRESSION_LITERAL_STRUCT._serialized_end = 4809
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4828
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 5014
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 5017
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5221
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5223
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5273
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5275
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5399
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5401
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5487
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5490
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5622
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 5625
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 5812
+ _EXPRESSION_ALIAS._serialized_start = 5814
+ _EXPRESSION_ALIAS._serialized_end = 5934
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5937
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6095
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6097
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6159
+ _EXPRESSIONCOMMON._serialized_start = 6174
+ _EXPRESSIONCOMMON._serialized_end = 6239
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6242
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6606
+ _PYTHONUDF._serialized_start = 6609
+ _PYTHONUDF._serialized_end = 6764
+ _SCALARSCALAUDF._serialized_start = 6767
+ _SCALARSCALAUDF._serialized_end = 6981
+ _JAVAUDF._serialized_start = 6984
+ _JAVAUDF._serialized_end = 7133
+ _CALLFUNCTION._serialized_start = 7135
+ _CALLFUNCTION._serialized_end = 7243
+ _NAMEDARGUMENTEXPRESSION._serialized_start = 7245
+ _NAMEDARGUMENTEXPRESSION._serialized_end = 7337
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 0209ba411616..eaf4059b2dbc 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -40,6 +40,7 @@ import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
+import pyspark.sql.connect.proto.common_pb2
import pyspark.sql.connect.proto.types_pb2
import sys
import typing
@@ -1163,6 +1164,7 @@ class Expression(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["name_parts",
b"name_parts"]
) -> None: ...
+ COMMON_FIELD_NUMBER: builtins.int
LITERAL_FIELD_NUMBER: builtins.int
UNRESOLVED_ATTRIBUTE_FIELD_NUMBER: builtins.int
UNRESOLVED_FUNCTION_FIELD_NUMBER: builtins.int
@@ -1182,6 +1184,8 @@ class Expression(google.protobuf.message.Message):
NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
+ def common(self) -> global___ExpressionCommon: ...
+ @property
def literal(self) -> global___Expression.Literal: ...
@property
def unresolved_attribute(self) -> global___Expression.UnresolvedAttribute:
...
@@ -1225,6 +1229,7 @@ class Expression(google.protobuf.message.Message):
def __init__(
self,
*,
+ common: global___ExpressionCommon | None = ...,
literal: global___Expression.Literal | None = ...,
unresolved_attribute: global___Expression.UnresolvedAttribute | None =
...,
unresolved_function: global___Expression.UnresolvedFunction | None =
...,
@@ -1254,6 +1259,8 @@ class Expression(google.protobuf.message.Message):
b"call_function",
"cast",
b"cast",
+ "common",
+ b"common",
"common_inline_user_defined_function",
b"common_inline_user_defined_function",
"expr_type",
@@ -1297,6 +1304,8 @@ class Expression(google.protobuf.message.Message):
b"call_function",
"cast",
b"cast",
+ "common",
+ b"common",
"common_inline_user_defined_function",
b"common_inline_user_defined_function",
"expr_type",
@@ -1359,6 +1368,25 @@ class Expression(google.protobuf.message.Message):
global___Expression = Expression
+class ExpressionCommon(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ ORIGIN_FIELD_NUMBER: builtins.int
+ @property
+ def origin(self) -> pyspark.sql.connect.proto.common_pb2.Origin:
+ """(Required) Keep the information of the origin for this expression
such as stacktrace."""
+ def __init__(
+ self,
+ *,
+ origin: pyspark.sql.connect.proto.common_pb2.Origin | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["origin", b"origin"]
+ ) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal["origin",
b"origin"]) -> None: ...
+
+global___ExpressionCommon = ExpressionCommon
+
class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 819b5d717c52..ecf07a4ea9d7 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -608,7 +608,9 @@ class Unknown(google.protobuf.message.Message):
global___Unknown = Unknown
class RelationCommon(google.protobuf.message.Message):
- """Common metadata of all relations."""
+ """Common metadata of all relations.
+ TODO(SPARK-48639): Add origin like Expression.ExpressionCommon
+ """
DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git
a/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
b/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
index 38bcd5643984..59107363571e 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
@@ -21,10 +21,8 @@ from pyspark.sql.tests.test_dataframe_query_context import
DataFrameQueryContext
from pyspark.testing.connectutils import ReusedConnectTestCase
-class DataFrameParityTests(DataFrameQueryContextTestsMixin,
ReusedConnectTestCase):
- @unittest.skip("Spark Connect does not support DataFrameQueryContext
currently.")
- def test_dataframe_query_context(self):
- super().test_dataframe_query_context()
+class DataFrameQueryContextParityTests(DataFrameQueryContextTestsMixin,
ReusedConnectTestCase):
+ pass
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_dataframe_query_context.py
b/python/pyspark/sql/tests/test_dataframe_query_context.py
index 42fb0b0e452f..e1a3e33df859 100644
--- a/python/pyspark/sql/tests/test_dataframe_query_context.py
+++ b/python/pyspark/sql/tests/test_dataframe_query_context.py
@@ -41,7 +41,7 @@ class DataFrameQueryContextTestsMixin:
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="divide",
+ fragment="__truediv__",
)
# DataFrameQueryContext with pysparkLoggingInfo - plus
@@ -57,7 +57,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="plus",
+ fragment="__add__",
)
# DataFrameQueryContext with pysparkLoggingInfo - minus
@@ -73,7 +73,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="minus",
+ fragment="__sub__",
)
# DataFrameQueryContext with pysparkLoggingInfo - multiply
@@ -89,7 +89,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="multiply",
+ fragment="__mul__",
)
# DataFrameQueryContext with pysparkLoggingInfo - mod
@@ -105,7 +105,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="mod",
+ fragment="__mod__",
)
# DataFrameQueryContext with pysparkLoggingInfo - equalTo
@@ -121,7 +121,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="__eq__",
+ fragment="__eq__",
)
# DataFrameQueryContext with pysparkLoggingInfo - lt
@@ -137,7 +137,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="lt",
+ fragment="__lt__",
)
# DataFrameQueryContext with pysparkLoggingInfo - leq
@@ -153,7 +153,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="leq",
+ fragment="__le__",
)
# DataFrameQueryContext with pysparkLoggingInfo - geq
@@ -169,7 +169,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="geq",
+ fragment="__ge__",
)
# DataFrameQueryContext with pysparkLoggingInfo - gt
@@ -185,7 +185,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="gt",
+ fragment="__gt__",
)
# DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe
@@ -201,7 +201,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="eqNullSafe",
+ fragment="eqNullSafe",
)
# DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR
@@ -217,7 +217,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="bitwiseOR",
+ fragment="bitwiseOR",
)
# DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND
@@ -233,7 +233,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="bitwiseAND",
+ fragment="bitwiseAND",
)
# DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR
@@ -249,7 +249,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="bitwiseXOR",
+ fragment="bitwiseXOR",
)
# DataFrameQueryContext with pysparkLoggingInfo - chained
(`divide` is problematic)
@@ -262,7 +262,7 @@ class DataFrameQueryContextTestsMixin:
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="divide",
+ fragment="__truediv__",
)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`plus`
is problematic)
@@ -282,7 +282,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="plus",
+ fragment="__add__",
)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`minus`
is problematic)
@@ -302,7 +302,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="minus",
+ fragment="__sub__",
)
# DataFrameQueryContext with pysparkLoggingInfo - chained
(`multiply` is problematic)
@@ -320,7 +320,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="multiply",
+ fragment="__mul__",
)
# Multiple expressions in df.select (`divide` is problematic)
@@ -331,7 +331,7 @@ class DataFrameQueryContextTestsMixin:
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="divide",
+ fragment="__truediv__",
)
# Multiple expressions in df.select (`plus` is problematic)
@@ -347,7 +347,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="plus",
+ fragment="__add__",
)
# Multiple expressions in df.select (`minus` is problematic)
@@ -363,7 +363,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="minus",
+ fragment="__sub__",
)
# Multiple expressions in df.select (`multiply` is problematic)
@@ -379,7 +379,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="multiply",
+ fragment="__mul__",
)
# Multiple expressions with pre-declared expressions (`divide` is
problematic)
@@ -392,7 +392,7 @@ class DataFrameQueryContextTestsMixin:
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="divide",
+ fragment="__truediv__",
)
# Multiple expressions with pre-declared expressions (`plus` is
problematic)
@@ -410,7 +410,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="plus",
+ fragment="__add__",
)
# Multiple expressions with pre-declared expressions (`minus` is
problematic)
@@ -428,7 +428,7 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="minus",
+ fragment="__sub__",
)
# Multiple expressions with pre-declared expressions (`multiply`
is problematic)
@@ -446,20 +446,11 @@ class DataFrameQueryContextTestsMixin:
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="multiply",
- )
-
- # DataFrameQueryContext without pysparkLoggingInfo
- with self.assertRaises(AnalysisException) as pe:
- df.select("non-existing-column")
- self.check_error(
- exception=pe.exception,
- error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION",
- message_parameters={"objectName": "`non-existing-column`",
"proposal": "`id`"},
- query_context_type=QueryContextType.DataFrame,
- pyspark_fragment="",
+ fragment="__mul__",
)
+ def test_sql_query_context(self):
+ with self.sql_conf({"spark.sql.ansi.enabled": True}):
# SQLQueryContext
with self.assertRaises(ArithmeticException) as pe:
self.spark.sql("select 10/0").collect()
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index fa58b7286fe8..c74291524dae 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -287,7 +287,7 @@ class PySparkErrorTestUtils:
error_class: str,
message_parameters: Optional[Dict[str, str]] = None,
query_context_type: Optional[QueryContextType] = None,
- pyspark_fragment: Optional[str] = None,
+ fragment: Optional[str] = None,
):
query_context = exception.getQueryContext()
assert bool(query_context) == (query_context_type is not None), (
@@ -326,10 +326,10 @@ class PySparkErrorTestUtils:
)
if actual == QueryContextType.DataFrame:
assert (
- pyspark_fragment is not None
- ), "`pyspark_fragment` is required when QueryContextType
is DataFrame."
- expected = pyspark_fragment
- actual = actual_context.pysparkFragment()
+ fragment is not None
+ ), "`fragment` is required when QueryContextType is
DataFrame."
+ expected = fragment
+ actual = actual_context.fragment()
self.assertEqual(
expected,
actual,
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
index 1c2456f00bcd..2b3f4674539e 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
@@ -145,36 +145,30 @@ case class DataFrameQueryContext(
override def stopIndex: Int = throw SparkUnsupportedOperationException()
override val fragment: String = {
- stackTrace.headOption.map { firstElem =>
- val methodName = firstElem.getMethodName
- if (methodName.length > 1 && methodName(0) == '$') {
- methodName.substring(1)
- } else {
- methodName
- }
- }.getOrElse("")
+ pysparkErrorContext.map(_._1).getOrElse {
+ stackTrace.headOption.map { firstElem =>
+ val methodName = firstElem.getMethodName
+ if (methodName.length > 1 && methodName(0) == '$') {
+ methodName.substring(1)
+ } else {
+ methodName
+ }
+ }.getOrElse("")
+ }
}
- override val callSite: String = stackTrace.tail.mkString("\n")
-
- val pysparkFragment: String = pysparkErrorContext.map(_._1).getOrElse("")
- val pysparkCallSite: String = pysparkErrorContext.map(_._2).getOrElse("")
-
- val (displayedFragment, displayedCallsite) = if
(pysparkErrorContext.nonEmpty) {
- (pysparkFragment, pysparkCallSite)
- } else {
- (fragment, callSite)
- }
+ override val callSite: String = pysparkErrorContext.map(
+ _._2).getOrElse(stackTrace.tail.mkString("\n"))
override lazy val summary: String = {
val builder = new StringBuilder
builder ++= "== DataFrame ==\n"
builder ++= "\""
- builder ++= displayedFragment
+ builder ++= fragment
builder ++= "\""
builder ++= " was called from\n"
- builder ++= displayedCallsite
+ builder ++= callSite
builder += '\n'
builder.result()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]