Repository: spark Updated Branches: refs/heads/master 6e6320122 -> 95e372141
[SPARK-14781] [SQL] support nested predicate subquery ## What changes were proposed in this pull request? In order to support nested predicate subquery, this PR introduce an internal join type ExistenceJoin, which will emit all the rows from left, plus an additional column, which presents there are any rows matched from right or not (it's not null-aware right now). This additional column could be used to replace the subquery in Filter. In theory, all the predicate subquery could use this join type, but it's slower than LeftSemi and LeftAnti, so it's only used for nested subquery (subquery inside OR). For example, the following SQL: ```sql SELECT a FROM t WHERE EXISTS (select 0) OR EXISTS (select 1) ``` This PR also fix a bug in predicate subquery push down through join (they should not). Nested null-aware subquery is still not supported. For example, `a > 3 OR b NOT IN (select bb from t)` After this, we could run TPCDS query Q10, Q35, Q45 ## How was this patch tested? Added unit tests. Author: Davies Liu <[email protected]> Closes #12820 from davies/or_exists. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/95e37214 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/95e37214 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/95e37214 Branch: refs/heads/master Commit: 95e372141a102f933045fe9472bbe1ce8c91b5d5 Parents: 6e63201 Author: Davies Liu <[email protected]> Authored: Mon May 2 12:58:59 2016 -0700 Committer: Davies Liu <[email protected]> Committed: Mon May 2 12:58:59 2016 -0700 ---------------------------------------------------------------------- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +- .../sql/catalyst/expressions/subquery.scala | 15 +++- .../sql/catalyst/optimizer/Optimizer.scala | 41 +++++++-- .../spark/sql/catalyst/plans/joinTypes.scala | 10 +++ .../plans/logical/basicLogicalOperators.scala | 4 + .../catalyst/analysis/AnalysisErrorSuite.scala | 11 ++- .../spark/sql/execution/SparkStrategies.scala | 1 + .../execution/joins/BroadcastHashJoinExec.scala | 66 +++++++++++++- .../joins/BroadcastNestedLoopJoinExec.scala | 94 ++++++++++++++------ .../spark/sql/execution/joins/HashJoin.scala | 31 ++++++- .../execution/joins/ShuffledHashJoinExec.scala | 13 +-- .../sql/execution/joins/SortMergeJoinExec.scala | 40 +++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 25 ++++++ .../execution/joins/ExistenceJoinSuite.scala | 50 ++++++++++- 14 files changed, 345 insertions(+), 61 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 61a7d9e..6e3a14d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -115,8 +115,9 @@ trait CheckAnalysis extends PredicateHelper { case f @ Filter(condition, child) => splitConjunctivePredicates(condition).foreach { case _: PredicateSubquery | Not(_: PredicateSubquery) => - case e if PredicateSubquery.hasPredicateSubquery(e) => - failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e") + case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => + failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + + s" conditions: $e") case e => } http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index cd6d3a0..eed062f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -92,7 +92,7 @@ case class PredicateSubquery( extends SubqueryExpression with Predicate with Unevaluable { override lazy val resolved = childrenResolved && query.resolved override lazy val references: AttributeSet = super.references -- query.outputSet - override def nullable: Boolean = false + override def nullable: Boolean = nullAware override def plan: LogicalPlan = SubqueryAlias(toString, query) override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan) override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" @@ -105,6 +105,19 @@ object PredicateSubquery { case _ => false }.isDefined } + + /** + * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could + * turn the null-aware predicate into not-null-aware predicate. + */ + def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { + e.find{ x => + x.isInstanceOf[Not] && e.find { + case p: PredicateSubquery => p.nullAware + case _ => false + }.isDefined + }.isDefined + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a147fff..e1c969f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -100,8 +100,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, - EliminateSerialization, - RewritePredicateSubquery) :: + EliminateSerialization) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, @@ -109,7 +108,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("LocalRelation", fixedPoint, ConvertToLocalRelation) :: Batch("OptimizeCodegen", Once, - OptimizeCodegen(conf)) :: Nil + OptimizeCodegen(conf)) :: + Batch("RewriteSubquery", Once, + RewritePredicateSubquery, + CollapseProject) :: Nil } /** @@ -1078,7 +1080,14 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { - Join(input(0), input(1), Inner, conditions.reduceLeftOption(And)) + val (joinConditions, others) = conditions.partition( + e => !PredicateSubquery.hasPredicateSubquery(e)) + val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And)) + if (others.nonEmpty) { + Filter(others.reduceLeft(And), join) + } else { + join + } } else { val left :: rest = input.toList // find out the first join that have at least one join condition @@ -1091,7 +1100,8 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val right = conditionalJoin.getOrElse(rest.head) val joinedRefs = left.outputSet ++ right.outputSet - val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs)) + val (joinConditions, others) = conditions.partition( + e => e.references.subsetOf(joinedRefs) && !PredicateSubquery.hasPredicateSubquery(e)) val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) // should not have reference to same logical plan @@ -1201,9 +1211,16 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) - val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And) + val (newJoinConditions, others) = + commonFilterCondition.partition(e => !PredicateSubquery.hasPredicateSubquery(e)) + val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) - Join(newLeft, newRight, Inner, newJoinCond) + val join = Join(newLeft, newRight, Inner, newJoinCond) + if (others.nonEmpty) { + Filter(others.reduceLeft(And), join) + } else { + join + } case RightOuter => // push down the right side only `where` condition val newLeft = left @@ -1543,6 +1560,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS // if performance matters to you. Join(p, sub, LeftAnti, Option(Or(anyNull, condition))) + case (p, predicate) => + var joined = p + val replaced = predicate transformUp { + case PredicateSubquery(sub, conditions, nullAware, _) => + // TODO: support null-aware join + val exists = AttributeReference("exists", BooleanType, false)() + joined = Join(joined, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) + exists + } + Project(p.output, Filter(replaced, joined)) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 13f57c5..80674d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Attribute object JoinType { def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { @@ -69,6 +70,14 @@ case object LeftAnti extends JoinType { override def sql: String = "LEFT ANTI" } +case class ExistenceJoin(exists: Attribute) extends JoinType { + override def sql: String = { + // This join type is only used in the end of optimizer and physical plans, we will not + // generate SQL for this join type + throw new UnsupportedOperationException + } +} + case class NaturalJoin(tpe: JoinType) extends JoinType { require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), "Unsupported natural join type " + tpe) @@ -84,6 +93,7 @@ case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) exte object LeftExistence { def unapply(joinType: JoinType): Option[JoinType] = joinType match { case LeftSemi | LeftAnti => Some(joinType) + case j: ExistenceJoin => Some(joinType) case _ => None } } http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b2297bb..830a7ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -273,6 +273,8 @@ case class Join( override def output: Seq[Attribute] = { joinType match { + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case LeftOuter => @@ -295,6 +297,8 @@ case class Join( case LeftSemi if condition.isDefined => left.constraints .union(splitConjunctivePredicates(condition.get).toSet) + case j: ExistenceJoin => + left.constraints case Inner => left.constraints.union(right.constraints) case LeftExistence(_) => http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 1b08913..10bff3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -459,11 +459,14 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(In(a, Seq(ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) - assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil) + val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), + LocalRelation(a)) + assertAnalysisError(plan1, + "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(In(a, Seq(ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) - assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil) + val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + assertAnalysisError(plan2, + "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } test("PredicateSubQuery correlated predicate is nested in an illegal plan") { http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 238334e..9747e58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -118,6 +118,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { private def canBuildRight(joinType: JoinType): Boolean = joinType match { case Inner | LeftOuter | LeftSemi | LeftAnti => true + case j: ExistenceJoin => true case _ => false } http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 587c603..7c194ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -48,8 +48,6 @@ case class BroadcastHashJoinExec( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) buildSide match { @@ -85,6 +83,7 @@ case class BroadcastHashJoinExec( case LeftOuter | RightOuter => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) case LeftAnti => codegenAnti(ctx, input) + case j: ExistenceJoin => codegenExistence(ctx, input) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") @@ -407,4 +406,67 @@ case class BroadcastHashJoinExec( """.stripMargin } } + + /** + * Generates the code for existence join. + */ + private def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + val existsVar = ctx.freshName("exists") + + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + s""" + |$eval + |${ev.code} + |$existsVar = !${ev.isNull} && ${ev.value}; + """.stripMargin + } else { + s"$existsVar = true;" + } + + val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |boolean $existsVar = false; + |if ($matched != null) { + | $checkCondition + |} + |$numOutput.add(1); + |${consume(ctx, resultVar)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |boolean $existsVar = false; + |if ($matches != null) { + | while (!$existsVar && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | } + |} + |$numOutput.add(1); + |${consume(ctx, resultVar)} + """.stripMargin + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index a659bf2..2a250ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -50,19 +50,16 @@ case class BroadcastNestedLoopJoinExec( UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil } - private[this] def genResultProjection: InternalRow => InternalRow = { - if (joinType == LeftSemi) { + private[this] def genResultProjection: InternalRow => InternalRow = joinType match { + case LeftExistence(j) => UnsafeProjection.create(output, output) - } else { + case other => // Always put the stream side on left to simplify implementation // both of left and right side could be null UnsafeProjection.create( output, (streamed.output ++ broadcast.output).map(_.withNullability(true))) - } } - override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def output: Seq[Attribute] = { joinType match { case Inner => @@ -73,6 +70,8 @@ case class BroadcastNestedLoopJoinExec( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case x => @@ -197,6 +196,28 @@ case class BroadcastNestedLoopJoinExec( } } + private def existenceJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + assert(buildSide == BuildRight) + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + + if (condition.isDefined) { + val resultRow = new GenericMutableRow(Array[Any](null)) + streamedIter.map { row => + val result = buildRows.exists(r => boundCondition(joinedRow(row, r))) + resultRow.setBoolean(0, result) + joinedRow(row, resultRow) + } + } else { + val resultRow = new GenericMutableRow(Array[Any](buildRows.nonEmpty)) + streamedIter.map { row => + joinedRow(row, resultRow) + } + } + } + } + /** * The implementation for these joins: * @@ -204,7 +225,8 @@ case class BroadcastNestedLoopJoinExec( * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft - * Anti with BuildLeft + * LeftAnti with BuildLeft + * ExistenceJoin with BuildLeft */ private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { /** All rows that either match both-way, or rows from streamed joined with nulls. */ @@ -231,27 +253,50 @@ case class BroadcastNestedLoopJoinExec( new BitSet(relation.value.length) )(_ | _) - if (joinType == LeftSemi) { - assert(buildSide == BuildLeft) - val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val rel = relation.value - while (i < rel.length) { - if (matchedBroadcastRows.get(i)) { - buf += rel(i).copy() + joinType match { + case LeftSemi => + assert(buildSide == BuildLeft) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (matchedBroadcastRows.get(i)) { + buf += rel(i).copy() + } + i += 1 } - i += 1 - } - return sparkContext.makeRDD(buf) + return sparkContext.makeRDD(buf) + case j: ExistenceJoin => + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + val result = new GenericInternalRow(Array[Any](matchedBroadcastRows.get(i))) + buf += new JoinedRow(rel(i).copy(), result) + i += 1 + } + return sparkContext.makeRDD(buf) + case LeftAnti => + val notMatched: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (!matchedBroadcastRows.get(i)) { + notMatched += rel(i).copy() + } + i += 1 + } + return sparkContext.makeRDD(notMatched) + case o => } val notMatchedBroadcastRows: Seq[InternalRow] = { val nulls = new GenericMutableRow(streamed.output.size) val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val buildRows = relation.value val joinedRow = new JoinedRow joinedRow.withLeft(nulls) + var i = 0 + val buildRows = relation.value while (i < buildRows.length) { if (!matchedBroadcastRows.get(i)) { buf += joinedRow.withRight(buildRows(i)).copy() @@ -261,10 +306,6 @@ case class BroadcastNestedLoopJoinExec( buf } - if (joinType == LeftAnti) { - return sparkContext.makeRDD(notMatchedBroadcastRows) - } - val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow @@ -308,13 +349,16 @@ case class BroadcastNestedLoopJoinExec( leftExistenceJoin(broadcastedRelation, exists = true) case (LeftAnti, BuildRight) => leftExistenceJoin(broadcastedRelation, exists = false) + case (j: ExistenceJoin, BuildRight) => + existenceJoin(broadcastedRelation) case _ => /** * LeftOuter with BuildLeft * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft - * Anti with BuildLeft + * LeftAnti with BuildLeft + * ExistenceJoin with BuildLeft */ defaultJoin(broadcastedRelation) } http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 9c173d7..d46a804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{IntegralType, LongType} @@ -43,6 +44,8 @@ trait HashJoin { left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case x => @@ -50,6 +53,8 @@ trait HashJoin { } } + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + protected lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) @@ -110,15 +115,14 @@ trait HashJoin { (r: InternalRow) => true } - protected def createResultProjection(): (InternalRow) => InternalRow = { - if (joinType == LeftSemi) { + protected def createResultProjection(): (InternalRow) => InternalRow = joinType match { + case LeftExistence(_) => UnsafeProjection.create(output, output) - } else { + case _ => // Always put the stream side on left to simplify implementation // both of left and right side could be null UnsafeProjection.create( output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true))) - } } private def innerJoin( @@ -184,6 +188,23 @@ trait HashJoin { } } + private def existenceJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() + val result = new GenericMutableRow(Array[Any](null)) + val joinedRow = new JoinedRow + streamIter.map { current => + val key = joinKeys(current) + lazy val buildIter = hashedRelation.get(key) + val exists = !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + }) + result.setBoolean(0, exists) + joinedRow(current, result) + } + } + private def antiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { @@ -212,6 +233,8 @@ trait HashJoin { semiJoin(streamedIter, hashed) case LeftAnti => antiJoin(streamedIter, hashed) + case j: ExistenceJoin => + existenceJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 3ef2fec..0036f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} @@ -44,17 +44,6 @@ case class ShuffledHashJoinExec( "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) - override def outputPartitioning: Partitioning = joinType match { - case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - case LeftAnti => left.outputPartitioning - case LeftSemi => left.outputPartitioning - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType") - } - override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 775f8ac..f0efa52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -53,6 +53,8 @@ case class SortMergeJoinExec( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => (left.output ++ right.output).map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case x => @@ -269,6 +271,44 @@ case class SortMergeJoinExec( override def getRow: InternalRow = currentLeftRow }.toScala + case j: ExistenceJoin => + new RowIterator { + private[this] var currentLeftRow: InternalRow = _ + private[this] val result: MutableRow = new GenericMutableRow(Array[Any](null)) + private[this] val smjScanner = new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) + private[this] val joinRow = new JoinedRow + + override def advanceNext(): Boolean = { + while (smjScanner.findNextOuterJoinRows()) { + currentLeftRow = smjScanner.getStreamedRow + val currentRightMatches = smjScanner.getBufferedMatches + var found = false + if (currentRightMatches != null) { + var i = 0 + while (!found && i < currentRightMatches.length) { + joinRow(currentLeftRow, currentRightMatches(i)) + if (boundCondition(joinRow)) { + found = true + } + i += 1 + } + } + result.setBoolean(0, found) + numOutputRows += 1 + return true + } + false + } + + override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result)) + }.toScala + case x => throw new IllegalArgumentException( s"SortMergeJoin should not take $x as the JoinType") http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 0bf4c6f..ff3f9bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -152,6 +152,19 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) } + test("EXISTS predicate subquery within OR") { + checkAnswer( + sql("select * from l where exists (select * from r where l.a = r.c)" + + " or exists (select * from r where l.a = r.c)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + checkAnswer( + sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)" + + " or not exists (select * from r where l.a = r.c)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) + } + test("IN predicate subquery") { checkAnswer( sql("select * from l where l.a in (select c from r)"), @@ -187,6 +200,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } + test("IN predicate subquery within OR") { + checkAnswer( + sql("select * from l where l.a in (select c from r)" + + " or l.a in (select c from r where l.b < r.d)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + intercept[AnalysisException] { + sql("select * from l where a not in (select c from r)" + + " or a not in (select c from r where c is not null)") + } + } + test("complex IN predicate subquery") { checkAnswer( sql("select * from l where (a, b) not in (select c, d from r)"), http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index b32b644..8093054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -18,15 +18,15 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructType} class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { @@ -89,6 +89,18 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } + val existsAttr = AttributeReference("exists", BooleanType, false)() + val leftSemiPlus = ExistenceJoin(existsAttr) + def createLeftSemiPlusJoin(join: SparkPlan): SparkPlan = { + val output = join.output.dropRight(1) + val condition = if (joinType == LeftSemi) { + existsAttr + } else { + Not(existsAttr) + } + ProjectExec(output, FilterExec(condition, join)) + } + test(s"$testName using ShuffledHashJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -98,6 +110,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(ShuffledHashJoinExec( + leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, left, right))), + expectedAnswer, + sortAnswers = true) } } } @@ -111,6 +129,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(BroadcastHashJoinExec( + leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, left, right))), + expectedAnswer, + sortAnswers = true) } } } @@ -123,6 +147,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(SortMergeJoinExec( + leftKeys, rightKeys, leftSemiPlus, boundCondition, left, right))), + expectedAnswer, + sortAnswers = true) } } } @@ -134,6 +164,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition))), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec( + left, right, BuildLeft, leftSemiPlus, Some(condition)))), + expectedAnswer, + sortAnswers = true) } } @@ -144,6 +180,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition))), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec( + left, right, BuildRight, leftSemiPlus, Some(condition)))), + expectedAnswer, + sortAnswers = true) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
