This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 81d9d1f82850 [SPARK-53847] Add ContinuousMemorySink for Real-time Mode 
testing
81d9d1f82850 is described below

commit 81d9d1f82850bd7a37810331487bfdb77ca67ea9
Author: Jerry Peng <[email protected]>
AuthorDate: Sun Oct 12 13:59:44 2025 -0700

    [SPARK-53847] Add ContinuousMemorySink for Real-time Mode testing
    
    ### What changes were proposed in this pull request?
    
    Add a new in memory sink called "ContinuousMemorySink" to facilitate RTM 
testing.  This sink differentiates from the existing MemorySink by immediately 
sending output back to the driver once the output is generated and not just at 
the end of the batch which is what the current MemorySink does.
    
    ### Why are the changes needed?
    
    To facilitate RTM testing
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added simple test.  There will be many RTM related tests that will be added 
in future PRs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #52550 from jerrypeng/SPARK-53847.
    
    Authored-by: Jerry Peng <[email protected]>
    Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
 .../scala/org/apache/spark/util/RpcUtils.scala     |   9 ++
 .../streaming/sources/ContinuousMemory.scala       | 176 +++++++++++++++++++++
 .../sources/RealTimeRowWriterFactory.scala         |  75 +++++++++
 3 files changed, 260 insertions(+)

diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala 
b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
index 30f5fced5a8b..58fa4df3b72d 100644
--- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala
@@ -36,6 +36,15 @@ private[spark] object RpcUtils {
     rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name)
   }
 
