This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 7076744de8a [opt](BinaryArithmetic)Optimize FunctionBinaryArithmetic 
by distributing types during the open phase. (#50082)
7076744de8a is described below

commit 7076744de8af3d8e2a00a53c221529d333d7541e
Author: Mryange <yanxuech...@selectdb.com>
AuthorDate: Thu Apr 17 11:35:17 2025 +0800

    [opt](BinaryArithmetic)Optimize FunctionBinaryArithmetic by distributing 
types during the open phase. (#50082)
    
    In the past, we determined the type during the exec phase.
    However, there was an issue where the type was evaluated sequentially
    each time, resulting in multiple evaluations for certain types that
    appear later.
    
    ```C++
        template <typename F>
        static bool cast_type(const IDataType* type, F&& f) {
            return cast_type_to_either<DataTypeUInt8, DataTypeInt8, 
DataTypeInt16, DataTypeInt32,
                                       DataTypeInt64, DataTypeInt128, 
DataTypeFloat32, DataTypeFloat64,
                                       DataTypeDecimal<Decimal32>, 
DataTypeDecimal<Decimal64>,
                                       DataTypeDecimal<Decimal128V2>, 
DataTypeDecimal<Decimal128V3>,
                                       DataTypeDecimal<Decimal256>>(type, 
std::forward<F>(f));
        }
    
    ```
---
 be/src/vec/functions/function_binary_arithmetic.h | 149 +++++++++++++---------
 1 file changed, 88 insertions(+), 61 deletions(-)

diff --git a/be/src/vec/functions/function_binary_arithmetic.h 
b/be/src/vec/functions/function_binary_arithmetic.h
index 13efdf9ddbd..666ee6471f2 100644
--- a/be/src/vec/functions/function_binary_arithmetic.h
+++ b/be/src/vec/functions/function_binary_arithmetic.h
@@ -20,6 +20,8 @@
 
 #pragma once
 
+#include <functional>
+#include <memory>
 #include <type_traits>
 
 #include "common/exception.h"
@@ -34,6 +36,7 @@
 #include "vec/core/types.h"
 #include "vec/core/wide_integer.h"
 #include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_factory.hpp"
 #include "vec/data_types/data_type_nullable.h"
 #include "vec/data_types/data_type_number.h"
 #include "vec/data_types/number_traits.h"
@@ -792,13 +795,13 @@ struct BinaryOperationTraits {
                  DataTypeFromFieldType<typename Op::ResultType>>>;
 };
 
-template <typename LeftDataType, typename RightDataType, typename 
ExpectedResultDataType,
+template <typename LeftDataType, typename RightDataType, typename 
FEResultDataType,
           template <typename, typename> class Operation, typename Name, bool 
is_to_null_type,
           bool check_overflow_for_decimal>
 struct ConstOrVectorAdapter {
     static constexpr bool result_is_decimal =
             IsDataTypeDecimal<LeftDataType> || 
IsDataTypeDecimal<RightDataType>;
-    using ResultDataType = ExpectedResultDataType;
+    using ResultDataType = FEResultDataType;
     using ResultType = typename ResultDataType::FieldType;
     using A = typename LeftDataType::FieldType;
     using B = typename RightDataType::FieldType;
@@ -931,6 +934,13 @@ private:
     }
 };
 
+struct BinaryArithmeticState {
+    std::function<Status(FunctionContext*, Block&, const ColumnNumbers&, 
uint32_t, size_t)> impl;
+    DataTypePtr left_type;
+    DataTypePtr right_type;
+    DataTypePtr result_type;
+};
+
 template <template <typename, typename> class Operation, typename Name, bool 
is_to_null_type>
 class FunctionBinaryArithmetic : public IFunction {
     using OpTraits = OperationTraits<Operation>;
@@ -1032,91 +1042,108 @@ public:
         return type_res;
     }
 
-    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
-                        uint32_t result, size_t input_rows_count) const 
override {
-        auto* left_generic = block.get_by_position(arguments[0]).type.get();
-        auto* right_generic = block.get_by_position(arguments[1]).type.get();
-        auto* result_generic = block.get_by_position(result).type.get();
-        if (left_generic->is_nullable()) {
-            left_generic =
-                    static_cast<const 
DataTypeNullable*>(left_generic)->get_nested_type().get();
-        }
-        if (right_generic->is_nullable()) {
-            right_generic =
-                    static_cast<const 
DataTypeNullable*>(right_generic)->get_nested_type().get();
-        }
-        if (result_generic->is_nullable()) {
-            result_generic =
-                    static_cast<const 
DataTypeNullable*>(result_generic)->get_nested_type().get();
+    Status open(FunctionContext* context, FunctionContext::FunctionStateScope 
scope) override {
+        if (scope == FunctionContext::THREAD_LOCAL) {
+            return Status::OK();
         }
-
-        bool check_overflow_for_decimal = 
context->check_overflow_for_decimal();
-        Status status;
+        std::shared_ptr<BinaryArithmeticState> state = 
std::make_shared<BinaryArithmeticState>();
+        context->set_function_state(scope, state);
+
+        state->left_type =
+                
DataTypeFactory::instance().create_data_type(*context->get_arg_type(0), false);
+        state->right_type =
+                
DataTypeFactory::instance().create_data_type(*context->get_arg_type(1), false);
+        state->result_type =
+                
DataTypeFactory::instance().create_data_type(context->get_return_type(), false);
+        const auto* left_generic = state->left_type.get();
+        const auto* right_generic = state->right_type.get();
+        const auto* result_generic = state->result_type.get();
+
+        const bool check_overflow_for_decimal = 
context->check_overflow_for_decimal();
         bool valid = cast_both_types(
                 left_generic, right_generic, result_generic,
                 [&](const auto& left, const auto& right, const auto& res) {
                     using LeftDataType = std::decay_t<decltype(left)>;
                     using RightDataType = std::decay_t<decltype(right)>;
-                    using ExpectedResultDataType = std::decay_t<decltype(res)>;
-                    using ResultDataType =
+                    using FEResultDataType = std::decay_t<decltype(res)>;
+                    using BEResultDataType =
                             typename BinaryOperationTraits<Operation, 
LeftDataType,
                                                            
RightDataType>::ResultDataType;
                     if constexpr (
-                            !std::is_same_v<ResultDataType, InvalidType> &&
-                            (IsDataTypeDecimal<ExpectedResultDataType> ==
+                            (!std::is_same_v<BEResultDataType,
+                                             InvalidType> /* Cannot be 
InvalidType */) &&
+                            (IsDataTypeDecimal<FEResultDataType> ==
                              IsDataTypeDecimal<
-                                     
ResultDataType>)&&(IsDataTypeDecimal<ExpectedResultDataType> ==
-                                                        
(IsDataTypeDecimal<LeftDataType> ||
-                                                         
IsDataTypeDecimal<RightDataType>))) {
+                                     BEResultDataType> /* The type planned by 
FE and the type planned by BE must both be Decimal or not */) &&
+                            (IsDataTypeDecimal<FEResultDataType> ==
+                             (IsDataTypeDecimal<LeftDataType> ||
+                              IsDataTypeDecimal<
+                                      RightDataType>)/* Only when at least one 
of left or right is Decimal, the return value can be Decimal */)) {
                         if (check_overflow_for_decimal) {
                             // !is_to_null_type: plus, minus, multiply,
                             //                   pow, bitxor, bitor, bitand
                             // if check_overflow and params are decimal types:
                             //   for functions pow, bitxor, bitor, bitand, 
return error
-                            if constexpr (IsDataTypeDecimal<ResultDataType> && 
!is_to_null_type &&
-                                          !OpTraits::is_multiply && 
!OpTraits::is_plus_minus) {
-                                status = 
Status::Error<ErrorCode::NOT_IMPLEMENTED_ERROR>(
-                                        "cannot check overflow with decimal 
for function {}", name);
-                                return false;
-                            }
-                            auto column_result = ConstOrVectorAdapter<
-                                    LeftDataType, RightDataType,
-                                    
std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
-                                                       ExpectedResultDataType, 
ResultDataType>,
-                                    Operation, Name, is_to_null_type,
-                                    
true>::execute(block.get_by_position(arguments[0]).column,
-                                                   
block.get_by_position(arguments[1]).column, left,
-                                                   right,
-                                                   remove_nullable(
-                                                           
block.get_by_position(result).type));
-                            block.replace_by_position(result, 
std::move(column_result));
+                            static_assert(
+                                    !(IsDataTypeDecimal<BEResultDataType> && 
!is_to_null_type &&
+                                      !OpTraits::is_multiply && 
!OpTraits::is_plus_minus),
+                                    "cannot check overflow with decimal for 
function");
+
+                            state->impl = execute_with_type<LeftDataType, 
RightDataType,
+                                                            FEResultDataType, 
true>;
                         } else {
-                            auto column_result = ConstOrVectorAdapter<
-                                    LeftDataType, RightDataType,
-                                    
std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
-                                                       ExpectedResultDataType, 
ResultDataType>,
-                                    Operation, Name, is_to_null_type,
-                                    
false>::execute(block.get_by_position(arguments[0]).column,
-                                                    
block.get_by_position(arguments[1]).column,
-                                                    left, right,
-                                                    remove_nullable(
-                                                            
block.get_by_position(result).type));
-                            block.replace_by_position(result, 
std::move(column_result));
+                            state->impl = execute_with_type<LeftDataType, 
RightDataType,
+                                                            FEResultDataType, 
false>;
                         }
+
                         return true;
                     }
                     return false;
                 });
         if (!valid) {
-            if (status.ok()) {
-                return Status::RuntimeError("{}'s arguments do not match the 
expected data types",
-                                            get_name());
-            }
-            return status;
+            return Status::RuntimeError("{}'s arguments do not match the 
expected data types",
+                                        get_name());
         }
 
         return Status::OK();
     }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        uint32_t result, size_t input_rows_count) const 
override {
+        auto* state = reinterpret_cast<BinaryArithmeticState*>(
+                context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
+        if (!state || !state->impl) {
+            return Status::RuntimeError("function context for function '{}' 
must have state;",
+                                        get_name());
+        }
+        return state->impl(context, block, arguments, result, 
input_rows_count);
+    }
+
+    template <typename LeftDataType, typename RightDataType, typename 
FEResultDataType,
+              bool check_overflow_for_decimal>
+    static Status execute_with_type(FunctionContext* context, Block& block,
+                                    const ColumnNumbers& arguments, uint32_t 
result,
+                                    size_t input_rows_count) {
+        const auto& left_type =
+                assert_cast<const 
LeftDataType&>(*block.get_by_position(arguments[0]).type);
+        const auto& right_type =
+                assert_cast<const 
RightDataType&>(*block.get_by_position(arguments[1]).type);
+
+        using BEResultDataType = typename BinaryOperationTraits<Operation, 
LeftDataType,
+                                                                
RightDataType>::ResultDataType;
+
+        using ExpectedResultDataType = 
std::conditional_t<IsDataTypeDecimal<FEResultDataType>,
+                                                          FEResultDataType, 
BEResultDataType>;
+        auto column_result =
+                ConstOrVectorAdapter<LeftDataType, RightDataType, 
ExpectedResultDataType, Operation,
+                                     Name, is_to_null_type, 
check_overflow_for_decimal>::
+                        execute(block.get_by_position(arguments[0]).column,
+                                block.get_by_position(arguments[1]).column, 
left_type, right_type,
+                                
remove_nullable(block.get_by_position(result).type));
+        block.replace_by_position(result, std::move(column_result));
+
+        return Status::OK();
+    }
 };
 
 } // namespace doris::vectorized


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to