This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 76d31f7853f2 [SPARK-50881][PYTHON] Use cached schema where possible in 
conenct dataframe.py
76d31f7853f2 is described below

commit 76d31f7853f22c857c1b4cc6a557e7a9848d48b6
Author: Garland Zhang <[email protected]>
AuthorDate: Mon Feb 10 18:51:22 2025 +0800

    [SPARK-50881][PYTHON] Use cached schema where possible in conenct 
dataframe.py
    
    ### What changes were proposed in this pull request?
    
    * schema property returns a deepcopy everytime to ensure completeness. 
However this creates a performance degradation for internal use in 
dataframe.py. we make the following changes:
    
    1. `columns` returns a copy of the array of names. This is the same as 
classic
    2. all uses of schema in dataframe.py now calls the cached schema, avoiding 
a deepcopy
    
    ### Why are the changes needed?
    * this does not scale well when these methods are called thousands of times 
like `columns` method in `pivot`
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    * existing tests
    
    benchmarking does show improvement in performance approximately 1/3 times 
faster.
    
    ```
    import cProfile, pstats
    import copy
    cProfile.run("""
    x = pd.DataFrame(zip(np.random.rand(1000000), np.random.randint(1, 3000, 
10000000), list(range(1000)) * 100000), columns=['x', 'y', 'z'])
    df = spark.createDataFrame(x)
    schema = df.schema
    for i in range(1_000_000):
      [name for name in schema.names]
    """)
    p = pstats.Stats("profile_results")
    p.sort_stats("cumtime").print_stats(.1)
    ```
    ```
             17000003 function calls in 8.886 seconds
    
       Ordered by: standard name
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.931    0.931    8.886    8.886 <string>:1(<module>)
      1000000    0.391    0.000    0.391    0.000 <string>:3(<listcomp>)
      1000000    0.933    0.000    5.516    0.000 
DatasetInfo.py:22(gather_imported_dataframes)
      1000000    0.948    0.000    6.669    0.000 
DatasetInfo.py:75(_maybe_handle_dataframe_assignment)
      1000000    0.895    0.000    7.564    0.000 DatasetInfo.py:90(__setitem__)
      3000000    2.853    0.000    4.583    0.000 
utils.py:54(retrieve_imported_type)
            1    0.000    0.000    8.886    8.886 {built-in method 
builtins.exec}
      3000000    0.667    0.000    0.667    0.000 {built-in method 
builtins.getattr}
      1000000    0.204    0.000    0.204    0.000 {built-in method 
builtins.isinstance}
            1    0.000    0.000    0.000    0.000 {method 'disable' of 
'_lsprof.Profiler' objects}
      3000000    0.473    0.000    0.473    0.000 {method 'get' of 'dict' 
objects}
      3000000    0.590    0.000    0.590    0.000 {method 'rsplit' of 'str' 
objects}
    
    Thu Jan 16 20:13:47 2025    profile_results
    
             3 function calls in 0.000 seconds
    
    ```
    vs
    
    ```
             55000003 function calls (50000003 primitive calls) in 23.181 
seconds
    
       Ordered by: standard name
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.987    0.987   23.181   23.181 <string>:1(<module>)
      1000000    1.060    0.000    5.750    0.000 
DatasetInfo.py:22(gather_imported_dataframes)
      1000000    0.956    0.000    6.907    0.000 
DatasetInfo.py:75(_maybe_handle_dataframe_assignment)
      1000000    0.930    0.000    7.837    0.000 DatasetInfo.py:90(__setitem__)
    6000000/1000000    7.420    0.000   14.357    0.000 copy.py:128(deepcopy)
      5000000    0.494    0.000    0.494    0.000 copy.py:182(_deepcopy_atomic)
      1000000    2.734    0.000   11.015    0.000 copy.py:201(_deepcopy_list)
      1000000    0.951    0.000    1.160    0.000 copy.py:243(_keep_alive)
      3000000    2.946    0.000    4.690    0.000 
utils.py:54(retrieve_imported_type)
            1    0.000    0.000   23.181   23.181 {built-in method 
builtins.exec}
      3000000    0.686    0.000    0.686    0.000 {built-in method 
builtins.getattr}
      9000000    0.976    0.000    0.976    0.000 {built-in method builtins.id}
      1000000    0.201    0.000    0.201    0.000 {built-in method 
builtins.isinstance}
      5000000    0.560    0.000    0.560    0.000 {method 'append' of 'list' 
objects}
            1    0.000    0.000    0.000    0.000 {method 'disable' of 
'_lsprof.Profiler' objects}
     15000000    1.673    0.000    1.673    0.000 {method 'get' of 'dict' 
objects}
      3000000    0.607    0.000    0.607    0.000 {method 'rsplit' of 'str' 
objects}
    
    Thu Jan 16 20:13:47 2025    profile_results
    
             3 function calls in 0.000 seconds
    
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #49749 from garlandz-db/SPARK-50881.
    
    Authored-by: Garland Zhang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 9f866475fca4f4e94a64d3b01356a115f6f8c4e0)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index fa06322c39ba..601ca8f9f1fa 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -294,11 +294,11 @@ class DataFrame(ParentDataFrame):
 
     @property
     def dtypes(self) -> List[Tuple[str, str]]:
-        return [(str(f.name), f.dataType.simpleString()) for f in 
self.schema.fields]
+        return [(str(f.name), f.dataType.simpleString()) for f in 
self._schema.fields]
 
     @property
     def columns(self) -> List[str]:
-        return self.schema.names
+        return [field.name for field in self._schema.fields]
 
     @property
     def sparkSession(self) -> "SparkSession":
@@ -1742,7 +1742,7 @@ class DataFrame(ParentDataFrame):
 
                     # Try best to verify the column name with cached schema
                     # If fails, fall back to the server side validation
-                    if not verify_col_name(item, self.schema):
+                    if not verify_col_name(item, self._schema):
                         self.select(item).isLocal()
 
                 return self._col(item)
@@ -1827,7 +1827,7 @@ class DataFrame(ParentDataFrame):
         return ConnectColumn(SubqueryExpression(self._plan, 
subquery_type="exists"))
 
     @property
-    def schema(self) -> StructType:
+    def _schema(self) -> StructType:
         # Schema caching is correct in most cases. Connect is lazy by nature. 
This means that
         # we only resolve the plan when it is submitted for execution or 
analysis. We do not
         # cache intermediate resolved plan. If the input (changes table, view 
redefinition,
@@ -1836,7 +1836,11 @@ class DataFrame(ParentDataFrame):
         if self._cached_schema is None:
             query = self._plan.to_proto(self._session.client)
             self._cached_schema = self._session.client.schema(query)
-        return copy.deepcopy(self._cached_schema)
+        return self._cached_schema
+
+    @property
+    def schema(self) -> StructType:
+        return copy.deepcopy(self._schema)
 
     @functools.cache
     def isLocal(self) -> bool:
@@ -2099,12 +2103,12 @@ class DataFrame(ParentDataFrame):
         def foreach_func(row: Any) -> None:
             f(row)
 
-        self.select(F.struct(*self.schema.fieldNames()).alias("row")).select(
+        self.select(F.struct(*self._schema.fieldNames()).alias("row")).select(
             F.udf(foreach_func, StructType())("row")  # type: ignore[arg-type]
         ).collect()
 
     def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
-        schema = self.schema
+        schema = self._schema
         field_converters = [
             ArrowTableToRowsConversion._create_converter(f.dataType) for f in 
schema.fields
         ]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to