This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e93c5fbe81d2 [SPARK-49301][SS] Chunk arrow data passed to Python worker e93c5fbe81d2 is described below commit e93c5fbe81d21f8bf2ce52867013d06a63c7956e Author: bogao007 <bo....@databricks.com> AuthorDate: Thu Aug 22 15:12:23 2024 +0900 [SPARK-49301][SS] Chunk arrow data passed to Python worker ### What changes were proposed in this pull request? - Add chunking logic to chunk arrow data into multiple batches if input rows are larger than `ARROW_EXECUTION_MAX_RECORDS_PER_BATCH` and pass to the Python worker to ensure scalability. - We added a new class `BaseStreamingArrowWriter` and let `ApplyInPandasWithStateWriter` reuse some of its methods. ### Why are the changes needed? If input rows have a giant size of records, our system should be able to handle it in a large scale. Without this change, we will try to send all the rows to the Python worker and may result with an OOM error. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47804 from bogao007/chunk-arrow-data. Authored-by: bogao007 <bo....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../python/ApplyInPandasWithStateWriter.scala | 36 +++------ .../python/BaseStreamingArrowWriter.scala | 85 ++++++++++++++++++++++ .../TransformWithStateInPandasPythonRunner.scala | 21 +++--- .../python/BaseStreamingArrowWriterSuite.scala | 67 +++++++++++++++++ 4 files changed, 171 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index 6c9c7e1179b6..db49be7fd99f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -50,7 +50,8 @@ import org.apache.spark.unsafe.types.UTF8String class ApplyInPandasWithStateWriter( root: VectorSchemaRoot, writer: ArrowStreamWriter, - arrowMaxRecordsPerBatch: Int) { + arrowMaxRecordsPerBatch: Int) + extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) { import ApplyInPandasWithStateWriter._ @@ -72,9 +73,9 @@ class ApplyInPandasWithStateWriter( // there are at least one data for a grouping key (we ensure this for the case of handling timed // out state as well) whereas there is only one state for a grouping key, we have to fill up the // empty rows in state side to ensure both have the same number of rows. - private val arrowWriterForData = createArrowWriter( + override protected val arrowWriterForData: ArrowWriter = createArrowWriter( root.getFieldVectors.asScala.toSeq.dropRight(1)) - private val arrowWriterForState = createArrowWriter( + private val arrowWriterForState: ArrowWriter = createArrowWriter( root.getFieldVectors.asScala.toSeq.takeRight(1)) // - Bin-packing @@ -117,12 +118,10 @@ class ApplyInPandasWithStateWriter( private var currentGroupState: GroupStateImpl[Row] = _ // variables for tracking the status of current batch - private var totalNumRowsForBatch = 0 private var totalNumStatesForBatch = 0 // variables for tracking the status of current chunk private var startOffsetForCurrentChunk = 0 - private var numRowsForCurrentChunk = 0 /** @@ -136,26 +135,6 @@ class ApplyInPandasWithStateWriter( currentGroupState = groupState } - /** - * Indicates writer to write a row in the current group. - * - * @param dataRow The row to write in the current group. - */ - def writeRow(dataRow: InternalRow): Unit = { - // If it exceeds the condition of batch (number of records) and there is more data for the - // same group, finalize and construct a new batch. - - if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { - finalizeCurrentChunk(isLastChunkForGroup = false) - finalizeCurrentArrowBatch() - } - - arrowWriterForData.write(dataRow) - - numRowsForCurrentChunk += 1 - totalNumRowsForBatch += 1 - } - /** * Indicates writer that current group has finalized and there will be no further row bound to * the current group. @@ -209,7 +188,7 @@ class ApplyInPandasWithStateWriter( new GenericInternalRow(Array[Any](stateUnderlyingRow)) } - private def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = { + override protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = { val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, startOffsetForCurrentChunk, numRowsForCurrentChunk, isLastChunkForGroup) arrowWriterForState.write(stateInfoRow) @@ -221,7 +200,10 @@ class ApplyInPandasWithStateWriter( numRowsForCurrentChunk = 0 } - private def finalizeCurrentArrowBatch(): Unit = { + /** + * Finalizes the current batch and writes it to the Arrow stream. + */ + override def finalizeCurrentArrowBatch(): Unit = { val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch (0 until remainingEmptyStateRows).foreach { _ => arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriter.scala new file mode 100644 index 000000000000..303389cee096 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriter.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter + +/** + * Base class to handle writing data to Arrow stream to Python workers. When the rows + * for a group exceed the maximum number of records per batch, we chunk the data into multiple + * batches. + */ +class BaseStreamingArrowWriter( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + arrowMaxRecordsPerBatch: Int, + arrowWriterForTest: ArrowWriter = null) { + protected val arrowWriterForData: ArrowWriter = if (arrowWriterForTest == null) { + ArrowWriter.create(root) + } else { + arrowWriterForTest + } + + // variables for tracking the status of current batch + protected var totalNumRowsForBatch = 0 + + // variables for tracking the status of current chunk + protected var numRowsForCurrentChunk = 0 + + /** + * Indicates writer to write a row for current batch. + * + * @param dataRow The row to write for current batch. + */ + def writeRow(dataRow: InternalRow): Unit = { + // If it exceeds the condition of batch (number of records) and there is more data for the + // same group, finalize and construct a new batch. + + if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { + finalizeCurrentChunk(isLastChunkForGroup = false) + finalizeCurrentArrowBatch() + } + + arrowWriterForData.write(dataRow) + + numRowsForCurrentChunk += 1 + totalNumRowsForBatch += 1 + } + + /** + * Finalizes the current batch and writes it to the Arrow stream. + */ + def finalizeCurrentArrowBatch(): Unit = { + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForData.reset() + totalNumRowsForBatch = 0 + } + + /** + * Finalizes the current chunk. We only reset the number of rows for the current chunk here since + * not all the writers need this step. + */ + protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = { + numRowsForCurrentChunk = 0 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala index d549ddba7c8c..7d0c177d1df8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala @@ -29,8 +29,6 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.arrow -import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType, OutType} import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleImpl @@ -66,8 +64,6 @@ class TransformWithStateInPandasPythonRunner( override protected val workerConf: Map[String, String] = initialWorkerConf + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) - private val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root) - // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s // constructor. override protected lazy val schema: StructType = _schema @@ -132,24 +128,27 @@ class TransformWithStateInPandasPythonRunner( PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None) } + private var pandasWriter: BaseStreamingArrowWriter = _ + override protected def writeNextInputToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, dataOut: DataOutputStream, inputIterator: Iterator[InType]): Boolean = { + if (pandasWriter == null) { + pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) + } if (inputIterator.hasNext) { val startData = dataOut.size() val next = inputIterator.next() - val nextBatch = next._2 + val dataIter = next._2 - while (nextBatch.hasNext) { - arrowWriter.write(nextBatch.next()) + while (dataIter.hasNext) { + val dataRow = dataIter.next() + pandasWriter.writeRow(dataRow) } - - arrowWriter.finish() - writer.writeBatch() - arrowWriter.reset() + pandasWriter.finalizeCurrentArrowBatch() val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriterSuite.scala new file mode 100644 index 000000000000..0417a839dc6b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BaseStreamingArrowWriterSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.mockito.Mockito.{mock, never, times, verify} +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter + +class BaseStreamingArrowWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + // Setting the maximum number of records per batch to 2 to make test easier. + val arrowMaxRecordsPerBatch = 2 + var transformWithStateInPandasWriter: BaseStreamingArrowWriter = _ + var arrowWriter: ArrowWriter = _ + var writer: ArrowStreamWriter = _ + + override def beforeEach(): Unit = { + val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot]) + writer = mock(classOf[ArrowStreamWriter]) + arrowWriter = mock(classOf[ArrowWriter]) + transformWithStateInPandasWriter = new BaseStreamingArrowWriter( + root, writer, arrowMaxRecordsPerBatch, arrowWriter) + } + + test("test writeRow") { + val dataRow = mock(classOf[InternalRow]) + // Write 2 rows first, batch is not finalized. + transformWithStateInPandasWriter.writeRow(dataRow) + transformWithStateInPandasWriter.writeRow(dataRow) + verify(arrowWriter, times(2)).write(dataRow) + verify(writer, never()).writeBatch() + // Write a 3rd row, batch is finalized. + transformWithStateInPandasWriter.writeRow(dataRow) + verify(arrowWriter, times(3)).write(dataRow) + verify(writer).writeBatch() + // Write 2 more rows, a new batch is finalized. + transformWithStateInPandasWriter.writeRow(dataRow) + transformWithStateInPandasWriter.writeRow(dataRow) + verify(arrowWriter, times(5)).write(dataRow) + verify(writer, times(2)).writeBatch() + } + + test("test finalizeCurrentArrowBatch") { + transformWithStateInPandasWriter.finalizeCurrentArrowBatch() + verify(arrowWriter).finish() + verify(writer).writeBatch() + verify(arrowWriter).reset() + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org