This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 8194398028c [fix](round) Fix incorrect decimal scale inference in 
round functions (#34471)
8194398028c is described below

commit 8194398028c5a6731858dd96b0036f23bc3a4800
Author: zhiqiang <seuhezhiqi...@163.com>
AuthorDate: Fri May 10 16:09:46 2024 +0800

    [fix](round) Fix incorrect decimal scale inference in round functions 
(#34471)
    
    * FIX NEEDED
    
    * FORMAT
    
    * FORMAT
    
    * FIX TEST
---
 be/src/vec/functions/round.h                       | 114 ++++++++++++-------
 .../functions/ComputePrecisionForRound.java        |   7 +-
 .../sql_functions/math_functions/test_round.out    | 123 +++++++++++++++++++++
 .../sql_functions/math_functions/test_round.groovy |  35 +++++-
 4 files changed, 237 insertions(+), 42 deletions(-)

diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h
index 97a81f644ed..a17865914c4 100644
--- a/be/src/vec/functions/round.h
+++ b/be/src/vec/functions/round.h
@@ -21,13 +21,17 @@
 #pragma once
 
 #include <cstddef>
+#include <memory>
 
 #include "common/exception.h"
 #include "common/status.h"
 #include "vec/columns/column_const.h"
 #include "vec/columns/columns_number.h"
 #include "vec/common/assert_cast.h"
+#include "vec/core/column_with_type_and_name.h"
 #include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
 #include "vec/functions/function.h"
 #if defined(__SSE4_1__) || defined(__aarch64__)
 #include "util/sse_util.hpp"
@@ -430,7 +434,10 @@ struct Dispatcher {
                     FloatRoundingImpl<T, rounding_mode, scale_mode, 
tie_breaking_mode>,
                     IntegerRoundingImpl<T, rounding_mode, scale_mode, 
tie_breaking_mode>>>;
 
-    static ColumnPtr apply_vec_const(const IColumn* col_general, Int16 
scale_arg) {
+    // scale_arg: scale for function computation
+    // result_scale: scale for result decimal, this scale is got from planner
+    static ColumnPtr apply_vec_const(const IColumn* col_general, const Int16 
scale_arg,
+                                     [[maybe_unused]] Int16 result_scale) {
         if constexpr (IsNumber<T>) {
             const auto* const col = 
check_and_get_column<ColumnVector<T>>(col_general);
             auto col_res = ColumnVector<T>::create();
@@ -457,10 +464,7 @@ struct Dispatcher {
         } else if constexpr (IsDecimalNumber<T>) {
             const auto* const decimal_col = 
check_and_get_column<ColumnDecimal<T>>(col_general);
             const auto& vec_src = decimal_col->get_data();
-
-            UInt32 result_scale =
-                    std::min(static_cast<UInt32>(std::max(scale_arg, 
static_cast<Int16>(0))),
-                             decimal_col->get_scale());
+            const size_t input_rows_count = vec_src.size();
             auto col_res = ColumnDecimal<T>::create(vec_src.size(), 
result_scale);
             auto& vec_res = col_res->get_data();
 
@@ -468,6 +472,27 @@ struct Dispatcher {
                 FunctionRoundingImpl<ScaleMode::Negative>::apply(
                         decimal_col->get_data(), decimal_col->get_scale(), 
vec_res, scale_arg);
             }
+            // We need to always make sure result decimal's scale is as 
expected as its in plan
+            // So we need to append enough zero to result.
+
+            // Case 0: scale_arg <= -(integer part digits count)
+            //      do nothing, because result is 0
+            // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits 
count)
+            //      decimal parts has been erased, so add them back by 
multiply 10^(result_scale)
+            // Case 2: scale_arg > 0 && scale_arg < result_scale
+            //      decimal part now has scale_arg digits, so multiply 
10^(result_scale - scal_arg)
+            // Case 3: scale_arg >= input_scale
+            //      do nothing
+
+            if (scale_arg <= 0) {
+                for (size_t i = 0; i < input_rows_count; ++i) {
+                    vec_res[i].value *= int_exp10(result_scale);
+                }
+            } else if (scale_arg > 0 && scale_arg < result_scale) {
+                for (size_t i = 0; i < input_rows_count; ++i) {
+                    vec_res[i].value *= int_exp10(result_scale - scale_arg);
+                }
+            }
 
             return col_res;
         } else {
@@ -477,7 +502,9 @@ struct Dispatcher {
         }
     }
 
-    static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* 
col_scale) {
+    // result_scale: scale for result decimal, this scale is got from planner
+    static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* 
col_scale,
+                                   [[maybe_unused]] Int16 result_scale) {
         const auto& col_scale_i32 = assert_cast<const 
ColumnInt32&>(*col_scale);
         const size_t input_row_count = col_scale_i32.size();
         for (size_t i = 0; i < input_row_count; ++i) {
@@ -515,10 +542,8 @@ struct Dispatcher {
             return col_res;
         } else if constexpr (IsDecimalNumber<T>) {
             const auto* decimal_col = assert_cast<const 
ColumnDecimal<T>*>(col_general);
-
-            // ALWAYS use SAME scale with source Decimal column
             const Int32 input_scale = decimal_col->get_scale();
-            auto col_res = ColumnDecimal<T>::create(input_row_count, 
input_scale);
+            auto col_res = ColumnDecimal<T>::create(input_row_count, 
result_scale);
 
             for (size_t i = 0; i < input_row_count; ++i) {
                 DecimalRoundingImpl<T, rounding_mode, 
tie_breaking_mode>::apply(
@@ -534,15 +559,15 @@ struct Dispatcher {
                 //      do nothing, because result is 0
                 // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits 
count)
                 //      decimal parts has been erased, so add them back by 
multiply 10^(scale_arg)
-                // Case 2: scale_arg > 0 && scale_arg < decimal part digits 
count
-                //      decimal part now has scale_arg digits, so multiply 
10^(input_scale - scal_arg)
+                // Case 2: scale_arg > 0 && scale_arg < result_scale
+                //      decimal part now has scale_arg digits, so multiply 
10^(result_scale - scal_arg)
                 // Case 3: scale_arg >= input_scale
                 //      do nothing
                 const Int32 scale_arg = col_scale_i32.get_data()[i];
                 if (scale_arg <= 0) {
-                    col_res->get_element(i).value *= int_exp10(input_scale);
-                } else if (scale_arg > 0 && scale_arg < input_scale) {
-                    col_res->get_element(i).value *= int_exp10(input_scale - 
scale_arg);
+                    col_res->get_element(i).value *= int_exp10(result_scale);
+                } else if (scale_arg > 0 && scale_arg < result_scale) {
+                    col_res->get_element(i).value *= int_exp10(result_scale - 
scale_arg);
                 }
             }
 
@@ -554,8 +579,9 @@ struct Dispatcher {
         }
     }
 
-    static ColumnPtr apply_const_vec(const ColumnConst* const_col_general,
-                                     const IColumn* col_scale) {
+    // result_scale: scale for result decimal, this scale is got from planner
+    static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, 
const IColumn* col_scale,
+                                     [[maybe_unused]] Int16 result_scale) {
         const auto& col_scale_i32 = assert_cast<const 
ColumnInt32&>(*col_scale);
         const size_t input_rows_count = col_scale->size();
 
@@ -575,8 +601,7 @@ struct Dispatcher {
                     assert_cast<const 
ColumnDecimal<T>&>(const_col_general->get_data_column());
             const T& general_val = data_col_general.get_data()[0];
             Int32 input_scale = data_col_general.get_scale();
-
-            auto col_res = ColumnDecimal<T>::create(input_rows_count, 
input_scale);
+            auto col_res = ColumnDecimal<T>::create(input_rows_count, 
result_scale);
 
             for (size_t i = 0; i < input_rows_count; ++i) {
                 DecimalRoundingImpl<T, rounding_mode, 
tie_breaking_mode>::apply(
@@ -592,15 +617,15 @@ struct Dispatcher {
                 //      do nothing, because result is 0
                 // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits 
count)
                 //      decimal parts has been erased, so add them back by 
multiply 10^(scale_arg)
-                // Case 2: scale_arg > 0 && scale_arg < decimal part digits 
count
-                //      decimal part now has scale_arg digits, so multiply 
10^(input_scale - scal_arg)
+                // Case 2: scale_arg > 0 && scale_arg < result_scale
+                //      decimal part now has scale_arg digits, so multiply 
10^(result_scale - scal_arg)
                 // Case 3: scale_arg >= input_scale
                 //      do nothing
                 const Int32 scale_arg = col_scale_i32.get_data()[i];
                 if (scale_arg <= 0) {
-                    col_res->get_element(i).value *= int_exp10(input_scale);
-                } else if (scale_arg > 0 && scale_arg < input_scale) {
-                    col_res->get_element(i).value *= int_exp10(input_scale - 
scale_arg);
+                    col_res->get_element(i).value *= int_exp10(result_scale);
+                } else if (scale_arg > 0 && scale_arg < result_scale) {
+                    col_res->get_element(i).value *= int_exp10(result_scale - 
scale_arg);
                 }
             }
 
@@ -679,26 +704,23 @@ public:
         return Status::OK();
     }
 
-    /// SELECT number, truncate(123.345, 1) FROM number("numbers"="10")
-    /// should NOT behave like two column arguments, so we can not use const 
column default implementation
-    bool use_default_implementation_for_constants() const override { return 
false; }
+    bool use_default_implementation_for_constants() const override { return 
true; }
 
-    //// We moved and optimized the execute_impl logic of function_truncate.h 
from PR#32746,
-    //// as well as make it suitable for all functions.
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
         const ColumnWithTypeAndName& column_general = 
block.get_by_position(arguments[0]);
+        ColumnWithTypeAndName& column_result = block.get_by_position(result);
+        const DataTypePtr result_type = block.get_by_position(result).type;
         const bool is_col_general_const = 
is_column_const(*column_general.column);
         const auto* col_general = is_col_general_const
                                           ? assert_cast<const 
ColumnConst&>(*column_general.column)
                                                     .get_data_column_ptr()
                                           : column_general.column.get();
-
         ColumnPtr res;
 
         /// potential argument types:
         /// if the SECOND argument is MISSING(would be considered as ZERO 
const) or CONST, then we have the following type:
-        ///    1. func(Column), func(ColumnConst), func(Column, ColumnConst), 
func(ColumnConst, ColumnConst)
+        ///    1. func(Column), func(Column, ColumnConst)
         /// otherwise, the SECOND arugment is COLUMN, we have another type:
         ///    2. func(Column, Column), func(ColumnConst, Column)
 
@@ -706,6 +728,23 @@ public:
             using Types = std::decay_t<decltype(types)>;
             using DataType = typename Types::LeftType;
 
+            // For decimal, we will always make sure result Decimal has 
exactly same precision and scale with
+            // arguments from query plan.
+            Int16 result_scale = 0;
+            if constexpr (IsDataTypeDecimal<DataType>) {
+                if (column_result.type->get_type_id() == TypeIndex::Nullable) {
+                    if (auto nullable_type = std::dynamic_pointer_cast<const 
DataTypeNullable>(
+                                column_result.type)) {
+                        result_scale = 
nullable_type->get_nested_type()->get_scale();
+                    } else {
+                        throw doris::Exception(ErrorCode::INTERNAL_ERROR,
+                                               "Illegal nullable column");
+                    }
+                } else {
+                    result_scale = column_result.type->get_scale();
+                }
+            }
+
             if constexpr (IsDataTypeNumber<DataType> || 
IsDataTypeDecimal<DataType>) {
                 using FieldType = typename DataType::FieldType;
                 if (arguments.size() == 1 ||
@@ -718,23 +757,20 @@ public:
                     }
 
                     res = Dispatcher<FieldType, rounding_mode, 
tie_breaking_mode>::apply_vec_const(
-                            col_general, scale_arg);
-
-                    if (is_col_general_const) {
-                        // Important, make sure the result column has the same 
size as the input column
-                        res = ColumnConst::create(std::move(res), 
input_rows_count);
-                    }
+                            col_general, scale_arg, result_scale);
                 } else {
                     // the SECOND arugment is COLUMN
                     if (is_col_general_const) {
                         res = Dispatcher<FieldType, rounding_mode, 
tie_breaking_mode>::
                                 apply_const_vec(
                                         &assert_cast<const 
ColumnConst&>(*column_general.column),
-                                        
block.get_by_position(arguments[1]).column.get());
+                                        
block.get_by_position(arguments[1]).column.get(),
+                                        result_scale);
                     } else {
                         res = Dispatcher<FieldType, rounding_mode, 
tie_breaking_mode>::
                                 apply_vec_vec(col_general,
-                                              
block.get_by_position(arguments[1]).column.get());
+                                              
block.get_by_position(arguments[1]).column.get(),
+                                              result_scale);
                     }
                 }
                 return true;
@@ -758,7 +794,7 @@ public:
                                            column_general.type->get_name(), 
name);
         }
 
-        block.replace_by_position(result, std::move(res));
+        column_result.column = std::move(res);
         return Status::OK();
     }
 };
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
index eedbfea6df9..b47804e23ff 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
@@ -37,9 +37,12 @@ public interface ComputePrecisionForRound extends 
ComputePrecision {
             Expression floatLength = getArgument(1);
             int scale;
 
-            if (floatLength.isLiteral() || (floatLength instanceof Cast && 
floatLength.child(0).isLiteral()
+            // If scale arg is an integer literal, or it is a cast(Integer as 
Integer)
+            // then we will try to use its value as result scale
+            // In any other cases, we will make sure result decimal has same 
scale with input.
+            if ((floatLength.isLiteral() && floatLength.getDataType() 
instanceof Int32OrLessType)
+                    || (floatLength instanceof Cast && 
floatLength.child(0).isLiteral()
                     && floatLength.child(0).getDataType() instanceof 
Int32OrLessType)) {
-                // Scale argument is a literal or cast from other literal
                 if (floatLength instanceof Cast) {
                     scale = ((IntegerLikeLiteral) 
floatLength.child(0)).getIntValue();
                 } else {
diff --git 
a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out 
b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
index 1ebc9cf5b89..ccdd9551f80 100644
--- a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
+++ b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
@@ -1,4 +1,115 @@
 -- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select --
+123.100
+
+-- !select --
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+
+-- !select --
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+
+-- !select --
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+
+-- !select --
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+
+-- !select --
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+
+-- !select --
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+
+-- !select --
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+
+-- !select --
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+
+-- !select --
+4434.41
+
+-- !select --
+0
+
+-- !select --
+false  \N      4434
+
+-- !select --
+0
+
 -- !select --
 10
 
@@ -97,6 +208,18 @@
 -- !select --
 16.025 16.02500        16.02500
 
+-- !select_fix --
+16.025 16.02500        16.02500
+
+-- !select_fix --
+16.025 16.02500        16.02500
+
+-- !select_fix --
+16.025 16.02500        16.02500
+
+-- !select_fix --
+16.025 16.02500        16.02500
+
 -- !nereids_round_arg1 --
 10
 
diff --git 
a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
 
b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
index 1d8bbb9df49..da361e15938 100644
--- 
a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
+++ 
b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
@@ -15,7 +15,35 @@
 // specific language governing permissions and limitations
 // under the License.
 
-    suite("test_round") {
+suite("test_round") {
+    sql "set enable_fold_constant_by_be=false;"
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+
+    qt_select "SELECT round(123.123, 1.123);"
+    qt_select """SELECT round(123.123, 1.123) FROM numbers("number"="10");"""
+    qt_select """SELECT round(123.123, -1.123) FROM numbers("number"="10");"""
+    qt_select """SELECT truncate(123.123, 1.123) FROM 
numbers("number"="10");"""
+    qt_select """SELECT truncate(123.123, -1.123) FROM 
numbers("number"="10");"""
+    qt_select """SELECT ceil(123.123, 1.123) FROM numbers("number"="10");"""
+    qt_select """SELECT ceil(123.123, -1.123) FROM numbers("number"="10");"""
+    qt_select """SELECT round_bankers(123.123, 1.123) FROM 
numbers("number"="10");"""
+    qt_select """SELECT round_bankers(123.123, -1.123) FROM 
numbers("number"="10");"""
+    sql """drop table if exists test_round_1; """
+    sql """
+        create table test_round_1(big_key bigint not NULL)
+                DISTRIBUTED BY HASH(big_key) BUCKETS 1 PROPERTIES 
("replication_num" = "1");
+    """
+    qt_select """SELECT truncate(cast(round(8990.65 - 4556.2354, 2.4652) as 
Decimal(9,4)), 2);"""
+    qt_select """SELECT cast(round(round(465.56,min(-5.987)),2) as DECIMAL)"""
+    qt_select """
+        SELECT truncate(100,2)<-2308.57 , 
cast(round(round(465.56,min(-5.987)),2) as DECIMAL) , 
cast(truncate(round(8990.65-4556.2354,2.4652),2)as DECIMAL) from test_round_1;
+    """
+
+    qt_select """
+        SELECT truncate(123456789.123456789, -9);
+    """
+
     qt_select "SELECT round(10.12345)"
     qt_select "SELECT round(10.12345, 2)"
     qt_select "SELECT round_bankers(10.12345)"
@@ -62,6 +90,11 @@
     qt_select """ SELECT truncate(col1, 7), truncate(col2, 7), truncate(col3, 
7) FROM `${tableName}`; """
     qt_select """ SELECT round_bankers(col1, 7), round_bankers(col2, 7), 
round_bankers(col3, 7) FROM `${tableName}`; """
 
+    qt_select_fix """ SELECT round(col1, 6.234), round(col2, 6.234), 
round(col3, 6.234) FROM `${tableName}`; """
+    qt_select_fix """ SELECT floor(col1, 6.234), floor(col2, 6.234), 
floor(col3, 6.234) FROM `${tableName}`; """
+    qt_select_fix """ SELECT truncate(col1, 6.234), truncate(col2, 6.234), 
truncate(col3, 6.234) FROM `${tableName}`; """
+    qt_select_fix """ SELECT round_bankers(col1, 6.234), round_bankers(col2, 
6.234), round_bankers(col3, 6.234) FROM `${tableName}`; """
+
     sql """ DROP TABLE IF EXISTS `${tableName}` """
 
     sql "SET enable_nereids_planner=true"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to