This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ad8250562cee [SPARK-50324][PYTHON][CONNECT] Make `createDataFrame` 
trigger `Config` RPC at most once
ad8250562cee is described below

commit ad8250562cee26ebcdf6c99ac4d78f431965c9a4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Nov 22 17:04:52 2024 +0800

    [SPARK-50324][PYTHON][CONNECT] Make `createDataFrame` trigger `Config` RPC 
at most once
    
    ### What changes were proposed in this pull request?
    Get all configs in batch
    
    ### Why are the changes needed?
    there are too many related configs in `createDataFrame`, they are fetched 
one by one (or group by group) in different branches:
    1, it is possible no Config RPC is triggered, e.g. in this branch:
    
https://github.com/apache/spark/blob/26330355836f5b2dad9b7bd4c72d9830c7ce6788/python/pyspark/sql/connect/session.py#L502-L509
    
    2, multiple Config RPCs for different configs, e.g. in this branch:
    
https://github.com/apache/spark/blob/26330355836f5b2dad9b7bd4c72d9830c7ce6788/python/pyspark/sql/connect/session.py#L599-L601
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48856 from zhengruifeng/lazy_config.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/client/core.py |  5 +++
 python/pyspark/sql/connect/session.py     | 57 ++++++++++++++++++-------------
 2 files changed, 39 insertions(+), 23 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 3de425505405..78d4e0fc1c4f 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -43,6 +43,7 @@ from typing import (
     Dict,
     Set,
     NoReturn,
+    Mapping,
     cast,
     TYPE_CHECKING,
     Type,
@@ -1576,6 +1577,10 @@ class SparkConnectClient(object):
         configs = dict(self.config(op).pairs)
         return tuple(configs.get(key) for key in keys)
 
+    def get_config_dict(self, *keys: str) -> Mapping[str, Optional[str]]:
+        op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys))
+        return dict(self.config(op).pairs)
+
     def get_config_with_defaults(
         self, *pairs: Tuple[str, Optional[str]]
     ) -> Tuple[Optional[str], ...]:
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 83b0496a8427..bfd79092ccf4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 from pyspark.sql.connect.utils import check_dependencies
-from pyspark.sql.utils import is_timestamp_ntz_preferred
 
 check_dependencies(__name__)
 
@@ -37,6 +36,7 @@ from typing import (
     cast,
     overload,
     Iterable,
+    Mapping,
     TYPE_CHECKING,
     ClassVar,
 )
@@ -407,7 +407,10 @@ class SparkSession:
     clearProgressHandlers.__doc__ = 
