This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fcf8636105c0 [SPARK-55144][SS] Introduce new state format version for 
performant stream-stream join
fcf8636105c0 is described below

commit fcf8636105c060b0cc05712d503d4f6fe5a5dff5
Author: Jungtaek Lim <[email protected]>
AuthorDate: Fri Feb 27 13:58:53 2026 +0900

    [SPARK-55144][SS] Introduce new state format version for performant 
stream-stream join
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to implement the new state format for stream-stream join, 
based on the new state key encoding w.r.t. event time awareness.
    
    The new state format is focused to eliminate the necessity of full scan 
during eviction & populating unmatched rows. The overhead of eviction should 
have bound to the actual number of state rows to be evicted (indirectly 
impacted by the amount of watermark advancement), but we have been doing the 
full scan with the existing state format, which could take more than 2 seconds 
in 1,000,000 rows even if there is zero row to be evicted. The overhead of 
eviction with the new state format wo [...]
    
    To achieve the above, we make a drastic change of data structure to move 
out from the logical array, and introduce a secondary index in addition to the 
main data.
    
    Each side of the join will use two (virtual) column families (total 4 
column families), which are following:
    
    * KeyWithTsToValuesStore
      * Primary data store
      * (key, event time) -> values
      * each element in values consists of (value, matched)
    * TsWithKeyTypeStore
      * Secondary index for efficient eviction
      * (event time, key) -> empty value (configured as multi-values)
      * numValues is calculated by the number of elements in the value side; 
new element is added when a new value is added into values in primary data store
        * This is to track the number of deleted rows accurately. It's optional 
but the metric has been useful so we want to keep it as it is.
    
    As the format of key part implies, KeyWithTsToValuesStore will use 
`TimestampAsPostfixKeyStateEncoderSpec`, and TsWithKeyTypeStore will use 
`TimestampAsPrefixKeyStateEncoderSpec`.
    
    The granularity of the timestamp for event time is 1 millisecond, which is 
in line with the granularity for watermark advancement. This can be a kind of 
knob controlling the number of the keys vs the number of the values in the key, 
trading off the granularity of eviction based on watermark advancement vs the 
size of key space (may impact performance).
    
    There are several follow-ups with this state format implementation, which 
can be addressed on top of this:
    
    * further optimizations with RocksDB offering: WriteBatch (for batched 
writes), MGET, etc.
    * retrieving matched rows with the "scope" of timestamps (in time-interval 
join)
      * while the format is ready to support ordered scan of timestamp, this 
needs another state store API to define the range of keys to scan, which needs 
some effort
    * Do not update the matched flag for non-outer join side.
    
    ### Why are the changes needed?
    
    The cost of eviction based on full scan is severe to make the stream-stream 
join to be lower latency. Also, the logic of maintaining logical array is 
complicated enough to maintain and the performance characteristic is less 
predictable given the behavior of deleting the element in random index (placing 
the value of the last index to the deleted index).
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. At this point, this state format is not integrated with the actual 
stream-stream join operator, and we need to do follow-up work for integration 
to finally introduce the change to user-facing.
    
    ### How was this patch tested?
    
    New UT suites, refactoring the existing suite to test with both time window 
and time interval cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53930 from HeartSaVioR/SPARK-55144-on-top-of-SPARK-55129.
    
    Authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |   5 +-
 .../join/StreamingSymmetricHashJoinExec.scala      |  91 +-
 .../join/StreamingSymmetricHashJoinHelper.scala    |  18 +-
 ...reamingSymmetricHashJoinValueRowConverter.scala | 118 +++
 .../join/SymmetricHashJoinStateManager.scala       | 957 ++++++++++++++++++---
 .../operators/stateful/statefulOperators.scala     |  38 +
 .../state/SymmetricHashJoinStateManagerSuite.scala | 896 ++++++++++++++-----
 7 files changed, 1747 insertions(+), 376 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b2a2b7027394..1874ff195516 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3087,9 +3087,12 @@ object SQLConf {
       .doc("State format version used by streaming join operations in a 
streaming query. " +
         "State between versions are tend to be incompatible, so state format 
version shouldn't " +
         "be modified after running. Version 3 uses a single state store with 
virtual column " +
-        "families instead of four stores and is only supported with RocksDB.")
+        "families instead of four stores and is only supported with RocksDB. 
NOTE: version " +
+        "1 is DEPRECATED and should not be explicitly set by users.")
       .version("3.0.0")
       .intConf
