This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new df534c355d9 [SPARK-44952][SQL][PYTHON] Support named arguments in
aggregate Pandas UDFs
df534c355d9 is described below
commit df534c355d9059fb5b128491a8f037baa121cbd7
Author: Takuya UESHIN <[email protected]>
AuthorDate: Fri Sep 1 10:41:37 2023 -0700
[SPARK-44952][SQL][PYTHON] Support named arguments in aggregate Pandas UDFs
### What changes were proposed in this pull request?
Supports named arguments in aggregate Pandas UDFs.
For example:
```py
>>> pandas_udf("double")
... def weighted_mean(v: pd.Series, w: pd.Series) -> float:
... import numpy as np
... return np.average(v, weights=w)
...
>>> df = spark.createDataFrame(
... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
... ("id", "v", "w"))
>>> df.groupby("id").agg(weighted_mean(v=df["v"], w=df["w"])).show()
+---+-----------------------------+
| id|weighted_mean(v => v, w => w)|
+---+-----------------------------+
| 1| 1.6666666666666667|
| 2| 7.166666666666667|
+---+-----------------------------+
>>> df.groupby("id").agg(weighted_mean(w=df["w"], v=df["v"])).show()
+---+-----------------------------+
| id|weighted_mean(w => w, v => v)|
+---+-----------------------------+
| 1| 1.6666666666666667|
| 2| 7.166666666666667|
+---+-----------------------------+
```
or with window:
```py
>>> w = Window.partitionBy("id").orderBy("v").rowsBetween(-2, 1)
>>> df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)).show()
+---+----+---+------------------+
| id| v| w| wm|
+---+----+---+------------------+
| 1| 1.0|1.0|1.6666666666666667|
| 1| 2.0|2.0|1.6666666666666667|
| 2| 3.0|1.0| 4.333333333333333|
| 2| 5.0|2.0| 7.166666666666667|
| 2|10.0|3.0| 7.166666666666667|
+---+----+---+------------------+
>>> df.withColumn("wm", weighted_mean_udf(w=df.w, v=df.v).over(w)).show()
+---+----+---+------------------+
| id| v| w| wm|
+---+----+---+------------------+
| 1| 1.0|1.0|1.6666666666666667|
| 1| 2.0|2.0|1.6666666666666667|
| 2| 3.0|1.0| 4.333333333333333|
| 2| 5.0|2.0| 7.166666666666667|
| 2|10.0|3.0| 7.166666666666667|
+---+----+---+------------------+
```
### Why are the changes needed?
Now that named arguments support was added
(https://github.com/apache/spark/pull/41796,
https://github.com/apache/spark/pull/42020).
Aggregate Pandas UDFs can support it.
### Does this PR introduce _any_ user-facing change?
Yes, named arguments will be available for aggregate Pandas UDFs.
### How was this patch tested?
Added related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #42663 from ueshin/issues/SPARK-44952/kwargs.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
python/pyspark/sql/pandas/functions.py | 20 ++-
.../tests/pandas/test_pandas_udf_grouped_agg.py | 147 ++++++++++++++++-
.../sql/tests/pandas/test_pandas_udf_window.py | 173 ++++++++++++++++++++-
python/pyspark/sql/tests/test_udf.py | 15 ++
python/pyspark/sql/tests/test_udtf.py | 15 ++
python/pyspark/worker.py | 25 +--
.../spark/sql/catalyst/analysis/Analyzer.scala | 11 +-
.../execution/python/AggregateInPandasExec.scala | 23 ++-
.../python/UserDefinedPythonFunction.scala | 3 +-
.../python/WindowInPandasEvaluatorFactory.scala | 37 +++--
10 files changed, 429 insertions(+), 40 deletions(-)
diff --git a/python/pyspark/sql/pandas/functions.py
b/python/pyspark/sql/pandas/functions.py
index ad9fdac9706..652129180df 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -57,7 +57,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
Supports Spark Connect.
.. versionchanged:: 4.0.0
- Supports keyword-arguments in SCALAR type.
+ Supports keyword-arguments in SCALAR and GROUPED_AGG type.
Parameters
----------
@@ -267,6 +267,24 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 6.0|
+---+-----------+
+ This type of Pandas UDF can use keyword arguments:
+
+ >>> @pandas_udf("double")
+ ... def weighted_mean_udf(v: pd.Series, w: pd.Series) -> float:
+ ... import numpy as np
+ ... return np.average(v, weights=w)
+ ...
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0),
(2, 10.0, 3.0)],
+ ... ("id", "v", "w"))
+ >>> df.groupby("id").agg(weighted_mean_udf(w=df["w"],
v=df["v"])).show()
+ +---+---------------------------------+
+ | id|weighted_mean_udf(w => w, v => v)|
+ +---+---------------------------------+
+ | 1| 1.6666666666666667|
+ | 2| 7.166666666666667|
+ +---+---------------------------------+
+
This UDF can also be used as window functions as below:
>>> from pyspark.sql import Window
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index f434489a6fb..b500be7a969 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -32,7 +32,7 @@ from pyspark.sql.functions import (
PandasUDFType,
)
from pyspark.sql.types import ArrayType, YearMonthIntervalType
-from pyspark.errors import AnalysisException, PySparkNotImplementedError
+from pyspark.errors import AnalysisException, PySparkNotImplementedError,
PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
@@ -40,7 +40,7 @@ from pyspark.testing.sqlutils import (
pandas_requirement_message,
pyarrow_requirement_message,
)
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, assertDataFrameEqual
if have_pandas:
@@ -575,6 +575,149 @@ class GroupedAggPandasUDFTestsMixin:
assert filtered.collect()[0]["mean"] == 42.0
+ def test_named_arguments(self):
+ df = self.data
+ weighted_mean = self.pandas_agg_weighted_mean_udf
+
+ with self.tempView("v"):
+ df.createOrReplaceTempView("v")
+ self.spark.udf.register("weighted_mean", weighted_mean)
+
+ for i, aggregated in enumerate(
+ [
+ df.groupby("id").agg(weighted_mean(df.v,
w=df.w).alias("wm")),
+ df.groupby("id").agg(weighted_mean(v=df.v,
w=df.w).alias("wm")),
+ df.groupby("id").agg(weighted_mean(w=df.w,
v=df.v).alias("wm")),
+ self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm
FROM v GROUP BY id"),
+ self.spark.sql(
+ "SELECT id, weighted_mean(v => v, w => w) as wm FROM v
GROUP BY id"
+ ),
+ self.spark.sql(
+ "SELECT id, weighted_mean(w => w, v => v) as wm FROM v
GROUP BY id"
+ ),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(aggregated,
df.groupby("id").agg(mean(df.v).alias("wm")))
+
+ def test_named_arguments_negative(self):
+ df = self.data
+ weighted_mean = self.pandas_agg_weighted_mean_udf
+
+ with self.tempView("v"):
+ df.createOrReplaceTempView("v")
+ self.spark.udf.register("weighted_mean", weighted_mean)
+
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql(
+ "SELECT id, weighted_mean(v => v, v => w) as wm FROM v
GROUP BY id"
+ ).show()
+
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql(
+ "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY
id"
+ ).show()
+
+ with self.assertRaisesRegex(
+ PythonException, r"weighted_mean\(\) got an unexpected keyword
argument 'x'"
+ ):
+ self.spark.sql(
+ "SELECT id, weighted_mean(v => v, x => w) as wm FROM v
GROUP BY id"
+ ).show()
+
+ with self.assertRaisesRegex(
+ PythonException, r"weighted_mean\(\) got multiple values for
argument 'v'"
+ ):
+ self.spark.sql(
+ "SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY
id"
+ ).show()
+
+ def test_kwargs(self):
+ df = self.data
+
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def weighted_mean(**kwargs):
+ import numpy as np
+
+ return np.average(kwargs["v"], weights=kwargs["w"])
+
+ with self.tempView("v"):
+ df.createOrReplaceTempView("v")
+ self.spark.udf.register("weighted_mean", weighted_mean)
+
+ for i, aggregated in enumerate(
+ [
+ df.groupby("id").agg(weighted_mean(v=df.v,
w=df.w).alias("wm")),
+ df.groupby("id").agg(weighted_mean(w=df.w,
v=df.v).alias("wm")),
+ self.spark.sql(
+ "SELECT id, weighted_mean(v => v, w => w) as wm FROM v
GROUP BY id"
+ ),
+ self.spark.sql(
+ "SELECT id, weighted_mean(w => w, v => v) as wm FROM v
GROUP BY id"
+ ),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(aggregated,
df.groupby("id").agg(mean(df.v).alias("wm")))
+
+ # negative
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql(
+ "SELECT id, weighted_mean(v => v, v => w) as wm FROM v
GROUP BY id"
+ ).show()
+
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql(
+ "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY
id"
+ ).show()
+
+ def test_named_arguments_and_defaults(self):
+ df = self.data
+
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def biased_sum(v, w=None):
+ return v.sum() + (w.sum() if w is not None else 100)
+
+ with self.tempView("v"):
+ df.createOrReplaceTempView("v")
+ self.spark.udf.register("biased_sum", biased_sum)
+
+ # without "w"
+ for i, aggregated in enumerate(
+ [
+ df.groupby("id").agg(biased_sum(df.v).alias("s")),
+ df.groupby("id").agg(biased_sum(v=df.v).alias("s")),
+ self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP
BY id"),
+ self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v
GROUP BY id"),
+ ]
+ ):
+ with self.subTest(with_w=False, query_no=i):
+ assertDataFrameEqual(
+ aggregated, df.groupby("id").agg((sum(df.v) +
lit(100)).alias("s"))
+ )
+
+ # with "w"
+ for i, aggregated in enumerate(
+ [
+ df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")),
+ df.groupby("id").agg(biased_sum(v=df.v,
w=df.w).alias("s")),
+ df.groupby("id").agg(biased_sum(w=df.w,
v=df.v).alias("s")),
+ self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM
v GROUP BY id"),
+ self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s
FROM v GROUP BY id"),
+ self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s
FROM v GROUP BY id"),
+ ]
+ ):
+ with self.subTest(with_w=True, query_no=i):
+ assertDataFrameEqual(
+ aggregated, df.groupby("id").agg((sum(df.v) +
sum(df.w)).alias("s"))
+ )
+
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin,
ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index e74e3783b12..6968c074094 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -18,7 +18,7 @@
import unittest
from typing import cast
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PythonException
from pyspark.sql.functions import (
array,
explode,
@@ -40,7 +40,7 @@ from pyspark.testing.sqlutils import (
pandas_requirement_message,
pyarrow_requirement_message,
)
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, assertDataFrameEqual
if have_pandas:
from pandas.testing import assert_frame_equal
@@ -107,6 +107,16 @@ class WindowPandasUDFTestsMixin:
return min
+ @property
+ def pandas_agg_weighted_mean_udf(self):
+ import numpy as np
+
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def weighted_mean(v, w):
+ return np.average(v, weights=w)
+
+ return weighted_mean
+
@property
def unbounded_window(self):
return (
@@ -394,6 +404,165 @@ class WindowPandasUDFTestsMixin:
assert_frame_equal(expected1.toPandas(), result1.toPandas())
+ def test_named_arguments(self):
+ df = self.data
+ weighted_mean = self.pandas_agg_weighted_mean_udf
+
+ for w, bound in [(self.sliding_row_window, True),
(self.unbounded_window, False)]:
+ for i, windowed in enumerate(
+ [
+ df.withColumn("wm", weighted_mean(df.v, w=df.w).over(w)),
+ df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)),
+ df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)),
+ ]
+ ):
+ with self.subTest(bound=bound, query_no=i):
+ assertDataFrameEqual(windowed, df.withColumn("wm",
mean(df.v).over(w)))
+
+ with self.tempView("v"):
+ df.createOrReplaceTempView("v")
+ self.spark.udf.register("weighted_mean", weighted_mean)
+
+ for w in [
+ "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
+ "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
+ ]:
+ window_spec = f"PARTITION BY id ORDER BY v {w}"
+ for i, func_call in enumerate(
+ [
+ "weighted_mean(v, w => w)",
+ "weighted_mean(v => v, w => w)",
+ "weighted_mean(w => w, v => v)",
+ ]
+ ):
+ with self.subTest(window_spec=window_spec, query_no=i):
+ assertDataFrameEqual(
+ self.spark.sql(
+ f"SELECT id, {func_call} OVER ({window_spec})
as wm FROM v"
+ ),
+ self.spark.sql(f"SELECT id, mean(v) OVER
({window_spec}) as wm FROM v"),
+ )
+
+ def test_named_arguments_negative(self):
+ df = self.data
+ weighted_mean = self.pandas_agg_weighted_mean_udf
+
+ with self.tempView("v"):
+ df.createOrReplaceTempView("v")
+ self.spark.udf.register("weighted_mean", weighted_mean)
+
+ base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM
v"
+
+ for w in [
+ "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
+ "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
+ ]:
+ window_spec = f"PARTITION BY id ORDER BY v {w}"
+ with self.subTest(window_spec=window_spec):
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql(
+ base_sql.format(
+ func_call="weighted_mean(v => v, v => w)",
window_spec=window_spec
+ )
+ ).show()
+
+ with self.assertRaisesRegex(
+ AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"
+ ):
+ self.spark.sql(
+ base_sql.format(
+ func_call="weighted_mean(v => v, w)",
window_spec=window_spec
+ )
+ ).show()
+
+ with self.assertRaisesRegex(
+ PythonException, r"weighted_mean\(\) got an unexpected
keyword argument 'x'"
+ ):
+ self.spark.sql(
+ base_sql.format(
+ func_call="weighted_mean(v => v, x => w)",
window_spec=window_spec
+ )
+ ).show()
+
+ with self.assertRaisesRegex(
+ PythonException, r"weighted_mean\(\) got multiple
values for argument 'v'"
+ ):
+ self.spark.sql(
+ base_sql.format(
+ func_call="weighted_mean(v, v => w)",
window_spec=window_spec
+ )
+ ).show()
+
+ def test_kwargs(self):
+ df = self.data
+
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def weighted_mean(**kwargs):
+ import numpy as np
+
+ return np.average(kwargs["v"], weights=kwargs["w"])
+
+ for w, bound in [(self.sliding_row_window, True),
(self.unbounded_window, False)]:
+ for i, windowed in enumerate(
+ [
+ df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)),
+ df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)),
+ ]
+ ):
+ with self.subTest(bound=bound, query_no=i):
+ assertDataFrameEqual(windowed, df.withColumn("wm",
mean(df.v).over(w)))
+
+ with self.tempView("v"):
+ df.createOrReplaceTempView("v")
+ self.spark.udf.register("weighted_mean", weighted_mean)
+
+ base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM
v"
+
+ for w in [
+ "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
+ "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
+ ]:
+ window_spec = f"PARTITION BY id ORDER BY v {w}"
+ with self.subTest(window_spec=window_spec):
+ for i, func_call in enumerate(
+ [
+ "weighted_mean(v => v, w => w)",
+ "weighted_mean(w => w, v => v)",
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(
+ self.spark.sql(
+ base_sql.format(func_call=func_call,
window_spec=window_spec)
+ ),
+ self.spark.sql(
+ base_sql.format(func_call="mean(v)",
window_spec=window_spec)
+ ),
+ )
+
+ # negative
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql(
+ base_sql.format(
+ func_call="weighted_mean(v => v, v => w)",
window_spec=window_spec
+ )
+ ).show()
+
+ with self.assertRaisesRegex(
+ AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"
+ ):
+ self.spark.sql(
+ base_sql.format(
+ func_call="weighted_mean(v => v, w)",
window_spec=window_spec
+ )
+ ).show()
+
class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index f72bf288230..32ea05bd00a 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -939,6 +939,11 @@ class BaseUDFTestsMixin(object):
):
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+ with self.assertRaisesRegex(
+ PythonException, r"test_udf\(\) got multiple values for argument
'a'"
+ ):
+ self.spark.sql("SELECT test_udf(id, a => id * 10) FROM
range(2)").show()
+
def test_kwargs(self):
@udf("int")
def test_udf(**kwargs):
@@ -957,6 +962,16 @@ class BaseUDFTestsMixin(object):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])
+ # negative
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
+
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
+
def test_named_arguments_and_defaults(self):
@udf("int")
def test_udf(a, b=0):
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index a7545c332e6..95e46ba433c 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -1848,6 +1848,11 @@ class BaseUDTFTestsMixin:
):
self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show()
+ with self.assertRaisesRegex(
+ PythonException, r"eval\(\) got multiple values for argument 'a'"
+ ):
+ self.spark.sql("SELECT * FROM test_udtf(10, a => 100)").show()
+
def test_udtf_with_kwargs(self):
@udtf(returnType="a: int, b: string")
class TestUDTF:
@@ -1867,6 +1872,16 @@ class BaseUDTFTestsMixin:
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="x")])
+ # negative
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show()
+
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show()
+
def test_udtf_with_analyze_kwargs(self):
@udtf
class TestUDTF:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 19c8c9c897b..d95a5c4672f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -452,13 +452,13 @@ def wrap_grouped_map_pandas_udf_with_state(f,
return_type):
def wrap_grouped_agg_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
- def wrapped(*series):
+ def wrapped(*args, **kwargs):
import pandas as pd
- result = f(*series)
+ result = f(*args, **kwargs)
return pd.Series([result])
- return lambda *a: (wrapped(*a), arrow_return_type)
+ return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type)
def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index):
@@ -484,19 +484,19 @@ def wrap_unbounded_window_agg_pandas_udf(f, return_type):
# the scalar value.
arrow_return_type = to_arrow_type(return_type)
- def wrapped(*series):
+ def wrapped(*args, **kwargs):
import pandas as pd
- result = f(*series)
- return pd.Series([result]).repeat(len(series[0]))
+ result = f(*args, **kwargs)
+ return pd.Series([result]).repeat(len((list(args) +
list(kwargs.values()))[0]))
- return lambda *a: (wrapped(*a), arrow_return_type)
+ return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type)
def wrap_bounded_window_agg_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
- def wrapped(begin_index, end_index, *series):
+ def wrapped(begin_index, end_index, *args, **kwargs):
import pandas as pd
result = []
@@ -521,11 +521,12 @@ def wrap_bounded_window_agg_pandas_udf(f, return_type):
# Note: Calling reset_index on the slices will increase the cost
# of creating slices by about 100%. Therefore, for
performance
# reasons we don't do it here.
- series_slices = [s.iloc[begin_array[i] : end_array[i]] for s in
series]
- result.append(f(*series_slices))
+ args_slices = [s.iloc[begin_array[i] : end_array[i]] for s in args]
+ kwargs_slices = {k: s.iloc[begin_array[i] : end_array[i]] for k, s
in kwargs.items()}
+ result.append(f(*args_slices, **kwargs_slices))
return pd.Series(result)
- return lambda *a: (wrapped(*a), arrow_return_type)
+ return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type)
def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
@@ -535,6 +536,8 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
# The below doesn't support named argument, but shares the same
protocol.
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9a6d9c8b735..b93f87e77b9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3003,6 +3003,10 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
// we need to make sure that col1 to col5 are all projected from the
child of the Window
// operator.
val extractedExprMap = mutable.LinkedHashMap.empty[Expression,
NamedExpression]
+ def getOrExtract(key: Expression, value: Expression): Expression = {
+ extractedExprMap.getOrElseUpdate(key.canonicalized,
+ Alias(value, s"_w${extractedExprMap.size}")()).toAttribute
+ }
def extractExpr(expr: Expression): Expression = expr match {
case ne: NamedExpression =>
// If a named expression is not in regularExpressions, add it to
@@ -3016,11 +3020,14 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
ne
case e: Expression if e.foldable =>
e // No need to create an attribute reference if it will be
evaluated as a Literal.
+ case e: NamedArgumentExpression =>
+ // For NamedArgumentExpression, we extract the value and replace it
with
+ // an AttributeReference (with an internal column name, e.g. "_w0").
+ NamedArgumentExpression(e.key, getOrExtract(e, e.value))
case e: Expression =>
// For other expressions, we extract it and replace it with an
AttributeReference (with
// an internal column name, e.g. "_w0").
- extractedExprMap.getOrElseUpdate(e.canonicalized,
- Alias(e, s"_w${extractedExprMap.size}")()).toAttribute
+ getOrExtract(e, e)
}
// Now, we extract regular expressions from
expressionsWithWindowFunctions
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index 73560a596ca..7e349b665f3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -30,6 +30,7 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples,
ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan,
UnaryExecNode}
import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -109,14 +110,20 @@ case class AggregateInPandasExec(
// Also eliminate duplicate UDF inputs.
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
- val argOffsets = inputs.map { input =>
+ val argMetas = inputs.map { input =>
input.map { e =>
- if (allInputs.exists(_.semanticEquals(e))) {
- allInputs.indexWhere(_.semanticEquals(e))
+ val (key, value) = e match {
+ case NamedArgumentExpression(key, value) =>
+ (Some(key), value)
+ case _ =>
+ (None, e)
+ }
+ if (allInputs.exists(_.semanticEquals(value))) {
+ ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key)
} else {
- allInputs += e
- dataTypes += e.dataType
- allInputs.length - 1
+ allInputs += value
+ dataTypes += value.dataType
+ ArgumentMetadata(allInputs.length - 1, key)
}
}.toArray
}.toArray
@@ -164,10 +171,10 @@ case class AggregateInPandasExec(
rows
}
- val columnarBatchIter = new ArrowPythonRunner(
+ val columnarBatchIter = new ArrowPythonWithNamedArgumentRunner(
pyFuncs,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
- argOffsets,
+ argMetas,
aggInputSchema,
sessionLocalTimeZone,
largeVarTypes,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index f576637aa25..2fcc428407e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -52,7 +52,8 @@ case class UserDefinedPythonFunction(
def builder(e: Seq[Expression]): Expression = {
if (pythonEvalType == PythonEvalType.SQL_BATCHED_UDF
|| pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF
- || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) {
+ || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
+ || pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) {
/*
* Check if the named arguments:
* - don't have duplicated names
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
index a32d892622b..cf9f8c22ea0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
@@ -25,11 +25,12 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{JobArtifactSet, PartitionEvaluator,
PartitionEvaluatorFactory, SparkEnv, TaskContext}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow,
NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder,
SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow,
NamedArgumentExpression, NamedExpression, PythonFuncExpression, PythonUDAF,
SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame,
UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame,
UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType, StructField,
StructType}
@@ -170,14 +171,20 @@ class WindowInPandasEvaluatorFactory(
// handles UDF inputs.
private val dataInputs = new ArrayBuffer[Expression]
private val dataInputTypes = new ArrayBuffer[DataType]
- private val argOffsets = inputs.map { input =>
+ private val argMetas = inputs.map { input =>
input.map { e =>
- if (dataInputs.exists(_.semanticEquals(e))) {
- dataInputs.indexWhere(_.semanticEquals(e))
+ val (key, value) = e match {
+ case NamedArgumentExpression(key, value) =>
+ (Some(key), value)
+ case _ =>
+ (None, e)
+ }
+ if (dataInputs.exists(_.semanticEquals(value))) {
+ ArgumentMetadata(dataInputs.indexWhere(_.semanticEquals(value)), key)
} else {
- dataInputs += e
- dataInputTypes += e.dataType
- dataInputs.length - 1
+ dataInputs += value
+ dataInputTypes += value.dataType
+ ArgumentMetadata(dataInputs.length - 1, key)
}
}.toArray
}.toArray
@@ -206,11 +213,15 @@ class WindowInPandasEvaluatorFactory(
pyFuncs.indices.foreach { exprIndex =>
val frameIndex = expressionIndexToFrameIndex(exprIndex)
if (isBounded(frameIndex)) {
- argOffsets(exprIndex) =
- Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++
- argOffsets(exprIndex).map(_ + windowBoundsInput.length)
+ argMetas(exprIndex) =
+ Array(
+ ArgumentMetadata(lowerBoundIndex(frameIndex), None),
+ ArgumentMetadata(upperBoundIndex(frameIndex), None)) ++
+ argMetas(exprIndex).map(
+ meta => ArgumentMetadata(meta.offset + windowBoundsInput.length,
meta.name))
} else {
- argOffsets(exprIndex) = argOffsets(exprIndex).map(_ +
windowBoundsInput.length)
+ argMetas(exprIndex) = argMetas(exprIndex).map(
+ meta => ArgumentMetadata(meta.offset + windowBoundsInput.length,
meta.name))
}
}
@@ -346,10 +357,10 @@ class WindowInPandasEvaluatorFactory(
}
}
- val windowFunctionResult = new ArrowPythonRunner(
+ val windowFunctionResult = new ArrowPythonWithNamedArgumentRunner(
pyFuncs,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
- argOffsets,
+ argMetas,
pythonInputSchema,
sessionLocalTimeZone,
largeVarTypes,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]