This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7398e93a2be [SPARK-44139][SQL] Discard completely pushed down filters
in group-based MERGE operations
7398e93a2be is described below
commit 7398e93a2be226d0e71af5b632fa640b37ea5e43
Author: aokolnychyi <[email protected]>
AuthorDate: Thu Jun 22 21:28:21 2023 -0700
[SPARK-44139][SQL] Discard completely pushed down filters in group-based
MERGE operations
### What changes were proposed in this pull request?
This PR adds logic to discard completely pushed down filters in group-based
MERGE operations.
### Why are the changes needed?
These changes are needed to simplify join conditions used in group-based
MERGE operations to avoid evaluating unnecessary expressions.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
This PR comes with tests.
Closes #41700 from aokolnychyi/spark-44139.
Authored-by: aokolnychyi <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/connector/catalog/InMemoryBaseTable.scala | 30 ++++-
.../GroupBasedRowLevelOperationScanPlanning.scala | 68 ++++++++--
.../sql/connector/MergeIntoTableSuiteBase.scala | 140 +++++++++++++++++++++
3 files changed, 221 insertions(+), 17 deletions(-)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 22155ade0aa..a0a4d8bdee9 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -292,12 +292,31 @@ abstract class InMemoryBaseTable(
new InMemoryScanBuilder(schema)
}
+ private def canEvaluate(filter: Filter): Boolean = {
+ if (partitioning.length == 1 && partitioning.head.references.length == 1) {
+ filter match {
+ case In(attrName, _) if attrName ==
partitioning.head.references.head.toString => true
+ case _ => false
+ }
+ } else {
+ false
+ }
+ }
+
class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
private var schema: StructType = tableSchema
+ private var postScanFilters: Array[Filter] = Array.empty
+ private var evaluableFilters: Array[Filter] = Array.empty
+ private var _pushedFilters: Array[Filter] = Array.empty
- override def build: Scan =
- InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema,
tableSchema)
+ override def build: Scan = {
+ val scan = InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]),
schema, tableSchema)
+ if (evaluableFilters.nonEmpty) {
+ scan.filter(evaluableFilters)
+ }
+ scan
+ }
override def pruneColumns(requiredSchema: StructType): Unit = {
// The required schema could contain conflict-renamed metadata columns,
so we need to match
@@ -310,11 +329,12 @@ abstract class InMemoryBaseTable(
schema = StructType(prunedFields)
}
- private var _pushedFilters: Array[Filter] = Array.empty
-
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ val (evaluableFilters, postScanFilters) = filters.partition(canEvaluate)
+ this.evaluableFilters = evaluableFilters
+ this.postScanFilters = postScanFilters
this._pushedFilters = filters
- this._pushedFilters
+ postScanFilters
}
override def pushedFilters(): Array[Filter] = this._pushedFilters
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
index 546ce6d8deb..11dddb50831 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
AttributeSet, Expression, PredicateHelper, SubqueryExpression}
-import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference,
AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
+import org.apache.spark.sql.catalyst.planning.{GroupBasedRowLevelOperation,
PhysicalOperation}
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan,
ReplaceData}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.filter.{Predicate =>
V2Filter}
import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources.Filter
@@ -40,10 +42,14 @@ object GroupBasedRowLevelOperationScanPlanning extends
Rule[LogicalPlan] with Pr
// push down the filter from the command condition instead of the filter
in the rewrite plan,
// which is negated for data sources that only support replacing groups of
data (e.g. files)
case GroupBasedRowLevelOperation(rd: ReplaceData, cond, _, relation:
DataSourceV2Relation) =>
+ assert(cond.deterministic, "row-level operation conditions must be
deterministic")
+
val table = relation.table.asRowLevelOperationTable
val scanBuilder = table.newScanBuilder(relation.options)
- val (pushedFilters, remainingFilters) = pushFilters(cond,
relation.output, scanBuilder)
+ val (pushedFilters, evaluatedFilters, postScanFilters) =
+ pushFilters(cond, relation.output, scanBuilder)
+
val pushedFiltersStr = if (pushedFilters.isLeft) {
pushedFilters.left.get.mkString(", ")
} else {
@@ -56,29 +62,67 @@ object GroupBasedRowLevelOperationScanPlanning extends
Rule[LogicalPlan] with Pr
s"""
|Pushing operators to ${relation.name}
|Pushed filters: $pushedFiltersStr
- |Filters that were not pushed: ${remainingFilters.mkString(", ")}
+ |Filters evaluated on data source side:
${evaluatedFilters.mkString(", ")}
+ |Filters evaluated on Spark side: ${postScanFilters.mkString(", ")}
|Output: ${output.mkString(", ")}
""".stripMargin)
- // replace DataSourceV2Relation with DataSourceV2ScanRelation for the
row operation table
- // there may be multiple read relations for UPDATEs that are rewritten
as UNION
- rd transform {
+ rd transformDown {
+ // simplify the join condition in MERGE operations by discarding
already evaluated filters
+ case j @ Join(
+ PhysicalOperation(_, _, r: DataSourceV2Relation), _, _,
Some(cond), _)
+ if rd.operation.command == MERGE && evaluatedFilters.nonEmpty &&
r.table.eq(table) =>
+ j.copy(condition = Some(optimizeMergeJoinCondition(cond,
evaluatedFilters)))
+
+ // replace DataSourceV2Relation with DataSourceV2ScanRelation for the
row operation table
+ // there may be multiple read relations for UPDATEs that are rewritten
as UNION
case r: DataSourceV2Relation if r.table eq table =>
DataSourceV2ScanRelation(r, scan,
PushDownUtils.toOutputAttrs(scan.readSchema(), r))
}
}
+ // pushes down the operation condition and returns the following information:
+ // - pushed down filters
+ // - filter expressions that are fully evaluated on the data source side
+ // (such filters can be discarded and don't have to be evaluated again on
the Spark side)
+ // - post-scan filter expressions that must be evaluated on the Spark side
+ // (such filters can overlap with pushed down filters, e.g. Parquet row
group filtering)
private def pushFilters(
cond: Expression,
tableAttrs: Seq[AttributeReference],
- scanBuilder: ScanBuilder): (Either[Seq[Filter], Seq[V2Filter]],
Seq[Expression]) = {
+ scanBuilder: ScanBuilder)
+ : (Either[Seq[Filter], Seq[V2Filter]], Seq[Expression], Seq[Expression]) = {
+
+ val (filtersWithSubquery, filtersWithoutSubquery) = findTableFilters(cond,
tableAttrs)
+
+ val (pushedFilters, postScanFiltersWithoutSubquery) =
+ PushDownUtils.pushFilters(scanBuilder, filtersWithoutSubquery)
+
+ val postScanFilterSetWithoutSubquery =
ExpressionSet(postScanFiltersWithoutSubquery)
+ val evaluatedFilters = filtersWithoutSubquery.filterNot { filter =>
+ postScanFilterSetWithoutSubquery.contains(filter)
+ }
+ val postScanFilters = postScanFiltersWithoutSubquery ++ filtersWithSubquery
+
+ (pushedFilters, evaluatedFilters, postScanFilters)
+ }
+
+ private def findTableFilters(
+ cond: Expression,
+ tableAttrs: Seq[AttributeReference]): (Seq[Expression], Seq[Expression])
= {
val tableAttrSet = AttributeSet(tableAttrs)
val filters =
splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet))
val normalizedFilters = DataSourceStrategy.normalizeExprs(filters,
tableAttrs)
- val (_, normalizedFiltersWithoutSubquery) =
- normalizedFilters.partition(SubqueryExpression.hasSubquery)
+ normalizedFilters.partition(SubqueryExpression.hasSubquery)
+ }
- PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery)
+ private def optimizeMergeJoinCondition(
+ cond: Expression,
+ evaluatedFilters: Seq[Expression]): Expression = {
+ val evaluatedFilterSet = ExpressionSet(evaluatedFilters)
+ val predicates = splitConjunctivePredicates(cond)
+ val remainingPredicates = predicates.filterNot(evaluatedFilterSet.contains)
+ remainingPredicates.reduceLeftOption(And).getOrElse(TrueLiteral)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index bd641b2026b..e7555c23fa4 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connector
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo,
In, Not}
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue}
import org.apache.spark.sql.connector.expressions.LiteralValue
@@ -1422,6 +1423,145 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase {
}
}
+ test("all target filters are evaluated on data source side") {
+ withTempView("source") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "hr" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |{ "pk": 4, "salary": 400, "dep": "software" }
+ |{ "pk": 5, "salary": 500, "dep": "software" }
+ |""".stripMargin)
+
+ val sourceDF = Seq(1, 2, 3, 6).toDF("pk")
+ sourceDF.createOrReplaceTempView("source")
+
+ val executedPlan = executeAndKeepPlan {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk AND t.DeP IN ('hr', 'software')
+ |WHEN MATCHED THEN
+ | UPDATE SET t.salary = t.salary + 1
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr')
+ |""".stripMargin)
+ }
+
+ val expressions =
flatMap(executedPlan)(_.expressions.flatMap(splitConjunctivePredicates))
+ val inFilterPushed = expressions.forall {
+ case In(attr: AttributeReference, _) if attr.name == "DeP" => false
+ case _ => true
+ }
+ assert(inFilterPushed, "IN filter must be evaluated on data source side")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"), // update
+ Row(2, 201, "hr"), // update
+ Row(3, 301, "hr"), // update
+ Row(4, 400, "software"), // unchanged
+ Row(5, 500, "software"), // unchanged
+ Row(6, 0, "hr"))) // insert
+ }
+ }
+
+ test("some target filters are evaluated on data source side") {
+ withTempView("source") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "hr" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |{ "pk": 4, "salary": 400, "dep": "software" }
+ |{ "pk": 5, "salary": 500, "dep": "software" }
+ |""".stripMargin)
+
+ val sourceDF = Seq(1, 2, 3, 6).toDF("pk")
+ sourceDF.createOrReplaceTempView("source")
+
+ val executedPlan = executeAndKeepPlan {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk AND t.dep IN ('hr', 'software') AND t.salary != -1
+ |WHEN MATCHED THEN
+ | UPDATE SET t.salary = t.salary + 1
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr')
+ |""".stripMargin)
+ }
+
+ val expressions =
flatMap(executedPlan)(_.expressions.flatMap(splitConjunctivePredicates))
+
+ val inFilterPushed = expressions.forall {
+ case In(attr: AttributeReference, _) if attr.name == "dep" => false
+ case _ => true
+ }
+ assert(inFilterPushed, "IN filter must be evaluated on data source side")
+
+ val notEqualFilterPreserved = expressions.exists {
+ case Not(EqualTo(attr: AttributeReference, _)) if attr.name ==
"salary" => true
+ case _ => false
+ }
+ assert(notEqualFilterPreserved, "NOT filter must be evaluated on Spark
side")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"), // update
+ Row(2, 201, "hr"), // update
+ Row(3, 301, "hr"), // update
+ Row(4, 400, "software"), // unchanged
+ Row(5, 500, "software"), // unchanged
+ Row(6, 0, "hr"))) // insert
+ }
+ }
+
+ test("pushable target filters are preserved with NOT MATCHED BY SOURCE
clause") {
+ withTempView("source") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "hr" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |{ "pk": 4, "salary": 400, "dep": "software" }
+ |{ "pk": 5, "salary": 500, "dep": "software" }
+ |""".stripMargin)
+
+ val sourceDF = Seq(1, 2, 3, 6).toDF("pk")
+ sourceDF.createOrReplaceTempView("source")
+
+ val executedPlan = executeAndKeepPlan {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk AND DeP IN ('hr', 'software')
+ |WHEN MATCHED THEN
+ | UPDATE SET t.salary = t.salary + 1
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr')
+ |WHEN NOT MATCHED BY SOURCE THEN
+ | DELETE
+ |""".stripMargin)
+ }
+
+ val expressions =
flatMap(executedPlan)(_.expressions.flatMap(splitConjunctivePredicates))
+ val inFilterPreserved = expressions.exists {
+ case In(attr: AttributeReference, _) if attr.name == "DeP" => true
+ case _ => false
+ }
+ assert(inFilterPreserved, "IN filter must be preserved")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"), // update
+ Row(2, 201, "hr"), // update
+ Row(3, 301, "hr"), // update
+ Row(6, 0, "hr"))) // insert
+ }
+ }
+
private def assertNoLeftBroadcastOrReplication(query: String): Unit = {
val plan = executeAndKeepPlan {
sql(query)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]