This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 2a37f221631 [SPARK-42168][3.2][SQL][PYTHON] Fix required child 
distribution of FlatMapCoGroupsInPandas (as in CoGroup)
2a37f221631 is described below

commit 2a37f221631b11b40a8fa38535b6ed1f0d06e49d
Author: Enrico Minack <[email protected]>
AuthorDate: Thu Jan 26 10:43:24 2023 +0900

    [SPARK-42168][3.2][SQL][PYTHON] Fix required child distribution of 
FlatMapCoGroupsInPandas (as in CoGroup)
    
    ### What changes were proposed in this pull request?
    Make `FlatMapCoGroupsInPandas` (used by PySpark) report its required child 
distribution as `HashClusteredDistribution`, rather than 
`ClusteredDistribution`. That is the same distribution as reported by `CoGroup` 
(used by Scala).
    
    ### Why are the changes needed?
    This allows the `EnsureRequirements` rule to correctly recognizes that 
`FlatMapCoGroupsInPandas` requiring `HashClusteredDistribution(id, day)` is not 
compatible with `HashPartitioning(day, id)`, while `ClusteredDistribution(id, 
day)` is compatible with `HashPartitioning(day, id)`.
    
    The following example returns an incorrect result in Spark 3.0, 3.1, and 
3.2.
    
    ```Scala
    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions.{col, lit, sum}
    
    val ids = 1000
    val days = 1000
    val parts = 10
    
    val id_df = spark.range(ids)
    val day_df = spark.range(days).withColumnRenamed("id", "day")
    val id_day_df = id_df.join(day_df)
    // these redundant aliases are needed to workaround bug SPARK-42132
    val left_df = id_day_df.select($"id".as("id"), $"day".as("day"), 
lit("left").as("side")).repartition(parts).cache()
    val right_df = id_day_df.select($"id".as("id"), $"day".as("day"), 
lit("right").as("side")).repartition(parts).cache()  //.withColumnRenamed("id", 
"id2")
    
    // note the column order is different to the groupBy("id", "day") column 
order below
    val window = Window.partitionBy("day", "id")
    
    case class Key(id: BigInt, day: BigInt)
    case class Value(id: BigInt, day: BigInt, side: String)
    case class Sum(id: BigInt, day: BigInt, side: String, day_sum: BigInt)
    
    val left_grouped_df = left_df.groupBy("id", "day").as[Key, Value]
    val right_grouped_df = right_df.withColumn("day_sum", 
sum(col("day")).over(window)).groupBy("id", "day").as[Key, Sum]
    
    val df = left_grouped_df.cogroup(right_grouped_df)((key: Key, left: 
Iterator[Value], right: Iterator[Sum]) => left)
    
    df.explain()
    df.show(5)
    ```
    
    Output was
    ```
    == Physical Plan ==
    AdaptiveSparkPlan isFinalPlan=false
    +- FlatMapCoGroupsInPandas [id#8L, day#9L], [id#29L, day#30L], 
cogroup(id#8L, day#9L, side#10, id#29L, day#30L, side#31, day_sum#54L), 
[id#64L, day#65L, lefts#66, rights#67]
       :- Sort [id#8L ASC NULLS FIRST, day#9L ASC NULLS FIRST], false, 0
       :  +- Exchange hashpartitioning(id#8L, day#9L, 200), 
ENSURE_REQUIREMENTS, [plan_id=117]
       :     +- ...
       +- Sort [id#29L ASC NULLS FIRST, day#30L ASC NULLS FIRST], false, 0
          +- Project [id#29L, day#30L, id#29L, day#30L, side#31, day_sum#54L]
             +- Window [sum(day#30L) windowspecdefinition(day#30L, id#29L, 
specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) 
AS day_sum#54L], [day#30L, id#29L]
                +- Sort [day#30L ASC NULLS FIRST, id#29L ASC NULLS FIRST], 
false, 0
                   +- Exchange hashpartitioning(day#30L, id#29L, 200), 
ENSURE_REQUIREMENTS, [plan_id=112]
                      +- ...
    
    +---+---+-----+------+
    | id|day|lefts|rights|
    +---+---+-----+------+
    |  0|  3|    0|     1|
    |  0|  4|    0|     1|
    |  0| 13|    1|     0|
    |  0| 27|    0|     1|
    |  0| 31|    0|     1|
    +---+---+-----+------+
    only showing top 5 rows
    ```
    
    Output now is
    ```
    == Physical Plan ==
    AdaptiveSparkPlan isFinalPlan=false
    +- FlatMapCoGroupsInPandas [id#8L, day#9L], [id#29L, day#30L], 
cogroup(id#8L, day#9L, side#10, id#29L, day#30L, side#31, day_sum#54L), 
[id#64L, day#65L, lefts#66, rights#67]
       :- Sort [id#8L ASC NULLS FIRST, day#9L ASC NULLS FIRST], false, 0
       :  +- Exchange hashpartitioning(id#8L, day#9L, 200), 
ENSURE_REQUIREMENTS, [plan_id=117]
       :     +- ...
       +- Sort [id#29L ASC NULLS FIRST, day#30L ASC NULLS FIRST], false, 0
          +- Exchange hashpartitioning(id#29L, day#30L, 200), 
ENSURE_REQUIREMENTS, [plan_id=118]
             +- Project [id#29L, day#30L, id#29L, day#30L, side#31, day_sum#54L]
                +- Window [sum(day#30L) windowspecdefinition(day#30L, id#29L, 
specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) 
AS day_sum#54L], [day#30L, id#29L]
                   +- Sort [day#30L ASC NULLS FIRST, id#29L ASC NULLS FIRST], 
false, 0
                      +- Exchange hashpartitioning(day#30L, id#29L, 200), 
ENSURE_REQUIREMENTS, [plan_id=112]
                         +- ...
    
    +---+---+-----+------+
    | id|day|lefts|rights|
    +---+---+-----+------+
    |  0| 13|    1|     1|
    |  0| 63|    1|     1|
    |  0| 89|    1|     1|
    |  0| 95|    1|     1|
    |  0| 96|    1|     1|
    +---+---+-----+------+
    only showing top 5 rows
    ```
    
    Spark 3.3 
