This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 41c6b9267643 [SPARK-55695][SQL] Avoid double planning in row-level 
operations
41c6b9267643 is described below

commit 41c6b9267643f67cba79eddc292b5fa8d53d513b
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Wed Mar 4 20:04:35 2026 -0800

    [SPARK-55695][SQL] Avoid double planning in row-level operations
    
    ### What changes were proposed in this pull request?
    
    This PR avoids unnecessary job planning in row-level operations like 
DELETE, UPDATE, and MERGE.
    
    ### Why are the changes needed?
    
    These changes improve performance and avoid useless job planning. Right 
now, `ReplaceData` and `WriteDelta` conditions may include subqueries that will 
undergo optimization (hence planning). This is not needed as those conditions 
are handled in a special way (see `RowLevelOperationRuntimeGroupFiltering`).
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    This PR comes with tests that would previously fail.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Claude Code v2.0.49 Sonnet 4.5.
    
    Closes #54492 from aokolnychyi/spark-55695.
    
    Authored-by: Anton Okolnychyi <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../catalyst/analysis/RewriteDeleteFromTable.scala |  3 +-
 .../catalyst/analysis/RewriteMergeIntoTable.scala  |  6 +-
 .../catalyst/analysis/RewriteRowLevelCommand.scala |  2 +
 .../sql/catalyst/analysis/RewriteUpdateTable.scala |  6 +-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 11 ++-
 .../connector/DeltaBasedDeleteFromTableSuite.scala | 27 ++++++++
 .../sql/connector/DeltaBasedUpdateTableSuite.scala | 27 ++++++++
 .../connector/GroupBasedDeleteFromTableSuite.scala | 64 ++++++++++++++++++
 .../connector/GroupBasedMergeIntoTableSuite.scala  | 78 ++++++++++++++++++++++
 .../sql/connector/GroupBasedUpdateTableSuite.scala | 65 +++++++++++++++++-
 .../sql/connector/RowLevelOperationSuiteBase.scala | 25 +++++--
 11 files changed, 304 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
index 0dc217788fd0..13cfc6b73ccb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala
@@ -88,7 +88,8 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {
     val writeRelation = relation.copy(table = operationTable)
     val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, 
remainingRowsPlan)
     val projections = buildReplaceDataProjections(query, relation.output, 
metadataAttrs)
-    ReplaceData(writeRelation, cond, query, relation, projections, Some(cond))
+    val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
+    ReplaceData(writeRelation, cond, query, relation, projections, 
groupFilterCond)
   }
 
   // build a rewrite plan for sources that support row deltas
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index 1d2e2fef2096..9675ee232786 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -175,7 +175,11 @@ object RewriteMergeIntoTable extends 
RewriteRowLevelCommand with PredicateHelper
     // predicates of the ON condition can be used to filter the target table 
(planning & runtime)
     // only if there is no NOT MATCHED BY SOURCE clause
     val (pushableCond, groupFilterCond) = if 
(notMatchedBySourceActions.isEmpty) {
-      (cond, Some(toGroupFilterCondition(relation, source, cond)))
+      if (groupFilterEnabled) {
+        (cond, Some(toGroupFilterCondition(relation, source, cond)))
+      } else {
+        (cond, None)
+      }
     } else {
       (TrueLiteral, None)
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
index 118ed4e99190..c5b81dec87c9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
@@ -45,6 +45,8 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
   private final val DELTA_OPERATIONS_WITH_ROW_ID =
     Set(DELETE_OPERATION, UPDATE_OPERATION)
 
+  protected def groupFilterEnabled: Boolean = 
conf.runtimeRowLevelOperationGroupFilterEnabled
+
   protected def buildOperationTable(
       table: SupportsRowLevelOperations,
       command: Command,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
index a4453ae51734..caf7579da889 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
@@ -79,7 +79,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
     val writeRelation = relation.copy(table = operationTable)
     val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, 
updatedAndRemainingRowsPlan)
     val projections = buildReplaceDataProjections(query, relation.output, 
metadataAttrs)
-    ReplaceData(writeRelation, cond, query, relation, projections, Some(cond))
+    val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
+    ReplaceData(writeRelation, cond, query, relation, projections, 
groupFilterCond)
   }
 
   // build a rewrite plan for sources that support replacing groups of data 
