This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 5b28a6750bd2 [SPARK-57027][SQL] SortMergeJoinExec: skip 
statically-dead branches in codegen
5b28a6750bd2 is described below

commit 5b28a6750bd28af9e9dc5ce314e4f5cb1ffb596f
Author: Gengliang Wang <[email protected]>
AuthorDate: Sun May 31 13:03:42 2026 -0700

    [SPARK-57027][SQL] SortMergeJoinExec: skip statically-dead branches in 
codegen
    
    ### What changes were proposed in this pull request?
    
    This is a sub-task of 
[SPARK-56908](https://issues.apache.org/jira/browse/SPARK-56908).
    
    Two statically-dead patterns in `SortMergeJoinExec` codegen:
    
    1. `genComparison` emits
    
        ```
        comp = 0;
        if (comp == 0) { comp = compare(k1); }
        if (comp == 0) { comp = compare(k2); }
        ```
    
        The first `if (comp == 0)` is always true (we just assigned 0). Emit 
`comp = compare(k1);` directly; only wrap subsequent keys. `genComparison` is 
called 5x per SMJ stage (twice in `genScanner`, three times in 
`codegenFullOuter`). For single-key joins (common), each call collapses to one 
line.
    
    2. `genScanner` and `codegenFullOuter` emit `if (k1IsNull || k2IsNull || 
...) { handler }`. When all key `ExprValue`s have `isNull == FalseLiteral`, the 
disjunction is statically `false` and the whole block (including its 
`handleStreamedAnyNull` / "join with null row" handler) is dead. Detect this 
and omit the block. Hits fact/dimension joins on numeric keys where Spark has 
already proved non-nullability.
    
    ### Why are the changes needed?
    
    Smaller generated Java per SMJ stage. JIT eliminates the dead code at 
runtime; the win is smaller generated source, more 64KB method-limit headroom, 
and slightly faster Janino compile.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test suites cover both paths with whole-stage codegen on and off:
    - `OuterJoinSuite` (SMJ full-outer codegen + interpreted scanner).
    - `InnerJoinSuite` (SMJ codegen and non-codegen paths).
    - `ExistenceJoinSuite` (SMJ existence path).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code
    
    Closes #56075 from gengliangwang/SPARK-57027-smj-dead-branches.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit 29d73467c23f9b14e96ba13845b8b41e58cc13f3)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../sql/execution/joins/SortMergeJoinExec.scala    | 103 +++++++++++++++------
 1 file changed, 77 insertions(+), 26 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index b206fb528dcd..51604cdfedf1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -191,7 +191,13 @@ case class SortMergeJoinExec(
   }
 
   private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b: 
Seq[ExprCode]): String = {
-    val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
+    // The first key compare always runs, so emit it unguarded. Each 
subsequent key compare runs
+    // only when previous keys were equal (comp == 0).
+    val pairs = a.zip(b).zipWithIndex
+    val firstCompare = pairs.headOption.map { case ((l, r), i) =>
+      s"comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};"
+    }.getOrElse("comp = 0;")
+    val restCompares = pairs.drop(1).map { case ((l, r), i) =>
       s"""
          |if (comp == 0) {
          |  comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
@@ -199,8 +205,8 @@ case class SortMergeJoinExec(
        """.stripMargin.trim
     }
     s"""
-       |comp = 0;
-       |${comparisons.mkString("\n")}
+       |$firstCompare
+       |${restCompares.mkString("\n")}
      """.stripMargin
   }
 
@@ -216,11 +222,18 @@ case class SortMergeJoinExec(
     val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", 
forceInline = true)
     val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow", 
forceInline = true)
 
-    // Create variables for join keys from both sides.
+    // Create variables for join keys from both sides. Filter out 
`FalseLiteral` `isNull`
+    // terms before building the disjunction so the emitted check has no 
statically-dead
+    // `false` operands. When every key is statically non-nullable, the 
disjunction is
+    // empty and we skip emitting the check (and the dead handler branch) 
entirely.
     val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, 
streamedOutput)
-    val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ")
+    val nullableStreamedIsNulls = streamedKeyVars.map(_.isNull).filter(_ != 
FalseLiteral)
+    val streamedKeysNullable = nullableStreamedIsNulls.nonEmpty
+    val streamedAnyNull = nullableStreamedIsNulls.mkString(" || ")
     val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, 
bufferedOutput)
-    val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ")
+    val nullableBufferedIsNulls = bufferedKeyTmpVars.map(_.isNull).filter(_ != 
FalseLiteral)
+    val bufferedKeysNullable = nullableBufferedIsNulls.nonEmpty
+    val bufferedAnyNull = nullableBufferedIsNulls.mkString(" || ")
     // Copy the buffered key as class members so they could be used in next 
function call.
     val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars)
 
@@ -287,6 +300,27 @@ case class SortMergeJoinExec(
         s"$matches.add((UnsafeRow) $bufferedRow);"
       }
 
