gstvg commented on code in PR #21679:
URL: https://github.com/apache/datafusion/pull/21679#discussion_r3152298644


##########
datafusion/expr/src/higher_order_function.rs:
##########
@@ -0,0 +1,560 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`HigherOrderUDF`]: User Defined Higher Order Functions
+
+use crate::expr::schema_name_from_exprs_comma_separated_without_space;
+use crate::{ColumnarValue, Documentation, Expr};
+use arrow::array::{ArrayRef, RecordBatch};
+use arrow::datatypes::{DataType, Field, FieldRef, Schema};
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::{Result, ScalarValue, not_impl_err};
+use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
+use datafusion_expr_common::signature::Volatility;
+use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
+use std::any::Any;
+use std::cmp::Ordering;
+use std::fmt::Debug;
+use std::hash::{Hash, Hasher};
+use std::sync::Arc;
+
+/// The types of arguments for which a function has implementations.
+///
+/// [`HigherOrderTypeSignature`] **DOES NOT** define the types that a user 
query could call the
+/// function with. DataFusion will automatically coerce (cast) argument types 
to
+/// one of the supported function signatures, if possible.
+///
+/// # Overview
+/// Functions typically provide implementations for a small number of different
+/// argument [`DataType`]s, rather than all possible combinations. If a user
+/// calls a function with arguments that do not match any of the declared 
types,
+/// DataFusion will attempt to automatically coerce (add casts to) function
+/// arguments so they match the [`HigherOrderTypeSignature`]. See the 
[`type_coercion`] module
+/// for more details
+///
+/// [`type_coercion`]: crate::type_coercion
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub enum HigherOrderTypeSignature {
+    /// The acceptable signature and coercions rules are special for this
+    /// function.
+    ///
+    /// If this signature is specified,
+    /// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare 
argument types.
+    UserDefined,
+    /// One or more lambdas or arguments with arbitrary types
+    VariadicAny,
+    /// The specified number of lambdas or arguments with arbitrary types.
+    Any(usize),
+}
+
+/// Provides information necessary for calling a higher order function.
+///
+/// - [`HigherOrderTypeSignature`] defines the argument types that a function 
has implementations
+///   for.
+///
+/// - [`Volatility`] defines how the output of the function changes with the 
input.
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub struct HigherOrderSignature {
+    /// The data types that the function accepts. See 
[HigherOrderTypeSignature] for more information.
+    pub type_signature: HigherOrderTypeSignature,
+    /// The volatility of the function. See [Volatility] for more information.
+    pub volatility: Volatility,
+}
+
+impl HigherOrderSignature {
+    /// Creates a new `HigherOrderSignature` from a given type signature and 
volatility.
+    pub fn new(type_signature: HigherOrderTypeSignature, volatility: 
Volatility) -> Self {
+        HigherOrderSignature {
+            type_signature,
+            volatility,
+        }
+    }
+
+    /// User-defined coercion rules for the function.
+    pub fn user_defined(volatility: Volatility) -> Self {
+        Self {
+            type_signature: HigherOrderTypeSignature::UserDefined,
+            volatility,
+        }
+    }
+
+    /// An arbitrary number of lambdas or arguments of any type.
+    pub fn variadic_any(volatility: Volatility) -> Self {
+        Self {
+            type_signature: HigherOrderTypeSignature::VariadicAny,
+            volatility,
+        }
+    }
+
+    /// A specified number of arguments of any type
+    pub fn any(arg_count: usize, volatility: Volatility) -> Self {
+        Self {
+            type_signature: HigherOrderTypeSignature::Any(arg_count),
+            volatility,
+        }
+    }
+}
+
+impl PartialEq for dyn HigherOrderUDF {
+    fn eq(&self, other: &Self) -> bool {
+        self.dyn_eq(other as _)
+    }
+}
+
+impl PartialOrd for dyn HigherOrderUDF {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        let mut cmp = self.name().cmp(other.name());
+        if cmp == Ordering::Equal {
+            cmp = self.signature().partial_cmp(other.signature())?;
+        }
+        if cmp == Ordering::Equal {
+            cmp = self.aliases().partial_cmp(other.aliases())?;
+        }
+        // Contract for PartialOrd and PartialEq consistency requires that
+        // a == b if and only if partial_cmp(a, b) == Some(Equal).
+        if cmp == Ordering::Equal && self != other {
+            // Functions may have other properties besides name and signature
+            // that differentiate two instances (e.g. type, or arbitrary 
parameters).
+            // We cannot return Some(Equal) in such case.
+            return None;
+        }
+        debug_assert!(
+            cmp == Ordering::Equal || self != other,
+            "Detected incorrect implementation of PartialEq when comparing 
functions: '{}' and '{}'. \
+            The functions compare as equal, but they are not equal based on 
general properties that \
+            the PartialOrd implementation observes,",
+            self.name(),
+            other.name()
+        );
+        Some(cmp)
+    }
+}
+
+impl Eq for dyn HigherOrderUDF {}
+
+impl Hash for dyn HigherOrderUDF {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        self.dyn_hash(state)
+    }
+}
+
+/// Arguments passed to [`HigherOrderUDF::invoke_with_args`] when invoking a
+/// higher order function.
+#[derive(Debug, Clone)]
+pub struct HigherOrderFunctionArgs {
+    /// The evaluated arguments and lambdas to the function
+    pub args: Vec<ValueOrLambda<ColumnarValue, LambdaArgument>>,
+    /// Field associated with each arg, if it exists
+    /// For lambdas, it will be the field of the result of
+    /// the lambda if evaluated with the parameters
+    /// returned from [`HigherOrderUDF::lambda_parameters`]
+    pub arg_fields: Vec<ValueOrLambda<FieldRef, FieldRef>>,
+    /// The number of rows in record batch being evaluated
+    pub number_rows: usize,
+    /// The return field of the higher order function returned
+    /// (from `return_field_from_args`) when creating the
+    /// physical expression from the logical expression
+    pub return_field: FieldRef,
+    /// The config options at execution time
+    pub config_options: Arc<ConfigOptions>,
+}
+
+impl HigherOrderFunctionArgs {
+    /// The return type of the function. See [`Self::return_field`] for more
+    /// details.
+    pub fn return_type(&self) -> &DataType {
+        self.return_field.data_type()
+    }
+}
+
+/// A lambda argument to a HigherOrderFunction
+#[derive(Clone, Debug)]
+pub struct LambdaArgument {
+    /// The parameters defined in this lambda
+    ///
+    /// For example, for `array_transform([2], v -> -v)`,
+    /// this will be `vec![Field::new("v", DataType::Int32, true)]`
+    params: Vec<FieldRef>,
+    /// The body of the lambda
+    ///
+    /// For example, for `array_transform([2], v -> -v)`,
+    /// this will be the physical expression of `-v`
+    body: Arc<dyn PhysicalExpr>,
+}
+
+impl LambdaArgument {
+    pub fn new(params: Vec<FieldRef>, body: Arc<dyn PhysicalExpr>) -> Self {
+        Self { params, body }
+    }
+
+    /// Evaluate this lambda
+    /// `args` should evaluate to the value of each parameter
+    /// of the correspondent lambda returned in 
[HigherOrderUDF::lambda_parameters].
+    pub fn evaluate(
+        &self,
+        args: &[&dyn Fn() -> Result<ArrayRef>],
+    ) -> Result<ColumnarValue> {
+        let columns = args
+            .iter()
+            .take(self.params.len())
+            .map(|arg| arg())
+            .collect::<Result<_>>()?;
+
+        let schema = Arc::new(Schema::new(self.params.clone()));
+
+        let batch = RecordBatch::try_new(schema, columns)?;
+
+        self.body.evaluate(&batch)
+    }
+}
+
+/// Information about arguments passed to the function
+///
+/// This structure contains metadata about how the function was called
+/// such as the type of the arguments, any scalar arguments and if the
+/// arguments can (ever) be null
+///
+/// See [`HigherOrderUDF::return_field_from_args`] for more information
+#[derive(Clone, Debug)]
+pub struct HigherOrderReturnFieldArgs<'a> {
+    /// The data types of the arguments to the function
+    ///
+    /// If argument `i` to the function is a lambda, it will be the field of 
the result of the
+    /// lambda if evaluated with the parameters returned from 
[`HigherOrderUDF::lambda_parameters`]
+    ///
+    /// For example, with `array_transform([1], v -> v == 5)`
+    /// this field will be `[
+    ///     ValueOrLambda::Value(Field::new("", 
DataType::List(DataType::Int32), false)),
+    ///     ValueOrLambda::Lambda(Field::new("", DataType::Boolean, false))
+    /// ]`
+    pub arg_fields: &'a [ValueOrLambda<FieldRef, FieldRef>],
+    /// Is argument `i` to the function a scalar (constant)?
+    ///
+    /// If the argument `i` is not a scalar, it will be None
+    ///
+    /// For example, if a function is called like `array_transform([1], v -> v 
== 5)`
+    /// this field will be `[Some(ScalarValue::List(...), None]`
+    pub scalar_arguments: &'a [Option<&'a ScalarValue>],
+}
+
+/// An argument to a higher order function
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum ValueOrLambda<V, L> {
+    /// A value with associated data
+    Value(V),
+    /// A lambda with associated data
+    Lambda(L),
+}
+
+/// Trait for implementing user defined higher order functions.
+///
+/// This trait exposes the full API for implementing user defined functions and
+/// can be used to implement any function.
+///
+/// See [`array_transform.rs`] for a commented complete implementation
+///
+/// [`array_transform.rs`]: 
https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs
+pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
+    /// Returns this function's name
+    fn name(&self) -> &str;
+
+    /// Returns any aliases (alternate names) for this function.
+    ///
+    /// Aliases can be used to invoke the same function using different names.
+    /// For example in some databases `now()` and `current_timestamp()` are
+    /// aliases for the same function. This behavior can be obtained by
+    /// returning `current_timestamp` as an alias for the `now` function.
+    ///
+    /// Note: `aliases` should only include names other than [`Self::name`].
+    /// Defaults to `[]` (no aliases)
+    fn aliases(&self) -> &[String] {
+        &[]
+    }
+
+    /// Returns the name of the column this expression would create
+    ///
+    /// See [`Expr::schema_name`] for details
+    fn schema_name(&self, args: &[Expr]) -> Result<String> {
+        Ok(format!(
+            "{}({})",
+            self.name(),
+            schema_name_from_exprs_comma_separated_without_space(args)?
+        ))
+    }
+
+    /// Returns a [`HigherOrderSignature`] describing the argument types for 
which this
+    /// function has an implementation, and the function's [`Volatility`].
+    ///
+    /// See [`HigherOrderSignature`] for more details on argument type handling
+    /// and [`Self::return_field_from_args`] for computing the return type.
+    ///
+    /// [`Volatility`]: datafusion_expr_common::signature::Volatility
+    fn signature(&self) -> &HigherOrderSignature;
+
+    /// Return the field of all the parameters supported by all the supported 
lambdas of this function
+    /// based on the field of the value arguments. If a lambda support 
multiple parameters, or if multiple
+    /// lambdas are supported and some are optional, all should be returned,
+    /// regardless of whether they are used on a particular invocation
+    ///
+    /// Tip: If you have a [`HigherOrderFunction`] invocation, you can call 
the helper
+    /// [`HigherOrderFunction::lambda_parameters`] instead of this method 
directly
+    ///
+    /// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction
+    /// [`HigherOrderFunction::lambda_parameters`]: 
crate::expr::HigherOrderFunction::lambda_parameters
+    ///
+    /// Example for array_transform:
+    ///
+    /// `array_transform([2.0, 8.0], v -> v > 4.0)`
+    ///
+    /// ```ignore
+    /// let lambda_parameters = array_transform.lambda_parameters(&[
+    ///      Arc::new(Field::new("", DataType::new_list(DataType::Float32, 
false))), // the Field of the literal `[2, 8]`
+    /// ])?;
+    ///
+    /// assert_eq!(
+    ///      lambda_parameters,
+    ///      vec![
+    ///         // the lambda supported parameters, regardless of how many are 
actually used
+    ///         vec![
+    ///             // the value being transformed
+    ///             Field::new("", DataType::Float32, false),
+    ///             // the 1-based index being transformed, not used on the 
example above,
+    ///             //but implementations doesn't need to care about it
+    ///             Field::new("", DataType::Int32, false),
+    ///         ]
+    ///      ]
+    /// )
+    /// ```
+    ///
+    /// The implementation can assume that some other part of the code has 
coerced
+    /// the actual argument types to match [`Self::signature`].
+    fn lambda_parameters(&self, value_fields: &[FieldRef]) -> 
Result<Vec<Vec<Field>>>;
+
+    /// What type will be returned by this function, given the arguments?
+    ///
+    /// The implementation can assume that some other part of the code has 
coerced
+    /// the actual argument types to match [`Self::signature`].
+    ///
+    /// # Example creating `Field`
+    ///
+    /// Note the name of the [`Field`] is ignored, except for structured types 
such as
+    /// `DataType::Struct`.
+    ///
+    /// ```rust
+    /// # use std::sync::Arc;
+    /// # use arrow::datatypes::{DataType, Field, FieldRef};
+    /// # use datafusion_common::Result;
+    /// # use datafusion_expr::HigherOrderReturnFieldArgs;
+    /// # struct Example{}
+    /// # impl Example {
+    /// fn return_field_from_args(&self, args: HigherOrderReturnFieldArgs) -> 
Result<FieldRef> {
+    ///     let field = Arc::new(Field::new("ignored_name", DataType::Int32, 
true));
+    ///     Ok(field)
+    /// }
+    /// # }
+    /// ```
+    fn return_field_from_args(
+        &self,
+        args: HigherOrderReturnFieldArgs,
+    ) -> Result<FieldRef>;
+
+    /// Whether List, LargeList and FixedSizeList arguments should have it's
+    /// non-empty null sublists cleaned by Datafusion before invoking this 
function
+    ///
+    /// The default implementation always returns true and should only be 
implemented
+    /// if you want to handle non-empty null sublists yourself
+    ///
+    /// fully null fixed size list arrays should always be handled regardless 
of
+    /// the return of this function
+    // todo: extend this to listview and maps when remove_list_null_values 
supports it
+    fn clear_null_values(&self) -> bool {
+        true
+    }
+
+    /// Invoke the function returning the appropriate result.
+    ///
+    /// # Performance
+    ///
+    /// For the best performance, the implementations should handle the common 
case
+    /// when one or more of their arguments are constant values (aka
+    /// [`ColumnarValue::Scalar`]).
+    ///
+    /// [`ColumnarValue::values_to_arrays`] can be used to convert the 
arguments
+    /// to arrays, which will likely be simpler code, but be slower.
+    fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> 
Result<ColumnarValue>;
+
+    /// Returns true if some of this `exprs` subexpressions may not be 
evaluated
+    /// and thus any side effects (like divide by zero) may not be encountered.
+    ///
+    /// Setting this to true prevents certain optimizations such as common
+    /// subexpression elimination
+    ///
+    /// When overriding this function to return `true`, 
[HigherOrderUDF::conditional_arguments] can also be
+    /// overridden to report more accurately which arguments are eagerly 
evaluated and which ones
+    /// lazily.
+    fn short_circuits(&self) -> bool {
+        false
+    }
+
+    /// Determines which of the arguments passed to this function are 
evaluated eagerly
+    /// and which may be evaluated lazily.
+    ///
+    /// If this function returns `None`, all arguments are eagerly evaluated.
+    /// Returning `None` is a micro optimization that saves a needless `Vec`
+    /// allocation.
+    ///
+    /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager`
+    /// are the arguments that are always evaluated, and `lazy` are the
+    /// arguments that may be evaluated lazily (i.e. may not be evaluated at 
all
+    /// in some cases).
+    ///
+    /// Implementations must ensure that the two returned `Vec`s are disjunct,
+    /// and that each argument from `args` is present in one the two `Vec`s.
+    ///
+    /// When overriding this function, [HigherOrderUDF::short_circuits] must
+    /// be overridden to return `true`.
+    fn conditional_arguments<'a>(
+        &self,
+        args: &'a [Expr],
+    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
+        if self.short_circuits() {
+            Some((vec![], args.iter().collect()))
+        } else {
+            None
+        }
+    }
+
+    /// Coerce arguments of a function call to types that the function can 
evaluate.

Review Comment:
   I commited but forgot to answer, 
https://github.com/apache/datafusion/pull/21679/commits/098fb1eecf0478dfdc47c6edd0fb5882335dae98
 thanks



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to