This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ad8250562cee [SPARK-50324][PYTHON][CONNECT] Make `createDataFrame`
trigger `Config` RPC at most once
ad8250562cee is described below
commit ad8250562cee26ebcdf6c99ac4d78f431965c9a4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Nov 22 17:04:52 2024 +0800
[SPARK-50324][PYTHON][CONNECT] Make `createDataFrame` trigger `Config` RPC
at most once
### What changes were proposed in this pull request?
Get all configs in batch
### Why are the changes needed?
there are too many related configs in `createDataFrame`, they are fetched
one by one (or group by group) in different branches:
1, it is possible no Config RPC is triggered, e.g. in this branch:
https://github.com/apache/spark/blob/26330355836f5b2dad9b7bd4c72d9830c7ce6788/python/pyspark/sql/connect/session.py#L502-L509
2, multiple Config RPCs for different configs, e.g. in this branch:
https://github.com/apache/spark/blob/26330355836f5b2dad9b7bd4c72d9830c7ce6788/python/pyspark/sql/connect/session.py#L599-L601
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48856 from zhengruifeng/lazy_config.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 5 +++
python/pyspark/sql/connect/session.py | 57 ++++++++++++++++++-------------
2 files changed, 39 insertions(+), 23 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 3de425505405..78d4e0fc1c4f 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -43,6 +43,7 @@ from typing import (
Dict,
Set,
NoReturn,
+ Mapping,
cast,
TYPE_CHECKING,
Type,
@@ -1576,6 +1577,10 @@ class SparkConnectClient(object):
configs = dict(self.config(op).pairs)
return tuple(configs.get(key) for key in keys)
+ def get_config_dict(self, *keys: str) -> Mapping[str, Optional[str]]:
+ op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys))
+ return dict(self.config(op).pairs)
+
def get_config_with_defaults(
self, *pairs: Tuple[str, Optional[str]]
) -> Tuple[Optional[str], ...]:
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 83b0496a8427..bfd79092ccf4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
-from pyspark.sql.utils import is_timestamp_ntz_preferred
check_dependencies(__name__)
@@ -37,6 +36,7 @@ from typing import (
cast,
overload,
Iterable,
+ Mapping,
TYPE_CHECKING,
ClassVar,
)
@@ -407,7 +407,10 @@ class SparkSession:
clearProgressHandlers.__doc__ =
PySparkSession.clearProgressHandlers.__doc__
def _inferSchemaFromList(
- self, data: Iterable[Any], names: Optional[List[str]] = None
+ self,
+ data: Iterable[Any],
+ names: Optional[List[str]],
+ configs: Mapping[str, Optional[str]],
) -> StructType:
"""
Infer schema from list of Row, dict, or tuple.
@@ -422,12 +425,12 @@ class SparkSession:
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
- prefer_timestamp_ntz,
- ) = self._client.get_configs(
- "spark.sql.pyspark.inferNestedDictAsStruct.enabled",
- "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
- "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
- "spark.sql.timestampType",
+ prefer_timestamp,
+ ) = (
+ configs["spark.sql.pyspark.inferNestedDictAsStruct.enabled"],
+
configs["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"],
+
configs["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"],
+ configs["spark.sql.timestampType"],
)
return functools.reduce(
_merge_type,
@@ -438,7 +441,7 @@ class SparkSession:
infer_dict_as_struct=(infer_dict_as_struct == "true"),
infer_array_from_first_element=(infer_array_from_first_element == "true"),
infer_map_from_first_pair=(infer_map_from_first_pair ==
"true"),
- prefer_timestamp_ntz=(prefer_timestamp_ntz ==
"TIMESTAMP_NTZ"),
+ prefer_timestamp_ntz=(prefer_timestamp == "TIMESTAMP_NTZ"),
)
for row in data
),
@@ -508,8 +511,21 @@ class SparkSession:
messageParameters={},
)
+ # Get all related configs in a batch
+ configs = self._client.get_config_dict(
+ "spark.sql.timestampType",
+ "spark.sql.session.timeZone",
+ "spark.sql.session.localRelationCacheThreshold",
+ "spark.sql.execution.pandas.convertToArrowArraySafely",
+ "spark.sql.execution.pandas.inferPandasDictAsMap",
+ "spark.sql.pyspark.inferNestedDictAsStruct.enabled",
+ "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
+ "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
+ )
+ timezone = configs["spark.sql.session.timeZone"]
+ prefer_timestamp = configs["spark.sql.timestampType"]
+
_table: Optional[pa.Table] = None
- timezone: Optional[str] = None
if isinstance(data, pd.DataFrame):
# Logic was borrowed from `_create_from_pandas_with_arrow` in
@@ -519,8 +535,7 @@ class SparkSession:
if schema is None:
_cols = [str(x) if not isinstance(x, str) else x for x in
data.columns]
infer_pandas_dict_as_map = (
-
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
- == "true"
+ configs["spark.sql.execution.pandas.inferPandasDictAsMap"]
== "true"
)
if infer_pandas_dict_as_map:
struct = StructType()
@@ -572,9 +587,7 @@ class SparkSession:
]
arrow_types = [to_arrow_type(dt) if dt is not None else None
for dt in spark_types]
- timezone, safecheck = self._client.get_configs(
- "spark.sql.session.timeZone",
"spark.sql.execution.pandas.convertToArrowArraySafely"
- )
+ safecheck =
configs["spark.sql.execution.pandas.convertToArrowArraySafely"]
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck
== "true")
@@ -596,10 +609,6 @@ class SparkSession:
).cast(arrow_schema)
elif isinstance(data, pa.Table):
- prefer_timestamp_ntz = is_timestamp_ntz_preferred()
-
- (timezone,) =
self._client.get_configs("spark.sql.session.timeZone")
-
# If no schema supplied by user then get the names of columns only
if schema is None:
_cols = data.column_names
@@ -609,7 +618,9 @@ class SparkSession:
_num_cols = len(_cols)
if not isinstance(schema, StructType):
- schema = from_arrow_schema(data.schema,
prefer_timestamp_ntz=prefer_timestamp_ntz)
+ schema = from_arrow_schema(
+ data.schema, prefer_timestamp_ntz=prefer_timestamp ==
"TIMESTAMP_NTZ"
+ )
_table = (
_check_arrow_table_timestamps_localize(data, schema, True,
timezone)
@@ -671,7 +682,7 @@ class SparkSession:
if not isinstance(_schema, StructType):
_schema = StructType().add("value", _schema)
else:
- _schema = self._inferSchemaFromList(_data, _cols)
+ _schema = self._inferSchemaFromList(_data, _cols, configs)
if _cols is not None and cast(int, _num_cols) < len(_cols):
_num_cols = len(_cols)
@@ -706,9 +717,9 @@ class SparkSession:
else:
local_relation = LocalRelation(_table)
- cache_threshold =
self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
+ cache_threshold =
configs["spark.sql.session.localRelationCacheThreshold"]
plan: LogicalPlan = local_relation
- if cache_threshold[0] is not None and int(cache_threshold[0]) <=
_table.nbytes:
+ if cache_threshold is not None and int(cache_threshold) <=
_table.nbytes:
plan =
CachedLocalRelation(self._cache_local_relation(local_relation))
df = DataFrame(plan, self)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]