(e.g. files, partitions)
@@ -113,7 +114,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
     val writeRelation = relation.copy(table = operationTable)
     val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, 
updatedAndRemainingRowsPlan)
     val projections = buildReplaceDataProjections(query, relation.output, 
metadataAttrs)
-    ReplaceData(writeRelation, cond, query, relation, projections, Some(cond))
+    val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
+    ReplaceData(writeRelation, cond, query, relation, projections, 
groupFilterCond)
   }
 
   // this method assumes the assignments have been already aligned before
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c953258cff2f..ad4769ff8e31 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -370,7 +370,16 @@ abstract class Optimizer(catalogManager: CatalogManager)
       s.withNewPlan(removeTopLevelSort(newPlan))
     }
 
-    def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformAllExpressionsWithPruning(
+    // optimizes subquery expressions, ignoring row-level operation conditions
+    def apply(plan: LogicalPlan): LogicalPlan = {
+      plan.transformWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
+        case wd: WriteDelta => wd
+        case rd: ReplaceData => rd
+        case p => optimize(p)
+      }
+    }
+
+    private def optimize(plan: LogicalPlan): LogicalPlan = 
plan.transformExpressionsWithPruning(
       _.containsPattern(PLAN_EXPRESSION), ruleId) {
       // Do not optimize DPP subquery, as it was created from optimized plan 
and we should not
       // optimize it again, to save optimization time and avoid breaking 
broadcast/subquery reuse.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
index eaa3f2f42b21..9046123ddbd3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.connector
 
 import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.expressions.InSubquery
 import org.apache.spark.sql.types.StructType
 
 class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase {
@@ -136,4 +137,30 @@ class DeltaBasedDeleteFromTableSuite extends 
DeleteFromTableSuiteBase {
       sql(s"SELECT * FROM $tableNameAsString"),
       Row(2, 2, "us", "software") :: Row(3, 3, "canada", "hr") :: Nil)
   }
+
+  test("delete does not double plan table") {
+    createAndInitTable("pk INT NOT NULL, id INT, salary INT, dep STRING",
+      """{ "pk": 1, "id": 1, "salary": 300, "dep": 'hr' }
+        |{ "pk": 2, "id": 2, "salary": 150, "dep": 'software' }
+        |{ "pk": 3, "id": 3, "salary": 120, "dep": 'hr' }
+        |""".stripMargin)
+
+    val (cond, groupFilterCond) = executeAndKeepConditions {
+      sql(
+        s"""DELETE FROM $tableNameAsString
+           |WHERE id IN (SELECT id FROM $tableNameAsString WHERE salary > 200)
+           |""".stripMargin)
+    }
+
+    cond match {
+      case InSubquery(_, query) => assertNoScanPlanning(query.plan)
+      case _ => fail(s"unexpected condition: $cond")
+    }
+
+    assert(groupFilterCond.isEmpty, "delta operations must not have group 
filter")
+
+    checkAnswer(
+      sql(s"SELECT * FROM $tableNameAsString"),
+      Row(2, 2, 150, "software") :: Row(3, 3, 120, "hr") :: Nil)
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala
index c9fd5d6e3ff0..813e8779f5f9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.connector
 
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.InSubquery
 import org.apache.spark.sql.types.StructType
 
 class DeltaBasedUpdateTableSuite extends DeltaBasedUpdateTableSuiteBase {
@@ -91,4 +92,30 @@ class DeltaBasedUpdateTableSuite extends 
DeltaBasedUpdateTableSuiteBase {
         updateWriteLogEntry(id = 1, metadata = Row("hr", null), data = Row(1, 
-1, "hr")))
     }
   }
+
+  test("update does not double plan table") {
+    createAndInitTable("pk INT NOT NULL, id INT, salary INT, dep STRING",
+      """{ "pk": 1, "id": 1, "salary": 300, "dep": 'hr' }
+        |{ "pk": 2, "id": 2, "salary": 150, "dep": 'software' }
+        |{ "pk": 3, "id": 3, "salary": 120, "dep": 'hr' }
+        |""".stripMargin)
+
+    val (cond, groupFilterCond) = executeAndKeepConditions {
+      sql(
+        s"""UPDATE $tableNameAsString SET salary = -1
+           |WHERE id IN (SELECT id FROM $tableNameAsString WHERE salary > 200)
+           |""".stripMargin)
+    }
+
+    cond match {
+      case InSubquery(_, query) => assertNoScanPlanning(query.plan)
+      case _ => fail(s"unexpected condition: $cond")
+    }
+
+    assert(groupFilterCond.isEmpty, "delta operations must not have group 
filter")
+
+    checkAnswer(
+      sql(s"SELECT * FROM $tableNameAsString"),
+      Row(1, 1, -1, "hr") :: Row(2, 2, 150, "software") :: Row(3, 3, 120, 
"hr") :: Nil)
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
index 4dd09a2f1c83..2f922295010f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.connector
 
 import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.expressions.InSubquery
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 
@@ -167,4 +168,67 @@ class GroupBasedDeleteFromTableSuite extends 
DeleteFromTableSuiteBase {
       checkReplacedPartitions(Seq("hr"))
     }
   }
+
+  test("delete does not double plan table (group filter enabled)") {
+    withSQLConf(SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED.key 
-> "true") {
+      createAndInitTable("id INT, salary INT, dep STRING",
+        """{ "id": 1, "salary": 300, "dep": 'hr' }
+          |{ "id": 2, "salary": 150, "dep": 'software' }
+          |{ "id": 3, "salary": 120, "dep": 'hr' }
+          |""".stripMargin)
+
+      val (cond, groupFilterCond) = executeAndKeepConditions {
+        sql(
+          s"""DELETE FROM $tableNameAsString
+             |WHERE id IN (SELECT id FROM $tableNameAsString WHERE salary > 
200)
+             |""".stripMargin)
+      }
+
+      cond match {
+        case InSubquery(_, query) => assertNoScanPlanning(query.plan)
+        case _ => fail(s"unexpected condition: $cond")
+      }
+
+      groupFilterCond match {
+        case Some(InSubquery(_, query)) => assertNoScanPlanning(query.plan)
+        case _ => fail(s"unexpected group filter: $groupFilterCond")
+      }
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil)
+
+      checkReplacedPartitions(Seq("hr"))
+    }
+  }
+
+  test("delete does not double plan table (group filter disabled)") {
+    withSQLConf(SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED.key 
-> "false") {
+      createAndInitTable("id INT, salary INT, dep STRING",
+        """{ "id": 1, "salary": 300, "dep": 'hr' }
+          |{ "id": 2, "salary": 150, "dep": 'software' }
+          |{ "id": 3, "salary": 120, "dep": 'hr' }
+          |""".stripMargin)
+
+      val (cond, groupFilterCond) = executeAndKeepConditions {
+        sql(
+          s"""DELETE FROM $tableNameAsString
+             |WHERE id IN (SELECT id FROM $tableNameAsString WHERE salary > 
200)
+             |""".stripMargin)
+      }
+
+      cond match {
+        case InSubquery(_, query) => assertNoScanPlanning(query.plan)
+        case _ => fail(s"unexpected condition: $cond")
+      }
+
+      assert(groupFilterCond.isEmpty, "group filter must be empty")
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil)
+
+      checkReplacedPartitions(Seq("software", "hr"))
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala
index 63eba256d8f2..b549723247cd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedMergeIntoTableSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.connector
 
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.Exists
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 
@@ -176,4 +177,81 @@ class GroupBasedMergeIntoTableSuite extends 
MergeIntoTableSuiteBase {
       checkReplacedPartitions(Seq("hr"))
     }
   }
+
+  test("merge does not double plan table (group filter enabled)") {
+    withSQLConf(SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED.key 
-> "true") {
+      withTempView("source") {
+        createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+          """{ "pk": 1, "salary": 100, "dep": "hr" }
+            |{ "pk": 2, "salary": 200, "dep": "software" }
+            |{ "pk": 3, "salary": 300, "dep": "hr" }
+            |""".stripMargin)
+
+        sql(
+          s"""CREATE TEMP VIEW source AS
+             |SELECT pk, salary FROM $tableNameAsString WHERE salary > 150
+             |""".stripMargin)
+
+        val (_, groupFilterCond) = executeAndKeepConditions {
+          sql(
+            s"""MERGE INTO $tableNameAsString t
+               |USING source s
+               |ON t.pk = s.pk
+               |WHEN MATCHED THEN
+               | UPDATE SET t.salary = s.salary + 1
+               |WHEN NOT MATCHED THEN
+               | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'new')
+               |""".stripMargin)
+        }
+
+        groupFilterCond match {
+          case Some(p: Exists) => assertNoScanPlanning(p.plan)
+          case _ => fail(s"unexpected group filter: $groupFilterCond")
+        }
+
+        checkAnswer(
+          sql(s"SELECT * FROM $tableNameAsString"),
+          Seq(Row(1, 100, "hr"), Row(2, 201, "software"), Row(3, 301, "hr")))
+
+        checkReplacedPartitions(Seq("software", "hr"))
+      }
+    }
+  }
+
+  test("merge does not double plan table (group filter disabled)") {
+    withSQLConf(SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED.key 
-> "false") {
+      withTempView("source") {
+        createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+          """{ "pk": 1, "salary": 100, "dep": "hr" }
+            |{ "pk": 2, "salary": 200, "dep": "software" }
+            |{ "pk": 3, "salary": 300, "dep": "hr" }
+            |""".stripMargin)
+
+        sql(
+          s"""CREATE TEMP VIEW source AS
+             |SELECT pk, salary FROM $tableNameAsString WHERE salary > 150
+             |""".stripMargin)
+
+        val (_, groupFilterCond) = executeAndKeepConditions {
+          sql(
+            s"""MERGE INTO $tableNameAsString t
+               |USING source s
+               |ON t.pk = s.pk
+               |WHEN MATCHED THEN
+               | UPDATE SET t.salary = s.salary + 1
+               |WHEN NOT MATCHED THEN
+               | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'new')
+               |""".stripMargin)
+        }
+
+        assert(groupFilterCond.isEmpty, "group filter must be disabled")
+
+        checkAnswer(
+          sql(s"SELECT * FROM $tableNameAsString"),
+          Seq(Row(1, 100, "hr"), Row(2, 201, "software"), Row(3, 301, "hr")))
+
+        checkReplacedPartitions(Seq("hr", "software"))
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala
index 30545f5aa01a..61defff7ebf0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedUpdateTableSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.connector
 
 import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
+import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, 
InSubquery}
 import org.apache.spark.sql.execution.{InSubqueryExec, ReusedSubqueryExec}
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.internal.SQLConf
@@ -188,4 +188,67 @@ class GroupBasedUpdateTableSuite extends 
UpdateTableSuiteBase {
         stop = 75)
     )
   }
