This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 80bba4463eba [SPARK-48459][CONNECT][PYTHON] Implement 
DataFrameQueryContext in Spark Connect
80bba4463eba is described below

commit 80bba4463eba29a56cdd90642f0681c3710ce87c
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Jun 18 17:25:32 2024 +0900

    [SPARK-48459][CONNECT][PYTHON] Implement DataFrameQueryContext in Spark 
Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to Implement DataFrameQueryContext in Spark Connect.
    
    1.  Add two new protobuf messages packed together with `Expression`:
    
        ```proto
        message Origin {
          // (Required) Indicate the origin type.
          oneof function {
            PythonOrigin python_origin = 1;
          }
        }
    
        message PythonOrigin {
          // (Required) Name of the origin, for example, the name of the 
function
          string fragment = 1;
    
          // (Required) Callsite to show to end users, for example, stacktrace.
          string call_site = 2;
        }
        ```
    
    2. Merge `DataFrameQueryContext.pysparkFragment` and 
`DataFrameQueryContext.pysparkcallSite` to existing 
`DataFrameQueryContext.fragment` and `DataFrameQueryContext.callSite`
    
    3. Separate `QueryContext` into `SQLQueryContext` and 
`DataFrameQueryContext` for consistency w/ Scala side
    
    4. Implement the origin logic. `current_origin` thread local holds the 
current call site/the function name, and `Expression` gets it from it.
        They are set to individual expression messages, and are used when 
analysis happens - this resembles Spark SQL implementation.
    
    See also https://github.com/apache/spark/pull/45377.
    
    ### Why are the changes needed?
    
    See https://github.com/apache/spark/pull/45377
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, same as https://github.com/apache/spark/pull/45377 but in Spark 
Connect.
    
    ### How was this patch tested?
    
    Same unittests reused in Spark Connect.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46789 from HyukjinKwon/connect-context.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../src/main/protobuf/spark/connect/common.proto   |  15 +++
 .../main/protobuf/spark/connect/expressions.proto  |   7 ++
 .../main/protobuf/spark/connect/relations.proto    |   2 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  19 ++-
 .../spark/sql/connect/utils/ErrorUtils.scala       |  22 +++-
 python/pyspark/errors/exceptions/captured.py       |  51 +++++---
 python/pyspark/errors/exceptions/connect.py        |  83 +++++++++++--
 python/pyspark/errors/utils.py                     |  72 +++++++----
 python/pyspark/sql/connect/column.py               |   2 +
 python/pyspark/sql/connect/expressions.py          |  58 +++++----
 python/pyspark/sql/connect/proto/common_pb2.py     |   6 +-
 python/pyspark/sql/connect/proto/common_pb2.pyi    |  51 ++++++++
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 133 +++++++++++----------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  28 +++++
 python/pyspark/sql/connect/proto/relations_pb2.pyi |   4 +-
 .../connect/test_parity_dataframe_query_context.py |   6 +-
 .../sql/tests/test_dataframe_query_context.py      |  65 +++++-----
 python/pyspark/testing/utils.py                    |  10 +-
 .../spark/sql/catalyst/trees/QueryContexts.scala   |  34 +++---
 19 files changed, 463 insertions(+), 205 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/common.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/common.proto
index da334bfd9ee8..b2848370b01d 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/common.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/common.proto
@@ -81,3 +81,18 @@ message ResourceProfile {
   // (e.g., cores, memory, CPU) to its specific request.
   map<string, TaskResourceRequest> task_resources = 2;
 }
+
+message Origin {
+  // (Required) Indicate the origin type.
+  oneof function {
+    PythonOrigin python_origin = 1;
+  }
+}
+
+message PythonOrigin {
+  // (Required) Name of the origin, for example, the name of the function
+  string fragment = 1;
+
+  // (Required) Callsite to show to end users, for example, stacktrace.
+  string call_site = 2;
+}
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index eb4be1586005..404a2fdcb2e8 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -19,6 +19,7 @@ syntax = 'proto3';
 
 import "google/protobuf/any.proto";
 import "spark/connect/types.proto";
+import "spark/connect/common.proto";
 
 package spark.connect;
 
@@ -30,6 +31,7 @@ option go_package = "internal/generated";
 // expressions in SQL appear.
 message Expression {
 
+  ExpressionCommon common = 18;
   oneof expr_type {
     Literal literal = 1;
     UnresolvedAttribute unresolved_attribute = 2;
@@ -342,6 +344,11 @@ message Expression {
   }
 }
 
+message ExpressionCommon {
+  // (Required) Keep the information of the origin for this expression such as 
stacktrace.
+  Origin origin = 1;
+}
+
 message CommonInlineUserDefinedFunction {
   // (Required) Name of the user-defined function.
   string function_name = 1;
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 0c4ca6290a76..ba1a633b0e61 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -106,6 +106,8 @@ message Unknown {}
 
 // Common metadata of all relations.
 message RelationCommon {
+  // TODO(SPARK-48639): Add origin like Expression.ExpressionCommon
+
   // (Required) Shared relation metadata.
   string source_info = 1;
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 453d2b30876f..a7fc87d8b65d 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
 import org.apache.spark.ml.{functions => MLFunctions}
 import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, 
TaskResourceProfile, TaskResourceRequest}
-import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, 
Observation, RelationalGroupedDataset, SparkSession}
+import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, 
ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
 import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier, QueryPlanningTracker}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar}
