This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2db8cfb3bd9 [SPARK-44060][SQL] Code-gen for build side outer shuffled
hash join
2db8cfb3bd9 is described below
commit 2db8cfb3bd9bf5e85379c6d5ca414d36cfd9292d
Author: Szehon Ho <[email protected]>
AuthorDate: Fri Jun 30 22:04:22 2023 -0700
[SPARK-44060][SQL] Code-gen for build side outer shuffled hash join
### What changes were proposed in this pull request?
Codegen of shuffled hash join of build side outer join (ie, left outer join
build left or right outer join build right)
### Why are the changes needed?
The implementation of https://github.com/apache/spark/pull/41398 was only
for non-codegen version, and codegen was disabled in this scenario.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit test in WholeStageCodegenSuite
Closes #41614 from szehon-ho/same_side_outer_join_codegen_master.
Authored-by: Szehon Ho <[email protected]>
Signed-off-by: huaxingao <[email protected]>
---
.../org/apache/spark/sql/internal/SQLConf.scala | 9 ++
.../sql/execution/joins/ShuffledHashJoinExec.scala | 68 ++++++----
.../scala/org/apache/spark/sql/JoinSuite.scala | 146 +++++++++++----------
.../sql/execution/WholeStageCodegenSuite.scala | 89 +++++++++++++
4 files changed, 217 insertions(+), 95 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d60f5d170e7..270508139e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2182,6 +2182,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN =
+ buildConf("spark.sql.codegen.join.buildSideOuterShuffledHashJoin.enabled")
+ .internal()
+ .doc("When true, enable code-gen for an OUTER shuffled hash join where
outer side" +
+ " is the build side.")
+ .version("3.5.0")
+ .booleanConf
+ .createWithDefault(true)
+
val ENABLE_FULL_OUTER_SORT_MERGE_JOIN_CODEGEN =
buildConf("spark.sql.codegen.join.fullOuterSortMergeJoin.enabled")
.internal()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 8953bf19f35..974f6f9e50c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -340,8 +340,10 @@ case class ShuffledHashJoinExec(
override def supportCodegen: Boolean = joinType match {
case FullOuter =>
conf.getConf(SQLConf.ENABLE_FULL_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
- case LeftOuter if buildSide == BuildLeft => false
- case RightOuter if buildSide == BuildRight => false
+ case LeftOuter if buildSide == BuildLeft =>
+ conf.getConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
+ case RightOuter if buildSide == BuildRight =>
+ conf.getConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
case _ => true
}
@@ -362,9 +364,15 @@ case class ShuffledHashJoinExec(
}
override def doProduce(ctx: CodegenContext): String = {
- // Specialize `doProduce` code for full outer join, because full outer
join needs to
- // iterate streamed and build side separately.
- if (joinType != FullOuter) {
+ // Specialize `doProduce` code for full outer join and build-side outer
join,
+ // because we need to iterate streamed and build side separately.
+ val specializedProduce = joinType match {
+ case FullOuter => true
+ case LeftOuter if buildSide == BuildLeft => true
+ case RightOuter if buildSide == BuildRight => true
+ case _ => false
+ }
+ if (!specializedProduce) {
return super.doProduce(ctx)
}
@@ -407,21 +415,24 @@ case class ShuffledHashJoinExec(
case BuildLeft => buildResultVars ++ streamedResultVars
case BuildRight => streamedResultVars ++ buildResultVars
}
- val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
- ctx.addNewFunction(consumeFullOuterJoinRow,
+ val consumeOuterJoinRow = ctx.freshName("consumeOuterJoinRow")
+ ctx.addNewFunction(consumeOuterJoinRow,
s"""
- |private void $consumeFullOuterJoinRow() throws java.io.IOException {
+ |private void $consumeOuterJoinRow() throws java.io.IOException {
| ${metricTerm(ctx, "numOutputRows")}.add(1);
| ${consume(ctx, resultVars)}
|}
""".stripMargin)
- val joinWithUniqueKey = codegenFullOuterJoinWithUniqueKey(
+ val isFullOuterJoin = joinType == FullOuter
+ val joinWithUniqueKey = codegenBuildSideOrFullOuterJoinWithUniqueKey(
ctx, (streamedRow, buildRow), (streamedInput, buildInput),
streamedKeyEv, streamedKeyAnyNull,
- streamedKeyExprCode.value, relationTerm, conditionCheck,
consumeFullOuterJoinRow)
- val joinWithNonUniqueKey = codegenFullOuterJoinWithNonUniqueKey(
+ streamedKeyExprCode.value, relationTerm, conditionCheck,
consumeOuterJoinRow,
+ isFullOuterJoin)
+ val joinWithNonUniqueKey = codegenBuildSideOrFullOuterJoinNonUniqueKey(
ctx, (streamedRow, buildRow), (streamedInput, buildInput),
streamedKeyEv, streamedKeyAnyNull,
- streamedKeyExprCode.value, relationTerm, conditionCheck,
consumeFullOuterJoinRow)
+ streamedKeyExprCode.value, relationTerm, conditionCheck,
consumeOuterJoinRow,
+ isFullOuterJoin)
s"""
|if ($keyIsUnique) {
@@ -433,10 +444,10 @@ case class ShuffledHashJoinExec(
}
/**
- * Generates the code for full outer join with unique join keys.
- * This is code-gen version of `fullOuterJoinWithUniqueKey()`.
+ * Generates the code for build-side or full outer join with unique join
keys.
+ * This is code-gen version of `buildSideOrFullOuterJoinUniqueKey()`.
*/
- private def codegenFullOuterJoinWithUniqueKey(
+ private def codegenBuildSideOrFullOuterJoinWithUniqueKey(
ctx: CodegenContext,
rows: (String, String),
inputs: (String, String),
@@ -445,7 +456,8 @@ case class ShuffledHashJoinExec(
streamedKeyValue: ExprValue,
relationTerm: String,
conditionCheck: String,
- consumeFullOuterJoinRow: String): String = {
+ consumeOuterJoinRow: String,
+ isFullOuterJoin: Boolean): String = {
// Inline mutable state since not many join operations in a task
val matchedKeySetClsName = classOf[BitSet].getName
val matchedKeySet = ctx.addMutableState(matchedKeySetClsName,
"matchedKeySet",
@@ -484,7 +496,10 @@ case class ShuffledHashJoinExec(
| }
| }
|
- | $consumeFullOuterJoinRow();
+ | if ($foundMatch || $isFullOuterJoin) {
+ | $consumeOuterJoinRow();
+ | }
+ |
| if (shouldStop()) return;
|}
""".stripMargin
@@ -500,7 +515,7 @@ case class ShuffledHashJoinExec(
| // check if key index is not in matched keys set
| if (!$matchedKeySet.get($rowWithIndex.getKeyIndex())) {
| $buildRow = $rowWithIndex.getValue();
- | $consumeFullOuterJoinRow();
+ | $consumeOuterJoinRow();
| }
|
| if (shouldStop()) return;
@@ -514,10 +529,10 @@ case class ShuffledHashJoinExec(
}
/**
- * Generates the code for full outer join with non-unique join keys.
- * This is code-gen version of `fullOuterJoinWithNonUniqueKey()`.
+ * Generates the code for build-side or full outer join with non-unique join
keys.
+ * This is code-gen version of `buildSideOrFullOuterJoinNonUniqueKey()`.
*/
- private def codegenFullOuterJoinWithNonUniqueKey(
+ private def codegenBuildSideOrFullOuterJoinNonUniqueKey(
ctx: CodegenContext,
rows: (String, String),
inputs: (String, String),
@@ -526,7 +541,8 @@ case class ShuffledHashJoinExec(
streamedKeyValue: ExprValue,
relationTerm: String,
conditionCheck: String,
- consumeFullOuterJoinRow: String): String = {
+ consumeOuterJoinRow: String,
+ isFullOuterJoin: Boolean): String = {
// Inline mutable state since not many join operations in a task
val matchedRowSetClsName = classOf[OpenHashSet[_]].getName
val matchedRowSet = ctx.addMutableState(matchedRowSetClsName,
"matchedRowSet",
@@ -572,13 +588,15 @@ case class ShuffledHashJoinExec(
| // set row index in matched row set
| $matchedRowSet.add($rowIndex);
| $foundMatch = true;
- | $consumeFullOuterJoinRow();
+ | $consumeOuterJoinRow();
| }
| }
|
| if (!$foundMatch) {
| $buildRow = null;
- | $consumeFullOuterJoinRow();
+ | if ($isFullOuterJoin) {
+ | $consumeOuterJoinRow();
+ | }
| }
|
| if (shouldStop()) return;
@@ -603,7 +621,7 @@ case class ShuffledHashJoinExec(
| // check if row index is not in matched row set
| if (!$matchedRowSet.contains($rowIndex)) {
| $buildRow = $rowWithIndex.getValue();
- | $consumeFullOuterJoinRow();
+ | $consumeOuterJoinRow();
| }
|
| if (shouldStop()) return;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 4d0fd2e6513..eb58a77704e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -1315,78 +1315,84 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
test("SPARK-36612: Support left outer join build left or right outer join
build right in " +
"shuffled hash join") {
- val inputDFs = Seq(
- // Test unique join key
- (spark.range(10).selectExpr("id as k1"),
- spark.range(30).selectExpr("id as k2"),
- $"k1" === $"k2"),
- // Test non-unique join key
- (spark.range(10).selectExpr("id % 5 as k1"),
- spark.range(30).selectExpr("id % 5 as k2"),
- $"k1" === $"k2"),
- // Test empty build side
- (spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
- spark.range(30).selectExpr("id as k2"),
- $"k1" === $"k2"),
- // Test empty stream side
- (spark.range(10).selectExpr("id as k1"),
- spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
- $"k1" === $"k2"),
- // Test empty build and stream side
- (spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
- spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
- $"k1" === $"k2"),
- // Test string join key
- (spark.range(10).selectExpr("cast(id * 3 as string) as k1"),
- spark.range(30).selectExpr("cast(id as string) as k2"),
- $"k1" === $"k2"),
- // Test build side at right
- (spark.range(30).selectExpr("cast(id / 3 as string) as k1"),
- spark.range(10).selectExpr("cast(id as string) as k2"),
- $"k1" === $"k2"),
- // Test NULL join key
- (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value
as k1"),
- spark.range(30).map(i => if (i % 4 == 0) i else
null).selectExpr("value as k2"),
- $"k1" === $"k2"),
- (spark.range(10).map(i => if (i % 3 == 0) i else null).selectExpr("value
as k1"),
- spark.range(30).map(i => if (i % 5 == 0) i else
null).selectExpr("value as k2"),
- $"k1" === $"k2"),
- // Test multiple join keys
- (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr(
- "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as
long) as k3"),
- spark.range(30).map(i => if (i % 3 == 0) i else null).selectExpr(
- "value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as
long) as k6"),
- $"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6")
- )
-
- // test left outer with left side build
- inputDFs.foreach { case (df1, df2, joinExprs) =>
- val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, "leftouter")
- assert(collect(smjDF.queryExecution.executedPlan) {
- case _: SortMergeJoinExec => true }.size === 1)
- val smjResult = smjDF.collect()
-
- val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, "leftouter")
- assert(collect(shjDF.queryExecution.executedPlan) {
- case _: ShuffledHashJoinExec => true
- }.size === 1)
- // Same result between shuffled hash join and sort merge join
- checkAnswer(shjDF, smjResult)
- }
+ Seq("true", "false").foreach{ codegen =>
+
withSQLConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN.key ->
codegen) {
+ val inputDFs = Seq(
+ // Test unique join key
+ (spark.range(10).selectExpr("id as k1"),
+ spark.range(30).selectExpr("id as k2"),
+ $"k1" === $"k2"),
+ // Test non-unique join key
+ (spark.range(10).selectExpr("id % 5 as k1"),
+ spark.range(30).selectExpr("id % 5 as k2"),
+ $"k1" === $"k2"),
+ // Test empty build side
+ (spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
+ spark.range(30).selectExpr("id as k2"),
+ $"k1" === $"k2"),
+ // Test empty stream side
+ (spark.range(10).selectExpr("id as k1"),
+ spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
+ $"k1" === $"k2"),
+ // Test empty build and stream side
+ (spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
+ spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
+ $"k1" === $"k2"),
+ // Test string join key
+ (spark.range(10).selectExpr("cast(id * 3 as string) as k1"),
+ spark.range(30).selectExpr("cast(id as string) as k2"),
+ $"k1" === $"k2"),
+ // Test build side at right
+ (spark.range(30).selectExpr("cast(id / 3 as string) as k1"),
+ spark.range(10).selectExpr("cast(id as string) as k2"),
+ $"k1" === $"k2"),
+ // Test NULL join key
+ (spark.range(10).map(i => if (i % 2 == 0) i else
null).selectExpr("value as k1"),
+ spark.range(30).map(i => if (i % 4 == 0) i else
null).selectExpr("value as k2"),
+ $"k1" === $"k2"),
+ (spark.range(10).map(i => if (i % 3 == 0) i else
null).selectExpr("value as k1"),
+ spark.range(30).map(i => if (i % 5 == 0) i else
null).selectExpr("value as k2"),
+ $"k1" === $"k2"),
+ // Test multiple join keys
+ (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr(
+ "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3
as long) as k3"),
+ spark.range(30).map(i => if (i % 3 == 0) i else null).selectExpr(
+ "value as k4", "cast(value % 5 as short) as k5", "cast(value * 3
as long) as k6"),
+ $"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6")
+ )
- // test right outer with right side build
- inputDFs.foreach { case (df2, df1, joinExprs) =>
- val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, "rightouter")
- assert(collect(smjDF.queryExecution.executedPlan) {
- case _: SortMergeJoinExec => true }.size === 1)
- val smjResult = smjDF.collect()
+ // test left outer with left side build
+ inputDFs.foreach { case (df1, df2, joinExprs) =>
+ val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs,
"leftouter")
+ assert(collect(smjDF.queryExecution.executedPlan) {
+ case _: SortMergeJoinExec => true
+ }.size === 1)
+ val smjResult = smjDF.collect()
+
+ val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs,
"leftouter")
+ assert(collect(shjDF.queryExecution.executedPlan) {
+ case _: ShuffledHashJoinExec => true
+ }.size === 1)
+ // Same result between shuffled hash join and sort merge join
+ checkAnswer(shjDF, smjResult)
+ }
- val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, "rightouter")
- assert(collect(shjDF.queryExecution.executedPlan) {
- case _: ShuffledHashJoinExec => true
- }.size === 1)
- // Same result between shuffled hash join and sort merge join
- checkAnswer(shjDF, smjResult)
+ // test right outer with right side build
+ inputDFs.foreach { case (df2, df1, joinExprs) =>
+ val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs,
"rightouter")
+ assert(collect(smjDF.queryExecution.executedPlan) {
+ case _: SortMergeJoinExec => true
+ }.size === 1)
+ val smjResult = smjDF.collect()
+
+ val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs,
"rightouter")
+ assert(collect(shjDF.queryExecution.executedPlan) {
+ case _: ShuffledHashJoinExec => true
+ }.size === 1)
+ // Same result between shuffled hash join and sort merge join
+ checkAnswer(shjDF, smjResult)
+ }
+ }
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index ac710c32296..0aaeedd5f06 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -232,6 +232,95 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
}
}
+
+ test("SPARK-44060 Code-gen for build side outer shuffled hash join") {
+ val df1 = spark.range(0, 5).select($"id".as("k1"))
+ val df2 = spark.range(1, 11).select($"id".as("k2"))
+ val df3 = spark.range(2, 5).select($"id".as("k3"))
+
+ withSQLConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN.key
-> "true") {
+ Seq("SHUFFLE_HASH", "SHUFFLE_MERGE").foreach { hint =>
+ // test right join with unique key from build side
+ val rightJoinUniqueDf = df1.join(df2.hint(hint), $"k1" === $"k2",
"right_outer")
+ assert(rightJoinUniqueDf.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_: SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(rightJoinUniqueDf, Seq(Row(1, 1), Row(2, 2), Row(3, 3),
Row(4, 4),
+ Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9),
+ Row(null, 10)))
+ assert(rightJoinUniqueDf.count() === 10)
+
+ // test left join with unique key from build side
+ val leftJoinUniqueDf = df1.hint(hint).join(df2, $"k1" === $"k2",
"left_outer")
+ assert(leftJoinUniqueDf.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_: SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(leftJoinUniqueDf, Seq(Row(0, null), Row(1, 1), Row(2, 2),
Row(3, 3), Row(4, 4)))
+ assert(leftJoinUniqueDf.count() === 5)
+
+ // test right join with non-unique key from build side
+ val rightJoinNonUniqueDf = df1.join(df2.hint(hint), $"k1" === $"k2" %
3, "right_outer")
+ assert(rightJoinNonUniqueDf.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_: SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(rightJoinNonUniqueDf, Seq(Row(0, 3), Row(0, 6), Row(0, 9),
Row(1, 1),
+ Row(1, 4), Row(1, 7), Row(1, 10), Row(2, 2), Row(2, 5), Row(2, 8)))
+
+ // test left join with non-unique key from build side
+ val leftJoinNonUniqueDf = df1.hint(hint).join(df2, $"k1" === $"k2" %
3, "left_outer")
+ assert(leftJoinNonUniqueDf.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_: SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(leftJoinNonUniqueDf, Seq(Row(0, 3), Row(0, 6), Row(0, 9),
Row(1, 1),
+ Row(1, 4), Row(1, 7), Row(1, 10), Row(2, 2), Row(2, 5), Row(2, 8),
Row(3, null),
+ Row(4, null)))
+
+ // test right join with non-equi condition
+ val rightJoinWithNonEquiDf = df1.join(df2.hint(hint),
+ $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "right_outer")
+ assert(rightJoinWithNonEquiDf.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_: SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(rightJoinWithNonEquiDf, Seq(Row(0, 6), Row(0, 9), Row(1,
1), Row(1, 7),
+ Row(1, 10), Row(2, 2), Row(2, 8), Row(null, 3), Row(null, 4),
Row(null, 5)))
+
+ // test left join with non-equi condition
+ val leftJoinWithNonEquiDf = df1.hint(hint).join(df2,
+ $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "left_outer")
+ assert(leftJoinWithNonEquiDf.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_: SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
+ }.size === 1)
+ checkAnswer(leftJoinWithNonEquiDf, Seq(Row(0, 6), Row(0, 9), Row(1,
1), Row(1, 7),
+ Row(1, 10), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null)))
+
+ // test two right joins
+ val twoRightJoinsDf = df1.join(df2.hint(hint), $"k1" === $"k2",
"right_outer")
+ .join(df3.hint(hint), $"k1" === $"k3" && $"k1" + $"k3" =!= 2,
"right_outer")
+ assert(twoRightJoinsDf.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_: SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
+ }.size === 2)
+ checkAnswer(twoRightJoinsDf, Seq(Row(2, 2, 2), Row(3, 3, 3), Row(4, 4,
4)))
+
+ // test two left joins
+ val twoLeftJoinsDf = df1.hint(hint).join(df2, $"k1" === $"k2",
"left_outer").hint(hint)
+ .join(df3, $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "left_outer")
+ assert(twoLeftJoinsDf.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_: SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
+ }.size === 2)
+ checkAnswer(twoLeftJoinsDf,
+ Seq(Row(0, null, null), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, 3),
Row(4, 4, 4)))
+ }
+ }
+ }
+
test("Left/Right Outer SortMergeJoin should be included in
WholeStageCodegen") {
val df1 = spark.range(10).select($"id".as("k1"))
val df2 = spark.range(4).select($"id".as("k2"))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]