This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new df534c355d9 [SPARK-44952][SQL][PYTHON] Support named arguments in 
aggregate Pandas UDFs
df534c355d9 is described below

commit df534c355d9059fb5b128491a8f037baa121cbd7
Author: Takuya UESHIN <[email protected]>
AuthorDate: Fri Sep 1 10:41:37 2023 -0700

    [SPARK-44952][SQL][PYTHON] Support named arguments in aggregate Pandas UDFs
    
    ### What changes were proposed in this pull request?
    
    Supports named arguments in aggregate Pandas UDFs.
    
    For example:
    
    ```py
    >>> pandas_udf("double")
    ... def weighted_mean(v: pd.Series, w: pd.Series) -> float:
    ...     import numpy as np
    ...     return np.average(v, weights=w)
    ...
    >>> df = spark.createDataFrame(
    ...     [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 
10.0, 3.0)],
    ...     ("id", "v", "w"))
    
    >>> df.groupby("id").agg(weighted_mean(v=df["v"], w=df["w"])).show()
    +---+-----------------------------+
    | id|weighted_mean(v => v, w => w)|
    +---+-----------------------------+
    |  1|           1.6666666666666667|
    |  2|            7.166666666666667|
    +---+-----------------------------+
    
    >>> df.groupby("id").agg(weighted_mean(w=df["w"], v=df["v"])).show()
    +---+-----------------------------+
    | id|weighted_mean(w => w, v => v)|
    +---+-----------------------------+
    |  1|           1.6666666666666667|
    |  2|            7.166666666666667|
    +---+-----------------------------+
    ```
    
    or with window:
    
    ```py
    >>> w = Window.partitionBy("id").orderBy("v").rowsBetween(-2, 1)
    
    >>> df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)).show()
    +---+----+---+------------------+
    | id|   v|  w|                wm|
    +---+----+---+------------------+
    |  1| 1.0|1.0|1.6666666666666667|
    |  1| 2.0|2.0|1.6666666666666667|
    |  2| 3.0|1.0| 4.333333333333333|
    |  2| 5.0|2.0| 7.166666666666667|
    |  2|10.0|3.0| 7.166666666666667|
    +---+----+---+------------------+
    
    >>> df.withColumn("wm", weighted_mean_udf(w=df.w, v=df.v).over(w)).show()
    +---+----+---+------------------+
    | id|   v|  w|                wm|
    +---+----+---+------------------+
    |  1| 1.0|1.0|1.6666666666666667|
    |  1| 2.0|2.0|1.6666666666666667|
    |  2| 3.0|1.0| 4.333333333333333|
    |  2| 5.0|2.0| 7.166666666666667|
    |  2|10.0|3.0| 7.166666666666667|
    +---+----+---+------------------+
    ```
    
    ### Why are the changes needed?
    
    Now that named arguments support was added 
