This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new f77834fe8963 [SPARK-52450][CONNECT] Improve performance of schema 
deepcopy
f77834fe8963 is described below

commit f77834fe8963d7fa3b2a25f7977cc51d65cd099c
Author: Xi Lyu <[email protected]>
AuthorDate: Fri Jun 20 11:00:56 2025 -0400

    [SPARK-52450][CONNECT] Improve performance of schema deepcopy
    
    ### What changes were proposed in this pull request?
    
    In Spark Connect, `DataFrame.schema` returns a deep copy of the schema to 
prevent unexpected behavior caused by user modifications to the returned schema 
object. However, if a user accesses `df.schema` repeatedly on a DataFrame with 
a complex schema, it can lead to noticeable performance degradation.
    
    The performance issue can be reproduced using the code snippet below. Since 
copy.deepcopy is known to be slow to handle complex objects, this PR replaces 
it with pickle-based ser/de to improve the performance of df.schema access. 
Given the limitations of pickle, the implementation falls back to deepcopy in 
cases where pickling fails.
    
    ```
    from pyspark.sql.types import StructType, StructField, StringType
    
    def make_nested_struct(level, max_level, fields_per_level):
        if level == max_level - 1:
            return StructType(
                [StructField(f"f{level}_{i}", StringType(), True) for i in 
range(fields_per_level)])
        else:
            return StructType(
                [StructField(f"s{level}_{i}",
                             make_nested_struct(level + 1, max_level, 
fields_per_level), True) for i in
                 range(fields_per_level)])
    
    # Create a 4 level nested schema with in total 10,000 leaf fields
    schema = make_nested_struct(0, 4, 10)
    ```
    
     The existing needs 21.9s to copy the schema for 100 times.
    ```
    import copy
    timeit.timeit(lambda: copy.deepcopy(schema), number=100)
    # 21.9
    ```
    
    The updated approach only needs 2.0s to copy for 100 times:
    ```
    from pyspark.serializers import CPickleSerializer
    cached_schema_serialized = CPickleSerializer().dumps(schema)
    
    timeit.timeit(lambda: CPickleSerializer().loads(cached_schema_serialized), 
number=100)
    # 2.0
    ```
    
    ### Why are the changes needed?
    
    It improves the performance when calling df.schema many times.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests and new tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51157 from xi-db/schema-deepcopy-improvement.
    
    Lead-authored-by: Xi Lyu <[email protected]>
    Co-authored-by: Xi Lyu <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
    (cherry picked from commit f502d66fa2955b810eca3e5bc1d30d8d37925860)
    Signed-off-by: Herman van Hovell <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py                | 18 +++++++++++++++++-
 .../tests/connect/test_connect_dataframe_property.py   |  9 +++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 01566644071c..b6637f6a1014 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -22,6 +22,7 @@ from pyspark.errors.exceptions.base import (
     PySparkAttributeError,
 )
 from pyspark.resource import ResourceProfile
+from pyspark.sql.connect.logging import logger
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -69,6 +70,7 @@ from pyspark.errors import (
     PySparkRuntimeError,
 )
 from pyspark.util import PythonEvalType
+from pyspark.serializers import CPickleSerializer
 from pyspark.storagelevel import StorageLevel
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.conversion import ArrowTableToRowsConversion
@@ -141,6 +143,7 @@ class DataFrame(ParentDataFrame):
         # by __repr__ and _repr_html_ while eager evaluation opens.
         self._support_repr_html = False
         self._cached_schema: Optional[StructType] = None
+        self._cached_schema_serialized: Optional[bytes] = None
         self._execution_info: Optional["ExecutionInfo"] = None
 
     def __reduce__(self) -> Tuple:
@@ -1836,11 +1839,24 @@ class DataFrame(ParentDataFrame):
         if self._cached_schema is None:
             query = self._plan.to_proto(self._session.client)
             self._cached_schema = self._session.client.schema(query)
+            try:
+                self._cached_schema_serialized = 
CPickleSerializer().dumps(self._schema)
+            except Exception as e:
+                logger.warn(f"DataFrame schema pickle dumps failed with 
exception: {e}.")
+                self._cached_schema_serialized = None
         return self._cached_schema
 
     @property
     def schema(self) -> StructType:
-        return copy.deepcopy(self._schema)
+        # self._schema call will cache the schema and serialize it if it is 
not cached yet.
+        _schema = self._schema
+        if self._cached_schema_serialized is not None:
+            try:
+                return 
CPickleSerializer().loads(self._cached_schema_serialized)
+            except Exception as e:
+                logger.warn(f"DataFrame schema pickle loads failed with 
exception: {e}.")
+        # In case of pickle ser/de failure, fallback to deepcopy approach.
+        return copy.deepcopy(_schema)
 
     @functools.cache
     def isLocal(self) -> bool:
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py 
b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
index c4c10c963a48..6b213d0ecb0f 100644
--- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -64,6 +64,15 @@ class 
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
             df_columns.remove(col)
         assert len(df.columns) == 4
 
+        cdf = self.connect.createDataFrame(data, schema)
+        cdf_schema = cdf.schema
+        assert len(cdf._cached_schema_serialized) > 0
+        assert cdf_schema.jsonValue() == cdf._cached_schema.jsonValue()
+        assert len(cdf_schema.fields) == 4
+        cdf_schema.fields.pop(0)
+        assert cdf.schema.jsonValue() == cdf._cached_schema.jsonValue()
+        assert len(cdf.schema.fields) == 4
+
     def test_cached_schema_to(self):
         cdf = self.connect.read.table(self.tbl_name)
         sdf = self.spark.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to