This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 8194398028c [fix](round) Fix incorrect decimal scale inference in round functions (#34471) 8194398028c is described below commit 8194398028c5a6731858dd96b0036f23bc3a4800 Author: zhiqiang <seuhezhiqi...@163.com> AuthorDate: Fri May 10 16:09:46 2024 +0800 [fix](round) Fix incorrect decimal scale inference in round functions (#34471) * FIX NEEDED * FORMAT * FORMAT * FIX TEST --- be/src/vec/functions/round.h | 114 ++++++++++++------- .../functions/ComputePrecisionForRound.java | 7 +- .../sql_functions/math_functions/test_round.out | 123 +++++++++++++++++++++ .../sql_functions/math_functions/test_round.groovy | 35 +++++- 4 files changed, 237 insertions(+), 42 deletions(-) diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index 97a81f644ed..a17865914c4 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -21,13 +21,17 @@ #pragma once #include <cstddef> +#include <memory> #include "common/exception.h" #include "common/status.h" #include "vec/columns/column_const.h" #include "vec/columns/columns_number.h" #include "vec/common/assert_cast.h" +#include "vec/core/column_with_type_and_name.h" #include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" #include "vec/functions/function.h" #if defined(__SSE4_1__) || defined(__aarch64__) #include "util/sse_util.hpp" @@ -430,7 +434,10 @@ struct Dispatcher { FloatRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>, IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>>; - static ColumnPtr apply_vec_const(const IColumn* col_general, Int16 scale_arg) { + // scale_arg: scale for function computation + // result_scale: scale for result decimal, this scale is got from planner + static ColumnPtr apply_vec_const(const IColumn* col_general, const Int16 scale_arg, + [[maybe_unused]] Int16 result_scale) { if constexpr (IsNumber<T>) { const auto* const col = check_and_get_column<ColumnVector<T>>(col_general); auto col_res = ColumnVector<T>::create(); @@ -457,10 +464,7 @@ struct Dispatcher { } else if constexpr (IsDecimalNumber<T>) { const auto* const decimal_col = check_and_get_column<ColumnDecimal<T>>(col_general); const auto& vec_src = decimal_col->get_data(); - - UInt32 result_scale = - std::min(static_cast<UInt32>(std::max(scale_arg, static_cast<Int16>(0))), - decimal_col->get_scale()); + const size_t input_rows_count = vec_src.size(); auto col_res = ColumnDecimal<T>::create(vec_src.size(), result_scale); auto& vec_res = col_res->get_data(); @@ -468,6 +472,27 @@ struct Dispatcher { FunctionRoundingImpl<ScaleMode::Negative>::apply( decimal_col->get_data(), decimal_col->get_scale(), vec_res, scale_arg); } + // We need to always make sure result decimal's scale is as expected as its in plan + // So we need to append enough zero to result. + + // Case 0: scale_arg <= -(integer part digits count) + // do nothing, because result is 0 + // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count) + // decimal parts has been erased, so add them back by multiply 10^(result_scale) + // Case 2: scale_arg > 0 && scale_arg < result_scale + // decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg) + // Case 3: scale_arg >= input_scale + // do nothing + + if (scale_arg <= 0) { + for (size_t i = 0; i < input_rows_count; ++i) { + vec_res[i].value *= int_exp10(result_scale); + } + } else if (scale_arg > 0 && scale_arg < result_scale) { + for (size_t i = 0; i < input_rows_count; ++i) { + vec_res[i].value *= int_exp10(result_scale - scale_arg); + } + } return col_res; } else { @@ -477,7 +502,9 @@ struct Dispatcher { } } - static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale) { + // result_scale: scale for result decimal, this scale is got from planner + static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale, + [[maybe_unused]] Int16 result_scale) { const auto& col_scale_i32 = assert_cast<const ColumnInt32&>(*col_scale); const size_t input_row_count = col_scale_i32.size(); for (size_t i = 0; i < input_row_count; ++i) { @@ -515,10 +542,8 @@ struct Dispatcher { return col_res; } else if constexpr (IsDecimalNumber<T>) { const auto* decimal_col = assert_cast<const ColumnDecimal<T>*>(col_general); - - // ALWAYS use SAME scale with source Decimal column const Int32 input_scale = decimal_col->get_scale(); - auto col_res = ColumnDecimal<T>::create(input_row_count, input_scale); + auto col_res = ColumnDecimal<T>::create(input_row_count, result_scale); for (size_t i = 0; i < input_row_count; ++i) { DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply( @@ -534,15 +559,15 @@ struct Dispatcher { // do nothing, because result is 0 // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count) // decimal parts has been erased, so add them back by multiply 10^(scale_arg) - // Case 2: scale_arg > 0 && scale_arg < decimal part digits count - // decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg) + // Case 2: scale_arg > 0 && scale_arg < result_scale + // decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg) // Case 3: scale_arg >= input_scale // do nothing const Int32 scale_arg = col_scale_i32.get_data()[i]; if (scale_arg <= 0) { - col_res->get_element(i).value *= int_exp10(input_scale); - } else if (scale_arg > 0 && scale_arg < input_scale) { - col_res->get_element(i).value *= int_exp10(input_scale - scale_arg); + col_res->get_element(i).value *= int_exp10(result_scale); + } else if (scale_arg > 0 && scale_arg < result_scale) { + col_res->get_element(i).value *= int_exp10(result_scale - scale_arg); } } @@ -554,8 +579,9 @@ struct Dispatcher { } } - static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, - const IColumn* col_scale) { + // result_scale: scale for result decimal, this scale is got from planner + static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, const IColumn* col_scale, + [[maybe_unused]] Int16 result_scale) { const auto& col_scale_i32 = assert_cast<const ColumnInt32&>(*col_scale); const size_t input_rows_count = col_scale->size(); @@ -575,8 +601,7 @@ struct Dispatcher { assert_cast<const ColumnDecimal<T>&>(const_col_general->get_data_column()); const T& general_val = data_col_general.get_data()[0]; Int32 input_scale = data_col_general.get_scale(); - - auto col_res = ColumnDecimal<T>::create(input_rows_count, input_scale); + auto col_res = ColumnDecimal<T>::create(input_rows_count, result_scale); for (size_t i = 0; i < input_rows_count; ++i) { DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply( @@ -592,15 +617,15 @@ struct Dispatcher { // do nothing, because result is 0 // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits count) // decimal parts has been erased, so add them back by multiply 10^(scale_arg) - // Case 2: scale_arg > 0 && scale_arg < decimal part digits count - // decimal part now has scale_arg digits, so multiply 10^(input_scale - scal_arg) + // Case 2: scale_arg > 0 && scale_arg < result_scale + // decimal part now has scale_arg digits, so multiply 10^(result_scale - scal_arg) // Case 3: scale_arg >= input_scale // do nothing const Int32 scale_arg = col_scale_i32.get_data()[i]; if (scale_arg <= 0) { - col_res->get_element(i).value *= int_exp10(input_scale); - } else if (scale_arg > 0 && scale_arg < input_scale) { - col_res->get_element(i).value *= int_exp10(input_scale - scale_arg); + col_res->get_element(i).value *= int_exp10(result_scale); + } else if (scale_arg > 0 && scale_arg < result_scale) { + col_res->get_element(i).value *= int_exp10(result_scale - scale_arg); } } @@ -679,26 +704,23 @@ public: return Status::OK(); } - /// SELECT number, truncate(123.345, 1) FROM number("numbers"="10") - /// should NOT behave like two column arguments, so we can not use const column default implementation - bool use_default_implementation_for_constants() const override { return false; } + bool use_default_implementation_for_constants() const override { return true; } - //// We moved and optimized the execute_impl logic of function_truncate.h from PR#32746, - //// as well as make it suitable for all functions. Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); + ColumnWithTypeAndName& column_result = block.get_by_position(result); + const DataTypePtr result_type = block.get_by_position(result).type; const bool is_col_general_const = is_column_const(*column_general.column); const auto* col_general = is_col_general_const ? assert_cast<const ColumnConst&>(*column_general.column) .get_data_column_ptr() : column_general.column.get(); - ColumnPtr res; /// potential argument types: /// if the SECOND argument is MISSING(would be considered as ZERO const) or CONST, then we have the following type: - /// 1. func(Column), func(ColumnConst), func(Column, ColumnConst), func(ColumnConst, ColumnConst) + /// 1. func(Column), func(Column, ColumnConst) /// otherwise, the SECOND arugment is COLUMN, we have another type: /// 2. func(Column, Column), func(ColumnConst, Column) @@ -706,6 +728,23 @@ public: using Types = std::decay_t<decltype(types)>; using DataType = typename Types::LeftType; + // For decimal, we will always make sure result Decimal has exactly same precision and scale with + // arguments from query plan. + Int16 result_scale = 0; + if constexpr (IsDataTypeDecimal<DataType>) { + if (column_result.type->get_type_id() == TypeIndex::Nullable) { + if (auto nullable_type = std::dynamic_pointer_cast<const DataTypeNullable>( + column_result.type)) { + result_scale = nullable_type->get_nested_type()->get_scale(); + } else { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "Illegal nullable column"); + } + } else { + result_scale = column_result.type->get_scale(); + } + } + if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) { using FieldType = typename DataType::FieldType; if (arguments.size() == 1 || @@ -718,23 +757,20 @@ public: } res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::apply_vec_const( - col_general, scale_arg); - - if (is_col_general_const) { - // Important, make sure the result column has the same size as the input column - res = ColumnConst::create(std::move(res), input_rows_count); - } + col_general, scale_arg, result_scale); } else { // the SECOND arugment is COLUMN if (is_col_general_const) { res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>:: apply_const_vec( &assert_cast<const ColumnConst&>(*column_general.column), - block.get_by_position(arguments[1]).column.get()); + block.get_by_position(arguments[1]).column.get(), + result_scale); } else { res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>:: apply_vec_vec(col_general, - block.get_by_position(arguments[1]).column.get()); + block.get_by_position(arguments[1]).column.get(), + result_scale); } } return true; @@ -758,7 +794,7 @@ public: column_general.type->get_name(), name); } - block.replace_by_position(result, std::move(res)); + column_result.column = std::move(res); return Status::OK(); } }; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java index eedbfea6df9..b47804e23ff 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java @@ -37,9 +37,12 @@ public interface ComputePrecisionForRound extends ComputePrecision { Expression floatLength = getArgument(1); int scale; - if (floatLength.isLiteral() || (floatLength instanceof Cast && floatLength.child(0).isLiteral() + // If scale arg is an integer literal, or it is a cast(Integer as Integer) + // then we will try to use its value as result scale + // In any other cases, we will make sure result decimal has same scale with input. + if ((floatLength.isLiteral() && floatLength.getDataType() instanceof Int32OrLessType) + || (floatLength instanceof Cast && floatLength.child(0).isLiteral() && floatLength.child(0).getDataType() instanceof Int32OrLessType)) { - // Scale argument is a literal or cast from other literal if (floatLength instanceof Cast) { scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); } else { diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out index 1ebc9cf5b89..ccdd9551f80 100644 --- a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out +++ b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out @@ -1,4 +1,115 @@ -- This file is automatically generated. You should know what you did if you want to edit this +-- !select -- +123.100 + +-- !select -- +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 + +-- !select -- +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 + +-- !select -- +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 + +-- !select -- +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 + +-- !select -- +123.200 +123.200 +123.200 +123.200 +123.200 +123.200 +123.200 +123.200 +123.200 +123.200 + +-- !select -- +130.000 +130.000 +130.000 +130.000 +130.000 +130.000 +130.000 +130.000 +130.000 +130.000 + +-- !select -- +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 +123.100 + +-- !select -- +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 +120.000 + +-- !select -- +4434.41 + +-- !select -- +0 + +-- !select -- +false \N 4434 + +-- !select -- +0 + -- !select -- 10 @@ -97,6 +208,18 @@ -- !select -- 16.025 16.02500 16.02500 +-- !select_fix -- +16.025 16.02500 16.02500 + +-- !select_fix -- +16.025 16.02500 16.02500 + +-- !select_fix -- +16.025 16.02500 16.02500 + +-- !select_fix -- +16.025 16.02500 16.02500 + -- !nereids_round_arg1 -- 10 diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy index 1d8bbb9df49..da361e15938 100644 --- a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy @@ -15,7 +15,35 @@ // specific language governing permissions and limitations // under the License. - suite("test_round") { +suite("test_round") { + sql "set enable_fold_constant_by_be=false;" + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + qt_select "SELECT round(123.123, 1.123);" + qt_select """SELECT round(123.123, 1.123) FROM numbers("number"="10");""" + qt_select """SELECT round(123.123, -1.123) FROM numbers("number"="10");""" + qt_select """SELECT truncate(123.123, 1.123) FROM numbers("number"="10");""" + qt_select """SELECT truncate(123.123, -1.123) FROM numbers("number"="10");""" + qt_select """SELECT ceil(123.123, 1.123) FROM numbers("number"="10");""" + qt_select """SELECT ceil(123.123, -1.123) FROM numbers("number"="10");""" + qt_select """SELECT round_bankers(123.123, 1.123) FROM numbers("number"="10");""" + qt_select """SELECT round_bankers(123.123, -1.123) FROM numbers("number"="10");""" + sql """drop table if exists test_round_1; """ + sql """ + create table test_round_1(big_key bigint not NULL) + DISTRIBUTED BY HASH(big_key) BUCKETS 1 PROPERTIES ("replication_num" = "1"); + """ + qt_select """SELECT truncate(cast(round(8990.65 - 4556.2354, 2.4652) as Decimal(9,4)), 2);""" + qt_select """SELECT cast(round(round(465.56,min(-5.987)),2) as DECIMAL)""" + qt_select """ + SELECT truncate(100,2)<-2308.57 , cast(round(round(465.56,min(-5.987)),2) as DECIMAL) , cast(truncate(round(8990.65-4556.2354,2.4652),2)as DECIMAL) from test_round_1; + """ + + qt_select """ + SELECT truncate(123456789.123456789, -9); + """ + qt_select "SELECT round(10.12345)" qt_select "SELECT round(10.12345, 2)" qt_select "SELECT round_bankers(10.12345)" @@ -62,6 +90,11 @@ qt_select """ SELECT truncate(col1, 7), truncate(col2, 7), truncate(col3, 7) FROM `${tableName}`; """ qt_select """ SELECT round_bankers(col1, 7), round_bankers(col2, 7), round_bankers(col3, 7) FROM `${tableName}`; """ + qt_select_fix """ SELECT round(col1, 6.234), round(col2, 6.234), round(col3, 6.234) FROM `${tableName}`; """ + qt_select_fix """ SELECT floor(col1, 6.234), floor(col2, 6.234), floor(col3, 6.234) FROM `${tableName}`; """ + qt_select_fix """ SELECT truncate(col1, 6.234), truncate(col2, 6.234), truncate(col3, 6.234) FROM `${tableName}`; """ + qt_select_fix """ SELECT round_bankers(col1, 6.234), round_bankers(col2, 6.234), round_bankers(col3, 6.234) FROM `${tableName}`; """ + sql """ DROP TABLE IF EXISTS `${tableName}` """ sql "SET enable_nereids_planner=true" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org