Kontinuation commented on code in PR #3845:
URL: https://github.com/apache/datafusion-comet/pull/3845#discussion_r3031129141
##########
docs/source/contributor-guide/native_shuffle.md:
##########
@@ -81,10 +81,18 @@ Native shuffle (`CometExchange`) is selected when all of
the following condition
└─────────────────────────────────────────────────────────────────────────────┘
│ │
▼ ▼
-┌───────────────────────────────────┐ ┌───────────────────────────────────┐
-│ MultiPartitionShuffleRepartitioner │ │ SinglePartitionShufflePartitioner │
-│ (hash/range partitioning) │ │ (single partition case) │
-└───────────────────────────────────┘ └───────────────────────────────────┘
+┌───────────────────────────────────────────────────────────────────────┐
+│ Partitioner Selection │
+│ Controlled by spark.comet.exec.shuffle.partitionerMode │
+├───────────────────────────┬───────────────────────────────────────────┤
+│ immediate (default) │ buffered │
+│ ImmediateModePartitioner │ MultiPartitionShuffleRepartitioner │
+│ (hash/range/round-robin) │ (hash/range/round-robin) │
+│ Writes IPC blocks as │ Buffers all rows in memory │
+│ batches arrive │ before writing │
+├───────────────────────────┴───────────────────────────────────────────┤
Review Comment:
```suggestion
│ Writes IPC blocks as │ Buffers all rows in memory │
│ batches arrive │ before writing │
├───────────────────────────┴───────────────────────────────────────────┤
```
##########
native/shuffle/src/partitioners/immediate_mode.rs:
##########
@@ -0,0 +1,1089 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::metrics::ShufflePartitionerMetrics;
+use crate::partitioners::ShufflePartitioner;
+use crate::{comet_partitioning, CometPartitioning, CompressionCodec};
+use arrow::array::builder::{
+ make_builder, ArrayBuilder, BinaryBuilder, BinaryViewBuilder,
BooleanBuilder,
+ LargeBinaryBuilder, LargeStringBuilder, NullBuilder, PrimitiveBuilder,
StringBuilder,
+ StringViewBuilder,
+};
+use arrow::array::{
+ Array, ArrayRef, AsArray, BinaryViewArray, RecordBatch, StringViewArray,
UInt32Array,
+};
+use arrow::compute::take;
+use arrow::datatypes::{
+ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float32Type, Float64Type,
+ Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef, TimeUnit,
TimestampMicrosecondType,
+ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt16Type, UInt32Type,
+ UInt64Type, UInt8Type,
+};
+use arrow::ipc::writer::StreamWriter;
+use datafusion::common::{DataFusionError, Result};
+use datafusion::execution::memory_pool::{MemoryConsumer, MemoryLimit,
MemoryReservation};
+use datafusion::execution::runtime_env::RuntimeEnv;
+use datafusion_comet_spark_expr::murmur3::create_murmur3_hashes;
+use std::fs::{File, OpenOptions};
+use std::io::{BufWriter, Seek, Write};
+use std::sync::Arc;
+use tokio::time::Instant;
+
+macro_rules! scatter_byte_array {
+ ($builder:expr, $source:expr, $indices:expr, $offset_type:ty,
$builder_type:ty, $cast:ident) => {{
+ let src = $source.$cast::<$offset_type>();
+ let dst = $builder
+ .as_any_mut()
+ .downcast_mut::<$builder_type>()
+ .expect("builder type mismatch");
+ if src.null_count() == 0 {
+ for &idx in $indices {
+ dst.append_value(src.value(idx));
+ }
+ } else {
+ for &idx in $indices {
+ dst.append_option(src.is_valid(idx).then(|| src.value(idx)));
+ }
+ }
+ }};
+}
+
+macro_rules! scatter_byte_view {
+ ($builder:expr, $source:expr, $indices:expr, $array_type:ty,
$builder_type:ty) => {{
+ let src = $source
+ .as_any()
+ .downcast_ref::<$array_type>()
+ .expect("array type mismatch");
+ let dst = $builder
+ .as_any_mut()
+ .downcast_mut::<$builder_type>()
+ .expect("builder type mismatch");
+ if src.null_count() == 0 {
+ for &idx in $indices {
+ dst.append_value(src.value(idx));
+ }
+ } else {
+ for &idx in $indices {
+ dst.append_option(src.is_valid(idx).then(|| src.value(idx)));
+ }
+ }
+ }};
+}
+
+macro_rules! scatter_primitive {
+ ($builder:expr, $source:expr, $indices:expr, $arrow_type:ty) => {{
+ let src = $source.as_primitive::<$arrow_type>();
+ let dst = $builder
+ .as_any_mut()
+ .downcast_mut::<PrimitiveBuilder<$arrow_type>>()
+ .expect("builder type mismatch");
+ if src.null_count() == 0 {
+ for &idx in $indices {
+ dst.append_value(src.value(idx));
+ }
+ } else {
+ for &idx in $indices {
+ dst.append_option(src.is_valid(idx).then(|| src.value(idx)));
+ }
+ }
+ }};
+}
+
+/// Scatter-append selected rows from `source` into `builder`.
+fn scatter_append(
+ builder: &mut dyn ArrayBuilder,
+ source: &dyn Array,
+ indices: &[usize],
+) -> Result<()> {
+ use DataType::*;
+ match source.data_type() {
+ Boolean => {
+ let src = source.as_boolean();
+ let dst = builder
+ .as_any_mut()
+ .downcast_mut::<BooleanBuilder>()
+ .unwrap();
+ if src.null_count() == 0 {
+ for &idx in indices {
+ dst.append_value(src.value(idx));
+ }
+ } else {
+ for &idx in indices {
+ dst.append_option(src.is_valid(idx).then(||
src.value(idx)));
+ }
+ }
+ }
+ Int8 => scatter_primitive!(builder, source, indices, Int8Type),
+ Int16 => scatter_primitive!(builder, source, indices, Int16Type),
+ Int32 => scatter_primitive!(builder, source, indices, Int32Type),
+ Int64 => scatter_primitive!(builder, source, indices, Int64Type),
+ UInt8 => scatter_primitive!(builder, source, indices, UInt8Type),
+ UInt16 => scatter_primitive!(builder, source, indices, UInt16Type),
+ UInt32 => scatter_primitive!(builder, source, indices, UInt32Type),
+ UInt64 => scatter_primitive!(builder, source, indices, UInt64Type),
+ Float32 => scatter_primitive!(builder, source, indices, Float32Type),
+ Float64 => scatter_primitive!(builder, source, indices, Float64Type),
+ Date32 => scatter_primitive!(builder, source, indices, Date32Type),
+ Date64 => scatter_primitive!(builder, source, indices, Date64Type),
+ Timestamp(TimeUnit::Second, _) => {
+ scatter_primitive!(builder, source, indices, TimestampSecondType)
+ }
+ Timestamp(TimeUnit::Millisecond, _) => {
+ scatter_primitive!(builder, source, indices,
TimestampMillisecondType)
+ }
+ Timestamp(TimeUnit::Microsecond, _) => {
+ scatter_primitive!(builder, source, indices,
TimestampMicrosecondType)
+ }
+ Timestamp(TimeUnit::Nanosecond, _) => {
+ scatter_primitive!(builder, source, indices,
TimestampNanosecondType)
+ }
+ Decimal128(_, _) => scatter_primitive!(builder, source, indices,
Decimal128Type),
+ Decimal256(_, _) => scatter_primitive!(builder, source, indices,
Decimal256Type),
+ Utf8 => scatter_byte_array!(builder, source, indices, i32,
StringBuilder, as_string),
+ LargeUtf8 => {
+ scatter_byte_array!(builder, source, indices, i64,
LargeStringBuilder, as_string)
+ }
+ Binary => scatter_byte_array!(builder, source, indices, i32,
BinaryBuilder, as_binary),
+ LargeBinary => {
+ scatter_byte_array!(builder, source, indices, i64,
LargeBinaryBuilder, as_binary)
+ }
+ Utf8View => {
+ scatter_byte_view!(builder, source, indices, StringViewArray,
StringViewBuilder)
+ }
+ BinaryView => {
+ scatter_byte_view!(builder, source, indices, BinaryViewArray,
BinaryViewBuilder)
+ }
+ Null => {
+ let dst =
builder.as_any_mut().downcast_mut::<NullBuilder>().unwrap();
+ dst.append_nulls(indices.len());
+ }
+ dt => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Scatter append not implemented for {dt}"
+ )));
+ }
+ }
+ Ok(())
+}
+
+/// Per-column strategy: scatter-write via builder for primitive/string types,
+/// or accumulate taken sub-arrays for complex types (List, Map, Struct, etc.).
+enum ColumnBuffer {
+ /// Fast path: direct scatter into a pre-allocated builder.
+ Builder(Box<dyn ArrayBuilder>),
+ /// Fallback for complex types: accumulate `take`-produced sub-arrays,
+ /// concatenate at flush time.
+ Accumulator(Vec<ArrayRef>),
+}
+
+/// Returns true if `scatter_append` can handle this data type directly.
+fn has_scatter_support(dt: &DataType) -> bool {
+ use DataType::*;
+ matches!(
+ dt,
+ Boolean
+ | Int8
+ | Int16
+ | Int32
+ | Int64
+ | UInt8
+ | UInt16
+ | UInt32
+ | UInt64
+ | Float32
+ | Float64
+ | Date32
+ | Date64
+ | Timestamp(_, _)
+ | Decimal128(_, _)
+ | Decimal256(_, _)
+ | Utf8
+ | LargeUtf8
+ | Binary
+ | LargeBinary
+ | Utf8View
+ | BinaryView
+ | Null
+ )
+}
+
+struct PartitionBuffer {
+ columns: Vec<ColumnBuffer>,
+ schema: SchemaRef,
+ num_rows: usize,
+ target_batch_size: usize,
+}
+
+impl PartitionBuffer {
+ fn new(schema: &SchemaRef, target_batch_size: usize) -> Self {
+ let columns = schema
+ .fields()
+ .iter()
+ .map(|f| {
+ if has_scatter_support(f.data_type()) {
+ ColumnBuffer::Builder(make_builder(f.data_type(),
target_batch_size))
+ } else {
+ ColumnBuffer::Accumulator(Vec::new())
+ }
+ })
+ .collect();
+ Self {
+ columns,
+ schema: Arc::clone(schema),
+ num_rows: 0,
+ target_batch_size,
+ }
+ }
+
+ fn is_full(&self) -> bool {
+ self.num_rows >= self.target_batch_size
+ }
+
+ /// Finish all columns into a RecordBatch. Builders are reset (retaining
+ /// capacity); accumulators are concatenated and cleared.
+ fn flush(&mut self) -> Result<RecordBatch> {
+ let arrays: Vec<ArrayRef> = self
+ .columns
+ .iter_mut()
+ .map(|col| match col {
+ ColumnBuffer::Builder(b) => b.finish(),
+ ColumnBuffer::Accumulator(chunks) => {
+ let refs: Vec<&dyn Array> = chunks.iter().map(|a|
a.as_ref()).collect();
+ let result = arrow::compute::concat(&refs)
+ .expect("concat failed for accumulated arrays");
+ chunks.clear();
+ result
+ }
+ })
+ .collect();
+ let batch = RecordBatch::try_new(Arc::clone(&self.schema), arrays)
+ .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
+ self.num_rows = 0;
+ Ok(batch)
+ }
+
+ fn has_data(&self) -> bool {
+ self.num_rows > 0
+ }
+}
+
+pub(crate) struct PartitionOutputStream {
+ schema: SchemaRef,
+ codec: CompressionCodec,
+ buffer: Vec<u8>,
+}
+
+impl PartitionOutputStream {
+ pub(crate) fn try_new(schema: SchemaRef, codec: CompressionCodec) ->
Result<Self> {
+ Ok(Self {
+ schema,
+ codec,
+ buffer: Vec::new(),
+ })
+ }
+
+ fn write_ipc_block(&mut self, batch: &RecordBatch) -> Result<usize> {
+ let start_pos = self.buffer.len();
+
+ self.buffer.extend_from_slice(&0u64.to_le_bytes());
+ let field_count = self.schema.fields().len();
+ self.buffer
+ .extend_from_slice(&(field_count as u64).to_le_bytes());
+ let codec_tag: &[u8; 4] = match &self.codec {
+ CompressionCodec::Snappy => b"SNAP",
+ CompressionCodec::Lz4Frame => b"LZ4_",
+ CompressionCodec::Zstd(_) => b"ZSTD",
+ CompressionCodec::None => b"NONE",
+ };
+ self.buffer.extend_from_slice(codec_tag);
+
+ match &self.codec {
+ CompressionCodec::None => {
+ let mut w = StreamWriter::try_new(&mut self.buffer,
&batch.schema())?;
+ w.write(batch)?;
+ w.finish()?;
+ w.into_inner()?;
+ }
+ CompressionCodec::Lz4Frame => {
+ let mut wtr = lz4_flex::frame::FrameEncoder::new(&mut
self.buffer);
+ let mut w = StreamWriter::try_new(&mut wtr, &batch.schema())?;
+ w.write(batch)?;
+ w.finish()?;
+ wtr.finish().map_err(|e| {
+ DataFusionError::Execution(format!("lz4 compression error:
{e}"))
+ })?;
+ }
+ CompressionCodec::Zstd(level) => {
+ let enc = zstd::Encoder::new(&mut self.buffer, *level)?;
+ let mut w = StreamWriter::try_new(enc, &batch.schema())?;
+ w.write(batch)?;
+ w.finish()?;
+ w.into_inner()?.finish()?;
+ }
+ CompressionCodec::Snappy => {
+ let mut wtr = snap::write::FrameEncoder::new(&mut self.buffer);
+ let mut w = StreamWriter::try_new(&mut wtr, &batch.schema())?;
+ w.write(batch)?;
+ w.finish()?;
+ wtr.into_inner().map_err(|e| {
+ DataFusionError::Execution(format!("snappy compression
error: {e}"))
+ })?;
+ }
+ }
+
+ let end_pos = self.buffer.len();
+ let ipc_length = (end_pos - start_pos - 8) as u64;
+ if ipc_length > i32::MAX as u64 {
+ return Err(DataFusionError::Execution(format!(
+ "Shuffle block size {ipc_length} exceeds maximum size of {}",
+ i32::MAX
+ )));
+ }
+ self.buffer[start_pos..start_pos +
8].copy_from_slice(&ipc_length.to_le_bytes());
+
+ Ok(end_pos - start_pos)
+ }
+
+ fn drain_buffer(&mut self) -> Vec<u8> {
+ std::mem::take(&mut self.buffer)
+ }
+
+ #[cfg(test)]
+ fn finish(self) -> Result<Vec<u8>> {
+ Ok(self.buffer)
+ }
+}
+
+struct SpillFile {
+ _temp_file: datafusion::execution::disk_manager::RefCountedTempFile,
+ file: File,
+}
+
+/// A partitioner that scatter-writes incoming rows directly into pre-allocated
+/// per-partition column builders. When a partition's builders reach
+/// `target_batch_size`, the batch is flushed to a compressed IPC block.
+/// No intermediate sub-batches or coalescers are created.
+pub(crate) struct ImmediateModePartitioner {
+ output_data_file: String,
+ output_index_file: String,
+ partition_buffers: Vec<PartitionBuffer>,
+ streams: Vec<PartitionOutputStream>,
+ spill_files: Vec<Option<SpillFile>>,
+ partitioning: CometPartitioning,
+ runtime: Arc<RuntimeEnv>,
+ reservation: MemoryReservation,
+ metrics: ShufflePartitionerMetrics,
+ hashes_buf: Vec<u32>,
+ partition_ids: Vec<u32>,
+ /// Reusable per-partition row index scratch space.
+ partition_row_indices: Vec<Vec<usize>>,
+ /// Maximum bytes this partitioner will reserve from the memory pool.
+ /// Computed as memory_pool_size * memory_fraction at construction.
+ memory_limit: usize,
+}
+
+impl ImmediateModePartitioner {
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) fn try_new(
+ partition: usize,
+ output_data_file: String,
+ output_index_file: String,
+ schema: SchemaRef,
+ partitioning: CometPartitioning,
+ metrics: ShufflePartitionerMetrics,
+ runtime: Arc<RuntimeEnv>,
+ batch_size: usize,
+ codec: CompressionCodec,
+ ) -> Result<Self> {
+ let num_output_partitions = partitioning.partition_count();
+
+ let partition_buffers = (0..num_output_partitions)
+ .map(|_| PartitionBuffer::new(&schema, batch_size))
+ .collect();
+
+ let streams = (0..num_output_partitions)
+ .map(|_| PartitionOutputStream::try_new(Arc::clone(&schema),
codec.clone()))
+ .collect::<Result<Vec<_>>>()?;
+
+ let spill_files: Vec<Option<SpillFile>> =
+ (0..num_output_partitions).map(|_| None).collect();
+
+ let hashes_buf = match &partitioning {
+ CometPartitioning::Hash(_, _) | CometPartitioning::RoundRobin(_,
_) => {
+ vec![0u32; batch_size]
+ }
+ _ => vec![],
+ };
+
+ let memory_limit = match runtime.memory_pool.memory_limit() {
+ MemoryLimit::Finite(pool_size) => pool_size,
+ _ => usize::MAX,
+ };
+
+ let reservation =
MemoryConsumer::new(format!("ImmediateModePartitioner[{partition}]"))
+ .with_can_spill(true)
+ .register(&runtime.memory_pool);
+
+ let partition_row_indices = (0..num_output_partitions).map(|_|
Vec::new()).collect();
+
+ Ok(Self {
+ output_data_file,
+ output_index_file,
+ partition_buffers,
+ streams,
+ spill_files,
+ partitioning,
+ runtime,
+ reservation,
+ metrics,
+ hashes_buf,
+ partition_ids: vec![0u32; batch_size],
+ partition_row_indices,
+ memory_limit,
+ })
+ }
+
+ fn compute_partition_ids(&mut self, batch: &RecordBatch) -> Result<usize> {
+ let num_rows = batch.num_rows();
+
+ // Ensure scratch buffers are large enough for this batch
+ if self.hashes_buf.len() < num_rows {
+ self.hashes_buf.resize(num_rows, 0);
+ }
+ if self.partition_ids.len() < num_rows {
+ self.partition_ids.resize(num_rows, 0);
+ }
+
+ match &self.partitioning {
+ CometPartitioning::Hash(exprs, num_output_partitions) => {
+ let num_output_partitions = *num_output_partitions;
+ let arrays = exprs
+ .iter()
+ .map(|expr| expr.evaluate(batch)?.into_array(num_rows))
+ .collect::<Result<Vec<_>>>()?;
+ let hashes_buf = &mut self.hashes_buf[..num_rows];
+ hashes_buf.fill(42_u32);
+ create_murmur3_hashes(&arrays, hashes_buf)?;
+ let partition_ids = &mut self.partition_ids[..num_rows];
+ for (idx, hash) in hashes_buf.iter().enumerate() {
+ partition_ids[idx] =
+ comet_partitioning::pmod(*hash, num_output_partitions)
as u32;
+ }
+ Ok(num_output_partitions)
+ }
+ CometPartitioning::RoundRobin(num_output_partitions,
max_hash_columns) => {
+ let num_output_partitions = *num_output_partitions;
+ let max_hash_columns = *max_hash_columns;
+ let num_columns_to_hash = if max_hash_columns == 0 {
+ batch.num_columns()
+ } else {
+ max_hash_columns.min(batch.num_columns())
+ };
+ let columns_to_hash: Vec<ArrayRef> = (0..num_columns_to_hash)
+ .map(|i| Arc::clone(batch.column(i)))
+ .collect();
+ let hashes_buf = &mut self.hashes_buf[..num_rows];
+ hashes_buf.fill(42_u32);
+ create_murmur3_hashes(&columns_to_hash, hashes_buf)?;
+ let partition_ids = &mut self.partition_ids[..num_rows];
+ for (idx, hash) in hashes_buf.iter().enumerate() {
+ partition_ids[idx] =
+ comet_partitioning::pmod(*hash, num_output_partitions)
as u32;
+ }
+ Ok(num_output_partitions)
+ }
+ CometPartitioning::RangePartitioning(
+ lex_ordering,
+ num_output_partitions,
+ row_converter,
+ bounds,
+ ) => {
+ let num_output_partitions = *num_output_partitions;
+ let arrays = lex_ordering
+ .iter()
+ .map(|expr|
expr.expr.evaluate(batch)?.into_array(num_rows))
+ .collect::<Result<Vec<_>>>()?;
+ let row_batch =
row_converter.convert_columns(arrays.as_slice())?;
+ let partition_ids = &mut self.partition_ids[..num_rows];
+ for (row_idx, row) in row_batch.iter().enumerate() {
+ partition_ids[row_idx] = bounds
+ .as_slice()
+ .partition_point(|bound| bound.row() <= row)
+ as u32;
+ }
+ Ok(num_output_partitions)
+ }
+ other => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported shuffle partitioning scheme {other:?}"
+ ))),
+ }
Review Comment:
I suggest that we move partition ID computation to a separate utility to
avoid repeating the same logic in multi partition mode and immediate mode.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]