This is an automated email from the ASF dual-hosted git repository.
yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 345e7da1cb0b [SPARK-53530][SS] Clean up the useless code related to
`TransformWithStateInPySparkStateServer`
345e7da1cb0b is described below
commit 345e7da1cb0b764d2eba9ae0620714a103670696
Author: yangjie01 <[email protected]>
AuthorDate: Wed Sep 10 12:07:33 2025 +0800
[SPARK-53530][SS] Clean up the useless code related to
`TransformWithStateInPySparkStateServer`
### What changes were proposed in this pull request?
This PR performs the following cleanup on the code related to
`TransformWithStateInPySparkStateServer`:
- Removed the `private` function `sendIteratorForListState` from
`TransformWithStateInPySparkStateServer`, as it is no longer used after
SPARK-51891.
- Removed the function `sendIteratorAsArrowBatches` from
`TransformWithStateInPySparkStateServer`, as it is no longer used after
SPARK-52333.
- Removed the input parameters `timeZoneId`, `errorOnDuplicatedFieldNames`,
`largeVarTypes`, and `arrowStreamWriterForTest` from the constructor of
`TransformWithStateInPySparkStateServer`, as they are no longer used after the
cleanup of `sendIteratorAsArrowBatches`.
- Removed the input parameter `timeZoneId` from the constructor of
`TransformWithStateInPySparkPythonPreInitRunner`, as it was only used for
constructing `TransformWithStateInPySparkStateServer`.
### Why are the changes needed?
Code cleanup.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Pass Github Actions
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52279 from LuciferYang/TransformWithStateInPySparkStateServer.
Lead-authored-by: yangjie01 <[email protected]>
Co-authored-by: YangJie <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
---
.../TransformWithStateInPySparkExec.scala | 1 -
.../TransformWithStateInPySparkPythonRunner.scala | 6 +-
.../TransformWithStateInPySparkStateServer.scala | 70 ----------------------
...arkTransformWithStateInPySparkStateServer.scala | 3 -
...ansformWithStateInPySparkStateServerSuite.scala | 33 +++++-----
5 files changed, 18 insertions(+), 95 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index f8390b7d878f..c10d21933c2f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -154,7 +154,6 @@ case class TransformWithStateInPySparkExec(
val runner = new TransformWithStateInPySparkPythonPreInitRunner(
pythonFunction,
"pyspark.sql.streaming.transform_with_state_driver_worker",
- sessionLocalTimeZone,
groupingKeySchema,
driverProcessorHandle
)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 51dc179c901a..329bd4335265 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -220,7 +220,7 @@ abstract class
TransformWithStateInPySparkPythonBaseRunner[I](
executionContext.execute(
new TransformWithStateInPySparkStateServer(stateServerSocket,
processorHandle,
- groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames,
largeVarTypes,
+ groupingKeySchema,
sqlConf.arrowTransformWithStateInPySparkMaxStateRecordsPerBatch,
batchTimestampMs, eventTimeWatermarkForEviction))
@@ -245,7 +245,6 @@ abstract class
TransformWithStateInPySparkPythonBaseRunner[I](
class TransformWithStateInPySparkPythonPreInitRunner(
func: PythonFunction,
workerModule: String,
- timeZoneId: String,
groupingKeySchema: StructType,
processorHandleImpl: DriverStatefulProcessorHandleImpl)
extends StreamingPythonRunner(func, "", "", workerModule)
@@ -299,8 +298,7 @@ class TransformWithStateInPySparkPythonPreInitRunner(
override def run(): Unit = {
try {
new TransformWithStateInPySparkStateServer(stateServerSocket,
processorHandleImpl,
- groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames = true,
- largeVarTypes = sqlConf.arrowUseLargeVarTypes,
+ groupingKeySchema,
sqlConf.arrowTransformWithStateInPySparkMaxStateRecordsPerBatch).run()
} catch {
case e: Exception =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
index 4edeae132b47..59acf434035e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
@@ -25,15 +25,12 @@ import scala.collection.mutable
import scala.jdk.CollectionConverters._
import com.google.protobuf.ByteString
-import org.apache.arrow.vector.VectorSchemaRoot
-import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.SparkEnv
import org.apache.spark.internal.{Logging, LogKeys}
import
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.api.python.PythonSQLUtils
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType
@@ -43,8 +40,6 @@ import
org.apache.spark.sql.execution.streaming.state.StateMessage.KeyAndValuePa
import
org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponseWithListGet
import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig,
ValueState}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.util.Utils
/**
* This class is used to handle the state requests from the Python side. It
runs on a separate
@@ -60,16 +55,12 @@ class TransformWithStateInPySparkStateServer(
stateServerSocket: ServerSocketChannel,
statefulProcessorHandle: StatefulProcessorHandleImplBase,
groupingKeySchema: StructType,
- timeZoneId: String,
- errorOnDuplicatedFieldNames: Boolean,
- largeVarTypes: Boolean,
arrowTransformWithStateInPySparkMaxRecordsPerBatch: Int,
batchTimestampMs: Option[Long] = None,
eventTimeWatermarkForEviction: Option[Long] = None,
outputStreamForTest: DataOutputStream = null,
valueStateMapForTest: mutable.HashMap[String, ValueStateInfo] = null,
deserializerForTest: TransformWithStateInPySparkDeserializer = null,
- arrowStreamWriterForTest: BaseStreamingArrowWriter = null,
listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null,
iteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null,
mapStatesMapForTest : mutable.HashMap[String, MapStateInfo] = null,
@@ -533,28 +524,6 @@ class TransformWithStateInPySparkStateServer(
}
}
- private def sendIteratorForListState(iter: Iterator[Row]): Unit = {
- // Only write a single batch in each GET request. Stops writing row if
rowCount reaches
- // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to
handle a case
- // when there are multiple state variables, user tries to access a
different state variable
- // while the current state variable is not exhausted yet.
- var rowCount = 0
- while (iter.hasNext && rowCount <
arrowTransformWithStateInPySparkMaxRecordsPerBatch) {
- val data = iter.next()
-
- // Serialize the value row as a byte array
- val valueBytes = PythonSQLUtils.toPyRow(data)
- val lenBytes = valueBytes.length
-
- outputStream.writeInt(lenBytes)
- outputStream.write(valueBytes)
-
- rowCount += 1
- }
- outputStream.writeInt(-1)
- outputStream.flush()
- }
-
private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
val stateName = message.getStateName
if (!mapStates.contains(stateName)) {
@@ -939,45 +908,6 @@ class TransformWithStateInPySparkStateServer(
outputStream.write(responseMessageBytes)
}
- def sendIteratorAsArrowBatches[T](
- iter: Iterator[T],
- outputSchema: StructType,
- arrowStreamWriterForTest: BaseStreamingArrowWriter = null)(func: T =>
InternalRow): Unit = {
- outputStream.flush()
- val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId,
- errorOnDuplicatedFieldNames, largeVarTypes)
- val allocator = ArrowUtils.rootAllocator.newChildAllocator(
- s"stdout writer for transformWithStateInPySpark state socket", 0,
Long.MaxValue)
- val root = VectorSchemaRoot.create(arrowSchema, allocator)
- val writer = new ArrowStreamWriter(root, null, outputStream)
- val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
- arrowStreamWriterForTest
- } else {
- new BaseStreamingArrowWriter(root, writer,
- arrowTransformWithStateInPySparkMaxRecordsPerBatch)
- }
- // Only write a single batch in each GET request. Stops writing row if
rowCount reaches
- // the arrowTransformWithStateInPySparkMaxRecordsPerBatch limit. This is
to handle a case
- // when there are multiple state variables, user tries to access a
different state variable
- // while the current state variable is not exhausted yet.
- var rowCount = 0
- while (iter.hasNext && rowCount <
arrowTransformWithStateInPySparkMaxRecordsPerBatch) {
- val data = iter.next()
- val internalRow = func(data)
- arrowStreamWriter.writeRow(internalRow)
- rowCount += 1
- }
- arrowStreamWriter.finalizeCurrentArrowBatch()
- Utils.tryWithSafeFinally {
- // end writes footer to the output stream and doesn't clean any
resources.
- // It could throw exception if the output stream is closed, so it
should be
- // in the try block.
- writer.end()
- } {
- root.close()
- allocator.close()
- }
- }
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/benchmark/BenchmarkTransformWithStateInPySparkStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/benchmark/BenchmarkTransformWithStateInPySparkStateServer.scala
index 5dc7d9733dcd..91162c7b02f9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/benchmark/BenchmarkTransformWithStateInPySparkStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/benchmark/BenchmarkTransformWithStateInPySparkStateServer.scala
@@ -351,9 +351,6 @@ object BenchmarkTransformWithStateInPySparkStateServer
extends App {
serverSocketChannel,
stateHandleImpl,
groupingKeySchema,
- timeZoneId,
- errorOnDuplicatedFieldNames,
- largeVarTypes,
arrowTransformWithStateInPySparkMaxRecordsPerBatch
)
// scalastyle:off println
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
index ff99b4ee280d..013aa375c308 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
@@ -100,9 +100,9 @@ class TransformWithStateInPySparkStateServerSuite extends
SparkFunSuite with Bef
batchTimestampMs = mock(classOf[Option[Long]])
eventTimeWatermarkForEviction = mock(classOf[Option[Long]])
stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
- statefulProcessorHandle, groupingKeySchema, "", false, false, 2,
+ statefulProcessorHandle, groupingKeySchema, 2,
batchTimestampMs, eventTimeWatermarkForEviction,
- outputStream, valueStateMap, transformWithStateInPySparkDeserializer,
arrowStreamWriter,
+ outputStream, valueStateMap, transformWithStateInPySparkDeserializer,
listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap,
expiryTimerIter, listTimerMap)
when(transformWithStateInPySparkDeserializer.readArrowBatches(any))
.thenReturn(Seq(getIntegerRow(1)))
@@ -278,9 +278,9 @@ class TransformWithStateInPySparkStateServerSuite extends
SparkFunSuite with Bef
val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId ->
Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3),
getIntegerRow(4)))
stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
- statefulProcessorHandle, groupingKeySchema, "", false, false,
+ statefulProcessorHandle, groupingKeySchema,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream,
- valueStateMap, transformWithStateInPySparkDeserializer,
arrowStreamWriter,
+ valueStateMap, transformWithStateInPySparkDeserializer,
listStateMap, iteratorMap)
// First call should send 2 records.
stateServer.handleListStateRequest(message)
@@ -307,9 +307,9 @@ class TransformWithStateInPySparkStateServerSuite extends
SparkFunSuite with Bef
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
- statefulProcessorHandle, groupingKeySchema, "", false, false,
+ statefulProcessorHandle, groupingKeySchema,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream,
- valueStateMap, transformWithStateInPySparkDeserializer,
arrowStreamWriter,
+ valueStateMap, transformWithStateInPySparkDeserializer,
listStateMap, iteratorMap)
when(listState.get()).thenReturn(Iterator(getIntegerRow(1),
getIntegerRow(2), getIntegerRow(3)))
stateServer.handleListStateRequest(message)
@@ -419,9 +419,9 @@ class TransformWithStateInPySparkStateServerSuite extends
SparkFunSuite with Bef
Iterator((getIntegerRow(1), getIntegerRow(1)), (getIntegerRow(2),
getIntegerRow(2)),
(getIntegerRow(3), getIntegerRow(3)), (getIntegerRow(4),
getIntegerRow(4))))
stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
- statefulProcessorHandle, groupingKeySchema, "", false, false,
+ statefulProcessorHandle, groupingKeySchema,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream,
- valueStateMap, transformWithStateInPySparkDeserializer,
arrowStreamWriter,
+ valueStateMap, transformWithStateInPySparkDeserializer,
listStateMap, null, mapStateMap, keyValueIteratorMap)
// First call should send 2 records.
stateServer.handleMapStateRequest(message)
@@ -448,10 +448,10 @@ class TransformWithStateInPySparkStateServerSuite extends
SparkFunSuite with Bef
.setIterator(StateMessage.Iterator.newBuilder().setIteratorId(iteratorId).build()).build()
val keyValueIteratorMap: mutable.HashMap[String, Iterator[(Row, Row)]] =
mutable.HashMap()
stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
- statefulProcessorHandle, groupingKeySchema, "", false, false,
+ statefulProcessorHandle, groupingKeySchema,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPySparkDeserializer,
- arrowStreamWriter, listStateMap, null, mapStateMap, keyValueIteratorMap)
+ listStateMap, null, mapStateMap, keyValueIteratorMap)
when(mapState.iterator()).thenReturn(Iterator((getIntegerRow(1),
getIntegerRow(1)),
(getIntegerRow(2), getIntegerRow(2)), (getIntegerRow(3),
getIntegerRow(3))))
stateServer.handleMapStateRequest(message)
@@ -481,10 +481,10 @@ class TransformWithStateInPySparkStateServerSuite extends
SparkFunSuite with Bef
.setKeys(Keys.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
- statefulProcessorHandle, groupingKeySchema, "", false, false,
+ statefulProcessorHandle, groupingKeySchema,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPySparkDeserializer,
- arrowStreamWriter, listStateMap, iteratorMap, mapStateMap)
+ listStateMap, iteratorMap, mapStateMap)
when(mapState.keys()).thenReturn(Iterator(getIntegerRow(1),
getIntegerRow(2), getIntegerRow(3)))
stateServer.handleMapStateRequest(message)
verify(mapState).keys()
@@ -513,10 +513,10 @@ class TransformWithStateInPySparkStateServerSuite extends
SparkFunSuite with Bef
.setValues(Values.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
- statefulProcessorHandle, groupingKeySchema, "", false, false,
+ statefulProcessorHandle, groupingKeySchema,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream,
valueStateMap, transformWithStateInPySparkDeserializer,
- arrowStreamWriter, listStateMap, iteratorMap, mapStateMap)
+ listStateMap, iteratorMap, mapStateMap)
when(mapState.values()).thenReturn(Iterator(getIntegerRow(1),
getIntegerRow(2),
getIntegerRow(3)))
stateServer.handleMapStateRequest(message)
@@ -611,10 +611,9 @@ class TransformWithStateInPySparkStateServerSuite extends
SparkFunSuite with Bef
.build()
).build()
stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
- statefulProcessorHandle, groupingKeySchema, "", false, false,
+ statefulProcessorHandle, groupingKeySchema,
2, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
- valueStateMap, transformWithStateInPySparkDeserializer,
- arrowStreamWriter, listStateMap, null, mapStateMap, null,
+ valueStateMap, transformWithStateInPySparkDeserializer, listStateMap,
null, mapStateMap, null,
null, listTimerMap)
when(statefulProcessorHandle.listTimers()).thenReturn(Iterator(1))
stateServer.handleStatefulProcessorCall(message)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]