[SPARK-22136][SS] Implement stream-stream outer joins. ## What changes were proposed in this pull request?
Allow one-sided outer joins between two streams when a watermark is defined. ## How was this patch tested? new unit tests Author: Jose Torres <j...@databricks.com> Closes #19327 from joseph-torres/outerjoin. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3099c574 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3099c574 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3099c574 Branch: refs/heads/master Commit: 3099c574c56cab86c3fcf759864f89151643f837 Parents: 5f69433 Author: Jose Torres <j...@databricks.com> Authored: Tue Oct 3 21:42:51 2017 -0700 Committer: Tathagata Das <tathagata.das1...@gmail.com> Committed: Tue Oct 3 21:42:51 2017 -0700 ---------------------------------------------------------------------- .../catalyst/analysis/StreamingJoinHelper.scala | 286 ++++++++++++++++++ .../analysis/UnsupportedOperationChecker.scala | 53 +++- .../analysis/StreamingJoinHelperSuite.scala | 140 +++++++++ .../analysis/UnsupportedOperationsSuite.scala | 108 ++++++- .../StreamingSymmetricHashJoinExec.scala | 152 ++++++++-- .../StreamingSymmetricHashJoinHelper.scala | 241 +-------------- .../state/SymmetricHashJoinStateManager.scala | 200 +++++++++---- .../SymmetricHashJoinStateManagerSuite.scala | 6 +- .../sql/streaming/StreamingJoinSuite.scala | 298 ++++++++++++------- 9 files changed, 1051 insertions(+), 433 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala new file mode 100644 index 0000000..072dc95 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + + +/** + * Helper object for stream joins. See [[StreamingSymmetricHashJoinExec]] in SQL for more details. + */ +object StreamingJoinHelper extends PredicateHelper with Logging { + + /** + * Check the provided logical plan to see if its join keys contain a watermark attribute. + * + * Will return false if the plan is not an equijoin. + * @param plan the logical plan to check + */ + def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { + plan match { + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + (leftKeys ++ rightKeys).exists { + case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) + case _ => false + } + case _ => false + } + } + + /** + * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) + * given the join condition and the event time watermark. This is how it works. + * - The condition is split into conjunctive predicates, and we find the predicates of the + * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). + * - We canoncalize the predicate and solve it with the event time watermark value to find the + * value of the state watermark. + * This function is supposed to make best-effort attempt to get the state watermark. If there is + * any error, it will return None. + * + * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark + * is to be calculated + * @param attributesWithEventWatermark attributes of the other side which has a watermark column + * @param joinCondition join condition + * @param eventWatermark watermark defined on the input event data + * @return state value watermark in milliseconds, is possible. + */ + def getStateValueWatermark( + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + joinCondition: Option[Expression], + eventWatermark: Option[Long]): Option[Long] = { + + // If condition or event time watermark is not provided, then cannot calculate state watermark + if (joinCondition.isEmpty || eventWatermark.isEmpty) return None + + // If there is not watermark attribute, then cannot define state watermark + if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None + + def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { + try { + getStateWatermarkFromLessThenPredicate( + l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) + } catch { + case NonFatal(e) => + logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) + None + } + } + + val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => + + // The generated the state watermark cleanup expression is inclusive of the state watermark. + // If state watermark is W, all state where timestamp <= W will be cleaned up. + // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean + // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. + val stateWatermark = predicate match { + case LessThan(l, r) => getStateWatermarkSafely(l, r) + case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) + case GreaterThan(l, r) => getStateWatermarkSafely(r, l) + case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) + case _ => None + } + if (stateWatermark.nonEmpty) { + logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") + } + stateWatermark + } + allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) + } + + /** + * Extract the state value watermark (milliseconds) from the condition + * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for + * leftTime using the watermark on the rightTime. Example: + * + * Input: rightTime-with-watermark + c1 < leftTime + c2 + * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 + * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime + * With watermark value: watermark-value + c1 + (-c2) < leftTime + */ + private def getStateWatermarkFromLessThenPredicate( + leftExpr: Expression, + rightExpr: Expression, + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + eventWatermark: Option[Long]): Option[Long] = { + + val attributesInCondition = AttributeSet( + leftExpr.collect { case a: AttributeReference => a } ++ + rightExpr.collect { case a: AttributeReference => a } + ) + if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || + attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { + // If more than attributes present in condition from one side, then it cannot be solved + return None + } + + def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { + e.collectLeaves().collectFirst { + case a @ AttributeReference(_, _, _, _) + if attributesToFindStateWatermarkFor.contains(a) => a + }.nonEmpty + } + + // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 + val allOnLeftExpr = Subtract(leftExpr, rightExpr) + logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") + + // Canonicalization step 2: extract commutative terms + // rightTime-with-watermark, c1, -leftTime, -c2 + val terms = ExpressionSet(collectTerms(allOnLeftExpr)) + logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) + + // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor + val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) + + // Verify there is only one correct constraint term and of the correct type + if (constraintTerms.size > 1) { + logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + if (constraintTerms.isEmpty) { + logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + val constraintTerm = constraintTerms.head + if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { + // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` + // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. + // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark + // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about + // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, + // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. + return None + } + + // Replace watermark attribute with watermark value, and generate the resolved expression + // from the other terms. That is, + // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) + logDebug(s"Constraint term from join condition:\t$constraintTerm") + val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => + term.transform { + case a @ AttributeReference(_, _, _, metadata) + if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => + Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) + } + }.reduceLeft(Add) + + // Calculate the constraint value + logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") + val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] + Some((Double2double(constraintValue) / 1000.0).toLong) + } + + /** + * Collect all the terms present in an expression after converting it into the form + * a + b + c + d where each term be either an attribute or a literal casted to long, + * optionally wrapped in a unary minus. + */ + private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { + var invalid = false + + /** Wrap a term with UnaryMinus if its needs to be negated. */ + def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { + if (minus) UnaryMinus(expr) else expr + } + + /** + * Recursively split the expression into its leaf terms contains attributes or literals. + * Returns terms only of the forms: + * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), + * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) + * Multiply(Literal), UnaryMinus(Multiply(Literal)) + * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) + * + * Note: + * - If term needs to be negated for making it a commutative term, + * then it will be wrapped in UnaryMinus(...) + * - Each terms will be representing timestamp value or time interval in microseconds, + * typed as doubles. + */ + def collect(expr: Expression, negate: Boolean): Seq[Expression] = { + expr match { + case Add(left, right) => + collect(left, negate) ++ collect(right, negate) + case Subtract(left, right) => + collect(left, negate) ++ collect(right, !negate) + case TimeAdd(left, right, _) => + collect(left, negate) ++ collect(right, negate) + case TimeSub(left, right, _) => + collect(left, negate) ++ collect(right, !negate) + case UnaryMinus(child) => + collect(child, !negate) + case CheckOverflow(child, _) => + collect(child, negate) + case Cast(child, dataType, _) => + dataType match { + case _: NumericType | _: TimestampType => collect(child, negate) + case _ => + invalid = true + Seq.empty + } + case a: AttributeReference => + val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a + Seq(negateIfNeeded(castedRef, negate)) + case lit: Literal => + // If literal of type calendar interval, then explicitly convert to millis + // Convert other number like literal to doubles representing millis (by x1000) + val castedLit = lit.dataType match { + case CalendarIntervalType => + val calendarInterval = lit.value.asInstanceOf[CalendarInterval] + if (calendarInterval.months > 0) { + invalid = true + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom " + + s"as imprecise intervals like months and years cannot be used for" + + s"watermark calculation. Use interval in terms of day instead.") + Literal(0.0) + } else { + Literal(calendarInterval.microseconds.toDouble) + } + case DoubleType => + Multiply(lit, Literal(1000000.0)) + case _: NumericType => + Multiply(Cast(lit, DoubleType), Literal(1000000.0)) + case _: TimestampType => + Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) + } + Seq(negateIfNeeded(castedLit, negate)) + case a @ _ => + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") + invalid = true + Seq.empty + } + } + + val terms = collect(exprToCollectFrom, negate = false) + if (!invalid) terms else Seq.empty + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d1d7056..dee6fbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -217,7 +218,7 @@ object UnsupportedOperationChecker { throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") - case Join(left, right, joinType, _) => + case Join(left, right, joinType, condition) => joinType match { @@ -233,16 +234,52 @@ object UnsupportedOperationChecker { throwError("Full outer joins with streaming DataFrames/Datasets are not supported") } - case LeftOuter | LeftSemi | LeftAnti => + case LeftSemi | LeftAnti => if (right.isStreaming) { - throwError("Left outer/semi/anti joins with a streaming DataFrame/Dataset " + - "on the right is not supported") + throwError("Left semi/anti joins with a streaming DataFrame/Dataset " + + "on the right are not supported") } + // We support streaming left outer joins with static on the right always, and with + // stream on both sides under the appropriate conditions. + case LeftOuter => + if (!left.isStreaming && right.isStreaming) { + throwError("Left outer join with a streaming DataFrame/Dataset " + + "on the right and a static DataFrame/Dataset on the left is not supported") + } else if (left.isStreaming && right.isStreaming) { + val watermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + left.outputSet, right.outputSet, condition, Some(1000000)).isDefined + + if (!watermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } + } + + // We support streaming right outer joins with static on the left always, and with + // stream on both sides under the appropriate conditions. case RightOuter => - if (left.isStreaming) { - throwError("Right outer join with a streaming DataFrame/Dataset on the left is " + - "not supported") + if (left.isStreaming && !right.isStreaming) { + throwError("Right outer join with a streaming DataFrame/Dataset on the left and " + + "a static DataFrame/DataSet on the right not supported") + } else if (left.isStreaming && right.isStreaming) { + val isWatermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + // Check if the nullable side has a watermark, and there's a range condition which + // implies a state value watermark on the first side. + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + right.outputSet, left.outputSet, condition, Some(1000000)).isDefined + + if (!isWatermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } } case NaturalJoin(_) | UsingJoin(_, _) => http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala new file mode 100644 index 0000000..8cf41a0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter, LeafNode, LocalRelation} +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, TimestampType} + +class StreamingJoinHelperSuite extends AnalysisTest { + + test("extract watermark from time condition") { + val attributesToFindConstraintFor = Seq( + AttributeReference("leftTime", TimestampType)(), + AttributeReference("leftOther", IntegerType)()) + val metadataWithWatermark = new MetadataBuilder() + .putLong(EventTimeWatermark.delayKey, 1000) + .build() + val attributesWithWatermark = Seq( + AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), + AttributeReference("rightOther", IntegerType)()) + + case class DummyLeafNode() extends LeafNode { + override def output: Seq[Attribute] = + attributesToFindConstraintFor ++ attributesWithWatermark + } + + def watermarkFrom( + conditionStr: String, + rightWatermark: Option[Long] = Some(10000)): Option[Long] = { + val conditionExpr = Some(conditionStr).map { str => + val plan = + Filter( + CatalystSqlParser.parseExpression(str), + DummyLeafNode()) + val optimized = SimpleTestOptimizer.execute(SimpleAnalyzer.execute(plan)) + optimized.asInstanceOf[Filter].condition + } + StreamingJoinHelper.getStateValueWatermark( + AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), + conditionExpr, rightWatermark) + } + + // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, + // then cannot define constraint on leftTime. + assert(watermarkFrom("leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) + assert(watermarkFrom("leftTime < rightTime") === None) + assert(watermarkFrom("leftTime <= rightTime") === None) + assert(watermarkFrom("rightTime > leftTime") === None) + assert(watermarkFrom("rightTime >= leftTime") === None) + assert(watermarkFrom("rightTime < leftTime") === Some(10000)) + assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) + + // Test type conversions + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) + + // Test with timestamp type + calendar interval on either side of equation + // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. + assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) + assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) + assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) + assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) + assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") + === Some(12000)) + + // Test with casted long type + constants on either side of equation + // Note: long type and constants commute, so more combinations to test. + // -- Constants on the right + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") + === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") + === Some(10100)) + assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") + === Some(10200)) + // -- Constants on the left + assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) + assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) + assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") + === Some(7000)) + assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") + === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") + === Some(10100)) + // -- Constants on both sides, mixed types + assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") + === Some(13000)) + + // Test multiple conditions, should return minimum watermark + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === + Some(7000)) // first condition wins + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === + Some(6000)) // second condition wins + + // Test invalid comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes + assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed + + // Test static comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) + + // Test non-positive results + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 10") === Some(0)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 100") === Some(-90000)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 11f48a3..e5057c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} +import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -417,9 +418,57 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBinaryOperationInStreamingPlan( "left outer join", _.join(_, joinType = LeftOuter), - streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + streamStreamSupported = false, + expectedMsg = "outer join") + + // Left outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Left outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute > rightTimeWithWatermark + 10)) + }, + OutputMode.Append()) + + // Left outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute < rightTimeWithWatermark + 10)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Left semi joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -427,7 +476,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftSemi), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Left anti joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -435,14 +484,63 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftAnti), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Right outer joins: stream-* not allowed testBinaryOperationInStreamingPlan( "right outer join", _.join(_, joinType = RightOuter), + streamBatchSupported = false, streamStreamSupported = false, - streamBatchSupported = false) + expectedMsg = "outer join") + + // Right outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Right outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 < attribute)) + }, + OutputMode.Append()) + + // Right outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 > attribute)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Cogroup: only batch-batch is allowed testBinaryOperationInStreamingPlan( http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 44f1fa5..9bd2127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -146,7 +146,14 @@ case class StreamingSymmetricHashJoinExec( stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) } - require(joinType == Inner, s"${getClass.getSimpleName} should not take $joinType as the JoinType") + private def throwBadJoinTypeException(): Nothing = { + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $joinType as the JoinType") + } + + require( + joinType == Inner || joinType == LeftOuter || joinType == RightOuter, + s"${getClass.getSimpleName} should not take $joinType as the JoinType") require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) private val storeConf = new StateStoreConf(sqlContext.conf) @@ -157,11 +164,18 @@ case class StreamingSymmetricHashJoinExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output ++ right.output + override def output: Seq[Attribute] = joinType match { + case _: InnerLike => left.output ++ right.output + case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case _ => throwBadJoinTypeException() + } override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => PartitioningCollection(Seq(left.outputPartitioning)) + case RightOuter => PartitioningCollection(Seq(right.outputPartitioning)) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -207,31 +221,108 @@ case class StreamingSymmetricHashJoinExec( // matching new left input with new right input, since the new left input has become stored // by that point. This tiny asymmetry is necessary to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(inputRow).withRight(matchedRow) + (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(input).withRight(matched) } val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(matchedRow).withRight(inputRow) + (input: UnsafeRow, matched: UnsafeRow) => joinedRow.withLeft(matched).withRight(input) } // Filter the joined rows based on the given condition. - val outputFilterFunction = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output).eval _ - val filteredOutputIter = - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { row => - numOutputRows += 1 - row - } + val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ + + // We need to save the time that the inner join output iterator completes, since outer join + // output counts as both update and removal time. + var innerOutputCompletionTimeNs: Long = 0 + def onInnerOutputCompletion = { + innerOutputCompletionTimeNs = System.nanoTime + } + val filteredInnerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), onInnerOutputCompletion) + + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { + rightSideJoiner.get(leftKeyValue.key).exists( + rightValue => { + outputFilterFunction( + joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) + }) + } + + def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { + leftSideJoiner.get(rightKeyValue.key).exists( + leftValue => { + outputFilterFunction( + joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) + }) + } + + val outputIter: Iterator[InternalRow] = joinType match { + case Inner => + filteredInnerOutputIter + case LeftOuter => + // We generate the outer join input by: + // * Getting an iterator over the rows that have aged out on the left side. These rows are + // candidates for being null joined. Note that to avoid doing two passes, this iterator + // removes the rows from the state manager as they're processed. + // * Checking whether the current row matches a key in the right side state, and that key + // has any value which satisfies the filter function when joined. If it doesn't, + // we know we can join with null, since there was never (including this batch) a match + // within the watermark period. If it does, there must have been a match at some point, so + // we know we can't join with null. + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + val removedRowIter = leftSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithRightSideState(pair)) + .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) + + filteredInnerOutputIter ++ outerOutputIter + case RightOuter => + // See comments for left outer case. + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val removedRowIter = rightSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithLeftSideState(pair)) + .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) + + filteredInnerOutputIter ++ outerOutputIter + case _ => throwBadJoinTypeException() + } + + val outputIterWithMetrics = outputIter.map { row => + numOutputRows += 1 + row + } // Function to remove old state after all the input has been consumed and output generated def onOutputCompletion = { + // All processing time counts as update time. allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0) - // Remove old state if needed + // Processing time between inner output completion and here comes from the outer portion of a + // join, and thus counts as removal time as we remove old state from one side while iterating. + if (innerOutputCompletionTimeNs != 0) { + allRemovalsTimeMs += + math.max(NANOSECONDS.toMillis(System.nanoTime - innerOutputCompletionTimeNs), 0) + } + allRemovalsTimeMs += timeTakenMs { - leftSideJoiner.removeOldState() - rightSideJoiner.removeOldState() + // Remove any remaining state rows which aren't needed because they're below the watermark. + // + // For inner joins, we have to remove unnecessary state rows from both sides if possible. + // For outer joins, we have already removed unnecessary state rows from the outer side + // (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we + // have to remove unnecessary state rows from the other side (e.g., right side for the left + // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal + // needs to be done greedily by immediately consuming the returned iterator. + val cleanupIter = joinType match { + case Inner => + leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case LeftOuter => rightSideJoiner.removeOldState() + case RightOuter => leftSideJoiner.removeOldState() + case _ => throwBadJoinTypeException() + } + while (cleanupIter.hasNext) { + cleanupIter.next() + } } // Commit all state changes and update state store metrics @@ -251,7 +342,8 @@ case class StreamingSymmetricHashJoinExec( } } - CompletionIterator[InternalRow, Iterator[InternalRow]](filteredOutputIter, onOutputCompletion) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIterWithMetrics, onOutputCompletion) } /** @@ -324,14 +416,32 @@ case class StreamingSymmetricHashJoinExec( } } - /** Remove old buffered state rows using watermarks for state keys and values */ - def removeOldState(): Unit = { + /** + * Get an iterator over the values stored in this joiner's state manager for the given key. + * + * Should not be interleaved with mutations. + */ + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + joinStateManager.get(key) + } + + /** + * Builds an iterator over old state key-value pairs, removing them lazily as they're produced. + * + * @note This iterator must be consumed fully before any other operations are made + * against this joiner's join state manager. For efficiency reasons, the intermediate states of + * the iterator leave the state manager in an undefined state. + * + * We do this to avoid requiring either two passes or full materialization when + * processing the rows for outer join. + */ + def removeOldState(): Iterator[UnsafeRowPair] = { stateWatermarkPredicate match { case Some(JoinStateKeyWatermarkPredicate(expr)) => joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc) case Some(JoinStateValueWatermarkPredicate(expr)) => joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc) - case _ => + case _ => Iterator.empty } } http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index e50274a..64c7189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import org.apache.spark.{Partition, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression @@ -34,7 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval /** * Helper object for [[StreamingSymmetricHashJoinExec]]. See that object for more details. */ -object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { +object StreamingSymmetricHashJoinHelper extends Logging { sealed trait JoinSide case object LeftSide extends JoinSide { override def toString(): String = "left" } @@ -111,7 +112,7 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { expr.map(JoinStateKeyWatermarkPredicate.apply _) } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs - val stateValueWatermark = getStateValueWatermark( + val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark( attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, @@ -133,242 +134,6 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { } /** - * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) - * given the join condition and the event time watermark. This is how it works. - * - The condition is split into conjunctive predicates, and we find the predicates of the - * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). - * - We canoncalize the predicate and solve it with the event time watermark value to find the - * value of the state watermark. - * This function is supposed to make best-effort attempt to get the state watermark. If there is - * any error, it will return None. - * - * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark - * is to be calculated - * @param attributesWithEventWatermark attributes of the other side which has a watermark column - * @param joinCondition join condition - * @param eventWatermark watermark defined on the input event data - * @return state value watermark in milliseconds, is possible. - */ - def getStateValueWatermark( - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - joinCondition: Option[Expression], - eventWatermark: Option[Long]): Option[Long] = { - - // If condition or event time watermark is not provided, then cannot calculate state watermark - if (joinCondition.isEmpty || eventWatermark.isEmpty) return None - - // If there is not watermark attribute, then cannot define state watermark - if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None - - def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { - try { - getStateWatermarkFromLessThenPredicate( - l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) - } catch { - case NonFatal(e) => - logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) - None - } - } - - val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => - - // The generated the state watermark cleanup expression is inclusive of the state watermark. - // If state watermark is W, all state where timestamp <= W will be cleaned up. - // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean - // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. - val stateWatermark = predicate match { - case LessThan(l, r) => getStateWatermarkSafely(l, r) - case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) - case GreaterThan(l, r) => getStateWatermarkSafely(r, l) - case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) - case _ => None - } - if (stateWatermark.nonEmpty) { - logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") - } - stateWatermark - } - allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) - } - - /** - * Extract the state value watermark (milliseconds) from the condition - * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for - * leftTime using the watermark on the rightTime. Example: - * - * Input: rightTime-with-watermark + c1 < leftTime + c2 - * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 - * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime - * With watermark value: watermark-value + c1 + (-c2) < leftTime - */ - private def getStateWatermarkFromLessThenPredicate( - leftExpr: Expression, - rightExpr: Expression, - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - eventWatermark: Option[Long]): Option[Long] = { - - val attributesInCondition = AttributeSet( - leftExpr.collect { case a: AttributeReference => a } ++ - rightExpr.collect { case a: AttributeReference => a } - ) - if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || - attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { - // If more than attributes present in condition from one side, then it cannot be solved - return None - } - - def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { - e.collectLeaves().collectFirst { - case a @ AttributeReference(_, TimestampType, _, _) - if attributesToFindStateWatermarkFor.contains(a) => a - }.nonEmpty - } - - // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 - val allOnLeftExpr = Subtract(leftExpr, rightExpr) - logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") - - // Canonicalization step 2: extract commutative terms - // rightTime-with-watermark, c1, -leftTime, -c2 - val terms = ExpressionSet(collectTerms(allOnLeftExpr)) - logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) - - - - // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor - val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) - - // Verify there is only one correct constraint term and of the correct type - if (constraintTerms.size > 1) { - logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - if (constraintTerms.isEmpty) { - logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - val constraintTerm = constraintTerms.head - if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { - // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` - // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. - // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark - // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about - // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, - // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. - return None - } - - // Replace watermark attribute with watermark value, and generate the resolved expression - // from the other terms. That is, - // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) - logDebug(s"Constraint term from join condition:\t$constraintTerm") - val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => - term.transform { - case a @ AttributeReference(_, TimestampType, _, metadata) - if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => - Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) - } - }.reduceLeft(Add) - - // Calculate the constraint value - logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") - val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] - Some((Double2double(constraintValue) / 1000.0).toLong) - } - - /** - * Collect all the terms present in an expression after converting it into the form - * a + b + c + d where each term be either an attribute or a literal casted to long, - * optionally wrapped in a unary minus. - */ - private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { - var invalid = false - - /** Wrap a term with UnaryMinus if its needs to be negated. */ - def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { - if (minus) UnaryMinus(expr) else expr - } - - /** - * Recursively split the expression into its leaf terms contains attributes or literals. - * Returns terms only of the forms: - * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), - * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) - * Multiply(Literal), UnaryMinus(Multiply(Literal)) - * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) - * - * Note: - * - If term needs to be negated for making it a commutative term, - * then it will be wrapped in UnaryMinus(...) - * - Each terms will be representing timestamp value or time interval in microseconds, - * typed as doubles. - */ - def collect(expr: Expression, negate: Boolean): Seq[Expression] = { - expr match { - case Add(left, right) => - collect(left, negate) ++ collect(right, negate) - case Subtract(left, right) => - collect(left, negate) ++ collect(right, !negate) - case TimeAdd(left, right, _) => - collect(left, negate) ++ collect(right, negate) - case TimeSub(left, right, _) => - collect(left, negate) ++ collect(right, !negate) - case UnaryMinus(child) => - collect(child, !negate) - case CheckOverflow(child, _) => - collect(child, negate) - case Cast(child, dataType, _) => - dataType match { - case _: NumericType | _: TimestampType => collect(child, negate) - case _ => - invalid = true - Seq.empty - } - case a: AttributeReference => - val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a - Seq(negateIfNeeded(castedRef, negate)) - case lit: Literal => - // If literal of type calendar interval, then explicitly convert to millis - // Convert other number like literal to doubles representing millis (by x1000) - val castedLit = lit.dataType match { - case CalendarIntervalType => - val calendarInterval = lit.value.asInstanceOf[CalendarInterval] - if (calendarInterval.months > 0) { - invalid = true - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom " + - s"as imprecise intervals like months and years cannot be used for" + - s"watermark calculation. Use interval in terms of day instead.") - Literal(0.0) - } else { - Literal(calendarInterval.microseconds.toDouble) - } - case DoubleType => - Multiply(lit, Literal(1000000.0)) - case _: NumericType => - Multiply(Cast(lit, DoubleType), Literal(1000000.0)) - case _: TimestampType => - Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) - } - Seq(negateIfNeeded(castedLit, negate)) - case a @ _ => - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") - invalid = true - Seq.empty - } - } - - val terms = collect(exprToCollectFrom, negate = false) - if (!invalid) terms else Seq.empty - } - - /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already * loaded. This is class is a modified verion of [[ZippedPartitionsRDD2]]. http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 3764871..d256fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -76,7 +76,7 @@ class SymmetricHashJoinStateManager( /** Get all the values of a key */ def get(key: UnsafeRow): Iterator[UnsafeRow] = { val numValues = keyToNumValues.get(key) - keyWithIndexToValue.getAll(key, numValues) + keyWithIndexToValue.getAll(key, numValues).map(_.value) } /** Append a new value to the key */ @@ -87,70 +87,163 @@ class SymmetricHashJoinStateManager( } /** - * Remove using a predicate on keys. See class docs for more context and implement details. + * Remove using a predicate on keys. + * + * This produces an iterator over the (key, value) pairs satisfying condition(key), where the + * underlying store is updated as a side-effect of producing next. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator - - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - if (condition(keyToNumValue.key)) { - keyToNumValues.remove(keyToNumValue.key) - keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue) + def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKeyToNumValue: KeyAndNumValues = null + private var currentValues: Iterator[KeyWithIndexAndValue] = null + + private def currentKey = currentKeyToNumValue.key + + private val reusedPair = new UnsafeRowPair() + + private def getAndRemoveValue() = { + val keyWithIndexAndValue = currentValues.next() + keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex) + reusedPair.withRows(currentKey, keyWithIndexAndValue.value) + } + + override def getNext(): UnsafeRowPair = { + // If there are more values for the current key, remove and return the next one. + if (currentValues != null && currentValues.hasNext) { + return getAndRemoveValue() + } + + // If there weren't any values left, try and find the next key that satisfies the removal + // condition and has values. + while (allKeyToNumValues.hasNext) { + currentKeyToNumValue = allKeyToNumValues.next() + if (removalCondition(currentKey)) { + currentValues = keyWithIndexToValue.getAll( + currentKey, currentKeyToNumValue.numValue) + keyToNumValues.remove(currentKey) + + if (currentValues.hasNext) { + return getAndRemoveValue() + } + } + } + + // We only reach here if there were no satisfying keys left, which means we're done. + finished = true + return null } + + override def close: Unit = {} } } /** - * Remove using a predicate on values. See class docs for more context and implementation details. + * Remove using a predicate on values. + * + * At a high level, this produces an iterator over the (key, value) pairs such that value + * satisfies the predicate, where producing an element removes the value from the state store + * and producing all elements with a given key updates it accordingly. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator + def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - val key = keyToNumValue.key + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() - var numValues: Long = keyToNumValue.numValue - var index: Long = 0L - var valueRemoved: Boolean = false - var valueForIndex: UnsafeRow = null + private val allKeyToNumValues = keyToNumValues.iterator - while (index < numValues) { - if (valueForIndex == null) { - valueForIndex = keyWithIndexToValue.get(key, index) + private var currentKey: UnsafeRow = null + private var numValues: Long = 0L + private var index: Long = 0L + private var valueRemoved: Boolean = false + + // Push the data for the current key to the numValues store, and reset the tracking variables + // to their empty state. + private def updateNumValueForCurrentKey(): Unit = { + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(currentKey, numValues) + } else { + keyToNumValues.remove(currentKey) + } } - if (condition(valueForIndex)) { - if (numValues > 1) { - val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1) - keyWithIndexToValue.put(key, index, valueAtMaxIndex) - keyWithIndexToValue.remove(key, numValues - 1) - valueForIndex = valueAtMaxIndex + + currentKey = null + numValues = 0 + index = 0 + valueRemoved = false + } + + // Find the next value satisfying the condition, updating `currentKey` and `numValues` if + // needed. Returns null when no value can be found. + private def findNextValueForIndex(): UnsafeRow = { + // Loop across all values for the current key, and then all other keys, until we find a + // value satisfying the removal condition. + def hasMoreValuesForCurrentKey = currentKey != null && index < numValues + def hasMoreKeys = allKeyToNumValues.hasNext + while (hasMoreValuesForCurrentKey || hasMoreKeys) { + if (hasMoreValuesForCurrentKey) { + // First search the values for the current key. + val currentValue = keyWithIndexToValue.get(currentKey, index) + if (removalCondition(currentValue)) { + return currentValue + } else { + index += 1 + } + } else if (hasMoreKeys) { + // If we can't find a value for the current key, cleanup and start looking at the next. + // This will also happen the first time the iterator is called. + updateNumValueForCurrentKey() + + val currentKeyToNumValue = allKeyToNumValues.next() + currentKey = currentKeyToNumValue.key + numValues = currentKeyToNumValue.numValue } else { - keyWithIndexToValue.remove(key, 0) - valueForIndex = null + // Should be unreachable, but in any case means a value couldn't be found. + return null } - numValues -= 1 - valueRemoved = true - } else { - valueForIndex = null - index += 1 } + + // We tried and failed to find the next value. + return null } - if (valueRemoved) { - if (numValues >= 1) { - keyToNumValues.put(key, numValues) + + override def getNext(): UnsafeRowPair = { + val currentValue = findNextValueForIndex() + + // If there's no value, clean up and finish. There aren't any more available. + if (currentValue == null) { + updateNumValueForCurrentKey() + finished = true + return null + } + + // The backing store is arraylike - we as the caller are responsible for filling back in + // any hole. So we swap the last element into the hole and decrement numValues to shorten. + // clean + if (numValues > 1) { + val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) + keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex) + keyWithIndexToValue.remove(currentKey, numValues - 1) } else { - keyToNumValues.remove(key) + keyWithIndexToValue.remove(currentKey, 0) } + numValues -= 1 + valueRemoved = true + + return reusedPair.withRows(currentKey, currentValue) } - } - } - def iterator(): Iterator[UnsafeRowPair] = { - val pair = new UnsafeRowPair() - keyWithIndexToValue.iterator.map { x => - pair.withRows(x.key, x.value) + override def close: Unit = {} } } @@ -309,19 +402,24 @@ class SymmetricHashJoinStateManager( stateStore.get(keyWithIndexRow(key, valueIndex)) } - /** Get all the values for key and all indices. */ - def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = { + /** + * Get all values and indices for the provided key. + * Should not return null. + */ + def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { + val keyWithIndexAndValue = new KeyWithIndexAndValue() var index = 0 - new NextIterator[UnsafeRow] { - override protected def getNext(): UnsafeRow = { + new NextIterator[KeyWithIndexAndValue] { + override protected def getNext(): KeyWithIndexAndValue = { if (index >= numValues) { finished = true null } else { val keyWithIndex = keyWithIndexRow(key, index) val value = stateStore.get(keyWithIndex) + keyWithIndexAndValue.withNew(key, index, value) index += 1 - value + keyWithIndexAndValue } } http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index ffa4c3c..d44af1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -137,14 +137,16 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter BoundReference( 1, inputValueAttribWithWatermark.dataType, inputValueAttribWithWatermark.nullable), Literal(threshold)) - manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + val iter = manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + while (iter.hasNext) iter.next() } /** Remove values where `time <= threshold` */ def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) - manager.removeByValueCondition( + val iter = manager.removeByValueCondition( GeneratePredicate.generate(expr, inputValueAttribs).eval _) + while (iter.hasNext) iter.next() } def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org