This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 928f25386c35 [SPARK-54027] Kafka Source RTM support
928f25386c35 is described below
commit 928f25386c356ce38a44fb49918b26219ba6b671
Author: Jerry Peng <[email protected]>
AuthorDate: Fri Oct 31 23:01:28 2025 -0700
[SPARK-54027] Kafka Source RTM support
### What changes were proposed in this pull request?
Add support for Real-time Mode in the Kafka Source. Which means
KafkaMicroBatchStream needs to implement the SupportsRealTimeMode interface and
the KakfaPartitionBatchReader needs to extend SupportRealTimeRead interface.
### Why are the changes needed?
So that Kafka source and sink can be used by Real-time Mode queries
### Does this PR introduce _any_ user-facing change?
Yes, Kafka source and sink can be used by Real-time Mode queries
### How was this patch tested?
Many tests added
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52729 from jerrypeng/SPARK-54027-int.
Authored-by: Jerry Peng <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
.../java/org/apache/spark/internal/LogKeys.java | 1 +
.../sql/kafka010/KafkaBatchPartitionReader.scala | 49 +-
.../spark/sql/kafka010/KafkaMicroBatchStream.scala | 140 ++++-
.../sql/kafka010/consumer/KafkaDataConsumer.scala | 98 ++-
.../sql/kafka010/KafkaMicroBatchSourceSuite.scala | 6 +-
.../kafka010/KafkaRealTimeIntegrationSuite.scala | 293 +++++++++
.../sql/kafka010/KafkaRealTimeModeSuite.scala | 681 +++++++++++++++++++++
.../streaming/StreamRealTimeModeSuiteBase.scala | 127 +++-
.../apache/spark/sql/streaming/StreamTest.scala | 16 +
9 files changed, 1396 insertions(+), 15 deletions(-)
diff --git
a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
index 7486f4ac4fab..8b6d3614b86d 100644
--- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
+++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
@@ -824,6 +824,7 @@ public enum LogKeys implements LogKey {
TIMEOUT,
TIMER,
TIMESTAMP,
+ TIMESTAMP_COLUMN_NAME,
TIME_UNITS,
TIP,
TOKEN,
diff --git
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
index 02568aa89eb1..9fcdf1a7d9bf 100644
---
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
+++
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala
@@ -19,15 +19,19 @@ package org.apache.spark.sql.kafka010
import java.{util => ju}
+import org.apache.kafka.common.record.TimestampType
+
import org.apache.spark.TaskContext
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
+import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead
+import
org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus
import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution,
StreamExecution}
-import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer
+import org.apache.spark.sql.kafka010.consumer.{KafkaDataConsumer,
KafkaDataConsumerIterator}
/** A [[InputPartition]] for reading Kafka data in a batch based streaming
query. */
private[kafka010] case class KafkaBatchInputPartition(
@@ -67,7 +71,8 @@ private case class KafkaBatchPartitionReader(
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
- includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging
{
+ includeHeaders: Boolean)
+ extends SupportsRealTimeRead[InternalRow] with Logging {
private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition,
executorKafkaParams)
@@ -77,6 +82,12 @@ private case class KafkaBatchPartitionReader(
private var nextOffset = rangeToRead.fromOffset
private var nextRow: UnsafeRow = _
+ private var iteratorForRealTimeMode: Option[KafkaDataConsumerIterator] = None
+
+ // Boolean flag that indicates whether we have logged the type of timestamp
(i.e. create time,
+ // log-append time, etc.) for the Kafka source. We log upon reading the
first record, and we
+ // then skip logging for subsequent records.
+ private var timestampTypeLogged = false
override def next(): Boolean = {
if (nextOffset < rangeToRead.untilOffset) {
@@ -93,6 +104,38 @@ private case class KafkaBatchPartitionReader(
}
}
+ override def nextWithTimeout(timeoutMs: java.lang.Long): RecordStatus = {
+ if (!iteratorForRealTimeMode.isDefined) {
+ logInfo(s"Getting a new kafka consuming iterator for
${offsetRange.topicPartition} " +
+ s"starting from ${nextOffset}, timeoutMs ${timeoutMs}")
+ iteratorForRealTimeMode = Some(consumer.getIterator(nextOffset))
+ }
+ assert(iteratorForRealTimeMode.isDefined)
+ val nextRecord = iteratorForRealTimeMode.get.nextWithTimeout(timeoutMs)
+ nextRecord.foreach { record =>
+
+ nextRow = unsafeRowProjector(record)
+ nextOffset = record.offset + 1
+ if (record.timestampType() == TimestampType.LOG_APPEND_TIME ||
+ record.timestampType() == TimestampType.CREATE_TIME) {
+ if (!timestampTypeLogged) {
+ logInfo(log"Kafka source record timestamp type is " +
+ log"${MDC(LogKeys.TIMESTAMP_COLUMN_NAME, record.timestampType())}")
+ timestampTypeLogged = true
+ }
+
+ RecordStatus.newStatusWithArrivalTimeMs(record.timestamp())
+ } else {
+ RecordStatus.newStatusWithoutArrivalTime(true)
+ }
+ }
+ RecordStatus.newStatusWithoutArrivalTime(nextRecord.isDefined)
+ }
+
+ override def getOffset(): KafkaSourcePartitionOffset = {
+ KafkaSourcePartitionOffset(offsetRange.topicPartition, nextOffset)
+ }
+
override def get(): UnsafeRow = {
assert(nextRow != null)
nextRow
diff --git
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
index 7449e9123033..828891f0b498 100644
---
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
+++
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala
@@ -26,7 +26,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
-import org.apache.spark.internal.LogKeys.{ERROR, OFFSETS, TIP}
+import org.apache.spark.internal.LogKeys.{ERROR, OFFSETS, TIP,
TOPIC_PARTITION_OFFSET}
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{InputPartition,
PartitionReaderFactory}
@@ -60,7 +60,11 @@ private[kafka010] class KafkaMicroBatchStream(
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
- extends SupportsTriggerAvailableNow with ReportsSourceMetrics with
MicroBatchStream with Logging {
+ extends SupportsTriggerAvailableNow
+ with SupportsRealTimeMode
+ with ReportsSourceMetrics
+ with MicroBatchStream
+ with Logging {
private[kafka010] val pollTimeoutMs = options.getLong(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
@@ -93,6 +97,11 @@ private[kafka010] class KafkaMicroBatchStream(
private var isTriggerAvailableNow: Boolean = false
+ private var inRealTimeMode = false
+ override def prepareForRealTimeMode(): Unit = {
+ inRealTimeMode = true
+ }
+
/**
* Lazily initialize `initialPartitionOffsets` to make sure that
`KafkaConsumer.poll` is only
* called in StreamExecutionThread. Otherwise, interrupting a thread while
running
@@ -218,6 +227,107 @@ private[kafka010] class KafkaMicroBatchStream(
}.toArray
}
+ override def planInputPartitions(start: Offset): Array[InputPartition] = {
+ // This function is used for real time mode. Trigger restrictions won't be
supported.
+ if (maxOffsetsPerTrigger.isDefined) {
+ throw new UnsupportedOperationException(
+ "maxOffsetsPerTrigger is not compatible with real time mode")
+ }
+ if (minOffsetPerTrigger.isDefined) {
+ throw new UnsupportedOperationException(
+ "minOffsetsPerTrigger is not compatible with real time mode"
+ )
+ }
+ if (options.containsKey(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) {
+ throw new UnsupportedOperationException(
+ "minpartitions is not compatible with real time mode"
+ )
+ }
+ if (options.containsKey(KafkaSourceProvider.ENDING_TIMESTAMP_OPTION_KEY)) {
+ throw new UnsupportedOperationException(
+ "endingtimestamp is not compatible with real time mode"
+ )
+ }
+ if (options.containsKey(KafkaSourceProvider.MAX_TRIGGER_DELAY)) {
+ throw new UnsupportedOperationException(
+ "maxtriggerdelay is not compatible with real time mode"
+ )
+ }
+
+ // This function is used by Real-time Mode, where we expect 1:1 mapping
between a
+ // topic partition and an input partition.
+ // We are skipping partition range check for performance reason. We can
always try to do
+ // it in tasks if needed.
+ val startPartitionOffsets =
start.asInstanceOf[KafkaSourceOffset].partitionToOffsets
+
+ // Here we check previous topic partitions with latest partition offsets
to see if we need to
+ // update the partition list. Here we don't need the updated partition
topic to be absolutely
+ // up to date, because there might already be minutes' delay since new
partition is created.
+ // latestPartitionOffsets should be fetched not long ago anyway.
+ // If the topic partitions change, we fetch the earliest offsets for all
new partitions
+ // and add them to the list.
+ assert(latestPartitionOffsets != null, "latestPartitionOffsets should be
set in latestOffset")
+ val latestTopicPartitions = latestPartitionOffsets.keySet
+ val newStartPartitionOffsets = if (startPartitionOffsets.keySet ==
latestTopicPartitions) {
+ startPartitionOffsets
+ } else {
+ val newPartitions =
latestTopicPartitions.diff(startPartitionOffsets.keySet)
+ // Instead of fetching earliest offsets, we could fill offset 0 here and
avoid this extra
+ // admin function call. But we consider new partition is rare and
getting earliest offset
+ // aligns with what we do in micro-batch mode and can potentially enable
more sanity checks
+ // in executor side.
+ val newPartitionOffsets =
kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
+
+ assert(
+ newPartitionOffsets.keys.forall(!startPartitionOffsets.contains(_)),
+ "startPartitionOffsets should not contain any key in
newPartitionOffsets")
+
+ logInfo(log"Partitions added: ${MDC(TOPIC_PARTITION_OFFSET,
newPartitionOffsets)}")
+ // Filter out new partition offsets that are not 0 and log a warning
+ val nonZeroNewPartitionOffsets = newPartitionOffsets.filter {
+ case (_, offset) => offset != 0
+ }
+ // Log the non-zero new partition offsets
+ if (nonZeroNewPartitionOffsets.nonEmpty) {
+ logWarning(log"new partitions should start from offset 0: " +
+ log"${MDC(OFFSETS, nonZeroNewPartitionOffsets)}")
+ nonZeroNewPartitionOffsets.foreach {
+ case (p, o) =>
+ reportDataLoss(
+ s"Added partition $p starts from $o instead of 0. Some data may
have been missed",
+ () => KafkaExceptions.addedPartitionDoesNotStartFromZero(p, o))
+ }
+ }
+
+ val deletedPartitions =
startPartitionOffsets.keySet.diff(latestTopicPartitions)
+ if (deletedPartitions.nonEmpty) {
+ reportDataLoss(
+ s"$deletedPartitions are gone. Some data may have been missed",
+ () =>
+ KafkaExceptions.partitionsDeleted(deletedPartitions, None))
+ }
+
+ startPartitionOffsets ++ newPartitionOffsets
+ }
+
+ newStartPartitionOffsets.keySet.toSeq.map { tp =>
+ val fromOffset = newStartPartitionOffsets(tp)
+ KafkaBatchInputPartition(
+ KafkaOffsetRange(tp, fromOffset, Long.MaxValue, preferredLoc = None),
+ executorKafkaParams,
+ pollTimeoutMs,
+ failOnDataLoss,
+ includeHeaders)
+ }.toArray
+ }
+
+ override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
+ val mergedMap = offsets.map {
+ case KafkaSourcePartitionOffset(p, o) => (p, o)
+ }.toMap
+ KafkaSourceOffset(mergedMap)
+ }
+
override def createReaderFactory(): PartitionReaderFactory = {
KafkaBatchReaderFactory
}
@@ -235,7 +345,22 @@ private[kafka010] class KafkaMicroBatchStream(
override def toString(): String = s"KafkaV2[$kafkaOffsetReader]"
override def metrics(latestConsumedOffset: Optional[Offset]): ju.Map[String,
String] = {
- KafkaMicroBatchStream.metrics(latestConsumedOffset, latestPartitionOffsets)
+ val reCalculatedLatestPartitionOffsets =
+ if (inRealTimeMode) {
+ if (!latestConsumedOffset.isPresent) {
+ // this means a batch has no end offsets, which should not happen
+ None
+ } else {
+ Some {
+ kafkaOffsetReader.fetchLatestOffsets(
+
Some(latestConsumedOffset.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets))
+ }
+ }
+ } else {
+ Some(latestPartitionOffsets)
+ }
+
+ KafkaMicroBatchStream.metrics(latestConsumedOffset,
reCalculatedLatestPartitionOffsets)
}
/**
@@ -386,13 +511,14 @@ object KafkaMicroBatchStream extends Logging {
*/
def metrics(
latestConsumedOffset: Optional[Offset],
- latestAvailablePartitionOffsets: PartitionOffsetMap): ju.Map[String,
String] = {
+ latestAvailablePartitionOffsets: Option[PartitionOffsetMap]):
ju.Map[String, String] = {
val offset = Option(latestConsumedOffset.orElse(null))
- if (offset.nonEmpty && latestAvailablePartitionOffsets != null) {
+ if (offset.nonEmpty && latestAvailablePartitionOffsets.isDefined) {
val consumedPartitionOffsets =
offset.map(KafkaSourceOffset(_)).get.partitionToOffsets
- val offsetsBehindLatest = latestAvailablePartitionOffsets
- .map(partitionOffset => partitionOffset._2 -
consumedPartitionOffsets(partitionOffset._1))
+ val offsetsBehindLatest = latestAvailablePartitionOffsets.get
+ .map(partitionOffset => partitionOffset._2 -
+ consumedPartitionOffsets.getOrElse(partitionOffset._1, 0L))
if (offsetsBehindLatest.nonEmpty) {
val avgOffsetBehindLatest = offsetsBehindLatest.sum.toDouble /
offsetsBehindLatest.size
return Map[String, String](
diff --git
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
index 450039ab6800..af4e5bab2947 100644
---
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
+++
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
@@ -63,6 +63,13 @@ private[kafka010] class InternalKafkaConsumer(
private[consumer] var kafkaParamsWithSecurity: ju.Map[String, Object] = _
private val consumer = createConsumer()
+ def poll(pollTimeoutMs: Long): ju.List[ConsumerRecord[Array[Byte],
Array[Byte]]] = {
+ val p = consumer.poll(Duration.ofMillis(pollTimeoutMs))
+ val r = p.records(topicPartition)
+ logDebug(s"Polled $groupId ${p.partitions()} ${r.size}")
+ r
+ }
+
/**
* Poll messages from Kafka starting from `offset` and returns a pair of
"list of consumer record"
* and "offset after poll". The list of consumer record may be empty if the
Kafka consumer fetches
@@ -131,7 +138,7 @@ private[kafka010] class InternalKafkaConsumer(
c
}
- private def seek(offset: Long): Unit = {
+ def seek(offset: Long): Unit = {
logDebug(s"Seeking to $groupId $topicPartition $offset")
consumer.seek(topicPartition, offset)
}
@@ -228,6 +235,19 @@ private[consumer] case class FetchedRecord(
}
}
+/**
+ * This class keeps returning the next records. If no new record is available,
it will keep
+ * polling until timeout. It is used by
KafkaBatchPartitionReader.nextWithTimeout(), to reduce
+ * seeking overhead in real time mode.
+ */
+private[sql] trait KafkaDataConsumerIterator {
+ /**
+ * Return the next record
+ * @return None if no new record is available after `timeoutMs`.
+ */
+ def nextWithTimeout(timeoutMs: Long): Option[ConsumerRecord[Array[Byte],
Array[Byte]]]
+}
+
/**
* This class helps caller to read from Kafka leveraging consumer pool as well
as fetched data pool.
* This class throws error when data loss is detected while reading from Kafka.
@@ -272,6 +292,82 @@ private[kafka010] class KafkaDataConsumer(
// Starting timestamp when the consumer is created.
private var startTimestampNano: Long = System.nanoTime()
+ /**
+ * Get an iterator that can return the next entry. It is used exclusively
for real-time
+ * mode.
+ *
+ * It is called by KafkaBatchPartitionReader.nextWithTimeout(). Unlike
get(), there is no
+ * out-of-bound check in this function. Since there is no endOffset given,
we assume anything
+ * record is valid to return as long as it is at or after `offset`.
+ *
+ * @param startOffset, the starting positions to read from, inclusive.
+ */
+ def getIterator(startOffset: Long): KafkaDataConsumerIterator = {
+ new KafkaDataConsumerIterator {
+ private var fetchedRecordList
+ : Option[ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]]]
= None
+ private val consumer = getOrRetrieveConsumer()
+ private var firstRecord = true
+ private var _currentOffset: Long = startOffset - 1
+
+ private def fetchedRecordListHasNext(): Boolean = {
+ fetchedRecordList.map(_.hasNext).getOrElse(false)
+ }
+
+ override def nextWithTimeout(
+ timeoutMs: Long): Option[ConsumerRecord[Array[Byte], Array[Byte]]] =
{
+ var timeLeftMs = timeoutMs
+
+ def timeAndDeductFromTimeLeftMs[T](body: => T): Unit = {
+ // To reduce timing the same operator twice, we reuse the timing
results for
+ // totalTimeReadNanos and for timeoutMs.
+ val prevTime = totalTimeReadNanos
+ timeNanos {
+ body
+ }
+ timeLeftMs -= (totalTimeReadNanos - prevTime) / 1000000
+ }
+
+ if (firstRecord) {
+ timeAndDeductFromTimeLeftMs {
+ consumer.seek(startOffset)
+ firstRecord = false
+ }
+ }
+ while (!fetchedRecordListHasNext() && timeLeftMs > 0) {
+ timeAndDeductFromTimeLeftMs {
+ try {
+ val records = consumer.poll(timeLeftMs)
+ numPolls += 1
+ if (!records.isEmpty) {
+ numRecordsPolled += records.size
+ fetchedRecordList = Some(records.listIterator)
+ }
+ } catch {
+ case ex: OffsetOutOfRangeException =>
+ if (_currentOffset != -1) {
+ throw ex
+ } else {
+ Thread.sleep(10) // retry until the source partition is
populated
+ assert(startOffset == 0)
+ consumer.seek(startOffset)
+ }
+ }
+ }
+ }
+ if (fetchedRecordListHasNext()) {
+ totalRecordsRead += 1
+ val nextRecord = fetchedRecordList.get.next()
+ assert(nextRecord.offset > _currentOffset, "Kafka offset should be
incremental.")
+ _currentOffset = nextRecord.offset
+ Some(nextRecord)
+ } else {
+ None
+ }
+ }
+ }
+ }
+
/**
* Get the record for the given offset if available.
*
diff --git
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 3da5b0c039bb..e619adfce17b 100644
---
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -1838,20 +1838,20 @@ abstract class KafkaMicroBatchV2SourceSuite extends
KafkaMicroBatchSourceSuiteBa
val latestOffset = Map[TopicPartition, Long]((topicPartition1, 3L),
(topicPartition2, 6L))
// test empty offset.
- assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(null),
latestOffset).isEmpty)
+ assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(null),
Some(latestOffset)).isEmpty)
// test valid offsetsBehindLatest
val offset = KafkaSourceOffset(
Map[TopicPartition, Long]((topicPartition1, 1L), (topicPartition2, 2L)))
assert(
- KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), latestOffset)
===
+ KafkaMicroBatchStream.metrics(Optional.ofNullable(offset),
Some(latestOffset)) ===
Map[String, String](
"minOffsetsBehindLatest" -> "2",
"maxOffsetsBehindLatest" -> "4",
"avgOffsetsBehindLatest" -> "3.0").asJava)
// test null latestAvailablePartitionOffsets
- assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(offset),
null).isEmpty)
+ assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(offset),
None).isEmpty)
}
}
diff --git
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala
new file mode 100644
index 000000000000..a359dc355478
--- /dev/null
+++
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.nio.file.Files
+import java.util.Properties
+
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+
+import org.apache.kafka.clients.producer.{KafkaProducer, Producer,
ProducerRecord}
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{SparkContext, ThreadAudit}
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.execution.streaming.RealTimeTrigger
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.{OutputMode, ResultsCollector,
StreamingQuery, StreamRealTimeModeE2ESuiteBase, StreamRealTimeModeSuiteBase}
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+class KafkaRealTimeModeE2ESuite extends KafkaSourceTest with
StreamRealTimeModeE2ESuiteBase {
+
+ override protected val defaultTrigger: RealTimeTrigger =
RealTimeTrigger.apply("5 seconds")
+
+ override protected def createSparkSession =
+ new TestSparkSession(
+ new SparkContext(
+ "local[15]",
+ "streaming-key-cuj"
+ )
+ )
+
+ override def beforeEach(): Unit = {
+ super[KafkaSourceTest].beforeEach()
+ super[StreamRealTimeModeE2ESuiteBase].beforeEach()
+ }
+
+ def getKafkaConsumerProperties: Properties = {
+ val props: Properties = new Properties()
+ props.put("bootstrap.servers", testUtils.brokerAddress)
+ props.put("key.serializer",
"org.apache.kafka.common.serialization.StringSerializer")
+ props.put("value.serializer",
"org.apache.kafka.common.serialization.StringSerializer")
+ props.put("compression.type", "snappy")
+
+ props
+ }
+
+ test("Union two kafka streams, for each write to sink") {
+ var q: StreamingQuery = null
+ try {
+ val topic1 = newTopic()
+ val topic2 = newTopic()
+ testUtils.createTopic(topic1, partitions = 2)
+ testUtils.createTopic(topic2, partitions = 2)
+
+ val props: Properties = getKafkaConsumerProperties
+ val producer1: Producer[String, String] = new KafkaProducer[String,
String](props)
+ val producer2: Producer[String, String] = new KafkaProducer[String,
String](props)
+
+ val readStream1 = spark.readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic1)
+ .load()
+
+ val readStream2 = spark.readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic2)
+ .load()
+
+ val df = readStream1
+ .union(readStream2)
+ .selectExpr("CAST(key AS STRING) AS key", "CAST(value AS STRING) AS
value")
+ .selectExpr("key || ',' || value")
+ .toDF()
+
+ q = runStreamingQuery("union-kafka", df)
+
+ waitForTasksToStart(4)
+
+ val expectedResults = new mutable.ListBuffer[String]()
+ for (batch <- 0 until 3) {
+ (1 to 100).foreach(i => {
+ producer1
+ .send(
+ new ProducerRecord[String, String](
+ topic1,
+ java.lang.Long.toString(i),
+ s"input1-${batch}-${i}"
+ )
+ )
+ .get()
+ producer2
+ .send(
+ new ProducerRecord[String, String](
+ topic2,
+ java.lang.Long.toString(i),
+ s"input2-${batch}-${i}"
+ )
+ )
+ .get()
+ })
+ producer1.flush()
+ producer2.flush()
+
+ expectedResults ++= (1 to 100)
+ .flatMap(v => {
+ Seq(
+ s"${v},input1-${batch}-${v}",
+ s"${v},input2-${batch}-${v}"
+ )
+ })
+ .toList
+
+ eventually(timeout(60.seconds)) {
+ ResultsCollector
+ .get(sinkName)
+ .toArray(new Array[String](ResultsCollector.get(sinkName).size()))
+ .toList
+ .sorted should equal(expectedResults.sorted)
+ }
+ }
+ } finally {
+ if (q != null) {
+ q.stop()
+ }
+ }
+ }
+}
+
+
+/**
+ * Kafka Real-Time Integration test suite.
+ * Tests with a distributed spark cluster with
+ * separate executors processes deployed.
+ */
+class KafkaRealTimeIntegrationSuite
+ extends KafkaSourceTest
+ with StreamRealTimeModeSuiteBase
+ with ThreadAudit
+ with BeforeAndAfterEach
+ with Matchers {
+
+ override protected def createSparkSession =
+ new TestSparkSession(
+ new SparkContext(
+ "local-cluster[3, 5, 1024]", // Ensure we have enough for both stages.
+ "microbatch-context",
+ sparkConf
+ .set("spark.sql.testkey", "true")
+ .set("spark.scheduler.mode", "FAIR")
+ .set("spark.executor.extraJavaOptions",
"-Dio.netty.leakDetection.level=paranoid")
+ )
+ )
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ // testing to make sure the cluster is usable
+ testUtils.createTopic("_test")
+ testUtils.sendMessage(new ProducerRecord[String, String]("_test", "", ""))
+ testUtils.deleteTopic("_test")
+ logInfo("Kafka cluster setup complete....")
+
+ eventually(timeout(10.seconds)) {
+ val executors = sparkContext.getExecutorIds()
+ assert(executors.size == 3, s"executors: ${executors}}")
+ }
+ }
+
+ test("e2e stateless") {
+ var query: StreamingQuery = null
+ try {
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 5)
+
+ val outputTopic = newTopic()
+ testUtils.createTopic(outputTopic, partitions = 5)
+
+ val props: Properties = new Properties()
+
+ props.put("bootstrap.servers", testUtils.brokerAddress)
+ props.put("key.serializer",
"org.apache.kafka.common.serialization.StringSerializer");
+ props.put("value.serializer",
"org.apache.kafka.common.serialization.StringSerializer");
+ props.put("compression.type", "snappy")
+
+ val producer: Producer[String, String] = new KafkaProducer[String,
String](props)
+
+ query = spark.readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("startingOffsets", "earliest")
+ .option("subscribe", inputTopic)
+ .option("kafka.fetch.max.wait.ms", 10)
+ .load()
+ .withColumn("value", substring(col("value"), 0, 500 * 1000))
+ .withColumn("value", base64(col("value")))
+ .withColumn(
+ "headers",
+ array(
+ struct(
+ lit("source-timestamp") as "key",
+ unix_millis(col("timestamp")).cast("STRING").cast("BINARY") as
"value"
+ )
+ )
+ )
+ .drop(col("timestamp"))
+ .writeStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("topic", outputTopic)
+ .option("checkpointLocation",
Files.createTempDirectory("some-prefix").toFile.getName)
+ .option("kafka.max.block.ms", "100")
+ .trigger(RealTimeTrigger.apply("5 minutes"))
+ .outputMode(OutputMode.Update())
+ .start()
+
+ waitForTasksToStart(5)
+
+ var expectedResults: ListBuffer[GenericRowWithSchema] = new ListBuffer
+ for (i <- 0 until 3) {
+ (1 to 100).foreach(i => {
+ producer
+ .send(
+ new ProducerRecord[String, String](
+ inputTopic,
+ java.lang.Long.toString(i),
+ s"payload-${i}"
+ )
+ )
+ .get()
+ })
+
+ producer.flush()
+
+ val kafkaSinkData = spark.read
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", outputTopic)
+ .option("includeHeaders", "true")
+ .option("startingOffsets", "earliest")
+ .load()
+ .withColumn("value", unbase64(col("value")).cast("STRING"))
+ .withColumn("headers-map", map_from_entries(col("headers")))
+ .withColumn("source-timestamp",
conv(hex(col("headers-map.source-timestamp")), 16, 10))
+ .withColumn("sink-timestamp", unix_millis(col("timestamp")))
+
+ // Check the answers
+ val newResults = (1 to 100)
+ .map(v => {
+ new GenericRowWithSchema(
+ Array(s"payload-${v}"),
+ schema = new StructType().add(StructField("value", StringType))
+ )
+ })
+ .toList
+
+ expectedResults ++= newResults
+ expectedResults =
+ expectedResults.sorted((x: GenericRowWithSchema, y:
GenericRowWithSchema) => {
+ x.getString(0).compareTo(y.getString(0))
+ })
+
+ eventually(timeout(1.minute)) {
+ checkAnswer(kafkaSinkData.select("value"), expectedResults.toSeq)
+ }
+ }
+ } finally {
+ if (query != null) {
+ query.stop()
+ }
+ }
+ }
+}
diff --git
a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
new file mode 100644
index 000000000000..83aae64d84f7
--- /dev/null
+++
b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{SparkConf, SparkContext, SparkIllegalStateException}
+import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemorySink
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, Trigger}
+import org.apache.spark.sql.streaming.OutputMode.Update
+import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.util.SystemClock
+
+class KafkaRealTimeModeSuite
+ extends KafkaSourceTest
+ with Matchers {
+
+ override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds")
+
+ override protected def sparkConf: SparkConf = {
+ // Should turn to use StreamingShuffleManager when it is ready.
+ super.sparkConf
+ .set("spark.databricks.streaming.realTimeMode.enabled", "true")
+ .set(
+ SQLConf.STATE_STORE_PROVIDER_CLASS,
+ classOf[RocksDBStateStoreProvider].getName)
+ }
+
+ override protected def createSparkSession = new TestSparkSession(
+ new SparkContext(
+ "local[8]", // Ensure enough number of cores to ensure concurrent
schedule of all tasks.
+ "streaming-rtm-context",
+ sparkConf.set("spark.sql.testkey", "true")))
+
+
+ import testImplicits._
+
+ val sleepOneSec = new ExternalAction() {
+ override def runAction(): Unit = {
+ Thread.sleep(1000)
+ }
+ }
+
+ var clock = new GlobalSingletonManualClock()
+
+ private def advanceRealTimeClock(timeMs: Int) = new ExternalAction {
+ override def runAction(): Unit = {
+ clock.advance(timeMs)
+ }
+
+ override def toString(): String = {
+ s"advanceRealTimeClock($timeMs)"
+ }
+ }
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set(
+ SQLConf.STREAMING_REAL_TIME_MODE_MIN_BATCH_DURATION,
+ defaultTrigger.batchDurationMs
+ )
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ GlobalSingletonManualClock.reset()
+ }
+
+ override def afterEach(): Unit = {
+ LowLatencyClock.setClock(new SystemClock)
+ super.afterEach()
+ }
+
+ def waitUntilBatchStartedOrProcessed(q: StreamingQuery, batchId: Long): Unit
= {
+ eventually(timeout(60.seconds)) {
+ val tasksRunning =
+
spark.sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum
+ val lastBatch = {
+ if (q.lastProgress == null) {
+ -1
+ } else {
+ q.lastProgress.batchId
+ }
+ }
+ val batchStarted = tasksRunning >= 1 && lastBatch >= batchId - 1
+ val batchProcessed = lastBatch >= batchId
+ assert(batchStarted || batchProcessed,
+ s"tasksRunning: ${tasksRunning} lastBatch: ${lastBatch}")
+ }
+ }
+
+ // A simple unit test that reads from Kakfa source, does a simple map and
writes to memory
+ // sink.
+ test("simple map") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+ testUtils.sendMessages(topic, Array("3"), Some(1))
+
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ testStream(reader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ CheckAnswerWithTimeout(60000, 2, 3, 4),
+ sleepOneSec,
+ sleepOneSec,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+ testUtils.sendMessages(topic, Array("6"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7),
+ WaitUntilCurrentBatchProcessed,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("7"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8),
+ WaitUntilCurrentBatchProcessed,
+ StopStream,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("8"), Some(0))
+ testUtils.sendMessages(topic, Array("9"), Some(1))
+ }
+ },
+ StartStream(),
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+ WaitUntilCurrentBatchProcessed)
+ }
+
+ // A simple unit test that reads from Kakfa source, does a simple map and
writes to memory
+ // sink. Make sure there is no data for a whole batch. Also, after restart
the first batch
+ // has no data.
+ test("simple map with empty batch") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ val reader = spark.readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ testStream(reader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ WaitUntilBatchProcessed(0),
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("1"), Some(0))
+ testUtils.sendMessages(topic, Array("2"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3),
+ WaitUntilCurrentBatchProcessed,
+ WaitUntilCurrentBatchProcessed,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("3"), Some(1))
+ }
+ },
+ WaitUntilCurrentBatchProcessed,
+ CheckAnswerWithTimeout(5000, 2, 3, 4),
+ StopStream,
+ StartStream(),
+ WaitUntilCurrentBatchProcessed,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("4"), Some(0))
+ testUtils.sendMessages(topic, Array("5"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6),
+ WaitUntilCurrentBatchProcessed
+ )
+ }
+
+ // A simple unit test that reads from Kakfa source, does a simple map and
writes to memory
+ // sink.
+ test("add partition") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 1)
+
+ testUtils.sendMessages(topic, Array("1", "2", "3"), Some(0))
+
+ val reader = spark.readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ testStream(reader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ CheckAnswerWithTimeout(60000, 2, 3, 4),
+ sleepOneSec,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.addPartitions(topic, 2)
+ testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+ testUtils.sendMessages(topic, Array("6"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7),
+ WaitUntilCurrentBatchProcessed,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.addPartitions(topic, 4)
+ testUtils.sendMessages(topic, Array("7"), Some(2))
+ }
+ },
+ CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7, 8),
+ WaitUntilCurrentBatchProcessed,
+ StopStream,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("8"), Some(3))
+ testUtils.sendMessages(topic, Array("9"), Some(2))
+ }
+ },
+ StartStream(),
+ CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+ WaitUntilCurrentBatchProcessed
+ )
+ }
+
+ test("Real-Time Mode fetches latestOffset again at end of the batch") {
+ // LowLatencyClock does not affect the wait time of kafka iterator, so
advancing the clock
+ // does not affect the test finish time. The purpose of using it is to
make the query start
+ // time consistent, so the test behaves the same.
+ LowLatencyClock.setClock(clock)
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 1)
+
+ val reader = spark.readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ // extra large number to make sure fetch does
+ // not return within batch duration
+ .option("kafka.fetch.max.wait.ms", "20000000")
+ .option("kafka.fetch.min.bytes", "20000000")
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ testStream(reader, Update, sink = new ContinuousMemorySink())(
+ StartStream(Trigger.RealTime(10000)),
+ advanceRealTimeClock(2000),
+ Execute { q =>
+ waitUntilBatchStartedOrProcessed(q, 0)
+ testUtils.sendMessages(topic, Array("1"), Some(0))
+ },
+ advanceRealTimeClock(8000),
+ WaitUntilBatchProcessed(0),
+ Execute { q =>
+ val expectedMetrics = Map(
+ "minOffsetsBehindLatest" -> "1",
+ "maxOffsetsBehindLatest" -> "1",
+ "avgOffsetsBehindLatest" -> "1.0",
+ "estimatedTotalBytesBehindLatest" -> null
+ )
+ eventually(timeout(60.seconds)) {
+ expectedMetrics.foreach { case (metric, expectedValue) =>
+ assert(q.lastProgress.sources(0).metrics.get(metric) ===
expectedValue)
+ }
+ }
+ },
+ advanceRealTimeClock(2000),
+ Execute { q =>
+ waitUntilBatchStartedOrProcessed(q, 1)
+ testUtils.sendMessages(topic, Array("2", "3"), Some(0))
+ },
+ advanceRealTimeClock(8000),
+ WaitUntilBatchProcessed(1),
+ Execute { q =>
+ val expectedMetrics = Map(
+ "minOffsetsBehindLatest" -> "3",
+ "maxOffsetsBehindLatest" -> "3",
+ "avgOffsetsBehindLatest" -> "3.0",
+ "estimatedTotalBytesBehindLatest" -> null
+ )
+ eventually(timeout(60.seconds)) {
+ expectedMetrics.foreach { case (metric, expectedValue) =>
+ assert(q.lastProgress.sources(0).metrics.get(metric) ===
expectedValue)
+ }
+ }
+ }
+ )
+ }
+
+ // Validate the query fails with minOffsetPerTrigger option set.
+ Seq(
+ "maxoffsetspertrigger",
+ "minoffsetspertrigger",
+ "minpartitions",
+ "endingtimestamp",
+ "maxtriggerdelay").foreach { opt =>
+ test(s"$opt incompatible") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ val reader = spark.readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .option(opt, "5")
+ .load()
+ testStream(reader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ ExpectFailure[UnsupportedOperationException] { (t: Throwable) => {
+ assert(t.getMessage.toLowerCase().contains(opt))
+ }
+ }
+ )
+ }
+ }
+
+ test("union 2 dataframes after projection") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ val topic1 = newTopic()
+ testUtils.createTopic(topic1, partitions = 2)
+
+ testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+ testUtils.sendMessages(topic, Array("3"), Some(1))
+
+ testUtils.sendMessages(topic1, Array("11", "12"), Some(0))
+ testUtils.sendMessages(topic1, Array("13"), Some(1))
+
+ val reader1 = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ val reader2 = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic1)
+ .option("startingOffsets", "earliest")
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ val unionedReader = reader1.union(reader2)
+
+ testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ CheckAnswerWithTimeout(60000, 2, 3, 4, 12, 13, 14),
+ sleepOneSec,
+ sleepOneSec,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+ testUtils.sendMessages(topic, Array("6"), Some(1))
+ testUtils.sendMessages(topic1, Array("14", "15"), Some(0))
+ testUtils.sendMessages(topic1, Array("16"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17),
+ WaitUntilCurrentBatchProcessed,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("7"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17,
8),
+ WaitUntilCurrentBatchProcessed,
+ StopStream,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("8"), Some(0))
+ testUtils.sendMessages(topic, Array("9"), Some(1))
+ testUtils.sendMessages(topic1, Array("19"), Some(1))
+ }
+ },
+ StartStream(),
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17,
8, 9, 10, 20),
+ WaitUntilCurrentBatchProcessed)
+ }
+
+ test("union 3 dataframes with and without maxPartitions") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ val topic1 = newTopic()
+ testUtils.createTopic(topic1, partitions = 2)
+
+ val topic2 = newTopic()
+ testUtils.createTopic(topic2, partitions = 2)
+
+ testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+ testUtils.sendMessages(topic, Array("3"), Some(1))
+
+ testUtils.sendMessages(topic1, Array("11", "12"), Some(0))
+ testUtils.sendMessages(topic1, Array("13"), Some(1))
+
+ testUtils.sendMessages(topic2, Array("21", "22"), Some(0))
+ testUtils.sendMessages(topic2, Array("23"), Some(1))
+
+ val reader1 = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .option("maxPartitions", "1")
+ .load()
+
+ val reader2 = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic1)
+ .option("startingOffsets", "earliest")
+ .load()
+
+ val reader3 = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic2)
+ .option("startingOffsets", "earliest")
+ .option("maxPartitions", "3")
+ .load()
+
+ val unionedReader = reader1.union(reader2).union(reader3)
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ CheckAnswerWithTimeout(10000, 2, 3, 4, 12, 13, 14, 22, 23, 24),
+ sleepOneSec,
+ sleepOneSec,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("4"), Some(0))
+ testUtils.sendMessages(topic, Array("5"), Some(1))
+ testUtils.sendMessages(topic1, Array("14", "15"), Some(0))
+ testUtils.sendMessages(topic2, Array("24"), Some(0))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15,
16, 25),
+ WaitUntilCurrentBatchProcessed,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("6"), Some(1))
+ testUtils.sendMessages(topic2, Array("25"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15,
16, 25, 7, 26),
+ WaitUntilCurrentBatchProcessed,
+ StopStream,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic1, Array("16"), Some(1))
+ }
+ },
+ StartStream(),
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15,
16, 25, 7, 26, 17),
+ WaitUntilCurrentBatchProcessed)
+ }
+
+ test("self union workaround") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+ testUtils.sendMessages(topic, Array("3"), Some(1))
+
+ val reader1 = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+
+
+ val reader2 = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+
+ val unionedReader = reader1.union(reader2)
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ CheckAnswerWithTimeout(60000, 2, 3, 4, 2, 3, 4),
+ sleepOneSec,
+ sleepOneSec,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+ testUtils.sendMessages(topic, Array("6"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7),
+ WaitUntilCurrentBatchProcessed,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("7"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7, 8, 8),
+ WaitUntilCurrentBatchProcessed,
+ StopStream,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("8"), Some(0))
+ testUtils.sendMessages(topic, Array("9"), Some(1))
+ }
+ },
+ StartStream(),
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7, 8, 8,
9, 10, 9, 10),
+ WaitUntilCurrentBatchProcessed)
+ }
+
+ test("union 2 different sources - Kafka and LowLatencyMemoryStream") {
+ import testImplicits._
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ val memoryStreamRead = LowLatencyMemoryStream[String](2)
+
+ testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+ testUtils.sendMessages(topic, Array("3"), Some(1))
+ memoryStreamRead.addData("11", "12", "13")
+
+ val reader1 = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+
+
+ val reader2 = memoryStreamRead.toDF()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+
+ val unionedReader = reader1.union(reader2)
+ .map(_.toInt)
+ .map(_ + 1)
+
+ testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ CheckAnswerWithTimeout(60000, 2, 3, 4, 12, 13, 14),
+ sleepOneSec,
+ sleepOneSec,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+ testUtils.sendMessages(topic, Array("6"), Some(1))
+ memoryStreamRead.addData("14")
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15),
+ WaitUntilCurrentBatchProcessed,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("7"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 8),
+ WaitUntilCurrentBatchProcessed,
+ StopStream,
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ testUtils.sendMessages(topic, Array("8"), Some(0))
+ testUtils.sendMessages(topic, Array("9"), Some(1))
+ memoryStreamRead.addData("15", "16", "17")
+ }
+ },
+ StartStream(),
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 8, 9, 10,
16, 17, 18),
+ WaitUntilCurrentBatchProcessed)
+ }
+
+ test("self union - not allowed") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .load()
+
+ val unionedReader = reader.union(reader)
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ testStream(unionedReader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ ExpectFailure[SparkIllegalStateException] { ex =>
+ checkErrorMatchPVals(
+ ex.asInstanceOf[SparkIllegalStateException],
+
"STREAMING_REAL_TIME_MODE.IDENTICAL_SOURCES_IN_UNION_NOT_SUPPORTED",
+ parameters =
+ Map("sources" -> "(?s).*")
+ )
+ }
+ )
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
index 5bb01bdea26e..9199580f6587 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
@@ -17,11 +17,18 @@
package org.apache.spark.sql.streaming
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
+
+import scala.collection.mutable
+
import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.SpanSugar._
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock
-import org.apache.spark.sql.execution.streaming.RealTimeTrigger
+import org.apache.spark.sql.execution.streaming.{LowLatencyMemoryStream,
RealTimeTrigger}
+import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryWrapper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
import org.apache.spark.sql.test.TestSparkSession
@@ -43,6 +50,124 @@ trait StreamRealTimeModeSuiteBase extends StreamTest with
Matchers {
"local[10]", // Ensure enough number of cores to ensure concurrent
schedule of all tasks.
"streaming-rtm-context",
sparkConf.set("spark.sql.testkey", "true")))
+
+ /**
+ * Should only be used in real-time mode where the batch duration is long
enough to ensure
+ * eventually does not skip the batch due to long refresh interval.
+ */
+ def waitForTasksToStart(numTasks: Int): Unit = {
+ eventually(timeout(60.seconds)) {
+ val tasksRunning = spark.sparkContext.statusTracker
+ .getExecutorInfos.map(_.numRunningTasks()).sum
+ assert(tasksRunning == numTasks, s"tasksRunning: ${tasksRunning}")
+ }
+ }
+}
+
+/**
+ * Must be a singleton object to ensure serializable when used in
ForeachWriter.
+ * Users must make sure different test suites use different sink names to
avoid race conditions.
+ */
+object ResultsCollector extends ConcurrentHashMap[String,
ConcurrentLinkedQueue[String]] {
+ def reset(): Unit = {
+ clear()
+ }
+}
+
+/**
+ * Base class that contains helper methods to test Real-Time Mode streaming
queries.
+ *
+ * The general procedure to use this suite is as follows:
+ * 1. Call createMemoryStream to create a memory stream with manual clock.
+ * 2. Call runStreamingQuery to start a streaming query with custom logic.
+ * 3. Call processBatches to add data to the memory stream and validate
results.
+ *
+ * It uses foreach to collect results into [[ResultsCollector]]. It also tests
whether
+ * results are emitted in real-time by having longer batch durations than the
waiting time.
+ */
+trait StreamRealTimeModeE2ESuiteBase extends StreamRealTimeModeSuiteBase {
+ import testImplicits._
+
+ override protected val defaultTrigger = RealTimeTrigger.apply("300 seconds")
+
+ protected final def sinkName: String = getClass.getName + "Sink"
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ ResultsCollector.reset()
+ }
+
+ // Create a ForeachWriter that collects results into ResultsCollector.
+ def foreachWriter(sinkName: String): ForeachWriter[String] = new
ForeachWriter[String] {
+ override def open(partitionId: Long, epochId: Long): Boolean = {
+ true
+ }
+
+ override def process(value: String): Unit = {
+ val collector =
+ ResultsCollector.computeIfAbsent(sinkName, (_) => new
ConcurrentLinkedQueue[String]())
+ collector.add(value)
+ }
+
+ override def close(errorOrNull: Throwable): Unit = {}
+ }
+
+ def createMemoryStream(numPartitions: Int = 5)
+ : (LowLatencyMemoryStream[(String, Int)], GlobalSingletonManualClock) = {
+ val clock = new GlobalSingletonManualClock()
+ LowLatencyClock.setClock(clock)
+ val read = LowLatencyMemoryStream[(String, Int)](numPartitions)
+ (read, clock)
+ }
+
+ def runStreamingQuery(queryName: String, df:
org.apache.spark.sql.DataFrame): StreamingQuery = {
+ df.as[String]
+ .writeStream
+ .outputMode(OutputMode.Update())
+ .foreach(foreachWriter(sinkName))
+ .queryName(queryName)
+ .trigger(defaultTrigger)
+ .start()
+ }
+
+ // Add test data to the memory source and validate results
+ def processBatches(
+ query: StreamingQuery,
+ read: LowLatencyMemoryStream[(String, Int)],
+ clock: GlobalSingletonManualClock,
+ numRowsPerBatch: Int,
+ numBatches: Int,
+ expectedResultsGenerator: (String, Int) => Array[String]): Unit = {
+ val expectedResults = mutable.ListBuffer[String]()
+ for (i <- 0 until numBatches) {
+ for (key <- List("a", "b", "c")) {
+ for (j <- 1 to numRowsPerBatch) {
+ val value = i * numRowsPerBatch + j
+ read.addData((key, value))
+ expectedResults ++= expectedResultsGenerator(key, value)
+ }
+ }
+
+ eventually(timeout(60.seconds)) {
+ ResultsCollector
+ .get(sinkName)
+ .toArray(new Array[String](ResultsCollector.get(sinkName).size()))
+ .toList
+ .sorted should equal(expectedResults.sorted)
+ }
+
+ clock.advance(defaultTrigger.batchDurationMs)
+
+ eventually(timeout(60.seconds)) {
+ query
+ .asInstanceOf[StreamingQueryWrapper]
+ .streamingQuery
+ .getLatestExecutionContext()
+ .batchId should be(i + 1)
+ query.lastProgress.sources(0).numInputRows should be(numRowsPerBatch *
3)
+ }
+ }
+ }
}
abstract class StreamRealTimeModeManualClockSuiteBase extends
StreamRealTimeModeSuiteBase {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 3ff8cab64d65..e57c4e1e665c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -284,6 +284,8 @@ trait StreamTest extends QueryTest with SharedSparkSession
with TimeLimits with
case class WaitUntilBatchProcessed(batchId: Long) extends StreamAction with
StreamMustBeRunning
+ case object WaitUntilCurrentBatchProcessed extends StreamAction with
StreamMustBeRunning
+
/**
* Signals that a failure is expected and should not kill the test.
*
@@ -659,6 +661,20 @@ trait StreamTest extends QueryTest with SharedSparkSession
with TimeLimits with
throw currentStream.exception.get
}
+ case WaitUntilCurrentBatchProcessed =>
+ if (currentStream.exception.isDefined) {
+ throw currentStream.exception.get
+ }
+ val currBatch =
currentStream.commitLog.getLatestBatchId().getOrElse(-1L)
+ eventually("Current batch never finishes") {
+ assert(currentStream.commitLog.getLatestBatchId() != None
+ && currentStream.commitLog.getLatestBatchId().get > currBatch)
+
+ // See WaitUntilBatchProcessed for an explanation of why we wait
for the progress
+ val latestProgressBatchId =
+
currentStream.recentProgress.lastOption.map(_.batchId).getOrElse(-1L)
+ assert(latestProgressBatchId >= currBatch)
+ }
case StopStream =>
verify(currentStream != null, "can not stop a stream that is not
running")
try failAfter(streamingTimeout) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]