This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d1d29c9840fe [SPARK-48598][PYTHON][CONNECT] Propagate cached schema in
dataframe operations
d1d29c9840fe is described below
commit d1d29c9840fedecc9b5d74651526359a2b70377e
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jun 12 16:48:24 2024 -0700
[SPARK-48598][PYTHON][CONNECT] Propagate cached schema in dataframe
operations
### What changes were proposed in this pull request?
Propagate cached schema in dataframe operations:
- DataFrame.alias
- DataFrame.coalesce
- DataFrame.repartition
- DataFrame.repartitionByRange
- DataFrame.dropDuplicates
- DataFrame.distinct
- DataFrame.filter
- DataFrame.where
- DataFrame.limit
- DataFrame.sort
- DataFrame.sortWithinPartitions
- DataFrame.orderBy
- DataFrame.sample
- DataFrame.hint
- DataFrame.randomSplit
- DataFrame.observe
### Why are the changes needed?
to avoid unnecessary RPCs if possible
### Does this PR introduce _any_ user-facing change?
No, optimization only
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46954 from zhengruifeng/py_connect_propagate_schema.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 69 ++++++++++++++++------
.../connect/test_connect_dataframe_property.py | 35 +++++++++++
2 files changed, 85 insertions(+), 19 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index baac1523c709..f2705ec7ad71 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -262,7 +262,9 @@ class DataFrame(ParentDataFrame):
return self.groupBy().agg(*exprs)
def alias(self, alias: str) -> ParentDataFrame:
- return DataFrame(plan.SubqueryAlias(self._plan, alias),
session=self._session)
+ res = DataFrame(plan.SubqueryAlias(self._plan, alias),
session=self._session)
+ res._cached_schema = self._cached_schema
+ return res
def colRegex(self, colName: str) -> Column:
from pyspark.sql.connect.column import Column as ConnectColumn
@@ -314,10 +316,12 @@ class DataFrame(ParentDataFrame):
error_class="VALUE_NOT_POSITIVE",
message_parameters={"arg_name": "numPartitions", "arg_value":
str(numPartitions)},
)
- return DataFrame(
+ res = DataFrame(
plan.Repartition(self._plan, num_partitions=numPartitions,
shuffle=False),
self._session,
)
+ res._cached_schema = self._cached_schema
+ return res
@overload
def repartition(self, numPartitions: int, *cols: "ColumnOrName") ->
ParentDataFrame:
@@ -340,12 +344,12 @@ class DataFrame(ParentDataFrame):
},
)
if len(cols) == 0:
- return DataFrame(
+ res = DataFrame(
plan.Repartition(self._plan, numPartitions, shuffle=True),
self._session,
)
else:
- return DataFrame(
+ res = DataFrame(
plan.RepartitionByExpression(
self._plan, numPartitions, [F._to_col(c) for c in cols]
),
@@ -353,7 +357,7 @@ class DataFrame(ParentDataFrame):
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
- return DataFrame(
+ res = DataFrame(
plan.RepartitionByExpression(self._plan, None, [F._to_col(c)
for c in cols]),
self.sparkSession,
)
@@ -366,6 +370,9 @@ class DataFrame(ParentDataFrame):
},
)
+ res._cached_schema = self._cached_schema
+ return res
+
@overload
def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") ->
ParentDataFrame:
...
@@ -392,14 +399,14 @@ class DataFrame(ParentDataFrame):
message_parameters={"item": "cols"},
)
else:
- return DataFrame(
+ res = DataFrame(
plan.RepartitionByExpression(
self._plan, numPartitions, [F._sort_col(c) for c in
cols]
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
- return DataFrame(
+ res = DataFrame(
plan.RepartitionByExpression(
self._plan, None, [F._sort_col(c) for c in [numPartitions]
+ list(cols)]
),
@@ -414,6 +421,9 @@ class DataFrame(ParentDataFrame):
},
)
+ res._cached_schema = self._cached_schema
+ return res
+
def dropDuplicates(self, *subset: Union[str, List[str]]) ->
ParentDataFrame:
# Acceptable args should be str, ... or a single List[str]
# So if subset length is 1, it can be either single str, or a list of
str
@@ -422,20 +432,23 @@ class DataFrame(ParentDataFrame):
assert all(isinstance(c, str) for c in subset)
if not subset:
- return DataFrame(
+ res = DataFrame(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True),
session=self._session
)
elif len(subset) == 1 and isinstance(subset[0], list):
- return DataFrame(
+ res = DataFrame(
plan.Deduplicate(child=self._plan, column_names=subset[0]),
session=self._session,
)
else:
- return DataFrame(
+ res = DataFrame(
plan.Deduplicate(child=self._plan,
column_names=cast(List[str], subset)),
session=self._session,
)
+ res._cached_schema = self._cached_schema
+ return res
+
drop_duplicates = dropDuplicates
def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) ->
ParentDataFrame:
@@ -466,9 +479,11 @@ class DataFrame(ParentDataFrame):
)
def distinct(self) -> ParentDataFrame:
- return DataFrame(
+ res = DataFrame(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True),
session=self._session
)
+ res._cached_schema = self._cached_schema
+ return res
@overload
def drop(self, cols: "ColumnOrName") -> ParentDataFrame:
@@ -499,7 +514,9 @@ class DataFrame(ParentDataFrame):
expr = F.expr(condition)
else:
expr = condition
- return DataFrame(plan.Filter(child=self._plan, filter=expr),
session=self._session)
+ res = DataFrame(plan.Filter(child=self._plan, filter=expr),
session=self._session)
+ res._cached_schema = self._cached_schema
+ return res
def first(self) -> Optional[Row]:
return self.head()
@@ -709,7 +726,9 @@ class DataFrame(ParentDataFrame):
)
def limit(self, n: int) -> ParentDataFrame:
- return DataFrame(plan.Limit(child=self._plan, limit=n),
session=self._session)
+ res = DataFrame(plan.Limit(child=self._plan, limit=n),
session=self._session)
+ res._cached_schema = self._cached_schema
+ return res
def tail(self, num: int) -> List[Row]:
return DataFrame(plan.Tail(child=self._plan, limit=num),
session=self._session).collect()
@@ -766,7 +785,7 @@ class DataFrame(ParentDataFrame):
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
- return DataFrame(
+ res = DataFrame(
plan.Sort(
self._plan,
columns=self._sort_cols(cols, kwargs),
@@ -774,6 +793,8 @@ class DataFrame(ParentDataFrame):
),
session=self._session,
)
+ res._cached_schema = self._cached_schema
+ return res
orderBy = sort
@@ -782,7 +803,7 @@ class DataFrame(ParentDataFrame):
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
- return DataFrame(
+ res = DataFrame(
plan.Sort(
self._plan,
columns=self._sort_cols(cols, kwargs),
@@ -790,6 +811,8 @@ class DataFrame(ParentDataFrame):
),
session=self._session,
)
+ res._cached_schema = self._cached_schema
+ return res
def sample(
self,
@@ -837,7 +860,7 @@ class DataFrame(ParentDataFrame):
seed = int(seed) if seed is not None else random.randint(0,
sys.maxsize)
- return DataFrame(
+ res = DataFrame(
plan.Sample(
child=self._plan,
lower_bound=0.0,
@@ -847,6 +870,8 @@ class DataFrame(ParentDataFrame):
),
session=self._session,
)
+ res._cached_schema = self._cached_schema
+ return res
def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame:
return self.withColumnsRenamed({existing: new})
@@ -1050,10 +1075,12 @@ class DataFrame(ParentDataFrame):
},
)
- return DataFrame(
+ res = DataFrame(
plan.Hint(self._plan, name, [F.lit(p) for p in list(parameters)]),
session=self._session,
)
+ res._cached_schema = self._cached_schema
+ return res
def randomSplit(
self,
@@ -1094,6 +1121,7 @@ class DataFrame(ParentDataFrame):
),
session=self._session,
)
+ samplePlan._cached_schema = self._cached_schema
splits.append(samplePlan)
j += 1
@@ -1118,9 +1146,9 @@ class DataFrame(ParentDataFrame):
)
if isinstance(observation, Observation):
- return observation._on(self, *exprs)
+ res = observation._on(self, *exprs)
elif isinstance(observation, str):
- return DataFrame(
+ res = DataFrame(
plan.CollectMetrics(self._plan, observation, list(exprs)),
self._session,
)
@@ -1133,6 +1161,9 @@ class DataFrame(ParentDataFrame):
},
)
+ res._cached_schema = self._cached_schema
+ return res
+
def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical:
bool = False) -> None:
print(self._show_string(n, truncate, vertical))
diff --git
a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
index 4a7e1e1ea760..c712e5d6efcb 100644
--- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -20,6 +20,9 @@ import unittest
from pyspark.sql.types import StructType, StructField, StringType,
IntegerType, LongType, DoubleType
from pyspark.sql.utils import is_remote
+from pyspark.sql import functions as SF
+from pyspark.sql.connect import functions as CF
+
from pyspark.sql.tests.connect.test_connect_basic import
SparkConnectSQLTestCase
from pyspark.testing.sqlutils import (
have_pandas,
@@ -393,6 +396,38 @@ class
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
# cannot infer when schemas mismatch
self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None)
+ def test_cached_schema_in_chain_op(self):
+ data = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)]
+
+ cdf = self.connect.createDataFrame(data, ("id", "v1"))
+ sdf = self.spark.createDataFrame(data, ("id", "v1"))
+
+ cdf1 = cdf.withColumn("v2", CF.lit(1))
+ sdf1 = sdf.withColumn("v2", SF.lit(1))
+
+ self.assertTrue(cdf1._cached_schema is None)
+ # trigger analysis of cdf1.schema
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertTrue(cdf1._cached_schema is not None)
+
+ cdf2 = cdf1.where(cdf1.v2 > 0)
+ sdf2 = sdf1.where(sdf1.v2 > 0)
+ self.assertEqual(cdf1._cached_schema, cdf2._cached_schema)
+
+ cdf3 = cdf2.repartition(10)
+ sdf3 = sdf2.repartition(10)
+ self.assertEqual(cdf1._cached_schema, cdf3._cached_schema)
+
+ cdf4 = cdf3.distinct()
+ sdf4 = sdf3.distinct()
+ self.assertEqual(cdf1._cached_schema, cdf4._cached_schema)
+
+ cdf5 = cdf4.sample(fraction=0.5)
+ sdf5 = sdf4.sample(fraction=0.5)
+ self.assertEqual(cdf1._cached_schema, cdf5._cached_schema)
+
+ self.assertEqual(cdf5.schema, sdf5.schema)
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_dataframe_property import * #
noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]