+
+  test("update does not double plan table (group filter enabled)") {
+    withSQLConf(SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED.key 
-> "true") {
+      createAndInitTable("id INT, salary INT, dep STRING",
+        """{ "id": 1, "salary": 300, "dep": 'hr' }
+          |{ "id": 2, "salary": 150, "dep": 'software' }
+          |{ "id": 3, "salary": 120, "dep": 'hr' }
+          |""".stripMargin)
+
+      val (cond, groupFilterCond) = executeAndKeepConditions {
+        sql(
+          s"""UPDATE $tableNameAsString SET salary = -1
+             |WHERE id IN (SELECT id FROM $tableNameAsString WHERE salary > 
200)
+             |""".stripMargin)
+      }
+
+      cond match {
+        case InSubquery(_, query) => assertNoScanPlanning(query.plan)
+        case _ => fail(s"unexpected condition: $cond")
+      }
+
+      groupFilterCond match {
+        case Some(InSubquery(_, query)) => assertNoScanPlanning(query.plan)
+        case _ => fail(s"unexpected group filter: $groupFilterCond")
+      }
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Row(1, -1, "hr") :: Row(2, 150, "software") :: Row(3, 120, "hr") :: 
Nil)
+
+      checkReplacedPartitions(Seq("hr"))
+    }
+  }
+
+  test("update does not double plan table (group filter disabled)") {
+    withSQLConf(SQLConf.RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED.key 
-> "false") {
+      createAndInitTable("id INT, salary INT, dep STRING",
+        """{ "id": 1, "salary": 300, "dep": 'hr' }
+          |{ "id": 2, "salary": 150, "dep": 'software' }
+          |{ "id": 3, "salary": 120, "dep": 'hr' }
+          |""".stripMargin)
+
+      val (cond, groupFilterCond) = executeAndKeepConditions {
+        sql(
+          s"""UPDATE $tableNameAsString SET salary = -1
+             |WHERE id IN (SELECT id FROM $tableNameAsString WHERE salary > 
200)
+             |""".stripMargin)
+      }
+
+      cond match {
+        case InSubquery(_, query) => assertNoScanPlanning(query.plan)
+        case _ => fail(s"unexpected condition: $cond")
+      }
+
+      assert(groupFilterCond.isEmpty, "group filter must be disabled")
+
+      checkAnswer(
+        sql(s"SELECT * FROM $tableNameAsString"),
+        Row(1, -1, "hr") :: Row(2, 150, "software") :: Row(3, 120, "hr") :: 
Nil)
+
+      checkReplacedPartitions(Seq("hr", "software"))
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
index f7aec678292e..8c51fb17b2cf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
@@ -22,9 +22,10 @@ import java.util.Collections
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row}
-import org.apache.spark.sql.QueryTest.sameRows
+import org.apache.spark.sql.QueryTest.{sameRows, withQueryExecutionsCaptured}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, 
GenericRowWithSchema}
+import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, 
Expression, GenericRowWithSchema}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData, 
WriteDelta}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY
 import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, 
