This is an automated email from the ASF dual-hosted git repository.

cloud-fan pushed a commit to branch branch-4.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.2 by this push:
     new 012958467845 [SPARK-56032][SQL][FOLLOWUP] Skip FilterExec 
subexpression elimination codegen when there is no common subexpression
012958467845 is described below

commit 01295846784501620f3c22c9b09d31fd6144f6cf
Author: Wenchen Fan <[email protected]>
AuthorDate: Tue Jun 2 09:24:38 2026 +0800

    [SPARK-56032][SQL][FOLLOWUP] Skip FilterExec subexpression elimination 
codegen when there is no common subexpression
    
    ### What changes were proposed in this pull request?
    
    This is a follow-up of #54862, which introduced subexpression elimination 
(CSE) in `FilterExec` whole-stage codegen.
    
    `FilterExec` takes the CSE codegen path whenever 
`subexpressionEliminationEnabled && otherPreds.nonEmpty`, regardless of whether 
any common subexpression actually exists. That path emits an 
`inputVarsEvalCode` prologue at the top of the per-row loop that eagerly 
evaluates every input column referenced by `otherPreds` (required so eliminated 
subexpressions can be materialized into shared variables). When there is 
nothing to eliminate, this prologue provides no benefit but still defeats [...]
    
    This PR gates the CSE path on whether `otherPreds` actually contain a 
common subexpression, using the same `EquivalentExpressions` analysis (and 
`output` binding) as the CSE codegen, so it agrees exactly with whether that 
path would find anything. When there is none, it falls back to the non-CSE 
`generatePredicateCode`, which loads columns lazily and preserves 
short-circuiting. Filters that do have a common subexpression are unaffected.
    
    To avoid analyzing the predicates twice, 
`subexpressionEliminationForWholeStageCodegen` gains an overload that accepts a 
pre-built `EquivalentExpressions`, so the single analysis used by the gate is 
reused by the codegen.
    
    ### Why are the changes needed?
    
    For a filter with no common subexpression but multiple conjuncts over 
different columns (e.g. `q_int BETWEEN ... AND (decimal_a BETWEEN ... OR 
decimal_b BETWEEN ...)`), the eager prologue decodes the decimal columns for 
every row, including rows a cheaper earlier predicate would have rejected. 
Decoding a high-precision decimal allocates a `BigInteger`/`BigDecimal` per 
call, so this is pure waste and shows up as a measurable performance regression 
versus the lazy non-CSE path (observed [...]
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. This is a codegen-only change; query results are unchanged.
    
    ### How was this patch tested?
    
    New unit test in `WholeStageCodegenSuite` asserting that, for a filter with 
no common subexpression, the CSE-enabled generated code is identical to the 
CSE-disabled code (i.e. it falls back to the lazy, short-circuiting path). The 
existing `FilterExec` CSE tests, which use genuine common subexpressions, still 
exercise the CSE path and pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude (Claude Code)
    
    Closes #56209 from cloud-fan/wenchen/filter-cse-skip-when-no-common-subexpr.
    
    Authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit d6036f470d750a8296016f4a8a88c17679295505)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/codegen/CodeGenerator.scala        | 16 ++++++---
 .../sql/execution/basicPhysicalOperators.scala     | 31 ++++++++++++++++--
 .../sql/execution/WholeStageCodegenSuite.scala     | 38 ++++++++++++++++++++++
 3 files changed, 78 insertions(+), 7 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 139a7f03cfa4..ae5774b200cf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1150,14 +1150,22 @@ class CodegenContext extends Logging {
    *      evaluation, we can look for generated subexpressions and do 
replacement.
    */
   def subexpressionEliminationForWholeStageCodegen(expressions: 
Seq[Expression]): SubExprCodes = {
-    // Create a clear EquivalentExpressions and SubExprEliminationState mapping
+    // Create a clear EquivalentExpressions and compute the common 
subexpressions.
     val equivalentExpressions: EquivalentExpressions = new 
EquivalentExpressions
+    expressions.foreach(equivalentExpressions.addExprTree(_))
+    subexpressionEliminationForWholeStageCodegen(equivalentExpressions)
+  }
+
+  /**
+   * Same as above, but takes a pre-built [[EquivalentExpressions]]. A caller 
that has already
+   * analyzed the expressions (e.g. to decide whether any common subexpression 
exists) can reuse
+   * that analysis here instead of rebuilding it.
+   */
+  def subexpressionEliminationForWholeStageCodegen(
+      equivalentExpressions: EquivalentExpressions): SubExprCodes = {
     val localSubExprEliminationExprsForNonSplit =
       mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState]
 
-    // Add each expression tree and compute the common subexpressions.
-    expressions.foreach(equivalentExpressions.addExprTree(_))
-
     // Get all the expressions that appear at least twice and set up the state 
for subexpression
     // elimination.
     val commonExprs = equivalentExpressions.getCommonSubexpressions
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 92cf3f59d575..8d183f915e8a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -242,6 +242,22 @@ case class FilterExec(condition: Expression, child: 
SparkPlan)
   // The columns that will filtered out by `IsNotNull` could be considered as 
not nullable.
   private val notNullAttributes = 
notNullPreds.flatMap(_.references).distinct.map(_.exprId)
 
+  // `otherPreds` bound against this operator's `output`, shared between the 
CSE gate in
+  // `doConsume` and the CSE codegen itself. Codegen-only derived state, so 
`@transient`: it is
+  // computed on the driver during code generation and never accessed on 
executors.
+  @transient private lazy val boundOtherPreds: Seq[Expression] =
+    otherPreds.map(BindReferences.bindReference(_, output))
+
+  // CSE analysis of `boundOtherPreds`, built once and reused. `doConsume` 
consults it to decide
+  // whether any common subexpression is worth eliminating; when one is, the 
same analysis is
+  // handed to `subexpressionEliminationForWholeStageCodegen` rather than 
rebuilt. `@transient`
+  // because `EquivalentExpressions` is not serializable (and this is 
driver-only codegen state).
+  @transient private lazy val otherPredsEquivalentExpressions: 
EquivalentExpressions = {
+    val equivalentExpressions = new EquivalentExpressions
+    boundOtherPreds.foreach(equivalentExpressions.addExprTree(_))
+    equivalentExpressions
+  }
+
   // Mark this as empty. We'll evaluate the input during doConsume(). We don't 
