This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7398e93a2be [SPARK-44139][SQL] Discard completely pushed down filters 
in group-based MERGE operations
7398e93a2be is described below

commit 7398e93a2be226d0e71af5b632fa640b37ea5e43
Author: aokolnychyi <[email protected]>
AuthorDate: Thu Jun 22 21:28:21 2023 -0700

    [SPARK-44139][SQL] Discard completely pushed down filters in group-based 
MERGE operations
    
    ### What changes were proposed in this pull request?
    
    This PR adds logic to discard completely pushed down filters in group-based 
MERGE operations.
    
    ### Why are the changes needed?
    
    These changes are needed to simplify join conditions used in group-based 
MERGE operations to avoid evaluating unnecessary expressions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    This PR comes with tests.
    
    Closes #41700 from aokolnychyi/spark-44139.
    
    Authored-by: aokolnychyi <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../sql/connector/catalog/InMemoryBaseTable.scala  |  30 ++++-
 .../GroupBasedRowLevelOperationScanPlanning.scala  |  68 ++++++++--
 .../sql/connector/MergeIntoTableSuiteBase.scala    | 140 +++++++++++++++++++++
 3 files changed, 221 insertions(+), 17 deletions(-)

diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 22155ade0aa..a0a4d8bdee9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -292,12 +292,31 @@ abstract class InMemoryBaseTable(
     new InMemoryScanBuilder(schema)
   }
 
+  private def canEvaluate(filter: Filter): Boolean = {
+    if (partitioning.length == 1 && partitioning.head.references.length == 1) {
+      filter match {
+        case In(attrName, _) if attrName == 
partitioning.head.references.head.toString => true
+        case _ => false
+      }
+    } else {
+      false
+    }
+  }
+
   class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
       with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
     private var schema: StructType = tableSchema
+    private var postScanFilters: Array[Filter] = Array.empty
+    private var evaluableFilters: Array[Filter] = Array.empty
+    private var _pushedFilters: Array[Filter] = Array.empty
 
-    override def build: Scan =
-      InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, 
tableSchema)
+    override def build: Scan = {
+      val scan = InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), 
schema, tableSchema)
+      if (evaluableFilters.nonEmpty) {
+        scan.filter(evaluableFilters)
+      }
+      scan
+    }
 
     override def pruneColumns(requiredSchema: StructType): Unit = {
       // The required schema could contain conflict-renamed metadata columns, 
so we need to match
@@ -310,11 +329,12 @@ abstract class InMemoryBaseTable(
       schema = StructType(prunedFields)
     }
 
-    private var _pushedFilters: Array[Filter] = Array.empty
-
     override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+      val (evaluableFilters, postScanFilters) = filters.partition(canEvaluate)
+      this.evaluableFilters = evaluableFilters
+      this.postScanFilters = postScanFilters
       this._pushedFilters = filters
-      this._pushedFilters
+      postScanFilters
     }
 
     override def pushedFilters(): Array[Filter] = this._pushedFilters
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
index 546ce6d8deb..11dddb50831 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
@@ -17,12 +17,14 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
AttributeSet, Expression, PredicateHelper, SubqueryExpression}
-import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, 
AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
+import org.apache.spark.sql.catalyst.planning.{GroupBasedRowLevelOperation, 
PhysicalOperation}
+import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, 
ReplaceData}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.expressions.filter.{Predicate => 
V2Filter}
 import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources.Filter
 
@@ -40,10 +42,14 @@ object GroupBasedRowLevelOperationScanPlanning extends 
Rule[LogicalPlan] with Pr
     // push down the filter from the command condition instead of the filter 
in the rewrite plan,
     // which is negated for data sources that only support replacing groups of 