+    val checkStreamedAnyNull = if (streamedKeysNullable) {
+      s"""
+         |if ($streamedAnyNull) {
+         |  $handleStreamedAnyNull
+         |}
+       """.stripMargin
+    } else {
+      ""
+    }
+
+    val checkBufferedAnyNull = if (bufferedKeysNullable) {
+      s"""
+         |if ($bufferedAnyNull) {
+         |  $bufferedRow = null;
+         |  continue;
+         |}
+       """.stripMargin
+    } else {
+      ""
+    }
+
     // Generate a function to scan both streamed and buffered sides to find a 
match.
     // Return whether a match is found.
     //
@@ -329,9 +363,7 @@ case class SortMergeJoinExec(
          |    if (!streamedIter.hasNext()) return false;
          |    $streamedRow = (InternalRow) streamedIter.next();
          |    ${streamedKeyVars.map(_.code).mkString("\n")}
-         |    if ($streamedAnyNull) {
-         |      $handleStreamedAnyNull
-         |    }
+         |    ${checkStreamedAnyNull.trim}
          |    if (!$matches.isEmpty()) {
          |      ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
          |      if (comp == 0) {
@@ -348,10 +380,7 @@ case class SortMergeJoinExec(
          |        }
          |        $bufferedRow = (InternalRow) bufferedIter.next();
          |        ${bufferedKeyTmpVars.map(_.code).mkString("\n")}
-         |        if ($bufferedAnyNull) {
-         |          $bufferedRow = null;
-         |          continue;
-         |        }
+         |        ${checkBufferedAnyNull.trim}
          |        ${bufferedKeyVars.map(_.code).mkString("\n")}
          |      }
          |      ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)}
@@ -788,11 +817,17 @@ case class SortMergeJoinExec(
     val leftInputRow = ctx.addMutableState("InternalRow", "leftInputRow", 
forceInline = true)
     val rightInputRow = ctx.addMutableState("InternalRow", "rightInputRow", 
forceInline = true)
 
-    // Create variables for join keys from both sides.
+    // Create variables for join keys from both sides. As in `genScanner`, 
drop FalseLiteral
+    // `isNull` terms before joining the disjunction so the emitted check has 
no dead `false`
+    // operands; omit the check entirely when every key is statically 
non-nullable.
     val leftKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
-    val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
+    val nullableLeftIsNulls = leftKeyVars.map(_.isNull).filter(_ != 
FalseLiteral)
+    val leftKeysNullable = nullableLeftIsNulls.nonEmpty
+    val leftAnyNull = nullableLeftIsNulls.mkString(" || ")
     val rightKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, 
right.output)
-    val rightAnyNull = rightKeyVars.map(_.isNull).mkString(" || ")
+    val nullableRightIsNulls = rightKeyVars.map(_.isNull).filter(_ != 
FalseLiteral)
+    val rightKeysNullable = nullableRightIsNulls.nonEmpty
+    val rightAnyNull = nullableRightIsNulls.mkString(" || ")
     val matchedKeyVars = copyKeys(ctx, leftKeyVars)
     val leftMatchedKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, 
left.output)
     val rightMatchedKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, 
right.output)
@@ -866,6 +901,30 @@ case class SortMergeJoinExec(
     //  - Step 3: Buffer rows with same join keys from both sides into 
`leftBuffer` and
     //            `rightBuffer`. Reset bit sets for both buffers accordingly 
(`leftMatched` and
     //            `rightMatched`).
+    val checkLeftAnyNull = if (leftKeysNullable) {
+      s"""
+         |if ($leftAnyNull) {
+         |  // The left row join key is null, join it with null row
+         |  $outputLeftNoMatch
+         |  return;
+         |}
+       """.stripMargin
+    } else {
+      ""
+    }
+
+    val checkRightAnyNull = if (rightKeysNullable) {
+      s"""
+         |if ($rightAnyNull) {
+         |  // The right row join key is null, join it with null row
+         |  $outputRightNoMatch
+         |  return;
+         |}
+       """.stripMargin
+    } else {
+      ""
+    }
+
     val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
     ctx.addNewFunction(findNextJoinRowsFuncName,
       s"""
@@ -884,18 +943,10 @@ case class SortMergeJoinExec(
          |  }
          |
          |  ${leftKeyVars.map(_.code).mkString("\n")}
-         |  if ($leftAnyNull) {
-         |    // The left row join key is null, join it with null row
-         |    $outputLeftNoMatch
-         |    return;
-         |  }
+         |  ${checkLeftAnyNull.trim}
          |
          |  ${rightKeyVars.map(_.code).mkString("\n")}
-         |  if ($rightAnyNull) {
-         |    // The right row join key is null, join it with null row
-         |    $outputRightNoMatch
-         |    return;
-         |  }
+         |  ${checkRightAnyNull.trim}
          |
          |  ${genComparison(ctx, leftKeyVars, rightKeyVars)}
          |  if (comp < 0) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to