This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 928f25386c35 [SPARK-54027] Kafka Source RTM support
928f25386c35 is described below

commit 928f25386c356ce38a44fb49918b26219ba6b671
Author: Jerry Peng <[email protected]>
AuthorDate: Fri Oct 31 23:01:28 2025 -0700

    [SPARK-54027] Kafka Source RTM support
    
    ### What changes were proposed in this pull request?
    
    Add support for Real-time Mode in the Kafka Source.  Which means 
KafkaMicroBatchStream needs to implement the SupportsRealTimeMode interface and 
the KakfaPartitionBatchReader needs to extend SupportRealTimeRead interface.
    
    ### Why are the changes needed?
    
    So that Kafka source and sink can be used by Real-time Mode queries
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, Kafka source and sink can be used by Real-time Mode queries
    
    ### How was this patch tested?
    
    Many tests added
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52729 from jerrypeng/SPARK-54027-int.
    
    Authored-by: Jerry Peng <[email protected]>
    Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
 .../java/org/apache/spark/internal/LogKeys.java    |   1 +
 .../sql/kafka010/KafkaBatchPartitionReader.scala   |  49 +-
 .../spark/sql/kafka010/KafkaMicroBatchStream.scala | 140 ++++-
 .../sql/kafka010/consumer/KafkaDataConsumer.scala  |  98 ++-
 .../sql/kafka010/KafkaMicroBatchSourceSuite.scala  |   6 +-
 .../kafka010/KafkaRealTimeIntegrationSuite.scala   | 293 +++++++++
 .../sql/kafka010/KafkaRealTimeModeSuite.scala      | 681 +++++++++++++++++++++
 .../streaming/StreamRealTimeModeSuiteBase.scala    | 127 +++-
 .../apache/spark/sql/streaming/StreamTest.scala    |  16 +
 9 files changed, 1396 insertions(+), 15 deletions(-)