+      // TODO: [SPARK-55628] Add version 4 once we integrate the state format 
version 4 into
+      //  stream-stream join operator.
       .checkValue(v => Set(1, 2, 3).contains(v), "Valid versions are 1, 2, and 
3")
       .createWithDefault(2)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
index ef37185ce416..d8ad576bb68a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
@@ -433,7 +433,7 @@ case class StreamingSymmetricHashJoinExec(
         }
 
         val initIterFn = { () =>
-          val removedRowIter = joinerManager.leftSideJoiner.removeOldState()
+          val removedRowIter = 
joinerManager.leftSideJoiner.removeAndReturnOldState()
           removedRowIter.filterNot { kv =>
             stateFormatVersion match {
               case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, 
kv.value))
@@ -459,7 +459,7 @@ case class StreamingSymmetricHashJoinExec(
         }
 
         val initIterFn = { () =>
-          val removedRowIter = joinerManager.rightSideJoiner.removeOldState()
+          val removedRowIter = 
joinerManager.rightSideJoiner.removeAndReturnOldState()
           removedRowIter.filterNot { kv =>
             stateFormatVersion match {
               case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, 
kv.value))
@@ -484,13 +484,13 @@ case class StreamingSymmetricHashJoinExec(
           }
 
         val leftSideInitIterFn = { () =>
-          val removedRowIter = joinerManager.leftSideJoiner.removeOldState()
+          val removedRowIter = 
joinerManager.leftSideJoiner.removeAndReturnOldState()
           removedRowIter.filterNot(isKeyToValuePairMatched)
             .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
         }
 
         val rightSideInitIterFn = { () =>
-          val removedRowIter = joinerManager.rightSideJoiner.removeOldState()
+          val removedRowIter = 
joinerManager.rightSideJoiner.removeAndReturnOldState()
           removedRowIter.filterNot(isKeyToValuePairMatched)
             .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
         }
@@ -539,22 +539,19 @@ case class StreamingSymmetricHashJoinExec(
         // the outer side (e.g., left side for left outer join) while 
generating the outer "null"
         // outputs. Now, we have to remove unnecessary state rows from the 
other side (e.g., right
         // side for the left outer join) if possible. In all cases, nothing 
needs to be outputted,
-        // hence the removal needs to be done greedily by immediately 
consuming the returned
-        // iterator.
+        // hence the removal needs to be done greedily.
         //
         // For full outer joins, we have already removed unnecessary states 
from both sides, so
         // nothing needs to be outputted here.
-        val cleanupIter = joinType match {
-          case Inner | LeftSemi => joinerManager.removeOldState()
-          case LeftOuter => joinerManager.rightSideJoiner.removeOldState()
-          case RightOuter => joinerManager.leftSideJoiner.removeOldState()
-          case FullOuter => Iterator.empty
-          case _ => throwBadJoinTypeException()
-        }
-        while (cleanupIter.hasNext) {
-          cleanupIter.next()
-          numRemovedStateRows += 1
-        }
+        numRemovedStateRows += (
+          joinType match {
+            case Inner | LeftSemi => joinerManager.removeOldState()
+            case LeftOuter => joinerManager.rightSideJoiner.removeOldState()
+            case RightOuter => joinerManager.leftSideJoiner.removeOldState()
+            case FullOuter => 0L
+            case _ => throwBadJoinTypeException()
+          }
+        )
       }
 
       // Commit all state changes and update state store metrics
@@ -643,7 +640,7 @@ case class StreamingSymmetricHashJoinExec(
     private[this] val keyGenerator = UnsafeProjection.create(joinKeys, 
inputAttributes)
 
     private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate 
match {
-      case Some(JoinStateKeyWatermarkPredicate(expr)) =>
+      case Some(JoinStateKeyWatermarkPredicate(expr, _)) =>
         // inputSchema can be empty as expr should only have BoundReferences 
and does not require
         // the schema to generated predicate. See 
[[StreamingSymmetricHashJoinHelper]].
         Predicate.create(expr, Seq.empty).eval _
@@ -652,7 +649,7 @@ case class StreamingSymmetricHashJoinExec(
     }
 
     private[this] val stateValueWatermarkPredicateFunc = 
stateWatermarkPredicate match {
-      case Some(JoinStateValueWatermarkPredicate(expr)) =>
+      case Some(JoinStateValueWatermarkPredicate(expr, _)) =>
         Predicate.create(expr, inputAttributes).eval _
       case _ =>
         Predicate.create(Literal(false), Seq.empty).eval _  // false = do not 
remove if no predicate
@@ -792,6 +789,32 @@ case class StreamingSymmetricHashJoinExec(
       joinStateManager.get(key)
     }
 
+    /**
+     * Remove the old state key-value pairs from this joiner's state manager 
based on the state
+     * watermark predicate, and return the number of removed rows.
+     */
+    def removeOldState(): Long = {
+      stateWatermarkPredicate match {
+        case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
+          joinStateManager match {
+            case s: SupportsEvictByCondition =>
+              s.evictByKeyCondition(stateKeyWatermarkPredicateFunc)
+
+            case s: SupportsEvictByTimestamp =>
+              s.evictByTimestamp(stateWatermark)
+          }
+        case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
+          joinStateManager match {
+            case s: SupportsEvictByCondition =>
+              s.evictByValueCondition(stateValueWatermarkPredicateFunc)
+
+            case s: SupportsEvictByTimestamp =>
+              s.evictByTimestamp(stateWatermark)
+          }
+        case _ => 0L
+      }
+    }
+
     /**
      * Builds an iterator over old state key-value pairs, removing them lazily 
as they're produced.
      *
@@ -802,12 +825,24 @@ case class StreamingSymmetricHashJoinExec(
      * We do this to avoid requiring either two passes or full materialization 
when
      * processing the rows for outer join.
      */
-    def removeOldState(): Iterator[KeyToValuePair] = {
+    def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
       stateWatermarkPredicate match {
-        case Some(JoinStateKeyWatermarkPredicate(expr)) =>
-          joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc)
-        case Some(JoinStateValueWatermarkPredicate(expr)) =>
-          
joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc)
+        case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
+          joinStateManager match {
+            case s: SupportsEvictByCondition =>
+              s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc)
+
+            case s: SupportsEvictByTimestamp =>
+              s.evictAndReturnByTimestamp(stateWatermark)
+          }
+        case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
+          joinStateManager match {
+            case s: SupportsEvictByCondition =>
+              
s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc)
+
+            case s: SupportsEvictByTimestamp =>
+              s.evictAndReturnByTimestamp(stateWatermark)
+          }
         case _ => Iterator.empty
       }
     }
@@ -836,8 +871,12 @@ case class StreamingSymmetricHashJoinExec(
   private case class OneSideHashJoinerManager(
       leftSideJoiner: OneSideHashJoiner, rightSideJoiner: OneSideHashJoiner) {
 
-    def removeOldState(): Iterator[KeyToValuePair] = {
-      leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
+    def removeOldState(): Long = {
+      leftSideJoiner.removeOldState() + rightSideJoiner.removeOldState()
+    }
+
+    def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
+      leftSideJoiner.removeAndReturnOldState() ++ 
rightSideJoiner.removeAndReturnOldState()
     }
 
     def metrics: StateStoreMetrics = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala
index 7b02a43cd5a9..a916b0d626d3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala
@@ -46,12 +46,12 @@ object StreamingSymmetricHashJoinHelper extends Logging {
     override def toString: String = s"$desc: $expr"
   }
   /** Predicate for watermark on state keys */
-  case class JoinStateKeyWatermarkPredicate(expr: Expression)
+  case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: 
Long)
     extends JoinStateWatermarkPredicate {
     def desc: String = "key predicate"
   }
   /** Predicate for watermark on state values */
-  case class JoinStateValueWatermarkPredicate(expr: Expression)
+  case class JoinStateValueWatermarkPredicate(expr: Expression, 
stateWatermark: Long)
     extends JoinStateWatermarkPredicate {
     def desc: String = "value predicate"
   }
@@ -212,8 +212,11 @@ object StreamingSymmetricHashJoinHelper extends Logging {
           oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType,
           oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable)
         val expr = watermarkExpression(Some(keyExprWithWatermark), 
eventTimeWatermarkForEviction)
-        expr.map(JoinStateKeyWatermarkPredicate.apply _)
-
+        expr.map { e =>
+          // watermarkExpression only provides the expression when 
eventTimeWatermarkForEviction
+          // is defined
+          JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get)
+        }
       } else if (isWatermarkDefinedOnInput) { // case 2 in the 
StreamingSymmetricHashJoinExec docs
         val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark(
           attributesToFindStateWatermarkFor = 
AttributeSet(oneSideInputAttributes),
@@ -222,8 +225,11 @@ object StreamingSymmetricHashJoinHelper extends Logging {
           eventTimeWatermarkForEviction)
         val inputAttributeWithWatermark = 
oneSideInputAttributes.find(_.metadata.contains(delayKey))
         val expr = watermarkExpression(inputAttributeWithWatermark, 
stateValueWatermark)
-        expr.map(JoinStateValueWatermarkPredicate.apply _)
-
+        expr.map { e =>
+          // watermarkExpression only provides the expression when 
eventTimeWatermarkForEviction
+          // is defined
+          JoinStateValueWatermarkPredicate(e, stateValueWatermark.get)
+        }
       } else {
         None
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinValueRowConverter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinValueRowConverter.scala
new file mode 100644
index 000000000000..b4258125bf89
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinValueRowConverter.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.operators.stateful.join
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Literal, UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager.ValueAndMatchPair
+import org.apache.spark.sql.types.BooleanType
+
+/**
+ * Converter between the value row stored in state store and the (actual 
value, match) pair.
+ */
+trait StreamingSymmetricHashJoinValueRowConverter {
+  /** Defines the schema of the value row (the value side of K-V in state 
store). */
+  def valueAttributes: Seq[Attribute]
+
+  /**
+   * Convert the value row to (actual value, match) pair.
+   *
+   * NOTE: implementations should ensure the result row is NOT reused during 
execution, so
+   * that caller can safely read the value in any time.
+   */
+  def convertValue(value: UnsafeRow): ValueAndMatchPair
+
+  /**
+   * Build the value row from (actual value, match) pair. This is expected to 
be called just
+   * before storing to the state store.
+   *
+   * NOTE: depending on the implementation, the result row "may" be reused 
during execution
+   * (to avoid initialization of object), so the caller should ensure that the 
logic doesn't
+   * get affected by such behavior. Call copy() against the result row if 
needed.
+   */
+  def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow
+}
+
+/**
+ * V1 implementation of the converter, which simply stores the actual value in 
state store and
+ * treats the match status as false. Note that only state format version 1 
uses this converter,
+ * and this is only for backward compatibility. There is known correctness 
issue for outer join
+ * with this converter - see SPARK-26154 for more details.
+ */
+class StreamingSymmetricHashJoinValueRowConverterFormatV1(
+    inputValueAttributes: Seq[Attribute]) extends 
StreamingSymmetricHashJoinValueRowConverter {
+  override val valueAttributes: Seq[Attribute] = inputValueAttributes
+
+  override def convertValue(value: UnsafeRow): ValueAndMatchPair =
+    if (value != null) ValueAndMatchPair(value, false) else null
+
+  override def convertToValueRow(value: UnsafeRow, matched: Boolean): 
UnsafeRow = value
+}
+
+/**
+ * V2 implementation of the converter, which adds an extra boolean field to 
store the match status
+ * in state store. This is the default implementation for state format version 
2 and above, which
+ * fixes the correctness issue for outer join in V1 implementation.
+ */
+class StreamingSymmetricHashJoinValueRowConverterFormatV2(
+    inputValueAttributes: Seq[Attribute]) extends 
StreamingSymmetricHashJoinValueRowConverter {
+  private val valueWithMatchedExprs = inputValueAttributes :+ Literal(true)
+  private val indexOrdinalInValueWithMatchedRow = inputValueAttributes.size
+
+  private val valueWithMatchedRowGenerator = 
UnsafeProjection.create(valueWithMatchedExprs,
+    inputValueAttributes)
+
+  override val valueAttributes: Seq[Attribute] = inputValueAttributes :+
+    AttributeReference("matched", BooleanType)()
+
+  // Projection to generate key row from (value + matched) row
+  private val valueRowGenerator = UnsafeProjection.create(
+    inputValueAttributes, valueAttributes)
+
+  override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
+    if (value != null) {
+      ValueAndMatchPair(valueRowGenerator(value).copy(),
+        value.getBoolean(indexOrdinalInValueWithMatchedRow))
+    } else {
+      null
+    }
+  }
+
+  override def convertToValueRow(value: UnsafeRow, matched: Boolean): 
UnsafeRow = {
+    val row = valueWithMatchedRowGenerator(value)
+    row.setBoolean(indexOrdinalInValueWithMatchedRow, matched)
+    row
+  }
+}
+
+/**
+ * The entry point to create the converter for value row in state store. The 
converter is created
+ * based on the state format version.
+ */
+object StreamingSymmetricHashJoinValueRowConverter {
+  def create(
+      inputValueAttributes: Seq[Attribute],
+      stateFormatVersion: Int): StreamingSymmetricHashJoinValueRowConverter = {
+    stateFormatVersion match {
+      case 1 => new 
StreamingSymmetricHashJoinValueRowConverterFormatV1(inputValueAttributes)
+      case 2 | 3 | 4 =>
+        new 
StreamingSymmetricHashJoinValueRowConverterFormatV2(inputValueAttributes)
+      case _ => throw new IllegalArgumentException ("Incorrect state format 
version! " +
+        s"version $stateFormatVersion")
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala
index 49d08d52b3bf..9aa41e196659 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala
@@ -27,16 +27,736 @@ import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, 
STATE_STORE_ID}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Expression, JoinedRow, Literal, SafeProjection, 
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Expression, JoinedRow, Literal, NamedExpression, 
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.execution.metric.SQLMetric
-import 
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import 
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
 StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import 
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
 KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, 
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, 
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, 
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, 
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField, 
StructType}
+import 
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
 KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, 
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, 
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, 
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, 
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, 
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeySt [...]
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, 
StructField, StructType}
 import org.apache.spark.util.NextIterator
 
+/**
+ * Base trait of the state manager for stream-stream symmetric hash join 
operator.
+ *
+ * This defines the basic APIs for the state manager, except the methods for 
eviction which are
+ * defined in separate traits - See [[SupportsEvictByCondition]] and 
[[SupportsEvictByTimestamp]].
+ *
+ * Implementation classes are expected to inherit those traits as needed, 
depending on the eviction
+ * strategy they support.
+ */
+trait SymmetricHashJoinStateManager {
+  import SymmetricHashJoinStateManager._
+
+  /** Append a new value to the given key, with the flag of matched or not. */
+  def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+  /**
+   * Retrieve all matched values from given key. This doesn't update the 
matched flag for the
+   * values being returned, hence should be only used if appropriate for the 
join type and the
+   * side.
+   */
+  def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+  /**
+   * Retrieve all joined rows for the given key. The joined rows are generated 
with the provided
+   * generateJoinedRow function and filtered with the provided predicate.
+   *
+   * For excludeRowsAlreadyMatched = true, the method will only return the 
joined rows for the
+   * values which have not been marked as matched yet. The matched flag will 
be updated to true
+   * for the values being returned, if it is semantically required to do so.
+   *
+   * It is caller's responsibility to consume the whole iterator.
+   */
+  def getJoinedRows(
+      key: UnsafeRow,
+      generateJoinedRow: InternalRow => JoinedRow,
+      predicate: JoinedRow => Boolean,
+      excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+  /**
+   * Provide all key-value pairs in the state manager.
+   *
+   * It is caller's responsibility to consume the whole iterator.
+   */
+  def iterator: Iterator[KeyToValuePair]
+
+  /** Commit all the changes to all the state stores */
+  def commit(): Unit
+
+  /** Abort any changes to the state stores if needed */
+  def abortIfNeeded(): Unit
+
+  /** Provide the metrics. */
+  def metrics: StateStoreMetrics
+
+  /**
+   * Get state store checkpoint information of the two state stores for this 
joiner, after
+   * they finished data processing.
+   */
+  def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+/**
+ * This trait is specific to help the old version of state manager 
implementation (v1-v3) to work
+ * with existing tests which look up the state store with key with index.
+ */
+trait SupportsIndexedKeys {
+  def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+  protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues: 
Long): Unit
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by 
condition.
+ * This is for the state manager implementations which have to perform full 
scan
+ * for eviction.
+ */
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+  import SymmetricHashJoinStateManager._
+
+  /** Evict the state via condition on the key. Returns the number of values 
evicted. */
+  def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+  /**
+   * Evict the state via condition on the key, and return the evicted 
key-value pairs.
+   *
+   * It is caller's responsibility to consume the whole iterator.
+   */
+  def evictAndReturnByKeyCondition(
+      removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+  /** Evict the state via condition on the value. Returns the number of values 
evicted. */
+  def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+  /**
+   * Evict the state via condition on the value, and return the evicted 
key-value pairs.
+   *
+   * It is caller's responsibility to consume the whole iterator.
+   */
+  def evictAndReturnByValueCondition(
+      removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by 
timestamp. This is for
+ * the state manager implementations which maintain the state with event time 
and can efficiently
+ * scan the keys with event time smaller than the given timestamp for eviction.
+ */
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+  import SymmetricHashJoinStateManager._
+
+  /** Evict the state by timestamp. Returns the number of values evicted. */
+  def evictByTimestamp(endTimestamp: Long): Long
+
+  /**
+   * Evict the state by timestamp and return the evicted key-value pairs.
+   *
+   * It is caller's responsibility to consume the whole iterator.
+   */
+  def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+/**
+ * The version 4 of stream-stream join state manager implementation, which is 
designed to optimize
+ * the eviction with watermark. Previous versions require full scan to find 
the keys to evict,
+ * while this version only scans the keys with event time smaller than the 
watermark.
+ *
+ * In this implementation, we no longer build a logical array of values; 
instead, we store the
+ * (key, timestamp) -> values in the primary store, and maintain a secondary 
index of
+ * (timestamp, key) to scan the keys to evict for each watermark. To retrieve 
the values for a key,
+ * we perform prefix scan with the key to get all the (key, timestamp) -> 
values.
+ *
+ * This implementation leverages (virtual) column family, and uses features 
which are only
+ * available in RocksDB state store provider. Please make sure to set the 
state store provider to
+ * RocksDBStateStoreProvider.
+ *
+ * Refer to the [[KeyWithTsToValuesStore]] and [[TsWithKeyTypeStore]] for more 
details.
+ */
+class SymmetricHashJoinStateManagerV4(
+    joinSide: JoinSide,
+    inputValueAttributes: Seq[Attribute],
+    joinKeys: Seq[Expression],
+    stateInfo: Option[StatefulOperatorStateInfo],
+    storeConf: StateStoreConf,
+    hadoopConf: Configuration,
+    partitionId: Int,
+    keyToNumValuesStateStoreCkptId: Option[String],
+    keyWithIndexToValueStateStoreCkptId: Option[String],
+    stateFormatVersion: Int,
+    skippedNullValueCount: Option[SQLMetric] = None,
+    useStateStoreCoordinator: Boolean = true,
+    snapshotOptions: Option[SnapshotOptions] = None,
+    joinStoreGenerator: JoinStateManagerStoreGenerator)
+  extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with 
Logging {
+
+  // TODO: [SPARK-55729] Once the new state manager is integrated to 
stream-stream join operator,
+  //  we should support state data source to understand the state for state 
format version v4.
+
+  import SymmetricHashJoinStateManager._
+
+  protected val keySchema = StructType(
+    joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", 
k.dataType, k.nullable) })
+  protected val keyAttributes = toAttributes(keySchema)
+  private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+    inputValueAttributes,
+    // NOTE: This does not accept multiple event time columns. This is not the 
same with the
+    // operator which we offer the backward compatibility, but it involves too 
many layers to
+    // pass the information. The information is in SQLConf.
+    allowMultipleEventTimeColumns = false)
+
+  private val random = new scala.util.Random(System.currentTimeMillis())
+  private val bucketCountForNoEventTime = 1024
+  private val extractEventTimeFn: UnsafeRow => Long = { row =>
+    eventTimeColIdxOpt match {
+      case Some(idx) =>
+        val attr = inputValueAttributes(idx)
+
+        if (attr.dataType.isInstanceOf[StructType]) {
+          // NOTE: We assume this is window struct, as same as 
WatermarkSupport.watermarkExpression
+          row.getStruct(idx, 2).getLong(1)
+        } else {
+          row.getLong(idx)
+        }
+
+      case _ =>
+        // When event time column is not available, we will use random 
bucketing strategy to decide
+        // where the new value will be stored. There is a trade-off between 
the bucket size and the
+        // number of values in each bucket; we can tune the bucket size with 
the configuration if
+        // we figure out the magic number to not work well.
+        random.nextInt(bucketCountForNoEventTime)
+    }
+  }
+
+  private val eventTimeColIdxOptInKey: Option[Int] = {
+    joinKeys.zipWithIndex.collectFirst {
+      case (ne: NamedExpression, index)
+        if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+    }
+  }
+
+  private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+    eventTimeColIdxOptInKey.map { idx =>
+      val attr = keyAttributes(idx)
+      if (attr.dataType.isInstanceOf[StructType]) {
+        // NOTE: We assume this is window struct, as same as 
WatermarkSupport.watermarkExpression
+        row.getStruct(idx, 2).getLong(1)
+      } else {
+        row.getLong(idx)
+      }
+    }
+  }
+
+  private val dummySchema = StructType(
+    Seq(StructField("dummy", NullType, nullable = true))
+  )
+
+  // TODO: [SPARK-55628] Below two fields need to be handled properly during 
integration with
+  //   the operator.
+  private val stateStoreCkptId: Option[String] = None
+  private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
+
+  private var stateStoreProvider: StateStoreProvider = _
+
+  // We will use the dummy schema for the default CF since we will register CF 
separately.
+  private val stateStore = getStateStore(
+    dummySchema, dummySchema, useVirtualColumnFamilies = true,
+    NoPrefixKeyStateEncoderSpec(dummySchema), useMultipleValuesPerKey = false
+  )
+
+  private def getStateStore(
+      keySchema: StructType,
+      valueSchema: StructType,
+      useVirtualColumnFamilies: Boolean,
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      useMultipleValuesPerKey: Boolean): StateStore = {
+    val storeName = StateStoreId.DEFAULT_STORE_NAME
+    val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId, 
storeName)
+    val store = if (useStateStoreCoordinator) {
+      assert(handlerSnapshotOptions.isEmpty, "Should not use state store 
coordinator " +
+        "when reading state as data source.")
+      joinStoreGenerator.getStore(
+        storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+        stateInfo.get.storeVersion, stateStoreCkptId, None, 
useVirtualColumnFamilies,
+        useMultipleValuesPerKey, storeConf, hadoopConf)
+    } else {
+      // This class will manage the state store provider by itself.
+      stateStoreProvider = StateStoreProvider.createAndInit(
+        storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+        useColumnFamilies = useVirtualColumnFamilies,
+        storeConf, hadoopConf, useMultipleValuesPerKey = 
useMultipleValuesPerKey,
+        stateSchemaProvider = None)
+      if (handlerSnapshotOptions.isDefined) {
+        if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
+          throw 
StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
+            stateStoreProvider.getClass.toString)
+        }
+        val opts = handlerSnapshotOptions.get
+        stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay]
+          .replayStateFromSnapshot(
+            opts.snapshotVersion,
+            opts.endVersion,
+            readOnly = true,
+            opts.startStateStoreCkptId,
+            opts.endStateStoreCkptId)
+      } else {
+        stateStoreProvider.getStore(stateInfo.get.storeVersion, 
stateStoreCkptId)
+      }
+    }
+    logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}")
+    store
+  }
+
+  private val keyWithTsToValues = new KeyWithTsToValuesStore
+
+  private val tsWithKey = new TsWithKeyTypeStore
+
+  override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): 
Unit = {
+    val eventTime = extractEventTimeFn(value)
+    // We always do blind merge for appending new value.
+    keyWithTsToValues.append(key, eventTime, value, matched)
+    tsWithKey.add(eventTime, key)
+  }
+
+  override def getJoinedRows(
+      key: UnsafeRow,
+      generateJoinedRow: InternalRow => JoinedRow,
+      predicate: JoinedRow => Boolean,
+      excludeRowsAlreadyMatched: Boolean): Iterator[JoinedRow] = {
+    // TODO: [SPARK-55147] We could improve this method to get the scope of 
timestamp and scan keys
+    //  more efficiently. For now, we just get all values for the key.
+    def getJoinedRowsFromTsAndValues(
+        ts: Long,
+        valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
+      new NextIterator[JoinedRow] {
+        private var currentIndex = 0
+
+        private var shouldUpdateValuesIntoStateStore = false
+
+        override protected def getNext(): JoinedRow = {
+          var ret: JoinedRow = null
+          while (ret == null && currentIndex < valuesAndMatched.length) {
+            val vmp = valuesAndMatched(currentIndex)
+
+            if (excludeRowsAlreadyMatched && vmp.matched) {
+              // Skip this one
+            } else {
+              val joinedRow = generateJoinedRow(vmp.value)
+              if (predicate(joinedRow)) {
+                if (!vmp.matched) {
+                  // Update the array to contain the value having matched = 
true
+                  valuesAndMatched(currentIndex) = vmp.copy(matched = true)
+                  // Need to update matched flag
+                  shouldUpdateValuesIntoStateStore = true
+                }
+
+                ret = joinedRow
+              } else {
+                // skip this one
+              }
+            }
+
+            currentIndex += 1
+          }
+
+          if (ret == null) {
+            assert(currentIndex == valuesAndMatched.length)
+            finished = true
+            null
+          } else {
+            ret
+          }
+        }
+
+        override protected def close(): Unit = {
+          if (shouldUpdateValuesIntoStateStore) {
+            // Update back to the state store
+            val updatedValuesWithMatched = valuesAndMatched.map { vmp =>
+              (vmp.value, vmp.matched)
+            }.toSeq
+            keyWithTsToValues.put(key, ts, updatedValuesWithMatched)
+          }
+        }
+      }
+    }
+
+    val ret = extractEventTimeFnFromKey(key) match {
+      case Some(ts) =>
+        val valuesAndMatchedIter = keyWithTsToValues.get(key, ts)
+        getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray)
+
+      case _ =>
+        keyWithTsToValues.getValues(key).flatMap { result =>
+          val ts = result.timestamp
+          val valuesAndMatched = result.values.toArray
+          getJoinedRowsFromTsAndValues(ts, valuesAndMatched)
+        }
+    }
+    ret.filter(_ != null)
+  }
+
+  /**
+   * NOTE: The entry provided by Iterator.next() will be reused. It is a 
caller's responsibility
+   * to copy it properly if caller needs to keep the reference after next() is 
called again.
+   */
+  override def iterator: Iterator[KeyToValuePair] = {
+    val reusableKeyToValuePair = KeyToValuePair()
+    keyWithTsToValues.iterator().map { kv =>
+      reusableKeyToValuePair.withNew(kv.key, kv.value, kv.matched)
+    }
+  }
+
+  override def evictByTimestamp(endTimestamp: Long): Long = {
+    var removed = 0L
+    tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
+      val key = evicted.key
+      val timestamp = evicted.timestamp
+      val numValues = evicted.numValues
+
+      // Remove from both primary and secondary stores
+      keyWithTsToValues.remove(key, timestamp)
+      tsWithKey.remove(key, timestamp)
+
+      removed += numValues
+    }
+    removed
+  }
+
+  override def evictAndReturnByTimestamp(endTimestamp: Long): 
Iterator[KeyToValuePair] = {
+    val reusableKeyToValuePair = KeyToValuePair()
+
+    tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
+      val key = evicted.key
+      val timestamp = evicted.timestamp
+      val values = keyWithTsToValues.get(key, timestamp)
+
+      // Remove from both primary and secondary stores
+      keyWithTsToValues.remove(key, timestamp)
+      tsWithKey.remove(key, timestamp)
+
+      values.map { value =>
+        reusableKeyToValuePair.withNew(key, value)
+      }
+    }
+  }
+
+  override def commit(): Unit = {
+    stateStore.commit()
+    logDebug("Committed, metrics = " + stateStore.metrics)
+  }
+
+  override def abortIfNeeded(): Unit = {
+    if (!stateStore.hasCommitted) {
+      logInfo(log"Aborted store ${MDC(STATE_STORE_ID, stateStore.id)}")
+      stateStore.abort()
+    }
+    // If this class manages a state store provider by itself, it should take 
care of closing
+    // provider instance as well.
+    if (stateStoreProvider != null) {
+      stateStoreProvider.close()
+    }
+  }
+
+  // Clean up any state store resources if necessary at the end of the task
+  Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ => 
abortIfNeeded() } }
+
+  class GetValuesResult(var timestamp: Long = -1, var values: 
Seq[ValueAndMatchPair] = Seq.empty) {
+    def withNew(newTimestamp: Long, newValues: Seq[ValueAndMatchPair]): 
GetValuesResult = {
+      this.timestamp = newTimestamp
+      this.values = newValues
+      this
+    }
+  }
+
+  /**
+   * The primary store to store the key-value pairs.
+   *
+   * The state format of the primary store is following:
+   * [key][timestamp (event time)] -> [(value, matched), (value, matched), ...]
+   *
+   * The values are bucketed by event time to facilitate efficient eviction by 
watermark; the
+   * secondary index will provide the way to scan the key + timestamp pairs 
for the eviction, and
+   * it will be easy to perform retrieval/removal of the values based on key + 
timestamp pairs.
+   * There is no case where we evict only part of the values for the same key 
+ timestamp.
+   *
+   * The matched flag is used to indicate whether the value has been matched 
with any row from the
+   * other side.
+   */
+  private class KeyWithTsToValuesStore {
+    private val valueRowConverter = 
StreamingSymmetricHashJoinValueRowConverter.create(
+      inputValueAttributes, stateFormatVersion = 4)
+
+    // Set up virtual column family name in the store if it is being used
+    private val colFamilyName = getStateStoreName(joinSide, 
KeyWithTsToValuesType)
+
+    private val keySchemaWithTimestamp = 
TimestampKeyStateEncoder.keySchemaWithTimestamp(keySchema)
+    private val detachTimestampProjection: UnsafeProjection =
+      
TimestampKeyStateEncoder.getDetachTimestampProjection(keySchemaWithTimestamp)
+    private val attachTimestampProjection: UnsafeProjection =
+      TimestampKeyStateEncoder.getAttachTimestampProjection(keySchema)
+
+    // Create the specific column family in the store for this join side's 
KeyWithIndexToValueStore
+    stateStore.createColFamilyIfAbsent(
+      colFamilyName,
+      keySchema,
+      valueRowConverter.valueAttributes.toStructType,
+      TimestampAsPostfixKeyStateEncoderSpec(keySchemaWithTimestamp),
+      useMultipleValuesPerKey = true
+    )
+
+    private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
+      TimestampKeyStateEncoder.attachTimestamp(
+        attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
+    }
+
+    def append(key: UnsafeRow, timestamp: Long, value: UnsafeRow, matched: 
Boolean): Unit = {
+      val valueWithMatched = valueRowConverter.convertToValueRow(value, 
matched)
+      stateStore.merge(createKeyRow(key, timestamp), valueWithMatched, 
colFamilyName)
+    }
+
+    def put(
+        key: UnsafeRow,
+        timestamp: Long,
+        valuesWithMatched: Seq[(UnsafeRow, Boolean)]): Unit = {
+      // copy() is required because convertToValueRow reuses its internal 
UnsafeProjection output
+      // TODO: [SPARK-55732] StateStore.putList should allow iterator to be 
passed in, so that we
+      //  don't need to materialize the array and copy the values here.
+      val valuesToPut = valuesWithMatched.map { case (value, matched) =>
+        valueRowConverter.convertToValueRow(value, matched).copy()
+      }.toArray
+      stateStore.putList(createKeyRow(key, timestamp), valuesToPut, 
colFamilyName)
+    }
+
+    def get(key: UnsafeRow, timestamp: Long): Iterator[ValueAndMatchPair] = {
+      stateStore.valuesIterator(createKeyRow(key, timestamp), 
colFamilyName).map { valueRow =>
+        valueRowConverter.convertValue(valueRow)
+      }
+    }
+
+    // NOTE: We do not have a case where we only remove a part of values. Even 
if that is needed
+    // we handle it via put() with writing a new array.
+    def remove(key: UnsafeRow, timestamp: Long): Unit = {
+      stateStore.remove(createKeyRow(key, timestamp), colFamilyName)
+    }
+
+    // NOTE: This assumes we consume the whole iterator to trigger completion.
+    def getValues(key: UnsafeRow): Iterator[GetValuesResult] = {
+      val reusableGetValuesResult = new GetValuesResult()
+
+      new NextIterator[GetValuesResult] {
+        private val iter = stateStore.prefixScanWithMultiValues(key, 
colFamilyName)
+
+        private var currentTs = -1L
+        private val valueAndMatchPairs = 
scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()
+
+        @tailrec
+        override protected def getNext(): GetValuesResult = {
+          if (iter.hasNext) {
+            val unsafeRowPair = iter.next()
+
+            val ts = 
TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
+
+            if (currentTs == -1L) {
+              // First time
+              currentTs = ts
+            }
+
+            if (currentTs != ts) {
+              assert(valueAndMatchPairs.nonEmpty,
+                "timestamp has changed but no values collected from previous 
timestamp! " +
+                s"This should not happen. currentTs: $currentTs, new ts: $ts")
+
+              // Return previous batch
+              val result = reusableGetValuesResult.withNew(
+                currentTs, valueAndMatchPairs.toSeq)
+
+              // Reset for new timestamp
+              currentTs = ts
+              valueAndMatchPairs.clear()
+
+              // Add current value
+              val value = valueRowConverter.convertValue(unsafeRowPair.value)
+              valueAndMatchPairs += value
+              result
+            } else {
+              // Same timestamp, accumulate values
+              val value = valueRowConverter.convertValue(unsafeRowPair.value)
+              valueAndMatchPairs += value
+
+              // Continue to next
+              getNext()
+            }
+          } else {
+            if (currentTs != -1L) {
+              assert(valueAndMatchPairs.nonEmpty)
+
+              // Return last batch
+              val result = reusableGetValuesResult.withNew(
+                currentTs, valueAndMatchPairs.toSeq)
+
+              // Mark as finished
+              currentTs = -1L
+              valueAndMatchPairs.clear()
+              result
+            } else {
+              finished = true
+              null
+            }
+          }
+        }
+
+        override protected def close(): Unit = iter.close()
+      }
+    }
+
+    def iterator(): Iterator[KeyAndTsToValuePair] = {
+      val iter = stateStore.iteratorWithMultiValues(colFamilyName)
+      val reusableKeyAndTsToValuePair = KeyAndTsToValuePair()
+      iter.map { kv =>
+        val keyRow = detachTimestampProjection(kv.key)
+        val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key)
+        val value = valueRowConverter.convertValue(kv.value)
+
+        reusableKeyAndTsToValuePair.withNew(keyRow, ts, value)
+      }
+    }
+  }
+
+  /**
+   * The secondary index for efficient state removal with watermark.
+   *
+   * The state format of the secondary index is following:
+   * [timestamp (adjusted for ordering, 8 bytes)][key] -> [list of empty 
values]
+   *
+   * The value part is used to track the number of values for the same (key, 
timestamp) in the
+   * primary store, so that we can track the number of values being removed 
for each eviction and
+   * update metrics accordingly. Alternatively, we can also maintain an 
integer count in the value
+   * part, but we found blind merge and count later to be more efficient than 
read-modify-write.
+   */
+  private class TsWithKeyTypeStore {
+    private val valueStructType = StructType(Array(StructField("__dummy__", 
NullType)))
+    private val EMPTY_ROW =
+      
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+    // Set up virtual column family name in the store if it is being used
+    private val colFamilyName = getStateStoreName(joinSide, TsWithKeyType)
+
+    private val keySchemaWithTimestamp = 
TimestampKeyStateEncoder.keySchemaWithTimestamp(keySchema)
+    private val detachTimestampProjection: UnsafeProjection =
+      
TimestampKeyStateEncoder.getDetachTimestampProjection(keySchemaWithTimestamp)
+    private val attachTimestampProjection: UnsafeProjection =
+      TimestampKeyStateEncoder.getAttachTimestampProjection(keySchema)
+
+    // Create the specific column family in the store for this join side's 
KeyWithIndexToValueStore
+    stateStore.createColFamilyIfAbsent(
+      colFamilyName,
+      keySchema,
+      valueStructType,
+      TimestampAsPrefixKeyStateEncoderSpec(keySchemaWithTimestamp),
+      useMultipleValuesPerKey = true
+    )
+
+    private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
+      TimestampKeyStateEncoder.attachTimestamp(
+        attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
+    }
+
+    def add(timestamp: Long, key: UnsafeRow): Unit = {
+      stateStore.merge(createKeyRow(key, timestamp), EMPTY_ROW, colFamilyName)
+    }
+
+    def remove(key: UnsafeRow, timestamp: Long): Unit = {
+      stateStore.remove(createKeyRow(key, timestamp), colFamilyName)
+    }
+
+    case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: 
Int)
+
+    // NOTE: This assumes we consume the whole iterator to trigger completion.
+    def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
+      val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
+      new NextIterator[EvictedKeysResult]() {
+        var currentKeyRow: UnsafeRow = null
+        var currentEventTime: Long = -1L
+        var count: Int = 0
+        var isBeyondUpperBound: Boolean = false
+
+        override protected def getNext(): EvictedKeysResult = {
+          var ret: EvictedKeysResult = null
+          while (evictIterator.hasNext && ret == null && !isBeyondUpperBound) {
+            val kv = evictIterator.next()
+            val keyRow = detachTimestampProjection(kv.key)
+            val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key)
+
+            if (keyRow == currentKeyRow && ts == currentEventTime) {
+              // new value with same (key, ts)
+              count += 1
+            } else if (ts > endTimestamp) {
+              // we found the timestamp beyond the range - we shouldn't 
continue further
+              isBeyondUpperBound = true
+
+              // We don't need to construct the last (key, ts) into 
EvictedKeysResult - the code
+              // after loop will handle that if there is leftover. That said, 
we do not reset the
+              // current (key, ts) info here.
+            } else if (currentKeyRow == null && currentEventTime == -1L) {
+              // first value to process
+              currentKeyRow = keyRow.copy()
+              currentEventTime = ts
+              count = 1
+            } else {
+              // construct the last (key, ts) into EvictedKeysResult
+              ret = EvictedKeysResult(currentKeyRow, currentEventTime, count)
+
+              // register the next (key, ts) to process
+              currentKeyRow = keyRow.copy()
+              currentEventTime = ts
+              count = 1
+            }
+          }
+
+          if (ret != null) {
+            ret
+          } else if (count > 0) {
+            // there is a final leftover (key, ts) to return
+            ret = EvictedKeysResult(currentKeyRow, currentEventTime, count)
+
+            // we shouldn't continue further
+            currentKeyRow = null
+            currentEventTime = -1L
+            count = 0
+
+            ret
+          } else {
+            finished = true
+            null
+          }
+        }
+
+        override protected def close(): Unit = {
+          evictIterator.close()
+        }
+      }
+    }
+  }
+
+  override def get(key: UnsafeRow): Iterator[UnsafeRow] = {
+    keyWithTsToValues.getValues(key).flatMap { result =>
+      result.values.map(_.value)
+    }
+  }
+
+  def metrics: StateStoreMetrics = stateStore.metrics
+
+  def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo = {
+    val keyToNumValuesCkptInfo = stateStore.getStateStoreCheckpointInfo()
+    val keyWithIndexToValueCkptInfo = stateStore.getStateStoreCheckpointInfo()
+
+    assert(keyToNumValuesCkptInfo == keyWithIndexToValueCkptInfo)
+
+    JoinerStateStoreCkptInfo(keyToNumValuesCkptInfo, 
keyWithIndexToValueCkptInfo)
+  }
+}
+
 /**
  * Helper class to manage state required by a single side of
  * [[org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinExec]].
@@ -84,7 +804,7 @@ import org.apache.spark.util.NextIterator
  *          by overwriting with the value of (key, maxIndex), and removing 
[(key, maxIndex),
  *          decrement corresponding num values in KeyToNumValuesStore
  */
