kosiew commented on code in PR #21175:
URL: https://github.com/apache/datafusion/pull/21175#discussion_r3050715672


##########
datafusion/physical-plan/src/aggregates/mod.rs:
##########
@@ -4131,16 +4114,180 @@ mod tests {
 
         let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?;
 
+        Ok(result)
+    }
+
+    /// Tests that PartialReduce mode:
+    /// 1. Accepts state as input (like Final)
+    /// 2. Produces state as output (like Partial)
+    /// 3. Can be followed by a Final stage to get the correct result
+    ///
+    /// This simulates a tree-reduce pattern:
+    ///   Partial -> PartialReduce -> Final
+    #[tokio::test]
+    async fn test_partial_reduce_mode() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::UInt32, false),
+            Field::new("b", DataType::Float64, false),
+        ]));
+
+        // Produce two partitions of input data
+        let batch1 = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![
+                Arc::new(UInt32Array::from(vec![1, 2, 3])),
+                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
+            ],
+        )?;
+        let batch2 = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![
+                Arc::new(UInt32Array::from(vec![1, 2, 3])),
+                Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
+            ],
+        )?;
+
+        let groups =
+            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, 
"a".to_string())]);
+        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
+            AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
+                .schema(Arc::clone(&schema))
+                .alias("SUM(b)")
+                .build()?,
+        )];
+
+        let result =
+            evaluate_partial_reduce(groups, aggregates, [vec![batch1], 
vec![batch2]])
+                .await?;
+
         // Expected: group 1 -> 10+40=50, group 2 -> 20+50=70, group 3 -> 
30+60=90
         assert_snapshot!(batches_to_sort_string(&result), @r"
-            +---+--------+
-            | a | SUM(b) |
-            +---+--------+
-            | 1 | 50.0   |
-            | 2 | 70.0   |
-            | 3 | 90.0   |
-            +---+--------+
-        ");
+                        +---+--------+
+                        | a | SUM(b) |
+                        +---+--------+
+                        | 1 | 50.0   |
+                        | 2 | 70.0   |
+                        | 3 | 90.0   |
+                        +---+--------+
+                    ");
+
+        Ok(())
+    }
+
+    /// Tests that PartialReduce mode:
+    /// 1. Accepts state as input (like Final)
+    /// 2. Produces state as output (like Partial)
+    /// 3. Can be followed by a Final stage to get the correct result
+    ///
+    /// This simulates a tree-reduce pattern:
+    ///   Partial -> PartialReduce -> Final
+    #[tokio::test]
+    async fn 
test_partial_reduce_mode_on_aggregate_that_have_more_than_1_state_fields_and_input_argument()
+    -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::UInt32, false),
+            Field::new("b", DataType::Float64, false),
+        ]));
+
+        // Produce two partitions of input data
+        let batch1 = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![
+                Arc::new(UInt32Array::from(vec![1, 2, 3])),
+                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
+            ],
+        )?;
+        let batch2 = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![
+                Arc::new(UInt32Array::from(vec![1, 2, 3])),
+                Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
+            ],
+        )?;
+
+        let groups =
+            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, 
"a".to_string())]);
+        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
+            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
+                .schema(Arc::clone(&schema))
+                .alias("AVG(b)")
+                .build()?,
+        )];
+
+        let result =
+            evaluate_partial_reduce(groups, aggregates, [vec![batch1], 
vec![batch2]])
+                .await?;
+
+        assert_snapshot!(batches_to_sort_string(&result), @r"
+                        +---+--------+
+                        | a | AVG(b) |
+                        +---+--------+
+                        | 1 | 25.0   |
+                        | 2 | 35.0   |
+                        | 3 | 45.0   |
+                        +---+--------+
+                    ");
+
+        Ok(())
+    }
+
+    /// Tests that PartialReduce mode:
+    /// 1. Accepts state as input (like Final)
+    /// 2. Produces state as output (like Partial)
+    /// 3. Can be followed by a Final stage to get the correct result
+    ///
+    /// This simulates a tree-reduce pattern:
+    ///   Partial -> PartialReduce -> Final
+    #[tokio::test]
+    async fn 
test_partial_reduce_mode_on_aggregate_that_have_more_than_state_fields_than_input_arguments()

Review Comment:
   This new `approx_percentile_cont` case is helpful, but it is still only an 
indirect proxy for the regression mentioned in the PR description.
   
   Right now it shows that one concrete aggregate can round-trip through 
`Partial -> PartialReduce -> Final`, but it would not fail if `PartialReduce` 
accidentally passed state-field types into `AccumulatorArgs::expr_fields`.
   
   It would be great to add a small test-only aggregate stub that asserts its 
accumulator receives the original input field types. That would directly pin 
the intended contract and avoid giving false confidence in cases where the 
first state field happens to remain compatible.



##########
datafusion/physical-plan/src/aggregates/mod.rs:
##########
@@ -4021,41 +4022,23 @@ mod tests {
     ///
     /// This simulates a tree-reduce pattern:
     ///   Partial -> PartialReduce -> Final
-    #[tokio::test]
-    async fn test_partial_reduce_mode() -> Result<()> {
-        let schema = Arc::new(Schema::new(vec![
-            Field::new("a", DataType::UInt32, false),
-            Field::new("b", DataType::Float64, false),
-        ]));
-
-        // Produce two partitions of input data
-        let batch1 = RecordBatch::try_new(
-            Arc::clone(&schema),
-            vec![
-                Arc::new(UInt32Array::from(vec![1, 2, 3])),
-                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
-            ],
-        )?;
-        let batch2 = RecordBatch::try_new(
-            Arc::clone(&schema),
-            vec![
-                Arc::new(UInt32Array::from(vec![1, 2, 3])),
-                Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
-            ],
-        )?;
+    async fn evaluate_partial_reduce(

Review Comment:
   `evaluate_partial_reduce` is a nice extraction 👍
   
   One thing I noticed is that the three callers still repeat the same schema 
and batch setup. You might consider table-driving the UDAF, alias, and expected 
snapshot so the test focuses more on which aggregate shape is being exercised 
rather than re-encoding the same fixture each time.



##########
datafusion/functions-aggregate/src/approx_percentile_cont.rs:
##########
@@ -297,6 +297,9 @@ impl AggregateUDFImpl for ApproxPercentileCont {
     }
 
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        if arg_types.len() > 3 {

Review Comment:
   Since the public signature already restricts `approx_percentile_cont` to 2 
or 3 arguments, this extra arity check reads more like a defensive internal 
invariant than user-facing validation.
   
   A short comment explaining that it protects aggregate planning from 
accidentally feeding state-field types back into `return_type` would help 
clarify the intent for future readers.



##########
datafusion/physical-plan/src/aggregates/mod.rs:
##########


Review Comment:
   This helper currently hardcodes `vec![None]` for the filter list, which 
quietly assumes there is only a single aggregate expression.
   
   Using `vec![None; aggregates.len()]` would make it more reusable if we add 
more aggregate cases later and also makes the helper's contract a bit clearer.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to