diff --git 
a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java 
b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
index 7486f4ac4fab..8b6d3614b86d 100644
--- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
+++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
@@ -824,6 +824,7 @@ public enum LogKeys implements LogKey {
   TIMEOUT,
   TIMER,
   TIMESTAMP,
+  TIMESTAMP_COLUMN_NAME,
   TIME_UNITS,
   TIP,
   TOKEN,
diff --git 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
index 02568aa89eb1..9fcdf1a7d9bf 100644
--- 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
+++ 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
@@ -19,15 +19,19 @@ package org.apache.spark.sql.kafka010
 
 import java.{util => ju}
 
+import org.apache.kafka.common.record.TimestampType
+
 import org.apache.spark.TaskContext
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{Logging, LogKeys}
 import org.apache.spark.internal.LogKeys._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.connector.metric.CustomTaskMetric
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
+import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead
+import 
org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus
 import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution, 
StreamExecution}
-import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer
+import org.apache.spark.sql.kafka010.consumer.{KafkaDataConsumer, 
KafkaDataConsumerIterator}
 
 /** A [[InputPartition]] for reading Kafka data in a batch based streaming 
query. */
 private[kafka010] case class KafkaBatchInputPartition(
@@ -67,7 +71,8 @@ private case class KafkaBatchPartitionReader(
     executorKafkaParams: ju.Map[String, Object],
     pollTimeoutMs: Long,
     failOnDataLoss: Boolean,
-    includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging 
{
+    includeHeaders: Boolean)
+  extends SupportsRealTimeRead[InternalRow] with Logging {
 
   private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, 
executorKafkaParams)
 
@@ -77,6 +82,12 @@ private case class KafkaBatchPartitionReader(
 
   private var nextOffset = rangeToRead.fromOffset
   private var nextRow: UnsafeRow = _
+  private var iteratorForRealTimeMode: Option[KafkaDataConsumerIterator] = None
+
+  // Boolean flag that indicates whether we have logged the type of timestamp 
(i.e. create time,
+  // log-append time, etc.) for the Kafka source. We log upon reading the 
first record, and we
+  // then skip logging for subsequent records.
+  private var timestampTypeLogged = false
 
   override def next(): Boolean = {
     if (nextOffset < rangeToRead.untilOffset) {
@@ -93,6 +104,38 @@ private case class KafkaBatchPartitionReader(
     }
   }
 
+  override def nextWithTimeout(timeoutMs: java.lang.Long): RecordStatus = {
+    if (!iteratorForRealTimeMode.isDefined) {
+      logInfo(s"Getting a new kafka consuming iterator for 
${offsetRange.topicPartition} " +
+        s"starting from ${nextOffset}, timeoutMs ${timeoutMs}")
+      iteratorForRealTimeMode = Some(consumer.getIterator(nextOffset))
+    }
+    assert(iteratorForRealTimeMode.isDefined)
+    val nextRecord = iteratorForRealTimeMode.get.nextWithTimeout(timeoutMs)
+    nextRecord.foreach { record =>
+
+      nextRow = unsafeRowProjector(record)
+      nextOffset = record.offset + 1
+      if (record.timestampType() == TimestampType.LOG_APPEND_TIME ||
+        record.timestampType() == TimestampType.CREATE_TIME) {
+        if (!timestampTypeLogged) {
+          logInfo(log"Kafka source record timestamp type is " +
+            log"${MDC(LogKeys.TIMESTAMP_COLUMN_NAME, record.timestampType())}")
+          timestampTypeLogged = true
+        }
+
+        RecordStatus.newStatusWithArrivalTimeMs(record.timestamp())
+      } else {
+        RecordStatus.newStatusWithoutArrivalTime(true)
+      }
+    }
+    RecordStatus.newStatusWithoutArrivalTime(nextRecord.isDefined)
+  }
+
+  override def getOffset(): KafkaSourcePartitionOffset = {
+    KafkaSourcePartitionOffset(offsetRange.topicPartition, nextOffset)
+  }
+
   override def get(): UnsafeRow = {
     assert(nextRow != null)
     nextRow
diff --git 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
index 7449e9123033..828891f0b498 100644
--- 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
+++ 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
@@ -26,7 +26,7 @@ import org.apache.kafka.common.TopicPartition
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
-import org.apache.spark.internal.LogKeys.{ERROR, OFFSETS, TIP}
+import org.apache.spark.internal.LogKeys.{ERROR, OFFSETS, TIP, 
TOPIC_PARTITION_OFFSET}
 import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
@@ -60,7 +60,11 @@ private[kafka010] class KafkaMicroBatchStream(
     metadataPath: String,
     startingOffsets: KafkaOffsetRangeLimit,
     failOnDataLoss: Boolean)
-  extends SupportsTriggerAvailableNow with ReportsSourceMetrics with 
MicroBatchStream with Logging {
+    extends SupportsTriggerAvailableNow
+    with SupportsRealTimeMode
+    with ReportsSourceMetrics
+    with MicroBatchStream
+    with Logging {
 
   private[kafka010] val pollTimeoutMs = options.getLong(
     KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
@@ -93,6 +97,11 @@ private[kafka010] class KafkaMicroBatchStream(
 
   private var isTriggerAvailableNow: Boolean = false
 
+  private var inRealTimeMode = false
+  override def prepareForRealTimeMode(): Unit = {
+    inRealTimeMode = true
+  }
+
   /**
    * Lazily initialize `initialPartitionOffsets` to make sure that 
`KafkaConsumer.poll` is only
    * called in StreamExecutionThread. Otherwise, interrupting a thread while 
running
@@ -218,6 +227,107 @@ private[kafka010] class KafkaMicroBatchStream(
     }.toArray
   }
 
+  override def planInputPartitions(start: Offset): Array[InputPartition] = {
+    // This function is used for real time mode. Trigger restrictions won't be 
supported.
+    if (maxOffsetsPerTrigger.isDefined) {
+      throw new UnsupportedOperationException(
+        "maxOffsetsPerTrigger is not compatible with real time mode")
+    }
+    if (minOffsetPerTrigger.isDefined) {
+      throw new UnsupportedOperationException(
+        "minOffsetsPerTrigger is not compatible with real time mode"
+      )
+    }
+    if (options.containsKey(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) {
+      throw new UnsupportedOperationException(
+        "minpartitions is not compatible with real time mode"
+      )
+    }
+    if (options.containsKey(KafkaSourceProvider.ENDING_TIMESTAMP_OPTION_KEY)) {
+      throw new UnsupportedOperationException(
+        "endingtimestamp is not compatible with real time mode"
+      )
+    }
+    if (options.containsKey(KafkaSourceProvider.MAX_TRIGGER_DELAY)) {
+      throw new UnsupportedOperationException(
+        "maxtriggerdelay is not compatible with real time mode"
+      )
+    }
+
+    // This function is used by Real-time Mode, where we expect 1:1 mapping 
between a
+    // topic partition and an input partition.
+    // We are skipping partition range check for performance reason. We can 
always try to do
+    // it in tasks if needed.
+    val startPartitionOffsets = 
start.asInstanceOf[KafkaSourceOffset].partitionToOffsets
+
+    // Here we check previous topic partitions with latest partition offsets 
to see if we need to
+    // update the partition list. Here we don't need the updated partition 
topic to be absolutely
+    // up to date, because there might already be minutes' delay since new 
partition is created.
+    // latestPartitionOffsets should be fetched not long ago anyway.
+    // If the topic partitions change, we fetch the earliest offsets for all 
new partitions
+    // and add them to the list.
+    assert(latestPartitionOffsets != null, "latestPartitionOffsets should be 
set in latestOffset")
+    val latestTopicPartitions = latestPartitionOffsets.keySet
+    val newStartPartitionOffsets = if (startPartitionOffsets.keySet == 
latestTopicPartitions) {
+      startPartitionOffsets
+    } else {
+      val newPartitions = 
latestTopicPartitions.diff(startPartitionOffsets.keySet)
+      // Instead of fetching earliest offsets, we could fill offset 0 here and 
avoid this extra
+      // admin function call. But we consider new partition is rare and 
getting earliest offset
+      // aligns with what we do in micro-batch mode and can potentially enable 
more sanity checks
+      // in executor side.
+      val newPartitionOffsets = 
kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
+
+      assert(
+        newPartitionOffsets.keys.forall(!startPartitionOffsets.contains(_)),
+        "startPartitionOffsets should not contain any key in 
newPartitionOffsets")
+
+      logInfo(log"Partitions added: ${MDC(TOPIC_PARTITION_OFFSET, 
newPartitionOffsets)}")
+      // Filter out new partition offsets that are not 0 and log a warning
+      val nonZeroNewPartitionOffsets = newPartitionOffsets.filter {
+        case (_, offset) => offset != 0
+      }
+      // Log the non-zero new partition offsets
+      if (nonZeroNewPartitionOffsets.nonEmpty) {
+        logWarning(log"new partitions should start from offset 0: " +
+          log"${MDC(OFFSETS, nonZeroNewPartitionOffsets)}")
+        nonZeroNewPartitionOffsets.foreach {
+          case (p, o) =>
+            reportDataLoss(
+              s"Added partition $p starts from $o instead of 0. Some data may 
have been missed",
+              () => KafkaExceptions.addedPartitionDoesNotStartFromZero(p, o))
+        }
+      }
+
+      val deletedPartitions = 
startPartitionOffsets.keySet.diff(latestTopicPartitions)
+      if (deletedPartitions.nonEmpty) {
+        reportDataLoss(
+          s"$deletedPartitions are gone. Some data may have been missed",
+          () =>
+            KafkaExceptions.partitionsDeleted(deletedPartitions, None))
+      }
+
+      startPartitionOffsets ++ newPartitionOffsets
+    }
+
+    newStartPartitionOffsets.keySet.toSeq.map { tp =>
+      val fromOffset = newStartPartitionOffsets(tp)
+      KafkaBatchInputPartition(
+        KafkaOffsetRange(tp, fromOffset, Long.MaxValue, preferredLoc = None),
+        executorKafkaParams,
+        pollTimeoutMs,
+        failOnDataLoss,
+        includeHeaders)
+    }.toArray
+  }
+
+  override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
+    val mergedMap = offsets.map {
+      case KafkaSourcePartitionOffset(p, o) => (p, o)
+    }.toMap
+    KafkaSourceOffset(mergedMap)
+  }
+
   override def createReaderFactory(): PartitionReaderFactory = {
     KafkaBatchReaderFactory
   }
@@ -235,7 +345,22 @@ private[kafka010] class KafkaMicroBatchStream(
   override def toString(): String = s"KafkaV2[$kafkaOffsetReader]"
 
   override def metrics(latestConsumedOffset: Optional[Offset]): ju.Map[String, 
String] = {
-    KafkaMicroBatchStream.metrics(latestConsumedOffset, latestPartitionOffsets)
+    val reCalculatedLatestPartitionOffsets =
+      if (inRealTimeMode) {
+        if (!latestConsumedOffset.isPresent) {
+          // this means a batch has no end offsets, which should not happen
+          None
+        } else {
+          Some {
+            kafkaOffsetReader.fetchLatestOffsets(
+              
Some(latestConsumedOffset.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets))
+          }
+        }
+      } else {
+        Some(latestPartitionOffsets)
+      }
+
+      KafkaMicroBatchStream.metrics(latestConsumedOffset, 
reCalculatedLatestPartitionOffsets)
   }
 
   /**
@@ -386,13 +511,14 @@ object KafkaMicroBatchStream extends Logging {
    */
   def metrics(
       latestConsumedOffset: Optional[Offset],
-      latestAvailablePartitionOffsets: PartitionOffsetMap): ju.Map[String, 
String] = {
+      latestAvailablePartitionOffsets: Option[PartitionOffsetMap]): 
ju.Map[String, String] = {
     val offset = Option(latestConsumedOffset.orElse(null))
 
-    if (offset.nonEmpty && latestAvailablePartitionOffsets != null) {
+    if (offset.nonEmpty && latestAvailablePartitionOffsets.isDefined) {
       val consumedPartitionOffsets = 
offset.map(KafkaSourceOffset(_)).get.partitionToOffsets
-      val offsetsBehindLatest = latestAvailablePartitionOffsets
-        .map(partitionOffset => partitionOffset._2 - 
consumedPartitionOffsets(partitionOffset._1))
+      val offsetsBehindLatest = latestAvailablePartitionOffsets.get
+        .map(partitionOffset => partitionOffset._2 -
+          consumedPartitionOffsets.getOrElse(partitionOffset._1, 0L))
       if (offsetsBehindLatest.nonEmpty) {
         val avgOffsetBehindLatest = offsetsBehindLatest.sum.toDouble / 
offsetsBehindLatest.size
         return Map[String, String](
diff --git 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
index 450039ab6800..af4e5bab2947 100644
--- 
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
+++ 
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
@@ -63,6 +63,13 @@ private[kafka010] class InternalKafkaConsumer(
   private[consumer] var kafkaParamsWithSecurity: ju.Map[String, Object] = _
   private val consumer = createConsumer()
 
+  def poll(pollTimeoutMs: Long): ju.List[ConsumerRecord[Array[Byte], 
Array[Byte]]] = {
+    val p = consumer.poll(Duration.ofMillis(pollTimeoutMs))
+    val r = p.records(topicPartition)
+    logDebug(s"Polled $groupId ${p.partitions()}  ${r.size}")
+    r
+  }
+
   /**
    * Poll messages from Kafka starting from `offset` and returns a pair of 
"list of consumer record"
    * and "offset after poll". The list of consumer record may be empty if the 
Kafka consumer fetches
@@ -131,7 +138,7 @@ private[kafka010] class InternalKafkaConsumer(
     c
   }
 
-  private def seek(offset: Long): Unit = {
+  def seek(offset: Long): Unit = {
     logDebug(s"Seeking to $groupId $topicPartition $offset")
     consumer.seek(topicPartition, offset)
   }
@@ -228,6 +235,19 @@ private[consumer] case class FetchedRecord(
   }
 }
 
+/**
+ * This class keeps returning the next records. If no new record is available, 
it will keep
+ * polling until timeout. It is used by 
KafkaBatchPartitionReader.nextWithTimeout(), to reduce
+ * seeking overhead in real time mode.
+ */
+private[sql] trait KafkaDataConsumerIterator {
+  /**
+   * Return the next record
+   * @return None if no new record is available after `timeoutMs`.
+   */
+  def nextWithTimeout(timeoutMs: Long): Option[ConsumerRecord[Array[Byte], 
Array[Byte]]]
+}
+
 /**
  * This class helps caller to read from Kafka leveraging consumer pool as well 
as fetched data pool.
  * This class throws error when data loss is detected while reading from Kafka.
@@ -272,6 +292,82 @@ private[kafka010] class KafkaDataConsumer(
   // Starting timestamp when the consumer is created.
   private var startTimestampNano: Long = System.nanoTime()
 
+  /**
+   * Get an iterator that can return the next entry. It is used exclusively 
for real-time
+   * mode.
+   *
+   * It is called by KafkaBatchPartitionReader.nextWithTimeout(). Unlike 
get(), there is no
+   * out-of-bound check in this function. Since there is no endOffset given, 
we assume anything
+   * record is valid to return as long as it is at or after `offset`.
+   *
+   * @param startOffset, the starting positions to read from, inclusive.
+   */
+  def getIterator(startOffset: Long): KafkaDataConsumerIterator = {
+    new KafkaDataConsumerIterator {
+      private var fetchedRecordList
+          : Option[ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]]] 
= None
+      private val consumer = getOrRetrieveConsumer()
+      private var firstRecord = true
+      private var _currentOffset: Long = startOffset - 1
+
+      private def fetchedRecordListHasNext(): Boolean = {
+        fetchedRecordList.map(_.hasNext).getOrElse(false)
+      }
+
+      override def nextWithTimeout(
+          timeoutMs: Long): Option[ConsumerRecord[Array[Byte], Array[Byte]]] = 
{
+        var timeLeftMs = timeoutMs
+
+        def timeAndDeductFromTimeLeftMs[T](body: => T): Unit = {
+          // To reduce timing the same operator twice, we reuse the timing 
results for
+          // totalTimeReadNanos and for timeoutMs.
+          val prevTime = totalTimeReadNanos
+          timeNanos {
+            body
+          }
+          timeLeftMs -= (totalTimeReadNanos - prevTime) / 1000000
+        }
+
+        if (firstRecord) {
+          timeAndDeductFromTimeLeftMs {
+            consumer.seek(startOffset)
+            firstRecord = false
+          }
+        }
+        while (!fetchedRecordListHasNext() && timeLeftMs > 0) {
+          timeAndDeductFromTimeLeftMs {
+            try {
+              val records = consumer.poll(timeLeftMs)
+              numPolls += 1
+              if (!records.isEmpty) {
+                numRecordsPolled += records.size
+                fetchedRecordList = Some(records.listIterator)
+              }
+            } catch {
+              case ex: OffsetOutOfRangeException =>
+                if (_currentOffset != -1) {
+                  throw ex
+                } else {
+                  Thread.sleep(10) // retry until the source partition is 
populated
+                  assert(startOffset == 0)
+                  consumer.seek(startOffset)
+                }
+            }
+          }
+        }
+        if (fetchedRecordListHasNext()) {
+          totalRecordsRead += 1
+          val nextRecord = fetchedRecordList.get.next()
+          assert(nextRecord.offset > _currentOffset, "Kafka offset should be 
incremental.")
+          _currentOffset = nextRecord.offset
+          Some(nextRecord)
+        } else {
+          None
+        }
+      }
+    }
+  }
+
   /**
    * Get the record for the given offset if available.
    *
diff --git 
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
 
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 3da5b0c039bb..e619adfce17b 100644
--- 
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ 
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -1838,20 +1838,20 @@ abstract class KafkaMicroBatchV2SourceSuite extends 
KafkaMicroBatchSourceSuiteBa
     val latestOffset = Map[TopicPartition, Long]((topicPartition1, 3L), 
(topicPartition2, 6L))
 
     // test empty offset.
-    assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(null), 
latestOffset).isEmpty)
+    assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(null), 
Some(latestOffset)).isEmpty)
 
     // test valid offsetsBehindLatest
     val offset = KafkaSourceOffset(
       Map[TopicPartition, Long]((topicPartition1, 1L), (topicPartition2, 2L)))
     assert(
-      KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), latestOffset) 
===
+      KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), 
Some(latestOffset)) ===
         Map[String, String](
           "minOffsetsBehindLatest" -> "2",
           "maxOffsetsBehindLatest" -> "4",
           "avgOffsetsBehindLatest" -> "3.0").asJava)
 
     // test null latestAvailablePartitionOffsets
-    assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), 
null).isEmpty)
+    assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), 
None).isEmpty)
   }
 }
 
diff --git 
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala
 
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala
new file mode 100644
index 000000000000..a359dc355478
--- /dev/null
+++ 
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.nio.file.Files
+import java.util.Properties
+
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+
+import org.apache.kafka.clients.producer.{KafkaProducer, Producer, 
ProducerRecord}
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{SparkContext, ThreadAudit}
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.execution.streaming.RealTimeTrigger
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.{OutputMode, ResultsCollector, 
StreamingQuery, StreamRealTimeModeE2ESuiteBase, StreamRealTimeModeSuiteBase}
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+class KafkaRealTimeModeE2ESuite extends KafkaSourceTest with 
StreamRealTimeModeE2ESuiteBase {
+
+  override protected val defaultTrigger: RealTimeTrigger = 
RealTimeTrigger.apply("5 seconds")
+
+  override protected def createSparkSession =
+    new TestSparkSession(
+      new SparkContext(
+        "local[15]",
+        "streaming-key-cuj"
+      )
+    )
+
+  override def beforeEach(): Unit = {
+    super[KafkaSourceTest].beforeEach()
+    super[StreamRealTimeModeE2ESuiteBase].beforeEach()
+  }
+
+  def getKafkaConsumerProperties: Properties = {
+    val props: Properties = new Properties()
+    props.put("bootstrap.servers", testUtils.brokerAddress)
+    props.put("key.serializer", 
"org.apache.kafka.common.serialization.StringSerializer")
+    props.put("value.serializer", 
"org.apache.kafka.common.serialization.StringSerializer")
+    props.put("compression.type", "snappy")
+
+    props
+  }
+
+  test("Union two kafka streams, for each write to sink") {
+    var q: StreamingQuery = null
+    try {
+      val topic1 = newTopic()
+      val topic2 = newTopic()
+      testUtils.createTopic(topic1, partitions = 2)
+      testUtils.createTopic(topic2, partitions = 2)
+
+      val props: Properties = getKafkaConsumerProperties
+      val producer1: Producer[String, String] = new KafkaProducer[String, 
String](props)
+      val producer2: Producer[String, String] = new KafkaProducer[String, 
String](props)
+
+      val readStream1 = spark.readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("subscribe", topic1)
+        .load()
+
+      val readStream2 = spark.readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("subscribe", topic2)
+        .load()
+
+      val df = readStream1
+        .union(readStream2)
+        .selectExpr("CAST(key AS STRING) AS key", "CAST(value AS STRING) AS 
value")
+        .selectExpr("key || ',' || value")
+        .toDF()
+
+      q = runStreamingQuery("union-kafka", df)
+
+      waitForTasksToStart(4)
+
+      val expectedResults = new mutable.ListBuffer[String]()
+      for (batch <- 0 until 3) {
+        (1 to 100).foreach(i => {
+          producer1
+            .send(
+              new ProducerRecord[String, String](
+                topic1,
+                java.lang.Long.toString(i),
+                s"input1-${batch}-${i}"
+              )
+            )
+            .get()
+          producer2
+            .send(
+              new ProducerRecord[String, String](
+                topic2,
+                java.lang.Long.toString(i),
+                s"input2-${batch}-${i}"
+              )
+            )
+            .get()
+        })
+        producer1.flush()
+        producer2.flush()
+
+        expectedResults ++= (1 to 100)
+          .flatMap(v => {
+            Seq(
+              s"${v},input1-${batch}-${v}",
+              s"${v},input2-${batch}-${v}"
+            )
+          })
+          .toList
+
+        eventually(timeout(60.seconds)) {
+          ResultsCollector
+            .get(sinkName)
+            .toArray(new Array[String](ResultsCollector.get(sinkName).size()))
+            .toList
+            .sorted should equal(expectedResults.sorted)
+        }
+      }
+    } finally {
+      if (q != null) {
+        q.stop()
+      }
+    }
+  }
+}
+
+
+/**
+ * Kafka Real-Time Integration test suite.
+ * Tests with a distributed spark cluster with
+ * separate executors processes deployed.
+ */
+class KafkaRealTimeIntegrationSuite
+  extends KafkaSourceTest
+    with StreamRealTimeModeSuiteBase
+    with ThreadAudit
+    with BeforeAndAfterEach
+    with Matchers {
+
+  override protected def createSparkSession =
+    new TestSparkSession(
+      new SparkContext(
+        "local-cluster[3, 5, 1024]", // Ensure we have enough for both stages.
+        "microbatch-context",
+        sparkConf
+          .set("spark.sql.testkey", "true")
+          .set("spark.scheduler.mode", "FAIR")
+          .set("spark.executor.extraJavaOptions", 
"-Dio.netty.leakDetection.level=paranoid")
+      )
+    )
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    // testing to make sure the cluster is usable
+    testUtils.createTopic("_test")
+    testUtils.sendMessage(new ProducerRecord[String, String]("_test", "", ""))
+    testUtils.deleteTopic("_test")
+    logInfo("Kafka cluster setup complete....")
+
+    eventually(timeout(10.seconds)) {
+      val executors = sparkContext.getExecutorIds()
+      assert(executors.size == 3, s"executors: ${executors}}")
+    }
+  }
+
+  test("e2e stateless") {
+    var query: StreamingQuery = null
+    try {
+      val inputTopic = newTopic()
+      testUtils.createTopic(inputTopic, partitions = 5)
+
+      val outputTopic = newTopic()
+      testUtils.createTopic(outputTopic, partitions = 5)
+
+      val props: Properties = new Properties()
+
+      props.put("bootstrap.servers", testUtils.brokerAddress)
+      props.put("key.serializer", 
"org.apache.kafka.common.serialization.StringSerializer");
+      props.put("value.serializer", 
"org.apache.kafka.common.serialization.StringSerializer");
+      props.put("compression.type", "snappy")
+
+      val producer: Producer[String, String] = new KafkaProducer[String, 
String](props)
+
+      query = spark.readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("startingOffsets", "earliest")
+        .option("subscribe", inputTopic)
+        .option("kafka.fetch.max.wait.ms", 10)
+        .load()
+        .withColumn("value", substring(col("value"), 0, 500 * 1000))
+        .withColumn("value", base64(col("value")))
+        .withColumn(
+          "headers",
+          array(
+            struct(
+              lit("source-timestamp") as "key",
+              unix_millis(col("timestamp")).cast("STRING").cast("BINARY") as 
"value"
+            )
+          )
+        )
+        .drop(col("timestamp"))
+        .writeStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("topic", outputTopic)
+        .option("checkpointLocation", 
Files.createTempDirectory("some-prefix").toFile.getName)
+        .option("kafka.max.block.ms", "100")
+        .trigger(RealTimeTrigger.apply("5 minutes"))
+        .outputMode(OutputMode.Update())
+        .start()
+
+      waitForTasksToStart(5)
+
+      var expectedResults: ListBuffer[GenericRowWithSchema] = new ListBuffer
+      for (i <- 0 until 3) {
+        (1 to 100).foreach(i => {
+          producer
+            .send(
+              new ProducerRecord[String, String](
+                inputTopic,
+                java.lang.Long.toString(i),
+                s"payload-${i}"
+              )
+            )
+            .get()
+        })
+
+        producer.flush()
+
+        val kafkaSinkData = spark.read
+          .format("kafka")
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("subscribe", outputTopic)
+          .option("includeHeaders", "true")
+          .option("startingOffsets", "earliest")
+          .load()
+          .withColumn("value", unbase64(col("value")).cast("STRING"))
+          .withColumn("headers-map", map_from_entries(col("headers")))
+          .withColumn("source-timestamp", 
conv(hex(col("headers-map.source-timestamp")), 16, 10))
+          .withColumn("sink-timestamp", unix_millis(col("timestamp")))
+
+        // Check the answers
+        val newResults = (1 to 100)
+          .map(v => {
+            new GenericRowWithSchema(
+              Array(s"payload-${v}"),
+              schema = new StructType().add(StructField("value", StringType))
+            )
+          })
+          .toList
+
+        expectedResults ++= newResults
+        expectedResults =
+          expectedResults.sorted((x: GenericRowWithSchema, y: 
GenericRowWithSchema) => {
+            x.getString(0).compareTo(y.getString(0))
+          })
+
+        eventually(timeout(1.minute)) {
+          checkAnswer(kafkaSinkData.select("value"), expectedResults.toSeq)
+        }
+      }
+    } finally {
+      if (query != null) {
+        query.stop()
+      }
+    }
+  }
+}
diff --git 
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
 
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
new file mode 100644
index 000000000000..83aae64d84f7
--- /dev/null
+++ 
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{SparkConf, SparkContext, SparkIllegalStateException}
+import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemorySink
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, Trigger}
+import org.apache.spark.sql.streaming.OutputMode.Update
+import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.util.SystemClock
+
+class KafkaRealTimeModeSuite
+  extends KafkaSourceTest
+    with Matchers {
+
+  override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds")
+
+  override protected def sparkConf: SparkConf = {
+    // Should turn to use StreamingShuffleManager when it is ready.
+    super.sparkConf
+      .set("spark.databricks.streaming.realTimeMode.enabled", "true")
+      .set(
+        SQLConf.STATE_STORE_PROVIDER_CLASS,
+        classOf[RocksDBStateStoreProvider].getName)
+  }
+
+  override protected def createSparkSession = new TestSparkSession(
+    new SparkContext(
+      "local[8]", // Ensure enough number of cores to ensure concurrent 
schedule of all tasks.
+      "streaming-rtm-context",
+      sparkConf.set("spark.sql.testkey", "true")))
+
+
+  import testImplicits._
+
+  val sleepOneSec = new ExternalAction() {
+    override def runAction(): Unit = {
+      Thread.sleep(1000)
+    }
+  }
+
+  var clock = new GlobalSingletonManualClock()
+
+  private def advanceRealTimeClock(timeMs: Int) = new ExternalAction {
+    override def runAction(): Unit = {
+      clock.advance(timeMs)
+    }
+
+    override def toString(): String = {
+      s"advanceRealTimeClock($timeMs)"
+    }
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(
+      SQLConf.STREAMING_REAL_TIME_MODE_MIN_BATCH_DURATION,
+      defaultTrigger.batchDurationMs
+    )
+  }
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    GlobalSingletonManualClock.reset()
+  }
+
+  override def afterEach(): Unit = {
+    LowLatencyClock.setClock(new SystemClock)
+    super.afterEach()
+  }
+
+  def waitUntilBatchStartedOrProcessed(q: StreamingQuery, batchId: Long): Unit 
= {
+    eventually(timeout(60.seconds)) {
+      val tasksRunning =
+        
spark.sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum
+      val lastBatch = {
+        if (q.lastProgress == null) {
+          -1
+        } else {
+          q.lastProgress.batchId
+        }
+      }
+      val batchStarted = tasksRunning >= 1 && lastBatch >= batchId - 1
+      val batchProcessed = lastBatch >= batchId
+      assert(batchStarted || batchProcessed,
+        s"tasksRunning: ${tasksRunning} lastBatch: ${lastBatch}")
+    }
+  }
+
+  // A simple unit test that reads from Kakfa source, does a simple map and 
writes to memory
+  // sink.
+  test("simple map") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 2)
+
+    testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+    testUtils.sendMessages(topic, Array("3"), Some(1))
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .load()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    testStream(reader, Update, sink = new ContinuousMemorySink())(
+      StartStream(),
+      CheckAnswerWithTimeout(60000, 2, 3, 4),
+      sleepOneSec,
+      sleepOneSec,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+          testUtils.sendMessages(topic, Array("6"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7),
+      WaitUntilCurrentBatchProcessed,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("7"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8),
+      WaitUntilCurrentBatchProcessed,
+      StopStream,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("8"), Some(0))
+          testUtils.sendMessages(topic, Array("9"), Some(1))
+        }
+      },
+      StartStream(),
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+      WaitUntilCurrentBatchProcessed)
+  }
+
+  // A simple unit test that reads from Kakfa source, does a simple map and 
writes to memory
+  // sink. Make sure there is no data for a whole batch. Also, after restart 
the first batch
+  // has no data.
+  test("simple map with empty batch") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 2)
+
+    val reader = spark.readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .load()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    testStream(reader, Update, sink = new ContinuousMemorySink())(
+      StartStream(),
+      WaitUntilBatchProcessed(0),
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("1"), Some(0))
+          testUtils.sendMessages(topic, Array("2"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3),
+      WaitUntilCurrentBatchProcessed,
+      WaitUntilCurrentBatchProcessed,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("3"), Some(1))
+        }
+      },
+      WaitUntilCurrentBatchProcessed,
+      CheckAnswerWithTimeout(5000, 2, 3, 4),
+      StopStream,
+      StartStream(),
+      WaitUntilCurrentBatchProcessed,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("4"), Some(0))
+          testUtils.sendMessages(topic, Array("5"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6),
+      WaitUntilCurrentBatchProcessed
+    )
+  }
+
+  // A simple unit test that reads from Kakfa source, does a simple map and 
writes to memory
+  // sink.
+  test("add partition") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 1)
+
+    testUtils.sendMessages(topic, Array("1", "2", "3"), Some(0))
+
+    val reader = spark.readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .load()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    testStream(reader, Update, sink = new ContinuousMemorySink())(
+      StartStream(),
+      CheckAnswerWithTimeout(60000, 2, 3, 4),
+      sleepOneSec,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.addPartitions(topic, 2)
+          testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+          testUtils.sendMessages(topic, Array("6"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7),
+      WaitUntilCurrentBatchProcessed,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.addPartitions(topic, 4)
+          testUtils.sendMessages(topic, Array("7"), Some(2))
+        }
+      },
+      CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7, 8),
+      WaitUntilCurrentBatchProcessed,
+      StopStream,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("8"), Some(3))
+          testUtils.sendMessages(topic, Array("9"), Some(2))
+        }
+      },
+      StartStream(),
+      CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+      WaitUntilCurrentBatchProcessed
+    )
+  }
+
+  test("Real-Time Mode fetches latestOffset again at end of the batch") {
+    // LowLatencyClock does not affect the wait time of kafka iterator, so 
advancing the clock
+    // does not affect the test finish time. The purpose of using it is to 
make the query start
+    // time consistent, so the test behaves the same.
+    LowLatencyClock.setClock(clock)
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 1)
+
+    val reader = spark.readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      // extra large number to make sure fetch does
+      // not return within batch duration
+      .option("kafka.fetch.max.wait.ms", "20000000")
+      .option("kafka.fetch.min.bytes", "20000000")
+      .load()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    testStream(reader, Update, sink = new ContinuousMemorySink())(
+      StartStream(Trigger.RealTime(10000)),
+      advanceRealTimeClock(2000),
+      Execute { q =>
+        waitUntilBatchStartedOrProcessed(q, 0)
+        testUtils.sendMessages(topic, Array("1"), Some(0))
+      },
+      advanceRealTimeClock(8000),
+      WaitUntilBatchProcessed(0),
+      Execute { q =>
+        val expectedMetrics = Map(
+          "minOffsetsBehindLatest" -> "1",
+          "maxOffsetsBehindLatest" -> "1",
+          "avgOffsetsBehindLatest" -> "1.0",
+          "estimatedTotalBytesBehindLatest" -> null
+        )
+        eventually(timeout(60.seconds)) {
+          expectedMetrics.foreach { case (metric, expectedValue) =>
+            assert(q.lastProgress.sources(0).metrics.get(metric) === 
expectedValue)
+          }
+        }
+      },
+      advanceRealTimeClock(2000),
+      Execute { q =>
+        waitUntilBatchStartedOrProcessed(q, 1)
+        testUtils.sendMessages(topic, Array("2", "3"), Some(0))
+      },
+      advanceRealTimeClock(8000),
+      WaitUntilBatchProcessed(1),
+      Execute { q =>
+        val expectedMetrics = Map(
+          "minOffsetsBehindLatest" -> "3",
+          "maxOffsetsBehindLatest" -> "3",
+          "avgOffsetsBehindLatest" -> "3.0",
+          "estimatedTotalBytesBehindLatest" -> null
+        )
+        eventually(timeout(60.seconds)) {
+          expectedMetrics.foreach { case (metric, expectedValue) =>
+            assert(q.lastProgress.sources(0).metrics.get(metric) === 
expectedValue)
+          }
+        }
+      }
+    )
+  }
+
+  // Validate the query fails with minOffsetPerTrigger option set.
+  Seq(
+    "maxoffsetspertrigger",
+    "minoffsetspertrigger",
+    "minpartitions",
+    "endingtimestamp",
+    "maxtriggerdelay").foreach { opt =>
+    test(s"$opt incompatible") {
+      val topic = newTopic()
+      testUtils.createTopic(topic, partitions = 2)
+
+      val reader = spark.readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("subscribe", topic)
+        .option("startingOffsets", "earliest")
+        .option(opt, "5")
+        .load()
+      testStream(reader, Update, sink = new ContinuousMemorySink())(
+        StartStream(),
+        ExpectFailure[UnsupportedOperationException] { (t: Throwable) => {
+          assert(t.getMessage.toLowerCase().contains(opt))
+        }
+        }
+      )
+    }
+  }
+
+  test("union 2 dataframes after projection") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 2)
+
+    val topic1 = newTopic()
+    testUtils.createTopic(topic1, partitions = 2)
+
+    testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+    testUtils.sendMessages(topic, Array("3"), Some(1))
+
+    testUtils.sendMessages(topic1, Array("11", "12"), Some(0))
+    testUtils.sendMessages(topic1, Array("13"), Some(1))
+
+    val reader1 = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .load()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    val reader2 = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic1)
+      .option("startingOffsets", "earliest")
+      .load()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    val unionedReader = reader1.union(reader2)
+
+    testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+      StartStream(),
+      CheckAnswerWithTimeout(60000, 2, 3, 4, 12, 13, 14),
+      sleepOneSec,
+      sleepOneSec,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+          testUtils.sendMessages(topic, Array("6"), Some(1))
+          testUtils.sendMessages(topic1, Array("14", "15"), Some(0))
+          testUtils.sendMessages(topic1, Array("16"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17),
+      WaitUntilCurrentBatchProcessed,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("7"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17, 
8),
+      WaitUntilCurrentBatchProcessed,
+      StopStream,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("8"), Some(0))
+          testUtils.sendMessages(topic, Array("9"), Some(1))
+          testUtils.sendMessages(topic1, Array("19"), Some(1))
+        }
+      },
+      StartStream(),
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17, 
8, 9, 10, 20),
+      WaitUntilCurrentBatchProcessed)
+  }
+
+  test("union 3 dataframes with and without maxPartitions") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 2)
+
+    val topic1 = newTopic()
+    testUtils.createTopic(topic1, partitions = 2)
+
+    val topic2 = newTopic()
+    testUtils.createTopic(topic2, partitions = 2)
+
+    testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+    testUtils.sendMessages(topic, Array("3"), Some(1))
+
+    testUtils.sendMessages(topic1, Array("11", "12"), Some(0))
+    testUtils.sendMessages(topic1, Array("13"), Some(1))
+
+    testUtils.sendMessages(topic2, Array("21", "22"), Some(0))
+    testUtils.sendMessages(topic2, Array("23"), Some(1))
+
+    val reader1 = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .option("maxPartitions", "1")
+      .load()
+
+    val reader2 = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic1)
+      .option("startingOffsets", "earliest")
+      .load()
+
+    val reader3 = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic2)
+      .option("startingOffsets", "earliest")
+      .option("maxPartitions", "3")
+      .load()
+
+    val unionedReader = reader1.union(reader2).union(reader3)
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+      StartStream(),
+      CheckAnswerWithTimeout(10000, 2, 3, 4, 12, 13, 14, 22, 23, 24),
+      sleepOneSec,
+      sleepOneSec,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("4"), Some(0))
+          testUtils.sendMessages(topic, Array("5"), Some(1))
+          testUtils.sendMessages(topic1, Array("14", "15"), Some(0))
+          testUtils.sendMessages(topic2, Array("24"), Some(0))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15, 
16, 25),
+      WaitUntilCurrentBatchProcessed,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("6"), Some(1))
+          testUtils.sendMessages(topic2, Array("25"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15, 
16, 25, 7, 26),
+      WaitUntilCurrentBatchProcessed,
+      StopStream,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic1, Array("16"), Some(1))
+        }
+      },
+      StartStream(),
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15, 
16, 25, 7, 26, 17),
+      WaitUntilCurrentBatchProcessed)
+  }
+
+  test("self union workaround") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 2)
+
+    testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+    testUtils.sendMessages(topic, Array("3"), Some(1))
+
+    val reader1 = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .load()
+
+
+    val reader2 = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .load()
+
+    val unionedReader = reader1.union(reader2)
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+    testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+      StartStream(),
+      CheckAnswerWithTimeout(60000, 2, 3, 4, 2, 3, 4),
+      sleepOneSec,
+      sleepOneSec,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+          testUtils.sendMessages(topic, Array("6"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7),
+      WaitUntilCurrentBatchProcessed,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("7"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7, 8, 8),
+      WaitUntilCurrentBatchProcessed,
+      StopStream,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("8"), Some(0))
+          testUtils.sendMessages(topic, Array("9"), Some(1))
+        }
+      },
+      StartStream(),
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7, 8, 8, 
9, 10, 9, 10),
+      WaitUntilCurrentBatchProcessed)
+  }
+
+  test("union 2 different sources - Kafka and LowLatencyMemoryStream") {
+    import testImplicits._
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 2)
+
+    val memoryStreamRead = LowLatencyMemoryStream[String](2)
+
+    testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+    testUtils.sendMessages(topic, Array("3"), Some(1))
+    memoryStreamRead.addData("11", "12", "13")
+
+    val reader1 = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .load()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+
+
+    val reader2 = memoryStreamRead.toDF()
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+
+    val unionedReader = reader1.union(reader2)
+      .map(_.toInt)
+      .map(_ + 1)
+
+    testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+      StartStream(),
+      CheckAnswerWithTimeout(60000, 2, 3, 4, 12, 13, 14),
+      sleepOneSec,
+      sleepOneSec,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+          testUtils.sendMessages(topic, Array("6"), Some(1))
+          memoryStreamRead.addData("14")
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15),
+      WaitUntilCurrentBatchProcessed,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("7"), Some(1))
+        }
+      },
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 8),
+      WaitUntilCurrentBatchProcessed,
+      StopStream,
+      new ExternalAction() {
+        override def runAction(): Unit = {
+          testUtils.sendMessages(topic, Array("8"), Some(0))
+          testUtils.sendMessages(topic, Array("9"), Some(1))
+          memoryStreamRead.addData("15", "16", "17")
+        }
+      },
+      StartStream(),
+      CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 8, 9, 10, 
16, 17, 18),
+      WaitUntilCurrentBatchProcessed)
+  }
+
+  test("self union - not allowed") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 2)
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+      .load()
+
+    val unionedReader = reader.union(reader)
+      .selectExpr("CAST(value AS STRING)")
+      .as[String]
+      .map(_.toInt)
+      .map(_ + 1)
+
+      testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+        StartStream(),
+        ExpectFailure[SparkIllegalStateException] { ex =>
+          checkErrorMatchPVals(
+            ex.asInstanceOf[SparkIllegalStateException],
+            
"STREAMING_REAL_TIME_MODE.IDENTICAL_SOURCES_IN_UNION_NOT_SUPPORTED",
+            parameters =
+              Map("sources" -> "(?s).*")
+          )
+        }
+      )
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
index 5bb01bdea26e..9199580f6587 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
@@ -17,11 +17,18 @@
 
 package org.apache.spark.sql.streaming
 
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
+
+import scala.collection.mutable
+
 import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.ForeachWriter
 import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock
