This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a4f2870b0855 [SPARK-50525][SQL] Define 
InsertMapSortInRepartitionExpressions Optimizer Rule
a4f2870b0855 is described below

commit a4f2870b08551031ace305a953ace23a6aa6e71a
Author: Dima <[email protected]>
AuthorDate: Fri Jan 10 12:20:37 2025 +0800

    [SPARK-50525][SQL] Define InsertMapSortInRepartitionExpressions Optimizer 
Rule
    
    ### What changes were proposed in this pull request?
    In the current version of Spark, its possible to use `MapType` as column 
for repartitioning. But `MapData` does not implement `equals` and `hashCode` 
(in according to [SPARK-9415](https://issues.apache.org/jira/browse/SPARK-9415) 
and [[SPARK-16135][SQL] Remove hashCode and equals in 
ArrayBasedMapData](https://github.com/apache/spark/pull/13847)). Considering 
that, hash value for same Maps can be different.
    
    In an attempt to run `xxhash64` or `hash` function on `MapType`, 
```org.apache.spark.sql.catalyst.ExtendedAnalysisException: 
[DATATYPE_MISMATCH.HASH_MAP_TYPE] Cannot resolve "xxhash64(value)" due to data 
type mismatch: Input to the function `xxhash64` cannot contain elements of the 
"MAP" type. In Spark, same maps may have different hashcode, thus hash 
expressions are prohibited on "MAP" elements. To restore previous behavior set 
"spark.sql.legacy.allowHashOnMapType" to "true".;``` wil [...]
    
    Also, when trying to run `ds.distinct(col("value"))`, where `value` has 
`MapType`, the following exception is thrown: 
```org.apache.spark.sql.catalyst.ExtendedAnalysisException: 
[UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE] The feature is not supported: 
Cannot have MAP type columns in DataFrame which calls set operations 
(INTERSECT, EXCEPT, etc.), but the type of column `value` is "MAP<INT, 
STRING>".;```
    
    With the above consideration, a new `InsertMapSortInRepartitionExpressions` 
`Rule[LogicalPlan]` was implemented to insert `mapsort` for every `MapType` in 
`RepartitionByExpression.partitionExpressions`.
    
    ### Why are the changes needed?
    
    To keep `repartition` API for MapType consistent.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #49144 from ostronaut/features/map_repartition.
    
    Authored-by: Dima <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      |  2 +-
 ...essions.scala => InsertMapSortExpression.scala} | 66 ++++++++++++++++------
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  1 +
 .../org/apache/spark/sql/DataFrameSuite.scala      | 37 ++++++++++--
 4 files changed, 84 insertions(+), 22 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6cd394fd79e9..46ca8e793218 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -884,7 +884,7 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
             o.failAnalysis(
               errorClass = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
               messageParameters = Map(
-                "expr" -> variantExpr.sql,
+                "expr" -> toSQLExpr(variantExpr),
                 "dataType" -> toSQLType(variantExpr.dataType)))
 
           case o if o.expressions.exists(!_.deterministic) &&
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala
similarity index 69%
rename from 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
rename to 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala
index b6ced6c49a36..9e613c54a49b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala
@@ -20,32 +20,30 @@ package org.apache.spark.sql.catalyst.optimizer
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, 
CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, 
Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, 
NamedLambdaVariable}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
Project, RepartitionByExpression}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern
+import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, 
REPARTITION_OPERATION}
 import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
 import org.apache.spark.util.ArrayImplicits.SparkArrayOps
 
 /**
- * Adds [[MapSort]] to group expressions containing map columns, as the 
key/value pairs need to be
- * in the correct order before grouping:
+ * Adds [[MapSort]] to [[Aggregate]] expressions containing map columns,
+ * as the key/value pairs need to be in the correct order before grouping:
  *
- * SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column  =>
+ * SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
  * SELECT _groupingmapsort as map_column, COUNT(*) FROM (
  *   SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
  * ) GROUP BY _groupingmapsort
  */
 object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
