This is an automated email from the ASF dual-hosted git repository. gabriellee pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 520b6d7910 [Improvement](decimalv3) Add a config to check overflow for DECIMALV3 (#15463) 520b6d7910 is described below commit 520b6d791044425c211fea54c088d5f10893b060 Author: Gabriel <gabrielleeb...@gmail.com> AuthorDate: Fri Dec 30 14:02:24 2022 +0800 [Improvement](decimalv3) Add a config to check overflow for DECIMALV3 (#15463) --- be/src/runtime/runtime_state.h | 5 + be/src/udf/udf_internal.h | 8 ++ be/src/vec/data_types/data_type_decimal.h | 30 ++++- be/src/vec/exprs/vexpr_context.cpp | 2 + be/src/vec/functions/function_binary_arithmetic.h | 142 ++++++++++++++------- be/src/vec/functions/function_cast.h | 124 +++++++++++------- .../org/apache/doris/analysis/ArithmeticExpr.java | 19 +-- .../main/java/org/apache/doris/analysis/Expr.java | 21 +++ .../java/org/apache/doris/qe/SessionVariable.java | 10 ++ gensrc/thrift/PaloInternalService.thrift | 1 + .../data/datatype_p0/decimalv3/test_overflow.out | 19 +++ .../datatype_p0/decimalv3/test_overflow.groovy | 56 ++++++++ 12 files changed, 335 insertions(+), 102 deletions(-) diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h index dedef5340d..e9650d0702 100644 --- a/be/src/runtime/runtime_state.h +++ b/be/src/runtime/runtime_state.h @@ -135,6 +135,11 @@ public: _query_options.enable_function_pushdown; } + bool check_overflow_for_decimal() const { + return _query_options.__isset.check_overflow_for_decimal && + _query_options.check_overflow_for_decimal; + } + // Create a codegen object in _codegen. No-op if it has already been called. // If codegen is enabled for the query, this is created when the runtime // state is created. If codegen is disabled for the query, this is created diff --git a/be/src/udf/udf_internal.h b/be/src/udf/udf_internal.h index 1bc4fefd0b..67a8ec60e7 100644 --- a/be/src/udf/udf_internal.h +++ b/be/src/udf/udf_internal.h @@ -109,6 +109,12 @@ public: const doris_udf::FunctionContext::TypeDesc& get_return_type() const { return _return_type; } + const bool check_overflow_for_decimal() const { return _check_overflow_for_decimal; } + + bool set_check_overflow_for_decimal(bool check_overflow_for_decimal) { + return _check_overflow_for_decimal = check_overflow_for_decimal; + } + private: friend class doris_udf::FunctionContext; friend class ExprContext; @@ -181,6 +187,8 @@ private: // call that passes the correct AnyVal subclass pointer type. std::vector<doris_udf::AnyVal*> _staging_input_vals; + bool _check_overflow_for_decimal = false; + // Indicates whether this context has been closed. Used for verification/debugging. bool _closed; diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h index c8e08303a1..2213093104 100644 --- a/be/src/vec/data_types/data_type_decimal.h +++ b/be/src/vec/data_types/data_type_decimal.h @@ -71,6 +71,11 @@ constexpr Int128 max_decimal_value<Decimal128>() { return static_cast<int128_t>(999999999999999999ll) * 100000000000000000ll * 1000ll + static_cast<int128_t>(99999999999999999ll) * 1000ll + 999ll; } +template <> +constexpr Int128 max_decimal_value<Decimal128I>() { + return static_cast<int128_t>(999999999999999999ll) * 100000000000000000ll * 1000ll + + static_cast<int128_t>(99999999999999999ll) * 1000ll + 999ll; +} DataTypePtr create_decimal(UInt64 precision, UInt64 scale, bool use_v2); @@ -291,8 +296,8 @@ constexpr bool IsDataTypeDecimalOrNumber = template <typename FromDataType, typename ToDataType> inline std::enable_if_t<IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>, typename ToDataType::FieldType> -convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_from, - UInt32 scale_to) { +convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_from, UInt32 scale_to, + UInt8* overflow_flag = nullptr) { using FromFieldType = typename FromDataType::FieldType; using ToFieldType = typename ToDataType::FieldType; using MaxFieldType = @@ -310,6 +315,9 @@ convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_fro DataTypeDecimal<MaxFieldType>::get_scale_multiplier(scale_to - scale_from); if (common::mul_overflow(static_cast<MaxNativeType>(value), converted_value, converted_value)) { + if (overflow_flag) { + *overflow_flag = 1; + } VLOG_DEBUG << "Decimal convert overflow"; return converted_value < 0 ? std::numeric_limits<typename ToFieldType::NativeType>::min() @@ -322,10 +330,16 @@ convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_fro if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType)) { if (converted_value < std::numeric_limits<typename ToFieldType::NativeType>::min()) { + if (overflow_flag) { + *overflow_flag = 1; + } VLOG_DEBUG << "Decimal convert overflow"; return std::numeric_limits<typename ToFieldType::NativeType>::min(); } if (converted_value > std::numeric_limits<typename ToFieldType::NativeType>::max()) { + if (overflow_flag) { + *overflow_flag = 1; + } VLOG_DEBUG << "Decimal convert overflow"; return std::numeric_limits<typename ToFieldType::NativeType>::max(); } @@ -381,12 +395,16 @@ convert_from_decimal(const typename FromDataType::FieldType& value, UInt32 scale template <typename FromDataType, typename ToDataType> inline std::enable_if_t<IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>, typename ToDataType::FieldType> -convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale) { +convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale, + UInt8* overflow_flag) { using FromFieldType = typename FromDataType::FieldType; using ToNativeType = typename ToDataType::FieldType::NativeType; if constexpr (std::is_floating_point_v<FromFieldType>) { if (!std::isfinite(value)) { + if (overflow_flag) { + *overflow_flag = 1; + } VLOG_DEBUG << "Decimal convert overflow. Cannot convert infinity or NaN to decimal"; return value < 0 ? std::numeric_limits<ToNativeType>::min() : std::numeric_limits<ToNativeType>::max(); @@ -395,10 +413,16 @@ convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale) FromFieldType out; out = value * ToDataType::get_scale_multiplier(scale); if (out <= static_cast<FromFieldType>(std::numeric_limits<ToNativeType>::min())) { + if (overflow_flag) { + *overflow_flag = 1; + } VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range"; return std::numeric_limits<ToNativeType>::min(); } if (out >= static_cast<FromFieldType>(std::numeric_limits<ToNativeType>::max())) { + if (overflow_flag) { + *overflow_flag = 1; + } VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range"; return std::numeric_limits<ToNativeType>::max(); } diff --git a/be/src/vec/exprs/vexpr_context.cpp b/be/src/vec/exprs/vexpr_context.cpp index ccb1045cb1..9033245202 100644 --- a/be/src/vec/exprs/vexpr_context.cpp +++ b/be/src/vec/exprs/vexpr_context.cpp @@ -110,6 +110,8 @@ int VExprContext::register_func(RuntimeState* state, const FunctionContext::Type int varargs_buffer_size) { _fn_contexts.push_back(FunctionContextImpl::create_context( state, _pool.get(), return_type, arg_types, varargs_buffer_size, false)); + _fn_contexts.back()->impl()->set_check_overflow_for_decimal( + state->check_overflow_for_decimal()); return _fn_contexts.size() - 1; } diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h index 2a8da748e3..d1c0375a55 100644 --- a/be/src/vec/functions/function_binary_arithmetic.h +++ b/be/src/vec/functions/function_binary_arithmetic.h @@ -23,6 +23,7 @@ #include <type_traits> #include "runtime/decimalv2_value.h" +#include "udf/udf_internal.h" #include "vec/columns/column_const.h" #include "vec/columns/column_decimal.h" #include "vec/columns/column_nullable.h" @@ -216,7 +217,8 @@ struct BinaryOperationImpl { /// * no agrs scale. ScaleR = Scale1 + Scale2; /// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::get_scale()). template <typename A, typename B, template <typename, typename> typename Operation, - typename ResultType, bool is_to_null_type, bool check_overflow = true> + typename ResultType, bool is_to_null_type, bool return_nullable_type, + bool check_overflow = true> struct DecimalBinaryOperation { using OpTraits = OperationTraits<Operation>; @@ -249,12 +251,14 @@ struct DecimalBinaryOperation { for (size_t i = 0; i < size; ++i) { c[i] = apply(a[i], b[i], null_map[i]); } - } else { - if constexpr (OpTraits::is_division && IsDecimalNumber<B>) { - for (size_t i = 0; i < size; ++i) { - c[i] = apply_scaled_div(a[i], b[i], null_map[i]); - } - return; + } else if constexpr (OpTraits::is_division && (IsDecimalNumber<B> || IsDecimalNumber<A>)) { + for (size_t i = 0; i < size; ++i) { + c[i] = apply_scaled_div(a[i], b[i], null_map[i]); + } + } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) && + (IsDecimalNumber<B> || IsDecimalNumber<A>)) { + for (size_t i = 0; i < size; ++i) { + null_map[i] = apply_op_safely(a[i], b[i], c[i].value); } } } @@ -281,21 +285,21 @@ struct DecimalBinaryOperation { for (size_t i = 0; i < size; ++i) { c[i] = apply_scaled_div(a[i], b, null_map[i]); } - return; - } - - for (size_t i = 0; i < size; ++i) { - c[i] = apply(a[i], b, null_map[i]); + } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) && + (IsDecimalNumber<B> || IsDecimalNumber<A>)) { + for (size_t i = 0; i < size; ++i) { + null_map[i] = apply_op_safely(a[i], b, c[i].value); + } + } else { + for (size_t i = 0; i < size; ++i) { + c[i] = apply(a[i], b, null_map[i]); + } } } static void constant_vector(A a, const typename Traits::ArrayB& b, ArrayC& c) { size_t size = b.size(); - if constexpr (OpTraits::is_division && IsDecimalNumber<B>) { - for (size_t i = 0; i < size; ++i) { - c[i] = apply_scaled_div(a, b[i]); - } - } else if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) { + if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) { DecimalV2Value da(a); for (size_t i = 0; i < size; ++i) { c[i] = Op::template apply(da, DecimalV2Value(b[i])).value(); @@ -314,33 +318,43 @@ struct DecimalBinaryOperation { for (size_t i = 0; i < size; ++i) { c[i] = apply_scaled_div(a, b[i], null_map[i]); } - return; - } - - for (size_t i = 0; i < size; ++i) { - c[i] = apply(a, b[i], null_map[i]); + } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) && + (IsDecimalNumber<B> || IsDecimalNumber<A>)) { + for (size_t i = 0; i < size; ++i) { + null_map[i] = apply_op_safely(a, b[i], c[i].value); + } + } else { + for (size_t i = 0; i < size; ++i) { + c[i] = apply(a, b[i], null_map[i]); + } } } - static ResultType constant_constant(A a, B b) { - if constexpr (OpTraits::is_division && IsDecimalNumber<B>) { - return apply_scaled_div(a, b); - } - return apply(a, b); - } + static ResultType constant_constant(A a, B b) { return apply(a, b); } static ResultType constant_constant(A a, B b, UInt8& is_null) { if constexpr (OpTraits::is_division && IsDecimalNumber<B>) { return apply_scaled_div(a, b, is_null); + } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) && + (IsDecimalNumber<B> || IsDecimalNumber<A>)) { + NativeResultType res; + is_null = apply_op_safely(a, b, res); + return res; + } else { + return apply(a, b, is_null); } - return apply(a, b, is_null); } static ColumnPtr adapt_decimal_constant_constant(A a, B b, DataTypePtr res_data_type) { auto column_result = ColumnDecimal<ResultType>::create( 1, assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale()); - if constexpr (is_to_null_type) { + if constexpr (return_nullable_type && !is_to_null_type && + ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> || + IsDecimalV2<B>)) { + LOG(FATAL) << "Invalid function type!"; + return column_result; + } else if constexpr (return_nullable_type || is_to_null_type) { auto null_map = ColumnUInt8::create(1, 0); column_result->get_element(0) = constant_constant(a, b, null_map->get_element(0)); return ColumnNullable::create(std::move(column_result), std::move(null_map)); @@ -358,7 +372,12 @@ struct DecimalBinaryOperation { assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale()); DCHECK(column_left_ptr != nullptr); - if constexpr (is_to_null_type) { + if constexpr (return_nullable_type && !is_to_null_type && + ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> || + IsDecimalV2<B>)) { + LOG(FATAL) << "Invalid function type!"; + return column_result; + } else if constexpr (return_nullable_type || is_to_null_type) { auto null_map = ColumnUInt8::create(column_left->size(), 0); vector_constant(column_left_ptr->get_data(), b, column_result->get_data(), null_map->get_data()); @@ -377,7 +396,12 @@ struct DecimalBinaryOperation { assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale()); DCHECK(column_right_ptr != nullptr); - if constexpr (is_to_null_type) { + if constexpr (return_nullable_type && !is_to_null_type && + ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> || + IsDecimalV2<B>)) { + LOG(FATAL) << "Invalid function type!"; + return column_result; + } else if constexpr (return_nullable_type || is_to_null_type) { auto null_map = ColumnUInt8::create(column_right->size(), 0); constant_vector(a, column_right_ptr->get_data(), column_result->get_data(), null_map->get_data()); @@ -398,7 +422,12 @@ struct DecimalBinaryOperation { assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale()); DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr); - if constexpr (is_to_null_type) { + if constexpr (return_nullable_type && !is_to_null_type && + ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> || + IsDecimalV2<B>)) { + LOG(FATAL) << "Invalid function type!"; + return column_result; + } else if constexpr (return_nullable_type || is_to_null_type) { auto null_map = ColumnUInt8::create(column_result->size(), 0); vector_vector(column_left_ptr->get_data(), column_right_ptr->get_data(), column_result->get_data(), null_map->get_data()); @@ -483,6 +512,12 @@ private: UInt8& is_null) { return apply(a, b, is_null); } + + static UInt8 apply_op_safely(NativeResultType a, NativeResultType b, NativeResultType& c) { + if constexpr (OpTraits::is_multiply || OpTraits::is_plus_minus) { + return Op::template apply(a, b, c); + } + } }; /// Used to indicate undefined operation @@ -568,7 +603,8 @@ struct BinaryOperationTraits { }; template <typename LeftDataType, typename RightDataType, typename ExpectedResultDataType, - template <typename, typename> class Operation, bool is_to_null_type> + template <typename, typename> class Operation, bool is_to_null_type, + bool return_nullable_type> struct ConstOrVectorAdapter { static constexpr bool result_is_decimal = IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>; @@ -580,7 +616,8 @@ struct ConstOrVectorAdapter { using OperationImpl = std::conditional_t< IsDataTypeDecimal<ResultDataType>, - DecimalBinaryOperation<A, B, Operation, ResultType, is_to_null_type>, + DecimalBinaryOperation<A, B, Operation, ResultType, is_to_null_type, + return_nullable_type>, BinaryOperationImpl<A, B, Operation<A, B>, is_to_null_type, ResultType>>; static ColumnPtr execute(ColumnPtr column_left, ColumnPtr column_right, @@ -774,6 +811,7 @@ public: right_generic = static_cast<const DataTypeNullable*>(right_generic)->get_nested_type().get(); } + bool result_is_nullable = context->impl()->check_overflow_for_decimal(); if (result_generic->is_nullable()) { result_generic = static_cast<const DataTypeNullable*>(result_generic)->get_nested_type().get(); @@ -795,15 +833,31 @@ public: ResultDataType>)&&(IsDataTypeDecimal<ExpectedResultDataType> == (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>))) { - auto column_result = ConstOrVectorAdapter< - LeftDataType, RightDataType, - std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>, - ExpectedResultDataType, ResultDataType>, - Operation, is_to_null_type>:: - execute(block.get_by_position(arguments[0]).column, - block.get_by_position(arguments[1]).column, left, right, - remove_nullable(block.get_by_position(result).type)); - block.replace_by_position(result, std::move(column_result)); + if (result_is_nullable) { + auto column_result = ConstOrVectorAdapter< + LeftDataType, RightDataType, + std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>, + ExpectedResultDataType, ResultDataType>, + Operation, is_to_null_type, + true>::execute(block.get_by_position(arguments[0]).column, + block.get_by_position(arguments[1]).column, left, + right, + remove_nullable( + block.get_by_position(result).type)); + block.replace_by_position(result, std::move(column_result)); + } else { + auto column_result = ConstOrVectorAdapter< + LeftDataType, RightDataType, + std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>, + ExpectedResultDataType, ResultDataType>, + Operation, is_to_null_type, + false>::execute(block.get_by_position(arguments[0]).column, + block.get_by_position(arguments[1]).column, + left, right, + remove_nullable( + block.get_by_position(result).type)); + block.replace_by_position(result, std::move(column_result)); + } return true; } return false; diff --git a/be/src/vec/functions/function_cast.h b/be/src/vec/functions/function_cast.h index 90edd3906b..e3baaecdd2 100644 --- a/be/src/vec/functions/function_cast.h +++ b/be/src/vec/functions/function_cast.h @@ -22,6 +22,7 @@ #include <fmt/format.h> +#include "udf/udf_internal.h" #include "vec/columns/column_array.h" #include "vec/columns/column_const.h" #include "vec/columns/column_nullable.h" @@ -72,7 +73,7 @@ struct ConvertImpl { template <typename Additions = void*> static Status execute(Block& block, const ColumnNumbers& arguments, size_t result, - size_t /*input_rows_count*/, + size_t /*input_rows_count*/, bool check_overflow [[maybe_unused]] = false, Additions additions [[maybe_unused]] = Additions()) { const ColumnWithTypeAndName& named_from = block.get_by_position(arguments[0]); @@ -96,37 +97,50 @@ struct ConvertImpl { if constexpr (IsDataTypeDecimal<ToDataType>) { UInt32 scale = additions; col_to = ColVecTo::create(0, scale); - } else + } else { col_to = ColVecTo::create(); + } const auto& vec_from = col_from->get_data(); auto& vec_to = col_to->get_data(); size_t size = vec_from.size(); vec_to.resize(size); - for (size_t i = 0; i < size; ++i) { - if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) { - if constexpr (IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>) + if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) { + ColumnUInt8::MutablePtr col_null_map_to = nullptr; + UInt8* vec_null_map_to = nullptr; + if (check_overflow) { + col_null_map_to = ColumnUInt8::create(size, 0); + vec_null_map_to = col_null_map_to->get_data().data(); + } + for (size_t i = 0; i < size; ++i) { + if constexpr (IsDataTypeDecimal<FromDataType> && + IsDataTypeDecimal<ToDataType>) { vec_to[i] = convert_decimals<FromDataType, ToDataType>( - vec_from[i], vec_from.get_scale(), vec_to.get_scale()); - else if constexpr (IsDataTypeDecimal<FromDataType> && - IsDataTypeNumber<ToDataType>) + vec_from[i], vec_from.get_scale(), vec_to.get_scale(), + vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to); + } else if constexpr (IsDataTypeDecimal<FromDataType> && + IsDataTypeNumber<ToDataType>) { vec_to[i] = convert_from_decimal<FromDataType, ToDataType>( vec_from[i], vec_from.get_scale()); - else if constexpr (IsDataTypeNumber<FromDataType> && - IsDataTypeDecimal<ToDataType>) + } else if constexpr (IsDataTypeNumber<FromDataType> && + IsDataTypeDecimal<ToDataType>) { vec_to[i] = convert_to_decimal<FromDataType, ToDataType>( - vec_from[i], vec_to.get_scale()); - else if constexpr (IsTimeType<FromDataType> && IsDataTypeDecimal<ToDataType>) { + vec_from[i], vec_to.get_scale(), + vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to); + } else if constexpr (IsTimeType<FromDataType> && + IsDataTypeDecimal<ToDataType>) { vec_to[i] = convert_to_decimal<DataTypeInt64, ToDataType>( reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64(), - vec_to.get_scale()); + vec_to.get_scale(), + vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to); } else if constexpr (IsDateV2Type<FromDataType> && IsDataTypeDecimal<ToDataType>) { vec_to[i] = convert_to_decimal<DataTypeUInt32, ToDataType>( reinterpret_cast<const DateV2Value<DateV2ValueType>&>(vec_from[i]) .to_date_int_val(), - vec_to.get_scale()); + vec_to.get_scale(), + vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to); } else if constexpr (IsDateTimeV2Type<FromDataType> && IsDataTypeDecimal<ToDataType>) { // TODO: should we consider the scale of datetimev2? @@ -134,9 +148,21 @@ struct ConvertImpl { reinterpret_cast<const DateV2Value<DateTimeV2ValueType>&>( vec_from[i]) .to_date_int_val(), - vec_to.get_scale()); + vec_to.get_scale(), + vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to); } - } else if constexpr (IsTimeType<FromDataType>) { + } + if (check_overflow) { + block.replace_by_position( + result, + ColumnNullable::create(std::move(col_to), std::move(col_null_map_to))); + } else { + block.replace_by_position(result, std::move(col_to)); + } + + return Status::OK(); + } else if constexpr (IsTimeType<FromDataType>) { + for (size_t i = 0; i < size; ++i) { if constexpr (IsTimeType<ToDataType>) { vec_to[i] = static_cast<ToFieldType>(vec_from[i]); if constexpr (IsDateTimeType<ToDataType>) { @@ -152,7 +178,9 @@ struct ConvertImpl { vec_to[i] = reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64(); } - } else if constexpr (IsTimeV2Type<FromDataType>) { + } + } else if constexpr (IsTimeV2Type<FromDataType>) { + for (size_t i = 0; i < size; ++i) { if constexpr (IsTimeV2Type<ToDataType>) { if constexpr (IsDateTimeV2Type<ToDataType> && IsDateV2Type<FromDataType>) { DataTypeDateV2::cast_to_date_time_v2(vec_from[i], vec_to[i]); @@ -189,7 +217,9 @@ struct ConvertImpl { .to_int64(); } } - } else { + } + } else { + for (size_t i = 0; i < size; ++i) { vec_to[i] = static_cast<ToFieldType>(vec_from[i]); } } @@ -547,7 +577,7 @@ struct ConvertImpl<DataTypeString, ToDataType, Name> { template <typename Additions = void*> static Status execute(Block& block, const ColumnNumbers& arguments, size_t result, - size_t /*input_rows_count*/, + size_t /*input_rows_count*/, bool check_overflow [[maybe_unused]] = false, Additions additions [[maybe_unused]] = Additions()) { return Status::RuntimeError("not support convert from string"); } @@ -832,19 +862,6 @@ public: Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) override { - return executeInternal(block, arguments, result, input_rows_count); - } - - bool has_information_about_monotonicity() const override { return Monotonic::has(); } - - Monotonicity get_monotonicity_for_range(const IDataType& type, const Field& left, - const Field& right) const override { - return Monotonic::get(type, left, right); - } - -private: - Status executeInternal(Block& block, const ColumnNumbers& arguments, size_t result, - size_t input_rows_count) { if (!arguments.size()) { return Status::RuntimeError("Function {} expects at least 1 arguments", get_name()); } @@ -873,13 +890,15 @@ private: UInt32 scale = extract_to_decimal_scale(scale_column); ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute( - block, arguments, result, input_rows_count, scale); + block, arguments, result, input_rows_count, + context->impl()->check_overflow_for_decimal(), scale); } else if constexpr (IsDataTypeDateTimeV2<RightDataType>) { const ColumnWithTypeAndName& scale_column = block.get_by_position(result); auto type = check_and_get_data_type<DataTypeDateTimeV2>(scale_column.type.get()); ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute( - block, arguments, result, input_rows_count, type->get_scale()); + block, arguments, result, input_rows_count, + context->impl()->check_overflow_for_decimal(), type->get_scale()); } else { ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute( block, arguments, result, input_rows_count); @@ -896,6 +915,13 @@ private: return ret_status; } } + + bool has_information_about_monotonicity() const override { return Monotonic::has(); } + + Monotonicity get_monotonicity_for_range(const IDataType& type, const Field& left, + const Field& right) const override { + return Monotonic::get(type, left, right); + } }; using FunctionToUInt8 = FunctionConvert<DataTypeUInt8, NameToUInt8, ToNumberMonotonicity<UInt8>>; @@ -1055,7 +1081,7 @@ struct ConvertThroughParsing { template <typename Additions = void*> static Status execute(Block& block, const ColumnNumbers& arguments, size_t result, - size_t input_rows_count, + size_t input_rows_count, bool check_overflow [[maybe_unused]] = false, Additions additions [[maybe_unused]] = Additions()) { using ColVecTo = std::conditional_t<IsDecimalNumber<ToFieldType>, ColumnDecimal<ToFieldType>, ColumnVector<ToFieldType>>; @@ -1254,7 +1280,8 @@ public: const ColumnNumbers& /*arguments*/, size_t /*result*/) const override { return std::make_shared<PreparedFunctionCast>( - prepare_unpack_dictionaries(get_argument_types()[0], get_return_type()), name); + prepare_unpack_dictionaries(context, get_argument_types()[0], get_return_type()), + name); } String get_name() const override { return name; } @@ -1347,7 +1374,8 @@ private: using RightDataType = typename Types::RightType; ConvertImpl<LeftDataType, RightDataType, NameCast>::execute( - block, arguments, result, input_rows_count, scale); + block, arguments, result, input_rows_count, + context->impl()->check_overflow_for_decimal(), scale); return true; }); @@ -1396,7 +1424,7 @@ private: return create_unsupport_wrapper(error_msg); } - WrapperType create_array_wrapper(const DataTypePtr& from_type_untyped, + WrapperType create_array_wrapper(FunctionContext* context, const DataTypePtr& from_type_untyped, const DataTypeArray& to_type) const { /// Conversion from String through parsing. if (check_and_get_data_type<DataTypeString>(from_type_untyped.get())) { @@ -1425,7 +1453,8 @@ private: const DataTypePtr& to_nested_type = to_type.get_nested_type(); /// Prepare nested type conversion - const auto nested_function = prepare_unpack_dictionaries(from_nested_type, to_nested_type); + const auto nested_function = + prepare_unpack_dictionaries(context, from_nested_type, to_nested_type); return [nested_function, from_nested_type, to_nested_type]( FunctionContext* context, Block& block, const ColumnNumbers& arguments, @@ -1513,7 +1542,7 @@ private: } } - WrapperType prepare_unpack_dictionaries(const DataTypePtr& from_type, + WrapperType prepare_unpack_dictionaries(FunctionContext* context, const DataTypePtr& from_type, const DataTypePtr& to_type) const { const auto& from_nested = from_type; const auto& to_nested = to_type; @@ -1534,18 +1563,20 @@ private: constexpr bool skip_not_null_check = false; - auto wrapper = prepare_remove_nullable(from_nested, to_nested, skip_not_null_check); + auto wrapper = + prepare_remove_nullable(context, from_nested, to_nested, skip_not_null_check); return wrapper; } - WrapperType prepare_remove_nullable(const DataTypePtr& from_type, const DataTypePtr& to_type, + WrapperType prepare_remove_nullable(FunctionContext* context, const DataTypePtr& from_type, + const DataTypePtr& to_type, bool skip_not_null_check) const { /// Determine whether pre-processing and/or post-processing must take place during conversion. bool source_is_nullable = from_type->is_nullable(); bool result_is_nullable = to_type->is_nullable(); - auto wrapper = prepare_impl(remove_nullable(from_type), remove_nullable(to_type), + auto wrapper = prepare_impl(context, remove_nullable(from_type), remove_nullable(to_type), result_is_nullable); if (result_is_nullable) { @@ -1620,8 +1651,8 @@ private: /// 'from_type' and 'to_type' are nested types in case of Nullable. /// 'requested_result_is_nullable' is true if CAST to Nullable type is requested. - WrapperType prepare_impl(const DataTypePtr& from_type, const DataTypePtr& to_type, - bool requested_result_is_nullable) const { + WrapperType prepare_impl(FunctionContext* context, const DataTypePtr& from_type, + const DataTypePtr& to_type, bool requested_result_is_nullable) const { if (from_type->equals(*to_type)) return create_identity_wrapper(from_type); else if (WhichDataType(from_type).is_nothing()) @@ -1679,7 +1710,8 @@ private: case TypeIndex::String: return create_string_wrapper(from_type); case TypeIndex::Array: - return create_array_wrapper(from_type, static_cast<const DataTypeArray&>(*to_type)); + return create_array_wrapper(context, from_type, + static_cast<const DataTypeArray&>(*to_type)); default: break; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java index e5502fe635..99b6e39e3b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java @@ -21,6 +21,7 @@ package org.apache.doris.analysis; import org.apache.doris.catalog.Function; +import org.apache.doris.catalog.Function.NullableMode; import org.apache.doris.catalog.FunctionSet; import org.apache.doris.catalog.PrimitiveType; import org.apache.doris.catalog.ScalarFunction; @@ -107,12 +108,13 @@ public class ArithmeticExpr extends Expr { public static void initBuiltins(FunctionSet functionSet) { for (Type t : Type.getNumericTypes()) { + NullableMode mode = t.isDecimalV3() ? NullableMode.CUSTOM : NullableMode.DEPEND_ON_ARGUMENT; functionSet.addBuiltin(ScalarFunction.createBuiltinOperator( - Operator.MULTIPLY.getName(), Lists.newArrayList(t, t), t)); + Operator.MULTIPLY.getName(), Lists.newArrayList(t, t), t, mode)); functionSet.addBuiltin(ScalarFunction.createBuiltinOperator( - Operator.ADD.getName(), Lists.newArrayList(t, t), t)); + Operator.ADD.getName(), Lists.newArrayList(t, t), t, mode)); functionSet.addBuiltin(ScalarFunction.createBuiltinOperator( - Operator.SUBTRACT.getName(), Lists.newArrayList(t, t), t)); + Operator.SUBTRACT.getName(), Lists.newArrayList(t, t), t, mode)); } functionSet.addBuiltin(ScalarFunction.createBuiltinOperator( Operator.DIVIDE.getName(), @@ -173,15 +175,14 @@ public class ArithmeticExpr extends Expr { for (int j = 0; j < Type.getNumericTypes().size(); j++) { Type t2 = Type.getNumericTypes().get(j); + Type retType = Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false)); + NullableMode mode = retType.isDecimalV3() ? NullableMode.CUSTOM : NullableMode.DEPEND_ON_ARGUMENT; functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( - Operator.MULTIPLY.getName(), Lists.newArrayList(t1, t2), - Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false)))); + Operator.MULTIPLY.getName(), Lists.newArrayList(t1, t2), retType, mode)); functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( - Operator.ADD.getName(), Lists.newArrayList(t1, t2), - Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false)))); + Operator.ADD.getName(), Lists.newArrayList(t1, t2), retType, mode)); functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator( - Operator.SUBTRACT.getName(), Lists.newArrayList(t1, t2), - Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false)))); + Operator.SUBTRACT.getName(), Lists.newArrayList(t1, t2), retType, mode)); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index 79df8f3395..9c29c7ffee 100755 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -20,6 +20,7 @@ package org.apache.doris.analysis; +import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.Function; import org.apache.doris.catalog.FunctionSet; @@ -31,6 +32,7 @@ import org.apache.doris.common.Config; import org.apache.doris.common.TreeNode; import org.apache.doris.common.io.Writable; import org.apache.doris.common.util.VectorizedUtil; +import org.apache.doris.qe.ConnectContext; import org.apache.doris.statistics.ExprStats; import org.apache.doris.thrift.TExpr; import org.apache.doris.thrift.TExprNode; @@ -2036,6 +2038,25 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl if (fn.functionName().equalsIgnoreCase("concat_ws")) { return children.get(0).isNullable(); } + if (fn.functionName().equalsIgnoreCase(Operator.MULTIPLY.getName()) + && fn.getReturnType().isDecimalV3()) { + if (ConnectContext.get() != null + && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) { + return true; + } else { + return hasNullableChild(); + } + } + if ((fn.functionName().equalsIgnoreCase(Operator.ADD.getName()) + || fn.functionName().equalsIgnoreCase(Operator.SUBTRACT.getName())) + && fn.getReturnType().isDecimalV3()) { + if (ConnectContext.get() != null + && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) { + return true; + } else { + return hasNullableChild(); + } + } return true; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 7faa492295..dcea629834 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -189,6 +189,8 @@ public class SessionVariable implements Serializable, Writable { public static final String ENABLE_PROJECTION = "enable_projection"; + public static final String CHECK_OVERFLOW_FOR_DECIMAL = "check_overflow_for_decimal"; + public static final String TRIM_TAILING_SPACES_FOR_EXTERNAL_TABLE_QUERY = "trim_tailing_spaces_for_external_table_query"; @@ -542,6 +544,9 @@ public class SessionVariable implements Serializable, Writable { @VariableMgr.VarAttr(name = ENABLE_PROJECTION) private boolean enableProjection = true; + @VariableMgr.VarAttr(name = CHECK_OVERFLOW_FOR_DECIMAL) + private boolean checkOverflowForDecimal = false; + /** * as the new optimizer is not mature yet, use this var * to control whether to use new optimizer, remove it when @@ -1235,6 +1240,10 @@ public class SessionVariable implements Serializable, Writable { return enableProjection; } + public boolean checkOverflowForDecimal() { + return checkOverflowForDecimal; + } + public boolean isTrimTailingSpacesForExternalTableQuery() { return trimTailingSpacesForExternalTableQuery; } @@ -1368,6 +1377,7 @@ public class SessionVariable implements Serializable, Writable { } tResult.setEnableFunctionPushdown(enableFunctionPushdown); + tResult.setCheckOverflowForDecimal(checkOverflowForDecimal); tResult.setFragmentTransmissionCompressionCodec(fragmentTransmissionCompressionCodec); tResult.setEnableLocalExchange(enableLocalExchange); tResult.setEnableNewShuffleHashMethod(enableNewShuffleHashMethod); diff --git a/gensrc/thrift/PaloInternalService.thrift b/gensrc/thrift/PaloInternalService.thrift index 32d6721bb3..bb241c4aa0 100644 --- a/gensrc/thrift/PaloInternalService.thrift +++ b/gensrc/thrift/PaloInternalService.thrift @@ -187,6 +187,7 @@ struct TQueryOptions { 55: optional bool enable_pipeline_engine = false 56: optional i32 repeat_max_num = 0 + 57: optional bool check_overflow_for_decimal = false } diff --git a/regression-test/data/datatype_p0/decimalv3/test_overflow.out b/regression-test/data/datatype_p0/decimalv3/test_overflow.out new file mode 100644 index 0000000000..c9b9873cd7 --- /dev/null +++ b/regression-test/data/datatype_p0/decimalv3/test_overflow.out @@ -0,0 +1,19 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select_all -- +11111111111111111111.100000000000000000 11111111111111111111.200000000000000000 11111111111111111111.300000000000000000 1.1000000000000000000000000000000000000 1.2000000000000000000000000000000000000 1.3000000000000000000000000000000000000 9 + +-- !select_check_overflow1 -- +\N \N \N 99999999999999999999.900000000000000000 \N + +-- !select_check_overflow2 -- +1.1000000000000000000000000000000000000 111111111111111111111.000000000000000000 \N + +-- !select_check_overflow3 -- +11111111111111111111.100000000000000000 \N + +-- !select_not_check_overflow1 -- +99.999999999999999999999999999999999999 99.999999999999999999999999999999999999 1.1111111111111111E21 99999999999999999999.900000000000000000 99999999999999999999.999999999999999999 + +-- !select_not_check_overflow2 -- +1.1000000000000000000000000000000000000 111111111111111111111.000000000000000000 -15.9141183460469231731687303715884105728 + diff --git a/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy b/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy new file mode 100644 index 0000000000..01de2ea498 --- /dev/null +++ b/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_overflow") { + + def table1 = "test_overflow" + + sql "drop table if exists ${table1}" + + sql """ + CREATE TABLE IF NOT EXISTS test_overflow ( + `k1` decimalv3(38, 18) NULL COMMENT "", + `k2` decimalv3(38, 18) NULL COMMENT "", + `k3` decimalv3(38, 18) NULL COMMENT "", + `v1` decimalv3(38, 37) NULL COMMENT "", + `v2` decimalv3(38, 37) NULL COMMENT "", + `v3` decimalv3(38, 37) NULL COMMENT "", + `v4` INT NULL COMMENT "" + ) ENGINE=OLAP + COMMENT "OLAP" + DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "in_memory" = "false", + "storage_format" = "V2" + ) + """ + + sql """insert into test_overflow values(11111111111111111111.1,11111111111111111111.2,11111111111111111111.3, 1.1,1.2,1.3,9) + """ + qt_select_all "select * from test_overflow order by k1" + + sql " SET check_overflow_for_decimal = true; " + qt_select_check_overflow1 "select k1 * k2, k1 * k3, k1 * k2 * k3, k1 * v4, k1*50 from test_overflow;" + qt_select_check_overflow2 "select v1, k1*10, v1 +k1*10 from test_overflow" + qt_select_check_overflow3 "select `k1`, cast (`k1` as DECIMALV3(38, 36)) from test_overflow;" + + sql " SET check_overflow_for_decimal = false; " + qt_select_not_check_overflow1 "select k1 * k2, k1 * k3, k1 * k2 * k3, k1 * v4, k1*50 from test_overflow;" + qt_select_not_check_overflow2 "select v1, k1*10, v1 +k1*10 from test_overflow" + sql "drop table if exists ${table1}" +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org