This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch branch-1.2-lts in repository https://gitbox.apache.org/repos/asf/doris.git
commit e599532588b5ac8d20cfc22efe89c311a0b2af24 Author: HappenLee <happen...@hotmail.com> AuthorDate: Thu Dec 29 15:35:17 2022 +0800 [Bug](Decimalv3) coredump of decimalv3 multiply (#15452) --- be/src/vec/core/decimal_comparison.h | 8 +-- be/src/vec/data_types/data_type_decimal.h | 60 ++++++---------------- be/src/vec/functions/function_binary_arithmetic.h | 7 ++- regression-test/data/decimalv3/test_decimalv3.out | 3 ++ .../suites/decimalv3/test_decimalv3.groovy | 1 + 5 files changed, 29 insertions(+), 50 deletions(-) diff --git a/be/src/vec/core/decimal_comparison.h b/be/src/vec/core/decimal_comparison.h index 1f2bce4200..4c7cfcf765 100644 --- a/be/src/vec/core/decimal_comparison.h +++ b/be/src/vec/core/decimal_comparison.h @@ -142,9 +142,11 @@ private: Shift shift; if (decimal0 && decimal1) { - auto result_type = decimal_result_type(*decimal0, *decimal1, false, false); - shift.a = result_type.scale_factor_for(*decimal0, false); - shift.b = result_type.scale_factor_for(*decimal1, false); + using Type = std::conditional_t<sizeof(T) >= sizeof(U), T, U>; + auto type_ptr = decimal_result_type(*decimal0, *decimal1, false, false, false); + const DataTypeDecimal<Type>* result_type = check_decimal<Type>(*type_ptr); + shift.a = result_type->scale_factor_for(*decimal0, false); + shift.b = result_type->scale_factor_for(*decimal1, false); } else if (decimal0) { shift.b = decimal0->get_scale_multiplier(); } else if (decimal1) { diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h index ea29954293..c8e08303a1 100644 --- a/be/src/vec/data_types/data_type_decimal.h +++ b/be/src/vec/data_types/data_type_decimal.h @@ -219,56 +219,30 @@ private: }; template <typename T, typename U> -typename std::enable_if_t<(sizeof(T) >= sizeof(U)), const DataTypeDecimal<T>> decimal_result_type( - const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, bool is_multiply, - bool is_divide) { +DataTypePtr decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, + bool is_multiply, bool is_divide, bool is_plus_minus) { + using Type = std::conditional_t<sizeof(T) >= sizeof(U), T, U>; if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) { - return DataTypeDecimal<T>(max_decimal_precision<T>(), 9); + return std::make_shared<DataTypeDecimal<Type>>((max_decimal_precision<T>(), 9)); } else { - UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale()); + UInt32 scale = std::max(tx.get_scale(), ty.get_scale()); + auto precision = max_decimal_precision<Type>(); + + size_t multiply_precision = tx.get_precision() + ty.get_precision(); + size_t divide_precision = tx.get_precision() + ty.get_scale(); + size_t plus_minus_precision = + std::max(tx.get_precision() - tx.get_scale(), ty.get_precision() - ty.get_scale()) + + scale; if (is_multiply) { scale = tx.get_scale() + ty.get_scale(); + precision = std::min(multiply_precision, max_decimal_precision<Decimal128I>()); } else if (is_divide) { scale = tx.get_scale(); + precision = std::min(divide_precision, max_decimal_precision<Decimal128I>()); + } else if (is_plus_minus) { + precision = std::min(plus_minus_precision, max_decimal_precision<Decimal128I>()); } - return DataTypeDecimal<T>(max_decimal_precision<T>(), scale); - } -} - -template <typename T, typename U> -typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DataTypeDecimal<U>> decimal_result_type( - const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, bool is_multiply, - bool is_divide) { - if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) { - return DataTypeDecimal<U>(max_decimal_precision<U>(), 9); - } else { - UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale()); - if (is_multiply) { - scale = tx.get_scale() + ty.get_scale(); - } else if (is_divide) { - scale = tx.get_scale(); - } - return DataTypeDecimal<U>(max_decimal_precision<U>(), scale); - } -} - -template <typename T, typename U> -const DataTypeDecimal<T> decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeNumber<U>&, - bool, bool) { - if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) { - return DataTypeDecimal<T>(max_decimal_precision<T>(), 9); - } else { - return DataTypeDecimal<T>(max_decimal_precision<T>(), tx.get_scale()); - } -} - -template <typename T, typename U> -const DataTypeDecimal<U> decimal_result_type(const DataTypeNumber<T>&, const DataTypeDecimal<U>& ty, - bool, bool) { - if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) { - return DataTypeDecimal<U>(max_decimal_precision<U>(), 9); - } else { - return DataTypeDecimal<U>(max_decimal_precision<U>(), ty.get_scale()); + return create_decimal(precision, scale, false); } } diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h index 5c98e72486..2a8da748e3 100644 --- a/be/src/vec/functions/function_binary_arithmetic.h +++ b/be/src/vec/functions/function_binary_arithmetic.h @@ -730,10 +730,9 @@ public: if constexpr (!std::is_same_v<ResultDataType, InvalidType>) { if constexpr (IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType>) { - ResultDataType result_type = decimal_result_type( - left, right, OpTraits::is_multiply, OpTraits::is_division); - type_res = std::make_shared<ResultDataType>(result_type.get_precision(), - result_type.get_scale()); + type_res = decimal_result_type(left, right, OpTraits::is_multiply, + OpTraits::is_division, + OpTraits::is_plus_minus); } else if constexpr (IsDataTypeDecimal<LeftDataType>) { type_res = std::make_shared<LeftDataType>(left.get_precision(), left.get_scale()); diff --git a/regression-test/data/decimalv3/test_decimalv3.out b/regression-test/data/decimalv3/test_decimalv3.out index 1bb8b045c0..f8d56b4c41 100644 --- a/regression-test/data/decimalv3/test_decimalv3.out +++ b/regression-test/data/decimalv3/test_decimalv3.out @@ -2,3 +2,6 @@ -- !decimalv3 -- 100.000000000000000000 +-- !decimalv3 -- +100.00000000000000000000 + diff --git a/regression-test/suites/decimalv3/test_decimalv3.groovy b/regression-test/suites/decimalv3/test_decimalv3.groovy index 374e554b93..8b8b010240 100644 --- a/regression-test/suites/decimalv3/test_decimalv3.groovy +++ b/regression-test/suites/decimalv3/test_decimalv3.groovy @@ -26,4 +26,5 @@ suite("test_decimalv3") { sql "create view test5_v (amout) as select cast(a*b as decimalv3(38,18)) from test5" qt_decimalv3 "select * from test5_v" + qt_decimalv3 "select cast(a as decimalv3(12,10)) * cast(b as decimalv3(18,10)) from test5" } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org