(https://github.com/apache/spark/pull/41796, 
https://github.com/apache/spark/pull/42020).
    
    Aggregate Pandas UDFs can support it.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, named arguments will be available for aggregate Pandas UDFs.
    
    ### How was this patch tested?
    
    Added related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42663 from ueshin/issues/SPARK-44952/kwargs.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Takuya UESHIN <[email protected]>
---
 python/pyspark/sql/pandas/functions.py             |  20 ++-
 .../tests/pandas/test_pandas_udf_grouped_agg.py    | 147 ++++++++++++++++-
 .../sql/tests/pandas/test_pandas_udf_window.py     | 173 ++++++++++++++++++++-
 python/pyspark/sql/tests/test_udf.py               |  15 ++
 python/pyspark/sql/tests/test_udtf.py              |  15 ++
 python/pyspark/worker.py                           |  25 +--
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  11 +-
 .../execution/python/AggregateInPandasExec.scala   |  23 ++-
 .../python/UserDefinedPythonFunction.scala         |   3 +-
 .../python/WindowInPandasEvaluatorFactory.scala    |  37 +++--
 10 files changed, 429 insertions(+), 40 deletions(-)

diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index ad9fdac9706..652129180df 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -57,7 +57,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
         Supports Spark Connect.
 
     .. versionchanged:: 4.0.0
-        Supports keyword-arguments in SCALAR type.
+        Supports keyword-arguments in SCALAR and GROUPED_AGG type.
 
     Parameters
     ----------
@@ -267,6 +267,24 @@ def pandas_udf(f=None, returnType=None, functionType=None):
         |  2|        6.0|
         +---+-----------+
 
+        This type of Pandas UDF can use keyword arguments:
+
+        >>> @pandas_udf("double")
+        ... def weighted_mean_udf(v: pd.Series, w: pd.Series) -> float:
+        ...     import numpy as np
+        ...     return np.average(v, weights=w)
+        ...
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), 
(2, 10.0, 3.0)],
+        ...     ("id", "v", "w"))
+        >>> df.groupby("id").agg(weighted_mean_udf(w=df["w"], 
v=df["v"])).show()
+        +---+---------------------------------+
+        | id|weighted_mean_udf(w => w, v => v)|
+        +---+---------------------------------+
+        |  1|               1.6666666666666667|
+        |  2|                7.166666666666667|
+        +---+---------------------------------+
+
         This UDF can also be used as window functions as below:
 
         >>> from pyspark.sql import Window
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index f434489a6fb..b500be7a969 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -32,7 +32,7 @@ from pyspark.sql.functions import (
     PandasUDFType,
 )
 from pyspark.sql.types import ArrayType, YearMonthIntervalType
-from pyspark.errors import AnalysisException, PySparkNotImplementedError
+from pyspark.errors import AnalysisException, PySparkNotImplementedError, 
PythonException
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_pandas,
@@ -40,7 +40,7 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, assertDataFrameEqual
 
 
 if have_pandas:
@@ -575,6 +575,149 @@ class GroupedAggPandasUDFTestsMixin:
 
         assert filtered.collect()[0]["mean"] == 42.0
 