data (e.g. files)
     case GroupBasedRowLevelOperation(rd: ReplaceData, cond, _, relation: 
DataSourceV2Relation) =>
+      assert(cond.deterministic, "row-level operation conditions must be 
deterministic")
+
       val table = relation.table.asRowLevelOperationTable
       val scanBuilder = table.newScanBuilder(relation.options)
 
-      val (pushedFilters, remainingFilters) = pushFilters(cond, 
relation.output, scanBuilder)
+      val (pushedFilters, evaluatedFilters, postScanFilters) =
+        pushFilters(cond, relation.output, scanBuilder)
+
       val pushedFiltersStr = if (pushedFilters.isLeft) {
         pushedFilters.left.get.mkString(", ")
       } else {
@@ -56,29 +62,67 @@ object GroupBasedRowLevelOperationScanPlanning extends 
Rule[LogicalPlan] with Pr
         s"""
            |Pushing operators to ${relation.name}
            |Pushed filters: $pushedFiltersStr
-           |Filters that were not pushed: ${remainingFilters.mkString(", ")}
+           |Filters evaluated on data source side: 
${evaluatedFilters.mkString(", ")}
+           |Filters evaluated on Spark side: ${postScanFilters.mkString(", ")}
            |Output: ${output.mkString(", ")}
          """.stripMargin)
 
-      // replace DataSourceV2Relation with DataSourceV2ScanRelation for the 
row operation table
-      // there may be multiple read relations for UPDATEs that are rewritten 
as UNION
-      rd transform {
+      rd transformDown {
+        // simplify the join condition in MERGE operations by discarding 
already evaluated filters
+        case j @ Join(
+            PhysicalOperation(_, _, r: DataSourceV2Relation), _, _, 
Some(cond), _)
+            if rd.operation.command == MERGE && evaluatedFilters.nonEmpty && 
r.table.eq(table) =>
+          j.copy(condition = Some(optimizeMergeJoinCondition(cond, 
evaluatedFilters)))
+
+        // replace DataSourceV2Relation with DataSourceV2ScanRelation for the 
row operation table
+        // there may be multiple read relations for UPDATEs that are rewritten 
as UNION
         case r: DataSourceV2Relation if r.table eq table =>
           DataSourceV2ScanRelation(r, scan, 
PushDownUtils.toOutputAttrs(scan.readSchema(), r))
       }
   }
 
+  // pushes down the operation condition and returns the following information:
+  // - pushed down filters
+  // - filter expressions that are fully evaluated on the data source side
+  //   (such filters can be discarded and don't have to be evaluated again on 
the Spark side)
+  // - post-scan filter expressions that must be evaluated on the Spark side
+  //   (such filters can overlap with pushed down filters, e.g. Parquet row 
group filtering)
   private def pushFilters(
       cond: Expression,
       tableAttrs: Seq[AttributeReference],
-      scanBuilder: ScanBuilder): (Either[Seq[Filter], Seq[V2Filter]], 
Seq[Expression]) = {
+      scanBuilder: ScanBuilder)
+  : (Either[Seq[Filter], Seq[V2Filter]], Seq[Expression], Seq[Expression]) = {
+
+    val (filtersWithSubquery, filtersWithoutSubquery) = findTableFilters(cond, 
tableAttrs)
+
+    val (pushedFilters, postScanFiltersWithoutSubquery) =
+      PushDownUtils.pushFilters(scanBuilder, filtersWithoutSubquery)
+
+    val postScanFilterSetWithoutSubquery = 
ExpressionSet(postScanFiltersWithoutSubquery)
+    val evaluatedFilters = filtersWithoutSubquery.filterNot { filter =>
+      postScanFilterSetWithoutSubquery.contains(filter)
+    }
 
+    val postScanFilters = postScanFiltersWithoutSubquery ++ filtersWithSubquery
+
+    (pushedFilters, evaluatedFilters, postScanFilters)
+  }
+
+  private def findTableFilters(
+      cond: Expression,
+      tableAttrs: Seq[AttributeReference]): (Seq[Expression], Seq[Expression]) 
= {
     val tableAttrSet = AttributeSet(tableAttrs)
     val filters = 
splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet))
     val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, 
tableAttrs)
-    val (_, normalizedFiltersWithoutSubquery) =
-      normalizedFilters.partition(SubqueryExpression.hasSubquery)
+    normalizedFilters.partition(SubqueryExpression.hasSubquery)
+  }
 
-    PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery)
+  private def optimizeMergeJoinCondition(
+      cond: Expression,
+      evaluatedFilters: Seq[Expression]): Expression = {
+    val evaluatedFilterSet = ExpressionSet(evaluatedFilters)
+    val predicates = splitConjunctivePredicates(cond)
+    val remainingPredicates = predicates.filterNot(evaluatedFilterSet.contains)
+    remainingPredicates.reduceLeftOption(And).getOrElse(TrueLiteral)
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index bd641b2026b..e7555c23fa4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connector
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, 
In, Not}
 import org.apache.spark.sql.catalyst.optimizer.BuildLeft
 import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue}
 import org.apache.spark.sql.connector.expressions.LiteralValue
@@ -1422,6 +1423,145 @@ abstract class MergeIntoTableSuiteBase extends 
RowLevelOperationSuiteBase {
     }
   }
 
