This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a4f2870b0855 [SPARK-50525][SQL] Define
InsertMapSortInRepartitionExpressions Optimizer Rule
a4f2870b0855 is described below
commit a4f2870b08551031ace305a953ace23a6aa6e71a
Author: Dima <[email protected]>
AuthorDate: Fri Jan 10 12:20:37 2025 +0800
[SPARK-50525][SQL] Define InsertMapSortInRepartitionExpressions Optimizer
Rule
### What changes were proposed in this pull request?
In the current version of Spark, its possible to use `MapType` as column
for repartitioning. But `MapData` does not implement `equals` and `hashCode`
(in according to [SPARK-9415](https://issues.apache.org/jira/browse/SPARK-9415)
and [[SPARK-16135][SQL] Remove hashCode and equals in
ArrayBasedMapData](https://github.com/apache/spark/pull/13847)). Considering
that, hash value for same Maps can be different.
In an attempt to run `xxhash64` or `hash` function on `MapType`,
```org.apache.spark.sql.catalyst.ExtendedAnalysisException:
[DATATYPE_MISMATCH.HASH_MAP_TYPE] Cannot resolve "xxhash64(value)" due to data
type mismatch: Input to the function `xxhash64` cannot contain elements of the
"MAP" type. In Spark, same maps may have different hashcode, thus hash
expressions are prohibited on "MAP" elements. To restore previous behavior set
"spark.sql.legacy.allowHashOnMapType" to "true".;``` wil [...]
Also, when trying to run `ds.distinct(col("value"))`, where `value` has
`MapType`, the following exception is thrown:
```org.apache.spark.sql.catalyst.ExtendedAnalysisException:
[UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE] The feature is not supported:
Cannot have MAP type columns in DataFrame which calls set operations
(INTERSECT, EXCEPT, etc.), but the type of column `value` is "MAP<INT,
STRING>".;```
With the above consideration, a new `InsertMapSortInRepartitionExpressions`
`Rule[LogicalPlan]` was implemented to insert `mapsort` for every `MapType` in
`RepartitionByExpression.partitionExpressions`.
### Why are the changes needed?
To keep `repartition` API for MapType consistent.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49144 from ostronaut/features/map_repartition.
Authored-by: Dima <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/CheckAnalysis.scala | 2 +-
...essions.scala => InsertMapSortExpression.scala} | 66 ++++++++++++++++------
.../spark/sql/catalyst/optimizer/Optimizer.scala | 1 +
.../org/apache/spark/sql/DataFrameSuite.scala | 37 ++++++++++--
4 files changed, 84 insertions(+), 22 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6cd394fd79e9..46ca8e793218 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -884,7 +884,7 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
o.failAnalysis(
errorClass = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
messageParameters = Map(
- "expr" -> variantExpr.sql,
+ "expr" -> toSQLExpr(variantExpr),
"dataType" -> toSQLType(variantExpr.dataType)))
case o if o.expressions.exists(!_.deterministic) &&
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala
similarity index 69%
rename from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
rename to
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala
index b6ced6c49a36..9e613c54a49b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortExpression.scala
@@ -20,32 +20,30 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform,
CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction,
Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression,
NamedLambdaVariable}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan,
Project, RepartitionByExpression}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern
+import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE,
REPARTITION_OPERATION}
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
import org.apache.spark.util.ArrayImplicits.SparkArrayOps
/**
- * Adds [[MapSort]] to group expressions containing map columns, as the
key/value pairs need to be
- * in the correct order before grouping:
+ * Adds [[MapSort]] to [[Aggregate]] expressions containing map columns,
+ * as the key/value pairs need to be in the correct order before grouping:
*
- * SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
+ * SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT _groupingmapsort as map_column, COUNT(*) FROM (
* SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
* ) GROUP BY _groupingmapsort
*/
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
- private def shouldAddMapSort(expr: Expression): Boolean = {
- expr.dataType.existsRecursively(_.isInstanceOf[MapType])
- }
+ import InsertMapSortExpression._
override def apply(plan: LogicalPlan): LogicalPlan = {
- if (!plan.containsPattern(TreePattern.AGGREGATE)) {
+ if (!plan.containsPattern(AGGREGATE)) {
return plan
}
val shouldRewrite = plan.exists {
- case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort)
=> true
+ case agg: Aggregate if
agg.groupingExpressions.exists(mapTypeExistsRecursively) => true
case _ => false
}
if (!shouldRewrite) {
@@ -53,8 +51,7 @@ object InsertMapSortInGroupingExpressions extends
Rule[LogicalPlan] {
}
plan transformUpWithNewOutput {
- case agg @ Aggregate(groupingExprs, aggregateExpressions, child, _)
- if agg.groupingExpressions.exists(shouldAddMapSort) =>
+ case agg @ Aggregate(groupingExprs, aggregateExpressions, child, hint) =>
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
val newGroupingKeys = groupingExprs.map { expr =>
val inserted = insertMapSortRecursively(expr)
@@ -77,15 +74,53 @@ object InsertMapSortInGroupingExpressions extends
Rule[LogicalPlan] {
}.asInstanceOf[NamedExpression]
}
val newChild = Project(child.output ++ exprToMapSort.values, child)
- val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
+ val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild,
hint)
newAgg -> agg.output.zip(newAgg.output)
}
}
+}
+
+/**
+ * Adds [[MapSort]] to [[RepartitionByExpression]] expressions containing map
columns,
+ * as the key/value pairs need to be in the correct order before
repartitioning:
+ *
+ * SELECT * FROM TABLE DISTRIBUTE BY map_column =>
+ * SELECT * FROM TABLE DISTRIBUTE BY map_sort(map_column)
+ */
+object InsertMapSortInRepartitionExpressions extends Rule[LogicalPlan] {
+ import InsertMapSortExpression._
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ plan.transformUpWithPruning(_.containsPattern(REPARTITION_OPERATION)) {
+ case rep: RepartitionByExpression
+ if rep.partitionExpressions.exists(mapTypeExistsRecursively) =>
+ val exprToMapSort = new mutable.HashMap[Expression, Expression]
+ val newPartitionExprs = rep.partitionExpressions.map { expr =>
+ val inserted = insertMapSortRecursively(expr)
+ if (expr.ne(inserted)) {
+ exprToMapSort.getOrElseUpdate(expr.canonicalized, inserted)
+ } else {
+ expr
+ }
+ }
+ rep.copy(partitionExpressions = newPartitionExprs)
+ }
+ }
+}
+
+private[optimizer] object InsertMapSortExpression {
/**
- * Inserts MapSort recursively taking into account when it is nested inside
a struct or array.
+ * Returns true if the expression contains a [[MapType]] in DataType tree.
*/
- private def insertMapSortRecursively(e: Expression): Expression = {
+ def mapTypeExistsRecursively(expr: Expression): Boolean = {
+ expr.dataType.existsRecursively(_.isInstanceOf[MapType])
+ }
+
+ /**
+ * Inserts [[MapSort]] recursively taking into account when it is nested
inside a struct or array.
+ */
+ def insertMapSortRecursively(e: Expression): Expression = {
e.dataType match {
case m: MapType =>
// Check if value type of MapType contains MapType (possibly nested)
@@ -122,5 +157,4 @@ object InsertMapSortInGroupingExpressions extends
Rule[LogicalPlan] {
case _ => e
}
}
-
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c0c76dd44ad5..8ee2226947ec 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -322,6 +322,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// so the grouping keys can only be attribute and literal which makes
// `InsertMapSortInGroupingExpressions` easy to insert `MapSort`.
InsertMapSortInGroupingExpressions,
+ InsertMapSortInRepartitionExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 0972a63a2495..317a88edf8e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -316,7 +316,7 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](df.repartition(5, col("v"))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
- "expr" -> "v",
+ "expr" -> "\"v\"",
"dataType" -> "\"VARIANT\"")
)
// nested variant column
@@ -324,7 +324,7 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](df.repartition(5, col("s"))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
- "expr" -> "s",
+ "expr" -> "\"s\"",
"dataType" -> "\"STRUCT<v: VARIANT NOT NULL>\"")
)
// variant producing expression
@@ -333,7 +333,7 @@ class DataFrameSuite extends QueryTest
intercept[AnalysisException](df.repartition(5,
parse_json(col("id").cast("string")))),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
- "expr" -> "parse_json(CAST(id AS STRING))",
+ "expr" -> "\"parse_json(CAST(id AS STRING))\"",
"dataType" -> "\"VARIANT\"")
)
// Partitioning by non-variant column works
@@ -350,7 +350,7 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](sql("SELECT * FROM tv
DISTRIBUTE BY v")),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
- "expr" -> "tv.v",
+ "expr" -> "\"v\"",
"dataType" -> "\"VARIANT\""),
context = ExpectedContext(
fragment = "DISTRIBUTE BY v",
@@ -361,7 +361,7 @@ class DataFrameSuite extends QueryTest
exception = intercept[AnalysisException](sql("SELECT * FROM tv
DISTRIBUTE BY s")),
condition = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT",
parameters = Map(
- "expr" -> "tv.s",
+ "expr" -> "\"s\"",
"dataType" -> "\"STRUCT<v: VARIANT NOT NULL>\""),
context = ExpectedContext(
fragment = "DISTRIBUTE BY s",
@@ -428,6 +428,33 @@ class DataFrameSuite extends QueryTest
}
}
+ test("repartition by MapType") {
+ Seq("int", "long", "float", "double", "decimal(10, 2)", "string",
"varchar(6)").foreach { dt =>
+ val df = spark.range(20)
+ .withColumn("c1",
+ when(col("id") % 3 === 1, typedLit(Map(1 -> 1)))
+ .when(col("id") % 3 === 2, typedLit(Map(1 -> 1, 2 -> 2)))
+ .otherwise(typedLit(Map(2 -> 2, 1 -> 1))).cast(s"map<$dt, $dt>"))
+ .withColumn("c2", typedLit(Map(1 -> null)).cast(s"map<$dt, $dt>"))
+ .withColumn("c3", lit(null).cast(s"map<$dt, $dt>"))
+
+ assertPartitionNumber(df.repartition(4, col("c1")), 2)
+ assertPartitionNumber(df.repartition(4, col("c2")), 1)
+ assertPartitionNumber(df.repartition(4, col("c3")), 1)
+ assertPartitionNumber(df.repartition(4, col("c1"), col("c2")), 2)
+ assertPartitionNumber(df.repartition(4, col("c1"), col("c3")), 2)
+ assertPartitionNumber(df.repartition(4, col("c1"), col("c2"),
col("c3")), 2)
+ assertPartitionNumber(df.repartition(4, col("c2"), col("c3")), 2)
+ }
+ }
+
+ private def assertPartitionNumber(df: => DataFrame, max: Int): Unit = {
+ val dfGrouped = df.groupBy(spark_partition_id()).count()
+ // Result number of partition can be lower or equal to max,
+ // but no more than that.
+ assert(dfGrouped.count() <= max, dfGrouped.queryExecution.simpleString)
+ }
+
test("coalesce") {
intercept[IllegalArgumentException] {
testData.select("key").coalesce(0)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]