This is an automated email from the ASF dual-hosted git repository. zhangstar333 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new a36e088e781 [enhancement](function truncate) truncate can use column as scale argument (#32746) a36e088e781 is described below commit a36e088e781f2fd47f4c3672b314f82a23d6e16f Author: zhiqiang <seuhezhiqi...@163.com> AuthorDate: Tue Apr 2 14:56:26 2024 +0800 [enhancement](function truncate) truncate can use column as scale argument (#32746) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- be/src/vec/functions/function_truncate.h | 245 ++++++++++++++ be/src/vec/functions/math.cpp | 23 +- be/src/vec/functions/round.h | 224 ++++++++++++- .../function/function_truncate_decimal_test.cpp | 370 +++++++++++++++++++++ .../apache/doris/analysis/FunctionCallExpr.java | 32 +- .../functions/ComputePrecisionForRound.java | 40 ++- .../math_functions/test_function_truncate.out | 101 ++++++ .../math_functions/test_function_truncate.groovy | 132 ++++++++ 8 files changed, 1136 insertions(+), 31 deletions(-) diff --git a/be/src/vec/functions/function_truncate.h b/be/src/vec/functions/function_truncate.h new file mode 100644 index 00000000000..e29bc99c041 --- /dev/null +++ b/be/src/vec/functions/function_truncate.h @@ -0,0 +1,245 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <cstddef> +#include <functional> +#include <type_traits> +#include <utility> + +#include "common/exception.h" +#include "common/status.h" +#include "olap/olap_common.h" +#include "round.h" +#include "vec/columns/column.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/call_on_type_index.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" + +namespace doris::vectorized { + +struct TruncateFloatOneArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { return {std::make_shared<DataTypeFloat64>()}; } +}; + +struct TruncateFloatTwoArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared<DataTypeFloat64>(), std::make_shared<DataTypeInt32>()}; + } +}; + +struct TruncateDecimalOneArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function + // So we can just register Decimal32 here + return {std::make_shared<DataTypeDecimal<Decimal32>>(9, 0)}; + } +}; + +struct TruncateDecimalTwoArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared<DataTypeDecimal<Decimal32>>(9, 0), + std::make_shared<DataTypeInt32>()}; + } +}; + +template <typename Impl> +class FunctionTruncate : public FunctionRounding<Impl, RoundingMode::Trunc, TieBreakingMode::Auto> { +public: + static FunctionPtr create() { return std::make_shared<FunctionTruncate>(); } + + ColumnNumbers get_arguments_that_are_always_constant() const override { return {}; } + // SELECT number, truncate(123.345, 1) FROM number("numbers"="10") + // should NOT behave like two column arguments, so we can not use const column default implementation + bool use_default_implementation_for_constants() const override { return false; } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); + ColumnPtr res; + + // potential argument types: + // 0. truncate(ColumnConst, ColumnConst) + // 1. truncate(Column), truncate(Column, ColumnConst) + // 2. truncate(Column, Column) + // 3. truncate(ColumnConst, Column) + + if (arguments.size() == 2 && is_column_const(*block.get_by_position(arguments[0]).column) && + is_column_const(*block.get_by_position(arguments[1]).column)) { + // truncate(ColumnConst, ColumnConst) + auto col_general = + assert_cast<const ColumnConst&>(*column_general.column).get_data_column_ptr(); + Int16 scale_arg = 0; + RETURN_IF_ERROR(FunctionTruncate<Impl>::get_scale_arg( + block.get_by_position(arguments[1]), &scale_arg)); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t<decltype(types)>; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) { + using FieldType = typename DataType::FieldType; + res = Dispatcher<FieldType, RoundingMode::Trunc, + TieBreakingMode::Auto>::apply_vec_const(col_general, + scale_arg); + return true; + } + + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "truncate"); + } + // Important, make sure the result column has the same size as the input column + res = ColumnConst::create(std::move(res), input_rows_count); + } else if (arguments.size() == 1 || + (arguments.size() == 2 && + is_column_const(*block.get_by_position(arguments[1]).column))) { + // truncate(Column) or truncate(Column, ColumnConst) + Int16 scale_arg = 0; + if (arguments.size() == 2) { + RETURN_IF_ERROR(FunctionTruncate<Impl>::get_scale_arg( + block.get_by_position(arguments[1]), &scale_arg)); + } + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t<decltype(types)>; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) { + using FieldType = typename DataType::FieldType; + res = Dispatcher<FieldType, RoundingMode::Trunc, TieBreakingMode::Auto>:: + apply_vec_const(column_general.column.get(), scale_arg); + return true; + } + + return false; + }; +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "truncate"); + } + + } else if (is_column_const(*block.get_by_position(arguments[0]).column)) { + // truncate(ColumnConst, Column) + const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); + const ColumnConst& const_col_general = + assert_cast<const ColumnConst&>(*column_general.column); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t<decltype(types)>; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) { + using FieldType = typename DataType::FieldType; + res = Dispatcher<FieldType, RoundingMode::Trunc, TieBreakingMode::Auto>:: + apply_const_vec(&const_col_general, column_scale.column.get()); + return true; + } + + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "truncate"); + } + } else { + // truncate(Column, Column) + const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t<decltype(types)>; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) { + using FieldType = typename DataType::FieldType; + res = Dispatcher<FieldType, RoundingMode::Trunc, TieBreakingMode::Auto>:: + apply_vec_vec(column_general.column.get(), column_scale.column.get()); + return true; + } + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "truncate"); + } + } + + block.replace_by_position(result, std::move(res)); + return Status::OK(); + } +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp index dc815cf74e5..c0dfe761576 100644 --- a/be/src/vec/functions/math.cpp +++ b/be/src/vec/functions/math.cpp @@ -46,6 +46,7 @@ #include "vec/functions/function_math_unary.h" #include "vec/functions/function_string.h" #include "vec/functions/function_totype.h" +#include "vec/functions/function_truncate.h" #include "vec/functions/function_unary_arithmetic.h" #include "vec/functions/round.h" #include "vec/functions/simple_function_factory.h" @@ -392,16 +393,14 @@ struct DecimalRoundOneImpl { // TODO: Now math may cause one thread compile time too long, because the function in math // so mush. Split it to speed up compile time in the future void register_function_math(SimpleFunctionFactory& factory) { -#define REGISTER_ROUND_FUNCTIONS(IMPL) \ - factory.register_function< \ - FunctionRounding<IMPL<RoundName>, RoundingMode::Round, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding<IMPL<FloorName>, RoundingMode::Floor, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding<IMPL<CeilName>, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding<IMPL<TruncateName>, RoundingMode::Trunc, TieBreakingMode::Auto>>(); \ - factory.register_function<FunctionRounding<IMPL<RoundBankersName>, RoundingMode::Round, \ +#define REGISTER_ROUND_FUNCTIONS(IMPL) \ + factory.register_function< \ + FunctionRounding<IMPL<RoundName>, RoundingMode::Round, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding<IMPL<FloorName>, RoundingMode::Floor, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding<IMPL<CeilName>, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ + factory.register_function<FunctionRounding<IMPL<RoundBankersName>, RoundingMode::Round, \ TieBreakingMode::Bankers>>(); REGISTER_ROUND_FUNCTIONS(DecimalRoundOneImpl) @@ -445,5 +444,9 @@ void register_function_math(SimpleFunctionFactory& factory) { factory.register_function<FunctionRadians>(); factory.register_function<FunctionDegrees>(); factory.register_function<FunctionBin>(); + factory.register_function<FunctionTruncate<TruncateFloatOneArgImpl>>(); + factory.register_function<FunctionTruncate<TruncateFloatTwoArgImpl>>(); + factory.register_function<FunctionTruncate<TruncateDecimalOneArgImpl>>(); + factory.register_function<FunctionTruncate<TruncateDecimalTwoArgImpl>>(); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index 7e48b8e9306..a9d1e7a019c 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -20,8 +20,15 @@ #pragma once +#include <cstddef> +#include <cstdint> + +#include "common/exception.h" +#include "common/status.h" #include "vec/columns/column_const.h" #include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" +#include "vec/core/types.h" #include "vec/functions/function.h" #if defined(__SSE4_1__) || defined(__aarch64__) #include "util/sse_util.hpp" @@ -176,6 +183,23 @@ public: memcpy(out.data(), in.data(), in.size() * sizeof(T)); } } + + // NOTE: This function is only tested for truncate + // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! + static NO_INLINE void apply(const NativeType& in, UInt32 in_scale, NativeType& out, + Int16 out_scale) { + Int16 scale_arg = in_scale - out_scale; + if (scale_arg > 0) { + size_t scale = int_exp10(scale_arg); + if (out_scale < 0) { + Op::compute(&in, scale, &out, int_exp10(-out_scale)); + } else { + Op::compute(&in, scale, &out, 1); + } + } else { + memcpy(&out, &in, sizeof(NativeType)); + } + } }; template <TieBreakingMode tie_breaking_mode> @@ -314,6 +338,11 @@ public: memcpy(p_out, &tmp_dst, tail_size_bytes); } } + + static NO_INLINE void apply(const T& in, size_t scale, T& out) { + auto mm_scale = Op::prepare(scale); + Op::compute(&in, mm_scale, &out); + } }; template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, @@ -386,6 +415,10 @@ public: __builtin_unreachable(); } } + + static NO_INLINE void apply(const T& in, size_t scale, T& out) { + Op::compute(&in, scale, &out, 1); + } }; /** Select the appropriate processing algorithm depending on the scale. @@ -400,7 +433,7 @@ struct Dispatcher { FloatRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>, IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>>; - static ColumnPtr apply(const IColumn* col_general, Int16 scale_arg) { + static ColumnPtr apply_vec_const(const IColumn* col_general, Int16 scale_arg) { if constexpr (IsNumber<T>) { const auto* const col = check_and_get_column<ColumnVector<T>>(col_general); auto col_res = ColumnVector<T>::create(); @@ -446,6 +479,179 @@ struct Dispatcher { return nullptr; } } + + // NOTE: This function is only tested for truncate + // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! + static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale) { + if constexpr (rounding_mode != RoundingMode::Trunc) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Using column as scale is only supported for function truncate"); + } + + const ColumnInt32& col_scale_i32 = assert_cast<const ColumnInt32&>(*col_scale); + const size_t input_row_count = col_scale_i32.size(); + for (size_t i = 0; i < input_row_count; ++i) { + const Int32 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg > std::numeric_limits<Int16>::max() || + scale_arg < std::numeric_limits<Int16>::min()) { + throw doris::Exception(ErrorCode::OUT_OF_BOUND, + "Scale argument for function is out of bound: {}", + scale_arg); + } + } + + if constexpr (IsNumber<T>) { + const auto* col = assert_cast<const ColumnVector<T>*>(col_general); + auto col_res = ColumnVector<T>::create(); + typename ColumnVector<T>::Container& vec_res = col_res->get_data(); + vec_res.resize(input_row_count); + + for (size_t i = 0; i < input_row_count; ++i) { + const Int32 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg == 0) { + size_t scale = 1; + FunctionRoundingImpl<ScaleMode::Zero>::apply(col->get_data()[i], scale, + vec_res[i]); + } else if (scale_arg > 0) { + size_t scale = int_exp10(scale_arg); + FunctionRoundingImpl<ScaleMode::Positive>::apply(col->get_data()[i], scale, + vec_res[i]); + } else { + size_t scale = int_exp10(-scale_arg); + FunctionRoundingImpl<ScaleMode::Negative>::apply(col->get_data()[i], scale, + vec_res[i]); + } + } + return col_res; + } else if constexpr (IsDecimalNumber<T>) { + const auto* decimal_col = assert_cast<const ColumnDecimal<T>*>(col_general); + + // For truncate, ALWAYS use SAME scale with source Decimal column + const Int32 input_scale = decimal_col->get_scale(); + auto col_res = ColumnDecimal<T>::create(input_row_count, input_scale); + + for (size_t i = 0; i < input_row_count; ++i) { + DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply( + decimal_col->get_element(i).value, input_scale, + col_res->get_element(i).value, col_scale_i32.get_data()[i]); + } + + for (size_t i = 0; i < input_row_count; ++i) { + // For truncate(ColumnDecimal, ColumnInt32), we should always have same scale with source Decimal column + // So we need this check to make sure the result have correct digits count + // + // Case 0: scale_arg <= -(integer part digits count) + // do nothing, because result is 0 + // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count) + // decimal parts has been erased, so add them back by multiply 10^(scale_arg) + // Case 2: scale_arg > 0 && scale_arg < decimal part digits count + // decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg) + // Case 3: scale_arg >= input_scale + // do nothing + const Int32 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg <= 0) { + col_res->get_element(i).value *= int_exp10(input_scale); + } else if (scale_arg > 0 && scale_arg < input_scale) { + col_res->get_element(i).value *= int_exp10(input_scale - scale_arg); + } + } + + return col_res; + } else { + LOG(FATAL) << "__builtin_unreachable"; + __builtin_unreachable(); + return nullptr; + } + } + + // NOTE: This function is only tested for truncate + // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! only test for truncate + static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, + const IColumn* col_scale) { + if constexpr (rounding_mode != RoundingMode::Trunc) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Using column as scale is only supported for function truncate"); + } + + const ColumnInt32& col_scale_i32 = assert_cast<const ColumnInt32&>(*col_scale); + const size_t input_rows_count = col_scale->size(); + + for (size_t i = 0; i < input_rows_count; ++i) { + const Int32 scale_arg = col_scale_i32.get_data()[i]; + + if (scale_arg > std::numeric_limits<Int16>::max() || + scale_arg < std::numeric_limits<Int16>::min()) { + throw doris::Exception(ErrorCode::OUT_OF_BOUND, + "Scale argument for function is out of bound: {}", + scale_arg); + } + } + + if constexpr (IsDecimalNumber<T>) { + const ColumnDecimal<T>& data_col_general = + assert_cast<const ColumnDecimal<T>&>(const_col_general->get_data_column()); + const T& general_val = data_col_general.get_data()[0]; + Int32 input_scale = data_col_general.get_scale(); + + auto col_res = ColumnDecimal<T>::create(input_rows_count, input_scale); + + for (size_t i = 0; i < input_rows_count; ++i) { + DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply( + general_val, input_scale, col_res->get_element(i).value, + col_scale_i32.get_data()[i]); + } + + for (size_t i = 0; i < input_rows_count; ++i) { + // For truncate(ColumnDecimal, ColumnInt32), we should always have same scale with source Decimal column + // So we need this check to make sure the result have correct digits count + // + // Case 0: scale_arg <= -(integer part digits count) + // do nothing, because result is 0 + // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count) + // decimal parts has been erased, so add them back by multiply 10^(scale_arg) + // Case 2: scale_arg > 0 && scale_arg < decimal part digits count + // decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg) + // Case 3: scale_arg >= input_scale + // do nothing + const Int32 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg <= 0) { + col_res->get_element(i).value *= int_exp10(input_scale); + } else if (scale_arg > 0 && scale_arg < input_scale) { + col_res->get_element(i).value *= int_exp10(input_scale - scale_arg); + } + } + + return col_res; + } else if constexpr (IsNumber<T>) { + const ColumnVector<T>& data_col_general = + assert_cast<const ColumnVector<T>&>(const_col_general->get_data_column()); + const T& general_val = data_col_general.get_data()[0]; + auto col_res = ColumnVector<T>::create(input_rows_count); + typename ColumnVector<T>::Container& vec_res = col_res->get_data(); + + for (size_t i = 0; i < input_rows_count; ++i) { + const Int16 scale_arg = col_scale_i32.get_data()[i]; + if (scale_arg == 0) { + size_t scale = 1; + FunctionRoundingImpl<ScaleMode::Zero>::apply(general_val, scale, vec_res[i]); + } else if (scale_arg > 0) { + size_t scale = int_exp10(col_scale_i32.get_data()[i]); + FunctionRoundingImpl<ScaleMode::Positive>::apply(general_val, scale, + vec_res[i]); + } else { + size_t scale = int_exp10(-col_scale_i32.get_data()[i]); + FunctionRoundingImpl<ScaleMode::Negative>::apply(general_val, scale, + vec_res[i]); + } + } + + return col_res; + } else { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Unsupported column {} for function truncate", + const_col_general->get_name()); + } + } }; template <typename Impl, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode> @@ -476,17 +682,17 @@ public: static Status get_scale_arg(const ColumnWithTypeAndName& arguments, Int16* scale) { const IColumn& scale_column = *arguments.column; - Int32 scale64 = static_cast<const ColumnInt32&>( - static_cast<const ColumnConst*>(&scale_column)->get_data_column()) - .get_element(0); + Int32 scale_arg = assert_cast<const ColumnInt32&>( + assert_cast<const ColumnConst*>(&scale_column)->get_data_column()) + .get_element(0); - if (scale64 > std::numeric_limits<Int16>::max() || - scale64 < std::numeric_limits<Int16>::min()) { + if (scale_arg > std::numeric_limits<Int16>::max() || + scale_arg < std::numeric_limits<Int16>::min()) { return Status::InvalidArgument("Scale argument for function {} is out of bound: {}", - name, scale64); + name, scale_arg); } - *scale = scale64; + *scale = scale_arg; return Status::OK(); } @@ -507,7 +713,7 @@ public: if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) { using FieldType = typename DataType::FieldType; - res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::apply( + res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::apply_vec_const( column.column.get(), scale_arg); return true; } diff --git a/be/test/vec/function/function_truncate_decimal_test.cpp b/be/test/vec/function/function_truncate_decimal_test.cpp new file mode 100644 index 00000000000..36fcaa14e67 --- /dev/null +++ b/be/test/vec/function/function_truncate_decimal_test.cpp @@ -0,0 +1,370 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest-message.h> +#include <gtest/gtest.h> + +#include <climits> +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <iomanip> +#include <limits> +#include <map> +#include <memory> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "function_test_util.h" +#include "vec/columns/column.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" +#include "vec/core/column_numbers.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/function_truncate.h" + +namespace doris::vectorized { +// {precision, scale} -> {input, scale_arg, expectation} +using TestDataSet = std::map<std::pair<int, int>, std::vector<std::tuple<Int128, int, Int128>>>; + +const static TestDataSet truncate_decimal32_cases = { + {{1, 0}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, + {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 1}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, + {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, + }}, + {{1, 1}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, + {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 0}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, + {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, + }}, + {{2, 0}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 10}, + {12, 0, 12}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 1}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 0}, + {12, 0, 10}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 2}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 0}, + {12, 0, 0}, + {12, 1, 10}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{9, 0}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 100000000}, + {123456789, -7, 120000000}, {123456789, -6, 123000000}, {123456789, -5, 123400000}, + {123456789, -4, 123450000}, {123456789, -3, 123456000}, {123456789, -2, 123456700}, + {123456789, -1, 123456780}, {123456789, 0, 123456789}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 1}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 100000000}, {123456789, -6, 120000000}, {123456789, -5, 123000000}, + {123456789, -4, 123400000}, {123456789, -3, 123450000}, {123456789, -2, 123456000}, + {123456789, -1, 123456700}, {123456789, 0, 123456780}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 2}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 100000000}, {123456789, -5, 120000000}, + {123456789, -4, 123000000}, {123456789, -3, 123400000}, {123456789, -2, 123450000}, + {123456789, -1, 123456000}, {123456789, 0, 123456700}, {123456789, 1, 123456780}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 3}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 100000000}, + {123456789, -4, 120000000}, {123456789, -3, 123000000}, {123456789, -2, 123400000}, + {123456789, -1, 123450000}, {123456789, 0, 123456000}, {123456789, 1, 123456700}, + {123456789, 2, 123456780}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 4}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 100000000}, {123456789, -3, 120000000}, {123456789, -2, 123000000}, + {123456789, -1, 123400000}, {123456789, 0, 123450000}, {123456789, 1, 123456000}, + {123456789, 2, 123456700}, {123456789, 3, 123456780}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 5}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 100000000}, {123456789, -2, 120000000}, + {123456789, -1, 123000000}, {123456789, 0, 123400000}, {123456789, 1, 123450000}, + {123456789, 2, 123456000}, {123456789, 3, 123456700}, {123456789, 4, 123456780}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 6}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 100000000}, + {123456789, -1, 120000000}, {123456789, 0, 123000000}, {123456789, 1, 123400000}, + {123456789, 2, 123450000}, {123456789, 3, 123456000}, {123456789, 4, 123456700}, + {123456789, 5, 123456780}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 7}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 100000000}, {123456789, 0, 120000000}, {123456789, 1, 123000000}, + {123456789, 2, 123400000}, {123456789, 3, 123450000}, {123456789, 4, 123456000}, + {123456789, 5, 123456700}, {123456789, 6, 123456780}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 8}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 100000000}, {123456789, 1, 120000000}, + {123456789, 2, 123000000}, {123456789, 3, 123400000}, {123456789, 4, 123450000}, + {123456789, 5, 123456000}, {123456789, 6, 123456700}, {123456789, 7, 123456780}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 9}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 0}, {123456789, 1, 100000000}, + {123456789, 2, 120000000}, {123456789, 3, 123000000}, {123456789, 4, 123400000}, + {123456789, 5, 123450000}, {123456789, 6, 123456000}, {123456789, 7, 123456700}, + {123456789, 8, 123456780}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}}; + +const static TestDataSet truncate_decimal64_cases = { + {{10, 0}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 1000000000}, + {1234567891, -8, 1200000000}, {1234567891, -7, 1230000000}, {1234567891, -6, 1234000000}, + {1234567891, -5, 1234500000}, {1234567891, -4, 1234560000}, {1234567891, -3, 1234567000}, + {1234567891, -2, 1234567800}, {1234567891, -1, 1234567890}, {1234567891, 0, 1234567891}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{10, 1}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 1000000000}, {1234567891, -7, 1200000000}, {1234567891, -6, 1230000000}, + {1234567891, -5, 1234000000}, {1234567891, -4, 1234500000}, {1234567891, -3, 1234560000}, + {1234567891, -2, 1234567000}, {1234567891, -1, 1234567800}, {1234567891, 0, 1234567890}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891} + + }}, + {{10, 2}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 0}, {1234567891, -7, 1000000000}, {1234567891, -6, 1200000000}, + {1234567891, -5, 1230000000}, {1234567891, -4, 1234000000}, {1234567891, -3, 1234500000}, + {1234567891, -2, 1234560000}, {1234567891, -1, 1234567000}, {1234567891, 0, 1234567800}, + {1234567891, 1, 1234567890}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{10, 9}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 0}, {1234567891, -7, 0}, {1234567891, -6, 0}, + {1234567891, -5, 0}, {1234567891, -4, 0}, {1234567891, -3, 0}, + {1234567891, -2, 0}, {1234567891, -1, 0}, {1234567891, 0, 1000000000}, + {1234567891, 1, 1200000000}, {1234567891, 2, 1230000000}, {1234567891, 3, 1234000000}, + {1234567891, 4, 1234500000}, {1234567891, 5, 1234560000}, {1234567891, 6, 1234567000}, + {1234567891, 7, 1234567800}, {1234567891, 8, 1234567890}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{18, 0}, + {{123456789123456789, -19, 0}, + {123456789123456789, -18, 0}, + {123456789123456789, -17, 100000000000000000}, + {123456789123456789, -16, 120000000000000000}, + {123456789123456789, -15, 123000000000000000}, + {123456789123456789, -14, 123400000000000000}, + {123456789123456789, -13, 123450000000000000}, + {123456789123456789, -12, 123456000000000000}, + {123456789123456789, -11, 123456700000000000}, + {123456789123456789, -10, 123456780000000000}, + {123456789123456789, -9, 123456789000000000}, + {123456789123456789, -8, 123456789100000000}, + {123456789123456789, -7, 123456789120000000}, + {123456789123456789, -6, 123456789123000000}, + {123456789123456789, -5, 123456789123400000}, + {123456789123456789, -4, 123456789123450000}, + {123456789123456789, -3, 123456789123456000}, + {123456789123456789, -2, 123456789123456700}, + {123456789123456789, -1, 123456789123456780}, + {123456789123456789, 0, 123456789123456789}, + {123456789123456789, 1, 123456789123456789}, + {123456789123456789, 2, 123456789123456789}, + {123456789123456789, 3, 123456789123456789}, + {123456789123456789, 4, 123456789123456789}, + {123456789123456789, 5, 123456789123456789}, + {123456789123456789, 6, 123456789123456789}, + {123456789123456789, 7, 123456789123456789}, + {123456789123456789, 8, 123456789123456789}, + {123456789123456789, 18, 123456789123456789}}}, + {{18, 18}, + {{123456789123456789, -1, 0}, + {123456789123456789, 0, 0}, + {123456789123456789, 1, 100000000000000000}, + {123456789123456789, 2, 120000000000000000}, + {123456789123456789, 3, 123000000000000000}, + {123456789123456789, 4, 123400000000000000}, + {123456789123456789, 5, 123450000000000000}, + {123456789123456789, 6, 123456000000000000}, + {123456789123456789, 7, 123456700000000000}, + {123456789123456789, 8, 123456780000000000}, + {123456789123456789, 9, 123456789000000000}, + {123456789123456789, 10, 123456789100000000}, + {123456789123456789, 11, 123456789120000000}, + {123456789123456789, 12, 123456789123000000}, + {123456789123456789, 13, 123456789123400000}, + {123456789123456789, 14, 123456789123450000}, + {123456789123456789, 15, 123456789123456000}, + {123456789123456789, 16, 123456789123456700}, + {123456789123456789, 17, 123456789123456780}, + {123456789123456789, 18, 123456789123456789}, + {123456789123456789, 19, 123456789123456789}, + {123456789123456789, 20, 123456789123456789}, + {123456789123456789, 21, 123456789123456789}, + {123456789123456789, 22, 123456789123456789}, + {123456789123456789, 23, 123456789123456789}, + {123456789123456789, 24, 123456789123456789}, + {123456789123456789, 25, 123456789123456789}, + {123456789123456789, 26, 123456789123456789}}}}; + +template <typename FuncType, typename DecimalType> +static void checker(const TestDataSet& truncate_test_cases, bool decimal_col_is_const) { + static_assert(IsDecimalNumber<DecimalType>); + auto func = std::dynamic_pointer_cast<FuncType>(FuncType::create()); + FunctionContext* context = nullptr; + + for (const auto& test_case : truncate_test_cases) { + Block block; + size_t res_idx = 2; + ColumnNumbers arguments = {0, 1, 2}; + const int precision = test_case.first.first; + const int scale = test_case.first.second; + const size_t input_rows_count = test_case.second.size(); + auto col_general = ColumnDecimal<DecimalType>::create(input_rows_count, scale); + auto col_scale = ColumnInt32::create(); + auto col_res_expected = ColumnDecimal<DecimalType>::create(input_rows_count, scale); + size_t rid = 0; + + for (const auto& test_date : test_case.second) { + auto input = std::get<0>(test_date); + auto scale_arg = std::get<1>(test_date); + auto expectation = std::get<2>(test_date); + col_general->get_element(rid) = DecimalType(input); + col_scale->insert(scale_arg); + col_res_expected->get_element(rid) = DecimalType(expectation); + rid++; + } + + if (decimal_col_is_const) { + block.insert({ColumnConst::create(col_general->clone_resized(1), 1), + std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale), + "col_general_const"}); + } else { + block.insert({col_general->clone(), + std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale), + "col_general"}); + } + + block.insert({col_scale->clone(), std::make_shared<DataTypeInt32>(), "col_scale"}); + block.insert({nullptr, std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale), + "col_res"}); + + auto status = func->execute_impl(context, block, arguments, res_idx, input_rows_count); + auto col_res = assert_cast<const ColumnDecimal<DecimalType>&>( + *(block.get_by_position(res_idx).column)); + EXPECT_TRUE(status.ok()); + + for (size_t i = 0; i < input_rows_count; ++i) { + auto res = col_res.get_element(i); + auto res_expected = col_res_expected->get_element(i); + EXPECT_EQ(res, res_expected) + << "precision " << precision << " input_scale " << scale << " input " + << col_general->get_element(i) << " scale_arg " << col_scale->get_element(i) + << " res " << res << " res_expected " << res_expected; + } + } +} +TEST(TruncateFunctionTest, normal_decimal) { + checker<FunctionTruncate<TruncateDecimalTwoArgImpl>, Decimal32>(truncate_decimal32_cases, + false); + checker<FunctionTruncate<TruncateDecimalTwoArgImpl>, Decimal64>(truncate_decimal64_cases, + false); +} + +TEST(TruncateFunctionTest, normal_decimal_const) { + checker<FunctionTruncate<TruncateDecimalTwoArgImpl>, Decimal32>(truncate_decimal32_cases, true); + checker<FunctionTruncate<TruncateDecimalTwoArgImpl>, Decimal64>(truncate_decimal64_cases, true); +} + +} // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index b5184c33fcd..9bc857bacef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -122,7 +122,7 @@ public class FunctionCallExpr extends Expr { Preconditions.checkArgument(children.get(1) instanceof IntLiteral || (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral), - "2nd argument of function round/floor/ceil/truncate must be literal"); + "2nd argument of function round/floor/ceil must be literal"); if (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral) { children.get(1).getChild(0).setType(children.get(1).getType()); children.set(1, children.get(1).getChild(0)); @@ -136,6 +136,34 @@ public class FunctionCallExpr extends Expr { return returnType; } }; + + java.util.function.BiFunction<ArrayList<Expr>, Type, Type> truncateRule = (children, returnType) -> { + Preconditions.checkArgument(children != null && children.size() > 0); + if (children.size() == 1 && children.get(0).getType().isDecimalV3()) { + return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), 0); + } else if (children.size() == 2) { + Expr scaleExpr = children.get(1); + if (scaleExpr instanceof IntLiteral + || (scaleExpr instanceof CastExpr && scaleExpr.getChild(0) instanceof IntLiteral)) { + if (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral) { + children.get(1).getChild(0).setType(children.get(1).getType()); + children.set(1, children.get(1).getChild(0)); + } else { + children.get(1).setType(Type.INT); + } + int scaleArg = (int) (((IntLiteral) children.get(1)).getValue()); + return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), + Math.min(Math.max(scaleArg, 0), ((ScalarType) children.get(0).getType()).decimalScale())); + } else { + // Scale argument is a Column, always use same scale with input decimal + return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), + ((ScalarType) children.get(0).getType()).decimalScale()); + } + } else { + return returnType; + } + }; + java.util.function.BiFunction<ArrayList<Expr>, Type, Type> arrayDateTimeV2OrDecimalV3Rule = (children, returnType) -> { Preconditions.checkArgument(children != null && children.size() > 0); @@ -239,7 +267,7 @@ public class FunctionCallExpr extends Expr { PRECISION_INFER_RULE.put("dround", roundRule); PRECISION_INFER_RULE.put("dceil", roundRule); PRECISION_INFER_RULE.put("dfloor", roundRule); - PRECISION_INFER_RULE.put("truncate", roundRule); + PRECISION_INFER_RULE.put("truncate", truncateRule); } public static final ImmutableSet<String> TIME_FUNCTIONS_WITH_PRECISION = new ImmutableSortedSet.Builder( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java index 4b57772ed23..6b6308c516c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.Int32OrLessType; @@ -37,19 +38,38 @@ public interface ComputePrecisionForRound extends ComputePrecision { } else if (arity() == 2 && signature.getArgType(0) instanceof DecimalV3Type) { DecimalV3Type decimalV3Type = DecimalV3Type.forType(getArgumentType(0)); Expression floatLength = getArgument(1); - Preconditions.checkArgument(floatLength.getDataType() instanceof Int32OrLessType - && (floatLength.isLiteral() || ( - floatLength instanceof Cast && floatLength.child(0).isLiteral() - && floatLength.child(0).getDataType() instanceof Int32OrLessType)), - "2nd argument of function round/floor/ceil/truncate must be literal"); - int scale; - if (floatLength instanceof Cast) { - scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); + + if (this instanceof Truncate) { + if (floatLength.isLiteral() || ( + floatLength instanceof Cast && floatLength.child(0).isLiteral() + && floatLength.child(0).getDataType() instanceof Int32OrLessType)) { + // Scale argument is a literal or cast from other literal + if (floatLength instanceof Cast) { + scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); + } else { + scale = ((IntegerLikeLiteral) floatLength).getIntValue(); + } + scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale()); + } else { + // Truncate could use Column as its scale argument. + // Result scale will always same with input Decimal in this situation. + scale = decimalV3Type.getScale(); + } } else { - scale = ((IntegerLikeLiteral) floatLength).getIntValue(); + Preconditions.checkArgument(floatLength.getDataType() instanceof Int32OrLessType + && (floatLength.isLiteral() || ( + floatLength instanceof Cast && floatLength.child(0).isLiteral() + && floatLength.child(0).getDataType() instanceof Int32OrLessType)), + "2nd argument of function round/floor/ceil must be literal"); + if (floatLength instanceof Cast) { + scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); + } else { + scale = ((IntegerLikeLiteral) floatLength).getIntValue(); + } + scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale()); } - scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale()); + return signature.withArgumentType(0, decimalV3Type) .withReturnType(DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision(), scale)); } else { diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out new file mode 100644 index 00000000000..24f675ffbe2 --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out @@ -0,0 +1,101 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql -- +0 123.3 +1 123.3 +2 123.3 +3 123.3 +4 123.3 +5 123.3 +6 123.3 +7 123.3 +8 123.3 +9 123.3 + +-- !sql -- +0 120 +1 120 +2 120 +3 120 +4 120 +5 120 +6 120 +7 120 +8 120 +9 120 + +-- !sql -- +0 123 +1 123 +2 123 +3 123 +4 123 +5 123 +6 123 +7 123 +8 123 +9 123 + +-- !sql -- +0E-8 + +-- !sql -- +0 0.0 +1 0.0 +2 0.0 +3 0.0 +4 0.0 + +-- !vec_const0 -- +1 12345.0 1.23456789E8 +2 12345.0 1.23456789E8 +3 12345.0 1.23456789E8 +4 0.0 0.0 + +-- !vec_const0 -- +1 12345.1 1.234567891E8 +2 12345.1 1.234567891E8 +3 12345.1 1.234567891E8 +4 0.0 0.0 + +-- !vec_const0 -- +1 12340.0 1.2345678E8 +2 12340.0 1.2345678E8 +3 12340.0 1.2345678E8 +4 0.0 0.0 + +-- !vec_const1 -- +1 123456789 123456789 12345678.1 12345678 0.123456789 0 +2 123456789 123456789 12345678.1 12345678 0.123456789 0 +3 123456789 123456789 12345678.1 12345678 0.123456789 0 +4 0 0 0.0 0 0E-9 0 + +-- !vec_const2 -- +1 123456789 123456789 1.123456789 1 0.1234567890 0 +2 123456789 123456789 1.123456789 1 0.1234567890 0 +3 123456789 123456789 1.123456789 1 0.1234567890 0 +4 0 0 0E-9 0 0E-10 0 + +-- !const_vec1 -- +123456789.123456789 1 123456789.100000000 +123456789.123456789 1 123456789.100000000 +123456789.123456789 1 123456789.100000000 +123456789.123456789 1 123456789.100000000 + +-- !const_vec2 -- +123456789.123456789 -1 123456780.000000000 +123456789.123456789 -1 123456780.000000000 +123456789.123456789 -1 123456780.000000000 +123456789.123456789 -1 123456780.000000000 + +-- !vec_vec0 -- +1 1 12345.1 1.234567891E8 +2 1 12345.1 1.234567891E8 +3 1 12345.1 1.234567891E8 +4 1 0.0 0.0 + +-- !truncate_dec128 -- +1 1234567891234567891 1234567891234567891 1234567891.123456789 1234567891 0.1234567891234567891 0 + +-- !truncate_dec128 -- +1 1234567891234567891 1234567891234567891 1234567891.123456789 1234567891.100000000 0.1234567891234567891 0.1000000000000000000 + diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy new file mode 100644 index 00000000000..767140e7a6f --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_function_truncate") { + qt_sql """ + SELECT number, truncate(123.345 , 1) FROM numbers("number"="10"); + """ + qt_sql """ + SELECT number, truncate(123.123, -1) FROM numbers("number"="10"); + """ + qt_sql """ + SELECT number, truncate(123.123, 0) FROM numbers("number"="10"); + """ + + // const_const, result scale should be 10 + qt_sql """ + SELECT truncate(cast(0 as Decimal(9,8)), 10); + """ + + // const_const, result scale should be 1 + qt_sql """ + SELECT number, truncate(cast(0 as Decimal(9,4)), 1) FROM numbers("number"="5") + """ + + sql """DROP TABLE IF EXISTS test_function_truncate;""" + sql """DROP TABLE IF EXISTS test_function_truncate_dec128;""" + sql """ + CREATE TABLE test_function_truncate ( + rid int, flo float, dou double, + dec90 decimal(9, 0), dec91 decimal(9, 1), dec99 decimal(9, 9), + dec100 decimal(10,0), dec109 decimal(10,9), dec1010 decimal(10,10), + number int DEFAULT 1) + DISTRIBUTED BY HASH(rid) + PROPERTIES("replication_num" = "1" ); + """ + + sql """ + INSERT INTO test_function_truncate + VALUES + (1, 12345.123, 123456789.123456789, + 123456789, 12345678.1, 0.123456789, + 123456789.1, 1.123456789, 0.123456789, 1); + """ + sql """ + INSERT INTO test_function_truncate + VALUES + (2, 12345.123, 123456789.123456789, + 123456789, 12345678.1, 0.123456789, + 123456789.1, 1.123456789, 0.123456789, 1); + """ + sql """ + INSERT INTO test_function_truncate + VALUES + (3, 12345.123, 123456789.123456789, + 123456789, 12345678.1, 0.123456789, + 123456789.1, 1.123456789, 0.123456789, 1); + """ + sql """ + INSERT INTO test_function_truncate + VALUES + (4, 0, 0, 0, 0.0, 0, 0, 0, 0, 1); + """ + qt_vec_const0 """ + SELECT rid, truncate(flo, 0), truncate(dou, 0) FROM test_function_truncate order by rid; + """ + qt_vec_const0 """ + SELECT rid, truncate(flo, 1), truncate(dou, 1) FROM test_function_truncate order by rid; + """ + qt_vec_const0 """ + SELECT rid, truncate(flo, -1), truncate(dou, -1) FROM test_function_truncate order by rid; + """ + qt_vec_const1 """ + SELECT rid, dec90, truncate(dec90, 0), dec91, truncate(dec91, 0), dec99, truncate(dec99, 0) FROM test_function_truncate order by rid + """ + qt_vec_const2 """ + SELECT rid, dec100, truncate(dec100, 0), dec109, truncate(dec109, 0), dec1010, truncate(dec1010, 0) FROM test_function_truncate order by rid + """ + + + + qt_const_vec1 """ + SELECT 123456789.123456789, number, truncate(123456789.123456789, number) from test_function_truncate; + """ + qt_const_vec2 """ + SELECT 123456789.123456789, -number, truncate(123456789.123456789, -number) from test_function_truncate; + """ + qt_vec_vec0 """ + SELECT rid,number, truncate(flo, number), truncate(dou, number) FROM test_function_truncate order by rid; + """ + + sql """ + CREATE TABLE test_function_truncate_dec128 ( + rid int, dec190 decimal(19,0), dec199 decimal(19,9), dec1919 decimal(19,19), + dec380 decimal(38,0), dec3819 decimal(38,19), dec3838 decimal(38,38), + number int DEFAULT 1 + ) + DISTRIBUTED BY HASH(rid) + PROPERTIES("replication_num" = "1" ); + """ + sql """ + INSERT INTO test_function_truncate_dec128 + VALUES + (1, 1234567891234567891.0, 1234567891.123456789, 0.1234567891234567891, + 12345678912345678912345678912345678912.0, + 1234567891234567891.1234567891234567891, + 0.12345678912345678912345678912345678912345678912345678912345678912345678912, 1); + """ + qt_truncate_dec128 """ + SELECT rid, dec190, truncate(dec190, 0), dec199, truncate(dec199, 0), dec1919, truncate(dec1919, 0) + FROM test_function_truncate_dec128 order by rid + """ + + qt_truncate_dec128 """ + SELECT rid, dec190, truncate(dec190, number), dec199, truncate(dec199, number), dec1919, truncate(dec1919, number) + FROM test_function_truncate_dec128 order by rid + """ + +} \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org