PySparkSession.clearProgressHandlers.__doc__
 
     def _inferSchemaFromList(
-        self, data: Iterable[Any], names: Optional[List[str]] = None
+        self,
+        data: Iterable[Any],
+        names: Optional[List[str]],
+        configs: Mapping[str, Optional[str]],
     ) -> StructType:
         """
         Infer schema from list of Row, dict, or tuple.
@@ -422,12 +425,12 @@ class SparkSession:
             infer_dict_as_struct,
             infer_array_from_first_element,
             infer_map_from_first_pair,
-            prefer_timestamp_ntz,
-        ) = self._client.get_configs(
-            "spark.sql.pyspark.inferNestedDictAsStruct.enabled",
-            "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
-            "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
-            "spark.sql.timestampType",
+            prefer_timestamp,
+        ) = (
+            configs["spark.sql.pyspark.inferNestedDictAsStruct.enabled"],
+            
configs["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"],
+            
configs["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"],
+            configs["spark.sql.timestampType"],
         )
         return functools.reduce(
             _merge_type,
@@ -438,7 +441,7 @@ class SparkSession:
                     infer_dict_as_struct=(infer_dict_as_struct == "true"),
                     
infer_array_from_first_element=(infer_array_from_first_element == "true"),
                     infer_map_from_first_pair=(infer_map_from_first_pair == 
"true"),
-                    prefer_timestamp_ntz=(prefer_timestamp_ntz == 
"TIMESTAMP_NTZ"),
+                    prefer_timestamp_ntz=(prefer_timestamp == "TIMESTAMP_NTZ"),
                 )
                 for row in data
             ),
@@ -508,8 +511,21 @@ class SparkSession:
                     messageParameters={},
                 )
 
+        # Get all related configs in a batch
+        configs = self._client.get_config_dict(
+            "spark.sql.timestampType",
+            "spark.sql.session.timeZone",
+            "spark.sql.session.localRelationCacheThreshold",
+            "spark.sql.execution.pandas.convertToArrowArraySafely",
+            "spark.sql.execution.pandas.inferPandasDictAsMap",
+            "spark.sql.pyspark.inferNestedDictAsStruct.enabled",
+            "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
+            "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
+        )
+        timezone = configs["spark.sql.session.timeZone"]
+        prefer_timestamp = configs["spark.sql.timestampType"]
+
         _table: Optional[pa.Table] = None
-        timezone: Optional[str] = None
 
         if isinstance(data, pd.DataFrame):
             # Logic was borrowed from `_create_from_pandas_with_arrow` in
@@ -519,8 +535,7 @@ class SparkSession:
             if schema is None:
                 _cols = [str(x) if not isinstance(x, str) else x for x in 
data.columns]
                 infer_pandas_dict_as_map = (
-                    
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
-                    == "true"
+                    configs["spark.sql.execution.pandas.inferPandasDictAsMap"] 
== "true"
                 )
                 if infer_pandas_dict_as_map:
                     struct = StructType()
@@ -572,9 +587,7 @@ class SparkSession:
                 ]
                 arrow_types = [to_arrow_type(dt) if dt is not None else None 
for dt in spark_types]
 
-            timezone, safecheck = self._client.get_configs(
-                "spark.sql.session.timeZone", 
"spark.sql.execution.pandas.convertToArrowArraySafely"
-            )
+            safecheck = 
configs["spark.sql.execution.pandas.convertToArrowArraySafely"]
 
             ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck 
== "true")
 
@@ -596,10 +609,6 @@ class SparkSession:
                 ).cast(arrow_schema)
 
         elif isinstance(data, pa.Table):
-            prefer_timestamp_ntz = is_timestamp_ntz_preferred()
-
-            (timezone,) = 
self._client.get_configs("spark.sql.session.timeZone")
-
             # If no schema supplied by user then get the names of columns only
             if schema is None:
                 _cols = data.column_names
@@ -609,7 +618,9 @@ class SparkSession:
                 _num_cols = len(_cols)
 
             if not isinstance(schema, StructType):
-                schema = from_arrow_schema(data.schema, 
prefer_timestamp_ntz=prefer_timestamp_ntz)
+                schema = from_arrow_schema(
+                    data.schema, prefer_timestamp_ntz=prefer_timestamp == 
"TIMESTAMP_NTZ"
+                )
 
             _table = (
                 _check_arrow_table_timestamps_localize(data, schema, True, 
timezone)
@@ -671,7 +682,7 @@ class SparkSession:
                 if not isinstance(_schema, StructType):
                     _schema = StructType().add("value", _schema)
             else:
-                _schema = self._inferSchemaFromList(_data, _cols)
+                _schema = self._inferSchemaFromList(_data, _cols, configs)
 
                 if _cols is not None and cast(int, _num_cols) < len(_cols):
                     _num_cols = len(_cols)
@@ -706,9 +717,9 @@ class SparkSession:
         else:
             local_relation = LocalRelation(_table)
 
-        cache_threshold = 
self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
+        cache_threshold = 
configs["spark.sql.session.localRelationCacheThreshold"]
         plan: LogicalPlan = local_relation
-        if cache_threshold[0] is not None and int(cache_threshold[0]) <= 
_table.nbytes:
+        if cache_threshold is not None and int(cache_threshold) <= 
_table.nbytes:
             plan = 
CachedLocalRelation(self._cache_local_relation(local_relation))
 
         df = DataFrame(plan, self)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to