[SPARK-22136][SS] Implement stream-stream outer joins.

## What changes were proposed in this pull request?

Allow one-sided outer joins between two streams when a watermark is defined.

## How was this patch tested?

new unit tests

Author: Jose Torres <j...@databricks.com>

Closes #19327 from joseph-torres/outerjoin.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3099c574
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3099c574
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3099c574

Branch: refs/heads/master
Commit: 3099c574c56cab86c3fcf759864f89151643f837
Parents: 5f69433
Author: Jose Torres <j...@databricks.com>
Authored: Tue Oct 3 21:42:51 2017 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Tue Oct 3 21:42:51 2017 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/StreamingJoinHelper.scala | 286 ++++++++++++++++++
 .../analysis/UnsupportedOperationChecker.scala  |  53 +++-
 .../analysis/StreamingJoinHelperSuite.scala     | 140 +++++++++
 .../analysis/UnsupportedOperationsSuite.scala   | 108 ++++++-
 .../StreamingSymmetricHashJoinExec.scala        | 152 ++++++++--
 .../StreamingSymmetricHashJoinHelper.scala      | 241 +--------------
 .../state/SymmetricHashJoinStateManager.scala   | 200 +++++++++----
 .../SymmetricHashJoinStateManagerSuite.scala    |   6 +-
 .../sql/streaming/StreamingJoinSuite.scala      | 298 ++++++++++++-------
 9 files changed, 1051 insertions(+), 433 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