-abstract class SymmetricHashJoinStateManager(
+abstract class SymmetricHashJoinStateManagerBase(
     joinSide: JoinSide,
     inputValueAttributes: Seq[Attribute],
     joinKeys: Seq[Expression],
@@ -98,17 +818,22 @@ abstract class SymmetricHashJoinStateManager(
     skippedNullValueCount: Option[SQLMetric] = None,
     useStateStoreCoordinator: Boolean = true,
     snapshotOptions: Option[SnapshotOptions] = None,
-    joinStoreGenerator: JoinStateManagerStoreGenerator) extends Logging {
+    joinStoreGenerator: JoinStateManagerStoreGenerator)
+  extends SymmetricHashJoinStateManager
+  with SupportsEvictByCondition
+  with SupportsIndexedKeys
+  with Logging {
+
   import SymmetricHashJoinStateManager._
 
-  private[streaming] val keySchema = StructType(
+  protected val keySchema = StructType(
     joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", 
k.dataType, k.nullable) })
   protected val keyAttributes = toAttributes(keySchema)
 
-  private[streaming] val keyToNumValues = new KeyToNumValuesStore(
+  protected val keyToNumValues = new KeyToNumValuesStore(
     stateFormatVersion,
     snapshotOptions.map(_.getKeyToNumValuesHandlerOpts()))
-  private[streaming] val keyWithIndexToValue = new KeyWithIndexToValueStore(
+  protected val keyWithIndexToValue = new KeyWithIndexToValueStore(
     stateFormatVersion,
     snapshotOptions.map(_.getKeyWithIndexToValueHandlerOpts()))
 
@@ -161,6 +886,25 @@ abstract class SymmetricHashJoinStateManager(
     }.filter(_ != null)
   }
 
+  /** Remove using a predicate on keys. */
+  override def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): 
Long = {
+    var numRemoved = 0L
+    keyToNumValues.iterator.foreach { keyAndNumValues =>
+      val key = keyAndNumValues.key
+      if (removalCondition(key)) {
+        val numValue = keyAndNumValues.numValue
+
+        (0L until numValue).foreach { idx =>
+          keyWithIndexToValue.remove(key, idx)
+        }
+
+        numRemoved += numValue
+        keyToNumValues.remove(key)
+      }
+    }
+    numRemoved
+  }
+
   /**
    * Remove using a predicate on keys.
    *
@@ -170,7 +914,8 @@ abstract class SymmetricHashJoinStateManager(
    * This implies the iterator must be consumed fully without any other 
operations on this manager
    * or the underlying store being interleaved.
    */
-  def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): 
Iterator[KeyToValuePair] = {
+  override def evictAndReturnByKeyCondition(
+      removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair] = {
     new NextIterator[KeyToValuePair] {
 
       private val allKeyToNumValues = keyToNumValues.iterator
@@ -268,6 +1013,14 @@ abstract class SymmetricHashJoinStateManager(
     }
   }
 
+  override def evictByValueCondition(removalCondition: UnsafeRow => Boolean): 
Long = {
+    var numRemoved = 0L
+    evictAndReturnByValueCondition(removalCondition).foreach { _ =>
+      numRemoved += 1
+    }
+    numRemoved
+  }
+
   /**
    * Remove using a predicate on values.
    *
@@ -278,7 +1031,8 @@ abstract class SymmetricHashJoinStateManager(
    * This implies the iterator must be consumed fully without any other 
operations on this manager
    * or the underlying store being interleaved.
    */
-  def removeByValueCondition(removalCondition: UnsafeRow => Boolean): 
Iterator[KeyToValuePair] = {
+  override def evictAndReturnByValueCondition(
+      removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair] = {
     new NextIterator[KeyToValuePair] {
 
       // Reuse this object to avoid creation+GC overhead.
@@ -447,7 +1201,7 @@ abstract class SymmetricHashJoinStateManager(
    * NOTE: this function is only intended for use in unit tests
    * to simulate null values.
    */
-  private[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues: 
Long): Unit = {
+  protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues: 
Long): Unit = {
     keyToNumValues.put(key, numValues)
   }
 
@@ -510,7 +1264,7 @@ abstract class SymmetricHashJoinStateManager(
         joinStoreGenerator.getStore(
           storeProviderId, keySchema, valueSchema, 
NoPrefixKeyStateEncoderSpec(keySchema),
           stateInfo.get.storeVersion, stateStoreCkptId, None, 
useVirtualColumnFamilies,
-          storeConf, hadoopConf)
+          useMultipleValuesPerKey = false, storeConf, hadoopConf)
       } else {
         // This class will manage the state store provider by itself.
         stateStoreProvider = StateStoreProvider.createAndInit(
@@ -551,7 +1305,6 @@ abstract class SymmetricHashJoinStateManager(
     }
   }
 
-
   /** A wrapper around a [[StateStore]] that stores [key -> number of values]. 
*/
   protected class KeyToNumValuesStore(
       val stateFormatVersion: Int,
@@ -648,78 +1401,6 @@ SnapshotOptions
     }
   }
 
-  private trait KeyWithIndexToValueRowConverter {
-    /** Defines the schema of the value row (the value side of K-V in state 
store). */
-    def valueAttributes: Seq[Attribute]
-
-    /**
-     * Convert the value row to (actual value, match) pair.
-     *
-     * NOTE: implementations should ensure the result row is NOT reused during 
execution, so
-     * that caller can safely read the value in any time.
-     */
-    def convertValue(value: UnsafeRow): ValueAndMatchPair
-
-    /**
-     * Build the value row from (actual value, match) pair. This is expected 
to be called just
-     * before storing to the state store.
-     *
-     * NOTE: depending on the implementation, the result row "may" be reused 
during execution
-     * (to avoid initialization of object), so the caller should ensure that 
the logic doesn't
-     * affect by such behavior. Call copy() against the result row if needed.
-     */
-    def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow
-  }
-
-  private object KeyWithIndexToValueRowConverter {
-    def create(version: Int): KeyWithIndexToValueRowConverter = version match {
-      case 1 => new KeyWithIndexToValueRowConverterFormatV1()
-      case 2 | 3 => new KeyWithIndexToValueRowConverterFormatV2()
-      case _ => throw new IllegalArgumentException("Incorrect state format 
version! " +
-        s"version $version")
-    }
-  }
-
-  private class KeyWithIndexToValueRowConverterFormatV1 extends 
KeyWithIndexToValueRowConverter {
-    override val valueAttributes: Seq[Attribute] = inputValueAttributes
-
-    override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
-      if (value != null) ValueAndMatchPair(value, false) else null
-    }
-
-    override def convertToValueRow(value: UnsafeRow, matched: Boolean): 
UnsafeRow = value
-  }
-
-  private class KeyWithIndexToValueRowConverterFormatV2 extends 
KeyWithIndexToValueRowConverter {
-    private val valueWithMatchedExprs = inputValueAttributes :+ Literal(true)
-    private val indexOrdinalInValueWithMatchedRow = inputValueAttributes.size
-
-    private val valueWithMatchedRowGenerator = 
UnsafeProjection.create(valueWithMatchedExprs,
-      inputValueAttributes)
-
-    override val valueAttributes: Seq[Attribute] = inputValueAttributes :+
-      AttributeReference("matched", BooleanType)()
-
-    // Projection to generate key row from (value + matched) row
-    private val valueRowGenerator = UnsafeProjection.create(
-      inputValueAttributes, valueAttributes)
-
-    override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
-      if (value != null) {
-        ValueAndMatchPair(valueRowGenerator(value).copy(),
-          value.getBoolean(indexOrdinalInValueWithMatchedRow))
-      } else {
-        null
-      }
-    }
-
-    override def convertToValueRow(value: UnsafeRow, matched: Boolean): 
UnsafeRow = {
-      val row = valueWithMatchedRowGenerator(value)
-      row.setBoolean(indexOrdinalInValueWithMatchedRow, matched)
-      row
-    }
-  }
-
   /**
    * A wrapper around a [[StateStore]] that stores the mapping; the mapping 
depends on the
    * state format version - please refer implementations of 
[[KeyWithIndexToValueRowConverter]].
@@ -742,7 +1423,8 @@ SnapshotOptions
     private val keyRowGenerator = UnsafeProjection.create(
       keyAttributes, keyAttributes :+ AttributeReference("index", LongType)())
 
-    private val valueRowConverter = 
KeyWithIndexToValueRowConverter.create(stateFormatVersion)
+    private val valueRowConverter = StreamingSymmetricHashJoinValueRowConverter
+      .create(inputValueAttributes, stateFormatVersion)
 
     protected val stateStore = getStateStore(keyWithIndexSchema,
       valueRowConverter.valueAttributes.toStructType, useVirtualColumnFamilies)
@@ -869,11 +1551,12 @@ class SymmetricHashJoinStateManagerV1(
     skippedNullValueCount: Option[SQLMetric] = None,
     useStateStoreCoordinator: Boolean = true,
     snapshotOptions: Option[SnapshotOptions] = None,
-    joinStoreGenerator: JoinStateManagerStoreGenerator) extends 
SymmetricHashJoinStateManager(
-  joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
-  partitionId, keyToNumValuesStateStoreCkptId, 
keyWithIndexToValueStateStoreCkptId,
-  stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, 
snapshotOptions,
-  joinStoreGenerator) {
+    joinStoreGenerator: JoinStateManagerStoreGenerator)
+  extends SymmetricHashJoinStateManagerBase(
+    joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
+    partitionId, keyToNumValuesStateStoreCkptId, 
keyWithIndexToValueStateStoreCkptId,
+    stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, 
snapshotOptions,
+    joinStoreGenerator) {
 
   /** Commit all the changes to all the state stores */
   override def commit(): Unit = {
@@ -948,11 +1631,12 @@ class SymmetricHashJoinStateManagerV2(
     skippedNullValueCount: Option[SQLMetric] = None,
     useStateStoreCoordinator: Boolean = true,
     snapshotOptions: Option[SnapshotOptions] = None,
-    joinStoreGenerator: JoinStateManagerStoreGenerator) extends 
SymmetricHashJoinStateManager(
-  joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
-  partitionId, keyToNumValuesStateStoreCkptId, 
keyWithIndexToValueStateStoreCkptId,
-  stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, 
snapshotOptions,
-  joinStoreGenerator) {
+    joinStoreGenerator: JoinStateManagerStoreGenerator)
+  extends SymmetricHashJoinStateManagerBase(
+    joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
+    partitionId, keyToNumValuesStateStoreCkptId, 
keyWithIndexToValueStateStoreCkptId,
+    stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, 
snapshotOptions,
+    joinStoreGenerator) {
 
   /** Commit all the changes to the state store */
   override def commit(): Unit = {
@@ -1001,6 +1685,7 @@ class JoinStateManagerStoreGenerator() extends Logging {
    * Creates the state store used for join operations, or returns the existing 
instance
    * if it has been previously created and virtual column families are enabled.
    */
+  // scalastyle:off argcount
   def getStore(
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
@@ -1010,6 +1695,7 @@ class JoinStateManagerStoreGenerator() extends Logging {
       stateStoreCkptId: Option[String],
       stateSchemaBroadcast: Option[StateSchemaBroadcast],
       useColumnFamilies: Boolean,
+      useMultipleValuesPerKey: Boolean,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStore = {
     if (useColumnFamilies) {
@@ -1019,7 +1705,7 @@ class JoinStateManagerStoreGenerator() extends Logging {
           StateStore.get(
             storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, 
version,
             stateStoreCkptId, stateSchemaBroadcast, useColumnFamilies = 
useColumnFamilies,
-            storeConf, hadoopConf
+            storeConf, hadoopConf, useMultipleValuesPerKey = 
useMultipleValuesPerKey
           )
         )
       }
@@ -1029,14 +1715,15 @@ class JoinStateManagerStoreGenerator() extends Logging {
       StateStore.get(
         storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, version,
         stateStoreCkptId, stateSchemaBroadcast, useColumnFamilies = 
useColumnFamilies,
-        storeConf, hadoopConf
+        storeConf, hadoopConf, useMultipleValuesPerKey = 
useMultipleValuesPerKey
       )
     }
   }
+  // scalastyle:on
 }
 
 object SymmetricHashJoinStateManager {
-  val supportedVersions = Seq(1, 2, 3)
+  val supportedVersions = Seq(1, 2, 3, 4)
   val legacyVersion = 1
 
   // scalastyle:off argcount
@@ -1056,7 +1743,14 @@ object SymmetricHashJoinStateManager {
       useStateStoreCoordinator: Boolean = true,
       snapshotOptions: Option[SnapshotOptions] = None,
       joinStoreGenerator: JoinStateManagerStoreGenerator): 
SymmetricHashJoinStateManager = {
-    if (stateFormatVersion == 3) {
+    if (stateFormatVersion == 4) {
+      new SymmetricHashJoinStateManagerV4(
+        joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, 
hadoopConf,
+        partitionId, keyToNumValuesStateStoreCkptId, 
keyWithIndexToValueStateStoreCkptId,
+        stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, 
snapshotOptions,
+        joinStoreGenerator
+      )
+    } else if (stateFormatVersion == 3) {
       new SymmetricHashJoinStateManagerV2(
         joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, 
hadoopConf,
         partitionId, keyToNumValuesStateStoreCkptId, 
keyWithIndexToValueStateStoreCkptId,
@@ -1254,7 +1948,7 @@ object SymmetricHashJoinStateManager {
     }
   }
 
-  private[streaming] sealed trait StateStoreType
+  private[sql] sealed trait StateStoreType
 
   private[sql] case object KeyToNumValuesType extends StateStoreType {
     override def toString(): String = "keyToNumValues"
@@ -1264,7 +1958,15 @@ object SymmetricHashJoinStateManager {
     override def toString(): String = "keyWithIndexToValue"
   }
 
-  private[streaming] def getStateStoreName(
+  private[sql] case object KeyWithTsToValuesType extends StateStoreType {
+    override def toString(): String = "keyWithTsToValues"
+  }
+
+  private[sql] case object TsWithKeyType extends StateStoreType {
+    override def toString(): String = "tsWithKey"
+  }
+
+  private[join] def getStateStoreName(
       joinSide: JoinSide, storeType: StateStoreType): String = {
     s"$joinSide-$storeType"
   }
@@ -1277,7 +1979,9 @@ object SymmetricHashJoinStateManager {
       storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) {
       KeyWithIndexToValueType
     } else {
-      throw new IllegalArgumentException(s"Unknown join store name: 
$storeName")
+      // TODO: [SPARK-55628] Add support of KeyWithTsToValuesType and 
TsWithKeyType during
+      //  integration.
+      throw new IllegalArgumentException(s"Unsupported join store name: 
$storeName")
     }
   }
 
@@ -1289,15 +1993,19 @@ object SymmetricHashJoinStateManager {
       colFamilyName: String,
       stateKeySchema: StructType,
       stateFormatVersion: Int): StatePartitionKeyExtractor = {
-    assert(stateFormatVersion <= 3, "State format version must be less than or 
equal to 3")
-    val name = if (stateFormatVersion == 3) colFamilyName else storeName
+    assert(stateFormatVersion <= 4, "State format version must be less than or 
equal to 4")
+    val name = if (stateFormatVersion >= 3) colFamilyName else storeName
     if (getStoreType(name) == KeyWithIndexToValueType) {
       // For KeyWithIndex, the index is added to the join (i.e. partition) key.
       // Drop the last field (index) to get the partition key
       new DropLastNFieldsStatePartitionKeyExtractor(stateKeySchema, 
numLastColsToDrop = 1)
-    } else {
+    } else if (getStoreType(name) == KeyToNumValuesType) {
       // State key is the partition key
       new NoopStatePartitionKeyExtractor(stateKeySchema)
+    } else {
+      // TODO: [SPARK-55628] Add support of KeyWithTsToValuesType and 
TsWithKeyType during
+      //  integration.
+      throw new IllegalArgumentException(s"Unsupported join store name: 
$storeName")
     }
   }
 
@@ -1331,6 +2039,39 @@ object SymmetricHashJoinStateManager {
       this
     }
   }
+
+  /**
+   * Helper class for representing data (key, timestamp) to (value, matched).
+   * Designed for object reuse.
+   */
+  case class KeyAndTsToValuePair(
+      var key: UnsafeRow = null,
+      var timestamp: Long = -1L,
+      var value: UnsafeRow = null,
+      var matched: Boolean = false) {
+    def withNew(
+        newKey: UnsafeRow,
+        newTimestamp: Long,
+        newValue: UnsafeRow,
+        newMatched: Boolean): this.type = {
+      this.key = newKey
+      this.timestamp = newTimestamp
+      this.value = newValue
+      this.matched = newMatched
+      this
+    }
+
+    def withNew(
+        newKey: UnsafeRow,
+        newTimestamp: Long,
+        newValue: ValueAndMatchPair): this.type = {
+      this.key = newKey
+      this.timestamp = newTimestamp
+      this.value = newValue.value
+      this.matched = newValue.matched
+      this
+    }
+  }
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
index 6206e6832618..3e01a6e55d7b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
@@ -669,6 +669,7 @@ object WatermarkSupport {
       val eventTimeColsSet = eventTimeCols.map(_.exprId).toSet
       if (eventTimeColsSet.size > 1) {
         throw new AnalysisException(
+          // TODO: [SPARK-55731] Assign error class for _LEGACY_ERROR_TEMP_3077
           errorClass = "_LEGACY_ERROR_TEMP_3077",
           messageParameters = Map("eventTimeCols" -> 
eventTimeCols.mkString("(", ",", ")")))
       }
@@ -684,6 +685,43 @@ object WatermarkSupport {
     // pick the first element if exists
     eventTimeCols.headOption
   }
+
+  /**
+   * Find the index of the column which is marked as "event time" column.
+   *
+   * If there are multiple event time columns in given column list, the 
behavior depends on the
+   * parameter `allowMultipleEventTimeColumns`. If it's set to true, the first 
occurred column will
+   * be returned. If not, this method will throw an AnalysisException as it is 
not allowed to have
+   * multiple event time columns.
+   */
+  def findEventTimeColumnIndex(
+      attrs: Seq[Attribute],
+      allowMultipleEventTimeColumns: Boolean): Option[Int] = {
+    val eventTimeCols = attrs.zipWithIndex
+      .filter(_._1.metadata.contains(EventTimeWatermark.delayKey))
+    if (!allowMultipleEventTimeColumns) {
+      // There is a case projection leads the same column (same exprId) to 
appear more than one
+      // time. Allowing them does not hurt the correctness of state row 
eviction, hence let's start
+      // with allowing them.
+      val eventTimeColsSet = eventTimeCols.map(_._1.exprId).toSet
+      if (eventTimeColsSet.size > 1) {
+        throw new AnalysisException(
+          // TODO: [SPARK-55731] Assign error class for _LEGACY_ERROR_TEMP_3077
+          errorClass = "_LEGACY_ERROR_TEMP_3077",
+          messageParameters = Map("eventTimeCols" -> 
eventTimeCols.mkString("(", ",", ")")))
+      }
+
+      // With above check, even there are multiple columns in eventTimeCols, 
all columns must be
+      // the same.
+    } else {
+      // This is for compatibility with previous behavior - we allow multiple 
distinct event time
+      // columns and pick up the first occurrence. This is incorrect if 
non-first occurrence is
+      // not smaller than the first one, but allow this as "escape hatch" in 
case we break the
+      // existing query.
+    }
+    // pick the first element if exists
+    eventTimeCols.headOption.map(_._2)
+  }
 }
 
 /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
index 2e46202b3b0b..fd601394f12d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala
@@ -17,54 +17,256 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.io.File
 import java.sql.Timestamp
 import java.util.UUID
 
 import org.apache.hadoop.conf.Configuration
 import org.scalatest.BeforeAndAfter
-import org.scalatest.PrivateMethodTester
 
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, 
UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, BoundReference, Expression, GenericInternalRow, JoinedRow, 
LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.execution.metric.SQLMetric
-import 
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
 StatefulOperatorsUtils, StatePartitionKeyExtractorFactory}
-import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.{JoinStateManagerStoreGenerator,
 SymmetricHashJoinStateManager}
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.{JoinStateManagerStoreGenerator,
 SupportsEvictByCondition, SupportsEvictByTimestamp, SupportsIndexedKeys, 
SymmetricHashJoinStateManager}
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.LeftSide
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager.KeyToValuePair
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.StreamTest
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
-class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
-  with PrivateMethodTester {
-
+abstract class SymmetricHashJoinStateManagerBaseSuite extends StreamTest with 
BeforeAndAfter {
   before {
     SparkSession.setActiveSession(spark) // set this before force initializing 
'joinExec'
     spark.streams.stateStoreCoordinator // initialize the lazy coordinator
   }
 
-  SymmetricHashJoinStateManager.supportedVersions.foreach { version =>
-    test(s"StreamingJoinStateManager V${version} - all operations") {
-      testAllOperations(version)
+  protected def inputValueAttributes: Seq[AttributeReference]
+  protected def inputValueAttributesWithWatermark: AttributeReference
+  protected def joinKeyExpressions: Seq[Expression]
+
+  private def inputValueGen = 
UnsafeProjection.create(inputValueAttributes.map(_.dataType).toArray)
+  private def joinKeyGen = 
UnsafeProjection.create(joinKeyExpressions.map(_.dataType).toArray)
+
+  protected def toInputValue(key: Int, value: Int): UnsafeRow = {
+    inputValueGen.apply(new GenericInternalRow(Array[Any](key, value)))
+  }
+
+  protected def toJoinKeyRow(key: Int): UnsafeRow = {
+    joinKeyGen.apply(new GenericInternalRow(Array[Any](false, key, 10.0)))
+  }
+
+  protected def toValueInt(inputValueRow: UnsafeRow): Int = 
inputValueRow.getInt(1)
+
+  protected def append(key: Int, value: Int)
+                      (implicit manager: SymmetricHashJoinStateManager): Unit 
= {
+    // we only put matched = false for simplicity - StreamingJoinSuite will 
test the functionality
+    manager.append(toJoinKeyRow(key), toInputValue(key, value), matched = 
false)
+  }
+
+  protected def appendAndTest(key: Int, values: Int*)
+                             (implicit manager: 
SymmetricHashJoinStateManager): Unit = {
+    values.foreach { value => append(key, value)}
+    require(get(key) === values)
+  }
+
+  protected def getNumValues(key: Int)
+                            (implicit manager: SymmetricHashJoinStateManager): 
Int = {
+    manager.get(toJoinKeyRow(key)).size
+  }
+
+  protected def get(key: Int)(implicit manager: 
SymmetricHashJoinStateManager): Seq[Int] = {
+    manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted
+  }
+
+  /** Remove keys (and corresponding values) where `time <= threshold` */
+  protected def removeByKey(threshold: Long)
+                           (implicit manager: SymmetricHashJoinStateManager): 
Long = {
+    manager match {
+      case evictByTimestamp: SupportsEvictByTimestamp =>
+        evictByTimestamp.evictByTimestamp(threshold)
+
+      case evictByCondition: SupportsEvictByCondition =>
+        val expr =
+          LessThanOrEqual(
+            BoundReference(
+              1,
+              inputValueAttributesWithWatermark.dataType,
+              inputValueAttributesWithWatermark.nullable),
+            Literal(threshold))
+        
evictByCondition.evictByKeyCondition(GeneratePredicate.generate(expr).eval _)
+    }
+  }
+
+  /** Remove keys (and corresponding values) where `time <= threshold` */
+  protected def removeAndReturnByKey(threshold: Long)(
+      implicit manager: SymmetricHashJoinStateManager): 
Iterator[KeyToValuePair] = {
+    manager match {
+      case evictByTimestamp: SupportsEvictByTimestamp =>
+        evictByTimestamp.evictAndReturnByTimestamp(threshold)
+
+      case evictByCondition: SupportsEvictByCondition =>
+        val expr =
+          LessThanOrEqual(
+            BoundReference(
+              1,
+              inputValueAttributesWithWatermark.dataType,
+              inputValueAttributesWithWatermark.nullable),
+            Literal(threshold))
+        
evictByCondition.evictAndReturnByKeyCondition(GeneratePredicate.generate(expr).eval
 _)
+    }
+  }
+
+  /** Remove values where `time <= threshold` */
+  protected def removeByValue(watermark: Long)
+                             (implicit manager: 
SymmetricHashJoinStateManager): Long = {
+    manager match {
+      case evictByTimestamp: SupportsEvictByTimestamp =>
+        evictByTimestamp.evictByTimestamp(watermark)
+
+      case evictByCondition: SupportsEvictByCondition =>
+        val expr = LessThanOrEqual(inputValueAttributesWithWatermark, 
Literal(watermark))
+        evictByCondition.evictByValueCondition(
+          GeneratePredicate.generate(expr, inputValueAttributes).eval _)
+    }
+  }
+
+  /** Remove values where `time <= threshold` */
+  protected def removeAndReturnByValue(watermark: Long)(
+      implicit manager: SymmetricHashJoinStateManager): 
Iterator[KeyToValuePair] = {
+    manager match {
+      case evictByTimestamp: SupportsEvictByTimestamp =>
+        evictByTimestamp.evictAndReturnByTimestamp(watermark)
+
+      case evictByCondition: SupportsEvictByCondition =>
+        val expr = LessThanOrEqual(inputValueAttributesWithWatermark, 
Literal(watermark))
+        evictByCondition.evictAndReturnByValueCondition(
+          GeneratePredicate.generate(expr, inputValueAttributes).eval _)
     }
   }
 
-  SymmetricHashJoinStateManager.supportedVersions.foreach { version =>
+  protected def assertNumRows(stateFormatVersion: Int, target: Long)(
+    implicit manager: SymmetricHashJoinStateManager): Unit = {
+    // This suite originally uses HDFSBackStateStoreProvider, which provides 
instantaneous metrics
+    // for numRows.
+    // But for version 3 with virtual column families, 
RocksDBStateStoreProvider updates metrics
+    // asynchronously. This means the number of keys obtained from the metrics 
are very likely
+    // to be outdated right after a put/remove.
+    if (stateFormatVersion <= 2) {
+      assert(manager.metrics.numKeys == target)
+    }
+  }
+
+  protected def withJoinStateManager(
+      inputValueAttribs: Seq[Attribute],
+      joinKeyExprs: Seq[Expression],
+      stateFormatVersion: Int,
+      skipNullsForStreamStreamJoins: Boolean = false,
+      metric: Option[SQLMetric] = None)
+    (f: SymmetricHashJoinStateManager => Unit): Unit = {
+    // HDFS store providers do not support virtual column families
+    val storeProvider = if (stateFormatVersion >= 3) {
+      classOf[RocksDBStateStoreProvider].getName
+    } else {
+      classOf[HDFSBackedStateStoreProvider].getName
+    }
+    withTempDir { file =>
+      withSQLConf(
+        SQLConf.STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.key ->
+          skipNullsForStreamStreamJoins.toString,
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> storeProvider
+      ) {
+        val storeConf = new StateStoreConf(spark.sessionState.conf)
+        val stateInfo = StatefulOperatorStateInfo(
+          file.getAbsolutePath, UUID.randomUUID, 0, 0, 5, None)
+        val manager = SymmetricHashJoinStateManager(
+          LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), 
storeConf, new Configuration,
+          partitionId = 0, None, None, stateFormatVersion, metric,
+          joinStoreGenerator = new JoinStateManagerStoreGenerator())
+        try {
+          f(manager)
+        } finally {
+          manager.abortIfNeeded()
+        }
+      }
+    }
+    StateStore.stop()
+  }
+
+  protected def withJoinStateManagerWithCheckpointDir(
+      inputValueAttribs: Seq[Attribute],
+      joinKeyExprs: Seq[Expression],
+      stateFormatVersion: Int,
+      checkpointDir: File,
+      storeVersion: Int,
+      changelogCheckpoint: Boolean,
+      skipNullsForStreamStreamJoins: Boolean = false,
+      metric: Option[SQLMetric] = None)
+    (f: SymmetricHashJoinStateManager => Unit): Unit = {
+    // HDFS store providers do not support virtual column families
+    val storeProvider = if (stateFormatVersion >= 3) {
+      classOf[RocksDBStateStoreProvider].getName
+    } else {
+      classOf[HDFSBackedStateStoreProvider].getName
+    }
+    withSQLConf(
+      SQLConf.STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.key ->
+        skipNullsForStreamStreamJoins.toString,
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> storeProvider,
+      "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" 
->
+        changelogCheckpoint.toString
+    ) {
+      val storeConf = new StateStoreConf(spark.sessionState.conf)
+      val stateInfo = StatefulOperatorStateInfo(
+        checkpointDir.getAbsolutePath, UUID.randomUUID, 0, storeVersion, 5, 
None)
+      val manager = SymmetricHashJoinStateManager(
+        LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, 
new Configuration,
+        partitionId = 0, None, None, stateFormatVersion, metric,
+        joinStoreGenerator = new JoinStateManagerStoreGenerator())
+      try {
+        f(manager)
+      } finally {
+        manager.abortIfNeeded()
+      }
+    }
+    StateStore.stop()
+  }
+}
+
+class SymmetricHashJoinStateManagerSuite extends 
SymmetricHashJoinStateManagerBaseSuite {
+  private val watermarkMetadata = new MetadataBuilder()
+    .putLong(EventTimeWatermark.delayKey, 10).build()
+  private val inputValueSchema = new StructType()
+    .add(StructField("key", IntegerType))
+    .add(StructField("time", IntegerType, metadata = watermarkMetadata))
+
+  override protected val inputValueAttributes = toAttributes(inputValueSchema)
+  override protected val inputValueAttributesWithWatermark = 
inputValueAttributes(1)
+  override protected val joinKeyExpressions = Seq[Expression](
+    Literal(false), inputValueAttributes(0), Literal(10.0))
+
+  // V4 is excluded because it does not use indexed values 
(SupportsIndexedKeys) and therefore
+  // cannot have the null-hole problem that these tests exercise via 
updateNumValuesTestOnly.
+  // V4 is covered by EventTimeInKeySuite and EventTimeInValueSuite for 
standard operations.
+  private val versionsInTest = Seq(1, 2, 3)
+
+  versionsInTest.foreach { version =>
     test(s"StreamingJoinStateManager V${version} - all operations with nulls") 
{
       testAllOperationsWithNulls(version)
     }
   }
 
-  SymmetricHashJoinStateManager.supportedVersions.foreach { version =>
+  versionsInTest.foreach { version =>
     test(s"StreamingJoinStateManager V${version} - all operations with nulls 
in middle") {
       testAllOperationsWithNullsInMiddle(version)
     }
   }
 
-  SymmetricHashJoinStateManager.supportedVersions.foreach { version =>
+  versionsInTest.foreach { version =>
     test(s"SPARK-35689: StreamingJoinStateManager V${version} - " +
         "printable key of keyWithIndexToValue") {
 
@@ -75,83 +277,23 @@ class SymmetricHashJoinStateManagerSuite extends 
StreamTest with BeforeAndAfter
         Literal(Timestamp.valueOf("2021-6-8 10:25:50")))
       val keyGen = UnsafeProjection.create(keyExprs.map(_.dataType).toArray)
 
-      withJoinStateManager(inputValueAttribs, keyExprs, version) { manager =>
+      withJoinStateManager(inputValueAttributes, keyExprs, version) { manager 
=>
+        assert(manager.isInstanceOf[SupportsIndexedKeys])
+
         val currentKey = keyGen.apply(new GenericInternalRow(Array[Any](
           false, 10.0, UTF8String.fromString("string"),
           Timestamp.valueOf("2021-6-8 10:25:50").getTime)))
 
-        val projectedRow = manager.getInternalRowOfKeyWithIndex(currentKey)
+        val projectedRow = manager.asInstanceOf[SupportsIndexedKeys]
+          .getInternalRowOfKeyWithIndex(currentKey)
         assert(s"$projectedRow" == "[false,10.0,string,1623173150000]")
       }
     }
   }
 
-  SymmetricHashJoinStateManager.supportedVersions.foreach { version =>
-    test(s"Partition key extraction - SymmetricHashJoinStateManager 
v$version") {
-      testPartitionKeyExtraction(version)
-    }
-  }
-
-  private def testAllOperations(stateFormatVersion: Int): Unit = {
-    withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) 
{ manager =>
-      implicit val mgr = manager
-
-      assert(get(20) === Seq.empty)     // initially empty
-      append(20, 2)
-      assert(get(20) === Seq(2))        // should first value correctly
-      assertNumRows(stateFormatVersion, 1)
-
-      append(20, 3)
-      assert(get(20) === Seq(2, 3))     // should append new values
-      append(20, 3)
-      assert(get(20) === Seq(2, 3, 3))  // should append another copy if same 
value added again
-      assertNumRows(stateFormatVersion, 3)
-
-      assert(get(30) === Seq.empty)
-      append(30, 1)
-      assert(get(30) === Seq(1))
-      assert(get(20) === Seq(2, 3, 3))  // add another key-value should not 
affect existing ones
-      assertNumRows(stateFormatVersion, 4)
-
-      removeByKey(25)
-      assert(get(20) === Seq.empty)
-      assert(get(30) === Seq(1))        // should remove 20, not 30
-      assertNumRows(stateFormatVersion, 1)
-
-      removeByKey(30)
-      assert(get(30) === Seq.empty)     // should remove 30
-      assertNumRows(stateFormatVersion, 0)
-
-      appendAndTest(40, 100, 200, 300)
-      appendAndTest(50, 125)
-      appendAndTest(60, 275)              // prepare for testing removeByValue
-      assertNumRows(stateFormatVersion, 5)
-
-      removeByValue(125)
-      assert(get(40) === Seq(200, 300))
-      assert(get(50) === Seq.empty)
-      assert(get(60) === Seq(275))        // should remove only some values, 
not all
-      assertNumRows(stateFormatVersion, 3)
-
-      append(40, 50)
-      assert(get(40) === Seq(50, 200, 300))
-      assertNumRows(stateFormatVersion, 4)
-
-      removeByValue(200)
-      assert(get(40) === Seq(300))
-      assert(get(60) === Seq(275))        // should remove only some values, 
not all
-      assertNumRows(stateFormatVersion, 2)
-
-      removeByValue(300)
-      assert(get(40) === Seq.empty)
-      assert(get(60) === Seq.empty)       // should remove all values now
-      assertNumRows(stateFormatVersion, 0)
-    }
-  }
-
   /* Test removeByValue with nulls simulated by updating numValues on the 
state manager */
   private def testAllOperationsWithNulls(stateFormatVersion: Int): Unit = {
-    withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) 
{ manager =>
+    withJoinStateManager(inputValueAttributes, joinKeyExpressions, 
stateFormatVersion) { manager =>
       implicit val mgr = manager
 
       appendAndTest(40, 100, 200, 300)
@@ -189,7 +331,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest 
with BeforeAndAfter
   private def testAllOperationsWithNullsInMiddle(stateFormatVersion: Int): 
Unit = {
     // Test with skipNullsForStreamStreamJoins set to false which would throw a
     // NullPointerException while iterating and also return null values as 
part of get
-    withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) 
{ manager =>
+    withJoinStateManager(inputValueAttributes, joinKeyExpressions, 
stateFormatVersion) { manager =>
       implicit val mgr = manager
 
       val ex = intercept[Exception] {
@@ -214,7 +356,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest 
with BeforeAndAfter
     // Test with skipNullsForStreamStreamJoins set to true which would skip 
nulls
     // and continue iterating as part of removeByValue as well as get
     val metric = new SQLMetric("sum")
-    withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion, 
true,
+    withJoinStateManager(inputValueAttributes, joinKeyExpressions, 
stateFormatVersion, true,
         Some(metric)) { manager =>
       implicit val mgr = manager
 
@@ -245,200 +387,484 @@ class SymmetricHashJoinStateManagerSuite extends 
StreamTest with BeforeAndAfter
     }
   }
 
-  val watermarkMetadata = new 
MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build()
-  val inputValueSchema = new StructType()
-    .add(StructField("time", IntegerType, metadata = watermarkMetadata))
-    .add(StructField("value", BooleanType))
-  val inputValueAttribs = toAttributes(inputValueSchema)
-  val inputValueAttribWithWatermark = inputValueAttribs(0)
-  val joinKeyExprs = Seq[Expression](Literal(false), 
inputValueAttribWithWatermark, Literal(10.0))
+  protected def updateNumValues(key: Int, numValues: Long)
+                               (implicit manager: 
SymmetricHashJoinStateManager): Unit = {
+    assert(manager.isInstanceOf[SupportsIndexedKeys])
+    
manager.asInstanceOf[SupportsIndexedKeys].updateNumValuesTestOnly(toJoinKeyRow(key),
 numValues)
+  }
+}
 
-  val inputValueGen = 
UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray)
-  val joinKeyGen = 
UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray)
+class SymmetricHashJoinStateManagerEventTimeInKeySuite
+  extends SymmetricHashJoinStateManagerBaseSuite {
 
+  private val versionsInTest = SymmetricHashJoinStateManager.supportedVersions
 
-  def toInputValue(i: Int): UnsafeRow = {
-    inputValueGen.apply(new GenericInternalRow(Array[Any](i, false)))
+  private val watermarkMetadata = new MetadataBuilder()
+    .putLong(EventTimeWatermark.delayKey, 10).build()
+  private val inputValueSchema = new StructType()
+    .add(StructField("time", IntegerType, metadata = watermarkMetadata))
+    .add(StructField("value", IntegerType))
+
+  override protected val inputValueAttributes: Seq[AttributeReference] =
+    toAttributes(inputValueSchema)
+  override protected val inputValueAttributesWithWatermark: AttributeReference 
=
+    inputValueAttributes(0)
+  override protected val joinKeyExpressions: Seq[Expression] = Seq[Expression](
+    Literal(false),
+    inputValueAttributesWithWatermark,
+    Literal(10.0))
+
+  versionsInTest.foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - all operations") {
+      testAllOperations(ver)
+    }
   }
 
-  def toJoinKeyRow(i: Int): UnsafeRow = {
-    joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0)))
+  versionsInTest.foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - all operations, with commit and 
load in between") {
+      testAllOperationsWithCommitAndLoad(ver, changelogCheckpoint = false)
+      testAllOperationsWithCommitAndLoad(ver, changelogCheckpoint = true)
+    }
   }
 
-  def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0)
+  private def testAllOperations(stateFormatVersion: Int): Unit = {
+    withJoinStateManager(
+      inputValueAttributes,
+      joinKeyExpressions,
+      stateFormatVersion = stateFormatVersion) { manager =>
+      implicit val mgr = manager
 
-  def append(key: Int, value: Int)(implicit manager: 
SymmetricHashJoinStateManager): Unit = {
-    // we only put matched = false for simplicity - StreamingJoinSuite will 
test the functionality
-    manager.append(toJoinKeyRow(key), toInputValue(value), matched = false)
+      assert(get(20) === Seq.empty)     // initially empty
+      append(20, 2)
+      assert(get(20) === Seq(2))        // should first value correctly
+      assertNumRows(stateFormatVersion, 1)
+
+      append(20, 3)
+      assert(get(20) === Seq(2, 3))     // should append new values
+      append(20, 3)
+      assert(get(20) === Seq(2, 3, 3))  // should append another copy if same 
value added again
+      assertNumRows(stateFormatVersion, 3)
+
+      assert(get(30) === Seq.empty)
+      append(30, 1)
+      assert(get(30) === Seq(1))
+      assert(get(20) === Seq(2, 3, 3))  // add another key-value should not 
affect existing ones
+      assertNumRows(stateFormatVersion, 4)
+
+      assert(removeByKey(25) === 3)
+      assert(get(20) === Seq.empty)
+      assert(get(30) === Seq(1))        // should remove 20, not 30
+      assertNumRows(stateFormatVersion, 1)
+
+      assert(removeByKey(30) === 1)
+      assert(get(30) === Seq.empty)     // should remove 30
+      assertNumRows(stateFormatVersion, 0)
+    }
   }
 
-  def appendAndTest(key: Int, values: Int*)
-                   (implicit manager: SymmetricHashJoinStateManager): Unit = {
-    values.foreach { value => append(key, value)}
-    require(get(key) === values)
+  private def testAllOperationsWithCommitAndLoad(
+      stateFormatVersion: Int,
+      changelogCheckpoint: Boolean): Unit = {
+    withTempDir { checkpointDir =>
+      withJoinStateManagerWithCheckpointDir(
+        inputValueAttributes,
+        joinKeyExpressions,
+        stateFormatVersion = stateFormatVersion,
+        checkpointDir,
+        storeVersion = 0,
+        changelogCheckpoint = changelogCheckpoint) { manager =>
+
+        implicit val mgr = manager
+
+        assert(get(20) === Seq.empty)     // initially empty
+        append(20, 2)
+        assert(get(20) === Seq(2))        // should first value correctly
+
+        append(20, 3)
+        assert(get(20) === Seq(2, 3))     // should append new values
+        append(20, 3)
+        assert(get(20) === Seq(2, 3, 3))  // should append another copy if 
same value added again
+
+        mgr.commit()
+      }
+
+      withJoinStateManagerWithCheckpointDir(
+        inputValueAttributes,
+        joinKeyExpressions,
+        stateFormatVersion = stateFormatVersion,
+        checkpointDir,
+        storeVersion = 1,
+        changelogCheckpoint = changelogCheckpoint) { manager =>
+
+        implicit val mgr = manager
+
+        assert(get(30) === Seq.empty)
+        append(30, 1)
+        assert(get(30) === Seq(1))
+        assert(get(20) === Seq(2, 3, 3))  // add another key-value should not 
affect existing ones
+
+        assert(removeByKey(25) === 3)
+        assert(get(20) === Seq.empty)
+        assert(get(30) === Seq(1))        // should remove 20, not 30
+
+        mgr.commit()
+      }
+
+      withJoinStateManagerWithCheckpointDir(
+        inputValueAttributes,
+        joinKeyExpressions,
+        stateFormatVersion = stateFormatVersion,
+        checkpointDir,
+        storeVersion = 2,
+        changelogCheckpoint = changelogCheckpoint) { manager =>
+
+        implicit val mgr = manager
+
+        assert(removeByKey(30) === 1)
+        assert(get(30) === Seq.empty)     // should remove 30
+
+        mgr.commit()
+      }
+    }
   }
 
-  def updateNumValues(key: Int, numValues: Long)
-                     (implicit manager: SymmetricHashJoinStateManager): Unit = 
{
-    manager.updateNumValuesTestOnly(toJoinKeyRow(key), numValues)
+  versionsInTest.foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - evictAndReturnByKey returns 
correct rows") {
+      withJoinStateManager(
+        inputValueAttributes, joinKeyExpressions, stateFormatVersion = ver) { 
manager =>
+        implicit val mgr = manager
+
+        append(20, 2)
+        append(20, 3)
+        append(30, 1)
+
+        val evicted = removeAndReturnByKey(25)
+        val evictedPairs = evicted.map(p => (toValueInt(p.value), 
p.matched)).toSeq
+        assert(evictedPairs.map(_._1).sorted === Seq(2, 3))
+        assert(evictedPairs.forall(!_._2))
+
+        assert(get(20) === Seq.empty)
+        assert(get(30) === Seq(1))
+      }
+    }
   }
 
-  def getNumValues(key: Int)
-                  (implicit manager: SymmetricHashJoinStateManager): Int = {
-    manager.get(toJoinKeyRow(key)).size
+  // V1 excluded: V1 converter does not persist matched flags (SPARK-26154)
+  versionsInTest.filter(_ >= 2).foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - matched flag update + eviction 
roundtrip") {
+      withTempDir { checkpointDir =>
+        withJoinStateManagerWithCheckpointDir(
+          inputValueAttributes, joinKeyExpressions, ver,
+          checkpointDir, storeVersion = 0, changelogCheckpoint = false) { 
manager =>
+          implicit val mgr = manager
+
+          append(20, 2)
+          append(20, 3)
+          append(30, 1)
+
+          val dummyRow = new GenericInternalRow(0)
+          val matched = manager.getJoinedRows(
+            toJoinKeyRow(20),
+            row => new JoinedRow(row, dummyRow),
+            jr => jr.getInt(1) == 2
+          ).toSeq
+          // Here we ensure consumption of the iterator provided by 
getJoinedRows.
+          assert(matched.size == 1)
+
+          mgr.commit()
+        }
+
+        withJoinStateManagerWithCheckpointDir(
+          inputValueAttributes, joinKeyExpressions, ver,
+          checkpointDir, storeVersion = 1, changelogCheckpoint = false) { 
manager =>
+          implicit val mgr = manager
+
+          val evicted = removeAndReturnByKey(25)
+          val evictedPairs = evicted.map(p => (toValueInt(p.value), 
p.matched)).toSeq
+          val matchedByValue = evictedPairs.toMap
+          assert(matchedByValue(2) === true)
+          assert(matchedByValue(3) === false)
+
+          mgr.commit()
+        }
+      }
+    }
   }
 
-  def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] 
= {
-    manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted
+  // V1 excluded: V1 converter does not persist matched flags (SPARK-26154)
+  versionsInTest.filter(_ >= 2).foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - " +
+        "getJoinedRows with excludeRowsAlreadyMatched") {
+      withJoinStateManager(
+        inputValueAttributes, joinKeyExpressions, stateFormatVersion = ver) { 
manager =>
+        implicit val mgr = manager
+
+        append(20, 2)
+        append(20, 3)
+        append(20, 4)
+
+        val dummyRow = new GenericInternalRow(0)
+        val firstPass = manager.getJoinedRows(
+          toJoinKeyRow(20),
+          row => new JoinedRow(row, dummyRow),
+          // intentionally exclude 4, which should be only returned in next 
pass
+          jr => jr.getInt(1) < 4
+        ).toSeq
+        assert(firstPass.size == 2)
+
+        val secondPass = manager.getJoinedRows(
+          toJoinKeyRow(20),
+          row => new JoinedRow(row, dummyRow),
+          _ => true,
+          excludeRowsAlreadyMatched = true
+        ).map(_.getInt(1)).toSeq
+        assert(secondPass === Seq(4))
+      }
+    }
   }
+}
 
-  /** Remove keys (and corresponding values) where `time <= threshold` */
-  def removeByKey(threshold: Long)(implicit manager: 
SymmetricHashJoinStateManager): Unit = {
-    val expr =
-      LessThanOrEqual(
-        BoundReference(
-          1, inputValueAttribWithWatermark.dataType, 
inputValueAttribWithWatermark.nullable),
-        Literal(threshold))
-    val iter = 
manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _)
-    while (iter.hasNext) iter.next()
+class SymmetricHashJoinStateManagerEventTimeInValueSuite
+  extends SymmetricHashJoinStateManagerBaseSuite {
+
+  private val versionsInTest = SymmetricHashJoinStateManager.supportedVersions
+
+  private val watermarkMetadata = new MetadataBuilder()
+    .putLong(EventTimeWatermark.delayKey, 10).build()
+  private val inputValueSchema = new StructType()
+    .add(StructField("key", IntegerType))
+    .add(StructField("time", IntegerType, metadata = watermarkMetadata))
+
+  protected override val inputValueAttributes = toAttributes(inputValueSchema)
+  protected override val inputValueAttributesWithWatermark = 
inputValueAttributes(1)
+  protected override val joinKeyExpressions = Seq[Expression](
+    Literal(false), inputValueAttributes(0), Literal(10.0))
+
+  versionsInTest.foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - all operations") {
+      testAllOperations(ver)
+    }
   }
 
-  /** Remove values where `time <= threshold` */
-  def removeByValue(watermark: Long)(implicit manager: 
SymmetricHashJoinStateManager): Unit = {
-    val expr = LessThanOrEqual(inputValueAttribWithWatermark, 
Literal(watermark))
-    val iter = manager.removeByValueCondition(
-      GeneratePredicate.generate(expr, inputValueAttribs).eval _)
-    while (iter.hasNext) iter.next()
+  versionsInTest.foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - all operations, with commit and 
load in between") {
+      testAllOperationsWithCommitAndLoad(ver, changelogCheckpoint = false)
+      testAllOperationsWithCommitAndLoad(ver, changelogCheckpoint = true)
+    }
   }
 
-  def assertNumRows(stateFormatVersion: Int, target: Long)(
-    implicit manager: SymmetricHashJoinStateManager): Unit = {
-    // This suite originally uses HDFSBackStateStoreProvider, which provides 
instantaneous metrics
-    // for numRows.
-    // But for version 3 with virtual column families, 
RocksDBStateStoreProvider updates metrics
-    // asynchronously. This means the number of keys obtained from the metrics 
are very likely
-    // to be outdated right after a put/remove.
-    if (stateFormatVersion <= 2) {
-      assert(manager.metrics.numKeys == target)
+  private def testAllOperations(stateFormatVersion: Int): Unit = {
+    withJoinStateManager(
+      inputValueAttributes,
+      joinKeyExpressions,
+      stateFormatVersion = stateFormatVersion) { manager =>
+      implicit val mgr = manager
+
+      appendAndTest(40, 100, 200, 300)
+      appendAndTest(50, 125)
+      appendAndTest(60, 275)              // prepare for testing removeByValue
+      assertNumRows(stateFormatVersion, 5)
+
+      assert(removeByValue(125) === 2)
+      assert(get(40) === Seq(200, 300))
+      assert(get(50) === Seq.empty)
+      assert(get(60) === Seq(275))        // should remove only some values, 
not all
+      assertNumRows(stateFormatVersion, 3)
+
+      append(40, 50)
+      assert(get(40) === Seq(50, 200, 300))
+      assertNumRows(stateFormatVersion, 4)
+
+      assert(removeByValue(200) === 2)
+      assert(get(40) === Seq(300))
+      assert(get(60) === Seq(275))        // should remove only some values, 
not all
+      assertNumRows(stateFormatVersion, 2)
+
+      assert(removeByValue(300) === 2)
+      assert(get(40) === Seq.empty)
+      assert(get(60) === Seq.empty)       // should remove all values now
+      assertNumRows(stateFormatVersion, 0)
     }
   }
 
-  def withJoinStateManager(
-      inputValueAttribs: Seq[Attribute],
-      joinKeyExprs: Seq[Expression],
+  private def testAllOperationsWithCommitAndLoad(
       stateFormatVersion: Int,
-      skipNullsForStreamStreamJoins: Boolean = false,
-      metric: Option[SQLMetric] = None)
-      (f: SymmetricHashJoinStateManager => Unit): Unit = {
-    // HDFS store providers do not support virtual column families
-    val storeProvider = if (stateFormatVersion == 3) {
-      classOf[RocksDBStateStoreProvider].getName
-    } else {
-      classOf[HDFSBackedStateStoreProvider].getName
-    }
-    withTempDir { file =>
-      withSQLConf(
-        SQLConf.STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.key ->
-          skipNullsForStreamStreamJoins.toString,
-        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> storeProvider
-      ) {
-        val storeConf = new StateStoreConf(spark.sessionState.conf)
-        val stateInfo = StatefulOperatorStateInfo(
-          file.getAbsolutePath, UUID.randomUUID, 0, 0, 5, None)
-        val manager = SymmetricHashJoinStateManager(
-          LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), 
storeConf, new Configuration,
-          partitionId = 0, None, None, stateFormatVersion, metric,
-          joinStoreGenerator = new JoinStateManagerStoreGenerator())
-        try {
-          f(manager)
-        } finally {
-          manager.abortIfNeeded()
-        }
+      changelogCheckpoint: Boolean): Unit = {
+    withTempDir { checkpointDir =>
+      withJoinStateManagerWithCheckpointDir(
+        inputValueAttributes,
+        joinKeyExpressions,
+        stateFormatVersion = stateFormatVersion,
+        checkpointDir,
+        storeVersion = 0,
+        changelogCheckpoint = changelogCheckpoint) { manager =>
+
+        implicit val mgr = manager
+
+        appendAndTest(40, 100, 200, 300)
+        appendAndTest(50, 125)
+        appendAndTest(60, 275) // prepare for testing removeByValue
+
+        mgr.commit()
+      }
+
+      withJoinStateManagerWithCheckpointDir(
+        inputValueAttributes,
+        joinKeyExpressions,
+        stateFormatVersion = stateFormatVersion,
+        checkpointDir,
+        storeVersion = 1,
+        changelogCheckpoint = changelogCheckpoint) { manager =>
+
+        implicit val mgr = manager
+
+        assert(removeByValue(125) === 2)
+        assert(get(40) === Seq(200, 300))
+        assert(get(50) === Seq.empty)
+        assert(get(60) === Seq(275))        // should remove only some values, 
not all
+
+        mgr.commit()
+      }
+
+      withJoinStateManagerWithCheckpointDir(
+        inputValueAttributes,
+        joinKeyExpressions,
+        stateFormatVersion = stateFormatVersion,
+        checkpointDir,
+        storeVersion = 2,
+        changelogCheckpoint = changelogCheckpoint) { manager =>
+
+        implicit val mgr = manager
+
+        append(40, 50)
+        assert(get(40) === Seq(50, 200, 300))
+
+        mgr.commit()
+      }
+
+      withJoinStateManagerWithCheckpointDir(
+        inputValueAttributes,
+        joinKeyExpressions,
+        stateFormatVersion = stateFormatVersion,
+        checkpointDir,
+        storeVersion = 3,
+        changelogCheckpoint = changelogCheckpoint) { manager =>
+
+        implicit val mgr = manager
+
+        assert(removeByValue(200) === 2)
+        assert(get(40) === Seq(300))
+        assert(get(60) === Seq(275))        // should remove only some values, 
not all
+
+        mgr.commit()
+      }
+
+      withJoinStateManagerWithCheckpointDir(
+        inputValueAttributes,
+        joinKeyExpressions,
+        stateFormatVersion = stateFormatVersion,
+        checkpointDir,
+        storeVersion = 4,
+        changelogCheckpoint = changelogCheckpoint) { manager =>
+
+        implicit val mgr = manager
+
+        assert(removeByValue(300) === 2)
+        assert(get(40) === Seq.empty)
+        assert(get(60) === Seq.empty)       // should remove all values now
+
+        mgr.commit()
       }
     }
-    StateStore.stop()
   }
 
-  private def testPartitionKeyExtraction(stateFormatVersion: Int): Unit = {
-    withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) 
{ manager =>
-      implicit val mgr = manager
+  versionsInTest.foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - evictAndReturnByValue returns 
correct rows") {
+      withJoinStateManager(
+        inputValueAttributes, joinKeyExpressions, stateFormatVersion = ver) { 
manager =>
+        implicit val mgr = manager
 
-      val joinKeySchema = StructType(
-        joinKeyExprs.zipWithIndex.map { case (expr, i) =>
-          StructField(s"field$i", expr.dataType, expr.nullable)
-        })
-
-      // Add some test data
-      append(key = 20, value = 100)
-      append(key = 20, value = 200)
-      append(key = 30, value = 150)
-
-      Seq(
-        (getKeyToNumValuesStoreAndKeySchema(), SymmetricHashJoinStateManager
-          .getStateStoreName(LeftSide, 
SymmetricHashJoinStateManager.KeyToNumValuesType),
-          // expect 1 for both key 20 & 30
-          1, 1),
-        (getKeyWithIndexToValueStoreAndKeySchema(), 
SymmetricHashJoinStateManager
-          .getStateStoreName(LeftSide, 
SymmetricHashJoinStateManager.KeyWithIndexToValueType),
-          // expect 2 for key 20 & 1 for key 30
-          2, 1)
-      ).foreach { case ((store, keySchema), name, expectedNumKey20, 
expectedNumKey30) =>
-        val storeName = if (stateFormatVersion == 3) {
-          StateStoreId.DEFAULT_STORE_NAME
-        } else {
-          name
-        }
+        appendAndTest(40, 100, 200, 300)
+        appendAndTest(50, 125)
+        appendAndTest(60, 275)
 
-        val colFamilyName = if (stateFormatVersion == 3) {
-          name
-        } else {
-          StateStore.DEFAULT_COL_FAMILY_NAME
-        }
+        val evicted = removeAndReturnByValue(125)
+        val evictedPairs = evicted.map(p => (toValueInt(p.value), 
p.matched)).toSeq
+        assert(evictedPairs.map(_._1).sorted === Seq(100, 125))
+        assert(evictedPairs.forall(!_._2))
 
-        val extractor = StatePartitionKeyExtractorFactory.create(
-          StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME,
-          keySchema,
-          storeName,
-          colFamilyName,
-          stateFormatVersion = Some(stateFormatVersion)
-        )
-
-        assert(extractor.partitionKeySchema === joinKeySchema,
-          "Partition key schema should match the join key schema")
-
-        // Copy both the state key and partition key to avoid UnsafeRow reuse 
issues
-        val stateKeys = store.iterator(colFamilyName).map(_.key.copy()).toList
-        val partitionKeys = stateKeys.map(extractor.partitionKey(_).copy())
-
-        assert(partitionKeys.length === expectedNumKey20 + expectedNumKey30,
-          "Should have same num partition keys as num state store keys")
-        assert(partitionKeys.count(_ === toJoinKeyRow(20)) === 
expectedNumKey20,
-          "Should have the expected num partition keys for join key 20")
-        assert(partitionKeys.count(_ === toJoinKeyRow(30)) === 
expectedNumKey30,
-          "Should have the expected num partition keys for join key 30")
+        assert(get(40) === Seq(200, 300))
+        assert(get(50) === Seq.empty)
+        assert(get(60) === Seq(275))
       }
     }
   }
 
-  def getKeyToNumValuesStoreAndKeySchema()
-      (implicit manager: SymmetricHashJoinStateManager): (StateStore, 
StructType) = {
-    val keyToNumValuesHandler = manager.keyToNumValues
-    val keyToNumValuesStoreMethod = 
PrivateMethod[StateStore](Symbol("stateStore"))
-    val keyToNumValuesStore = 
keyToNumValuesHandler.invokePrivate(keyToNumValuesStoreMethod())
+  // V1 excluded: V1 converter does not persist matched flags (SPARK-26154)
+  versionsInTest.filter(_ >= 2).foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - matched flag update + eviction 
roundtrip") {
+      withTempDir { checkpointDir =>
+        withJoinStateManagerWithCheckpointDir(
+          inputValueAttributes, joinKeyExpressions, ver,
+          checkpointDir, storeVersion = 0, changelogCheckpoint = false) { 
manager =>
+          implicit val mgr = manager
+
+          appendAndTest(40, 100, 200, 300)
+
+          val dummyRow = new GenericInternalRow(0)
+          val matched = manager.getJoinedRows(
+            toJoinKeyRow(40),
+            row => new JoinedRow(row, dummyRow),
+            jr => jr.getInt(1) == 100
+          ).toSeq
+          assert(matched.size == 1)
+
+          mgr.commit()
+        }
 
-    (keyToNumValuesStore, manager.keySchema)
-  }
+        withJoinStateManagerWithCheckpointDir(
+          inputValueAttributes, joinKeyExpressions, ver,
+          checkpointDir, storeVersion = 1, changelogCheckpoint = false) { 
manager =>
+          implicit val mgr = manager
 
-  def getKeyWithIndexToValueStoreAndKeySchema()
-      (implicit manager: SymmetricHashJoinStateManager): (StateStore, 
StructType) = {
-    val keyWithIndexToValueHandler = manager.keyWithIndexToValue
+          val evicted = removeAndReturnByValue(125)
+          val evictedPairs = evicted.map(p => (toValueInt(p.value), 
p.matched)).toSeq
+          val matchedByValue = evictedPairs.toMap
+          assert(matchedByValue(100) === true)
 
-    val keyWithIndexToValueStoreMethod = 
PrivateMethod[StateStore](Symbol("stateStore"))
-    val keyWithIndexToValueStore =
-      
keyWithIndexToValueHandler.invokePrivate(keyWithIndexToValueStoreMethod())
+          mgr.commit()
+        }
+      }
+    }
+  }
 
-    val keySchemaMethod = 
PrivateMethod[StructType](Symbol("keyWithIndexSchema"))
-    val keyWithIndexToValueKeySchema = 
keyWithIndexToValueHandler.invokePrivate(keySchemaMethod())
-    (keyWithIndexToValueStore, keyWithIndexToValueKeySchema)
+  // V1 excluded: V1 converter does not persist matched flags (SPARK-26154)
+  versionsInTest.filter(_ >= 2).foreach { ver =>
+    test(s"StreamingJoinStateManager V$ver - " +
+        "getJoinedRows with excludeRowsAlreadyMatched") {
+      withJoinStateManager(
+        inputValueAttributes, joinKeyExpressions, stateFormatVersion = ver) { 
manager =>
+        implicit val mgr = manager
+
+        appendAndTest(40, 100, 200, 300)
+
+        val dummyRow = new GenericInternalRow(0)
+        val firstPass = manager.getJoinedRows(
+          toJoinKeyRow(40),
+          row => new JoinedRow(row, dummyRow),
+          // intentionally exclude 300, which should be only returned in next 
pass
+          jr => jr.getInt(1) < 300
+        ).toSeq
+        assert(firstPass.size == 2)
+
+        val secondPass = manager.getJoinedRows(
+          toJoinKeyRow(40),
+          row => new JoinedRow(row, dummyRow),
+          _ => true,
+          excludeRowsAlreadyMatched = true
+        ).map(_.getInt(1)).toSeq
+        assert(secondPass === Seq(300))
+      }
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to