This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d9ca9820384 [SPARK-42168][SQL][PYTHON][FOLLOW-UP] Test
FlatMapCoGroupsInPandas with Window function
d9ca9820384 is described below
commit d9ca9820384b84aa5004f4c407d72d3fbc6cbb97
Author: Enrico Minack <[email protected]>
AuthorDate: Fri Jan 27 09:20:08 2023 -0800
[SPARK-42168][SQL][PYTHON][FOLLOW-UP] Test FlatMapCoGroupsInPandas with
Window function
### What changes were proposed in this pull request?
This ports tests from #39717 in branch-3.2 to master.
### Why are the changes needed?
To make sure this use case is tested.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
E2E test in `test_pandas_cogrouped_map.py` and analysis test in
`EnsureRequirementsSuite.scala`.
Closes #39752 from EnricoMi/branch-cogroup-window-bug-test.
Authored-by: Enrico Minack <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
---
.../sql/tests/pandas/test_pandas_cogrouped_map.py | 54 ++++++++++++++++++++-
.../exchange/EnsureRequirementsSuite.scala | 56 ++++++++++++++++++++++
2 files changed, 109 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index d92a105f5d4..5cbc9e1caa4 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -18,8 +18,9 @@
import unittest
from typing import cast
-from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
+from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf,
sum
from pyspark.sql.types import DoubleType, StructType, StructField, Row
+from pyspark.sql.window import Window
from pyspark.errors import IllegalArgumentException, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
@@ -365,6 +366,57 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
+ def test_with_window_function(self):
+ # SPARK-42168: a window function with same partition keys but
differing key order
+ ids = 2
+ days = 100
+ vals = 10000
+ parts = 10
+
+ id_df = self.spark.range(ids)
+ day_df = self.spark.range(days).withColumnRenamed("id", "day")
+ vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
+ df = id_df.join(day_df).join(vals_df)
+
+ left_df = df.withColumnRenamed("value",
"left").repartition(parts).cache()
+ # SPARK-42132: this bug requires us to alias all columns from df here
+ right_df = (
+ df.select(col("id").alias("id"), col("day").alias("day"),
col("value").alias("right"))
+ .repartition(parts)
+ .cache()
+ )
+
+ # note the column order is different to the groupBy("id", "day")
column order below
+ window = Window.partitionBy("day", "id")
+
+ left_grouped_df = left_df.groupBy("id", "day")
+ right_grouped_df = right_df.withColumn("day_sum",
sum(col("day")).over(window)).groupBy(
+ "id", "day"
+ )
+
+ def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
+ return pd.DataFrame(
+ [
+ {
+ "id": left["id"][0]
+ if not left.empty
+ else (right["id"][0] if not right.empty else None),
+ "day": left["day"][0]
+ if not left.empty
+ else (right["day"][0] if not right.empty else None),
+ "lefts": len(left.index),
+ "rights": len(right.index),
+ }
+ ]
+ )
+
+ df = left_grouped_df.cogroup(right_grouped_df).applyInPandas(
+ cogroup, schema="id long, day long, lefts integer, rights integer"
+ )
+
+ actual = df.orderBy("id", "day").take(days)
+ self.assertEqual(actual, [Row(0, day, vals, vals) for day in
range(days)])
+
@staticmethod
def _test_with_key(left, right, isLeft):
def right_assign_key(key, lft, rgt):
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index bc1fd7a5fa5..844037339ab 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution.exchange
+import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
@@ -25,9 +27,12 @@ import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
+import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import
org.apache.spark.sql.internal.SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
@@ -1104,6 +1109,57 @@ class EnsureRequirementsSuite extends SharedSparkSession
{
}
}
+ test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing
key order") {
+ val lKey = AttributeReference("key", IntegerType)()
+ val lKey2 = AttributeReference("key2", IntegerType)()
+
+ val rKey = AttributeReference("key", IntegerType)()
+ val rKey2 = AttributeReference("key2", IntegerType)()
+ val rValue = AttributeReference("value", IntegerType)()
+
+ val left = DummySparkPlan()
+ val right = WindowExec(
+ Alias(
+ WindowExpression(
+ Sum(rValue).toAggregateExpression(),
+ WindowSpecDefinition(
+ Seq(rKey2, rKey),
+ Nil,
+ SpecifiedWindowFrame(RowFrame, UnboundedPreceding,
UnboundedFollowing)
+ )
+ ), "sum")() :: Nil,
+ Seq(rKey2, rKey),
+ Nil,
+ DummySparkPlan()
+ )
+
+ val pythonUdf = PythonUDF("pyUDF", null,
+ StructType(Seq(StructField("value", IntegerType))),
+ Seq.empty,
+ PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+ true)
+
+ val flapMapCoGroup = FlatMapCoGroupsInPandasExec(
+ Seq(lKey, lKey2),
+ Seq(rKey, rKey2),
+ pythonUdf,
+ AttributeReference("value", IntegerType)() :: Nil,
+ left,
+ right
+ )
+
+ val result = EnsureRequirements.apply(flapMapCoGroup)
+ result match {
+ case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _,
+ SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) =>
+ assert(leftKeys === Seq(lKey, lKey2))
+ assert(rightKeys === Seq(rKey, rKey2))
+ assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder)
+ assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder)
+ case other => fail(other.toString)
+ }
+ }
+
def bucket(numBuckets: Int, expr: Expression): TransformExpression = {
TransformExpression(BucketFunction, Seq(expr), Some(numBuckets))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]