avantgardnerio commented on code in PR #23026: URL: https://github.com/apache/datafusion/pull/23026#discussion_r3454319602
########## datafusion/physical-plan/src/range_repartition.rs: ########## @@ -0,0 +1,636 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Range-partition an input stream on a single Int64 order-key into N +//! output partitions, with halo overlap for bounded RANGE-frame window +//! functions sitting above it. +//! +//! `execute()`'s first call spawns a coordinator that: +//! 1. opens `child.execute(k)` for every input partition `k`, +//! 2. drives each stream to its first batch (which makes the pipeline- +//! breaking sort child populate its `PartitionExtremes` slot), +//! 3. reads `child.runtime_partition_extremes(k)` per input, +//! 4. lex-reduces those into a single global [`PartitionExtremes`], derives +//! `N` equal-width Int64 bucket boundaries from `[global.min, +//! global.max]`, and computes per-bucket expanded ranges by +//! extending each primary [b_i, b_{i+1}) outward by +//! `halo_preceding` / `halo_following`, +//! 5. then for every batch flowing out of every input stream, splits +//! the batch into per-bucket pieces (rows whose order key lies in +//! bucket `b`'s expanded range), and sends each piece into bucket +//! `b`'s output channel. +//! +//! Halo rows therefore appear in *two* output partitions (their primary +//! bucket and the neighbor whose expanded range reaches them). That's +//! correct for letting the per-bucket window operator compute frame +//! values at the seams — but it also means rows are duplicated in the +//! merged output until a future `HaloDropExec` strips halo rows after +//! the window. + +use std::sync::{Arc, Mutex}; + +use arrow::array::{Array, Int64Array, RecordBatch, UInt32Array}; +use arrow::compute::take_arrays; +use arrow::datatypes::SchemaRef; +use datafusion_common::{DataFusionError, Result, internal_datafusion_err}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::OrderingRequirements; +use futures::StreamExt; +use log::info; +use tokio::sync::{mpsc, oneshot}; + +use datafusion_common::ScalarValue; + +use crate::sorts::sort::lex_compare; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, + PartitionExtremes, PlanProperties, SendableRecordBatchStream, +}; + +#[derive(Debug)] +pub struct RangeRepartitionExec { + input: Arc<dyn ExecutionPlan>, + cache: Arc<PlanProperties>, + /// Required input ordering — passed down from the consumer (window + /// operator) so EnsureRequirements inserts the pipeline-breaking sort + /// *below* us, not above. Same key feeds the routing decision. + ordering: LexOrdering, + /// Halo distance preceding each bucket's primary range, in + /// leading-sort-key units. Carried over from the window frame at plan + /// time so the coordinator can derive per-bucket expanded ranges. + halo_preceding: i64, + /// Halo distance following each bucket's primary range. + halo_following: i64, + state: Arc<Mutex<State>>, + /// Per-output-partition primary `[lo, hi_exclusive)` ranges, filled + /// by the coordinator before any batch is routed. Surfaced through + /// `runtime_partition_extremes(partition)` so downstream operators + /// (e.g. HaloDropExec) can read each bucket's intended primary + /// range without needing the global extremes. + bucket_primary_ranges: Arc<Mutex<Option<Vec<(i64, i64)>>>>, +} + +struct State { + initialized: bool, + /// One `oneshot::Receiver` per output partition, populated when the + /// coordinator hands off this partition's data. `take()`n by the + /// corresponding `execute(partition)` call. + handoffs: Vec<Option<oneshot::Receiver<Result<PartitionData>>>>, +} + +/// Per-output-partition payload the coordinator hands to its stream. +/// Once the coordinator has computed boundaries it starts router tasks +/// that funnel routed batches into bucket-keyed mpsc channels. Each +/// output partition's stream drains its receiver. +struct PartitionData { + rx: mpsc::Receiver<Result<RecordBatch>>, +} + +impl std::fmt::Debug for State { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("State") + .field("initialized", &self.initialized) + .field("handoffs", &self.handoffs.len()) + .finish() + } +} + +impl RangeRepartitionExec { + pub fn new( + input: Arc<dyn ExecutionPlan>, + ordering: LexOrdering, + halo_preceding: i64, + halo_following: i64, + ) -> Self { + let n = input.output_partitioning().partition_count(); + let cache = Arc::clone(input.properties()); + Self { + input, + cache, + ordering, + halo_preceding, + halo_following, + state: Arc::new(Mutex::new(State { + initialized: false, + handoffs: (0..n).map(|_| None).collect(), + })), + bucket_primary_ranges: Arc::new(Mutex::new(None)), + } + } + + pub fn input(&self) -> &Arc<dyn ExecutionPlan> { + &self.input + } +} + +impl DisplayAs for RangeRepartitionExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "RangeRepartitionExec") + } +} + +impl ExecutionPlan for RangeRepartitionExec { + fn name(&self) -> &'static str { + "RangeRepartitionExec" + } + + fn properties(&self) -> &Arc<PlanProperties> { + &self.cache + } + + fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { + vec![&self.input] + } + + fn with_new_children( + self: Arc<Self>, + mut children: Vec<Arc<dyn ExecutionPlan>>, + ) -> Result<Arc<dyn ExecutionPlan>> { + Ok(Arc::new(Self::new( + children.swap_remove(0), + self.ordering.clone(), + self.halo_preceding, + self.halo_following, + ))) + } + + fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> { + vec![Some(OrderingRequirements::from(self.ordering.clone()))] + } + + fn maintains_input_order(&self) -> Vec<bool> { + vec![true] + } + + /// Returns each output partition's *intended primary range* as + /// inclusive `[min, max]` — not the actual range of routed data + /// (which is wider, by `halo_preceding`/`halo_following`). This is a + /// "useful lie" the downstream `HaloDropExec` consumes to filter + /// halo rows back out. + /// + /// Returns `Ok(None)` if the coordinator hasn't computed boundaries + /// yet — callers must drive the input stream to first batch before + /// reading, per the trait contract on `runtime_partition_extremes`. + fn runtime_partition_extremes( + &self, + partition: usize, + ) -> Result<Option<PartitionExtremes>> { + let guard = self.bucket_primary_ranges.lock().map_err(|_| { + internal_datafusion_err!( + "RangeRepartitionExec bucket_primary_ranges mutex poisoned" + ) + })?; + let Some(ranges) = guard.as_ref() else { + return Ok(None); + }; + let &(lo, hi_excl) = &ranges[partition]; + // Convert [lo, hi_exclusive) → inclusive [min, max]. + let max = hi_excl.saturating_sub(1); + Ok(Some(PartitionExtremes { + min: vec![ScalarValue::Int64(Some(lo))], + max: vec![ScalarValue::Int64(Some(max))], + row_count: 0, // not tracked; consumers shouldn't rely on it + })) + } + + fn execute( + &self, + partition: usize, + context: Arc<TaskContext>, + ) -> Result<SendableRecordBatchStream> { + let mut state = self.state.lock().map_err(|_| { + internal_datafusion_err!("RangeRepartitionExec mutex poisoned") + })?; + if !state.initialized { + state.initialized = true; + let n = state.handoffs.len(); + let mut senders = Vec::with_capacity(n); + for slot in state.handoffs.iter_mut() { + let (tx, rx) = oneshot::channel(); + senders.push(tx); + *slot = Some(rx); + } + let child = Arc::clone(&self.input); + let ctx = Arc::clone(&context); + let halo_preceding = self.halo_preceding; + let halo_following = self.halo_following; + let primaries = Arc::clone(&self.bucket_primary_ranges); + tokio::spawn(coordinator( + child, + ctx, + senders, + halo_preceding, + halo_following, + primaries, + )); + } + let rx = state + .handoffs + .get_mut(partition) + .and_then(Option::take) + .ok_or_else(|| { + internal_datafusion_err!("partition {partition} already taken") + })?; + drop(state); + + let schema = self.schema(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + partition_stream(schema, rx), + ))) + } +} + +/// Stream that awaits the coordinator's handoff for one output partition, +/// then drains the bucket-keyed mpsc receiver router tasks are pushing +/// into. If the coordinator drops the sender (e.g. setup failed) the +/// stream surfaces an error. +fn partition_stream( + _schema: SchemaRef, + rx: oneshot::Receiver<Result<PartitionData>>, +) -> impl futures::Stream<Item = Result<RecordBatch>> + Send { + use futures::stream::{TryStreamExt, once}; + once(async move { + let data = rx + .await + .map_err(|_| internal_datafusion_err!("coordinator dropped"))??; + let mut bucket_rx = data.rx; + let inner = futures::stream::poll_fn(move |cx| bucket_rx.poll_recv(cx)); + Ok::<_, DataFusionError>(inner) + }) + .try_flatten() +} + +/// Coordinator task: drive every input partition to first batch, gather +/// runtime extremes, log the lex-reduced global, then hand off per-input +/// payloads to their corresponding output partition. +async fn coordinator( + child: Arc<dyn ExecutionPlan>, + ctx: Arc<TaskContext>, + mut senders: Vec<oneshot::Sender<Result<PartitionData>>>, + halo_preceding: i64, + halo_following: i64, + bucket_primary_ranges: Arc<Mutex<Option<Vec<(i64, i64)>>>>, +) { + let n = senders.len(); + + // Phase 1: open every input stream and pull the first batch from each. + let mut firsts: Vec<(Option<RecordBatch>, SendableRecordBatchStream)> = + Vec::with_capacity(n); + for k in 0..n { + let mut stream = match child.execute(k, Arc::clone(&ctx)) { + Ok(s) => s, + Err(e) => { + let msg = format!("input {k} open failed: {e}"); + for tx in senders.drain(..) { + let _ = tx.send(Err(internal_datafusion_err!("{msg}"))); + } + return; + } + }; + let first = match stream.next().await { + Some(Ok(batch)) => Some(batch), + Some(Err(e)) => { + let msg = format!("first batch from input {k} failed: {e}"); + for tx in senders.drain(..) { + let _ = tx.send(Err(internal_datafusion_err!("{msg}"))); + } + return; + } + None => None, + }; + firsts.push((first, stream)); + } + + // Phase 2: collect per-input runtime extremes. + let per_input: Vec<Option<PartitionExtremes>> = (0..n) + .map(|k| child.runtime_partition_extremes(k).ok().flatten()) + .collect(); + + // Phase 3: lex-reduce per-input → global, using the input's declared + // output ordering so direction and null ordering are honored. + let ordering: Option<LexOrdering> = child.output_ordering().cloned(); + let global = ordering + .as_ref() + .and_then(|o| reduce_global_extremes(&per_input, o)); + + info!( + "RangeRepartitionExec: coordinator gathered {} input partitions; \ + global extremes = {:?}", + n, global + ); + + // Phase 4: derive bucket boundaries from the global extremes. v1 is Review Comment: Repartition dynamically without a true AQE boundry -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
