This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b61b30c2 chore: Extract some tied down logic (#3374)
7b61b30c2 is described below

commit 7b61b30c24a9fcedded76b45666c43df244915fc
Author: Emily Matheys <[email protected]>
AuthorDate: Wed Feb 4 00:17:49 2026 +0200

    chore: Extract some tied down logic (#3374)
---
 native/core/src/execution/shuffle/metrics.rs       |  61 +++++++
 native/core/src/execution/shuffle/mod.rs           |   1 +
 .../core/src/execution/shuffle/shuffle_writer.rs   | 179 ++++++++-------------
 3 files changed, 127 insertions(+), 114 deletions(-)

diff --git a/native/core/src/execution/shuffle/metrics.rs 
b/native/core/src/execution/shuffle/metrics.rs
new file mode 100644
index 000000000..33b51c3cd
--- /dev/null
+++ b/native/core/src/execution/shuffle/metrics.rs
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::physical_plan::metrics::{
+    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time,
+};
+
+pub(super) struct ShufflePartitionerMetrics {
+    /// metrics
+    pub(super) baseline: BaselineMetrics,
+
+    /// Time to perform repartitioning
+    pub(super) repart_time: Time,
+
+    /// Time encoding batches to IPC format
+    pub(super) encode_time: Time,
+
+    /// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL 
Metrics.
+    pub(super) write_time: Time,
+
+    /// Number of input batches
+    pub(super) input_batches: Count,
+
+    /// count of spills during the execution of the operator
+    pub(super) spill_count: Count,
+
+    /// total spilled bytes during the execution of the operator
+    pub(super) spilled_bytes: Count,
+
+    /// The original size of spilled data. Different to `spilled_bytes` 
because of compression.
+    pub(super) data_size: Count,
+}
+
+impl ShufflePartitionerMetrics {
+    pub(super) fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> 
Self {
+        Self {
+            baseline: BaselineMetrics::new(metrics, partition),
+            repart_time: 
MetricBuilder::new(metrics).subset_time("repart_time", partition),
+            encode_time: 
MetricBuilder::new(metrics).subset_time("encode_time", partition),
+            write_time: MetricBuilder::new(metrics).subset_time("write_time", 
partition),
+            input_batches: 
MetricBuilder::new(metrics).counter("input_batches", partition),
+            spill_count: MetricBuilder::new(metrics).spill_count(partition),
+            spilled_bytes: 
MetricBuilder::new(metrics).spilled_bytes(partition),
+            data_size: MetricBuilder::new(metrics).counter("data_size", 
partition),
+        }
+    }
+}
diff --git a/native/core/src/execution/shuffle/mod.rs 
b/native/core/src/execution/shuffle/mod.rs
index 2e9a08c43..a72258322 100644
--- a/native/core/src/execution/shuffle/mod.rs
+++ b/native/core/src/execution/shuffle/mod.rs
@@ -17,6 +17,7 @@
 
 pub(crate) mod codec;
 mod comet_partitioning;
+mod metrics;
 mod shuffle_writer;
 pub mod spark_unsafe;
 
diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs 
b/native/core/src/execution/shuffle/shuffle_writer.rs
index 55d6a9ef9..5c68940b9 100644
--- a/native/core/src/execution/shuffle/shuffle_writer.rs
+++ b/native/core/src/execution/shuffle/shuffle_writer.rs
@@ -17,6 +17,7 @@
 
 //! Defines the External shuffle repartition plan.
 
+use crate::execution::shuffle::metrics::ShufflePartitionerMetrics;
 use crate::execution::shuffle::{CometPartitioning, CompressionCodec, 
ShuffleBlockWriter};
 use crate::execution::tracing::{with_trace, with_trace_async};
 use arrow::compute::interleave_record_batch;
@@ -35,9 +36,7 @@ use datafusion::{
         runtime_env::RuntimeEnv,
     },
     physical_plan::{
-        metrics::{
-            BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, 
MetricsSet, Time,
-        },
+        metrics::{ExecutionPlanMetricsSet, MetricsSet, Time},
         stream::RecordBatchStreamAdapter,
         DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, 
SendableRecordBatchStream,
         Statistics,
@@ -185,7 +184,7 @@ impl ExecutionPlan for ShuffleWriterExec {
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
         let input = self.input.execute(partition, Arc::clone(&context))?;
-        let metrics = ShuffleRepartitionerMetrics::new(&self.metrics, 0);
+        let metrics = ShufflePartitionerMetrics::new(&self.metrics, 0);
 
         Ok(Box::pin(RecordBatchStreamAdapter::new(
             self.schema(),
@@ -216,7 +215,7 @@ async fn external_shuffle(
     output_data_file: String,
     output_index_file: String,
     partitioning: CometPartitioning,
-    metrics: ShuffleRepartitionerMetrics,
+    metrics: ShufflePartitionerMetrics,
     context: Arc<TaskContext>,
     codec: CompressionCodec,
     tracing_enabled: bool,
@@ -268,47 +267,6 @@ async fn external_shuffle(
     .await
 }
 
-struct ShuffleRepartitionerMetrics {
-    /// metrics
-    baseline: BaselineMetrics,
-
-    /// Time to perform repartitioning
-    repart_time: Time,
-
-    /// Time encoding batches to IPC format
-    encode_time: Time,
-
-    /// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL 
Metrics.
-    write_time: Time,
-
-    /// Number of input batches
-    input_batches: Count,
-
-    /// count of spills during the execution of the operator
-    spill_count: Count,
-
-    /// total spilled bytes during the execution of the operator
-    spilled_bytes: Count,
-
-    /// The original size of spilled data. Different to `spilled_bytes` 
because of compression.
-    data_size: Count,
-}
-
-impl ShuffleRepartitionerMetrics {
-    fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
-        Self {
-            baseline: BaselineMetrics::new(metrics, partition),
-            repart_time: 
MetricBuilder::new(metrics).subset_time("repart_time", partition),
-            encode_time: 
MetricBuilder::new(metrics).subset_time("encode_time", partition),
-            write_time: MetricBuilder::new(metrics).subset_time("write_time", 
partition),
-            input_batches: 
MetricBuilder::new(metrics).counter("input_batches", partition),
-            spill_count: MetricBuilder::new(metrics).spill_count(partition),
-            spilled_bytes: 
MetricBuilder::new(metrics).spilled_bytes(partition),
-            data_size: MetricBuilder::new(metrics).counter("data_size", 
partition),
-        }
-    }
-}
-
 #[async_trait::async_trait]
 trait ShufflePartitioner: Send + Sync {
     /// Insert a batch into the partitioner
@@ -328,7 +286,7 @@ struct MultiPartitionShuffleRepartitioner {
     /// Partitioning scheme to use
     partitioning: CometPartitioning,
     runtime: Arc<RuntimeEnv>,
-    metrics: ShuffleRepartitionerMetrics,
+    metrics: ShufflePartitionerMetrics,
     /// Reused scratch space for computing partition indices
     scratch: ScratchSpace,
     /// The configured batch size
@@ -356,6 +314,54 @@ struct ScratchSpace {
     partition_starts: Vec<u32>,
 }
 
+impl ScratchSpace {
+    fn map_partition_ids_to_starts_and_indices(
+        &mut self,
+        num_output_partitions: usize,
+        num_rows: usize,
+    ) {
+        let partition_ids = &mut self.partition_ids[..num_rows];
+
+        // count each partition size, while leaving the last extra element as 0
+        let partition_counters = &mut self.partition_starts;
+        partition_counters.resize(num_output_partitions + 1, 0);
+        partition_counters.fill(0);
+        partition_ids
+            .iter()
+            .for_each(|partition_id| partition_counters[*partition_id as 
usize] += 1);
+
+        // accumulate partition counters into partition ends
+        // e.g. partition counter: [1, 3, 2, 1, 0] => [1, 4, 6, 7, 7]
+        let partition_ends = partition_counters;
+        let mut accum = 0;
+        partition_ends.iter_mut().for_each(|v| {
+            *v += accum;
+            accum = *v;
+        });
+
+        // calculate partition row indices and partition starts
+        // e.g. partition ids: [3, 1, 1, 1, 2, 2, 0] will produce the 
following partition_row_indices
+        // and partition_starts arrays:
+        //
+        //  partition_row_indices: [6, 1, 2, 3, 4, 5, 0]
+        //  partition_starts: [0, 1, 4, 6, 7]
+        //
+        // partition_starts conceptually splits partition_row_indices into 
smaller slices.
+        // Each slice 
partition_row_indices[partition_starts[K]..partition_starts[K + 1]] contains the
+        // row indices of the input batch that are partitioned into partition 
K. For example,
+        // first partition 0 has one row index [6], partition 1 has row 
indices [1, 2, 3], etc.
+        let partition_row_indices = &mut self.partition_row_indices;
+        partition_row_indices.resize(num_rows, 0);
+        for (index, partition_id) in partition_ids.iter().enumerate().rev() {
+            partition_ends[*partition_id as usize] -= 1;
+            let end = partition_ends[*partition_id as usize];
+            partition_row_indices[end as usize] = index as u32;
+        }
+
+        // after calculating, partition ends become partition starts
+    }
+}
+
 impl MultiPartitionShuffleRepartitioner {
     #[allow(clippy::too_many_arguments)]
     pub fn try_new(
@@ -364,7 +370,7 @@ impl MultiPartitionShuffleRepartitioner {
         output_index_file: String,
         schema: SchemaRef,
         partitioning: CometPartitioning,
-        metrics: ShuffleRepartitionerMetrics,
+        metrics: ShufflePartitionerMetrics,
         runtime: Arc<RuntimeEnv>,
         batch_size: usize,
         codec: CompressionCodec,
@@ -432,52 +438,6 @@ impl MultiPartitionShuffleRepartitioner {
             return Ok(());
         }
 
-        fn map_partition_ids_to_starts_and_indices(
-            scratch: &mut ScratchSpace,
-            num_output_partitions: usize,
-            num_rows: usize,
-        ) {
-            let partition_ids = &mut scratch.partition_ids[..num_rows];
-
-            // count each partition size, while leaving the last extra element 
as 0
-            let partition_counters = &mut scratch.partition_starts;
-            partition_counters.resize(num_output_partitions + 1, 0);
-            partition_counters.fill(0);
-            partition_ids
-                .iter()
-                .for_each(|partition_id| partition_counters[*partition_id as 
usize] += 1);
-
-            // accumulate partition counters into partition ends
-            // e.g. partition counter: [1, 3, 2, 1, 0] => [1, 4, 6, 7, 7]
-            let partition_ends = partition_counters;
-            let mut accum = 0;
-            partition_ends.iter_mut().for_each(|v| {
-                *v += accum;
-                accum = *v;
-            });
-
-            // calculate partition row indices and partition starts
-            // e.g. partition ids: [3, 1, 1, 1, 2, 2, 0] will produce the 
following partition_row_indices
-            // and partition_starts arrays:
-            //
-            //  partition_row_indices: [6, 1, 2, 3, 4, 5, 0]
-            //  partition_starts: [0, 1, 4, 6, 7]
-            //
-            // partition_starts conceptually splits partition_row_indices into 
smaller slices.
-            // Each slice 
partition_row_indices[partition_starts[K]..partition_starts[K + 1]] contains the
-            // row indices of the input batch that are partitioned into 
partition K. For example,
-            // first partition 0 has one row index [6], partition 1 has row 
indices [1, 2, 3], etc.
-            let partition_row_indices = &mut scratch.partition_row_indices;
-            partition_row_indices.resize(num_rows, 0);
-            for (index, partition_id) in 
partition_ids.iter().enumerate().rev() {
-                partition_ends[*partition_id as usize] -= 1;
-                let end = partition_ends[*partition_id as usize];
-                partition_row_indices[end as usize] = index as u32;
-            }
-
-            // after calculating, partition ends become partition starts
-        }
-
         if input.num_rows() > self.batch_size {
             return Err(DataFusionError::Internal(
                 "Input batch size exceeds configured batch size. Call 
`insert_batch` instead."
@@ -524,11 +484,8 @@ impl MultiPartitionShuffleRepartitioner {
 
                     // We now have partition ids for every input row, map that 
to partition starts
                     // and partition indices to eventually right these rows to 
partition buffers.
-                    map_partition_ids_to_starts_and_indices(
-                        &mut scratch,
-                        *num_output_partitions,
-                        num_rows,
-                    );
+                    scratch
+                        
.map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows);
 
                     timer.stop();
                     Ok::<(&Vec<u32>, &Vec<u32>), DataFusionError>((
@@ -580,11 +537,8 @@ impl MultiPartitionShuffleRepartitioner {
 
                     // We now have partition ids for every input row, map that 
to partition starts
                     // and partition indices to eventually right these rows to 
partition buffers.
-                    map_partition_ids_to_starts_and_indices(
-                        &mut scratch,
-                        *num_output_partitions,
-                        num_rows,
-                    );
+                    scratch
+                        
.map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows);
 
                     timer.stop();
                     Ok::<(&Vec<u32>, &Vec<u32>), DataFusionError>((
@@ -642,11 +596,8 @@ impl MultiPartitionShuffleRepartitioner {
 
                     // We now have partition ids for every input row, map that 
to partition starts
                     // and partition indices to eventually write these rows to 
partition buffers.
-                    map_partition_ids_to_starts_and_indices(
-                        &mut scratch,
-                        *num_output_partitions,
-                        num_rows,
-                    );
+                    scratch
+                        
.map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows);
 
                     timer.stop();
                     Ok::<(&Vec<u32>, &Vec<u32>), DataFusionError>((
@@ -923,7 +874,7 @@ struct SinglePartitionShufflePartitioner {
     /// Number of rows in the concatenating batches
     num_buffered_rows: usize,
     /// Metrics for the repartitioner
-    metrics: ShuffleRepartitionerMetrics,
+    metrics: ShufflePartitionerMetrics,
     /// The configured batch size
     batch_size: usize,
 }
@@ -933,7 +884,7 @@ impl SinglePartitionShufflePartitioner {
         output_data_path: String,
         output_index_path: String,
         schema: SchemaRef,
-        metrics: ShuffleRepartitionerMetrics,
+        metrics: ShufflePartitionerMetrics,
         batch_size: usize,
         codec: CompressionCodec,
         write_buffer_size: usize,
@@ -1200,7 +1151,7 @@ impl PartitionWriter {
         &mut self,
         iter: &mut PartitionedBatchIterator,
         runtime: &RuntimeEnv,
-        metrics: &ShuffleRepartitionerMetrics,
+        metrics: &ShufflePartitionerMetrics,
         write_buffer_size: usize,
     ) -> Result<usize> {
         if let Some(batch) = iter.next() {
@@ -1393,7 +1344,7 @@ mod test {
     }
 
     #[tokio::test]
-    async fn shuffle_repartitioner_memory() {
+    async fn shuffle_partitioner_memory() {
         let batch = create_batch(900);
         assert_eq!(8316, batch.get_array_memory_size()); // Not stable across 
Arrow versions
 
@@ -1407,7 +1358,7 @@ mod test {
             "/tmp/index.out".to_string(),
             batch.schema(),
             CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], 
num_partitions),
-            ShuffleRepartitionerMetrics::new(&metrics_set, 0),
+            ShufflePartitionerMetrics::new(&metrics_set, 0),
             runtime_env,
             1024,
             CompressionCodec::Lz4Frame,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to