gstvg commented on code in PR #21679:
URL: https://github.com/apache/datafusion/pull/21679#discussion_r3144259358


##########
datafusion/expr/src/higher_order_function.rs:
##########
@@ -0,0 +1,771 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`HigherOrderUDF`]: User Defined Higher Order Functions
+
+use crate::expr::schema_name_from_exprs_comma_separated_without_space;
+use crate::{ColumnarValue, Documentation, Expr};
+use arrow::array::{ArrayRef, RecordBatch};
+use arrow::datatypes::{DataType, FieldRef, Schema};
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err};
+use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
+use datafusion_expr_common::signature::Volatility;
+use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
+use std::any::Any;
+use std::cmp::Ordering;
+use std::fmt::Debug;
+use std::hash::{Hash, Hasher};
+use std::sync::Arc;
+
+/// The types of arguments for which a function has implementations.
+///
+/// [`HigherOrderTypeSignature`] **DOES NOT** define the types that a user 
query could call the
+/// function with. DataFusion will automatically coerce (cast) argument types 
to
+/// one of the supported function signatures, if possible.
+///
+/// # Overview
+/// Functions typically provide implementations for a small number of different
+/// argument [`DataType`]s, rather than all possible combinations. If a user
+/// calls a function with arguments that do not match any of the declared 
types,
+/// DataFusion will attempt to automatically coerce (add casts to) function
+/// arguments so they match the [`HigherOrderTypeSignature`]. See the 
[`type_coercion`] module
+/// for more details
+///
+/// [`type_coercion`]: crate::type_coercion
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub enum HigherOrderTypeSignature {
+    /// The acceptable signature and coercions rules are special for this
+    /// function.
+    ///
+    /// If this signature is specified,
+    /// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare 
argument types.
+    UserDefined,
+    /// One or more lambdas or arguments with arbitrary types
+    VariadicAny,
+    /// The specified number of lambdas or arguments with arbitrary types.
+    Any(usize),
+}
+
+/// Provides information necessary for calling a higher order function.
+///
+/// - [`HigherOrderTypeSignature`] defines the argument types that a function 
has implementations
+///   for.
+///
+/// - [`Volatility`] defines how the output of the function changes with the 
input.
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub struct HigherOrderSignature {
+    /// The data types that the function accepts. See 
[HigherOrderTypeSignature] for more information.
+    pub type_signature: HigherOrderTypeSignature,
+    /// The volatility of the function. See [Volatility] for more information.
+    pub volatility: Volatility,
+    /// Whether [HigherOrderUDF::coerce_values_for_lambdas] should be called
+    pub coerce_values_for_lambdas: bool,
+}
+
+impl HigherOrderSignature {
+    /// Creates a new `HigherOrderSignature` from a given type signature and 
volatility.
+    pub fn new(type_signature: HigherOrderTypeSignature, volatility: 
Volatility) -> Self {
+        HigherOrderSignature {
+            type_signature,
+            volatility,
+            coerce_values_for_lambdas: false,
+        }
+    }
+
+    /// User-defined coercion rules for the function.
+    pub fn user_defined(volatility: Volatility) -> Self {
+        Self {
+            type_signature: HigherOrderTypeSignature::UserDefined,
+            volatility,
+            coerce_values_for_lambdas: false,
+        }
+    }
+
+    /// An arbitrary number of lambdas or arguments of any type.
+    pub fn variadic_any(volatility: Volatility) -> Self {
+        Self {
+            type_signature: HigherOrderTypeSignature::VariadicAny,
+            volatility,
+            coerce_values_for_lambdas: false,
+        }
+    }
+
+    /// A specified number of arguments of any type
+    pub fn any(arg_count: usize, volatility: Volatility) -> Self {
+        Self {
+            type_signature: HigherOrderTypeSignature::Any(arg_count),
+            volatility,
+            coerce_values_for_lambdas: false,
+        }
+    }
+
+    /// Set [Self::coerce_values_for_lambdas] to true to indicate that 
[HigherOrderUDF::coerce_values_for_lambdas]
+    /// should be called
+    pub fn with_coerce_values_for_lambdas(mut self) -> Self {
+        self.coerce_values_for_lambdas = true;
+
+        self
+    }
+}
+
+impl PartialEq for dyn HigherOrderUDF {
+    fn eq(&self, other: &Self) -> bool {
+        self.dyn_eq(other as _)
+    }
+}
+
+impl PartialOrd for dyn HigherOrderUDF {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        let mut cmp = self.name().cmp(other.name());
+        if cmp == Ordering::Equal {
+            cmp = self.signature().partial_cmp(other.signature())?;
+        }
+        if cmp == Ordering::Equal {
+            cmp = self.aliases().partial_cmp(other.aliases())?;
+        }
+        // Contract for PartialOrd and PartialEq consistency requires that
+        // a == b if and only if partial_cmp(a, b) == Some(Equal).
+        if cmp == Ordering::Equal && self != other {
+            // Functions may have other properties besides name and signature
+            // that differentiate two instances (e.g. type, or arbitrary 
parameters).
+            // We cannot return Some(Equal) in such case.
+            return None;
+        }
+        debug_assert!(
+            cmp == Ordering::Equal || self != other,
+            "Detected incorrect implementation of PartialEq when comparing 
functions: '{}' and '{}'. \
+            The functions compare as equal, but they are not equal based on 
general properties that \
+            the PartialOrd implementation observes,",
+            self.name(),
+            other.name()
+        );
+        Some(cmp)
+    }
+}
+
+impl Eq for dyn HigherOrderUDF {}
+
+impl Hash for dyn HigherOrderUDF {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        self.dyn_hash(state)
+    }
+}
+
+/// Arguments passed to [`HigherOrderUDF::invoke_with_args`] when invoking a
+/// higher order function.
+#[derive(Debug, Clone)]
+pub struct HigherOrderFunctionArgs {
+    /// The evaluated arguments and lambdas to the function
+    pub args: Vec<ValueOrLambda<ColumnarValue, LambdaArgument>>,
+    /// Field associated with each arg, if it exists
+    /// For lambdas, it will be the field of the result of
+    /// the lambda if evaluated with the parameters
+    /// returned from [`HigherOrderUDF::lambda_parameters`]
+    pub arg_fields: Vec<ValueOrLambda<FieldRef, FieldRef>>,
+    /// The number of rows in record batch being evaluated
+    pub number_rows: usize,
+    /// The return field of the higher order function returned
+    /// (from `return_field_from_args`) when creating the
+    /// physical expression from the logical expression
+    pub return_field: FieldRef,
+    /// The config options at execution time
+    pub config_options: Arc<ConfigOptions>,
+}
+
+impl HigherOrderFunctionArgs {
+    /// The return type of the function. See [`Self::return_field`] for more
+    /// details.
+    pub fn return_type(&self) -> &DataType {
+        self.return_field.data_type()
+    }
+}
+
+/// A lambda argument to a HigherOrderFunction
+#[derive(Clone, Debug)]
+pub struct LambdaArgument {
+    /// The parameters defined in this lambda
+    ///
+    /// For example, for `array_transform([2], v -> -v)`,
+    /// this will be `vec![Field::new("v", DataType::Int32, true)]`
+    params: Vec<FieldRef>,
+    /// The body of the lambda
+    ///
+    /// For example, for `array_transform([2], v -> -v)`,
+    /// this will be the physical expression of `-v`
+    body: Arc<dyn PhysicalExpr>,
+    /// A RecordBatch with the captured columns inside the lambda body, if any
+    ///
+    /// For example, for `array_transform([2], v -> v + a + b)`,
+    /// this will be a `RecordBatch` with columns `a` and `b`
+    captures: Option<RecordBatch>,
+}
+
+impl LambdaArgument {
+    /// Create a new LambdaArgument
+    ///
+    /// Note that capture is not supported yet and must be `None` for now,
+    /// otherwise [LambdaArgument::evaluate] will fail
+    pub fn new(
+        params: Vec<FieldRef>,
+        body: Arc<dyn PhysicalExpr>,
+        captures: Option<RecordBatch>,
+    ) -> Self {
+        Self {
+            params,
+            body,
+            captures,
+        }
+    }
+
+    /// Evaluate this lambda
+    /// `args` should evaluate to the value of each parameter
+    /// of the correspondent lambda returned in 
[HigherOrderUDF::lambda_parameters].
+    ///
+    /// `adjust` should adjust the length of captured columns of this
+    /// lambda relative to it's parameters
+    pub fn evaluate(
+        &self,
+        args: &[&dyn Fn() -> Result<ArrayRef>],
+        _adjust: impl FnOnce(&[ArrayRef]) -> Result<Vec<ArrayRef>>,
+    ) -> Result<ColumnarValue> {
+        if self.captures.is_some() {
+            return exec_err!("lambda column capture is not supported yet");
+        }
+
+        let columns = args
+            .iter()
+            .take(self.params.len())
+            .map(|arg| arg())
+            .collect::<Result<_>>()?;
+
+        let schema = Arc::new(Schema::new(self.params.clone()));
+
+        let batch = RecordBatch::try_new(schema, columns)?;
+
+        self.body.evaluate(&batch)
+    }
+}
+
+/// Information about arguments passed to the function
+///
+/// This structure contains metadata about how the function was called
+/// such as the type of the arguments, any scalar arguments and if the
+/// arguments can (ever) be null
+///
+/// See [`HigherOrderUDF::return_field_from_args`] for more information
+#[derive(Clone, Debug)]
+pub struct HigherOrderReturnFieldArgs<'a> {
+    /// The data types of the arguments to the function
+    ///
+    /// If argument `i` to the function is a lambda, it will be the field of 
the result of the
+    /// lambda if evaluated with the parameters returned from 
[`HigherOrderUDF::lambda_parameters`]
+    ///
+    /// For example, with `array_transform([1], v -> v == 5)`
+    /// this field will be
+    /// ```ignore
+    /// [
+    ///     ValueOrLambda::Value(Field::new("", 
DataType::new_list(DataType::Int32, true), true)),
+    ///     ValueOrLambda::Lambda(Field::new("", DataType::Boolean, true))
+    /// ]
+    /// ```
+    pub arg_fields: &'a [ValueOrLambda<FieldRef, FieldRef>],
+    /// Is argument `i` to the function a scalar (constant)?
+    ///
+    /// If the argument `i` is not a scalar, it will be None
+    ///
+    /// For example, if a function is called like `array_transform([1], v -> v 
== 5)`
+    /// this field will be `[Some(ScalarValue::List(...), None]`
+    pub scalar_arguments: &'a [Option<&'a ScalarValue>],
+}
+
+/// An argument to a higher order function
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum ValueOrLambda<V, L> {
+    /// A value with associated data
+    Value(V),
+    /// A lambda with associated data
+    Lambda(L),
+}
+
+/// The return of [HigherOrderUDF::lambda_parameters]
+pub enum LambdaParametersProgress {
+    Partial(Vec<Option<Vec<FieldRef>>>),
+    Complete(Vec<Vec<FieldRef>>),
+}
+
+/// Trait for implementing user defined higher order functions.
+///
+/// This trait exposes the full API for implementing user defined functions and
+/// can be used to implement any function.
+///
+/// See [`array_transform.rs`] for a commented complete implementation
+///
+/// [`array_transform.rs`]: 
https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs
+pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
+    /// Returns this function's name
+    fn name(&self) -> &str;
+
+    /// Returns any aliases (alternate names) for this function.
+    ///
+    /// Aliases can be used to invoke the same function using different names.
+    /// For example in some databases `now()` and `current_timestamp()` are
+    /// aliases for the same function. This behavior can be obtained by
+    /// returning `current_timestamp` as an alias for the `now` function.
+    ///
+    /// Note: `aliases` should only include names other than [`Self::name`].
+    /// Defaults to `[]` (no aliases)
+    fn aliases(&self) -> &[String] {
+        &[]
+    }
+
+    /// Returns the name of the column this expression would create
+    ///
+    /// See [`Expr::schema_name`] for details
+    fn schema_name(&self, args: &[Expr]) -> Result<String> {
+        Ok(format!(
+            "{}({})",
+            self.name(),
+            schema_name_from_exprs_comma_separated_without_space(args)?
+        ))
+    }
+
+    /// Returns a [`HigherOrderSignature`] describing the argument types for 
which this
+    /// function has an implementation, and the function's [`Volatility`].
+    ///
+    /// See [`HigherOrderSignature`] for more details on argument type handling
+    /// and [`Self::return_field_from_args`] for computing the return type.
+    ///
+    /// [`Volatility`]: datafusion_expr_common::signature::Volatility
+    fn signature(&self) -> &HigherOrderSignature;
+
+    /// Return the field of all the parameters supported by the lambdas in 
`fields`.
+    /// If a lambda support multiple parameters, all should be returned, 
regardless of
+    /// whether they are used or not on a particular invocation
+    ///
+    /// Tip: If you have a [`HigherOrderFunction`] invocation, you can call 
the helper
+    /// [`HigherOrderFunction::lambda_parameters`] instead of this method 
directly
+    ///
+    /// The name of the returned fields are ignored.
+    ///
+    /// This function is repeatedelly called until 
[LambdaParametersProgress::Complete] is returned, with
+    /// `step` increased by one at each invocation, starting at 0.
+    ///
+    /// For functions which all lambda parameters depend only on the field of 
it's value arguments,
+    /// this can return [LambdaParametersProgress::Complete] at step 0. Taking 
as an example a strict
+    /// array_reduce with the signature `(arr: [V], initial_value: I, (I, V) 
-> I, (I) -> O) -> O`, which
+    /// requires it's initial value to be the exact same type of it's merge 
output, which is also the
+    /// parameter of it's finish lambda, the expression
+    ///
+    /// `array_reduce([1.2, 2.1], 0.0, (acc, v) -> acc + v + 1.5, v -> v > 
5.1)`
+    ///
+    ///  would result in this function being called as the following:
+    ///
+    /// ```ignore
+    /// let lambda_parameters = array_reduce.lambda_parameters(
+    ///     0,
+    ///     &[
+    ///         // the Field of the literal `[1.2, 2.1]`, the array being 
reduced
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::new_list(DataType::Float32, true), true))),
+    ///         // the Field of the literal `0.0`, the initial value
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::Float32, true))),
+    ///         // the Field of the output of the merge lambda, which is 
unknown at this point because it depends
+    ///         // on the return of this call
+    ///         ValueOrLambda::Lambda(None),
+    ///         // the Field of the output of the finish lambda, unknown for 
the same reason as above
+    ///         ValueOrLambda::Lambda(None),
+    /// ])?;
+    ///
+    /// assert_eq!(
+    ///      lambda_parameters,
+    ///      LambdaParametersProgress::Complete(vec![
+    ///         // the finish lambda supported parameters, regardless of how 
many are actually used
+    ///         vec![
+    ///             // the accumulator which is the field of the initial value
+    ///             Arc::new(Field::new("ignored_name", DataType::Float32, 
true)),
+    ///             // the array values being reduced
+    ///             Arc::new(Field::new("", DataType::Float32, true)),
+    ///         ],
+    ///         // the merge lambda supported parameters
+    ///         vec![
+    ///             // the reduced value which is the field of the initial 
value
+    ///             Arc::new(Field::new("ignored_name", DataType::Float32, 
true)),
+    ///         ],
+    ///      ])
+    /// );
+    /// ```
+    ///
+    /// For functions which lambda parameters depends on the output of other 
lambdas, or on their own lambda,
+    /// this can return [LambdaParametersProgress::Partial] until all 
dependencies are met. Note that for
+    /// lambda with cyclic dependencies, you likely want to use 
[HigherOrderUDF::coerce_values_for_lambdas] too.
+    /// Take as an example a flexible array_reduce with the signature `(arr: 
[V], initial_value: I, (ACC, V) -> ACC, (ACC) -> O) -> O`.
+    /// It has a cyclic dependency in the merge lambda, and a dependency of 
the finish lambda in the merge lambda,
+    /// and only requires the initial value to be *coercible* to the output of 
the merge lambda, which is defined by
+    /// it's [HigherOrderUDF::coerce_values_for_lambdas] implementation. The 
expression
+    ///
+    /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)`
+    ///
+    /// would result in this function being called as the following:
+    ///
+    /// ```ignore
+    /// let lambda_parameters = array_reduce.lambda_parameters(
+    ///     0,
+    ///     &[
+    ///         // the Field of the literal `[1.2, 2.1]`, the array being 
reduced
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::new_list(DataType::Float32, true), true))),
+    ///         // the Field of the literal `0`, the initial value
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, 
true))),
+    ///         // the Field of the output of the merge lambda, which is 
unknown at this point because it depends on
+    ///         // the return this call
+    ///         ValueOrLambda::Lambda(None),
+    ///         // the Field of the output of the finish lambda, unknown for 
the same reason as above
+    ///         ValueOrLambda::Lambda(None),
+    /// ])?;
+    ///
+    /// assert_eq!(
+    ///      lambda_parameters,
+    ///      LambdaParametersProgress::Partial(vec![
+    ///         // the finish lambda supported parameters, regardless of how 
many are actually used
+    ///         Some(vec![
+    ///             // at step 0, use the field of the initial value
+    ///             Arc::new(Field::new("ignored_name", DataType::Int32, 
true)),
+    ///             // the array values being reduced
+    ///             Arc::new(Field::new("", DataType::Float32, true)),
+    ///         ]),
+    ///         // the merge lambda supported parameters, unknown at this 
point due to dependency on the merge output
+    ///         None,
+    ///      ])
+    /// );
+    ///
+    /// let lambda_parameters = array_reduce.lambda_parameters(
+    ///     1,
+    ///     &[
+    ///         // the Field of the literal `[1.2, 2.1]`, the array being 
reduced
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::new_list(DataType::Float32, true), true))),
+    ///         // the Field of the literal `0`, the initial value
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, 
true))),
+    ///         // the Field of the output of the merge lambda, which could be 
inferred to be a Float32 based on the
+    ///         // returned values of the previous step
+    ///         ValueOrLambda::Value(Arc::new(Field::new("", 
DataType::Float32, true))),
+    ///         // the Field of the output of the finish lambda, which is 
unknown at this point because it depends
+    ///         // on the return of this call
+    ///         ValueOrLambda::Lambda(None),
+    /// ])?;
+    ///
+    /// assert_eq!(
+    ///      lambda_parameters,
+    ///      LambdaParametersProgress::Complete(vec![
+    ///         // the finish lambda supported parameters, regardless of how 
many are actually used
+    ///         vec![
+    ///             // the finish lambda own output now used as it's 
accumulator
+    ///             Arc::new(Field::new("ignored_name", DataType::Float32, 
true)),
+    ///             // the array values being reduced
+    ///             Arc::new(Field::new("", DataType::Float32, true)),
+    ///         ],
+    ///         // the merge lambda supported parameters, which is the output 
of the merge lambda,
+    ///         vec![
+    ///             // the output of the merge lambda
+    ///             Arc::new(Field::new("", DataType::Float32, true)),
+    ///         ],
+    ///      ])
+    /// );
+    ///
+    /// let coerce_to = array_reduce.coerce_values_for_lambdas(&[
+    ///     // the literal `[1.2, 2.1]` data type, the array being reduced
+    ///     ValueOrLambda::Value(DataType::new_list(DataType::Float32, true)),
+    ///     // the literal `0` data type, the initial value
+    ///     ValueOrLambda::Value(DataType::Int32),
+    ///     // the output data type of the merge lambda
+    ///     ValueOrLambda::Lambda(DataType::Float32),
+    ///     // the output data type of the finish lambda
+    ///     ValueOrLambda::Lambda(DataType::Boolean),
+    /// ])?;
+    ///
+    /// assert_eq!(
+    ///     coerce_to,
+    ///     vec![
+    ///         // return the same type for the array being reduced
+    ///         DataType::new_list(DataType::Float32, true),
+    ///         // coerce the initial value to the output of the merge lambda
+    ///         DataType::Float32,
+    ///     ]
+    /// );
+    ///
+    /// ```
+    ///
+    /// Note this may also be called at step 0 with all lambda outputs already 
set, and in that case,
+    /// [LambdaParametersProgress::Complete] must be returned
+    ///
+    /// The implementation can assume that some other part of the code has 
coerced
+    /// the actual argument types to match [`Self::signature`], except the 
coercion defined by
+    /// [Self::coerce_values_for_lambdas], if applicable.
+    ///
+    /// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction
+    /// [`HigherOrderFunction::lambda_parameters`]: 
crate::expr::HigherOrderFunction::lambda_parameters
+    fn lambda_parameters(
+        &self,
+        step: usize,
+        fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
+    ) -> Result<LambdaParametersProgress>;
+
+    /// Coerce value arguments of a function call to types that the function 
can evaluate also taking into
+    /// account the *output type of it's lambdas*. This differs from 
[HigherOrderUDF::coerce_value_types]
+    /// that only has access to the type of it's value arguments. So that this 
method is called, the
+    /// function must have it's 
[HigherOrderSignature::coerce_values_for_lambdas] set to true
+    ///
+    /// See the [type coercion module](crate::type_coercion)
+    /// documentation for more details on type coercion
+    ///
+    /// # Parameters
+    /// * `fields`: The argument types of the value arguments of this 
function, or the output type of lambdas
+    ///
+    /// # Return value
+    /// A Vec with the same number of [ValueOrLambda::Value] in `fields`. 
DataFusion will `CAST` the
+    /// function call arguments to these specific types.
+    ///
+    /// For example, a flexible array_reduce implementation (see 
[Self::lambda_parameters] docs), when working
+    /// with the expression below, may want to coerce it's initial value 
argument, the *integer* `0`,
+    /// to match the output it's merge function, which is a *float*:
+    ///
+    /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 2.0)`
+    fn coerce_values_for_lambdas(
+        &self,
+        _fields: &[ValueOrLambda<DataType, DataType>],
+    ) -> Result<Vec<DataType>> {
+        not_impl_err!(
+            "{} coerce_values_for_lambdas is not implemented",
+            self.name()
+        )
+    }

Review Comment:
   @rluvaton @comphead @LiaCastaneda @pepijnve 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to