+    def test_named_arguments(self):
+        df = self.data
+        weighted_mean = self.pandas_agg_weighted_mean_udf
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+            self.spark.udf.register("weighted_mean", weighted_mean)
+
+            for i, aggregated in enumerate(
+                [
+                    df.groupby("id").agg(weighted_mean(df.v, 
w=df.w).alias("wm")),
+                    df.groupby("id").agg(weighted_mean(v=df.v, 
w=df.w).alias("wm")),
+                    df.groupby("id").agg(weighted_mean(w=df.w, 
v=df.v).alias("wm")),
+                    self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm 
FROM v GROUP BY id"),
+                    self.spark.sql(
+                        "SELECT id, weighted_mean(v => v, w => w) as wm FROM v 
GROUP BY id"
+                    ),
+                    self.spark.sql(
+                        "SELECT id, weighted_mean(w => w, v => v) as wm FROM v 
GROUP BY id"
+                    ),
+                ]
+            ):
+                with self.subTest(query_no=i):
+                    assertDataFrameEqual(aggregated, 
df.groupby("id").agg(mean(df.v).alias("wm")))
+
+    def test_named_arguments_negative(self):
+        df = self.data
+        weighted_mean = self.pandas_agg_weighted_mean_udf
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+            self.spark.udf.register("weighted_mean", weighted_mean)
+
+            with self.assertRaisesRegex(
+                AnalysisException,
+                
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+            ):
+                self.spark.sql(
+                    "SELECT id, weighted_mean(v => v, v => w) as wm FROM v 
GROUP BY id"
+                ).show()
+
+            with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+                self.spark.sql(
+                    "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY 
id"
+                ).show()
+
+            with self.assertRaisesRegex(
+                PythonException, r"weighted_mean\(\) got an unexpected keyword 
argument 'x'"
+            ):
+                self.spark.sql(
+                    "SELECT id, weighted_mean(v => v, x => w) as wm FROM v 
GROUP BY id"
+                ).show()
+
+            with self.assertRaisesRegex(
+                PythonException, r"weighted_mean\(\) got multiple values for 
argument 'v'"
+            ):
+                self.spark.sql(
+                    "SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY 
id"
+                ).show()
+
+    def test_kwargs(self):
+        df = self.data
+
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def weighted_mean(**kwargs):
+            import numpy as np
+
+            return np.average(kwargs["v"], weights=kwargs["w"])
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+            self.spark.udf.register("weighted_mean", weighted_mean)
+
+            for i, aggregated in enumerate(
+                [
+                    df.groupby("id").agg(weighted_mean(v=df.v, 
w=df.w).alias("wm")),
+                    df.groupby("id").agg(weighted_mean(w=df.w, 
v=df.v).alias("wm")),
+                    self.spark.sql(
+                        "SELECT id, weighted_mean(v => v, w => w) as wm FROM v 
GROUP BY id"
+                    ),
+                    self.spark.sql(
+                        "SELECT id, weighted_mean(w => w, v => v) as wm FROM v 
GROUP BY id"
+                    ),
+                ]
+            ):
+                with self.subTest(query_no=i):
+                    assertDataFrameEqual(aggregated, 
df.groupby("id").agg(mean(df.v).alias("wm")))
+
+            # negative
+            with self.assertRaisesRegex(
+                AnalysisException,
+                
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+            ):
+                self.spark.sql(
+                    "SELECT id, weighted_mean(v => v, v => w) as wm FROM v 
GROUP BY id"
+                ).show()
+
+            with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+                self.spark.sql(
+                    "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY 
id"
+                ).show()
+
+    def test_named_arguments_and_defaults(self):
+        df = self.data
+
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def biased_sum(v, w=None):
+            return v.sum() + (w.sum() if w is not None else 100)
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+            self.spark.udf.register("biased_sum", biased_sum)
+
+            # without "w"
+            for i, aggregated in enumerate(
+                [
+                    df.groupby("id").agg(biased_sum(df.v).alias("s")),
+                    df.groupby("id").agg(biased_sum(v=df.v).alias("s")),
+                    self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP 
BY id"),
+                    self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v 
GROUP BY id"),
+                ]
+            ):
+                with self.subTest(with_w=False, query_no=i):
+                    assertDataFrameEqual(
+                        aggregated, df.groupby("id").agg((sum(df.v) + 
lit(100)).alias("s"))
+                    )
+
+            # with "w"
+            for i, aggregated in enumerate(
+                [
+                    df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")),
+                    df.groupby("id").agg(biased_sum(v=df.v, 
w=df.w).alias("s")),
+                    df.groupby("id").agg(biased_sum(w=df.w, 
v=df.v).alias("s")),
+                    self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM 
v GROUP BY id"),
+                    self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s 
FROM v GROUP BY id"),
+                    self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s 
FROM v GROUP BY id"),
+                ]
+            ):
+                with self.subTest(with_w=True, query_no=i):
+                    assertDataFrameEqual(
+                        aggregated, df.groupby("id").agg((sum(df.v) + 
sum(df.w)).alias("s"))
+                    )
+
 
 class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, 
ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index e74e3783b12..6968c074094 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -18,7 +18,7 @@
 import unittest
 from typing import cast
 
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PythonException
 from pyspark.sql.functions import (
     array,
     explode,
@@ -40,7 +40,7 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, assertDataFrameEqual
 
 if have_pandas:
     from pandas.testing import assert_frame_equal
@@ -107,6 +107,16 @@ class WindowPandasUDFTestsMixin:
 
         return min
 
+    @property
+    def pandas_agg_weighted_mean_udf(self):
+        import numpy as np
+
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def weighted_mean(v, w):
+            return np.average(v, weights=w)
+
+        return weighted_mean
+
     @property
     def unbounded_window(self):
         return (
@@ -394,6 +404,165 @@ class WindowPandasUDFTestsMixin:
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
+    def test_named_arguments(self):
+        df = self.data
+        weighted_mean = self.pandas_agg_weighted_mean_udf
+
+        for w, bound in [(self.sliding_row_window, True), 
(self.unbounded_window, False)]:
+            for i, windowed in enumerate(
+                [
+                    df.withColumn("wm", weighted_mean(df.v, w=df.w).over(w)),
+                    df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)),
+                    df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)),
+                ]
+            ):
+                with self.subTest(bound=bound, query_no=i):
+                    assertDataFrameEqual(windowed, df.withColumn("wm", 
mean(df.v).over(w)))
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+            self.spark.udf.register("weighted_mean", weighted_mean)
+
+            for w in [
+                "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
+                "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
+            ]:
+                window_spec = f"PARTITION BY id ORDER BY v {w}"
+                for i, func_call in enumerate(
+                    [
+                        "weighted_mean(v, w => w)",
+                        "weighted_mean(v => v, w => w)",
+                        "weighted_mean(w => w, v => v)",
+                    ]
+                ):
+                    with self.subTest(window_spec=window_spec, query_no=i):
+                        assertDataFrameEqual(
+                            self.spark.sql(
+                                f"SELECT id, {func_call} OVER ({window_spec}) 
as wm FROM v"
+                            ),
+                            self.spark.sql(f"SELECT id, mean(v) OVER 
({window_spec}) as wm FROM v"),
+                        )
+
+    def test_named_arguments_negative(self):
+        df = self.data
+        weighted_mean = self.pandas_agg_weighted_mean_udf
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+            self.spark.udf.register("weighted_mean", weighted_mean)
+
+            base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM 
v"
+
+            for w in [
+                "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
+                "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
+            ]:
+                window_spec = f"PARTITION BY id ORDER BY v {w}"
+                with self.subTest(window_spec=window_spec):
+                    with self.assertRaisesRegex(
+                        AnalysisException,
+                        
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+                    ):
+                        self.spark.sql(
+                            base_sql.format(
+                                func_call="weighted_mean(v => v, v => w)", 
window_spec=window_spec
+                            )
+                        ).show()
+
+                    with self.assertRaisesRegex(
+                        AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"
+                    ):
+                        self.spark.sql(
+                            base_sql.format(
+                                func_call="weighted_mean(v => v, w)", 
window_spec=window_spec
+                            )
+                        ).show()
+
+                    with self.assertRaisesRegex(
+                        PythonException, r"weighted_mean\(\) got an unexpected 
keyword argument 'x'"
+                    ):
+                        self.spark.sql(
+                            base_sql.format(
+                                func_call="weighted_mean(v => v, x => w)", 
window_spec=window_spec
+                            )
+                        ).show()
+
+                    with self.assertRaisesRegex(
+                        PythonException, r"weighted_mean\(\) got multiple 
values for argument 'v'"
+                    ):
+                        self.spark.sql(
+                            base_sql.format(
+                                func_call="weighted_mean(v, v => w)", 
window_spec=window_spec
+                            )
+                        ).show()
+
+    def test_kwargs(self):
+        df = self.data
+
+        @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+        def weighted_mean(**kwargs):
+            import numpy as np
+
+            return np.average(kwargs["v"], weights=kwargs["w"])
+
+        for w, bound in [(self.sliding_row_window, True), 
(self.unbounded_window, False)]:
+            for i, windowed in enumerate(
+                [
+                    df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)),
+                    df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)),
+                ]
+            ):
+                with self.subTest(bound=bound, query_no=i):
+                    assertDataFrameEqual(windowed, df.withColumn("wm", 
mean(df.v).over(w)))
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+            self.spark.udf.register("weighted_mean", weighted_mean)
+
+            base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM 
v"
+
+            for w in [
+                "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING",
+                "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING",
+            ]:
+                window_spec = f"PARTITION BY id ORDER BY v {w}"
+                with self.subTest(window_spec=window_spec):
+                    for i, func_call in enumerate(
+                        [
+                            "weighted_mean(v => v, w => w)",
+                            "weighted_mean(w => w, v => v)",
+                        ]
+                    ):
+                        with self.subTest(query_no=i):
+                            assertDataFrameEqual(
+                                self.spark.sql(
+                                    base_sql.format(func_call=func_call, 
window_spec=window_spec)
+                                ),
+                                self.spark.sql(
+                                    base_sql.format(func_call="mean(v)", 
window_spec=window_spec)
+                                ),
+                            )
+
+                    # negative
+                    with self.assertRaisesRegex(
+                        AnalysisException,
+                        
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+                    ):
+                        self.spark.sql(
+                            base_sql.format(
+                                func_call="weighted_mean(v => v, v => w)", 
window_spec=window_spec
+                            )
+                        ).show()
+
+                    with self.assertRaisesRegex(
+                        AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"
+                    ):
+                        self.spark.sql(
+                            base_sql.format(
+                                func_call="weighted_mean(v => v, w)", 
window_spec=window_spec
+                            )
+                        ).show()
+
 
 class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index f72bf288230..32ea05bd00a 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -939,6 +939,11 @@ class BaseUDFTestsMixin(object):
         ):
             self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
 