-import org.apache.spark.sql.execution.streaming.RealTimeTrigger
+import org.apache.spark.sql.execution.streaming.{LowLatencyMemoryStream, 
RealTimeTrigger}
+import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryWrapper
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
 import org.apache.spark.sql.test.TestSparkSession
@@ -43,6 +50,124 @@ trait StreamRealTimeModeSuiteBase extends StreamTest with 
Matchers {
       "local[10]", // Ensure enough number of cores to ensure concurrent 
schedule of all tasks.
       "streaming-rtm-context",
       sparkConf.set("spark.sql.testkey", "true")))
+
+  /**
+   * Should only be used in real-time mode where the batch duration is long 
enough to ensure
+   * eventually does not skip the batch due to long refresh interval.
+   */
+  def waitForTasksToStart(numTasks: Int): Unit = {
+    eventually(timeout(60.seconds)) {
+      val tasksRunning = spark.sparkContext.statusTracker
+        .getExecutorInfos.map(_.numRunningTasks()).sum
+      assert(tasksRunning == numTasks, s"tasksRunning: ${tasksRunning}")
+    }
+  }
+}
+
+/**
+ * Must be a singleton object to ensure serializable when used in 
ForeachWriter.
+ * Users must make sure different test suites use different sink names to 
avoid race conditions.
+ */
+object ResultsCollector extends ConcurrentHashMap[String, 
ConcurrentLinkedQueue[String]] {
+  def reset(): Unit = {
+    clear()
+  }
+}
+
+/**
+ * Base class that contains helper methods to test Real-Time Mode streaming 
queries.
+ *
+ * The general procedure to use this suite is as follows:
+ * 1. Call createMemoryStream to create a memory stream with manual clock.
+ * 2. Call runStreamingQuery to start a streaming query with custom logic.
+ * 3. Call processBatches to add data to the memory stream and validate 
results.
+ *
+ * It uses foreach to collect results into [[ResultsCollector]]. It also tests 
whether
+ * results are emitted in real-time by having longer batch durations than the 
waiting time.
+ */
+trait StreamRealTimeModeE2ESuiteBase extends StreamRealTimeModeSuiteBase {
+  import testImplicits._
+
+  override protected val defaultTrigger = RealTimeTrigger.apply("300 seconds")
+
+  protected final def sinkName: String = getClass.getName + "Sink"
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    ResultsCollector.reset()
+  }
+
+  // Create a ForeachWriter that collects results into ResultsCollector.
+  def foreachWriter(sinkName: String): ForeachWriter[String] = new 
ForeachWriter[String] {
+    override def open(partitionId: Long, epochId: Long): Boolean = {
+      true
+    }
+
+    override def process(value: String): Unit = {
+      val collector =
+        ResultsCollector.computeIfAbsent(sinkName, (_) => new 
ConcurrentLinkedQueue[String]())
+      collector.add(value)
+    }
+
+    override def close(errorOrNull: Throwable): Unit = {}
+  }
+
+  def createMemoryStream(numPartitions: Int = 5)
+      : (LowLatencyMemoryStream[(String, Int)], GlobalSingletonManualClock) = {
+    val clock = new GlobalSingletonManualClock()
+    LowLatencyClock.setClock(clock)
+    val read = LowLatencyMemoryStream[(String, Int)](numPartitions)
+    (read, clock)
+  }
+
+  def runStreamingQuery(queryName: String, df: 
org.apache.spark.sql.DataFrame): StreamingQuery = {
+    df.as[String]
+      .writeStream
+      .outputMode(OutputMode.Update())
+      .foreach(foreachWriter(sinkName))
+      .queryName(queryName)
+      .trigger(defaultTrigger)
+      .start()
+  }
+
+  // Add test data to the memory source and validate results
+  def processBatches(
+      query: StreamingQuery,
+      read: LowLatencyMemoryStream[(String, Int)],
+      clock: GlobalSingletonManualClock,
+      numRowsPerBatch: Int,
+      numBatches: Int,
+      expectedResultsGenerator: (String, Int) => Array[String]): Unit = {
+    val expectedResults = mutable.ListBuffer[String]()
+    for (i <- 0 until numBatches) {
+      for (key <- List("a", "b", "c")) {
+        for (j <- 1 to numRowsPerBatch) {
+          val value = i * numRowsPerBatch + j
+          read.addData((key, value))
+          expectedResults ++= expectedResultsGenerator(key, value)
+        }
+      }
+
+      eventually(timeout(60.seconds)) {
+        ResultsCollector
+          .get(sinkName)
+          .toArray(new Array[String](ResultsCollector.get(sinkName).size()))
+          .toList
+          .sorted should equal(expectedResults.sorted)
+      }
+
+      clock.advance(defaultTrigger.batchDurationMs)
+
+      eventually(timeout(60.seconds)) {
+        query
+          .asInstanceOf[StreamingQueryWrapper]
+          .streamingQuery
+          .getLatestExecutionContext()
+          .batchId should be(i + 1)
+        query.lastProgress.sources(0).numInputRows should be(numRowsPerBatch * 
3)
+      }
+    }
+  }
 }
 
 abstract class StreamRealTimeModeManualClockSuiteBase extends 