@@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, 
Inner, JoinType, L
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, 
CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, 
DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, 
Unpivot, UnresolvedHint}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CharVarcharUtils}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, 
StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
@@ -1471,7 +1472,21 @@ class SparkConnectPlanner(
    *   Catalyst expression
    */
   @DeveloperApi
-  def transformExpression(exp: proto.Expression): Expression = {
+  def transformExpression(exp: proto.Expression): Expression = if 
(exp.hasCommon) {
+    try {
+      val origin = exp.getCommon.getOrigin
+      PySparkCurrentOrigin.set(
+        origin.getPythonOrigin.getFragment,
+        origin.getPythonOrigin.getCallSite)
+      withOrigin { doTransformExpression(exp) }
+    } finally {
+      PySparkCurrentOrigin.clear()
+    }
+  } else {
+    doTransformExpression(exp)
+  }
+
+  private def doTransformExpression(exp: proto.Expression): Expression = {
     exp.getExprTypeCase match {
       case proto.Expression.ExprTypeCase.LITERAL => 
transformLiteral(exp.getLiteral)
       case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 773a97e92973..355048cf3036 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods
 
-import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
+import org.apache.spark.{QueryContextType, SparkEnv, SparkException, 
SparkThrowable}
 import org.apache.spark.api.python.PythonException
 import org.apache.spark.connect.proto.FetchErrorDetailsResponse
 import org.apache.spark.internal.{Logging, MDC}
@@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging {
             sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
           }
           for (queryCtx <- sparkThrowable.getQueryContext) {
-            sparkThrowableBuilder.addQueryContexts(
-              FetchErrorDetailsResponse.QueryContext
-                .newBuilder()
+            val builder = FetchErrorDetailsResponse.QueryContext
+              .newBuilder()
+            val context = if (queryCtx.contextType() == QueryContextType.SQL) {
+              builder
+                
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL)
                 .setObjectType(queryCtx.objectType())
                 .setObjectName(queryCtx.objectName())
                 .setStartIndex(queryCtx.startIndex())
                 .setStopIndex(queryCtx.stopIndex())
                 .setFragment(queryCtx.fragment())
-                .build())
+                .setSummary(queryCtx.summary())
+                .build()
+            } else {
+              builder
+                
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME)
+                .setFragment(queryCtx.fragment())
+                .setCallSite(queryCtx.callSite())
+                .setSummary(queryCtx.summary())
+                .build()
+            }
+            sparkThrowableBuilder.addQueryContexts(context)
           }
           if (sparkThrowable.getSqlState != null) {
             sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState)
diff --git a/python/pyspark/errors/exceptions/captured.py 
b/python/pyspark/errors/exceptions/captured.py
index 2a30eba3fb22..b5bb742161c0 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -166,7 +166,14 @@ class CapturedException(PySparkException):
         if self._origin is not None and is_instance_of(
             gw, self._origin, "org.apache.spark.SparkThrowable"
         ):
-            return [QueryContext(q) for q in self._origin.getQueryContext()]
+            contexts: List[BaseQueryContext] = []
+            for q in self._origin.getQueryContext():
+                if q.contextType().toString() == "SQL":
+                    contexts.append(SQLQueryContext(q))
+                else:
+                    contexts.append(DataFrameQueryContext(q))
+
+            return contexts
         else:
             return []
 
@@ -379,17 +386,12 @@ class UnknownException(CapturedException, 
BaseUnknownException):
     """
 
 
-class QueryContext(BaseQueryContext):
+class SQLQueryContext(BaseQueryContext):
     def __init__(self, q: "JavaObject"):
         self._q = q
 
     def contextType(self) -> QueryContextType:
-        context_type = self._q.contextType().toString()
-        assert context_type in ("SQL", "DataFrame")
-        if context_type == "DataFrame":
-            return QueryContextType.DataFrame
-        else:
-            return QueryContextType.SQL
+        return QueryContextType.SQL
 
     def objectType(self) -> str:
         return str(self._q.objectType())
@@ -409,13 +411,34 @@ class QueryContext(BaseQueryContext):
     def callSite(self) -> str:
         return str(self._q.callSite())
 
-    def pysparkFragment(self) -> Optional[str]:  # type: ignore[return]
-        if self.contextType() == QueryContextType.DataFrame:
-            return str(self._q.pysparkFragment())
+    def summary(self) -> str:
+        return str(self._q.summary())
+
+
+class DataFrameQueryContext(BaseQueryContext):
+    def __init__(self, q: "JavaObject"):
+        self._q = q
+
+    def contextType(self) -> QueryContextType:
+        return QueryContextType.DataFrame
+
+    def objectType(self) -> str:
+        return str(self._q.objectType())
+
+    def objectName(self) -> str:
+        return str(self._q.objectName())
 
-    def pysparkCallSite(self) -> Optional[str]:  # type: ignore[return]
-        if self.contextType() == QueryContextType.DataFrame:
-            return str(self._q.pysparkCallSite())
+    def startIndex(self) -> int:
+        return int(self._q.startIndex())
+
+    def stopIndex(self) -> int:
+        return int(self._q.stopIndex())
+
+    def fragment(self) -> str:
+        return str(self._q.fragment())
+
+    def callSite(self) -> str:
+        return str(self._q.callSite())
 
     def summary(self) -> str:
         return str(self._q.summary())
diff --git a/python/pyspark/errors/exceptions/connect.py 
b/python/pyspark/errors/exceptions/connect.py
index 0cffe7268753..8a95358f2697 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -91,7 +91,10 @@ def convert_exception(
         )
         query_contexts = []
         for query_context in 
resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
-            query_contexts.append(QueryContext(query_context))
+            if query_context.context_type == 
pb2.FetchErrorDetailsResponse.QueryContext.SQL:
+                query_contexts.append(SQLQueryContext(query_context))
+            else:
+                query_contexts.append(DataFrameQueryContext(query_context))
 
     if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
         return ParseException(
@@ -430,17 +433,12 @@ class 
SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
     """
 
 
-class QueryContext(BaseQueryContext):
+class SQLQueryContext(BaseQueryContext):
     def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
         self._q = q
 
     def contextType(self) -> QueryContextType:
-        context_type = self._q.context_type
-
-        if int(context_type) == QueryContextType.DataFrame.value:
-            return QueryContextType.DataFrame
-        else:
-            return QueryContextType.SQL
+        return QueryContextType.SQL
 
     def objectType(self) -> str:
         return str(self._q.object_type)
@@ -457,6 +455,75 @@ class QueryContext(BaseQueryContext):
     def fragment(self) -> str:
         return str(self._q.fragment)
 
+    def callSite(self) -> str:
+        raise UnsupportedOperationException(
+            "",
+            error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+            message_parameters={"className": "SQLQueryContext", "methodName": 
"callSite"},
+            sql_state="0A000",
+            server_stacktrace=None,
+            display_server_stacktrace=False,
+            query_contexts=[],
+        )
+
+    def summary(self) -> str:
+        return str(self._q.summary)
+
+
+class DataFrameQueryContext(BaseQueryContext):
+    def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
+        self._q = q
+
+    def contextType(self) -> QueryContextType:
+        return QueryContextType.DataFrame
+
+    def objectType(self) -> str:
+        raise UnsupportedOperationException(
+            "",
+            error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+            message_parameters={"className": "DataFrameQueryContext", 
"methodName": "objectType"},
+            sql_state="0A000",
+            server_stacktrace=None,
+            display_server_stacktrace=False,
+            query_contexts=[],
+        )
+
+    def objectName(self) -> str:
+        raise UnsupportedOperationException(
+            "",
+            error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+            message_parameters={"className": "DataFrameQueryContext", 
"methodName": "objectName"},
+            sql_state="0A000",
+            server_stacktrace=None,
+            display_server_stacktrace=False,
+            query_contexts=[],
+        )
+
+    def startIndex(self) -> int:
+        raise UnsupportedOperationException(
+            "",
+            error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+            message_parameters={"className": "DataFrameQueryContext", 
"methodName": "startIndex"},
+            sql_state="0A000",
+            server_stacktrace=None,
+            display_server_stacktrace=False,
+            query_contexts=[],
+        )
+
+    def stopIndex(self) -> int:
+        raise UnsupportedOperationException(
+            "",
+            error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
+            message_parameters={"className": "DataFrameQueryContext", 
"methodName": "stopIndex"},
+            sql_state="0A000",
+            server_stacktrace=None,
+            display_server_stacktrace=False,
+            query_contexts=[],
+        )
+
+    def fragment(self) -> str:
+        return str(self._q.fragment)
+
     def callSite(self) -> str:
         return str(self._q.call_site)
 
diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py
index cddec3319964..cd3046380284 100644
--- a/python/pyspark/errors/utils.py
+++ b/python/pyspark/errors/utils.py
@@ -19,16 +19,34 @@ import re
 import functools
 import inspect
 import os
-from typing import Any, Callable, Dict, Match, TypeVar, Type, TYPE_CHECKING
+import threading
+from typing import Any, Callable, Dict, Match, TypeVar, Type, Optional, 
TYPE_CHECKING
 from pyspark.errors.error_classes import ERROR_CLASSES_MAP
 
-
 if TYPE_CHECKING:
     from pyspark.sql import SparkSession
-    from py4j.java_gateway import JavaClass
 
 T = TypeVar("T")
 
+_current_origin = threading.local()
+
+
+def current_origin() -> threading.local:
+    global _current_origin
+
+    if not hasattr(_current_origin, "fragment"):
+        _current_origin.fragment = None
+    if not hasattr(_current_origin, "call_site"):
+        _current_origin.call_site = None
+    return _current_origin
+
+
+def set_current_origin(fragment: Optional[str], call_site: Optional[str]) -> 
None:
+    global _current_origin
+
+    _current_origin.fragment = fragment
+    _current_origin.call_site = call_site
+
 
 class ErrorClassesReader:
     """
@@ -130,9 +148,7 @@ class ErrorClassesReader:
         return message_template
 
 
-def _capture_call_site(
-    spark_session: "SparkSession", pyspark_origin: "JavaClass", fragment: str
-) -> None:
+def _capture_call_site(spark_session: "SparkSession", depth: int) -> str:
     """
     Capture the call site information including file name, line number, and 
function name.
     This function updates the thread-local storage from JVM side 
(PySparkCurrentOrigin)
@@ -142,10 +158,6 @@ def _capture_call_site(
     ----------
     spark_session : SparkSession
         Current active Spark session.
-    pyspark_origin : py4j.JavaClass
-        PySparkCurrentOrigin from current active Spark session.
-    fragment : str
-        The name of the PySpark API function being captured.
 
     Notes
     -----
@@ -153,14 +165,11 @@ def _capture_call_site(
     in the user code that led to the error.
     """
     stack = list(reversed(inspect.stack()))
-    depth = int(
-        spark_session.conf.get("spark.sql.stackTracesInDataFrameContext")  # 
type: ignore[arg-type]
-    )
     selected_frames = stack[:depth]
     call_sites = [f"{frame.filename}:{frame.lineno}" for frame in 
selected_frames]
     call_sites_str = "\n".join(call_sites)
 
-    pyspark_origin.set(fragment, call_sites_str)
+    return call_sites_str
 
 
 def _with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
@@ -172,19 +181,38 @@ def _with_origin(func: Callable[..., Any]) -> 
Callable[..., Any]:
     @functools.wraps(func)
     def wrapper(*args: Any, **kwargs: Any) -> Any:
         from pyspark.sql import SparkSession
+        from pyspark.sql.utils import is_remote
 
         spark = SparkSession.getActiveSession()
         if spark is not None and hasattr(func, "__name__"):
-            assert spark._jvm is not None
-            pyspark_origin = 
spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+            if is_remote():
+                global current_origin
 
-            # Update call site when the function is called
-            _capture_call_site(spark, pyspark_origin, func.__name__)
+                # Getting the configuration requires RPC call. Uses the 
default value for now.
+                depth = 1
+                set_current_origin(func.__name__, _capture_call_site(spark, 
depth))
 
-            try:
-                return func(*args, **kwargs)
-            finally:
-                pyspark_origin.clear()
+                try:
+                    return func(*args, **kwargs)
+                finally:
+                    set_current_origin(None, None)
+            else:
+                assert spark._jvm is not None
+                jvm_pyspark_origin = (
+                    
spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
+                )
+                depth = int(
+                    spark.conf.get(  # type: ignore[arg-type]
+                        "spark.sql.stackTracesInDataFrameContext"
+                    )
+                )
+                # Update call site when the function is called
+                jvm_pyspark_origin.set(func.__name__, 
_capture_call_site(spark, depth))
+
+                try:
+                    return func(*args, **kwargs)
+                finally:
+                    jvm_pyspark_origin.clear()
         else:
             return func(*args, **kwargs)
 
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index c38717afccda..b63e06bccae1 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -46,6 +46,7 @@ from pyspark.sql.connect.expressions import (
     WithField,
     DropField,
 )
+from pyspark.errors.utils import with_origin_to_class
 
 
 if TYPE_CHECKING:
@@ -95,6 +96,7 @@ def _unary_op(name: str, self: ParentColumn) -> ParentColumn:
     return Column(UnresolvedFunction(name, [self._expr]))  # type: 
ignore[list-item]
 
 
+@with_origin_to_class
 class Column(ParentColumn):
     def __new__(
         cls,
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index e4e31ad60034..c10bef56c3b8 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 from pyspark.sql.connect.utils import check_dependencies
-from pyspark.sql.utils import is_timestamp_ntz_preferred
 
 check_dependencies(__name__)
 
@@ -77,6 +76,8 @@ from pyspark.sql.connect.types import (
     proto_schema_to_pyspark_data_type,
 )
 from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors.utils import current_origin
+from pyspark.sql.utils import is_timestamp_ntz_preferred
 
 if TYPE_CHECKING:
     from pyspark.sql.connect.client import SparkConnectClient
@@ -89,7 +90,16 @@ class Expression:
     """
 
     def __init__(self) -> None:
-        pass
+        origin = current_origin()
+        fragment = origin.fragment
+        call_site = origin.call_site
+        self.origin = None
+        if fragment is not None and call_site is not None:
+            self.origin = proto.Origin(
+                python_origin=proto.PythonOrigin(
+                    fragment=origin.fragment, call_site=origin.call_site
+                )
+            )
 
     def to_plan(  # type: ignore[empty-body]
         self, session: "SparkConnectClient"
@@ -112,6 +122,12 @@ class Expression:
     def name(self) -> str:  # type: ignore[empty-body]
         ...
 
+    def _create_proto_expression(self) -> proto.Expression:
+        plan = proto.Expression()
+        if self.origin is not None:
+            plan.common.origin.CopyFrom(self.origin)
+        return plan
+
 
 class CaseWhen(Expression):
     def __init__(
@@ -162,7 +178,7 @@ class ColumnAlias(Expression):
 
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
         if len(self._alias) == 1:
-            exp = proto.Expression()
+            exp = self._create_proto_expression()
             exp.alias.name.append(self._alias[0])
             exp.alias.expr.CopyFrom(self._child.to_plan(session))
 
@@ -175,7 +191,7 @@ class ColumnAlias(Expression):
                     error_class="CANNOT_PROVIDE_METADATA",
                     message_parameters={},
                 )
-            exp = proto.Expression()
+            exp = self._create_proto_expression()
             exp.alias.name.extend(self._alias)
             exp.alias.expr.CopyFrom(self._child.to_plan(session))
             return exp
@@ -407,7 +423,7 @@ class LiteralExpression(Expression):
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
         """Converts the literal expression to the literal in proto."""
 
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
 
         if self._value is None:
             
expr.literal.null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
@@ -483,7 +499,7 @@ class ColumnReference(Expression):
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         """Returns the Proto representation of the expression."""
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         expr.unresolved_attribute.unparsed_identifier = 
self._unparsed_identifier
         if self._plan_id is not None:
             expr.unresolved_attribute.plan_id = self._plan_id
@@ -512,7 +528,7 @@ class UnresolvedStar(Expression):
         self._plan_id = plan_id
 
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         expr.unresolved_star.SetInParent()
         if self._unparsed_target is not None:
             expr.unresolved_star.unparsed_target = self._unparsed_target
@@ -546,7 +562,7 @@ class SQLExpression(Expression):
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         """Returns the Proto representation of the SQL expression."""
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         expr.expression_string.expression = self._expr
         return expr
 
@@ -572,7 +588,7 @@ class SortOrder(Expression):
         )
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        sort = proto.Expression()
+        sort = self._create_proto_expression()
         sort.sort_order.child.CopyFrom(self._child.to_plan(session))
 
         if self._ascending:
@@ -611,7 +627,7 @@ class UnresolvedFunction(Expression):
         self._is_distinct = is_distinct
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        fun = proto.Expression()
+        fun = self._create_proto_expression()
         fun.unresolved_function.function_name = self._name
         if len(self._args) > 0:
             fun.unresolved_function.arguments.extend([arg.to_plan(session) for 
arg in self._args])
@@ -708,7 +724,7 @@ class CommonInlineUserDefinedFunction(Expression):
         self._function = function
 
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         expr.common_inline_user_defined_function.function_name = 
self._function_name
         expr.common_inline_user_defined_function.deterministic = 
self._deterministic
         if len(self._arguments) > 0:
@@ -762,7 +778,7 @@ class WithField(Expression):
         self._valueExpr = valueExpr
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
         expr.update_fields.field_name = self._fieldName
         
expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session))
@@ -787,7 +803,7 @@ class DropField(Expression):
         self._fieldName = fieldName
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
         expr.update_fields.field_name = self._fieldName
         return expr
@@ -811,7 +827,7 @@ class UnresolvedExtractValue(Expression):
         self._extraction = extraction
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         
expr.unresolved_extract_value.child.CopyFrom(self._child.to_plan(session))
         
expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session))
         return expr