-  private def shouldAddMapSort(expr: Expression): Boolean = {
-    expr.dataType.existsRecursively(_.isInstanceOf[MapType])
-  }
+  import InsertMapSortExpression._
 
   override def apply(plan: LogicalPlan): LogicalPlan = {
-    if (!plan.containsPattern(TreePattern.AGGREGATE)) {
+    if (!plan.containsPattern(AGGREGATE)) {
       return plan
     }
     val shouldRewrite = plan.exists {
-      case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) 
=> true
+      case agg: Aggregate if 
agg.groupingExpressions.exists(mapTypeExistsRecursively) => true
       case _ => false
     }
     if (!shouldRewrite) {
@@ -53,8 +51,7 @@ object InsertMapSortInGroupingExpressions extends 
Rule[LogicalPlan] {
     }
 
     plan transformUpWithNewOutput {
-      case agg @ Aggregate(groupingExprs, aggregateExpressions, child, _)
-          if agg.groupingExpressions.exists(shouldAddMapSort) =>
+      case agg @ Aggregate(groupingExprs, aggregateExpressions, child, hint) =>
         val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
         val newGroupingKeys = groupingExprs.map { expr =>
           val inserted = insertMapSortRecursively(expr)
@@ -77,15 +74,53 @@ object InsertMapSortInGroupingExpressions extends 
Rule[LogicalPlan] {
             }.asInstanceOf[NamedExpression]
         }
         val newChild = Project(child.output ++ exprToMapSort.values, child)
-        val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
+        val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild, 
hint)
         newAgg -> agg.output.zip(newAgg.output)
     }
   }
+}
+
+/**
+ * Adds [[MapSort]] to [[RepartitionByExpression]] expressions containing map 
columns,
+ * as the key/value pairs need to be in the correct order before 
repartitioning:
+ *
+ * SELECT * FROM TABLE DISTRIBUTE BY map_column =>
+ * SELECT * FROM TABLE DISTRIBUTE BY map_sort(map_column)
+ */
+object InsertMapSortInRepartitionExpressions extends Rule[LogicalPlan] {
+  import InsertMapSortExpression._
+
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    plan.transformUpWithPruning(_.containsPattern(REPARTITION_OPERATION)) {
+      case rep: RepartitionByExpression
+        if rep.partitionExpressions.exists(mapTypeExistsRecursively) =>
+        val exprToMapSort = new mutable.HashMap[Expression, Expression]
+        val newPartitionExprs = rep.partitionExpressions.map { expr =>
+          val inserted = insertMapSortRecursively(expr)
+          if (expr.ne(inserted)) {
+            exprToMapSort.getOrElseUpdate(expr.canonicalized, inserted)
+          } else {
+            expr
+          }
+        }
+        rep.copy(partitionExpressions = newPartitionExprs)
+    }
+  }
+}
+
+private[optimizer] object InsertMapSortExpression {
 
   /**
-   * Inserts MapSort recursively taking into account when it is nested inside 
a struct or array.
+   * Returns true if the expression contains a [[MapType]] in DataType tree.
    */
-  private def insertMapSortRecursively(e: Expression): Expression = {
+  def mapTypeExistsRecursively(expr: Expression): Boolean = {
+    expr.dataType.existsRecursively(_.isInstanceOf[MapType])
+  }
+
+  /**
+   * Inserts [[MapSort]] recursively taking into account when it is nested 
inside a struct or array.
+   */
+  def insertMapSortRecursively(e: Expression): Expression = {
     e.dataType match {
       case m: MapType =>
         // Check if value type of MapType contains MapType (possibly nested)
@@ -122,5 +157,4 @@ object InsertMapSortInGroupingExpressions extends 
Rule[LogicalPlan] {
       case _ => e
     }
   }
-
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c0c76dd44ad5..8ee2226947ec 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -322,6 +322,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
       // so the grouping keys can only be attribute and literal which makes
       // `InsertMapSortInGroupingExpressions` easy to insert `MapSort`.
       InsertMapSortInGroupingExpressions,
+      InsertMapSortInRepartitionExpressions,
       ComputeCurrentTime,
       ReplaceCurrentLike(catalogManager),
       SpecialDatetimeValues,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 0972a63a2495..317a88edf8e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -316,7 +316,7 @@ class DataFrameSuite extends QueryTest
       exception = intercept[AnalysisException](df.repartition(5, col("v"))),
       condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
       parameters = Map(
-        "expr" -> "v",
+        "expr" -> "\"v\"",
         "dataType" -> "\"VARIANT\"")
     )
     // nested variant column
@@ -324,7 +324,7 @@ class DataFrameSuite extends QueryTest
       exception = intercept[AnalysisException](df.repartition(5, col("s"))),
       condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
       parameters = Map(
-        "expr" -> "s",
+        "expr" -> "\"s\"",
         "dataType" -> "\"STRUCT<v: VARIANT NOT NULL>\"")
     )
     // variant producing expression
@@ -333,7 +333,7 @@ class DataFrameSuite extends QueryTest
         intercept[AnalysisException](df.repartition(5, 
parse_json(col("id").cast("string")))),
       condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
       parameters = Map(
-        "expr" -> "parse_json(CAST(id AS STRING))",
+        "expr" -> "\"parse_json(CAST(id AS STRING))\"",
         "dataType" -> "\"VARIANT\"")
     )
     // Partitioning by non-variant column works
