gstvg commented on code in PR #21679:
URL: https://github.com/apache/datafusion/pull/21679#discussion_r3152294069


##########
datafusion/expr/src/higher_order_function.rs:
##########
@@ -333,51 +353,210 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send 
+ Sync + Any {
     /// [`Volatility`]: datafusion_expr_common::signature::Volatility
     fn signature(&self) -> &HigherOrderSignature;
 
-    /// Return the field of all the parameters supported by all the supported 
lambdas of this function
-    /// based on the field of the value arguments. If a lambda support 
multiple parameters, or if multiple
-    /// lambdas are supported and some are optional, all should be returned,
-    /// regardless of whether they are used on a particular invocation
+    /// Return the field of all the parameters supported by the lambdas in 
`fields`.
+    /// If a lambda support multiple parameters, all should be returned, 
regardless of
+    /// whether they are used or not on a particular invocation
     ///
     /// Tip: If you have a [`HigherOrderFunction`] invocation, you can call 
the helper
     /// [`HigherOrderFunction::lambda_parameters`] instead of this method 
directly
     ///
     /// The name of the returned [`Field`]'s are ignored.
     ///
-    /// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction
-    /// [`HigherOrderFunction::lambda_parameters`]: 
crate::expr::HigherOrderFunction::lambda_parameters
+    /// This function is repeatedelly called until 
[LambdaParametersProgress::Complete] is returned, with
+    /// `step` increased by one at each invocation, starting at 0.
     ///
-    /// Example for array_transform:
+    /// For functions which all lambda parameters depend only on the field of 
it's value arguments,
+    /// this can return [LambdaParametersProgress::Complete] at step 0. Taking 
as an example a strict
+    /// array_reduce with the signature `(arr: [V], initial_value: I, (I, V) 
-> I, (I) -> O) -> O`, which
+    /// requires it's initial value to be the exact same type of it's merge 
output, which is also the
+    /// parameter of it's finish lambda, the expression
     ///
-    /// `array_transform([2.0, 8.0], v -> v > 4.0)`
+    /// `array_reduce([1.2, 2.1], 0.0, (acc, v) -> acc + v + 1.5, v -> v > 
5.1)`
+    ///
+    ///  would result in this function being called as the following:
     ///
     /// ```ignore
-    /// let lambda_parameters = array_transform.lambda_parameters(&[
-    ///      Arc::new(Field::new("", DataType::new_list(DataType::Float32, 
false))), // the Field of the literal `[2, 8]`
+    /// let lambda_parameters = array_reduce.lambda_parameters(
+    ///     0,
+    ///     &[
+    ///         // the Field of the literal `[1.2, 2.1]`, the array being 
reduced
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::new_list(DataType::Float32, true), true))),
+    ///         // the Field of the literal `0.0`, the initial value
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::Float32, true))),
+    ///         // the Field of the output of the merge lambda, which is 
unknow at this point because it depends
+    ///         // on the return of this call
+    ///         ValueOrLambda::Lambda(None),
+    ///         // the Field of the output of the finish lambda, unknow for 
the same reason as above
+    ///         ValueOrLambda::Lambda(None),
     /// ])?;
     ///
     /// assert_eq!(
     ///      lambda_parameters,
-    ///      vec![
-    ///         // the lambda supported parameters, regardless of how many are 
actually used
+    ///      LambdaParametersProgress::Complete(vec![
+    ///         // the finish lambda supported parameters, regardless of how 
many are actually used
+    ///         vec![
+    ///             // the accumulator which is the field of the initial value
+    ///             Arc::new(Field::new("ignored_name", DataType::Float32, 
true)),
+    ///             // the array values being reduced
+    ///             Arc::new(Field::new("", DataType::Float32, true)),
+    ///         ],
+    ///         // the merge lambda supported parameters
     ///         vec![
-    ///             // the value being transformed
-    ///             Field::new("ignored_name", DataType::Float32, false),
-    ///             // the 1-based index being transformed, not used on the 
example above,
-    ///             //but implementations doesn't need to care about it
-    ///             Field::new("", DataType::Int32, false),
-    ///         ]
-    ///      ]
-    /// )
+    ///             // the reduced value which is the field of the initial 
value
+    ///             Arc::new(Field::new("ignored_name", DataType::Float32, 
true)),
+    ///         ],
+    ///      ])
+    /// );
     /// ```
     ///