@@ -831,7 +847,7 @@ class UnresolvedRegex(Expression):
         self._plan_id = plan_id
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         expr.unresolved_regex.col_name = self.col_name
         if self._plan_id is not None:
             expr.unresolved_regex.plan_id = self._plan_id
@@ -858,7 +874,7 @@ class CastExpression(Expression):
         self._eval_mode = eval_mode
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        fun = proto.Expression()
+        fun = self._create_proto_expression()
         fun.cast.expr.CopyFrom(self._expr.to_plan(session))
         if isinstance(self._data_type, str):
             fun.cast.type_str = self._data_type
@@ -909,7 +925,7 @@ class UnresolvedNamedLambdaVariable(Expression):
         self._name_parts = name_parts
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         
expr.unresolved_named_lambda_variable.name_parts.extend(self._name_parts)
         return expr
 
@@ -951,7 +967,7 @@ class LambdaFunction(Expression):
         self._arguments = arguments
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         expr.lambda_function.function.CopyFrom(self._function.to_plan(session))
         expr.lambda_function.arguments.extend(
             [arg.to_plan(session).unresolved_named_lambda_variable for arg in 
self._arguments]
@@ -984,7 +1000,7 @@ class WindowExpression(Expression):
         self._windowSpec = windowSpec
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
 
         
expr.window.window_function.CopyFrom(self._windowFunction.to_plan(session))
 
@@ -1091,7 +1107,7 @@ class CallFunction(Expression):
         self._args = args
 
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         expr.call_function.function_name = self._name
         if len(self._args) > 0:
             expr.call_function.arguments.extend([arg.to_plan(session) for arg 
in self._args])
@@ -1115,7 +1131,7 @@ class NamedArgumentExpression(Expression):
         self._value = value
 
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
-        expr = proto.Expression()
+        expr = self._create_proto_expression()
         expr.named_argument_expression.key = self._key
         
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
         return expr
diff --git a/python/pyspark/sql/connect/proto/common_pb2.py 
b/python/pyspark/sql/connect/proto/common_pb2.py
index a77d1463e51d..fd528fae3369 100644
--- a/python/pyspark/sql/connect/proto/common_pb2.py
+++ b/python/pyspark/sql/connect/proto/common_pb2.py
@@ -29,7 +29,7 @@ _sym_db = _symbol_database.Default()
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01
 \x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12 
\n\x0cuse_off_heap\x18\x03 
\x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04 
\x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05 
\x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01 
\x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3 [...]
+    
b'\n\x1aspark/connect/common.proto\x12\rspark.connect"\xb0\x01\n\x0cStorageLevel\x12\x19\n\x08use_disk\x18\x01
 \x01(\x08R\x07useDisk\x12\x1d\n\nuse_memory\x18\x02 \x01(\x08R\tuseMemory\x12 
\n\x0cuse_off_heap\x18\x03 
\x01(\x08R\nuseOffHeap\x12"\n\x0c\x64\x65serialized\x18\x04 
\x01(\x08R\x0c\x64\x65serialized\x12 \n\x0breplication\x18\x05 
\x01(\x05R\x0breplication"G\n\x13ResourceInformation\x12\x12\n\x04name\x18\x01 
\x01(\tR\x04name\x12\x1c\n\taddresses\x18\x02 \x03(\tR\taddresses"\xc3 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -59,4 +59,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RESOURCEPROFILE_EXECUTORRESOURCESENTRY._serialized_end = 899
     _RESOURCEPROFILE_TASKRESOURCESENTRY._serialized_start = 901
     _RESOURCEPROFILE_TASKRESOURCESENTRY._serialized_end = 1001
+    _ORIGIN._serialized_start = 1003
+    _ORIGIN._serialized_end = 1091
+    _PYTHONORIGIN._serialized_start = 1093
+    _PYTHONORIGIN._serialized_end = 1164
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/common_pb2.pyi 
b/python/pyspark/sql/connect/proto/common_pb2.pyi
index 163781b41998..eda172e26cf4 100644
--- a/python/pyspark/sql/connect/proto/common_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/common_pb2.pyi
@@ -296,3 +296,54 @@ class ResourceProfile(google.protobuf.message.Message):
     ) -> None: ...
 
 global___ResourceProfile = ResourceProfile
+
+class Origin(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    PYTHON_ORIGIN_FIELD_NUMBER: builtins.int
+    @property
+    def python_origin(self) -> global___PythonOrigin: ...
+    def __init__(
+        self,
+        *,
+        python_origin: global___PythonOrigin | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "function", b"function", "python_origin", b"python_origin"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "function", b"function", "python_origin", b"python_origin"
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["function", b"function"]
+    ) -> typing_extensions.Literal["python_origin"] | None: ...
+
+global___Origin = Origin
+
+class PythonOrigin(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    FRAGMENT_FIELD_NUMBER: builtins.int
+    CALL_SITE_FIELD_NUMBER: builtins.int
+    fragment: builtins.str
+    """(Required) Name of the origin, for example, the name of the function"""
+    call_site: builtins.str
+    """(Required) Callsite to show to end users, for example, stacktrace."""
+    def __init__(
+        self,
+        *,
+        fragment: builtins.str = ...,
+        call_site: builtins.str = ...,
+    ) -> None: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal["call_site", b"call_site", 
"fragment", b"fragment"],
+    ) -> None: ...
+
+global___PythonOrigin = PythonOrigin
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 07e1a53f3608..521e15d6950b 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -30,10 +30,11 @@ _sym_db = _symbol_database.Default()
 
 from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
 from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__pb2
+from pyspark.sql.connect.proto import common_pb2 as 
spark_dot_connect_dot_common__pb2
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xde.\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x97/\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAtt
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -45,68 +46,70 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._serialized_options = (
         b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
     )
-    _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 6087
-    _EXPRESSION_WINDOW._serialized_start = 1645
-    _EXPRESSION_WINDOW._serialized_end = 2428
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2428
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2202
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2347
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2349
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2428
-    _EXPRESSION_SORTORDER._serialized_start = 2431
-    _EXPRESSION_SORTORDER._serialized_end = 2856
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2661
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2769
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2771
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2856
-    _EXPRESSION_CAST._serialized_start = 2859
-    _EXPRESSION_CAST._serialized_end = 3174
-    _EXPRESSION_CAST_EVALMODE._serialized_start = 3060
-    _EXPRESSION_CAST_EVALMODE._serialized_end = 3158
-    _EXPRESSION_LITERAL._serialized_start = 3177
-    _EXPRESSION_LITERAL._serialized_end = 4740
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4012
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4129
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4131
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4229
-    _EXPRESSION_LITERAL_ARRAY._serialized_start = 4232
-    _EXPRESSION_LITERAL_ARRAY._serialized_end = 4362
-    _EXPRESSION_LITERAL_MAP._serialized_start = 4365
-    _EXPRESSION_LITERAL_MAP._serialized_end = 4592
-    _EXPRESSION_LITERAL_STRUCT._serialized_start = 4595
-    _EXPRESSION_LITERAL_STRUCT._serialized_end = 4724
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4743
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4929
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4932
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5136
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5138
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5188
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5190
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5314
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5316
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5402
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5405
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5537
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 5540
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 5727
-    _EXPRESSION_ALIAS._serialized_start = 5729
-    _EXPRESSION_ALIAS._serialized_end = 5849
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5852
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6010
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6012
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6074
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6090
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6454
-    _PYTHONUDF._serialized_start = 6457
-    _PYTHONUDF._serialized_end = 6612
-    _SCALARSCALAUDF._serialized_start = 6615
-    _SCALARSCALAUDF._serialized_end = 6829
-    _JAVAUDF._serialized_start = 6832
-    _JAVAUDF._serialized_end = 6981
-    _CALLFUNCTION._serialized_start = 6983
-    _CALLFUNCTION._serialized_end = 7091
-    _NAMEDARGUMENTEXPRESSION._serialized_start = 7093
-    _NAMEDARGUMENTEXPRESSION._serialized_end = 7185
+    _EXPRESSION._serialized_start = 133
+    _EXPRESSION._serialized_end = 6172
+    _EXPRESSION_WINDOW._serialized_start = 1730
+    _EXPRESSION_WINDOW._serialized_end = 2513
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 2020
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2513
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2287
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2432
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2434
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2513
+    _EXPRESSION_SORTORDER._serialized_start = 2516
+    _EXPRESSION_SORTORDER._serialized_end = 2941
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2746
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2854
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2856
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2941
+    _EXPRESSION_CAST._serialized_start = 2944
+    _EXPRESSION_CAST._serialized_end = 3259
+    _EXPRESSION_CAST_EVALMODE._serialized_start = 3145
+    _EXPRESSION_CAST_EVALMODE._serialized_end = 3243
+    _EXPRESSION_LITERAL._serialized_start = 3262
+    _EXPRESSION_LITERAL._serialized_end = 4825
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4097
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4214
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4216
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4314
+    _EXPRESSION_LITERAL_ARRAY._serialized_start = 4317
+    _EXPRESSION_LITERAL_ARRAY._serialized_end = 4447
+    _EXPRESSION_LITERAL_MAP._serialized_start = 4450
+    _EXPRESSION_LITERAL_MAP._serialized_end = 4677
+    _EXPRESSION_LITERAL_STRUCT._serialized_start = 4680
+    _EXPRESSION_LITERAL_STRUCT._serialized_end = 4809
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4828
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 5014
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 5017
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5221
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5223
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5273
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5275
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5399
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5401
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5487
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5490
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5622
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 5625
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 5812
+    _EXPRESSION_ALIAS._serialized_start = 5814
+    _EXPRESSION_ALIAS._serialized_end = 5934
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5937
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6095
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6097
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6159
+    _EXPRESSIONCOMMON._serialized_start = 6174
+    _EXPRESSIONCOMMON._serialized_end = 6239
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6242
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6606
+    _PYTHONUDF._serialized_start = 6609
+    _PYTHONUDF._serialized_end = 6764
+    _SCALARSCALAUDF._serialized_start = 6767
+    _SCALARSCALAUDF._serialized_end = 6981
+    _JAVAUDF._serialized_start = 6984
+    _JAVAUDF._serialized_end = 7133
+    _CALLFUNCTION._serialized_start = 7135
+    _CALLFUNCTION._serialized_end = 7243
+    _NAMEDARGUMENTEXPRESSION._serialized_start = 7245
+    _NAMEDARGUMENTEXPRESSION._serialized_end = 7337
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 0209ba411616..eaf4059b2dbc 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -40,6 +40,7 @@ import google.protobuf.descriptor
 import google.protobuf.internal.containers
 import google.protobuf.internal.enum_type_wrapper
 import google.protobuf.message
+import pyspark.sql.connect.proto.common_pb2
 import pyspark.sql.connect.proto.types_pb2
 import sys
 import typing
@@ -1163,6 +1164,7 @@ class Expression(google.protobuf.message.Message):
             self, field_name: typing_extensions.Literal["name_parts", 
b"name_parts"]
         ) -> None: ...
 
+    COMMON_FIELD_NUMBER: builtins.int
     LITERAL_FIELD_NUMBER: builtins.int
     UNRESOLVED_ATTRIBUTE_FIELD_NUMBER: builtins.int
     UNRESOLVED_FUNCTION_FIELD_NUMBER: builtins.int
@@ -1182,6 +1184,8 @@ class Expression(google.protobuf.message.Message):
     NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
+    def common(self) -> global___ExpressionCommon: ...
+    @property
     def literal(self) -> global___Expression.Literal: ...
     @property
     def unresolved_attribute(self) -> global___Expression.UnresolvedAttribute: 
...
@@ -1225,6 +1229,7 @@ class Expression(google.protobuf.message.Message):
     def __init__(
         self,
         *,
+        common: global___ExpressionCommon | None = ...,
         literal: global___Expression.Literal | None = ...,
         unresolved_attribute: global___Expression.UnresolvedAttribute | None = 
...,
         unresolved_function: global___Expression.UnresolvedFunction | None = 
...,
@@ -1254,6 +1259,8 @@ class Expression(google.protobuf.message.Message):
             b"call_function",
             "cast",
             b"cast",
+            "common",
+            b"common",
             "common_inline_user_defined_function",
             b"common_inline_user_defined_function",
             "expr_type",
@@ -1297,6 +1304,8 @@ class Expression(google.protobuf.message.Message):
             b"call_function",
             "cast",
             b"cast",
+            "common",
+            b"common",
             "common_inline_user_defined_function",
             b"common_inline_user_defined_function",
             "expr_type",
@@ -1359,6 +1368,25 @@ class Expression(google.protobuf.message.Message):
 
 global___Expression = Expression
 
+class ExpressionCommon(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    ORIGIN_FIELD_NUMBER: builtins.int
+    @property
+    def origin(self) -> pyspark.sql.connect.proto.common_pb2.Origin:
+        """(Required) Keep the information of the origin for this expression 
such as stacktrace."""
+    def __init__(
+        self,
+        *,
+        origin: pyspark.sql.connect.proto.common_pb2.Origin | None = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["origin", b"origin"]
+    ) -> builtins.bool: ...
+    def ClearField(self, field_name: typing_extensions.Literal["origin", 
b"origin"]) -> None: ...
+
+global___ExpressionCommon = ExpressionCommon
+
 class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 819b5d717c52..ecf07a4ea9d7 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -608,7 +608,9 @@ class Unknown(google.protobuf.message.Message):
 global___Unknown = Unknown
 
 class RelationCommon(google.protobuf.message.Message):
-    """Common metadata of all relations."""
+    """Common metadata of all relations.
+    TODO(SPARK-48639): Add origin like Expression.ExpressionCommon
+    """
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
diff --git 
a/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
index 38bcd5643984..59107363571e 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe_query_context.py
@@ -21,10 +21,8 @@ from pyspark.sql.tests.test_dataframe_query_context import 
DataFrameQueryContext
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-class DataFrameParityTests(DataFrameQueryContextTestsMixin, 
ReusedConnectTestCase):
-    @unittest.skip("Spark Connect does not support DataFrameQueryContext 
currently.")
-    def test_dataframe_query_context(self):
-        super().test_dataframe_query_context()
+class DataFrameQueryContextParityTests(DataFrameQueryContextTestsMixin, 
ReusedConnectTestCase):
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_dataframe_query_context.py 
b/python/pyspark/sql/tests/test_dataframe_query_context.py
index 42fb0b0e452f..e1a3e33df859 100644
--- a/python/pyspark/sql/tests/test_dataframe_query_context.py
+++ b/python/pyspark/sql/tests/test_dataframe_query_context.py
@@ -41,7 +41,7 @@ class DataFrameQueryContextTestsMixin:
                 error_class="DIVIDE_BY_ZERO",
                 message_parameters={"config": '"spark.sql.ansi.enabled"'},
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="divide",
+                fragment="__truediv__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - plus
@@ -57,7 +57,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="plus",
+                fragment="__add__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - minus
@@ -73,7 +73,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="minus",
+                fragment="__sub__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - multiply
@@ -89,7 +89,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="multiply",
+                fragment="__mul__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - mod
@@ -105,7 +105,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="mod",
+                fragment="__mod__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - equalTo
@@ -121,7 +121,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="__eq__",
+                fragment="__eq__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - lt
@@ -137,7 +137,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="lt",
+                fragment="__lt__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - leq
@@ -153,7 +153,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="leq",
+                fragment="__le__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - geq
@@ -169,7 +169,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="geq",
+                fragment="__ge__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - gt
@@ -185,7 +185,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="gt",
+                fragment="__gt__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe
@@ -201,7 +201,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="eqNullSafe",
+                fragment="eqNullSafe",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR
@@ -217,7 +217,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="bitwiseOR",
+                fragment="bitwiseOR",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND
@@ -233,7 +233,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="bitwiseAND",
+                fragment="bitwiseAND",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR
@@ -249,7 +249,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="bitwiseXOR",
+                fragment="bitwiseXOR",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - chained 
(`divide` is problematic)
@@ -262,7 +262,7 @@ class DataFrameQueryContextTestsMixin:
                 error_class="DIVIDE_BY_ZERO",
                 message_parameters={"config": '"spark.sql.ansi.enabled"'},
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="divide",
+                fragment="__truediv__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - chained (`plus` 
is problematic)
@@ -282,7 +282,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="plus",
+                fragment="__add__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - chained (`minus` 
is problematic)
@@ -302,7 +302,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="minus",
+                fragment="__sub__",
             )
 
             # DataFrameQueryContext with pysparkLoggingInfo - chained 
(`multiply` is problematic)
@@ -320,7 +320,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="multiply",
+                fragment="__mul__",
             )
 
             # Multiple expressions in df.select (`divide` is problematic)
@@ -331,7 +331,7 @@ class DataFrameQueryContextTestsMixin:
                 error_class="DIVIDE_BY_ZERO",
                 message_parameters={"config": '"spark.sql.ansi.enabled"'},
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="divide",
+                fragment="__truediv__",
             )
 
             # Multiple expressions in df.select (`plus` is problematic)
@@ -347,7 +347,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="plus",
+                fragment="__add__",
             )
 
             # Multiple expressions in df.select (`minus` is problematic)
@@ -363,7 +363,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="minus",
+                fragment="__sub__",
             )
 
             # Multiple expressions in df.select (`multiply` is problematic)
@@ -379,7 +379,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="multiply",
+                fragment="__mul__",
             )
 
             # Multiple expressions with pre-declared expressions (`divide` is 
problematic)
@@ -392,7 +392,7 @@ class DataFrameQueryContextTestsMixin:
                 error_class="DIVIDE_BY_ZERO",
                 message_parameters={"config": '"spark.sql.ansi.enabled"'},
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="divide",
+                fragment="__truediv__",
             )
 
             # Multiple expressions with pre-declared expressions (`plus` is 