+        with self.assertRaisesRegex(
+            PythonException, r"test_udf\(\) got multiple values for argument 
'a'"
+        ):
+            self.spark.sql("SELECT test_udf(id, a => id * 10) FROM 
range(2)").show()
+
     def test_kwargs(self):
         @udf("int")
         def test_udf(**kwargs):
@@ -957,6 +962,16 @@ class BaseUDFTestsMixin(object):
             with self.subTest(query_no=i):
                 assertDataFrameEqual(df, [Row(0), Row(101)])
 
+        # negative
+        with self.assertRaisesRegex(
+            AnalysisException,
+            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+        ):
+            self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
+
+        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+            self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
+
     def test_named_arguments_and_defaults(self):
         @udf("int")
         def test_udf(a, b=0):
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index a7545c332e6..95e46ba433c 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -1848,6 +1848,11 @@ class BaseUDTFTestsMixin:
         ):
             self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show()
 
+        with self.assertRaisesRegex(
+            PythonException, r"eval\(\) got multiple values for argument 'a'"
+        ):
+            self.spark.sql("SELECT * FROM test_udtf(10, a => 100)").show()
+
     def test_udtf_with_kwargs(self):
         @udtf(returnType="a: int, b: string")
         class TestUDTF:
@@ -1867,6 +1872,16 @@ class BaseUDTFTestsMixin:
             with self.subTest(query_no=i):
                 assertDataFrameEqual(df, [Row(a=10, b="x")])
 