[reworked](https://github.com/apache/spark/pull/32875/files#diff-e938569a4ca4eba8f7e10fe473d4f9c306ea253df151405bcaba880a601f075fR75-R76)
 `HashClusteredDistribution`, and is not sensitive to using 
`ClusteredDistribution`: #32875
    
    ### Does this PR introduce _any_ user-facing change?
    This fixes correctness.
    
    ### How was this patch tested?
    A unit test in `EnsureRequirementsSuite`.
    
    Closes #39717 from EnricoMi/branch-3.2-cogroup-window-bug.
    
    Authored-by: Enrico Minack <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../pyspark/sql/tests/test_pandas_cogrouped_map.py | 47 +++++++++++++++++-
 .../python/FlatMapCoGroupsInPandasExec.scala       |  6 +--
 .../exchange/EnsureRequirementsSuite.scala         | 58 +++++++++++++++++++++-
 3 files changed, 106 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py
index 94a12bfb3f6..485bb880437 100644
--- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py
@@ -17,8 +17,9 @@
 
 import unittest
 
-from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
+from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, 
sum
 from pyspark.sql.types import DoubleType, StructType, StructField, Row
+from pyspark.sql.window import Window
 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, 
have_pyarrow, \
     pandas_requirement_message, pyarrow_requirement_message
 from pyspark.testing.utils import QuietTest
@@ -215,6 +216,50 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
 
         self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
 
+    def test_with_window_function(self):
+        # SPARK-42168: a window function with same partition keys but 
differing key order
+        ids = 2
+        days = 100
+        vals = 10000
+        parts = 10
+
+        id_df = self.spark.range(ids)
+        day_df = self.spark.range(days).withColumnRenamed("id", "day")
+        vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
+        df = id_df.join(day_df).join(vals_df)
+
+        left_df = df.withColumnRenamed("value", 
"left").repartition(parts).cache()
+        # SPARK-42132: this bug requires us to alias all columns from df here
+        right_df = df.select(
+            col("id").alias("id"), col("day").alias("day"), 
col("value").alias("right")
+        ).repartition(parts).cache()
+
+        # note the column order is different to the groupBy("id", "day") 
column order below
+        window = Window.partitionBy("day", "id")
+
+        left_grouped_df = left_df.groupBy("id", "day")
+        right_grouped_df = right_df \
+            .withColumn("day_sum", sum(col("day")).over(window)) \
+            .groupBy("id", "day")
+
+        def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
+            return pd.DataFrame([{
+                "id": left["id"][0] if not left.empty else (
+                    right["id"][0] if not right.empty else None
+                ),
+                "day": left["day"][0] if not left.empty else (
+                    right["day"][0] if not right.empty else None
+                ),
+                "lefts": len(left.index),
+                "rights": len(right.index)
+            }])
+
+        df = left_grouped_df.cogroup(right_grouped_df) \
+            .applyInPandas(cogroup, schema="id long, day long, lefts integer, 
rights integer")
+
+        actual = df.orderBy("id", "day").take(days)
+        self.assertEqual(actual, [Row(0, day, vals, vals) for day in 
range(days)])
+
     @staticmethod
     def _test_with_key(left, right, isLeft):
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index e830ea6b546..e4503bdd9f4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -21,7 +21,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, 
PythonEvalType}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, 
HashClusteredDistribution, Partitioning}
 import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, 