problematic)
@@ -410,7 +410,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="plus",
+                fragment="__add__",
             )
 
             # Multiple expressions with pre-declared expressions (`minus` is 
problematic)
@@ -428,7 +428,7 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="minus",
+                fragment="__sub__",
             )
 
             # Multiple expressions with pre-declared expressions (`multiply` 
is problematic)
@@ -446,20 +446,11 @@ class DataFrameQueryContextTestsMixin:
                     "ansiConfig": '"spark.sql.ansi.enabled"',
                 },
                 query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="multiply",
-            )
-
-            # DataFrameQueryContext without pysparkLoggingInfo
-            with self.assertRaises(AnalysisException) as pe:
-                df.select("non-existing-column")
-            self.check_error(
-                exception=pe.exception,
-                error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION",
-                message_parameters={"objectName": "`non-existing-column`", 
"proposal": "`id`"},
-                query_context_type=QueryContextType.DataFrame,
-                pyspark_fragment="",
+                fragment="__mul__",
             )
 
+    def test_sql_query_context(self):
+        with self.sql_conf({"spark.sql.ansi.enabled": True}):
             # SQLQueryContext
             with self.assertRaises(ArithmeticException) as pe:
                 self.spark.sql("select 10/0").collect()
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index fa58b7286fe8..c74291524dae 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -287,7 +287,7 @@ class PySparkErrorTestUtils:
         error_class: str,
         message_parameters: Optional[Dict[str, str]] = None,
         query_context_type: Optional[QueryContextType] = None,