+        # negative
+        with self.assertRaisesRegex(
+            AnalysisException,
+            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+        ):
+            self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show()
+
+        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+            self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show()
+
     def test_udtf_with_analyze_kwargs(self):
         @udtf
         class TestUDTF:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 19c8c9c897b..d95a5c4672f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -452,13 +452,13 @@ def wrap_grouped_map_pandas_udf_with_state(f, 
return_type):
 def wrap_grouped_agg_pandas_udf(f, return_type):
     arrow_return_type = to_arrow_type(return_type)
 
-    def wrapped(*series):
+    def wrapped(*args, **kwargs):
         import pandas as pd
 
-        result = f(*series)
+        result = f(*args, **kwargs)
         return pd.Series([result])
 
-    return lambda *a: (wrapped(*a), arrow_return_type)
+    return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type)
 
 
 def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index):
@@ -484,19 +484,19 @@ def wrap_unbounded_window_agg_pandas_udf(f, return_type):
     # the scalar value.
     arrow_return_type = to_arrow_type(return_type)
 
-    def wrapped(*series):
+    def wrapped(*args, **kwargs):
         import pandas as pd
 
-        result = f(*series)
-        return pd.Series([result]).repeat(len(series[0]))
+        result = f(*args, **kwargs)
+        return pd.Series([result]).repeat(len((list(args) + 
list(kwargs.values()))[0]))
 