+    /// For functions which lambda parameters depends on the output of other 
lambdas, or on their own lambda,
+    /// this can return [LambdaParametersProgress::Partial] until all 
dependencies are met. Note that for
+    /// lambda with cyclic dependencies, you likely want to use 
[HigherOrderUDF::coerce_values_for_lambdas] too.
+    /// Take as an example a flexible array_reduce with the signature `(arr: 
[V], initial_value: I, (ACC, V) -> ACC, (ACC) -> O) -> O`.
+    /// It has a cyclic dependency in the merge lambda, and a dependency of 
the finish lambda in the merge lambda,
+    /// and only requires the initial value to be *coercible* to the output of 
the merge lambda, which is defined by
+    /// it's [HigherOrderUDF::coerce_values_for_lambdas] implementation. The 
expression
+    ///
+    /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)`
+    ///
+    /// would result in this function being called as the following:
+    ///
+    /// ```ignore
+    /// let lambda_parameters = array_reduce.lambda_parameters(
+    ///     0,
+    ///     &[
+    ///         // the Field of the literal `[1.2, 2.1]`, the array being 
reduced
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::new_list(DataType::Float32, true), true))),
+    ///         // the Field of the literal `0`, the initial value
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, 
true))),
+    ///         // the Field of the output of the merge lambda, which is 
unknow at this point because it depends on
+    ///         // the return this call
+    ///         ValueOrLambda::Lambda(None),
+    ///         // the Field of the output of the finish lambda, unknow for 
the same reason as above
+    ///         ValueOrLambda::Lambda(None),
+    /// ])?;
+    ///
+    /// assert_eq!(
+    ///      lambda_parameters,
+    ///      LambdaParametersProgress::Partial(vec![
+    ///         // the finish lambda supported parameters, regardless of how 
many are actually used
+    ///         Some(vec![
+    ///             // at step 0, use the field of the initial value
+    ///             Arc::new(Field::new("ignored_name", DataType::Int32, 
true)),
+    ///             // the array values being reduced
+    ///             Arc::new(Field::new("", DataType::Float32, true)),
+    ///         ]),
+    ///         // the merge lambda supported parameters, unknow at this point 
due to dependency on the merge output
+    ///         None,
+    ///      ])
+    /// );
+    ///
+    /// let lambda_parameters = array_reduce.lambda_parameters(
+    ///     1,
+    ///     &[
+    ///         // the Field of the literal `[1.2, 2.1]`, the array being 
reduced
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::new_list(DataType::Float32, true), true))),
+    ///         // the Field of the literal `0`, the initial value
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, 
true))),
+    ///         // the Field of the output of the merge lambda, which could be 
inferred to be a Float32 based on the
+    ///         // returned values of the previous step
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::Float32, true))),
+    ///         // the Field of the output of the finish lambda, which is 
unknow at this point because it depends
+    ///         // on the return of this call
+    ///         ValueOrLambda::Lambda(None),
+    /// ])?;
+    ///
+    /// assert_eq!(
+    ///      lambda_parameters,
+    ///      LambdaParametersProgress::Complete(vec![
+    ///         // the finish lambda supported parameters, regardless of how 
many are actually used
+    ///         vec![
+    ///             // the finish lambda own output now used as it's 
accumulator
+    ///             Arc::new(Field::new("ignored_name", DataType::Float32, 
true)),
+    ///             // the array values being reduced
+    ///             Arc::new(Field::new("", DataType::Float32, true)),
+    ///         ],
+    ///         // the merge lambda supported parameters, which is the output 
of the merge lambda,
+    ///         vec![
+    ///             // the output of the merge lambda
+    ///             Arc::new(Field::new("", DataType::Float32, true)),
+    ///         ],
+    ///      ])
+    /// );
+    ///
+    /// let coerce_to = array_reduce.coerce_values_for_lambdas(&[
+    ///     // the literal `[1.2, 2.1]` data type, the array being reduced
+    ///     ValueOrLambda::Value(DataType::new_list(DataType::Float32, true)),
+    ///     // the literal `0` data type, the initial value
+    ///     ValueOrLambda::Value(DataType::Int32),
+    ///     // the output data type of the merge lambda
+    ///     ValueOrLambda::Lambda(DataType::Float32),
+    ///     // the output data type of the finish lambda
+    ///     ValueOrLambda::Lambda(DataType::Boolean),
+    /// ])?;
+    ///
+    /// assert_eq!(
+    ///     coerce_to,
+    ///     vec![
+    ///         // return the same type for the array being reduced
+    ///         DataType::new_list(DataType::Float32, true),
+    ///         // coerce the initial value to the output of the merge lambda
+    ///         DataType::Float32,
+    ///     ]
+    /// );
+    ///
+    /// ```
+    ///
+    /// Note this may also be called at step 0 with all lambda outputs already 
set, and in that case,
+    /// [LambdaParametersProgress::Complete] must be returned
+    ///
     /// The implementation can assume that some other part of the code has 
coerced
-    /// the actual argument types to match [`Self::signature`].
-    fn lambda_parameters(&self, value_fields: &[FieldRef]) -> 
Result<Vec<Vec<Field>>>;
+    /// the actual argument types to match [`Self::signature`], except the 
coercion defined by
+    /// [Self::coerce_values_for_lambdas], if applicable.
+    ///
+    /// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction
+    /// [`HigherOrderFunction::lambda_parameters`]: 
crate::expr::HigherOrderFunction::lambda_parameters
+    fn lambda_parameters(
+        &self,
+        step: usize,
+        fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
+    ) -> Result<LambdaParametersProgress>;
+
+    /// Coerce value arguments of a function call to types that the function 
can evaluate also taking into
+    /// account the *output type of it's lambdas*. This differs from 
[HigherOrderUDF::coerce_value_types]
+    /// that only has access to the type of it's value arguments. So that this 
method is called, the
+    /// function must have it's 
[HigherOrderSignature::coerce_values_for_lambdas] set to true
+    ///
+    /// See the [type coercion module](crate::type_coercion)
+    /// documentation for more details on type coercion
+    ///
+    /// # Parameters
+    /// * `fields`: The argument types of the value arguments of this 
function, or the output type of lambdas
+    ///
+    /// # Return value
+    /// A Vec with the same number of [ValueOrLambda::Value] in `fields`. 
DataFusion will `CAST` the
+    /// function call arguments to these specific types.
+    ///
+    /// For example, a flexible array_reduce implementation (see 
[Self::lambda_parameters] docs), when working
+    /// with the expression below, may want to coerce it's initial value 
argument, the *integer* `0`,
+    /// to match the output it's merge function, which is a *float*:
+    ///
+    /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 2.0)`
+    fn coerce_values_for_lambdas(
+        &self,
+        _fields: &[ValueOrLambda<DataType, DataType>],
+    ) -> Result<Vec<DataType>> {

Review Comment:
   expanded an existing test to cover this 
https://github.com/apache/datafusion/pull/21679/changes/c93fc811a697249d9c6b374794faf5dcbe67a6d4



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to