+  def makeDriverRef(
+      name: String,
+      driverHost: String,
+      driverPort: Int,
+      rpcEnv: RpcEnv): RpcEndpointRef = {
+    Utils.checkHost(driverHost)
+    rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name)
+  }
+
   /** Returns the default Spark timeout to use for RPC ask operations. */
   def askRpcTimeout(conf: SparkConf): RpcTimeout = {
     RpcTimeout(conf, Seq(RPC_ASK_TIMEOUT.key, NETWORK_TIMEOUT.key), "120s")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemory.scala
new file mode 100644
index 000000000000..fd9d7bb654bd
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemory.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.util
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, SparkUnsupportedOperationException}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableCapability}
+import org.apache.spark.sql.connector.write.{
+  LogicalWriteInfo,
+  PhysicalWriteInfo,
+  Write,
+  WriteBuilder,
+  WriterCommitMessage
+}
+import 
org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, 
StreamingWrite}
+import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A sink that stores the results in memory. This 
[[org.apache.spark.sql.execution.streaming.Sink]]
+ * is primarily intended for use in unit tests and does not provide durability.
+ * This is mostly copied from MemorySink, except that the data needs to be 
available not in
+ * commit() but after each write.
+ */
+class ContinuousMemorySink
+    extends MemorySink
+    with SupportsWrite {
+
+  private val batches = new ArrayBuffer[Row]()
+  override def name(): String = "ContinuousMemorySink"
+
+  override def schema(): StructType = StructType(Nil)
+
+  override def capabilities(): util.Set[TableCapability] = {
+    util.EnumSet.of(TableCapability.STREAMING_WRITE)
+  }
+
+  override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+    new WriteBuilder with SupportsStreamingUpdateAsAppend {
+      private val inputSchema: StructType = info.schema()
+
+      override def build(): Write = {
+        new ContinuousMemoryWrite(batches, inputSchema)
+      }
+    }
+  }
+
+  /** Returns all rows that are stored in this [[Sink]]. */
+  override def allData: Seq[Row] = {
+    val batches = getBatches()
+    batches.synchronized {
+      batches.toSeq
+    }
+  }
+
+  override def latestBatchId: Option[Long] = {
+    None
+  }
+
+  override def latestBatchData: Seq[Row] = {
+    throw new SparkUnsupportedOperationException(
+      errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
+      messageParameters = Map("operation" -> "latestBatchData")
+    )
+  }
+
+  override def dataSinceBatch(sinceBatchId: Long): Seq[Row] = {
+    throw new SparkUnsupportedOperationException(
+      errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
+      messageParameters = Map("operation" -> "dataSinceBatch")
+    )
+  }
+
+  override def toDebugString: String = {
+    s"${allData}"
+  }
+
+  override def write(batchId: Long, needTruncate: Boolean, newRows: 
Array[Row]): Unit = {
+    throw new SparkUnsupportedOperationException(
+      errorClass = "UNSUPPORTED_OPERATION_FOR_CONTINUOUS_MEMORY_SINK",
+      messageParameters = Map("operation" -> "write")
+    )
+  }
+
+  override def clear(): Unit = synchronized {
+    batches.clear()
+  }
+
+  private def getBatches(): ArrayBuffer[Row] = {
+    batches
+  }
+
+  override def toString(): String = "ContinuousMemorySink"
+}
+
+class ContinuousMemoryWrite(batches: ArrayBuffer[Row], schema: StructType) 
extends Write {
+  override def toStreaming: StreamingWrite = {
+    new ContinuousMemoryStreamingWrite(batches, schema)
+  }
+}
+
+/**
+ * An RPC endpoint that receives rows and stores them to the ArrayBuffer in 
real-time.
+ */
+class MemoryRealTimeRpcEndpoint(
+    override val rpcEnv: RpcEnv,
+    schema: StructType,
+    batches: ArrayBuffer[Row]
+) extends ThreadSafeRpcEndpoint {
+  private val encoder = 
ExpressionEncoder(schema).resolveAndBind().createDeserializer()
+
+  override def receive: PartialFunction[Any, Unit] = {
+    case rows: Array[InternalRow] =>
+      // synchronized block is optional here since ThreadSafeRpcEndpoint 
already, just to be safe
+      batches.synchronized {
+        rows.foreach { row =>
+          batches += encoder(row)
+        }
+      }
+  }
+}
+
+class ContinuousMemoryStreamingWrite(val batches: ArrayBuffer[Row], schema: 
StructType)
+    extends StreamingWrite {
+
+  private val memoryEndpoint =
+    new MemoryRealTimeRpcEndpoint(
+      SparkEnv.get.rpcEnv,
+      schema,
+      batches
+    )
+  @volatile private var endpointRef: RpcEndpointRef = _
+
+  override def createStreamingWriterFactory(info: PhysicalWriteInfo): 
StreamingDataWriterFactory = {
+    val endpointName = 
s"MemoryRealTimeRpcEndpoint-${java.util.UUID.randomUUID()}"
+    endpointRef = memoryEndpoint.rpcEnv.setupEndpoint(endpointName, 
memoryEndpoint)
+    RealTimeRowWriterFactory(endpointName, endpointRef.address)
+  }
+
+  override def useCommitCoordinator(): Boolean = false
+
+  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {
+    // We don't need to commit anything in this case, as the rows have already 
been printed
+    if (endpointRef != null) {
+      memoryEndpoint.rpcEnv.stop(endpointRef)
+    }
+  }
+
+  override def abort(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {
+    if (endpointRef != null) {
+      memoryEndpoint.rpcEnv.stop(endpointRef)
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RealTimeRowWriterFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RealTimeRowWriterFactory.scala
new file mode 100644
index 000000000000..dec1f47847b8
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RealTimeRowWriterFactory.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.rpc.RpcAddress
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import 
org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
+import org.apache.spark.util.RpcUtils
+
+/**
+ * A [[StreamingDataWriterFactory]] that creates [[RealTimeRowWriter]], which 
sends rows to
+ * the driver in real-time through RPC.
+ *
+ * Note that, because it sends all rows to the driver, this factory will 
generally be unsuitable
+ * for production-quality sinks. It's intended for use in tests.
+ *
+ */
+case class RealTimeRowWriterFactory(
+    driverEndpointName: String,
+    driverEndpointAddr: RpcAddress
+) extends StreamingDataWriterFactory {
+  override def createWriter(
+      partitionId: Int,
+      taskId: Long,
+      epochId: Long): DataWriter[InternalRow] = {
+    new RealTimeRowWriter(
+      driverEndpointName,
+      driverEndpointAddr
+    )
+  }
+}
+
+/**
+ * A [[DataWriter]] that sends arrays of rows to the driver in real-time 
through RPC.
+ */
+class RealTimeRowWriter(
+    driverEndpointName: String,
+    driverEndpointAddr: RpcAddress
+) extends DataWriter[InternalRow] {
+
+  private val endpointRef = RpcUtils.makeDriverRef(
+    driverEndpointName,
+    driverEndpointAddr.host,
+    driverEndpointAddr.port,
+    SparkEnv.get.rpcEnv
+  )
+
+  // Spark reuses the same `InternalRow` instance, here we copy it before 
buffer it.
+  override def write(row: InternalRow): Unit = {
+    endpointRef.send(Array(row.copy()))
+  }
+
+  override def commit(): WriterCommitMessage = { null }
+
+  override def abort(): Unit = {}
+
+  override def close(): Unit = {}
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to