-    return lambda *a: (wrapped(*a), arrow_return_type)
+    return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type)
 
 
 def wrap_bounded_window_agg_pandas_udf(f, return_type):
     arrow_return_type = to_arrow_type(return_type)
 
-    def wrapped(begin_index, end_index, *series):
+    def wrapped(begin_index, end_index, *args, **kwargs):
         import pandas as pd
 
         result = []
@@ -521,11 +521,12 @@ def wrap_bounded_window_agg_pandas_udf(f, return_type):
             # Note: Calling reset_index on the slices will increase the cost
             #       of creating slices by about 100%. Therefore, for 
performance
             #       reasons we don't do it here.
-            series_slices = [s.iloc[begin_array[i] : end_array[i]] for s in 
series]
-            result.append(f(*series_slices))
+            args_slices = [s.iloc[begin_array[i] : end_array[i]] for s in args]
+            kwargs_slices = {k: s.iloc[begin_array[i] : end_array[i]] for k, s 
in kwargs.items()}
+            result.append(f(*args_slices, **kwargs_slices))
         return pd.Series(result)
 
-    return lambda *a: (wrapped(*a), arrow_return_type)
+    return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type)
 
 
 def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
@@ -535,6 +536,8 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         PythonEvalType.SQL_BATCHED_UDF,
         PythonEvalType.SQL_ARROW_BATCHED_UDF,
         PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+        PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+        PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
         # The below doesn't support named argument, but shares the same 
protocol.
         PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
     ):
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9a6d9c8b735..b93f87e77b9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3003,6 +3003,10 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       // we need to make sure that col1 to col5 are all projected from the 
child of the Window
       // operator.
       val extractedExprMap = mutable.LinkedHashMap.empty[Expression, 
NamedExpression]
+      def getOrExtract(key: Expression, value: Expression): Expression = {
+        extractedExprMap.getOrElseUpdate(key.canonicalized,
+          Alias(value, s"_w${extractedExprMap.size}")()).toAttribute
+      }
       def extractExpr(expr: Expression): Expression = expr match {
         case ne: NamedExpression =>
           // If a named expression is not in regularExpressions, add it to
@@ -3016,11 +3020,14 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
           ne
         case e: Expression if e.foldable =>
           e // No need to create an attribute reference if it will be 
evaluated as a Literal.
+        case e: NamedArgumentExpression =>
+          // For NamedArgumentExpression, we extract the value and replace it 
with
+          // an AttributeReference (with an internal column name, e.g. "_w0").
+          NamedArgumentExpression(e.key, getOrExtract(e, e.value))
         case e: Expression =>
           // For other expressions, we extract it and replace it with an 
AttributeReference (with
           // an internal column name, e.g. "_w0").
-          extractedExprMap.getOrElseUpdate(e.canonicalized,
-            Alias(e, s"_w${extractedExprMap.size}")()).toAttribute
+          getOrExtract(e, e)
       }
 
       // Now, we extract regular expressions from 
expressionsWithWindowFunctions
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index 73560a596ca..7e349b665f3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -30,6 +30,7 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, Partitioning}
 import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, 
UnaryExecNode}
 import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
 import org.apache.spark.util.Utils
 
