This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e93c5fbe81d2 [SPARK-49301][SS] Chunk arrow data passed to Python worker
e93c5fbe81d2 is described below

commit e93c5fbe81d21f8bf2ce52867013d06a63c7956e
Author: bogao007 <bo....@databricks.com>
AuthorDate: Thu Aug 22 15:12:23 2024 +0900

    [SPARK-49301][SS] Chunk arrow data passed to Python worker
    
    ### What changes were proposed in this pull request?
    
    - Add chunking logic to chunk arrow data into multiple batches if input 
rows are larger than `ARROW_EXECUTION_MAX_RECORDS_PER_BATCH` and pass to the 
Python worker to ensure scalability.
    - We added a new class `BaseStreamingArrowWriter` and let 
`ApplyInPandasWithStateWriter` reuse some of its methods.
    
    ### Why are the changes needed?
    
    If input rows have a giant size of records, our system should be able to 
handle it in a large scale. Without this change, we will try to send all the 
rows to the Python worker and may result with an OOM error.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added unit test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #47804 from bogao007/chunk-arrow-data.
    
    Authored-by: bogao007 <bo....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../python/ApplyInPandasWithStateWriter.scala      | 36 +++------
 .../python/BaseStreamingArrowWriter.scala          | 85 ++++++++++++++++++++++
 .../TransformWithStateInPandasPythonRunner.scala   | 21 +++---
 .../python/BaseStreamingArrowWriterSuite.scala     | 67 +++++++++++++++++
 4 files changed, 171 insertions(+), 38 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