StreamRealTimeModeSuiteBase {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 3ff8cab64d65..e57c4e1e665c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -284,6 +284,8 @@ trait StreamTest extends QueryTest with SharedSparkSession 
with TimeLimits with
 
   case class WaitUntilBatchProcessed(batchId: Long) extends StreamAction with 
StreamMustBeRunning
 
+  case object WaitUntilCurrentBatchProcessed extends StreamAction with 
StreamMustBeRunning
+
   /**
    * Signals that a failure is expected and should not kill the test.
    *
@@ -659,6 +661,20 @@ trait StreamTest extends QueryTest with SharedSparkSession 
with TimeLimits with
             throw currentStream.exception.get
           }
 
+        case WaitUntilCurrentBatchProcessed =>
+          if (currentStream.exception.isDefined) {
+            throw currentStream.exception.get
+          }
+          val currBatch = 
currentStream.commitLog.getLatestBatchId().getOrElse(-1L)
+          eventually("Current batch never finishes") {
+            assert(currentStream.commitLog.getLatestBatchId() != None
+              && currentStream.commitLog.getLatestBatchId().get > currBatch)
+
+            // See WaitUntilBatchProcessed for an explanation of why we wait 
for the progress
+            val latestProgressBatchId =
+              
currentStream.recentProgress.lastOption.map(_.batchId).getOrElse(-1L)
+            assert(latestProgressBatchId >= currBatch)
+          }
         case StopStream =>
           verify(currentStream != null, "can not stop a stream that is not 
running")
           try failAfter(streamingTimeout) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to