This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 32105b5373a0 [SPARK-57026][SQL] SortMergeJoinExec and
ShuffledHashJoinExec: replace anonymous TaskCompletionListener with shared
JoinHelper methods
32105b5373a0 is described below
commit 32105b5373a03109a742f2b74ba6f9b4cf5505d7
Author: Gengliang Wang <[email protected]>
AuthorDate: Sat May 30 20:13:12 2026 -0700
[SPARK-57026][SQL] SortMergeJoinExec and ShuffledHashJoinExec: replace
anonymous TaskCompletionListener with shared JoinHelper methods
### What changes were proposed in this pull request?
This is a sub-task of
[SPARK-56908](https://issues.apache.org/jira/browse/SPARK-56908).
Two join operators emit anonymous `TaskCompletionListener`s whose bodies
are type-independent:
- `SortMergeJoinExec.doProduce` registers a per-stage anonymous inner class
that adds `matches.spillSize()` to the `spillSize` metric.
- `ShuffledHashJoinExec.buildSideOrFullOuterJoinNonUniqueKey` registers a
runtime anonymous closure that adds the `OpenHashSet[Long]` memory footprint
(bit-set + data array) to `buildDataSize`.
Hoist both into shared static helpers in a new file
`sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java`:
```java
recordSpillSizeOnTaskCompletion(ExternalAppendOnlyUnsafeRowArray, SQLMetric)
recordOpenHashSetMemoryUsageOnTaskCompletion(OpenHashSet<?>, SQLMetric)
```
Also remove the now-unused `SortMergeJoinExec.getTaskContext()` whose only
caller was the inlined listener.
### Why are the changes needed?
- Smaller generated Java per `SortMergeJoinExec` whole-stage-codegen stage:
one anonymous inner class is no longer emitted per stage.
- Centralises the metric-recording listener bodies in one place where the
JIT can compile them once instead of once per stage.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing test suites cover both paths with whole-stage codegen on and off:
- `OuterJoinSuite` (SMJ full-outer codegen path).
- `InnerJoinSuite` (SMJ codegen path with spill).
- ShuffledHashJoin full-outer non-unique-key path tests in `OuterJoinSuite`.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code
Closes #56074 from gengliangwang/SPARK-57026-listener-helpers.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit be8a32d67656bfff1ac75422498db8314a37e939)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../spark/sql/execution/joins/JoinHelper.java | 35 ++++++++++++++++++++--
.../sql/execution/joins/ShuffledHashJoinExec.scala | 12 +++-----
.../sql/execution/joins/SortMergeJoinExec.scala | 18 ++---------
3 files changed, 39 insertions(+), 26 deletions(-)
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java
index 91156b2600fd..041bfa04081f 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java
@@ -17,12 +17,17 @@
package org.apache.spark.sql.execution.joins;
+import org.apache.spark.TaskContext;
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray;
+import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.util.collection.BitSet;
+import org.apache.spark.util.collection.OpenHashSet;
/**
* Static helpers shared by join operators in this package, used both from
whole-stage codegen and
- * from interpreted execution paths. Hoisting recurring snippets here keeps
the generated Java
- * source smaller and lets the JIT compile the bodies once instead of once per
stage.
+ * from interpreted execution paths. Hoisting recurring snippets here
(especially the ones that
+ * would otherwise be emitted as anonymous inner classes per generated stage)
keeps the generated
+ * Java source smaller and lets the JIT compile the bodies once instead of
once per stage.
*/
public final class JoinHelper {
@@ -44,4 +49,30 @@ public final class JoinHelper {
}
return new BitSet(bufferSize);
}
+
+ /**
+ * Register a task-completion listener that adds the final spill size of
{@code matches} to
+ * {@code spillSize}. Replaces an anonymous {@code TaskCompletionListener}
that would otherwise
+ * be generated per {@code SortMergeJoinExec} whole-stage class.
+ */
+ public static void recordSpillSizeOnTaskCompletion(
+ ExternalAppendOnlyUnsafeRowArray matches, SQLMetric spillSize) {
+ TaskContext.get().addTaskCompletionListener(context -> {
+ spillSize.add(matches.spillSize());
+ });
+ }
+
+ /**
+ * Register a task-completion listener that adds the estimated memory
footprint of
+ * {@code matchedRows} (the bit-set plus the data array) to {@code metric}.
Used by
+ * {@code ShuffledHashJoinExec} to track {@code buildDataSize} for its
matched-row tracker.
+ */
+ public static void recordOpenHashSetMemoryUsageOnTaskCompletion(
+ OpenHashSet<?> matchedRows, SQLMetric metric) {
+ TaskContext.get().addTaskCompletionListener(context -> {
+ long bitSetEstimatedSize = matchedRows.getBitSet().capacity() / 8L;
+ long dataEstimatedSize = matchedRows.capacity() * 8L;
+ metric.add(bitSetEstimatedSize + dataEstimatedSize);
+ });
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 0f90f443ad41..8d65a082984f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -270,14 +270,10 @@ case class ShuffledHashJoinExec private (
buildNullRow: GenericInternalRow,
isFullOuterJoin: Boolean): Iterator[InternalRow] = {
val matchedRows = new OpenHashSet[Long]
- TaskContext.get().addTaskCompletionListener[Unit](_ => {
- // At the end of the task, update the task's memory usage for this
- // [[OpenHashSet]] to track matched rows, which has two parts:
- // [[OpenHashSet._bitset]] and [[OpenHashSet._data]].
- val bitSetEstimatedSize = matchedRows.getBitSet.capacity / 8
- val dataEstimatedSize = matchedRows.capacity * 8
- longMetric("buildDataSize") += bitSetEstimatedSize + dataEstimatedSize
- })
+ // At the end of the task, update the task's memory usage for this
OpenHashSet that tracks
+ // matched rows (its underlying bit-set plus data array).
+ JoinHelper.recordOpenHashSetMemoryUsageOnTaskCompletion(
+ matchedRows, longMetric("buildDataSize"))
def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 985fc518742c..b206fb528dcd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -440,13 +440,6 @@ case class SortMergeJoinExec(
override def needCopyResult: Boolean = true
- /**
- * This is called by generated Java class, should be public.
- */
- def getTaskContext(): TaskContext = {
- TaskContext.get()
- }
-
override def doProduce(ctx: CodegenContext): String = {
// Specialize `doProduce` code for full outer join, because full outer
join needs to
// buffer both sides of join.
@@ -591,16 +584,9 @@ case class SortMergeJoinExec(
}
val initJoin = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initJoin")
+ val helperCls = classOf[JoinHelper].getName
val addHookToRecordMetrics =
- s"""
- |$thisPlan.getTaskContext().addTaskCompletionListener(
- | new org.apache.spark.util.TaskCompletionListener() {
- | @Override
- | public void onTaskCompletion(org.apache.spark.TaskContext
context) {
- | ${metricTerm(ctx, "spillSize")}.add($matches.spillSize());
- | }
- |});
- """.stripMargin
+ s"$helperCls.recordSpillSizeOnTaskCompletion($matches, ${metricTerm(ctx,
"spillSize")});"
s"""
|if (!$initJoin) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]