This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b3d5bc0c109 [SPARK-45362][PYTHON] Project out PARTITION BY expressions
before Python UDTF 'eval' method consumes them
b3d5bc0c109 is described below
commit b3d5bc0c10908aa66510844eaabc43b6764dd7c0
Author: Daniel Tenedorio <[email protected]>
AuthorDate: Thu Sep 28 14:02:46 2023 -0700
[SPARK-45362][PYTHON] Project out PARTITION BY expressions before Python
UDTF 'eval' method consumes them
### What changes were proposed in this pull request?
This PR projects out PARTITION BY expressions before Python UDTF 'eval'
method consumes them.
Before this PR, if a query included this `PARTITION BY` clause:
```
SELECT * FROM udtf((SELECT a, b FROM TABLE t) PARTITION BY (c, d))
```
Then the `eval` method received four columns in each row: `a, b, c, d`.
After this PR, the `eval` method only receives two columns: `a, b`, as
expected.
### Why are the changes needed?
This makes the Python UDTF `TABLE` columns consistently match what the
`eval` method receives, as expected.
### Does this PR introduce _any_ user-facing change?
Yes, see above.
### How was this patch tested?
This PR adds new unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43156 from dtenedor/project-out-partition-exprs.
Authored-by: Daniel Tenedorio <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
python/pyspark/sql/tests/test_udtf.py | 12 ++++++++++++
python/pyspark/worker.py | 31 +++++++++++++++++++++++++++----
2 files changed, 39 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 97d5190a506..a1d82056c50 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -2009,6 +2009,10 @@ class BaseUDTFTestsMixin:
self._partition_col = None
def eval(self, row: Row):
+ # Make sure that the PARTITION BY expressions were projected
out.
+ assert len(row.asDict().items()) == 2
+ assert "partition_col" in row
+ assert "input" in row
self._sum += row["input"]
if self._partition_col is not None and self._partition_col !=
row["partition_col"]:
# Make sure that all values of the partitioning column are
the same
@@ -2092,6 +2096,10 @@ class BaseUDTFTestsMixin:
self._partition_col = None
def eval(self, row: Row, partition_col: str):
+ # Make sure that the PARTITION BY and ORDER BY expressions
were projected out.
+ assert len(row.asDict().items()) == 2
+ assert "partition_col" in row
+ assert "input" in row
# Make sure that all values of the partitioning column are the
same
# for each row consumed by this method for this instance of
the class.
if self._partition_col is not None and self._partition_col !=
row[partition_col]:
@@ -2247,6 +2255,10 @@ class BaseUDTFTestsMixin:
)
def eval(self, row: Row):
+ # Make sure that the PARTITION BY and ORDER BY expressions
were projected out.
+ assert len(row.asDict().items()) == 2
+ assert "partition_col" in row
+ assert "input" in row
# Make sure that all values of the partitioning column are the
same
# for each row consumed by this method for this instance of
the class.
if self._partition_col is not None and self._partition_col !=
row["partition_col"]:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 77481704979..4cffb02a64a 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -51,7 +51,14 @@ from pyspark.sql.pandas.serializers import (
ApplyInPandasWithStateSerializer,
)
from pyspark.sql.pandas.types import to_arrow_type
-from pyspark.sql.types import BinaryType, Row, StringType, StructType,
_parse_datatype_json_string
+from pyspark.sql.types import (
+ BinaryType,
+ Row,
+ StringType,
+ StructType,
+ _create_row,
+ _parse_datatype_json_string,
+)
from pyspark.util import fail_on_stopiteration, handle_worker_exception
from pyspark import shuffle
from pyspark.errors import PySparkRuntimeError, PySparkTypeError
@@ -735,7 +742,12 @@ def read_udtf(pickleSer, infile, eval_type):
yield row
self._udtf = self._create_udtf()
if self._udtf.eval is not None:
- result = self._udtf.eval(*args, **kwargs)
+ # Filter the arguments to exclude projected PARTITION BY
values added by Catalyst.
+ filtered_args = [self._remove_partition_by_exprs(arg) for arg
in args]
+ filtered_kwargs = {
+ key: self._remove_partition_by_exprs(value) for (key,
value) in kwargs.items()
+ }
+ result = self._udtf.eval(*filtered_args, **filtered_kwargs)
if result is not None:
for row in result:
yield row
@@ -752,10 +764,9 @@ def read_udtf(pickleSer, infile, eval_type):
prev_table_arg = self._get_table_arg(self._prev_arguments)
cur_partitions_args = []
prev_partitions_args = []
- for i in partition_child_indexes:
+ for i in self._partition_child_indexes:
cur_partitions_args.append(cur_table_arg[i])
prev_partitions_args.append(prev_table_arg[i])
- self._prev_arguments = arguments
result = any(k != v for k, v in zip(cur_partitions_args,
prev_partitions_args))
self._prev_arguments = arguments
return result
@@ -763,6 +774,18 @@ def read_udtf(pickleSer, infile, eval_type):
def _get_table_arg(self, inputs: list) -> Row:
return [x for x in inputs if type(x) is Row][0]
+ def _remove_partition_by_exprs(self, arg: Any) -> Any:
+ if isinstance(arg, Row):
+ new_row_keys = []
+ new_row_values = []
+ for i, (key, value) in enumerate(zip(arg.__fields__, arg)):
+ if i not in self._partition_child_indexes:
+ new_row_keys.append(key)
+ new_row_values.append(value)
+ return _create_row(new_row_keys, new_row_values)
+ else:
+ return arg
+
# Instantiate the UDTF class.
try:
if len(partition_child_indexes) > 0:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]