+  test("all target filters are evaluated on data source side") {
+    withTempView("source") {
+      createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+        """{ "pk": 1, "salary": 100, "dep": "hr" }
+          |{ "pk": 2, "salary": 200, "dep": "hr" }
+          |{ "pk": 3, "salary": 300, "dep": "hr" }
+          |{ "pk": 4, "salary": 400, "dep": "software" }
+          |{ "pk": 5, "salary": 500, "dep": "software" }
+          |""".stripMargin)
+
+      val sourceDF = Seq(1, 2, 3, 6).toDF("pk")
+      sourceDF.createOrReplaceTempView("source")
+
+      val executedPlan = executeAndKeepPlan {
+        sql(
+          s"""MERGE INTO $tableNameAsString t
+             |USING source s
+             |ON t.pk = s.pk AND t.DeP IN ('hr', 'software')
+             |WHEN MATCHED THEN
+             | UPDATE SET t.salary = t.salary + 1
+             |WHEN NOT MATCHED THEN
+             | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr')
+             |""".stripMargin)
+      }
+
+      val expressions = 
flatMap(executedPlan)(_.expressions.flatMap(splitConjunctivePredicates))
+      val inFilterPushed = expressions.forall {
+        case In(attr: AttributeReference, _) if attr.name == "DeP" => false
+        case _ => true
+      }
+      assert(inFilterPushed, "IN filter must be evaluated on data source side")
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Seq(
+          Row(1, 101, "hr"), // update
+          Row(2, 201, "hr"), // update
+          Row(3, 301, "hr"), // update
+          Row(4, 400, "software"), // unchanged
+          Row(5, 500, "software"), // unchanged
+          Row(6, 0, "hr"))) // insert
+    }
+  }
+
+  test("some target filters are evaluated on data source side") {
+    withTempView("source") {
+      createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+        """{ "pk": 1, "salary": 100, "dep": "hr" }
+          |{ "pk": 2, "salary": 200, "dep": "hr" }
+          |{ "pk": 3, "salary": 300, "dep": "hr" }
+          |{ "pk": 4, "salary": 400, "dep": "software" }
+          |{ "pk": 5, "salary": 500, "dep": "software" }
+          |""".stripMargin)
+
+      val sourceDF = Seq(1, 2, 3, 6).toDF("pk")
+      sourceDF.createOrReplaceTempView("source")
+
+      val executedPlan = executeAndKeepPlan {
+        sql(
+          s"""MERGE INTO $tableNameAsString t
+             |USING source s
+             |ON t.pk = s.pk AND t.dep IN ('hr', 'software') AND t.salary != -1
+             |WHEN MATCHED THEN
+             | UPDATE SET t.salary = t.salary + 1
+             |WHEN NOT MATCHED THEN
+             | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr')
+             |""".stripMargin)
+      }
+
+      val expressions = 
flatMap(executedPlan)(_.expressions.flatMap(splitConjunctivePredicates))
+
+      val inFilterPushed = expressions.forall {
+        case In(attr: AttributeReference, _) if attr.name == "dep" => false
+        case _ => true
+      }
+      assert(inFilterPushed, "IN filter must be evaluated on data source side")
+
+      val notEqualFilterPreserved = expressions.exists {
+        case Not(EqualTo(attr: AttributeReference, _)) if attr.name == 
"salary" => true
+        case _ => false
+      }
+      assert(notEqualFilterPreserved, "NOT filter must be evaluated on Spark 
side")
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Seq(
+          Row(1, 101, "hr"), // update
+          Row(2, 201, "hr"), // update
+          Row(3, 301, "hr"), // update
+          Row(4, 400, "software"), // unchanged
+          Row(5, 500, "software"), // unchanged
+          Row(6, 0, "hr"))) // insert
+    }
+  }
+
+  test("pushable target filters are preserved with NOT MATCHED BY SOURCE 
clause") {
+    withTempView("source") {
+      createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+        """{ "pk": 1, "salary": 100, "dep": "hr" }
+          |{ "pk": 2, "salary": 200, "dep": "hr" }
+          |{ "pk": 3, "salary": 300, "dep": "hr" }
+          |{ "pk": 4, "salary": 400, "dep": "software" }
+          |{ "pk": 5, "salary": 500, "dep": "software" }
+          |""".stripMargin)
+
+      val sourceDF = Seq(1, 2, 3, 6).toDF("pk")
+      sourceDF.createOrReplaceTempView("source")
+
+      val executedPlan = executeAndKeepPlan {
+        sql(
+          s"""MERGE INTO $tableNameAsString t
+             |USING source s
+             |ON t.pk = s.pk AND DeP IN ('hr', 'software')
+             |WHEN MATCHED THEN
+             | UPDATE SET t.salary = t.salary + 1
+             |WHEN NOT MATCHED THEN
+             | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'hr')
+             |WHEN NOT MATCHED BY SOURCE THEN
+             | DELETE
+             |""".stripMargin)
+      }
+
+      val expressions = 
flatMap(executedPlan)(_.expressions.flatMap(splitConjunctivePredicates))
+      val inFilterPreserved = expressions.exists {
+        case In(attr: AttributeReference, _) if attr.name == "DeP" => true
+        case _ => false
+      }
+      assert(inFilterPreserved, "IN filter must be preserved")
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Seq(
+          Row(1, 101, "hr"), // update
+          Row(2, 201, "hr"), // update
+          Row(3, 301, "hr"), // update
+          Row(6, 0, "hr"))) // insert
+    }
+  }
+
   private def assertNoLeftBroadcastOrReplication(query: String): Unit = {
     val plan = executeAndKeepPlan {
       sql(query)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to