want to evaluate
   // all the variables at the beginning to take advantage of short circuiting.
   override def usedInputs: AttributeSet = AttributeSet.empty
@@ -291,8 +307,17 @@ case class FilterExec(condition: Expression, child: 
SparkPlan)
     //       without consulting `isNull_X`. The (b) interleaving gives us that 
ordering
     //       for free, since the IsNotNull check fires before the CSE 
precompute keyed
     //       off the same reference.
+    // Only take the CSE path when there is actually a common subexpression to 
eliminate. That
+    // path emits the `inputVarsEvalCode` prologue below, which eagerly 
evaluates every
+    // `otherPreds` input column at the top of the row loop -- required so 
eliminated
+    // subexpressions can be materialized into shared variables, but it 
defeats the
+    // short-circuiting the non-CSE path gets from loading columns lazily, 
just before the
+    // predicate that needs them. With no common subexpression the prologue is 
pure overhead
+    // (e.g. decoding a decimal column for rows a cheaper earlier predicate 
would reject), so we
+    // fall back to `generatePredicateCode`.
     val (prologueCode, predicateCode) =
-      if (conf.subexpressionEliminationEnabled && otherPreds.nonEmpty) {
+      if (conf.subexpressionEliminationEnabled && otherPreds.nonEmpty &&
+          otherPredsEquivalentExpressions.getCommonSubexpressions.nonEmpty) {
         // Pre-evaluate input variables before CSE analysis: CSE clears
         // ctx.currentVars[i].code as a side effect; without this 
pre-evaluation, Janino
         // fails when otherPreds reference the same input columns that CSE 
already
@@ -301,8 +326,8 @@ case class FilterExec(condition: Expression, child: 
SparkPlan)
         val inputVarsEvalCode = evaluateRequiredVariables(
           child.output, input, otherPredInputAttrs)
 
-        val boundOtherPreds = otherPreds.map(BindReferences.bindReference(_, 
output))
-        val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundOtherPreds)
+        val subExprs =
+          
ctx.subexpressionEliminationForWholeStageCodegen(otherPredsEquivalentExpressions)
 
         // Group CSE states by the index of the first otherPred that 
references them.
         // `evaluateSubExprEliminationState` recursively emits each state's 
children
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index e5b9e7016841..8f0ec0ffd6f1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -1186,4 +1186,42 @@ class WholeStageCodegenSuite extends SharedSparkSession
       }
     }
   }
+
+  test("SPARK-56032: FilterExec skips CSE codegen when there is no common 
subexpression") {
+    // When otherPreds share no common subexpression, the CSE codegen path 
provides no benefit
+    // but would still eagerly evaluate every referenced input column at the 
top of the row loop
+    // (the inputVarsEvalCode prologue), defeating the lazy, short-circuiting 
column loads of the
+    // non-CSE path. Verify that with CSE enabled we fall back to the exact 
same generated code as
+    // with CSE disabled, so no column is decoded for rows an earlier 
predicate would reject.
+    val schema = StructType(Seq(
+      StructField("a", IntegerType, nullable = true),
+      StructField("b", IntegerType, nullable = true)))
+    val data = spark.sparkContext.parallelize(Seq(
+      Row(1, 5), Row(null, 3), Row(4, null), Row(5, 6), Row(7, 8), Row(2, 3)))
+    val expected = Seq(Row(5, 6), Row(7, 8))
+
+    def filterCode(cseEnabled: Boolean): String = {
+      withSQLConf(
+        SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> cseEnabled.toString,
+        SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
+        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+        val df = spark.createDataFrame(data, schema)
+        // `a > 4` and `b > 4` reference different columns and share no 
subexpression.
+        val filtered = df.where("a IS NOT NULL AND a > 4 AND b > 4")
+        val plan = filtered.queryExecution.executedPlan
+        assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]),
+          "Filter should be in whole-stage codegen")
+        checkAnswer(filtered, expected)
+        codegenString(plan)
+      }
+    }
+
+    // Each `createDataFrame` mints fresh attribute exprIds (e.g. `a#16`), 
which appear in the
+    // plan-tree header of the codegen dump but not in the generated Java. 
Normalize them away so
+    // the comparison reflects the generated code, not the id counter.
+    def normalize(code: String): String = code.replaceAll("#\\d+", "#")
+    assert(normalize(filterCode(cseEnabled = true)) == 
normalize(filterCode(cseEnabled = false)),
+      "With no common subexpression, CSE-enabled FilterExec codegen should be 
identical to " +
+        "CSE-disabled codegen (i.e. fall back to the lazy, short-circuiting 
non-CSE path)")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to