@@ -350,7 +350,7 @@ class DataFrameSuite extends QueryTest
         exception = intercept[AnalysisException](sql("SELECT * FROM tv 
DISTRIBUTE BY v")),
         condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
         parameters = Map(
-          "expr" -> "tv.v",
+          "expr" -> "\"v\"",
           "dataType" -> "\"VARIANT\""),
         context = ExpectedContext(
           fragment = "DISTRIBUTE BY v",
@@ -361,7 +361,7 @@ class DataFrameSuite extends QueryTest
         exception = intercept[AnalysisException](sql("SELECT * FROM tv 
DISTRIBUTE BY s")),
         condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
         parameters = Map(
-          "expr" -> "tv.s",
+          "expr" -> "\"s\"",
           "dataType" -> "\"STRUCT<v: VARIANT NOT NULL>\""),
         context = ExpectedContext(
           fragment = "DISTRIBUTE BY s",
@@ -428,6 +428,33 @@ class DataFrameSuite extends QueryTest
     }
   }
 
+  test("repartition by MapType") {
+    Seq("int", "long", "float", "double", "decimal(10, 2)", "string", 
"varchar(6)").foreach { dt =>
+      val df = spark.range(20)
+        .withColumn("c1",
+          when(col("id") % 3 === 1, typedLit(Map(1 -> 1)))
+            .when(col("id") % 3 === 2, typedLit(Map(1 -> 1, 2 -> 2)))
+            .otherwise(typedLit(Map(2 -> 2, 1 -> 1))).cast(s"map<$dt, $dt>"))
+        .withColumn("c2", typedLit(Map(1 -> null)).cast(s"map<$dt, $dt>"))
+        .withColumn("c3", lit(null).cast(s"map<$dt, $dt>"))
+
+      assertPartitionNumber(df.repartition(4, col("c1")), 2)
+      assertPartitionNumber(df.repartition(4, col("c2")), 1)
+      assertPartitionNumber(df.repartition(4, col("c3")), 1)
+      assertPartitionNumber(df.repartition(4, col("c1"), col("c2")), 2)
+      assertPartitionNumber(df.repartition(4, col("c1"), col("c3")), 2)
+      assertPartitionNumber(df.repartition(4, col("c1"), col("c2"), 
col("c3")), 2)
+      assertPartitionNumber(df.repartition(4, col("c2"), col("c3")), 2)
+    }
+  }
+
+  private def assertPartitionNumber(df: => DataFrame, max: Int): Unit = {
+    val dfGrouped = df.groupBy(spark_partition_id()).count()
+    // Result number of partition can be lower or equal to max,
+    // but no more than that.
+    assert(dfGrouped.count() <= max, dfGrouped.queryExecution.simpleString)
+  }
+
   test("coalesce") {
     intercept[IllegalArgumentException] {
       testData.select("key").coalesce(0)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to