@@ -109,14 +110,20 @@ case class AggregateInPandasExec(
     // Also eliminate duplicate UDF inputs.
     val allInputs = new ArrayBuffer[Expression]
     val dataTypes = new ArrayBuffer[DataType]
-    val argOffsets = inputs.map { input =>
+    val argMetas = inputs.map { input =>
       input.map { e =>
-        if (allInputs.exists(_.semanticEquals(e))) {
-          allInputs.indexWhere(_.semanticEquals(e))
+        val (key, value) = e match {
+          case NamedArgumentExpression(key, value) =>
+            (Some(key), value)
+          case _ =>
+            (None, e)
+        }
+        if (allInputs.exists(_.semanticEquals(value))) {
+          ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key)
         } else {
-          allInputs += e
-          dataTypes += e.dataType
-          allInputs.length - 1
+          allInputs += value
+          dataTypes += value.dataType
+          ArgumentMetadata(allInputs.length - 1, key)
         }
       }.toArray
     }.toArray
@@ -164,10 +171,10 @@ case class AggregateInPandasExec(
         rows
       }
 
-      val columnarBatchIter = new ArrowPythonRunner(
+      val columnarBatchIter = new ArrowPythonWithNamedArgumentRunner(
         pyFuncs,
         PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
-        argOffsets,
+        argMetas,
         aggInputSchema,
         sessionLocalTimeZone,
         largeVarTypes,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index f576637aa25..2fcc428407e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -52,7 +52,8 @@ case class UserDefinedPythonFunction(
   def builder(e: Seq[Expression]): Expression = {
     if (pythonEvalType == PythonEvalType.SQL_BATCHED_UDF
         || pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF
-        || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) {
+        || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
+        || pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) {
       /*
        * Check if the named arguments:
        * - don't have duplicated names
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
index a32d892622b..cf9f8c22ea0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
@@ -25,11 +25,12 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.{JobArtifactSet, PartitionEvaluator, 
PartitionEvaluatorFactory, SparkEnv, TaskContext}
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, 
NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, 
SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, 
NamedArgumentExpression, NamedExpression, PythonFuncExpression, PythonUDAF, 
SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
 import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, 
UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, 
UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, IntegerType, StructField, 
StructType}
@@ -170,14 +171,20 @@ class WindowInPandasEvaluatorFactory(
     // handles UDF inputs.
     private val dataInputs = new ArrayBuffer[Expression]
     private val dataInputTypes = new ArrayBuffer[DataType]
-    private val argOffsets = inputs.map { input =>
+    private val argMetas = inputs.map { input =>
       input.map { e =>
-        if (dataInputs.exists(_.semanticEquals(e))) {
-          dataInputs.indexWhere(_.semanticEquals(e))
+        val (key, value) = e match {
+          case NamedArgumentExpression(key, value) =>
+            (Some(key), value)
+          case _ =>
+            (None, e)
+        }
+        if (dataInputs.exists(_.semanticEquals(value))) {
+          ArgumentMetadata(dataInputs.indexWhere(_.semanticEquals(value)), key)
         } else {
-          dataInputs += e
-          dataInputTypes += e.dataType
-          dataInputs.length - 1
+          dataInputs += value
+          dataInputTypes += value.dataType
+          ArgumentMetadata(dataInputs.length - 1, key)
         }
       }.toArray
     }.toArray
@@ -206,11 +213,15 @@ class WindowInPandasEvaluatorFactory(
     pyFuncs.indices.foreach { exprIndex =>
       val frameIndex = expressionIndexToFrameIndex(exprIndex)
       if (isBounded(frameIndex)) {
-        argOffsets(exprIndex) =
-          Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++
-            argOffsets(exprIndex).map(_ + windowBoundsInput.length)
+        argMetas(exprIndex) =
+          Array(
+            ArgumentMetadata(lowerBoundIndex(frameIndex), None),
+            ArgumentMetadata(upperBoundIndex(frameIndex), None)) ++
+          argMetas(exprIndex).map(
+            meta => ArgumentMetadata(meta.offset + windowBoundsInput.length, 
meta.name))
       } else {
-        argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + 
windowBoundsInput.length)
+        argMetas(exprIndex) = argMetas(exprIndex).map(
+          meta => ArgumentMetadata(meta.offset + windowBoundsInput.length, 
meta.name))
       }
     }
 
@@ -346,10 +357,10 @@ class WindowInPandasEvaluatorFactory(
         }
       }
 
-      val windowFunctionResult = new ArrowPythonRunner(
+      val windowFunctionResult = new ArrowPythonWithNamedArgumentRunner(
         pyFuncs,
         PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
-        argOffsets,
+        argMetas,
         pythonInputSchema,
         sessionLocalTimeZone,
         largeVarTypes,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to