This is an automated email from the ASF dual-hosted git repository. lihaopeng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 7076744de8a [opt](BinaryArithmetic)Optimize FunctionBinaryArithmetic by distributing types during the open phase. (#50082) 7076744de8a is described below commit 7076744de8af3d8e2a00a53c221529d333d7541e Author: Mryange <yanxuech...@selectdb.com> AuthorDate: Thu Apr 17 11:35:17 2025 +0800 [opt](BinaryArithmetic)Optimize FunctionBinaryArithmetic by distributing types during the open phase. (#50082) In the past, we determined the type during the exec phase. However, there was an issue where the type was evaluated sequentially each time, resulting in multiple evaluations for certain types that appear later. ```C++ template <typename F> static bool cast_type(const IDataType* type, F&& f) { return cast_type_to_either<DataTypeUInt8, DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64, DataTypeInt128, DataTypeFloat32, DataTypeFloat64, DataTypeDecimal<Decimal32>, DataTypeDecimal<Decimal64>, DataTypeDecimal<Decimal128V2>, DataTypeDecimal<Decimal128V3>, DataTypeDecimal<Decimal256>>(type, std::forward<F>(f)); } ``` --- be/src/vec/functions/function_binary_arithmetic.h | 149 +++++++++++++--------- 1 file changed, 88 insertions(+), 61 deletions(-) diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h index 13efdf9ddbd..666ee6471f2 100644 --- a/be/src/vec/functions/function_binary_arithmetic.h +++ b/be/src/vec/functions/function_binary_arithmetic.h @@ -20,6 +20,8 @@ #pragma once +#include <functional> +#include <memory> #include <type_traits> #include "common/exception.h" @@ -34,6 +36,7 @@ #include "vec/core/types.h" #include "vec/core/wide_integer.h" #include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_factory.hpp" #include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_number.h" #include "vec/data_types/number_traits.h" @@ -792,13 +795,13 @@ struct BinaryOperationTraits { DataTypeFromFieldType<typename Op::ResultType>>>; }; -template <typename LeftDataType, typename RightDataType, typename ExpectedResultDataType, +template <typename LeftDataType, typename RightDataType, typename FEResultDataType, template <typename, typename> class Operation, typename Name, bool is_to_null_type, bool check_overflow_for_decimal> struct ConstOrVectorAdapter { static constexpr bool result_is_decimal = IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>; - using ResultDataType = ExpectedResultDataType; + using ResultDataType = FEResultDataType; using ResultType = typename ResultDataType::FieldType; using A = typename LeftDataType::FieldType; using B = typename RightDataType::FieldType; @@ -931,6 +934,13 @@ private: } }; +struct BinaryArithmeticState { + std::function<Status(FunctionContext*, Block&, const ColumnNumbers&, uint32_t, size_t)> impl; + DataTypePtr left_type; + DataTypePtr right_type; + DataTypePtr result_type; +}; + template <template <typename, typename> class Operation, typename Name, bool is_to_null_type> class FunctionBinaryArithmetic : public IFunction { using OpTraits = OperationTraits<Operation>; @@ -1032,91 +1042,108 @@ public: return type_res; } - Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, - uint32_t result, size_t input_rows_count) const override { - auto* left_generic = block.get_by_position(arguments[0]).type.get(); - auto* right_generic = block.get_by_position(arguments[1]).type.get(); - auto* result_generic = block.get_by_position(result).type.get(); - if (left_generic->is_nullable()) { - left_generic = - static_cast<const DataTypeNullable*>(left_generic)->get_nested_type().get(); - } - if (right_generic->is_nullable()) { - right_generic = - static_cast<const DataTypeNullable*>(right_generic)->get_nested_type().get(); - } - if (result_generic->is_nullable()) { - result_generic = - static_cast<const DataTypeNullable*>(result_generic)->get_nested_type().get(); + Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { + if (scope == FunctionContext::THREAD_LOCAL) { + return Status::OK(); } - - bool check_overflow_for_decimal = context->check_overflow_for_decimal(); - Status status; + std::shared_ptr<BinaryArithmeticState> state = std::make_shared<BinaryArithmeticState>(); + context->set_function_state(scope, state); + + state->left_type = + DataTypeFactory::instance().create_data_type(*context->get_arg_type(0), false); + state->right_type = + DataTypeFactory::instance().create_data_type(*context->get_arg_type(1), false); + state->result_type = + DataTypeFactory::instance().create_data_type(context->get_return_type(), false); + const auto* left_generic = state->left_type.get(); + const auto* right_generic = state->right_type.get(); + const auto* result_generic = state->result_type.get(); + + const bool check_overflow_for_decimal = context->check_overflow_for_decimal(); bool valid = cast_both_types( left_generic, right_generic, result_generic, [&](const auto& left, const auto& right, const auto& res) { using LeftDataType = std::decay_t<decltype(left)>; using RightDataType = std::decay_t<decltype(right)>; - using ExpectedResultDataType = std::decay_t<decltype(res)>; - using ResultDataType = + using FEResultDataType = std::decay_t<decltype(res)>; + using BEResultDataType = typename BinaryOperationTraits<Operation, LeftDataType, RightDataType>::ResultDataType; if constexpr ( - !std::is_same_v<ResultDataType, InvalidType> && - (IsDataTypeDecimal<ExpectedResultDataType> == + (!std::is_same_v<BEResultDataType, + InvalidType> /* Cannot be InvalidType */) && + (IsDataTypeDecimal<FEResultDataType> == IsDataTypeDecimal< - ResultDataType>)&&(IsDataTypeDecimal<ExpectedResultDataType> == - (IsDataTypeDecimal<LeftDataType> || - IsDataTypeDecimal<RightDataType>))) { + BEResultDataType> /* The type planned by FE and the type planned by BE must both be Decimal or not */) && + (IsDataTypeDecimal<FEResultDataType> == + (IsDataTypeDecimal<LeftDataType> || + IsDataTypeDecimal< + RightDataType>)/* Only when at least one of left or right is Decimal, the return value can be Decimal */)) { if (check_overflow_for_decimal) { // !is_to_null_type: plus, minus, multiply, // pow, bitxor, bitor, bitand // if check_overflow and params are decimal types: // for functions pow, bitxor, bitor, bitand, return error - if constexpr (IsDataTypeDecimal<ResultDataType> && !is_to_null_type && - !OpTraits::is_multiply && !OpTraits::is_plus_minus) { - status = Status::Error<ErrorCode::NOT_IMPLEMENTED_ERROR>( - "cannot check overflow with decimal for function {}", name); - return false; - } - auto column_result = ConstOrVectorAdapter< - LeftDataType, RightDataType, - std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>, - ExpectedResultDataType, ResultDataType>, - Operation, Name, is_to_null_type, - true>::execute(block.get_by_position(arguments[0]).column, - block.get_by_position(arguments[1]).column, left, - right, - remove_nullable( - block.get_by_position(result).type)); - block.replace_by_position(result, std::move(column_result)); + static_assert( + !(IsDataTypeDecimal<BEResultDataType> && !is_to_null_type && + !OpTraits::is_multiply && !OpTraits::is_plus_minus), + "cannot check overflow with decimal for function"); + + state->impl = execute_with_type<LeftDataType, RightDataType, + FEResultDataType, true>; } else { - auto column_result = ConstOrVectorAdapter< - LeftDataType, RightDataType, - std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>, - ExpectedResultDataType, ResultDataType>, - Operation, Name, is_to_null_type, - false>::execute(block.get_by_position(arguments[0]).column, - block.get_by_position(arguments[1]).column, - left, right, - remove_nullable( - block.get_by_position(result).type)); - block.replace_by_position(result, std::move(column_result)); + state->impl = execute_with_type<LeftDataType, RightDataType, + FEResultDataType, false>; } + return true; } return false; }); if (!valid) { - if (status.ok()) { - return Status::RuntimeError("{}'s arguments do not match the expected data types", - get_name()); - } - return status; + return Status::RuntimeError("{}'s arguments do not match the expected data types", + get_name()); } return Status::OK(); } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) const override { + auto* state = reinterpret_cast<BinaryArithmeticState*>( + context->get_function_state(FunctionContext::FRAGMENT_LOCAL)); + if (!state || !state->impl) { + return Status::RuntimeError("function context for function '{}' must have state;", + get_name()); + } + return state->impl(context, block, arguments, result, input_rows_count); + } + + template <typename LeftDataType, typename RightDataType, typename FEResultDataType, + bool check_overflow_for_decimal> + static Status execute_with_type(FunctionContext* context, Block& block, + const ColumnNumbers& arguments, uint32_t result, + size_t input_rows_count) { + const auto& left_type = + assert_cast<const LeftDataType&>(*block.get_by_position(arguments[0]).type); + const auto& right_type = + assert_cast<const RightDataType&>(*block.get_by_position(arguments[1]).type); + + using BEResultDataType = typename BinaryOperationTraits<Operation, LeftDataType, + RightDataType>::ResultDataType; + + using ExpectedResultDataType = std::conditional_t<IsDataTypeDecimal<FEResultDataType>, + FEResultDataType, BEResultDataType>; + auto column_result = + ConstOrVectorAdapter<LeftDataType, RightDataType, ExpectedResultDataType, Operation, + Name, is_to_null_type, check_overflow_for_decimal>:: + execute(block.get_by_position(arguments[0]).column, + block.get_by_position(arguments[1]).column, left_type, right_type, + remove_nullable(block.get_by_position(result).type)); + block.replace_by_position(result, std::move(column_result)); + + return Status::OK(); + } }; } // namespace doris::vectorized --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org