Identifier, InMemoryRowLevelOperationTable, 
InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, 
Reinsert, TableInfo, Update, Write}
@@ -32,7 +33,7 @@ import 
org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity,
 import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, 
SparkPlan}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, 
DataSourceV2Relation, DataSourceV2ScanRelation}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, 
StructField, StructType}
@@ -152,6 +153,22 @@ abstract class RowLevelOperationSuiteBase
     stripAQEPlan(executedPlan)
   }
 
+  // executes an operation and extracts conditions from ReplaceData or 
WriteDelta
+  protected def executeAndKeepConditions(func: => Unit): (Expression, 
Option[Expression]) = {
+    val Seq(qe) = withQueryExecutionsCaptured(spark)(func)
+    qe.optimizedPlan.collectFirst {
+      case rd: ReplaceData => (rd.condition, rd.groupFilterCondition)
+      case wd: WriteDelta => (wd.condition, None)
+    }.getOrElse(fail("couldn't find row-level operation in optimized plan"))
+  }
+
+  protected def assertNoScanPlanning(plan: LogicalPlan): Unit = {
+    val relations = plan.collect { case r: DataSourceV2Relation => r }
+    assert(relations.nonEmpty, "plan must contain relations")
+    val scans = plan.collect { case s: DataSourceV2ScanRelation => s }
+    assert(scans.isEmpty, "plan must not contain scan relations")
+  }
+
   protected def executeAndCheckScan(
       query: String,
       expectedScanSchema: String): Unit = {
@@ -206,7 +223,7 @@ abstract class RowLevelOperationSuiteBase
       case Seq(partValue) => partValue
       case other => fail(s"expected only one partition value: $other" )
     }
-    assert(actualPartitions == expectedPartitions, "replaced partitions must 
match")
+    assert(actualPartitions.toSet == expectedPartitions.toSet, "replaced 
partitions must match")
   }
 
   protected def checkLastWriteInfo(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to