SparkPlan}
 import org.apache.spark.sql.execution.python.PandasGroupUtils._
 import org.apache.spark.sql.types.StructType
@@ -66,8 +66,8 @@ case class FlatMapCoGroupsInPandasExec(
   override def outputPartitioning: Partitioning = left.outputPartitioning
 
   override def requiredChildDistribution: Seq[Distribution] = {
-    val leftDist = if (leftGroup.isEmpty) AllTuples else 
ClusteredDistribution(leftGroup)
-    val rightDist = if (rightGroup.isEmpty) AllTuples else 
ClusteredDistribution(rightGroup)
+    val leftDist = if (leftGroup.isEmpty) AllTuples else 
HashClusteredDistribution(leftGroup)
+    val rightDist = if (rightGroup.isEmpty) AllTuples else 
HashClusteredDistribution(rightGroup)
     leftDist :: rightDist :: Nil
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 0425be6f9a7..e61619fac6b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -17,13 +17,18 @@
 
 package org.apache.spark.sql.execution.exchange
 
-import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
PartitioningCollection}
 import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
+import org.apache.spark.sql.execution.window.WindowExec
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 
 class EnsureRequirementsSuite extends SharedSparkSession {
   private val exprA = Literal(1)
@@ -135,4 +140,55 @@ class EnsureRequirementsSuite extends SharedSparkSession {
       }.size == 2)
     }
   }
+
+  test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing 
key order") {
+    val lKey = AttributeReference("key", IntegerType)()
+    val lKey2 = AttributeReference("key2", IntegerType)()
+
+    val rKey = AttributeReference("key", IntegerType)()
+    val rKey2 = AttributeReference("key2", IntegerType)()
+    val rValue = AttributeReference("value", IntegerType)()
+
+    val left = DummySparkPlan()
+    val right = WindowExec(
+      Alias(
+        WindowExpression(
+          Sum(rValue).toAggregateExpression(),
+          WindowSpecDefinition(
+            Seq(rKey2, rKey),
+            Nil,
+            SpecifiedWindowFrame(RowFrame, UnboundedPreceding, 
UnboundedFollowing)
+          )
+        ), "sum")() :: Nil,
+      Seq(rKey2, rKey),
+      Nil,
+      DummySparkPlan()
+    )
+
+    val pythonUdf = PythonUDF("pyUDF", null,
+      StructType(Seq(StructField("value", IntegerType))),
+      Seq.empty,
+      PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+      true)
+
+    val flapMapCoGroup = FlatMapCoGroupsInPandasExec(
+      Seq(lKey, lKey2),
+      Seq(rKey, rKey2),
+      pythonUdf,
+      AttributeReference("value", IntegerType)() :: Nil,
+      left,
+      right
+    )
+
+    val result = EnsureRequirements.apply(flapMapCoGroup)
+    result match {
+      case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _,
+        SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) =>
+        assert(leftKeys === Seq(lKey, lKey2))
+        assert(rightKeys === Seq(rKey, rKey2))
+        assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder)
+        assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder)
+      case other => fail(other.toString)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to