-        pyspark_fragment: Optional[str] = None,
+        fragment: Optional[str] = None,
     ):
         query_context = exception.getQueryContext()
         assert bool(query_context) == (query_context_type is not None), (
@@ -326,10 +326,10 @@ class PySparkErrorTestUtils:
                 )
                 if actual == QueryContextType.DataFrame:
                     assert (
-                        pyspark_fragment is not None
-                    ), "`pyspark_fragment` is required when QueryContextType 
is DataFrame."
-                    expected = pyspark_fragment
-                    actual = actual_context.pysparkFragment()
+                        fragment is not None
+                    ), "`fragment` is required when QueryContextType is 
DataFrame."
+                    expected = fragment
+                    actual = actual_context.fragment()
                     self.assertEqual(
                         expected,
                         actual,
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
index 1c2456f00bcd..2b3f4674539e 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
@@ -145,36 +145,30 @@ case class DataFrameQueryContext(
   override def stopIndex: Int = throw SparkUnsupportedOperationException()
 
   override val fragment: String = {
-    stackTrace.headOption.map { firstElem =>
-      val methodName = firstElem.getMethodName
-      if (methodName.length > 1 && methodName(0) == '$') {
-        methodName.substring(1)
-      } else {
-        methodName
-      }
-    }.getOrElse("")
+    pysparkErrorContext.map(_._1).getOrElse {
+      stackTrace.headOption.map { firstElem =>
+        val methodName = firstElem.getMethodName
+        if (methodName.length > 1 && methodName(0) == '$') {
+          methodName.substring(1)
+        } else {
+          methodName
+        }
+      }.getOrElse("")
+    }
   }
 
-  override val callSite: String = stackTrace.tail.mkString("\n")
-
-  val pysparkFragment: String = pysparkErrorContext.map(_._1).getOrElse("")
-  val pysparkCallSite: String = pysparkErrorContext.map(_._2).getOrElse("")
-
-  val (displayedFragment, displayedCallsite) = if 
(pysparkErrorContext.nonEmpty) {
-    (pysparkFragment, pysparkCallSite)
-  } else {
-    (fragment, callSite)
-  }
+  override val callSite: String = pysparkErrorContext.map(
+    _._2).getOrElse(stackTrace.tail.mkString("\n"))
 
   override lazy val summary: String = {
     val builder = new StringBuilder
     builder ++= "== DataFrame ==\n"
     builder ++= "\""
 
-    builder ++= displayedFragment
+    builder ++= fragment
     builder ++= "\""
     builder ++= " was called from\n"
-    builder ++= displayedCallsite
+    builder ++= callSite
     builder += '\n'
 
     builder.result()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to