alamb commented on code in PR #21328:
URL: https://github.com/apache/datafusion/pull/21328#discussion_r3033479310
##########
datafusion/physical-plan/src/sorts/sort_preserving_merge.rs:
##########
@@ -433,6 +422,99 @@ impl ExecutionPlan for SortPreservingMergeExec {
}
}
+/// A stream that lazily spawns input partition tasks and builds the streaming
+/// merge on first poll, rather than eagerly in `execute()`.
+struct SortPreservingMergeExecStream {
+ schema: SchemaRef,
+ input: Arc<dyn ExecutionPlan>,
+ context: Arc<TaskContext>,
+ expr: LexOrdering,
+ metrics: BaselineMetrics,
+ batch_size: usize,
+ fetch: Option<usize>,
+ reservation: datafusion_execution::memory_pool::MemoryReservation,
+ enable_round_robin_repartition: bool,
+ state: SPMStreamState,
+}
+
+enum SPMStreamState {
+ /// Tasks have not been spawned yet.
+ Pending,
+ /// The streaming merge has been built and is running.
+ Running(SendableRecordBatchStream),
+ /// Initialization failed.
+ Failed,
+}
+
+impl SortPreservingMergeExecStream {
+ fn start(&mut self) -> Result<&mut SendableRecordBatchStream> {
+ let input_partitions =
self.input.output_partitioning().partition_count();
+
+ let receivers = (0..input_partitions)
+ .map(|partition| {
+ let stream = self.input.execute(partition,
Arc::clone(&self.context))?;
+ Ok(spawn_buffered(stream, 1))
+ })
+ .collect::<Result<_>>()?;
+
+ debug!("Done setting up sender-receiver for
SortPreservingMergeExec::execute");
+
+ // Take reservation out of self via mem::replace to pass ownership
+ let reservation = std::mem::replace(
+ &mut self.reservation,
+ MemoryConsumer::new("empty")
+ .register(&self.context.runtime_env().memory_pool),
+ );
+
+ let result = StreamingMergeBuilder::new()
+ .with_streams(receivers)
+ .with_schema(Arc::clone(&self.schema))
+ .with_expressions(&self.expr)
+ .with_metrics(self.metrics.clone())
+ .with_batch_size(self.batch_size)
+ .with_fetch(self.fetch)
+ .with_reservation(reservation)
+ .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
+ .build()?;
+
+ debug!("Got stream result from
SortPreservingMergeStream::new_from_receivers");
+
+ self.state = SPMStreamState::Running(result);
+ match &mut self.state {
+ SPMStreamState::Running(s) => Ok(s),
+ _ => unreachable!(),
+ }
+ }
+}
+
+impl RecordBatchStream for SortPreservingMergeExecStream {
+ fn schema(&self) -> SchemaRef {
+ Arc::clone(&self.schema)
+ }
+}
+
+impl Stream for SortPreservingMergeExecStream {
+ type Item = Result<arrow::array::RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let stream = match &mut self.state {
+ SPMStreamState::Running(s) => s,
+ SPMStreamState::Failed => return Poll::Ready(None),
+ SPMStreamState::Pending => match self.start() {
+ Ok(s) => s,
+ Err(e) => {
+ self.state = SPMStreamState::Failed;
+ return Poll::Ready(Some(Err(e)));
+ }
+ },
+ };
+ stream.poll_next_unpin(cx)
+ }
+}
+
#[cfg(test)]
Review Comment:
I think it would be good to add a test to this PR to avoid regressions.
Specifically, one that calls `execute(0, ...)` without polling and proves no
child stream/task has started yet
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]