index 6c9c7e1179b6..db49be7fd99f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
@@ -50,7 +50,8 @@ import org.apache.spark.unsafe.types.UTF8String
 class ApplyInPandasWithStateWriter(
     root: VectorSchemaRoot,
     writer: ArrowStreamWriter,
-    arrowMaxRecordsPerBatch: Int) {
+    arrowMaxRecordsPerBatch: Int)
+  extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) {
 
   import ApplyInPandasWithStateWriter._
 
@@ -72,9 +73,9 @@ class ApplyInPandasWithStateWriter(
   // there are at least one data for a grouping key (we ensure this for the 
case of handling timed
   // out state as well) whereas there is only one state for a grouping key, we 
have to fill up the
   // empty rows in state side to ensure both have the same number of rows.
-  private val arrowWriterForData = createArrowWriter(
+  override protected val arrowWriterForData: ArrowWriter = createArrowWriter(
     root.getFieldVectors.asScala.toSeq.dropRight(1))
-  private val arrowWriterForState = createArrowWriter(
+  private val arrowWriterForState: ArrowWriter = createArrowWriter(
     root.getFieldVectors.asScala.toSeq.takeRight(1))
 
   // - Bin-packing
@@ -117,12 +118,10 @@ class ApplyInPandasWithStateWriter(
   private var currentGroupState: GroupStateImpl[Row] = _
 
   // variables for tracking the status of current batch
-  private var totalNumRowsForBatch = 0
   private var totalNumStatesForBatch = 0
 
   // variables for tracking the status of current chunk
   private var startOffsetForCurrentChunk = 0
-  private var numRowsForCurrentChunk = 0
 
 
   /**
@@ -136,26 +135,6 @@ class ApplyInPandasWithStateWriter(
     currentGroupState = groupState
   }
 
-  /**
-   * Indicates writer to write a row in the current group.
-   *
-   * @param dataRow The row to write in the current group.
-   */
-  def writeRow(dataRow: InternalRow): Unit = {
-    // If it exceeds the condition of batch (number of records) and there is 
more data for the
-    // same group, finalize and construct a new batch.
-
-    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
-      finalizeCurrentChunk(isLastChunkForGroup = false)
-      finalizeCurrentArrowBatch()
-    }
-
-    arrowWriterForData.write(dataRow)
-
-    numRowsForCurrentChunk += 1
-    totalNumRowsForBatch += 1
-  }
-
   /**
    * Indicates writer that current group has finalized and there will be no 
further row bound to
    * the current group.
@@ -209,7 +188,7 @@ class ApplyInPandasWithStateWriter(
     new GenericInternalRow(Array[Any](stateUnderlyingRow))
   }
 
-  private def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = {
+  override protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): 
Unit = {
     val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState,
       startOffsetForCurrentChunk, numRowsForCurrentChunk, isLastChunkForGroup)
     arrowWriterForState.write(stateInfoRow)
@@ -221,7 +200,10 @@ class ApplyInPandasWithStateWriter(
     numRowsForCurrentChunk = 0
   }
 
-  private def finalizeCurrentArrowBatch(): Unit = {
+  /**
+   * Finalizes the current batch and writes it to the Arrow stream.
+   */
+  override def finalizeCurrentArrowBatch(): Unit = {
     val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch
     (0 until remainingEmptyStateRows).foreach { _ =>
       arrowWriterForState.write(EMPTY_STATE_METADATA_ROW)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriter.scala
new file mode 100644
index 000000000000..303389cee096
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriter.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+
+/**
+ * Base class to handle writing data to Arrow stream to Python workers. When 
the rows
+ * for a group exceed the maximum number of records per batch, we chunk the 
data into multiple
+ * batches.
+ */
+class BaseStreamingArrowWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    arrowMaxRecordsPerBatch: Int,
+    arrowWriterForTest: ArrowWriter = null) {
+  protected val arrowWriterForData: ArrowWriter = if (arrowWriterForTest == 
null) {
+    ArrowWriter.create(root)
+  } else {
+    arrowWriterForTest
+  }
+
+  // variables for tracking the status of current batch
+  protected var totalNumRowsForBatch = 0
+
+  // variables for tracking the status of current chunk
+  protected var numRowsForCurrentChunk = 0
+
+  /**
+   * Indicates writer to write a row for current batch.
+   *
+   * @param dataRow The row to write for current batch.
+   */
+  def writeRow(dataRow: InternalRow): Unit = {
+    // If it exceeds the condition of batch (number of records) and there is 
more data for the
+    // same group, finalize and construct a new batch.
+
+    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+      finalizeCurrentChunk(isLastChunkForGroup = false)
+      finalizeCurrentArrowBatch()
+    }
+
+    arrowWriterForData.write(dataRow)
+
+    numRowsForCurrentChunk += 1
+    totalNumRowsForBatch += 1
+  }
+
+  /**
+   * Finalizes the current batch and writes it to the Arrow stream.
+   */
+  def finalizeCurrentArrowBatch(): Unit = {
+    arrowWriterForData.finish()
+    writer.writeBatch()
+    arrowWriterForData.reset()
+    totalNumRowsForBatch = 0
+  }
+
+  /**
+   * Finalizes the current chunk. We only reset the number of rows for the 
current chunk here since
+   * not all the writers need this step.
+   */
+  protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = {
+    numRowsForCurrentChunk = 0
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
index d549ddba7c8c..7d0c177d1df8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
@@ -29,8 +29,6 @@ import org.apache.spark.TaskContext
 import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.arrow
-import org.apache.spark.sql.execution.arrow.ArrowWriter
 import org.apache.spark.sql.execution.metric.SQLMetric
 import 
org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType,
 OutType}
 import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleImpl
@@ -66,8 +64,6 @@ class TransformWithStateInPandasPythonRunner(
   override protected val workerConf: Map[String, String] = initialWorkerConf +
     (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
arrowMaxRecordsPerBatch.toString)
 
-  private val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
-
   // Use lazy val to initialize the fields before these are accessed in 
[[PythonArrowInput]]'s
   // constructor.
   override protected lazy val schema: StructType = _schema
@@ -132,24 +128,27 @@ class TransformWithStateInPandasPythonRunner(
     PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None)
   }
 
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
   override protected def writeNextInputToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
       dataOut: DataOutputStream,
       inputIterator: Iterator[InType]): Boolean = {
+    if (pandasWriter == null) {
+      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+    }
 
     if (inputIterator.hasNext) {
       val startData = dataOut.size()
       val next = inputIterator.next()
-      val nextBatch = next._2
+      val dataIter = next._2
 
-      while (nextBatch.hasNext) {
-        arrowWriter.write(nextBatch.next())
+      while (dataIter.hasNext) {
+        val dataRow = dataIter.next()
+        pandasWriter.writeRow(dataRow)
       }
-
-      arrowWriter.finish()
-      writer.writeBatch()
-      arrowWriter.reset()
+      pandasWriter.finalizeCurrentArrowBatch()
       val deltaData = dataOut.size() - startData
       pythonMetrics("pythonDataSent") += deltaData
       true
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriterSuite.scala
new file mode 100644
index 000000000000..0417a839dc6b
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriterSuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.mockito.Mockito.{mock, never, times, verify}
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+
+class BaseStreamingArrowWriterSuite extends SparkFunSuite with 
BeforeAndAfterEach {
+  // Setting the maximum number of records per batch to 2 to make test easier.
+  val arrowMaxRecordsPerBatch = 2
+  var transformWithStateInPandasWriter: BaseStreamingArrowWriter = _
+  var arrowWriter: ArrowWriter = _
+  var writer: ArrowStreamWriter = _
+
+  override def beforeEach(): Unit = {
+    val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot])
+    writer = mock(classOf[ArrowStreamWriter])
+    arrowWriter = mock(classOf[ArrowWriter])
+    transformWithStateInPandasWriter = new BaseStreamingArrowWriter(
+      root, writer, arrowMaxRecordsPerBatch, arrowWriter)
+  }
+
+  test("test writeRow") {
+    val dataRow = mock(classOf[InternalRow])
+    // Write 2 rows first, batch is not finalized.
+    transformWithStateInPandasWriter.writeRow(dataRow)
+    transformWithStateInPandasWriter.writeRow(dataRow)
+    verify(arrowWriter, times(2)).write(dataRow)
+    verify(writer, never()).writeBatch()
+    // Write a 3rd row, batch is finalized.
+    transformWithStateInPandasWriter.writeRow(dataRow)
+    verify(arrowWriter, times(3)).write(dataRow)
+    verify(writer).writeBatch()
+    // Write 2 more rows, a new batch is finalized.
+    transformWithStateInPandasWriter.writeRow(dataRow)
+    transformWithStateInPandasWriter.writeRow(dataRow)
+    verify(arrowWriter, times(5)).write(dataRow)
+    verify(writer, times(2)).writeBatch()
+  }
+
+  test("test finalizeCurrentArrowBatch") {
+    transformWithStateInPandasWriter.finalizeCurrentArrowBatch()
+    verify(arrowWriter).finish()
+    verify(writer).writeBatch()
+    verify(arrowWriter).reset()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to