This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 76ea2ddef perf: Coalesce broadcast exchange batches before
broadcasting (#3703)
76ea2ddef is described below
commit 76ea2ddef267ebc7c40bde02a7fa0d146f94ff6b
Author: Matt Butrovich <[email protected]>
AuthorDate: Mon Mar 16 18:40:12 2026 -0400
perf: Coalesce broadcast exchange batches before broadcasting (#3703)
---
.../org/apache/spark/sql/comet/util/Utils.scala | 111 ++++++++++++++++++++-
.../sql/comet/CometBroadcastExchangeExec.scala | 17 +++-
.../org/apache/comet/exec/CometExecSuite.scala | 12 ++-
.../org/apache/comet/exec/CometJoinSuite.scala | 84 +++++++++++++---
4 files changed, 205 insertions(+), 19 deletions(-)
diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
index 6eaa9cad4..78f2e81c7 100644
--- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
+++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
@@ -26,13 +26,15 @@ import java.nio.channels.Channels
import scala.jdk.CollectionConverters._
import org.apache.arrow.c.CDataDictionaryProvider
-import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector,
DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector,
IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector,
TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector,
VarCharVector, VectorSchemaRoot}
+import org.apache.arrow.vector._
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
import org.apache.arrow.vector.dictionary.DictionaryProvider
-import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.arrow.vector.types._
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.arrow.vector.util.VectorSchemaRootAppender
import org.apache.spark.{SparkEnv, SparkException}
+import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
import org.apache.spark.sql.types._
@@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV
import org.apache.comet.shims.CometTypeShim
import org.apache.comet.vector.CometVector
-object Utils extends CometTypeShim {
+object Utils extends CometTypeShim with Logging {
def getConfPath(confFileName: String): String = {
sys.env
.get(COMET_CONF_DIR_ENV)
@@ -232,6 +234,7 @@ object Utils extends CometTypeShim {
/**
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
+ *
* @param bytes
* the serialized batches
* @param source
@@ -252,6 +255,108 @@ object Utils extends CometTypeShim {
new ArrowReaderIterator(Channels.newChannel(ins), source)
}
+ /**
+ * Coalesces many small Arrow IPC batches into a single batch for
broadcasting.
+ *
+ * Why this is necessary: The broadcast exchange collects shuffle output by
calling
+ * getByteArrayRdd, which serializes each ColumnarBatch independently into
its own
+ * ChunkedByteBuffer. The shuffle reader (CometBlockStoreShuffleReader)
produces one
+ * ColumnarBatch per shuffle block, and there is one block per writer task
per output partition.
+ * So with W writer tasks and P output partitions, the broadcast collects up
to W * P tiny
+ * batches. For example, with 400 writer tasks and 500 partitions, 1M rows
would arrive as ~200K
+ * batches of ~5 rows each.
+ *
+ * Without coalescing, every consumer task in the broadcast join would
independently deserialize
+ * all of these tiny Arrow IPC streams, paying per-stream overhead (schema
parsing, buffer
+ * allocation) for each one. With coalescing, we decode and append all
batches into one
+ * VectorSchemaRoot on the driver, then re-serialize once. Each consumer
task then deserializes
+ * a single Arrow IPC stream.
+ */
+ def coalesceBroadcastBatches(
+ input: Iterator[ChunkedByteBuffer]): (Array[ChunkedByteBuffer], Long,
Long) = {
+ val buffers = input.filterNot(_.size == 0).toArray
+ if (buffers.isEmpty) {
+ return (Array.empty, 0L, 0L)
+ }
+
+ val allocator = org.apache.comet.CometArrowAllocator
+ .newChildAllocator("broadcast-coalesce", 0, Long.MaxValue)
+ try {
+ var targetRoot: VectorSchemaRoot = null
+ var totalRows = 0L
+ var batchCount = 0
+
+ val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+ try {
+ for (bytes <- buffers) {
+ val compressedInputStream =
+ new
DataInputStream(codec.compressedInputStream(bytes.toInputStream()))
+ val reader =
+ new ArrowStreamReader(Channels.newChannel(compressedInputStream),
allocator)
+ try {
+ // Comet decodes dictionaries during execution, so this shouldn't
happen.
+ // If it does, fall back to the original uncoalesced buffers
because each
+ // partition can have a different dictionary, and appending index
vectors
+ // would silently mix indices from incompatible dictionaries.
+ if (!reader.getDictionaryVectors.isEmpty) {
+ logWarning(
+ "Unexpected dictionary-encoded column during BroadcastExchange
coalescing; " +
+ "skipping coalesce")
+ reader.close()
+ if (targetRoot != null) {
+ targetRoot.close()
+ targetRoot = null
+ }
+ return (buffers, 0L, 0L)
+ }
+ while (reader.loadNextBatch()) {
+ val sourceRoot = reader.getVectorSchemaRoot
+ if (targetRoot == null) {
+ targetRoot = VectorSchemaRoot.create(sourceRoot.getSchema,
allocator)
+ targetRoot.allocateNew()
+ }
+ VectorSchemaRootAppender.append(targetRoot, sourceRoot)
+ totalRows += sourceRoot.getRowCount
+ batchCount += 1
+ }
+ } finally {
+ reader.close()
+ }
+ }
+
+ if (targetRoot == null) {
+ return (Array.empty, 0L, 0L)
+ }
+
+ assert(
+ targetRoot.getRowCount.toLong == totalRows,
+ s"Row count mismatch after coalesce: ${targetRoot.getRowCount} !=
$totalRows")
+
+ logInfo(s"Coalesced $batchCount broadcast batches into 1 ($totalRows
rows)")
+
+ val outputStream = new ChunkedByteBufferOutputStream(1024 * 1024,
ByteBuffer.allocate)
+ val compressedOutputStream =
+ new DataOutputStream(codec.compressedOutputStream(outputStream))
+ val writer =
+ new ArrowStreamWriter(targetRoot, null,
Channels.newChannel(compressedOutputStream))
+ try {
+ writer.start()
+ writer.writeBatch()
+ } finally {
+ writer.close()
+ }
+
+ (Array(outputStream.toChunkedByteBuffer), batchCount.toLong, totalRows)
+ } finally {
+ if (targetRoot != null) {
+ targetRoot.close()
+ }
+ }
+ } finally {
+ allocator.close()
+ }
+ }
+
def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index f40e05ea0..4a323e575 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -77,7 +77,13 @@ case class CometBroadcastExchangeExec(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output
rows"),
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to
collect"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to
build"),
- "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to
broadcast"))
+ "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to
broadcast"),
+ "numCoalescedBatches" -> SQLMetrics.createMetric(
+ sparkContext,
+ "number of coalesced batches for broadcast"),
+ "numCoalescedRows" -> SQLMetrics.createMetric(
+ sparkContext,
+ "number of coalesced rows for broadcast"))
override def doCanonicalize(): SparkPlan = {
CometBroadcastExchangeExec(null, null, mode, child.canonicalized)
@@ -155,7 +161,14 @@ case class CometBroadcastExchangeExec(
val beforeBuild = System.nanoTime()
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild -
beforeCollect)
- val batches = input.toArray
+ // Coalesce the many small per-shuffle-block buffers into a single
buffer.
+ // Without this, each consumer task deserializes one Arrow IPC stream
per
+ // shuffle block (one per writer task per partition), which is very
expensive
+ // when there are hundreds of writer tasks and partitions. See the
scaladoc
+ // on coalesceBroadcastBatches for details.
+ val (batches, coalescedBatches, coalescedRows) =
Utils.coalesceBroadcastBatches(input)
+ longMetric("numCoalescedBatches") += coalescedBatches
+ longMetric("numCoalescedRows") += coalescedRows
val dataSize = batches.map(_.size).sum
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 0bf9bbc95..aff181626 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression,
ExpressionInfo, He
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode,
BloomFilterAggregate}
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
-import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SparkPlan, SQLExecution, UnionExec}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
BroadcastQueryStageExec}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
@@ -474,6 +474,10 @@ class CometExecSuite extends CometTestBase {
val expected = (0 until numParts).flatMap(_ => (0 until 5).map(i =>
i + 1)).sorted
assert(rowContents === expected)
+
+ val metrics = nativeBroadcast.metrics
+ assert(metrics("numCoalescedBatches").value == 5L)
+ assert(metrics("numCoalescedRows").value == 5L)
}
}
}
@@ -493,6 +497,10 @@ class CometExecSuite extends CometTestBase {
}.get.asInstanceOf[CometBroadcastExchangeExec]
val rows = nativeBroadcast.executeCollect()
assert(rows.isEmpty)
+
+ val metrics = nativeBroadcast.metrics
+ assert(metrics("numCoalescedBatches").value == 0L)
+ assert(metrics("numCoalescedRows").value == 0L)
}
}
}
@@ -712,7 +720,7 @@ class CometExecSuite extends CometTestBase {
assert(metrics.contains("build_time"))
assert(metrics("build_time").value > 1L)
assert(metrics.contains("build_input_batches"))
- assert(metrics("build_input_batches").value == 25L)
+ assert(metrics("build_input_batches").value == 5L)
assert(metrics.contains("build_mem_used"))
assert(metrics("build_mem_used").value > 1L)
assert(metrics.contains("build_input_rows"))
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
index d5a8387be..49fbe10c3 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.comet.CometConf
class CometJoinSuite extends CometTestBase {
+
import testImplicits._
override protected def test(testName: String, testTags: Tag*)(testFun: =>
Any)(implicit
@@ -359,28 +360,87 @@ class CometJoinSuite extends CometTestBase {
checkSparkAnswer(left.join(right, ($"left.N" === $"right.N") &&
($"right.N" =!= 3), "full"))
checkSparkAnswer(sql("""
- |SELECT l.a, count(*)
- |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
- |GROUP BY l.a
+ |SELECT l.a, count(*)
+ |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
+ |GROUP BY l.a
""".stripMargin))
checkSparkAnswer(sql("""
- |SELECT r.N, count(*)
- |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
- |GROUP BY r.N
+ |SELECT r.N, count(*)
+ |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
+ |GROUP BY r.N
""".stripMargin))
checkSparkAnswer(sql("""
- |SELECT l.N, count(*)
- |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
- |GROUP BY l.N
+ |SELECT l.N, count(*)
+ |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
+ |GROUP BY l.N
""".stripMargin))
checkSparkAnswer(sql("""
- |SELECT r.a, count(*)
- |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
- |GROUP BY r.a
+ |SELECT r.a, count(*)
+ |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
+ |GROUP BY r.a
""".stripMargin))
}
}
+
+ test("Broadcast hash join build-side batch coalescing") {
+ // Use many shuffle partitions to produce many small broadcast batches,
+ // then verify that coalescing reduces the build-side batch count to 1 per
task.
+ val numPartitions = 512
+ withSQLConf(
+ CometConf.COMET_BATCH_SIZE.key -> "100",
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+ "spark.sql.join.forceApplyShuffledHashJoin" -> "true",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) {
+ withParquetTable((0 until 10000).map(i => (i, i % 5)), "tbl_a") {
+ withParquetTable((0 until 10000).map(i => (i % 10, i + 2)), "tbl_b") {
+ // Force a shuffle on tbl_a before broadcast so the broadcast source
has
+ // numPartitions partitions, not just the number of parquet files.
+ val query =
+ s"""SELECT /*+ BROADCAST(a) */ *
+ |FROM (SELECT /*+ REPARTITION($numPartitions) */ * FROM tbl_a) a
+ |JOIN tbl_b ON a._2 = tbl_b._1""".stripMargin
+
+ val (_, cometPlan) = checkSparkAnswerAndOperator(
+ sql(query),
+ Seq(classOf[CometBroadcastExchangeExec],
classOf[CometBroadcastHashJoinExec]))
+
+ val joins = collect(cometPlan) { case j: CometBroadcastHashJoinExec
=>
+ j
+ }
+ assert(joins.nonEmpty, "Expected CometBroadcastHashJoinExec in plan")
+
+ val join = joins.head
+ val buildBatches = join.metrics("build_input_batches").value
+
+ // Without coalescing, build_input_batches would be ~numPartitions
per task,
+ // totaling ~numPartitions * numPartitions across all tasks.
+ // With coalescing, each task gets 1 batch, so total ≈ numPartitions.
+ assert(
+ buildBatches <= numPartitions,
+ s"Expected at most $numPartitions build batches (1 per task), got
$buildBatches. " +
+ "Broadcast batch coalescing may not be working.")
+
+ val broadcasts = collect(cometPlan) { case b:
CometBroadcastExchangeExec =>
+ b
+ }
+ assert(broadcasts.nonEmpty, "Expected CometBroadcastExchangeExec in
plan")
+
+ val broadcast = broadcasts.head
+ val coalescedBatches = broadcast.metrics("numCoalescedBatches").value
+ val coalescedRows = broadcast.metrics("numCoalescedRows").value
+
+ assert(
+ coalescedBatches >= numPartitions,
+ s"Expected at least $numPartitions coalesced batches, got
$coalescedBatches")
+ assert(coalescedRows == 10000, s"Expected 10000 coalesced rows, got
$coalescedRows")
+ }
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]