new file mode 100644
index 0000000..072dc95
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, 
AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, 
GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, 
PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, 
UnaryMinus}
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, 
LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+
+/**
+ * Helper object for stream joins. See [[StreamingSymmetricHashJoinExec]] in 
SQL for more details.
+ */
+object StreamingJoinHelper extends PredicateHelper with Logging {
+
+  /**
+   * Check the provided logical plan to see if its join keys contain a 
watermark attribute.
+   *
+   * Will return false if the plan is not an equijoin.
+   * @param plan the logical plan to check
+   */
+  def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = {
+    plan match {
+      case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) =>
+        (leftKeys ++ rightKeys).exists {
+          case a: AttributeReference => 
a.metadata.contains(EventTimeWatermark.delayKey)
+          case _ => false
+        }
+      case _ => false
+    }
+  }
+
+  /**
+   * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for 
context about it)
+   * given the join condition and the event time watermark. This is how it 
works.
+   * - The condition is split into conjunctive predicates, and we find the 
predicates of the
+   *   form `leftTime + c1 < rightTime + c2`   (or <=, >, >=).
+   * - We canoncalize the predicate and solve it with the event time watermark 
value to find the
+   *  value of the state watermark.
+   * This function is supposed to make best-effort attempt to get the state 
watermark. If there is
+   * any error, it will return None.
+   *
+   * @param attributesToFindStateWatermarkFor attributes of the side whose 
state watermark
+   *                                         is to be calculated
+   * @param attributesWithEventWatermark  attributes of the other side which 
has a watermark column
+   * @param joinCondition                 join condition
+   * @param eventWatermark                watermark defined on the input event 
data
+   * @return state value watermark in milliseconds, is possible.
+   */
+  def getStateValueWatermark(
+      attributesToFindStateWatermarkFor: AttributeSet,
+      attributesWithEventWatermark: AttributeSet,
+      joinCondition: Option[Expression],
+      eventWatermark: Option[Long]): Option[Long] = {
+
+    // If condition or event time watermark is not provided, then cannot 
calculate state watermark
+    if (joinCondition.isEmpty || eventWatermark.isEmpty) return None
+
+    // If there is not watermark attribute, then cannot define state watermark
+    if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) 
return None
+
+    def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = {
+      try {
+        getStateWatermarkFromLessThenPredicate(
+          l, r, attributesToFindStateWatermarkFor, 
attributesWithEventWatermark, eventWatermark)
+      } catch {
+        case NonFatal(e) =>
+          logWarning(s"Error trying to extract state constraint from condition 
$joinCondition", e)
+          None
+      }
+    }
+
+    val allStateWatermarks = 
splitConjunctivePredicates(joinCondition.get).flatMap { predicate =>
+
+      // The generated the state watermark cleanup expression is inclusive of 
the state watermark.
+      // If state watermark is W, all state where timestamp <= W will be 
cleaned up.
+      // Now when the canonicalized join condition solves to leftTime >= W, we 
dont want to clean
+      // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence 
the -1 below.
+      val stateWatermark = predicate match {
+        case LessThan(l, r) => getStateWatermarkSafely(l, r)
+        case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1)
+        case GreaterThan(l, r) => getStateWatermarkSafely(r, l)
+        case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 
1)
+        case _ => None
+      }
+      if (stateWatermark.nonEmpty) {
+        logInfo(s"Condition $joinCondition generated watermark constraint = 
${stateWatermark.get}")
+      }
+      stateWatermark
+    }
+    allStateWatermarks.reduceOption((x, y) => Math.min(x, y))
+  }
+
+  /**
+   * Extract the state value watermark (milliseconds) from the condition
+   * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find 
the constraint for
+   * leftTime using the watermark on the rightTime. Example:
+   *
+   * Input:                 rightTime-with-watermark + c1 < leftTime + c2
+   * Canonical form:        rightTime-with-watermark + c1 + (-c2) + 
(-leftTime) < 0
+   * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime
+   * With watermark value:  watermark-value + c1 + (-c2) < leftTime
+   */
+  private def getStateWatermarkFromLessThenPredicate(
+      leftExpr: Expression,
+      rightExpr: Expression,
+      attributesToFindStateWatermarkFor: AttributeSet,
+      attributesWithEventWatermark: AttributeSet,
+      eventWatermark: Option[Long]): Option[Long] = {
+
+    val attributesInCondition = AttributeSet(
+      leftExpr.collect { case a: AttributeReference => a } ++
+      rightExpr.collect { case a: AttributeReference => a }
+    )
+    if (attributesInCondition.filter { 
attributesToFindStateWatermarkFor.contains(_) }.size > 1 ||
+        attributesInCondition.filter { 
attributesWithEventWatermark.contains(_) }.size > 1) {
+      // If more than attributes present in condition from one side, then it 
cannot be solved
+      return None
+    }
+
+    def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = {
+      e.collectLeaves().collectFirst {
+        case a @ AttributeReference(_, _, _, _)
+          if attributesToFindStateWatermarkFor.contains(a) => a
+      }.nonEmpty
+    }
+
+    // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - 
(leftTime + c2) < 0
+    val allOnLeftExpr = Subtract(leftExpr, rightExpr)
+    logDebug(s"All on 
Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}")
+
+    // Canonicalization step 2: extract commutative terms
+    //    rightTime-with-watermark, c1, -leftTime, -c2
+    val terms = ExpressionSet(collectTerms(allOnLeftExpr))
+    logDebug("Terms extracted from join condition:\n\t" + 
terms.mkString("\n\t"))
+
+    // Find the term that has leftTime (i.e. the one present in 
attributesToFindConstraintFor
+    val constraintTerms = 
terms.filter(containsAttributeToFindStateConstraintFor)
+
+    // Verify there is only one correct constraint term and of the correct type
+    if (constraintTerms.size > 1) {
+      logWarning("Failed to extract state constraint terms: multiple time 
terms in condition\n\t" +
+        terms.mkString("\n\t"))
+      return None
+    }
+    if (constraintTerms.isEmpty) {
+      logDebug("Failed to extract state constraint terms: no time terms in 
condition\n\t" +
+        terms.mkString("\n\t"))
+      return None
+    }
+    val constraintTerm = constraintTerms.head
+    if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) {
+      // Incorrect condition. We want the constraint term in canonical form to 
be `-leftTime`
+      // so that resolve for it as `-leftTime + watermark + c < 0` ==> 
`watermark + c < leftTime`.
+      // Now, if the original conditions is `rightTime-with-watermark > 
leftTime` and watermark
+      // condition is `rightTime-with-watermark > watermarkValue`, then no 
constraint about
+      // `leftTime` can be inferred. In this case, after canonicalization and 
collection of terms,
+      // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we 
return None.
+      return None
+    }
+
+    // Replace watermark attribute with watermark value, and generate the 
resolved expression
+    // from the other terms. That is,
+    // rightTime-with-watermark, c1, -c2  =>  watermark, c1, -c2  =>  
watermark + c1 + (-c2)
+    logDebug(s"Constraint term from join condition:\t$constraintTerm")
+    val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term =>
+      term.transform {
+        case a @ AttributeReference(_, _, _, metadata)
+          if attributesWithEventWatermark.contains(a) && 
metadata.contains(delayKey) =>
+          Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0))
+      }
+    }.reduceLeft(Add)
+
+    // Calculate the constraint value
+    logInfo(s"Final expression to evaluate 
constraint:\t$exprWithWatermarkSubstituted")
+    val constraintValue = 
exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double]
+    Some((Double2double(constraintValue) / 1000.0).toLong)
+  }
+
+  /**
+   * Collect all the terms present in an expression after converting it into 
the form
+   * a + b + c + d where each term be either an attribute or a literal casted 
to long,
+   * optionally wrapped in a unary minus.
+   */
+  private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = {
+    var invalid = false
+
+    /** Wrap a term with UnaryMinus if its needs to be negated. */
+    def negateIfNeeded(expr: Expression, minus: Boolean): Expression = {
+      if (minus) UnaryMinus(expr) else expr
+    }
+
+    /**
+     * Recursively split the expression into its leaf terms contains 
attributes or literals.
+     * Returns terms only of the forms:
+     *    Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)),
+     *    Cast(AttributeReference, Double), 
UnaryMinus(Cast(AttributeReference, Double))
+     *    Multiply(Literal), UnaryMinus(Multiply(Literal))
+     *    Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal)))
+     *
+     * Note:
+     * - If term needs to be negated for making it a commutative term,
+     *   then it will be wrapped in UnaryMinus(...)
+     * - Each terms will be representing timestamp value or time interval in 
microseconds,
+     *   typed as doubles.
+     */
+    def collect(expr: Expression, negate: Boolean): Seq[Expression] = {
+      expr match {
+        case Add(left, right) =>
+          collect(left, negate) ++ collect(right, negate)
+        case Subtract(left, right) =>
+          collect(left, negate) ++ collect(right, !negate)
+        case TimeAdd(left, right, _) =>
+          collect(left, negate) ++ collect(right, negate)
+        case TimeSub(left, right, _) =>
+          collect(left, negate) ++ collect(right, !negate)
+        case UnaryMinus(child) =>
+          collect(child, !negate)
+        case CheckOverflow(child, _) =>
+          collect(child, negate)
+        case Cast(child, dataType, _) =>
+          dataType match {
+            case _: NumericType | _: TimestampType => collect(child, negate)
+            case _ =>
+              invalid = true
+              Seq.empty
+          }
+        case a: AttributeReference =>
+          val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) 
else a
+          Seq(negateIfNeeded(castedRef, negate))
+        case lit: Literal =>
+          // If literal of type calendar interval, then explicitly convert to 
millis
+          // Convert other number like literal to doubles representing millis 
(by x1000)
+          val castedLit = lit.dataType match {
+            case CalendarIntervalType =>
+              val calendarInterval = lit.value.asInstanceOf[CalendarInterval]
+              if (calendarInterval.months > 0) {
+                invalid = true
+                logWarning(
+                  s"Failed to extract state value watermark from condition 
$exprToCollectFrom " +
+                    s"as imprecise intervals like months and years cannot be 
used for" +
+                    s"watermark calculation. Use interval in terms of day 
instead.")
+                Literal(0.0)
+              } else {
+                Literal(calendarInterval.microseconds.toDouble)
+              }
+            case DoubleType =>
+              Multiply(lit, Literal(1000000.0))
+            case _: NumericType =>
+              Multiply(Cast(lit, DoubleType), Literal(1000000.0))
+            case _: TimestampType =>
+              Multiply(PreciseTimestampConversion(lit, TimestampType, 
LongType), Literal(1000000.0))
+          }
+          Seq(negateIfNeeded(castedLit, negate))
+        case a @ _ =>
+          logWarning(
+            s"Failed to extract state value watermark from condition 
$exprToCollectFrom due to $a")
+          invalid = true
+          Seq.empty
+      }
+    }
+
+    val terms = collect(exprToCollectFrom, negate = false)
+    if (!invalid) terms else Seq.empty
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index d1d7056..dee6fbe 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -18,8 +18,9 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, AttributeSet}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
@@ -217,7 +218,7 @@ object UnsupportedOperationChecker {
           throwError("dropDuplicates is not supported after aggregation on a " 
+
             "streaming DataFrame/Dataset")
 
-        case Join(left, right, joinType, _) =>
+        case Join(left, right, joinType, condition) =>
 
           joinType match {
 
@@ -233,16 +234,52 @@ object UnsupportedOperationChecker {
                 throwError("Full outer joins with streaming 
DataFrames/Datasets are not supported")
               }
 
-            case LeftOuter | LeftSemi | LeftAnti =>
+            case LeftSemi | LeftAnti =>
               if (right.isStreaming) {
-                throwError("Left outer/semi/anti joins with a streaming 
DataFrame/Dataset " +
-                    "on the right is not supported")
+                throwError("Left semi/anti joins with a streaming 
DataFrame/Dataset " +
+                    "on the right are not supported")
               }
 
+            // We support streaming left outer joins with static on the right 
always, and with
+            // stream on both sides under the appropriate conditions.
+            case LeftOuter =>
+              if (!left.isStreaming && right.isStreaming) {
+                throwError("Left outer join with a streaming DataFrame/Dataset 
" +
+                  "on the right and a static DataFrame/Dataset on the left is 
not supported")
+              } else if (left.isStreaming && right.isStreaming) {
+                val watermarkInJoinKeys = 
StreamingJoinHelper.isWatermarkInJoinKeys(subPlan)
+
+                val hasValidWatermarkRange =
+                  StreamingJoinHelper.getStateValueWatermark(
+                    left.outputSet, right.outputSet, condition, 
Some(1000000)).isDefined
+
+                if (!watermarkInJoinKeys && !hasValidWatermarkRange) {
+                  throwError("Stream-stream outer join between two streaming 
DataFrame/Datasets " +
+                    "is not supported without a watermark in the join keys, or 
a watermark on " +
+                    "the nullable side and an appropriate range condition")
+                }
+              }
+
+            // We support streaming right outer joins with static on the left 
always, and with
+            // stream on both sides under the appropriate conditions.
             case RightOuter =>
-              if (left.isStreaming) {
-                throwError("Right outer join with a streaming 
DataFrame/Dataset on the left is " +
-                    "not supported")
+              if (left.isStreaming && !right.isStreaming) {
+                throwError("Right outer join with a streaming 
DataFrame/Dataset on the left and " +
+                    "a static DataFrame/DataSet on the right not supported")
+              } else if (left.isStreaming && right.isStreaming) {
+                val isWatermarkInJoinKeys = 
StreamingJoinHelper.isWatermarkInJoinKeys(subPlan)
+
+                // Check if the nullable side has a watermark, and there's a 
range condition which
+                // implies a state value watermark on the first side.
+                val hasValidWatermarkRange =
+                    StreamingJoinHelper.getStateValueWatermark(
+                      right.outputSet, left.outputSet, condition, 
Some(1000000)).isDefined
+
+                if (!isWatermarkInJoinKeys && !hasValidWatermarkRange) {
+                  throwError("Stream-stream outer join between two streaming 
DataFrame/Datasets " +
+                    "is not supported without a watermark in the join keys, or 
a watermark on " +
+                    "the nullable side and an appropriate range condition")
+                }
               }
 
             case NaturalJoin(_) | UsingJoin(_, _) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala
new file mode 100644
index 0000000..8cf41a0
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, AttributeSet}
+import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, 
Filter, LeafNode, LocalRelation}
+import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, TimestampType}
+
+class StreamingJoinHelperSuite extends AnalysisTest {
+
+  test("extract watermark from time condition") {
+    val attributesToFindConstraintFor = Seq(
+      AttributeReference("leftTime", TimestampType)(),
+      AttributeReference("leftOther", IntegerType)())
+    val metadataWithWatermark = new MetadataBuilder()
+      .putLong(EventTimeWatermark.delayKey, 1000)
+      .build()
+    val attributesWithWatermark = Seq(
+      AttributeReference("rightTime", TimestampType, metadata = 
metadataWithWatermark)(),
+      AttributeReference("rightOther", IntegerType)())
+
+    case class DummyLeafNode() extends LeafNode {
+      override def output: Seq[Attribute] =
+        attributesToFindConstraintFor ++ attributesWithWatermark
+    }
+
+    def watermarkFrom(
+        conditionStr: String,
+        rightWatermark: Option[Long] = Some(10000)): Option[Long] = {
+      val conditionExpr = Some(conditionStr).map { str =>
+        val plan =
+          Filter(
+            CatalystSqlParser.parseExpression(str),
+            DummyLeafNode())
+        val optimized = 
SimpleTestOptimizer.execute(SimpleAnalyzer.execute(plan))
+        optimized.asInstanceOf[Filter].condition
+      }
+      StreamingJoinHelper.getStateValueWatermark(
+        AttributeSet(attributesToFindConstraintFor), 
AttributeSet(attributesWithWatermark),
+        conditionExpr, rightWatermark)
+    }
+
+    // Test comparison directionality. E.g. if leftTime < rightTime and 
rightTime > watermark,
+    // then cannot define constraint on leftTime.
+    assert(watermarkFrom("leftTime > rightTime") === Some(10000))
+    assert(watermarkFrom("leftTime >= rightTime") === Some(9999))
+    assert(watermarkFrom("leftTime < rightTime") === None)
+    assert(watermarkFrom("leftTime <= rightTime") === None)
+    assert(watermarkFrom("rightTime > leftTime") === None)
+    assert(watermarkFrom("rightTime >= leftTime") === None)
+    assert(watermarkFrom("rightTime < leftTime") === Some(10000))
+    assert(watermarkFrom("rightTime <= leftTime") === Some(9999))
+
+    // Test type conversions
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") 
=== Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") 
=== None)
+    assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS 
DOUBLE)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") 
=== Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") 
=== Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS 
FLOAT)") === Some(10000))
+    assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS 
STRING)") === None)
+
+    // Test with timestamp type + calendar interval on either side of equation
+    // Note: timestamptype and calendar interval don't commute, so less valid 
combinations to test.
+    assert(watermarkFrom("leftTime > rightTime + interval 1 second") === 
Some(11000))
+    assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === 
Some(8000))
+    assert(watermarkFrom("leftTime > rightTime - interval 3 second") === 
Some(7000))
+    assert(watermarkFrom("rightTime < leftTime - interval 3 second") === 
Some(13000))
+    assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 
3 second")
+      === Some(12000))
+
+    // Test with casted long type + constants on either side of equation
+    // Note: long type and constants commute, so more combinations to test.
+    // -- Constants on the right
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 
1") === Some(11000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 
1") === Some(9000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 
1 second) AS LONG)")
+      === Some(11000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS 
LONG)") === Some(12000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS 
LONG)") === Some(9500))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 
2") === Some(12000))
+    assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) 
> 0.1")
+      === Some(10100))
+    assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) 
+ 0.2")
+      === Some(10200))
+    // -- Constants on the left
+    assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS 
LONG)") === Some(8000))
+    assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS 
LONG)") === Some(9000))
+    assert(watermarkFrom("CAST((leftTime  + interval 3 second) AS LONG) > 
CAST(rightTime AS LONG)")
+      === Some(7000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS 
LONG)") === Some(12000))
+    assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS 
LONG)") === Some(9500))
+    assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 
> 0")
+      === Some(12000))
+    assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 
0.1 > 0")
+      === Some(10100))
+    // -- Constants on both sides, mixed types
+    assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS 
LONG) + 1")
+      === Some(13000))
+
+    // Test multiple conditions, should return minimum watermark
+    assert(watermarkFrom(
+      "leftTime > rightTime - interval 3 second AND rightTime < leftTime + 
interval 2 seconds") ===
+      Some(7000))  // first condition wins
+    assert(watermarkFrom(
+      "leftTime > rightTime - interval 3 second AND rightTime < leftTime + 
interval 4 seconds") ===
+      Some(6000))  // second condition wins
+
+    // Test invalid comparisons
+    assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None)      
// non-time attributes
+    assert(watermarkFrom("leftOther > rightOther") === None)                  
// non-time attributes
+    assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") 
=== Some(10000))
+    assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None)  
// non-time attributes
+    assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) 
// month not allowed
+
+    // Test static comparisons
+    assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000))
+
+    // Test non-positive results
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 
10") === Some(0))
+    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 
100") === Some(-90000))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 11f48a3..e5057c4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -31,6 +31,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder}
+import org.apache.spark.unsafe.types.CalendarInterval
 
 /** A dummy command for testing unsupported operations. */
 case class DummyCommand() extends Command
@@ -417,9 +418,57 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
   testBinaryOperationInStreamingPlan(
     "left outer join",
     _.join(_, joinType = LeftOuter),
-    streamStreamSupported = false,
     batchStreamSupported = false,
-    expectedMsg = "left outer/semi/anti joins")
+    streamStreamSupported = false,
+    expectedMsg = "outer join")
+
+  // Left outer joins: stream-stream allowed with join on watermark attribute
+  // Note that the attribute need not be watermarked on both sides.
+  assertSupportedInStreamingPlan(
+    s"left outer join with stream-stream relations and join on attribute with 
left watermark",
+    streamRelation.join(streamRelation, joinType = LeftOuter,
+      condition = Some(attributeWithWatermark === attribute)),
+    OutputMode.Append())
+  assertSupportedInStreamingPlan(
+    s"left outer join with stream-stream relations and join on attribute with 
right watermark",
+    streamRelation.join(streamRelation, joinType = LeftOuter,
+      condition = Some(attribute === attributeWithWatermark)),
+    OutputMode.Append())
+  assertNotSupportedInStreamingPlan(
+    s"left outer join with stream-stream relations and join on non-watermarked 
attribute",
+    streamRelation.join(streamRelation, joinType = LeftOuter,
+      condition = Some(attribute === attribute)),
+    OutputMode.Append(),
+    Seq("watermark in the join keys"))
+
+  // Left outer joins: stream-stream allowed with range condition yielding 
state value watermark
+  assertSupportedInStreamingPlan(
+    s"left outer join with stream-stream relations and state value watermark", 
{
+      val leftRelation = streamRelation
+      val rightTimeWithWatermark =
+        AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata)
+      val rightRelation = new TestStreamingRelation(rightTimeWithWatermark)
+      leftRelation.join(
+        rightRelation,
+        joinType = LeftOuter,
+        condition = Some(attribute > rightTimeWithWatermark + 10))
+    },
+    OutputMode.Append())
+
+  // Left outer joins: stream-stream not allowed with insufficient range 
condition
+  assertNotSupportedInStreamingPlan(
+    s"left outer join with stream-stream relations and state value watermark", 
{
+      val leftRelation = streamRelation
+      val rightTimeWithWatermark =
+        AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata)
+      val rightRelation = new TestStreamingRelation(rightTimeWithWatermark)
+      leftRelation.join(
+        rightRelation,
+        joinType = LeftOuter,
+        condition = Some(attribute < rightTimeWithWatermark + 10))
+    },
+    OutputMode.Append(),
+    Seq("appropriate range condition"))
 
   // Left semi joins: stream-* not allowed
   testBinaryOperationInStreamingPlan(
@@ -427,7 +476,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
     _.join(_, joinType = LeftSemi),
     streamStreamSupported = false,
     batchStreamSupported = false,
-    expectedMsg = "left outer/semi/anti joins")
+    expectedMsg = "left semi/anti joins")
 
   // Left anti joins: stream-* not allowed
   testBinaryOperationInStreamingPlan(
@@ -435,14 +484,63 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
     _.join(_, joinType = LeftAnti),
     streamStreamSupported = false,
     batchStreamSupported = false,
-    expectedMsg = "left outer/semi/anti joins")
+    expectedMsg = "left semi/anti joins")
 
   // Right outer joins: stream-* not allowed
   testBinaryOperationInStreamingPlan(
     "right outer join",
     _.join(_, joinType = RightOuter),
+    streamBatchSupported = false,
     streamStreamSupported = false,
-    streamBatchSupported = false)
+    expectedMsg = "outer join")
+
+  // Right outer joins: stream-stream allowed with join on watermark attribute
+  // Note that the attribute need not be watermarked on both sides.
+  assertSupportedInStreamingPlan(
+    s"right outer join with stream-stream relations and join on attribute with 
left watermark",
+    streamRelation.join(streamRelation, joinType = RightOuter,
+      condition = Some(attributeWithWatermark === attribute)),
+    OutputMode.Append())
+  assertSupportedInStreamingPlan(
+    s"right outer join with stream-stream relations and join on attribute with 
right watermark",
+    streamRelation.join(streamRelation, joinType = RightOuter,
+      condition = Some(attribute === attributeWithWatermark)),
+    OutputMode.Append())
+  assertNotSupportedInStreamingPlan(
+    s"right outer join with stream-stream relations and join on 
non-watermarked attribute",
+    streamRelation.join(streamRelation, joinType = RightOuter,
+      condition = Some(attribute === attribute)),
+    OutputMode.Append(),
+    Seq("watermark in the join keys"))
+
+  // Right outer joins: stream-stream allowed with range condition yielding 
state value watermark
+  assertSupportedInStreamingPlan(
+    s"right outer join with stream-stream relations and state value 
watermark", {
+      val leftTimeWithWatermark =
+        AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata)
+      val leftRelation = new TestStreamingRelation(leftTimeWithWatermark)
+      val rightRelation = streamRelation
+      leftRelation.join(
+        rightRelation,
+        joinType = RightOuter,
+        condition = Some(leftTimeWithWatermark + 10 < attribute))
+    },
+    OutputMode.Append())
+
+  // Right outer joins: stream-stream not allowed with insufficient range 
condition
+  assertNotSupportedInStreamingPlan(
+    s"right outer join with stream-stream relations and state value 
watermark", {
+      val leftTimeWithWatermark =
+        AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata)
+      val leftRelation = new TestStreamingRelation(leftTimeWithWatermark)
+      val rightRelation = streamRelation
+      leftRelation.join(
+        rightRelation,
+        joinType = RightOuter,
+        condition = Some(leftTimeWithWatermark + 10 > attribute))
+    },
+    OutputMode.Append(),
+    Seq("appropriate range condition"))
 
   // Cogroup: only batch-batch is allowed
   testBinaryOperationInStreamingPlan(

http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 44f1fa5..9bd2127 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, 
Expression, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, 
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, 
Expression, GenericInternalRow, JoinedRow, Literal, NamedExpression, 
PreciseTimestampConversion, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
 import org.apache.spark.sql.catalyst.plans.physical._
@@ -146,7 +146,14 @@ case class StreamingSymmetricHashJoinExec(
       stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right)
   }
 
-  require(joinType == Inner, s"${getClass.getSimpleName} should not take 
$joinType as the JoinType")
+  private def throwBadJoinTypeException(): Nothing = {
+    throw new IllegalArgumentException(
+      s"${getClass.getSimpleName} should not take $joinType as the JoinType")
+  }
+
+  require(
+    joinType == Inner || joinType == LeftOuter || joinType == RightOuter,
+    s"${getClass.getSimpleName} should not take $joinType as the JoinType")
   require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType))
 
   private val storeConf = new StateStoreConf(sqlContext.conf)
@@ -157,11 +164,18 @@ case class StreamingSymmetricHashJoinExec(
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
-  override def output: Seq[Attribute] = left.output ++ right.output
+  override def output: Seq[Attribute] = joinType match {
+    case _: InnerLike => left.output ++ right.output
+    case LeftOuter => left.output ++ right.output.map(_.withNullability(true))
+    case RightOuter => left.output.map(_.withNullability(true)) ++ right.output
+    case _ => throwBadJoinTypeException()
+  }
 
   override def outputPartitioning: Partitioning = joinType match {
     case _: InnerLike =>
       PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
+    case LeftOuter => PartitioningCollection(Seq(left.outputPartitioning))
+    case RightOuter => PartitioningCollection(Seq(right.outputPartitioning))
     case x =>
       throw new IllegalArgumentException(
         s"${getClass.getSimpleName} should not take $x as the JoinType")
@@ -207,31 +221,108 @@ case class StreamingSymmetricHashJoinExec(
     //    matching new left input with new right input, since the new left 
input has become stored
     //    by that point. This tiny asymmetry is necessary to avoid duplication.
     val leftOutputIter = 
leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) {
-      (inputRow: UnsafeRow, matchedRow: UnsafeRow) =>
-        joinedRow.withLeft(inputRow).withRight(matchedRow)
+      (input: UnsafeRow, matched: UnsafeRow) => 
joinedRow.withLeft(input).withRight(matched)
     }
     val rightOutputIter = 
rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) {
-      (inputRow: UnsafeRow, matchedRow: UnsafeRow) =>
-        joinedRow.withLeft(matchedRow).withRight(inputRow)
+      (input: UnsafeRow, matched: UnsafeRow) => 
joinedRow.withLeft(matched).withRight(input)
     }
 
     // Filter the joined rows based on the given condition.
-    val outputFilterFunction =
-      newPredicate(condition.getOrElse(Literal(true)), left.output ++ 
right.output).eval _
-    val filteredOutputIter =
-      (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { 
row =>
-        numOutputRows += 1
-        row
-      }
+    val outputFilterFunction = 
newPredicate(condition.getOrElse(Literal(true)), output).eval _
+
+    // We need to save the time that the inner join output iterator completes, 
since outer join
+    // output counts as both update and removal time.
+    var innerOutputCompletionTimeNs: Long = 0
+    def onInnerOutputCompletion = {
+      innerOutputCompletionTimeNs = System.nanoTime
+    }
+    val filteredInnerOutputIter = CompletionIterator[InternalRow, 
Iterator[InternalRow]](
+      (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction), 
onInnerOutputCompletion)
+
+    def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = {
+      rightSideJoiner.get(leftKeyValue.key).exists(
+        rightValue => {
+          outputFilterFunction(
+            joinedRow.withLeft(leftKeyValue.value).withRight(rightValue))
+        })
+    }
+
+    def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = {
+      leftSideJoiner.get(rightKeyValue.key).exists(
+        leftValue => {
+          outputFilterFunction(
+            joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
+        })
+    }
+
+    val outputIter: Iterator[InternalRow] = joinType match {
+      case Inner =>
+        filteredInnerOutputIter
+      case LeftOuter =>
+        // We generate the outer join input by:
+        // * Getting an iterator over the rows that have aged out on the left 
side. These rows are
+        //   candidates for being null joined. Note that to avoid doing two 
passes, this iterator
+        //   removes the rows from the state manager as they're processed.
+        // * Checking whether the current row matches a key in the right side 
state, and that key
+        //   has any value which satisfies the filter function when joined. If 
it doesn't,
+        //   we know we can join with null, since there was never (including 
this batch) a match
+        //   within the watermark period. If it does, there must have been a 
match at some point, so
+        //   we know we can't join with null.
+        val nullRight = new 
GenericInternalRow(right.output.map(_.withNullability(true)).length)
+        val removedRowIter = leftSideJoiner.removeOldState()
+        val outerOutputIter = removedRowIter
+          .filterNot(pair => matchesWithRightSideState(pair))
+          .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
+
+        filteredInnerOutputIter ++ outerOutputIter
+      case RightOuter =>
+        // See comments for left outer case.
+        val nullLeft = new 
GenericInternalRow(left.output.map(_.withNullability(true)).length)
+        val removedRowIter = rightSideJoiner.removeOldState()
+        val outerOutputIter = removedRowIter
+          .filterNot(pair => matchesWithLeftSideState(pair))
+          .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
+
+        filteredInnerOutputIter ++ outerOutputIter
+      case _ => throwBadJoinTypeException()
+    }
+
+    val outputIterWithMetrics = outputIter.map { row =>
+      numOutputRows += 1
+      row
+    }
 
     // Function to remove old state after all the input has been consumed and 
output generated
     def onOutputCompletion = {
+      // All processing time counts as update time.
       allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - 
updateStartTimeNs), 0)
 
-      // Remove old state if needed
+      // Processing time between inner output completion and here comes from 
the outer portion of a
+      // join, and thus counts as removal time as we remove old state from one 
side while iterating.
+      if (innerOutputCompletionTimeNs != 0) {
+        allRemovalsTimeMs +=
+          math.max(NANOSECONDS.toMillis(System.nanoTime - 
innerOutputCompletionTimeNs), 0)
+      }
+
       allRemovalsTimeMs += timeTakenMs {
-        leftSideJoiner.removeOldState()
-        rightSideJoiner.removeOldState()
+        // Remove any remaining state rows which aren't needed because they're 
below the watermark.
+        //
+        // For inner joins, we have to remove unnecessary state rows from both 
sides if possible.
+        // For outer joins, we have already removed unnecessary state rows 
from the outer side
+        // (e.g., left side for left outer join) while generating the outer 
"null" outputs. Now, we
+        // have to remove unnecessary state rows from the other side (e.g., 
right side for the left
+        // outer join) if possible. In all cases, nothing needs to be 
outputted, hence the removal
+        // needs to be done greedily by immediately consuming the returned 
iterator.
+        val cleanupIter = joinType match {
+          case Inner =>
+            leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
+          case LeftOuter => rightSideJoiner.removeOldState()
+          case RightOuter => leftSideJoiner.removeOldState()
+          case _ => throwBadJoinTypeException()
+        }
+        while (cleanupIter.hasNext) {
+          cleanupIter.next()
+        }
       }
 
       // Commit all state changes and update state store metrics
@@ -251,7 +342,8 @@ case class StreamingSymmetricHashJoinExec(
       }
     }
 
-    CompletionIterator[InternalRow, Iterator[InternalRow]](filteredOutputIter, 
onOutputCompletion)
+    CompletionIterator[InternalRow, Iterator[InternalRow]](
+      outputIterWithMetrics, onOutputCompletion)
   }
 
   /**
@@ -324,14 +416,32 @@ case class StreamingSymmetricHashJoinExec(
       }
     }
 
-    /** Remove old buffered state rows using watermarks for state keys and 
values */
-    def removeOldState(): Unit = {
+    /**
+     * Get an iterator over the values stored in this joiner's state manager 
for the given key.
+     *
+     * Should not be interleaved with mutations.
+     */
+    def get(key: UnsafeRow): Iterator[UnsafeRow] = {
+      joinStateManager.get(key)
+    }
+
+    /**
+     * Builds an iterator over old state key-value pairs, removing them lazily 
as they're produced.
+     *
+     * @note This iterator must be consumed fully before any other operations 
are made
+     * against this joiner's join state manager. For efficiency reasons, the 
intermediate states of
+     * the iterator leave the state manager in an undefined state.
+     *
+     * We do this to avoid requiring either two passes or full materialization 
when
+     * processing the rows for outer join.
+     */
+    def removeOldState(): Iterator[UnsafeRowPair] = {
       stateWatermarkPredicate match {
         case Some(JoinStateKeyWatermarkPredicate(expr)) =>
           joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc)
         case Some(JoinStateValueWatermarkPredicate(expr)) =>
           
joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc)
-        case _ =>
+        case _ => Iterator.empty
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
index e50274a..64c7189 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala
@@ -23,6 +23,7 @@ import scala.util.control.NonFatal
 import org.apache.spark.{Partition, SparkContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2}
+import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
 import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, 
AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, 
Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, 
LessThanOrEqual, Literal, Multiply, NamedExpression, 
PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, 
UnaryMinus}
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
 import 
org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression
@@ -34,7 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval
 /**
  * Helper object for [[StreamingSymmetricHashJoinExec]]. See that object for 
more details.
  */
-object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging {
+object StreamingSymmetricHashJoinHelper extends Logging {
 
   sealed trait JoinSide
   case object LeftSide extends JoinSide { override def toString(): String = 
"left" }
@@ -111,7 +112,7 @@ object StreamingSymmetricHashJoinHelper extends 
PredicateHelper with Logging {
         expr.map(JoinStateKeyWatermarkPredicate.apply _)
 
       } else if (isWatermarkDefinedOnInput) { // case 2 in the 
StreamingSymmetricHashJoinExec docs
-        val stateValueWatermark = getStateValueWatermark(
+        val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark(
           attributesToFindStateWatermarkFor = 
AttributeSet(oneSideInputAttributes),
           attributesWithEventWatermark = 
AttributeSet(otherSideInputAttributes),
           condition,
@@ -133,242 +134,6 @@ object StreamingSymmetricHashJoinHelper extends 
PredicateHelper with Logging {
   }
 
   /**
-   * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for 
context about it)
-   * given the join condition and the event time watermark. This is how it 
works.
-   * - The condition is split into conjunctive predicates, and we find the 
predicates of the
-   *   form `leftTime + c1 < rightTime + c2`   (or <=, >, >=).
-   * - We canoncalize the predicate and solve it with the event time watermark 
value to find the
-   *  value of the state watermark.
-   * This function is supposed to make best-effort attempt to get the state 
watermark. If there is
-   * any error, it will return None.
-   *
-   * @param attributesToFindStateWatermarkFor attributes of the side whose 
state watermark
-   *                                         is to be calculated
-   * @param attributesWithEventWatermark  attributes of the other side which 
has a watermark column
-   * @param joinCondition                 join condition
-   * @param eventWatermark                watermark defined on the input event 
data
-   * @return state value watermark in milliseconds, is possible.
-   */
-  def getStateValueWatermark(
-      attributesToFindStateWatermarkFor: AttributeSet,
-      attributesWithEventWatermark: AttributeSet,
-      joinCondition: Option[Expression],
-      eventWatermark: Option[Long]): Option[Long] = {
-
-    // If condition or event time watermark is not provided, then cannot 
calculate state watermark
-    if (joinCondition.isEmpty || eventWatermark.isEmpty) return None
-
-    // If there is not watermark attribute, then cannot define state watermark
-    if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) 
return None
-
-    def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = {
-      try {
-        getStateWatermarkFromLessThenPredicate(
-          l, r, attributesToFindStateWatermarkFor, 
attributesWithEventWatermark, eventWatermark)
-      } catch {
-        case NonFatal(e) =>
-          logWarning(s"Error trying to extract state constraint from condition 
$joinCondition", e)
-          None
-      }
-    }
-
-    val allStateWatermarks = 
splitConjunctivePredicates(joinCondition.get).flatMap { predicate =>
-
-      // The generated the state watermark cleanup expression is inclusive of 
the state watermark.
-      // If state watermark is W, all state where timestamp <= W will be 
cleaned up.
-      // Now when the canonicalized join condition solves to leftTime >= W, we 
dont want to clean
-      // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence 
the -1 below.
-      val stateWatermark = predicate match {
-        case LessThan(l, r) => getStateWatermarkSafely(l, r)
-        case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1)
-        case GreaterThan(l, r) => getStateWatermarkSafely(r, l)
-        case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 
1)
-        case _ => None
-      }
-      if (stateWatermark.nonEmpty) {
-        logInfo(s"Condition $joinCondition generated watermark constraint = 
${stateWatermark.get}")
-      }
-      stateWatermark
-    }
-    allStateWatermarks.reduceOption((x, y) => Math.min(x, y))
-  }
-
-  /**
-   * Extract the state value watermark (milliseconds) from the condition
-   * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find 
the constraint for
-   * leftTime using the watermark on the rightTime. Example:
-   *
-   * Input:                 rightTime-with-watermark + c1 < leftTime + c2
-   * Canonical form:        rightTime-with-watermark + c1 + (-c2) + 
(-leftTime) < 0
-   * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime
-   * With watermark value:  watermark-value + c1 + (-c2) < leftTime
-   */
-  private def getStateWatermarkFromLessThenPredicate(
-      leftExpr: Expression,
-      rightExpr: Expression,
-      attributesToFindStateWatermarkFor: AttributeSet,
-      attributesWithEventWatermark: AttributeSet,
-      eventWatermark: Option[Long]): Option[Long] = {
-
-    val attributesInCondition = AttributeSet(
-      leftExpr.collect { case a: AttributeReference => a } ++
-      rightExpr.collect { case a: AttributeReference => a }
-    )
-    if (attributesInCondition.filter { 
attributesToFindStateWatermarkFor.contains(_) }.size > 1 ||
-        attributesInCondition.filter { 
attributesWithEventWatermark.contains(_) }.size > 1) {
-      // If more than attributes present in condition from one side, then it 
cannot be solved
-      return None
-    }
-
-    def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = {
-      e.collectLeaves().collectFirst {
-        case a @ AttributeReference(_, TimestampType, _, _)
-          if attributesToFindStateWatermarkFor.contains(a) => a
-      }.nonEmpty
-    }
-
-    // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - 
(leftTime + c2) < 0
-    val allOnLeftExpr = Subtract(leftExpr, rightExpr)
-    logDebug(s"All on 
Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}")
-
-    // Canonicalization step 2: extract commutative terms
-    //    rightTime-with-watermark, c1, -leftTime, -c2
-    val terms = ExpressionSet(collectTerms(allOnLeftExpr))
-    logDebug("Terms extracted from join condition:\n\t" + 
terms.mkString("\n\t"))
-
-
-
-    // Find the term that has leftTime (i.e. the one present in 
attributesToFindConstraintFor
-    val constraintTerms = 
terms.filter(containsAttributeToFindStateConstraintFor)
-
-    // Verify there is only one correct constraint term and of the correct type
-    if (constraintTerms.size > 1) {
-      logWarning("Failed to extract state constraint terms: multiple time 
terms in condition\n\t" +
-        terms.mkString("\n\t"))
-      return None
-    }
-    if (constraintTerms.isEmpty) {
-      logDebug("Failed to extract state constraint terms: no time terms in 
condition\n\t" +
-        terms.mkString("\n\t"))
-      return None
-    }
-    val constraintTerm = constraintTerms.head
-    if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) {
-      // Incorrect condition. We want the constraint term in canonical form to 
be `-leftTime`
-      // so that resolve for it as `-leftTime + watermark + c < 0` ==> 
`watermark + c < leftTime`.
-      // Now, if the original conditions is `rightTime-with-watermark > 
leftTime` and watermark
-      // condition is `rightTime-with-watermark > watermarkValue`, then no 
constraint about
-      // `leftTime` can be inferred. In this case, after canonicalization and 
collection of terms,
-      // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we 
return None.
-      return None
-    }
-
-    // Replace watermark attribute with watermark value, and generate the 
resolved expression
-    // from the other terms. That is,
-    // rightTime-with-watermark, c1, -c2  =>  watermark, c1, -c2  =>  
watermark + c1 + (-c2)
-    logDebug(s"Constraint term from join condition:\t$constraintTerm")
-    val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term =>
-      term.transform {
-        case a @ AttributeReference(_, TimestampType, _, metadata)
-          if attributesWithEventWatermark.contains(a) && 
metadata.contains(delayKey) =>
-          Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0))
-      }
-    }.reduceLeft(Add)
-
-    // Calculate the constraint value
-    logInfo(s"Final expression to evaluate 
constraint:\t$exprWithWatermarkSubstituted")
-    val constraintValue = 
exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double]
-    Some((Double2double(constraintValue) / 1000.0).toLong)
-  }
-
-  /**
-   * Collect all the terms present in an expression after converting it into 
the form
-   * a + b + c + d where each term be either an attribute or a literal casted 
to long,
-   * optionally wrapped in a unary minus.
-   */
-  private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = {
-    var invalid = false
-
-    /** Wrap a term with UnaryMinus if its needs to be negated. */
-    def negateIfNeeded(expr: Expression, minus: Boolean): Expression = {
-      if (minus) UnaryMinus(expr) else expr
-    }
-
-    /**
-     * Recursively split the expression into its leaf terms contains 
attributes or literals.
-     * Returns terms only of the forms:
-     *    Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)),
-     *    Cast(AttributeReference, Double), 
UnaryMinus(Cast(AttributeReference, Double))
-     *    Multiply(Literal), UnaryMinus(Multiply(Literal))
-     *    Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal)))
-     *
-     * Note:
-     * - If term needs to be negated for making it a commutative term,
-     *   then it will be wrapped in UnaryMinus(...)
-     * - Each terms will be representing timestamp value or time interval in 
microseconds,
-     *   typed as doubles.
-     */
-    def collect(expr: Expression, negate: Boolean): Seq[Expression] = {
-      expr match {
-        case Add(left, right) =>
-          collect(left, negate) ++ collect(right, negate)
-        case Subtract(left, right) =>
-          collect(left, negate) ++ collect(right, !negate)
-        case TimeAdd(left, right, _) =>
-          collect(left, negate) ++ collect(right, negate)
-        case TimeSub(left, right, _) =>
-          collect(left, negate) ++ collect(right, !negate)
-        case UnaryMinus(child) =>
-          collect(child, !negate)
-        case CheckOverflow(child, _) =>
-          collect(child, negate)
-        case Cast(child, dataType, _) =>
-          dataType match {
-            case _: NumericType | _: TimestampType => collect(child, negate)
-            case _ =>
-              invalid = true
-              Seq.empty
-          }
-        case a: AttributeReference =>
-          val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) 
else a
-          Seq(negateIfNeeded(castedRef, negate))
-        case lit: Literal =>
-          // If literal of type calendar interval, then explicitly convert to 
millis
-          // Convert other number like literal to doubles representing millis 
(by x1000)
-          val castedLit = lit.dataType match {
-            case CalendarIntervalType =>
-              val calendarInterval = lit.value.asInstanceOf[CalendarInterval]
-              if (calendarInterval.months > 0) {
-                invalid = true
-                logWarning(
-                  s"Failed to extract state value watermark from condition 
$exprToCollectFrom " +
-                    s"as imprecise intervals like months and years cannot be 
used for" +
-                    s"watermark calculation. Use interval in terms of day 
instead.")
-                Literal(0.0)
-              } else {
-                Literal(calendarInterval.microseconds.toDouble)
-              }
-            case DoubleType =>
-              Multiply(lit, Literal(1000000.0))
-            case _: NumericType =>
-              Multiply(Cast(lit, DoubleType), Literal(1000000.0))
-            case _: TimestampType =>
-              Multiply(PreciseTimestampConversion(lit, TimestampType, 
LongType), Literal(1000000.0))
-          }
-          Seq(negateIfNeeded(castedLit, negate))
-        case a @ _ =>
-          logWarning(
-            s"Failed to extract state value watermark from condition 
$exprToCollectFrom due to $a")
-          invalid = true
-          Seq.empty
-      }
-    }
-
-    val terms = collect(exprToCollectFrom, negate = false)
-    if (!invalid) terms else Seq.empty
-  }
-
-  /**
    * A custom RDD that allows partitions to be "zipped" together, while 
ensuring the tasks'
    * preferred location is based on which executors have the required join 
state stores already
    * loaded. This is class is a modified verion of [[ZippedPartitionsRDD2]].

http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index 3764871..d256fb5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -76,7 +76,7 @@ class SymmetricHashJoinStateManager(
   /** Get all the values of a key */
   def get(key: UnsafeRow): Iterator[UnsafeRow] = {
     val numValues = keyToNumValues.get(key)
-    keyWithIndexToValue.getAll(key, numValues)
+    keyWithIndexToValue.getAll(key, numValues).map(_.value)
   }
 
   /** Append a new value to the key */
@@ -87,70 +87,163 @@ class SymmetricHashJoinStateManager(
   }
 
   /**
-   * Remove using a predicate on keys. See class docs for more context and 
implement details.
+   * Remove using a predicate on keys.
+   *
+   * This produces an iterator over the (key, value) pairs satisfying 
condition(key), where the
+   * underlying store is updated as a side-effect of producing next.
+   *
+   * This implies the iterator must be consumed fully without any other 
operations on this manager
+   * or the underlying store being interleaved.
    */
-  def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
-    val allKeyToNumValues = keyToNumValues.iterator
-
-    while (allKeyToNumValues.hasNext) {
-      val keyToNumValue = allKeyToNumValues.next
-      if (condition(keyToNumValue.key)) {
-        keyToNumValues.remove(keyToNumValue.key)
-        keyWithIndexToValue.removeAllValues(keyToNumValue.key, 
keyToNumValue.numValue)
+  def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): 
Iterator[UnsafeRowPair] = {
+    new NextIterator[UnsafeRowPair] {
+
+      private val allKeyToNumValues = keyToNumValues.iterator
+
+      private var currentKeyToNumValue: KeyAndNumValues = null
+      private var currentValues: Iterator[KeyWithIndexAndValue] = null
+
+      private def currentKey = currentKeyToNumValue.key
+
+      private val reusedPair = new UnsafeRowPair()
+
+      private def getAndRemoveValue() = {
+        val keyWithIndexAndValue = currentValues.next()
+        keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex)
+        reusedPair.withRows(currentKey, keyWithIndexAndValue.value)
+      }
+
+      override def getNext(): UnsafeRowPair = {
+        // If there are more values for the current key, remove and return the 
next one.
+        if (currentValues != null && currentValues.hasNext) {
+          return getAndRemoveValue()
+        }
+
+        // If there weren't any values left, try and find the next key that 
satisfies the removal
+        // condition and has values.
+        while (allKeyToNumValues.hasNext) {
+          currentKeyToNumValue = allKeyToNumValues.next()
+          if (removalCondition(currentKey)) {
+            currentValues = keyWithIndexToValue.getAll(
+              currentKey, currentKeyToNumValue.numValue)
+            keyToNumValues.remove(currentKey)
+
+            if (currentValues.hasNext) {
+              return getAndRemoveValue()
+            }
+          }
+        }
+
+        // We only reach here if there were no satisfying keys left, which 
means we're done.
+        finished = true
+        return null
       }
+
+      override def close: Unit = {}
     }
   }
 
   /**
-   * Remove using a predicate on values. See class docs for more context and 
implementation details.
+   * Remove using a predicate on values.
+   *
+   * At a high level, this produces an iterator over the (key, value) pairs 
such that value
+   * satisfies the predicate, where producing an element removes the value 
from the state store
+   * and producing all elements with a given key updates it accordingly.
+   *
+   * This implies the iterator must be consumed fully without any other 
operations on this manager
+   * or the underlying store being interleaved.
    */
-  def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = {
-    val allKeyToNumValues = keyToNumValues.iterator
+  def removeByValueCondition(removalCondition: UnsafeRow => Boolean): 
Iterator[UnsafeRowPair] = {
+    new NextIterator[UnsafeRowPair] {
 
-    while (allKeyToNumValues.hasNext) {
-      val keyToNumValue = allKeyToNumValues.next
-      val key = keyToNumValue.key
+      // Reuse this object to avoid creation+GC overhead.
+      private val reusedPair = new UnsafeRowPair()
 
-      var numValues: Long = keyToNumValue.numValue
-      var index: Long = 0L
-      var valueRemoved: Boolean = false
-      var valueForIndex: UnsafeRow = null
+      private val allKeyToNumValues = keyToNumValues.iterator
 
-      while (index < numValues) {
-        if (valueForIndex == null) {
-          valueForIndex = keyWithIndexToValue.get(key, index)
+      private var currentKey: UnsafeRow = null
+      private var numValues: Long = 0L
+      private var index: Long = 0L
+      private var valueRemoved: Boolean = false
+
+      // Push the data for the current key to the numValues store, and reset 
the tracking variables
+      // to their empty state.
+      private def updateNumValueForCurrentKey(): Unit = {
+        if (valueRemoved) {
+          if (numValues >= 1) {
+            keyToNumValues.put(currentKey, numValues)
+          } else {
+            keyToNumValues.remove(currentKey)
+          }
         }
-        if (condition(valueForIndex)) {
-          if (numValues > 1) {
-            val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1)
-            keyWithIndexToValue.put(key, index, valueAtMaxIndex)
-            keyWithIndexToValue.remove(key, numValues - 1)
-            valueForIndex = valueAtMaxIndex
+
+        currentKey = null
+        numValues = 0
+        index = 0
+        valueRemoved = false
+      }
+
+      // Find the next value satisfying the condition, updating `currentKey` 
and `numValues` if
+      // needed. Returns null when no value can be found.
+      private def findNextValueForIndex(): UnsafeRow = {
+        // Loop across all values for the current key, and then all other 
keys, until we find a
+        // value satisfying the removal condition.
+        def hasMoreValuesForCurrentKey = currentKey != null && index < 
numValues
+        def hasMoreKeys = allKeyToNumValues.hasNext
+        while (hasMoreValuesForCurrentKey || hasMoreKeys) {
+          if (hasMoreValuesForCurrentKey) {
+            // First search the values for the current key.
+            val currentValue = keyWithIndexToValue.get(currentKey, index)
+            if (removalCondition(currentValue)) {
+              return currentValue
+            } else {
+              index += 1
+            }
+          } else if (hasMoreKeys) {
+            // If we can't find a value for the current key, cleanup and start 
looking at the next.
+            // This will also happen the first time the iterator is called.
+            updateNumValueForCurrentKey()
+
+            val currentKeyToNumValue = allKeyToNumValues.next()
+            currentKey = currentKeyToNumValue.key
+            numValues = currentKeyToNumValue.numValue
           } else {
-            keyWithIndexToValue.remove(key, 0)
-            valueForIndex = null
+            // Should be unreachable, but in any case means a value couldn't 
be found.
+            return null
           }
-          numValues -= 1
-          valueRemoved = true
-        } else {
-          valueForIndex = null
-          index += 1
         }
+
+        // We tried and failed to find the next value.
+        return null
       }
-      if (valueRemoved) {
-        if (numValues >= 1) {
-          keyToNumValues.put(key, numValues)
+
+      override def getNext(): UnsafeRowPair = {
+        val currentValue = findNextValueForIndex()
+
+        // If there's no value, clean up and finish. There aren't any more 
available.
+        if (currentValue == null) {
+          updateNumValueForCurrentKey()
+          finished = true
+          return null
+        }
+
+        // The backing store is arraylike - we as the caller are responsible 
for filling back in
+        // any hole. So we swap the last element into the hole and decrement 
numValues to shorten.
+        // clean
+        if (numValues > 1) {
+          val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues 
- 1)
+          keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex)
+          keyWithIndexToValue.remove(currentKey, numValues - 1)
         } else {
-          keyToNumValues.remove(key)
+          keyWithIndexToValue.remove(currentKey, 0)
         }
+        numValues -= 1
+        valueRemoved = true
+
+        return reusedPair.withRows(currentKey, currentValue)
       }
-    }
-  }
 
-  def iterator(): Iterator[UnsafeRowPair] = {
-    val pair = new UnsafeRowPair()
-    keyWithIndexToValue.iterator.map { x =>
-      pair.withRows(x.key, x.value)
+      override def close: Unit = {}
     }
   }
 
@@ -309,19 +402,24 @@ class SymmetricHashJoinStateManager(
       stateStore.get(keyWithIndexRow(key, valueIndex))
     }
 
-    /** Get all the values for key and all indices. */
-    def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = {
+    /**
+     * Get all values and indices for the provided key.
+     * Should not return null.
+     */
+    def getAll(key: UnsafeRow, numValues: Long): 
Iterator[KeyWithIndexAndValue] = {
+      val keyWithIndexAndValue = new KeyWithIndexAndValue()
       var index = 0
-      new NextIterator[UnsafeRow] {
-        override protected def getNext(): UnsafeRow = {
+      new NextIterator[KeyWithIndexAndValue] {
+        override protected def getNext(): KeyWithIndexAndValue = {
           if (index >= numValues) {
             finished = true
             null
           } else {
             val keyWithIndex = keyWithIndexRow(key, index)
             val value = stateStore.get(keyWithIndex)
+            keyWithIndexAndValue.withNew(key, index, value)
             index += 1
-            value
+            keyWithIndexAndValue
           }
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
index ffa4c3c..d44af1d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
@@ -137,14 +137,16 @@ class SymmetricHashJoinStateManagerSuite extends 
StreamTest with BeforeAndAfter
         BoundReference(
           1, inputValueAttribWithWatermark.dataType, 
inputValueAttribWithWatermark.nullable),
         Literal(threshold))
-    manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _)
+    val iter = 
manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _)
+    while (iter.hasNext) iter.next()
   }
 
   /** Remove values where `time <= threshold` */
   def removeByValue(watermark: Long)(implicit manager: 
SymmetricHashJoinStateManager): Unit = {
     val expr = LessThanOrEqual(inputValueAttribWithWatermark, 
Literal(watermark))
-    manager.removeByValueCondition(
+    val iter = manager.removeByValueCondition(
       GeneratePredicate.generate(expr, inputValueAttribs).eval _)
+    while (iter.hasNext) iter.next()
   }
 
   def numRows(implicit manager: SymmetricHashJoinStateManager): Long = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to