This is an automated email from the ASF dual-hosted git repository.

zhangstar333 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new a36e088e781 [enhancement](function truncate) truncate can use column 
as scale argument (#32746)
a36e088e781 is described below

commit a36e088e781f2fd47f4c3672b314f82a23d6e16f
Author: zhiqiang <seuhezhiqi...@163.com>
AuthorDate: Tue Apr 2 14:56:26 2024 +0800

    [enhancement](function truncate) truncate can use column as scale argument 
(#32746)
    
    Co-authored-by: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
---
 be/src/vec/functions/function_truncate.h           | 245 ++++++++++++++
 be/src/vec/functions/math.cpp                      |  23 +-
 be/src/vec/functions/round.h                       | 224 ++++++++++++-
 .../function/function_truncate_decimal_test.cpp    | 370 +++++++++++++++++++++
 .../apache/doris/analysis/FunctionCallExpr.java    |  32 +-
 .../functions/ComputePrecisionForRound.java        |  40 ++-
 .../math_functions/test_function_truncate.out      | 101 ++++++
 .../math_functions/test_function_truncate.groovy   | 132 ++++++++
 8 files changed, 1136 insertions(+), 31 deletions(-)

diff --git a/be/src/vec/functions/function_truncate.h 
b/be/src/vec/functions/function_truncate.h
new file mode 100644
index 00000000000..e29bc99c041
--- /dev/null
+++ b/be/src/vec/functions/function_truncate.h
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstddef>
+#include <functional>
+#include <type_traits>
+#include <utility>
+
+#include "common/exception.h"
+#include "common/status.h"
+#include "olap/olap_common.h"
+#include "round.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_const.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/call_on_type_index.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_number.h"
+
+namespace doris::vectorized {
+
+struct TruncateFloatOneArgImpl {
+    static constexpr auto name = "truncate";
+    static DataTypes get_variadic_argument_types() { return 
{std::make_shared<DataTypeFloat64>()}; }
+};
+
+struct TruncateFloatTwoArgImpl {
+    static constexpr auto name = "truncate";
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<DataTypeFloat64>(), 
std::make_shared<DataTypeInt32>()};
+    }
+};
+
+struct TruncateDecimalOneArgImpl {
+    static constexpr auto name = "truncate";
+    static DataTypes get_variadic_argument_types() {
+        // All Decimal types are named Decimal, and real scale will be passed 
as type argument for execute function
+        // So we can just register Decimal32 here
+        return {std::make_shared<DataTypeDecimal<Decimal32>>(9, 0)};
+    }
+};
+
+struct TruncateDecimalTwoArgImpl {
+    static constexpr auto name = "truncate";
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<DataTypeDecimal<Decimal32>>(9, 0),
+                std::make_shared<DataTypeInt32>()};
+    }
+};
+
+template <typename Impl>
+class FunctionTruncate : public FunctionRounding<Impl, RoundingMode::Trunc, 
TieBreakingMode::Auto> {
+public:
+    static FunctionPtr create() { return std::make_shared<FunctionTruncate>(); 
}
+
+    ColumnNumbers get_arguments_that_are_always_constant() const override { 
return {}; }
+    // SELECT number, truncate(123.345, 1) FROM number("numbers"="10")
+    // should NOT behave like two column arguments, so we can not use const 
column default implementation
+    bool use_default_implementation_for_constants() const override { return 
false; }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) const override 
{
+        const ColumnWithTypeAndName& column_general = 
block.get_by_position(arguments[0]);
+        ColumnPtr res;
+
+        // potential argument types:
+        // 0. truncate(ColumnConst, ColumnConst)
+        // 1. truncate(Column), truncate(Column, ColumnConst)
+        // 2. truncate(Column, Column)
+        // 3. truncate(ColumnConst, Column)
+
+        if (arguments.size() == 2 && 
is_column_const(*block.get_by_position(arguments[0]).column) &&
+            is_column_const(*block.get_by_position(arguments[1]).column)) {
+            // truncate(ColumnConst, ColumnConst)
+            auto col_general =
+                    assert_cast<const 
ColumnConst&>(*column_general.column).get_data_column_ptr();
+            Int16 scale_arg = 0;
+            RETURN_IF_ERROR(FunctionTruncate<Impl>::get_scale_arg(
+                    block.get_by_position(arguments[1]), &scale_arg));
+
+            auto call = [&](const auto& types) -> bool {
+                using Types = std::decay_t<decltype(types)>;
+                using DataType = typename Types::LeftType;
+
+                if constexpr (IsDataTypeNumber<DataType> || 
IsDataTypeDecimal<DataType>) {
+                    using FieldType = typename DataType::FieldType;
+                    res = Dispatcher<FieldType, RoundingMode::Trunc,
+                                     
TieBreakingMode::Auto>::apply_vec_const(col_general,
+                                                                             
scale_arg);
+                    return true;
+                }
+
+                return false;
+            };
+
+#if !defined(__SSE4_1__) && !defined(__aarch64__)
+            /// In case of "nearbyint" function is used, we should ensure the 
expected rounding mode for the Banker's rounding.
+            /// Actually it is by default. But we will set it just in case.
+
+            if constexpr (rounding_mode == RoundingMode::Round) {
+                if (0 != fesetround(FE_TONEAREST)) {
+                    return Status::InvalidArgument("Cannot set floating point 
rounding mode");
+                }
+            }
+#endif
+
+            if 
(!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) {
+                return Status::InvalidArgument("Invalid argument type {} for 
function {}",
+                                               
column_general.type->get_name(), "truncate");
+            }
+            // Important, make sure the result column has the same size as the 
input column
+            res = ColumnConst::create(std::move(res), input_rows_count);
+        } else if (arguments.size() == 1 ||
+                   (arguments.size() == 2 &&
+                    
is_column_const(*block.get_by_position(arguments[1]).column))) {
+            // truncate(Column) or truncate(Column, ColumnConst)
+            Int16 scale_arg = 0;
+            if (arguments.size() == 2) {
+                RETURN_IF_ERROR(FunctionTruncate<Impl>::get_scale_arg(
+                        block.get_by_position(arguments[1]), &scale_arg));
+            }
+
+            auto call = [&](const auto& types) -> bool {
+                using Types = std::decay_t<decltype(types)>;
+                using DataType = typename Types::LeftType;
+
+                if constexpr (IsDataTypeNumber<DataType> || 
IsDataTypeDecimal<DataType>) {
+                    using FieldType = typename DataType::FieldType;
+                    res = Dispatcher<FieldType, RoundingMode::Trunc, 
TieBreakingMode::Auto>::
+                            apply_vec_const(column_general.column.get(), 
scale_arg);
+                    return true;
+                }
+
+                return false;
+            };
+#if !defined(__SSE4_1__) && !defined(__aarch64__)
+            /// In case of "nearbyint" function is used, we should ensure the 
expected rounding mode for the Banker's rounding.
+            /// Actually it is by default. But we will set it just in case.
+
+            if constexpr (rounding_mode == RoundingMode::Round) {
+                if (0 != fesetround(FE_TONEAREST)) {
+                    return Status::InvalidArgument("Cannot set floating point 
rounding mode");
+                }
+            }
+#endif
+
+            if 
(!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) {
+                return Status::InvalidArgument("Invalid argument type {} for 
function {}",
+                                               
column_general.type->get_name(), "truncate");
+            }
+
+        } else if 
(is_column_const(*block.get_by_position(arguments[0]).column)) {
+            // truncate(ColumnConst, Column)
+            const ColumnWithTypeAndName& column_scale = 
block.get_by_position(arguments[1]);
+            const ColumnConst& const_col_general =
+                    assert_cast<const ColumnConst&>(*column_general.column);
+
+            auto call = [&](const auto& types) -> bool {
+                using Types = std::decay_t<decltype(types)>;
+                using DataType = typename Types::LeftType;
+
+                if constexpr (IsDataTypeNumber<DataType> || 
IsDataTypeDecimal<DataType>) {
+                    using FieldType = typename DataType::FieldType;
+                    res = Dispatcher<FieldType, RoundingMode::Trunc, 
TieBreakingMode::Auto>::
+                            apply_const_vec(&const_col_general, 
column_scale.column.get());
+                    return true;
+                }
+
+                return false;
+            };
+
+#if !defined(__SSE4_1__) && !defined(__aarch64__)
+            /// In case of "nearbyint" function is used, we should ensure the 
expected rounding mode for the Banker's rounding.
+            /// Actually it is by default. But we will set it just in case.
+
+            if constexpr (rounding_mode == RoundingMode::Round) {
+                if (0 != fesetround(FE_TONEAREST)) {
+                    return Status::InvalidArgument("Cannot set floating point 
rounding mode");
+                }
+            }
+#endif
+
+            if 
(!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) {
+                return Status::InvalidArgument("Invalid argument type {} for 
function {}",
+                                               
column_general.type->get_name(), "truncate");
+            }
+        } else {
+            // truncate(Column, Column)
+            const ColumnWithTypeAndName& column_scale = 
block.get_by_position(arguments[1]);
+
+            auto call = [&](const auto& types) -> bool {
+                using Types = std::decay_t<decltype(types)>;
+                using DataType = typename Types::LeftType;
+
+                if constexpr (IsDataTypeNumber<DataType> || 
IsDataTypeDecimal<DataType>) {
+                    using FieldType = typename DataType::FieldType;
+                    res = Dispatcher<FieldType, RoundingMode::Trunc, 
TieBreakingMode::Auto>::
+                            apply_vec_vec(column_general.column.get(), 
column_scale.column.get());
+                    return true;
+                }
+                return false;
+            };
+
+#if !defined(__SSE4_1__) && !defined(__aarch64__)
+            /// In case of "nearbyint" function is used, we should ensure the 
expected rounding mode for the Banker's rounding.
+            /// Actually it is by default. But we will set it just in case.
+
+            if constexpr (rounding_mode == RoundingMode::Round) {
+                if (0 != fesetround(FE_TONEAREST)) {
+                    return Status::InvalidArgument("Cannot set floating point 
rounding mode");
+                }
+            }
+#endif
+
+            if 
(!call_on_index_and_data_type<void>(column_general.type->get_type_id(), call)) {
+                return Status::InvalidArgument("Invalid argument type {} for 
function {}",
+                                               
column_general.type->get_name(), "truncate");
+            }
+        }
+
+        block.replace_by_position(result, std::move(res));
+        return Status::OK();
+    }
+};
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp
index dc815cf74e5..c0dfe761576 100644
--- a/be/src/vec/functions/math.cpp
+++ b/be/src/vec/functions/math.cpp
@@ -46,6 +46,7 @@
 #include "vec/functions/function_math_unary.h"
 #include "vec/functions/function_string.h"
 #include "vec/functions/function_totype.h"
+#include "vec/functions/function_truncate.h"
 #include "vec/functions/function_unary_arithmetic.h"
 #include "vec/functions/round.h"
 #include "vec/functions/simple_function_factory.h"
@@ -392,16 +393,14 @@ struct DecimalRoundOneImpl {
 // TODO: Now math may cause one thread compile time too long, because the 
function in math
 // so mush. Split it to speed up compile time in the future
 void register_function_math(SimpleFunctionFactory& factory) {
-#define REGISTER_ROUND_FUNCTIONS(IMPL)                                         
                  \
-    factory.register_function<                                                 
                  \
-            FunctionRounding<IMPL<RoundName>, RoundingMode::Round, 
TieBreakingMode::Auto>>();    \
-    factory.register_function<                                                 
                  \
-            FunctionRounding<IMPL<FloorName>, RoundingMode::Floor, 
TieBreakingMode::Auto>>();    \
-    factory.register_function<                                                 
                  \
-            FunctionRounding<IMPL<CeilName>, RoundingMode::Ceil, 
TieBreakingMode::Auto>>();      \
-    factory.register_function<                                                 
                  \
-            FunctionRounding<IMPL<TruncateName>, RoundingMode::Trunc, 
TieBreakingMode::Auto>>(); \
-    factory.register_function<FunctionRounding<IMPL<RoundBankersName>, 
RoundingMode::Round,      \
+#define REGISTER_ROUND_FUNCTIONS(IMPL)                                         
               \
+    factory.register_function<                                                 
               \
+            FunctionRounding<IMPL<RoundName>, RoundingMode::Round, 
TieBreakingMode::Auto>>(); \
+    factory.register_function<                                                 
               \
+            FunctionRounding<IMPL<FloorName>, RoundingMode::Floor, 
TieBreakingMode::Auto>>(); \
+    factory.register_function<                                                 
               \
+            FunctionRounding<IMPL<CeilName>, RoundingMode::Ceil, 
TieBreakingMode::Auto>>();   \
+    factory.register_function<FunctionRounding<IMPL<RoundBankersName>, 
RoundingMode::Round,   \
                                                TieBreakingMode::Bankers>>();
 
     REGISTER_ROUND_FUNCTIONS(DecimalRoundOneImpl)
@@ -445,5 +444,9 @@ void register_function_math(SimpleFunctionFactory& factory) 
{
     factory.register_function<FunctionRadians>();
     factory.register_function<FunctionDegrees>();
     factory.register_function<FunctionBin>();
+    factory.register_function<FunctionTruncate<TruncateFloatOneArgImpl>>();
+    factory.register_function<FunctionTruncate<TruncateFloatTwoArgImpl>>();
+    factory.register_function<FunctionTruncate<TruncateDecimalOneArgImpl>>();
+    factory.register_function<FunctionTruncate<TruncateDecimalTwoArgImpl>>();
 }
 } // namespace doris::vectorized
diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h
index 7e48b8e9306..a9d1e7a019c 100644
--- a/be/src/vec/functions/round.h
+++ b/be/src/vec/functions/round.h
@@ -20,8 +20,15 @@
 
 #pragma once
 
+#include <cstddef>
+#include <cstdint>
+
+#include "common/exception.h"
+#include "common/status.h"
 #include "vec/columns/column_const.h"
 #include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/types.h"
 #include "vec/functions/function.h"
 #if defined(__SSE4_1__) || defined(__aarch64__)
 #include "util/sse_util.hpp"
@@ -176,6 +183,23 @@ public:
             memcpy(out.data(), in.data(), in.size() * sizeof(T));
         }
     }
+
+    // NOTE: This function is only tested for truncate
+    // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW 
EXACTLY WHAT YOU ARE DOING !!!
+    static NO_INLINE void apply(const NativeType& in, UInt32 in_scale, 
NativeType& out,
+                                Int16 out_scale) {
+        Int16 scale_arg = in_scale - out_scale;
+        if (scale_arg > 0) {
+            size_t scale = int_exp10(scale_arg);
+            if (out_scale < 0) {
+                Op::compute(&in, scale, &out, int_exp10(-out_scale));
+            } else {
+                Op::compute(&in, scale, &out, 1);
+            }
+        } else {
+            memcpy(&out, &in, sizeof(NativeType));
+        }
+    }
 };
 
 template <TieBreakingMode tie_breaking_mode>
@@ -314,6 +338,11 @@ public:
             memcpy(p_out, &tmp_dst, tail_size_bytes);
         }
     }
+
+    static NO_INLINE void apply(const T& in, size_t scale, T& out) {
+        auto mm_scale = Op::prepare(scale);
+        Op::compute(&in, mm_scale, &out);
+    }
 };
 
 template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
@@ -386,6 +415,10 @@ public:
             __builtin_unreachable();
         }
     }
+
+    static NO_INLINE void apply(const T& in, size_t scale, T& out) {
+        Op::compute(&in, scale, &out, 1);
+    }
 };
 
 /** Select the appropriate processing algorithm depending on the scale.
@@ -400,7 +433,7 @@ struct Dispatcher {
                     FloatRoundingImpl<T, rounding_mode, scale_mode, 
tie_breaking_mode>,
                     IntegerRoundingImpl<T, rounding_mode, scale_mode, 
tie_breaking_mode>>>;
 
-    static ColumnPtr apply(const IColumn* col_general, Int16 scale_arg) {
+    static ColumnPtr apply_vec_const(const IColumn* col_general, Int16 
scale_arg) {
         if constexpr (IsNumber<T>) {
             const auto* const col = 
check_and_get_column<ColumnVector<T>>(col_general);
             auto col_res = ColumnVector<T>::create();
@@ -446,6 +479,179 @@ struct Dispatcher {
             return nullptr;
         }
     }
+
+    // NOTE: This function is only tested for truncate
+    // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW 
EXACTLY WHAT YOU ARE DOING !!!
+    static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* 
col_scale) {
+        if constexpr (rounding_mode != RoundingMode::Trunc) {
+            throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
+                                   "Using column as scale is only supported 
for function truncate");
+        }
+
+        const ColumnInt32& col_scale_i32 = assert_cast<const 
ColumnInt32&>(*col_scale);
+        const size_t input_row_count = col_scale_i32.size();
+        for (size_t i = 0; i < input_row_count; ++i) {
+            const Int32 scale_arg = col_scale_i32.get_data()[i];
+            if (scale_arg > std::numeric_limits<Int16>::max() ||
+                scale_arg < std::numeric_limits<Int16>::min()) {
+                throw doris::Exception(ErrorCode::OUT_OF_BOUND,
+                                       "Scale argument for function is out of 
bound: {}",
+                                       scale_arg);
+            }
+        }
+
+        if constexpr (IsNumber<T>) {
+            const auto* col = assert_cast<const ColumnVector<T>*>(col_general);
+            auto col_res = ColumnVector<T>::create();
+            typename ColumnVector<T>::Container& vec_res = col_res->get_data();
+            vec_res.resize(input_row_count);
+
+            for (size_t i = 0; i < input_row_count; ++i) {
+                const Int32 scale_arg = col_scale_i32.get_data()[i];
+                if (scale_arg == 0) {
+                    size_t scale = 1;
+                    
FunctionRoundingImpl<ScaleMode::Zero>::apply(col->get_data()[i], scale,
+                                                                 vec_res[i]);
+                } else if (scale_arg > 0) {
+                    size_t scale = int_exp10(scale_arg);
+                    
FunctionRoundingImpl<ScaleMode::Positive>::apply(col->get_data()[i], scale,
+                                                                     
vec_res[i]);
+                } else {
+                    size_t scale = int_exp10(-scale_arg);
+                    
FunctionRoundingImpl<ScaleMode::Negative>::apply(col->get_data()[i], scale,
+                                                                     
vec_res[i]);
+                }
+            }
+            return col_res;
+        } else if constexpr (IsDecimalNumber<T>) {
+            const auto* decimal_col = assert_cast<const 
ColumnDecimal<T>*>(col_general);
+
+            // For truncate, ALWAYS use SAME scale with source Decimal column
+            const Int32 input_scale = decimal_col->get_scale();
+            auto col_res = ColumnDecimal<T>::create(input_row_count, 
input_scale);
+
+            for (size_t i = 0; i < input_row_count; ++i) {
+                DecimalRoundingImpl<T, rounding_mode, 
tie_breaking_mode>::apply(
+                        decimal_col->get_element(i).value, input_scale,
+                        col_res->get_element(i).value, 
col_scale_i32.get_data()[i]);
+            }
+
+            for (size_t i = 0; i < input_row_count; ++i) {
+                // For truncate(ColumnDecimal, ColumnInt32), we should always 
have same scale with source Decimal column
+                // So we need this check to make sure the result have correct 
digits count
+                //
+                // Case 0: scale_arg <= -(integer part digits count)
+                //      do nothing, because result is 0
+                // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits 
count)
+                //      decimal parts has been erased, so add them back by 
multiply 10^(scale_arg)
+                // Case 2: scale_arg > 0 && scale_arg < decimal part digits 
count
+                //      decimal part now has scale_arg digits, so multiply 
10^(input_scale - scal_arg)
+                // Case 3: scale_arg >= input_scale
+                //      do nothing
+                const Int32 scale_arg = col_scale_i32.get_data()[i];
+                if (scale_arg <= 0) {
+                    col_res->get_element(i).value *= int_exp10(input_scale);
+                } else if (scale_arg > 0 && scale_arg < input_scale) {
+                    col_res->get_element(i).value *= int_exp10(input_scale - 
scale_arg);
+                }
+            }
+
+            return col_res;
+        } else {
+            LOG(FATAL) << "__builtin_unreachable";
+            __builtin_unreachable();
+            return nullptr;
+        }
+    }
+
+    // NOTE: This function is only tested for truncate
+    // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW 
EXACTLY WHAT YOU ARE DOING !!! only test for truncate
+    static ColumnPtr apply_const_vec(const ColumnConst* const_col_general,
+                                     const IColumn* col_scale) {
+        if constexpr (rounding_mode != RoundingMode::Trunc) {
+            throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
+                                   "Using column as scale is only supported 
for function truncate");
+        }
+
+        const ColumnInt32& col_scale_i32 = assert_cast<const 
ColumnInt32&>(*col_scale);
+        const size_t input_rows_count = col_scale->size();
+
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            const Int32 scale_arg = col_scale_i32.get_data()[i];
+
+            if (scale_arg > std::numeric_limits<Int16>::max() ||
+                scale_arg < std::numeric_limits<Int16>::min()) {
+                throw doris::Exception(ErrorCode::OUT_OF_BOUND,
+                                       "Scale argument for function is out of 
bound: {}",
+                                       scale_arg);
+            }
+        }
+
+        if constexpr (IsDecimalNumber<T>) {
+            const ColumnDecimal<T>& data_col_general =
+                    assert_cast<const 
ColumnDecimal<T>&>(const_col_general->get_data_column());
+            const T& general_val = data_col_general.get_data()[0];
+            Int32 input_scale = data_col_general.get_scale();
+
+            auto col_res = ColumnDecimal<T>::create(input_rows_count, 
input_scale);
+
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                DecimalRoundingImpl<T, rounding_mode, 
tie_breaking_mode>::apply(
+                        general_val, input_scale, 
col_res->get_element(i).value,
+                        col_scale_i32.get_data()[i]);
+            }
+
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                // For truncate(ColumnDecimal, ColumnInt32), we should always 
have same scale with source Decimal column
+                // So we need this check to make sure the result have correct 
digits count
+                //
+                // Case 0: scale_arg <= -(integer part digits count)
+                //      do nothing, because result is 0
+                // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits 
count)
+                //      decimal parts has been erased, so add them back by 
multiply 10^(scale_arg)
+                // Case 2: scale_arg > 0 && scale_arg < decimal part digits 
count
+                //      decimal part now has scale_arg digits, so multiply 
10^(input_scale - scal_arg)
+                // Case 3: scale_arg >= input_scale
+                //      do nothing
+                const Int32 scale_arg = col_scale_i32.get_data()[i];
+                if (scale_arg <= 0) {
+                    col_res->get_element(i).value *= int_exp10(input_scale);
+                } else if (scale_arg > 0 && scale_arg < input_scale) {
+                    col_res->get_element(i).value *= int_exp10(input_scale - 
scale_arg);
+                }
+            }
+
+            return col_res;
+        } else if constexpr (IsNumber<T>) {
+            const ColumnVector<T>& data_col_general =
+                    assert_cast<const 
ColumnVector<T>&>(const_col_general->get_data_column());
+            const T& general_val = data_col_general.get_data()[0];
+            auto col_res = ColumnVector<T>::create(input_rows_count);
+            typename ColumnVector<T>::Container& vec_res = col_res->get_data();
+
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                const Int16 scale_arg = col_scale_i32.get_data()[i];
+                if (scale_arg == 0) {
+                    size_t scale = 1;
+                    FunctionRoundingImpl<ScaleMode::Zero>::apply(general_val, 
scale, vec_res[i]);
+                } else if (scale_arg > 0) {
+                    size_t scale = int_exp10(col_scale_i32.get_data()[i]);
+                    
FunctionRoundingImpl<ScaleMode::Positive>::apply(general_val, scale,
+                                                                     
vec_res[i]);
+                } else {
+                    size_t scale = int_exp10(-col_scale_i32.get_data()[i]);
+                    
FunctionRoundingImpl<ScaleMode::Negative>::apply(general_val, scale,
+                                                                     
vec_res[i]);
+                }
+            }
+
+            return col_res;
+        } else {
+            throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
+                                   "Unsupported column {} for function 
truncate",
+                                   const_col_general->get_name());
+        }
+    }
 };
 
 template <typename Impl, RoundingMode rounding_mode, TieBreakingMode 
tie_breaking_mode>
@@ -476,17 +682,17 @@ public:
     static Status get_scale_arg(const ColumnWithTypeAndName& arguments, Int16* 
scale) {
         const IColumn& scale_column = *arguments.column;
 
-        Int32 scale64 = static_cast<const ColumnInt32&>(
-                                static_cast<const 
ColumnConst*>(&scale_column)->get_data_column())
-                                .get_element(0);
+        Int32 scale_arg = assert_cast<const ColumnInt32&>(
+                                  assert_cast<const 
ColumnConst*>(&scale_column)->get_data_column())
+                                  .get_element(0);
 
-        if (scale64 > std::numeric_limits<Int16>::max() ||
-            scale64 < std::numeric_limits<Int16>::min()) {
+        if (scale_arg > std::numeric_limits<Int16>::max() ||
+            scale_arg < std::numeric_limits<Int16>::min()) {
             return Status::InvalidArgument("Scale argument for function {} is 
out of bound: {}",
-                                           name, scale64);
+                                           name, scale_arg);
         }
 
-        *scale = scale64;
+        *scale = scale_arg;
         return Status::OK();
     }
 
@@ -507,7 +713,7 @@ public:
 
             if constexpr (IsDataTypeNumber<DataType> || 
IsDataTypeDecimal<DataType>) {
                 using FieldType = typename DataType::FieldType;
-                res = Dispatcher<FieldType, rounding_mode, 
tie_breaking_mode>::apply(
+                res = Dispatcher<FieldType, rounding_mode, 
tie_breaking_mode>::apply_vec_const(
                         column.column.get(), scale_arg);
                 return true;
             }
diff --git a/be/test/vec/function/function_truncate_decimal_test.cpp 
b/be/test/vec/function/function_truncate_decimal_test.cpp
new file mode 100644
index 00000000000..36fcaa14e67
--- /dev/null
+++ b/be/test/vec/function/function_truncate_decimal_test.cpp
@@ -0,0 +1,370 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest-message.h>
+#include <gtest/gtest.h>
+
+#include <climits>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <iomanip>
+#include <limits>
+#include <map>
+#include <memory>
+#include <string>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "function_test_util.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_const.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/column_numbers.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/functions/function_truncate.h"
+
+namespace doris::vectorized {
+// {precision, scale} -> {input, scale_arg, expectation}
+using TestDataSet = std::map<std::pair<int, int>, 
std::vector<std::tuple<Int128, int, Int128>>>;
+
+const static TestDataSet truncate_decimal32_cases = {
+        {{1, 0},
+         {
+                 {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, 
{1, -5, 0},
+                 {1, -4, 0},  {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 1},  
{1, 1, 1},
+                 {1, 2, 1},   {1, 3, 1},  {1, 4, 1},  {1, 5, 1},  {1, 6, 1},  
{1, 7, 1},
+                 {1, 8, 1},   {1, 9, 1},  {1, 10, 1},
+         }},
+        {{1, 1},
+         {
+                 {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, 
{1, -5, 0},
+                 {1, -4, 0},  {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 0},  
{1, 1, 1},
+                 {1, 2, 1},   {1, 3, 1},  {1, 4, 1},  {1, 5, 1},  {1, 6, 1},  
{1, 7, 1},
+                 {1, 8, 1},   {1, 9, 1},  {1, 10, 1},
+         }},
+        {{2, 0},
+         {
+                 {12, -4, 0},
+                 {12, -3, 0},
+                 {12, -2, 0},
+                 {12, -1, 10},
+                 {12, 0, 12},
+                 {12, 1, 12},
+                 {12, 2, 12},
+                 {12, 3, 12},
+                 {12, 4, 12},
+         }},
+        {{2, 1},
+         {
+                 {12, -4, 0},
+                 {12, -3, 0},
+                 {12, -2, 0},
+                 {12, -1, 0},
+                 {12, 0, 10},
+                 {12, 1, 12},
+                 {12, 2, 12},
+                 {12, 3, 12},
+                 {12, 4, 12},
+         }},
+        {{2, 2},
+         {
+                 {12, -4, 0},
+                 {12, -3, 0},
+                 {12, -2, 0},
+                 {12, -1, 0},
+                 {12, 0, 0},
+                 {12, 1, 10},
+                 {12, 2, 12},
+                 {12, 3, 12},
+                 {12, 4, 12},
+         }},
+        {{9, 0},
+         {
+                 {123456789, -10, 0},        {123456789, -9, 0},         
{123456789, -8, 100000000},
+                 {123456789, -7, 120000000}, {123456789, -6, 123000000}, 
{123456789, -5, 123400000},
+                 {123456789, -4, 123450000}, {123456789, -3, 123456000}, 
{123456789, -2, 123456700},
+                 {123456789, -1, 123456780}, {123456789, 0, 123456789},  
{123456789, 1, 123456789},
+                 {123456789, 2, 123456789},  {123456789, 3, 123456789},  
{123456789, 4, 123456789},
+                 {123456789, 5, 123456789},  {123456789, 6, 123456789},  
{123456789, 7, 123456789},
+                 {123456789, 8, 123456789},  {123456789, 9, 123456789},  
{123456789, 10, 123456789},
+         }},
+        {{9, 1},
+         {
+                 {123456789, -10, 0},        {123456789, -9, 0},         
{123456789, -8, 0},
+                 {123456789, -7, 100000000}, {123456789, -6, 120000000}, 
{123456789, -5, 123000000},
+                 {123456789, -4, 123400000}, {123456789, -3, 123450000}, 
{123456789, -2, 123456000},
+                 {123456789, -1, 123456700}, {123456789, 0, 123456780},  
{123456789, 1, 123456789},
+                 {123456789, 2, 123456789},  {123456789, 3, 123456789},  
{123456789, 4, 123456789},
+                 {123456789, 5, 123456789},  {123456789, 6, 123456789},  
{123456789, 7, 123456789},
+                 {123456789, 8, 123456789},  {123456789, 9, 123456789},  
{123456789, 10, 123456789},
+         }},
+        {{9, 2},
+         {
+                 {123456789, -10, 0},        {123456789, -9, 0},         
{123456789, -8, 0},
+                 {123456789, -7, 0},         {123456789, -6, 100000000}, 
{123456789, -5, 120000000},
+                 {123456789, -4, 123000000}, {123456789, -3, 123400000}, 
{123456789, -2, 123450000},
+                 {123456789, -1, 123456000}, {123456789, 0, 123456700},  
{123456789, 1, 123456780},
+                 {123456789, 2, 123456789},  {123456789, 3, 123456789},  
{123456789, 4, 123456789},
+                 {123456789, 5, 123456789},  {123456789, 6, 123456789},  
{123456789, 7, 123456789},
+                 {123456789, 8, 123456789},  {123456789, 9, 123456789},  
{123456789, 10, 123456789},
+         }},
+        {{9, 3},
+         {
+                 {123456789, -10, 0},        {123456789, -9, 0},         
{123456789, -8, 0},
+                 {123456789, -7, 0},         {123456789, -6, 0},         
{123456789, -5, 100000000},
+                 {123456789, -4, 120000000}, {123456789, -3, 123000000}, 
{123456789, -2, 123400000},
+                 {123456789, -1, 123450000}, {123456789, 0, 123456000},  
{123456789, 1, 123456700},
+                 {123456789, 2, 123456780},  {123456789, 3, 123456789},  
{123456789, 4, 123456789},
+                 {123456789, 5, 123456789},  {123456789, 6, 123456789},  
{123456789, 7, 123456789},
+                 {123456789, 8, 123456789},  {123456789, 9, 123456789},  
{123456789, 10, 123456789},
+         }},
+        {{9, 4},
+         {
+                 {123456789, -10, 0},        {123456789, -9, 0},         
{123456789, -8, 0},
+                 {123456789, -7, 0},         {123456789, -6, 0},         
{123456789, -5, 0},
+                 {123456789, -4, 100000000}, {123456789, -3, 120000000}, 
{123456789, -2, 123000000},
+                 {123456789, -1, 123400000}, {123456789, 0, 123450000},  
{123456789, 1, 123456000},
+                 {123456789, 2, 123456700},  {123456789, 3, 123456780},  
{123456789, 4, 123456789},
+                 {123456789, 5, 123456789},  {123456789, 6, 123456789},  
{123456789, 7, 123456789},
+                 {123456789, 8, 123456789},  {123456789, 9, 123456789},  
{123456789, 10, 123456789},
+         }},
+        {{9, 5},
+         {
+                 {123456789, -10, 0},        {123456789, -9, 0},         
{123456789, -8, 0},
+                 {123456789, -7, 0},         {123456789, -6, 0},         
{123456789, -5, 0},
+                 {123456789, -4, 0},         {123456789, -3, 100000000}, 
{123456789, -2, 120000000},
+                 {123456789, -1, 123000000}, {123456789, 0, 123400000},  
{123456789, 1, 123450000},
+                 {123456789, 2, 123456000},  {123456789, 3, 123456700},  
{123456789, 4, 123456780},
+                 {123456789, 5, 123456789},  {123456789, 6, 123456789},  
{123456789, 7, 123456789},
+                 {123456789, 8, 123456789},  {123456789, 9, 123456789},  
{123456789, 10, 123456789},
+         }},
+        {{9, 6},
+         {
+                 {123456789, -10, 0},        {123456789, -9, 0},        
{123456789, -8, 0},
+                 {123456789, -7, 0},         {123456789, -6, 0},        
{123456789, -5, 0},
+                 {123456789, -4, 0},         {123456789, -3, 0},        
{123456789, -2, 100000000},
+                 {123456789, -1, 120000000}, {123456789, 0, 123000000}, 
{123456789, 1, 123400000},
+                 {123456789, 2, 123450000},  {123456789, 3, 123456000}, 
{123456789, 4, 123456700},
+                 {123456789, 5, 123456780},  {123456789, 6, 123456789}, 
{123456789, 7, 123456789},
+                 {123456789, 8, 123456789},  {123456789, 9, 123456789}, 
{123456789, 10, 123456789},
+         }},
+        {{9, 7},
+         {
+                 {123456789, -10, 0},        {123456789, -9, 0},        
{123456789, -8, 0},
+                 {123456789, -7, 0},         {123456789, -6, 0},        
{123456789, -5, 0},
+                 {123456789, -4, 0},         {123456789, -3, 0},        
{123456789, -2, 0},
+                 {123456789, -1, 100000000}, {123456789, 0, 120000000}, 
{123456789, 1, 123000000},
+                 {123456789, 2, 123400000},  {123456789, 3, 123450000}, 
{123456789, 4, 123456000},
+                 {123456789, 5, 123456700},  {123456789, 6, 123456780}, 
{123456789, 7, 123456789},
+                 {123456789, 8, 123456789},  {123456789, 9, 123456789}, 
{123456789, 10, 123456789},
+         }},
+        {{9, 8},
+         {
+                 {123456789, -10, 0},       {123456789, -9, 0},        
{123456789, -8, 0},
+                 {123456789, -7, 0},        {123456789, -6, 0},        
{123456789, -5, 0},
+                 {123456789, -4, 0},        {123456789, -3, 0},        
{123456789, -2, 0},
+                 {123456789, -1, 0},        {123456789, 0, 100000000}, 
{123456789, 1, 120000000},
+                 {123456789, 2, 123000000}, {123456789, 3, 123400000}, 
{123456789, 4, 123450000},
+                 {123456789, 5, 123456000}, {123456789, 6, 123456700}, 
{123456789, 7, 123456780},
+                 {123456789, 8, 123456789}, {123456789, 9, 123456789}, 
{123456789, 10, 123456789},
+         }},
+        {{9, 9},
+         {
+                 {123456789, -10, 0},       {123456789, -9, 0},        
{123456789, -8, 0},
+                 {123456789, -7, 0},        {123456789, -6, 0},        
{123456789, -5, 0},
+                 {123456789, -4, 0},        {123456789, -3, 0},        
{123456789, -2, 0},
+                 {123456789, -1, 0},        {123456789, 0, 0},         
{123456789, 1, 100000000},
+                 {123456789, 2, 120000000}, {123456789, 3, 123000000}, 
{123456789, 4, 123400000},
+                 {123456789, 5, 123450000}, {123456789, 6, 123456000}, 
{123456789, 7, 123456700},
+                 {123456789, 8, 123456780}, {123456789, 9, 123456789}, 
{123456789, 10, 123456789},
+         }}};
+
+const static TestDataSet truncate_decimal64_cases = {
+        {{10, 0},
+         {{1234567891, -11, 0},         {1234567891, -10, 0},         
{1234567891, -9, 1000000000},
+          {1234567891, -8, 1200000000}, {1234567891, -7, 1230000000}, 
{1234567891, -6, 1234000000},
+          {1234567891, -5, 1234500000}, {1234567891, -4, 1234560000}, 
{1234567891, -3, 1234567000},
+          {1234567891, -2, 1234567800}, {1234567891, -1, 1234567890}, 
{1234567891, 0, 1234567891},
+          {1234567891, 1, 1234567891},  {1234567891, 2, 1234567891},  
{1234567891, 3, 1234567891},
+          {1234567891, 4, 1234567891},  {1234567891, 5, 1234567891},  
{1234567891, 6, 1234567891},
+          {1234567891, 7, 1234567891},  {1234567891, 8, 1234567891},  
{1234567891, 9, 1234567891},
+          {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}},
+        {{10, 1},
+         {{1234567891, -11, 0},         {1234567891, -10, 0},         
{1234567891, -9, 0},
+          {1234567891, -8, 1000000000}, {1234567891, -7, 1200000000}, 
{1234567891, -6, 1230000000},
+          {1234567891, -5, 1234000000}, {1234567891, -4, 1234500000}, 
{1234567891, -3, 1234560000},
+          {1234567891, -2, 1234567000}, {1234567891, -1, 1234567800}, 
{1234567891, 0, 1234567890},
+          {1234567891, 1, 1234567891},  {1234567891, 2, 1234567891},  
{1234567891, 3, 1234567891},
+          {1234567891, 4, 1234567891},  {1234567891, 5, 1234567891},  
{1234567891, 6, 1234567891},
+          {1234567891, 7, 1234567891},  {1234567891, 8, 1234567891},  
{1234567891, 9, 1234567891},
+          {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}
+
+         }},
+        {{10, 2},
+         {{1234567891, -11, 0},         {1234567891, -10, 0},         
{1234567891, -9, 0},
+          {1234567891, -8, 0},          {1234567891, -7, 1000000000}, 
{1234567891, -6, 1200000000},
+          {1234567891, -5, 1230000000}, {1234567891, -4, 1234000000}, 
{1234567891, -3, 1234500000},
+          {1234567891, -2, 1234560000}, {1234567891, -1, 1234567000}, 
{1234567891, 0, 1234567800},
+          {1234567891, 1, 1234567890},  {1234567891, 2, 1234567891},  
{1234567891, 3, 1234567891},
+          {1234567891, 4, 1234567891},  {1234567891, 5, 1234567891},  
{1234567891, 6, 1234567891},
+          {1234567891, 7, 1234567891},  {1234567891, 8, 1234567891},  
{1234567891, 9, 1234567891},
+          {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}},
+        {{10, 9},
+         {{1234567891, -11, 0},         {1234567891, -10, 0},        
{1234567891, -9, 0},
+          {1234567891, -8, 0},          {1234567891, -7, 0},         
{1234567891, -6, 0},
+          {1234567891, -5, 0},          {1234567891, -4, 0},         
{1234567891, -3, 0},
+          {1234567891, -2, 0},          {1234567891, -1, 0},         
{1234567891, 0, 1000000000},
+          {1234567891, 1, 1200000000},  {1234567891, 2, 1230000000}, 
{1234567891, 3, 1234000000},
+          {1234567891, 4, 1234500000},  {1234567891, 5, 1234560000}, 
{1234567891, 6, 1234567000},
+          {1234567891, 7, 1234567800},  {1234567891, 8, 1234567890}, 
{1234567891, 9, 1234567891},
+          {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}},
+        {{18, 0},
+         {{123456789123456789, -19, 0},
+          {123456789123456789, -18, 0},
+          {123456789123456789, -17, 100000000000000000},
+          {123456789123456789, -16, 120000000000000000},
+          {123456789123456789, -15, 123000000000000000},
+          {123456789123456789, -14, 123400000000000000},
+          {123456789123456789, -13, 123450000000000000},
+          {123456789123456789, -12, 123456000000000000},
+          {123456789123456789, -11, 123456700000000000},
+          {123456789123456789, -10, 123456780000000000},
+          {123456789123456789, -9, 123456789000000000},
+          {123456789123456789, -8, 123456789100000000},
+          {123456789123456789, -7, 123456789120000000},
+          {123456789123456789, -6, 123456789123000000},
+          {123456789123456789, -5, 123456789123400000},
+          {123456789123456789, -4, 123456789123450000},
+          {123456789123456789, -3, 123456789123456000},
+          {123456789123456789, -2, 123456789123456700},
+          {123456789123456789, -1, 123456789123456780},
+          {123456789123456789, 0, 123456789123456789},
+          {123456789123456789, 1, 123456789123456789},
+          {123456789123456789, 2, 123456789123456789},
+          {123456789123456789, 3, 123456789123456789},
+          {123456789123456789, 4, 123456789123456789},
+          {123456789123456789, 5, 123456789123456789},
+          {123456789123456789, 6, 123456789123456789},
+          {123456789123456789, 7, 123456789123456789},
+          {123456789123456789, 8, 123456789123456789},
+          {123456789123456789, 18, 123456789123456789}}},
+        {{18, 18},
+         {{123456789123456789, -1, 0},
+          {123456789123456789, 0, 0},
+          {123456789123456789, 1, 100000000000000000},
+          {123456789123456789, 2, 120000000000000000},
+          {123456789123456789, 3, 123000000000000000},
+          {123456789123456789, 4, 123400000000000000},
+          {123456789123456789, 5, 123450000000000000},
+          {123456789123456789, 6, 123456000000000000},
+          {123456789123456789, 7, 123456700000000000},
+          {123456789123456789, 8, 123456780000000000},
+          {123456789123456789, 9, 123456789000000000},
+          {123456789123456789, 10, 123456789100000000},
+          {123456789123456789, 11, 123456789120000000},
+          {123456789123456789, 12, 123456789123000000},
+          {123456789123456789, 13, 123456789123400000},
+          {123456789123456789, 14, 123456789123450000},
+          {123456789123456789, 15, 123456789123456000},
+          {123456789123456789, 16, 123456789123456700},
+          {123456789123456789, 17, 123456789123456780},
+          {123456789123456789, 18, 123456789123456789},
+          {123456789123456789, 19, 123456789123456789},
+          {123456789123456789, 20, 123456789123456789},
+          {123456789123456789, 21, 123456789123456789},
+          {123456789123456789, 22, 123456789123456789},
+          {123456789123456789, 23, 123456789123456789},
+          {123456789123456789, 24, 123456789123456789},
+          {123456789123456789, 25, 123456789123456789},
+          {123456789123456789, 26, 123456789123456789}}}};
+
+template <typename FuncType, typename DecimalType>
+static void checker(const TestDataSet& truncate_test_cases, bool 
decimal_col_is_const) {
+    static_assert(IsDecimalNumber<DecimalType>);
+    auto func = std::dynamic_pointer_cast<FuncType>(FuncType::create());
+    FunctionContext* context = nullptr;
+
+    for (const auto& test_case : truncate_test_cases) {
+        Block block;
+        size_t res_idx = 2;
+        ColumnNumbers arguments = {0, 1, 2};
+        const int precision = test_case.first.first;
+        const int scale = test_case.first.second;
+        const size_t input_rows_count = test_case.second.size();
+        auto col_general = 
ColumnDecimal<DecimalType>::create(input_rows_count, scale);
+        auto col_scale = ColumnInt32::create();
+        auto col_res_expected = 
ColumnDecimal<DecimalType>::create(input_rows_count, scale);
+        size_t rid = 0;
+
+        for (const auto& test_date : test_case.second) {
+            auto input = std::get<0>(test_date);
+            auto scale_arg = std::get<1>(test_date);
+            auto expectation = std::get<2>(test_date);
+            col_general->get_element(rid) = DecimalType(input);
+            col_scale->insert(scale_arg);
+            col_res_expected->get_element(rid) = DecimalType(expectation);
+            rid++;
+        }
+
+        if (decimal_col_is_const) {
+            block.insert({ColumnConst::create(col_general->clone_resized(1), 
1),
+                          
std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale),
+                          "col_general_const"});
+        } else {
+            block.insert({col_general->clone(),
+                          
std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale),
+                          "col_general"});
+        }
+
+        block.insert({col_scale->clone(), std::make_shared<DataTypeInt32>(), 
"col_scale"});
+        block.insert({nullptr, 
std::make_shared<DataTypeDecimal<DecimalType>>(precision, scale),
+                      "col_res"});
+
+        auto status = func->execute_impl(context, block, arguments, res_idx, 
input_rows_count);
+        auto col_res = assert_cast<const ColumnDecimal<DecimalType>&>(
+                *(block.get_by_position(res_idx).column));
+        EXPECT_TRUE(status.ok());
+
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            auto res = col_res.get_element(i);
+            auto res_expected = col_res_expected->get_element(i);
+            EXPECT_EQ(res, res_expected)
+                    << "precision " << precision << " input_scale " << scale 
<< " input "
+                    << col_general->get_element(i) << " scale_arg " << 
col_scale->get_element(i)
+                    << " res " << res << " res_expected " << res_expected;
+        }
+    }
+}
+TEST(TruncateFunctionTest, normal_decimal) {
+    checker<FunctionTruncate<TruncateDecimalTwoArgImpl>, 
Decimal32>(truncate_decimal32_cases,
+                                                                    false);
+    checker<FunctionTruncate<TruncateDecimalTwoArgImpl>, 
Decimal64>(truncate_decimal64_cases,
+                                                                    false);
+}
+
+TEST(TruncateFunctionTest, normal_decimal_const) {
+    checker<FunctionTruncate<TruncateDecimalTwoArgImpl>, 
Decimal32>(truncate_decimal32_cases, true);
+    checker<FunctionTruncate<TruncateDecimalTwoArgImpl>, 
Decimal64>(truncate_decimal64_cases, true);
+}
+
+} // namespace doris::vectorized
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index b5184c33fcd..9bc857bacef 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -122,7 +122,7 @@ public class FunctionCallExpr extends Expr {
                 Preconditions.checkArgument(children.get(1) instanceof 
IntLiteral
                         || (children.get(1) instanceof CastExpr
                                 && children.get(1).getChild(0) instanceof 
IntLiteral),
-                        "2nd argument of function round/floor/ceil/truncate 
must be literal");
+                        "2nd argument of function round/floor/ceil must be 
literal");
                 if (children.get(1) instanceof CastExpr && 
children.get(1).getChild(0) instanceof IntLiteral) {
                     
children.get(1).getChild(0).setType(children.get(1).getType());
                     children.set(1, children.get(1).getChild(0));
@@ -136,6 +136,34 @@ public class FunctionCallExpr extends Expr {
                 return returnType;
             }
         };
+
+        java.util.function.BiFunction<ArrayList<Expr>, Type, Type> 
truncateRule = (children, returnType) -> {
+            Preconditions.checkArgument(children != null && children.size() > 
0);
+            if (children.size() == 1 && 
children.get(0).getType().isDecimalV3()) {
+                return 
ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), 0);
+            } else if (children.size() == 2) {
+                Expr scaleExpr = children.get(1);
+                if (scaleExpr instanceof IntLiteral
+                        || (scaleExpr instanceof CastExpr && 
scaleExpr.getChild(0) instanceof IntLiteral)) {
+                    if (children.get(1) instanceof CastExpr && 
children.get(1).getChild(0) instanceof IntLiteral) {
+                        
children.get(1).getChild(0).setType(children.get(1).getType());
+                        children.set(1, children.get(1).getChild(0));
+                    } else {
+                        children.get(1).setType(Type.INT);
+                    }
+                    int scaleArg = (int) (((IntLiteral) 
children.get(1)).getValue());
+                    return 
ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(),
+                            Math.min(Math.max(scaleArg, 0), ((ScalarType) 
children.get(0).getType()).decimalScale()));
+                } else {
+                    // Scale argument is a Column, always use same scale with 
input decimal
+                    return 
ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(),
+                            ((ScalarType) 
children.get(0).getType()).decimalScale());
+                }
+            } else {
+                return returnType;
+            }
+        };
+
         java.util.function.BiFunction<ArrayList<Expr>, Type, Type> 
arrayDateTimeV2OrDecimalV3Rule
                 = (children, returnType) -> {
                     Preconditions.checkArgument(children != null && 
children.size() > 0);
@@ -239,7 +267,7 @@ public class FunctionCallExpr extends Expr {
         PRECISION_INFER_RULE.put("dround", roundRule);
         PRECISION_INFER_RULE.put("dceil", roundRule);
         PRECISION_INFER_RULE.put("dfloor", roundRule);
-        PRECISION_INFER_RULE.put("truncate", roundRule);
+        PRECISION_INFER_RULE.put("truncate", truncateRule);
     }
 
     public static final ImmutableSet<String> TIME_FUNCTIONS_WITH_PRECISION = 
new ImmutableSortedSet.Builder(
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
index 4b57772ed23..6b6308c516c 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions.functions;
 import org.apache.doris.catalog.FunctionSignature;
 import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.coercion.Int32OrLessType;
@@ -37,19 +38,38 @@ public interface ComputePrecisionForRound extends 
ComputePrecision {
         } else if (arity() == 2 && signature.getArgType(0) instanceof 
DecimalV3Type) {
             DecimalV3Type decimalV3Type = 
DecimalV3Type.forType(getArgumentType(0));
             Expression floatLength = getArgument(1);
-            Preconditions.checkArgument(floatLength.getDataType() instanceof 
Int32OrLessType
-                    && (floatLength.isLiteral() || (
-                            floatLength instanceof Cast && 
floatLength.child(0).isLiteral()
-                                    && floatLength.child(0).getDataType() 
instanceof Int32OrLessType)),
-                    "2nd argument of function round/floor/ceil/truncate must 
be literal");
-
             int scale;
-            if (floatLength instanceof Cast) {
-                scale = ((IntegerLikeLiteral) 
floatLength.child(0)).getIntValue();
+
+            if (this instanceof Truncate) {
+                if (floatLength.isLiteral() || (
+                        floatLength instanceof Cast && 
floatLength.child(0).isLiteral()
+                                && floatLength.child(0).getDataType() 
instanceof Int32OrLessType)) {
+                    // Scale argument is a literal or cast from other literal
+                    if (floatLength instanceof Cast) {
+                        scale = ((IntegerLikeLiteral) 
floatLength.child(0)).getIntValue();
+                    } else {
+                        scale = ((IntegerLikeLiteral) 
floatLength).getIntValue();
+                    }
+                    scale = Math.min(Math.max(scale, 0), 
decimalV3Type.getScale());
+                } else {
+                    // Truncate could use Column as its scale argument.
+                    // Result scale will always same with input Decimal in 
this situation.
+                    scale = decimalV3Type.getScale();
+                }
             } else {
-                scale = ((IntegerLikeLiteral) floatLength).getIntValue();
+                Preconditions.checkArgument(floatLength.getDataType() 
instanceof Int32OrLessType
+                                && (floatLength.isLiteral() || (
+                                floatLength instanceof Cast && 
floatLength.child(0).isLiteral()
+                                        && floatLength.child(0).getDataType() 
instanceof Int32OrLessType)),
+                        "2nd argument of function round/floor/ceil must be 
literal");
+                if (floatLength instanceof Cast) {
+                    scale = ((IntegerLikeLiteral) 
floatLength.child(0)).getIntValue();
+                } else {
+                    scale = ((IntegerLikeLiteral) floatLength).getIntValue();
+                }
+                scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale());
             }
-            scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale());
+
             return signature.withArgumentType(0, decimalV3Type)
                     
.withReturnType(DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision(), 
scale));
         } else {
diff --git 
a/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out
 
b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out
new file mode 100644
index 00000000000..24f675ffbe2
--- /dev/null
+++ 
b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out
@@ -0,0 +1,101 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !sql --
+0      123.3
+1      123.3
+2      123.3
+3      123.3
+4      123.3
+5      123.3
+6      123.3
+7      123.3
+8      123.3
+9      123.3
+
+-- !sql --
+0      120
+1      120
+2      120
+3      120
+4      120
+5      120
+6      120
+7      120
+8      120
+9      120
+
+-- !sql --
+0      123
+1      123
+2      123
+3      123
+4      123
+5      123
+6      123
+7      123
+8      123
+9      123
+
+-- !sql --
+0E-8
+
+-- !sql --
+0      0.0
+1      0.0
+2      0.0
+3      0.0
+4      0.0
+
+-- !vec_const0 --
+1      12345.0 1.23456789E8
+2      12345.0 1.23456789E8
+3      12345.0 1.23456789E8
+4      0.0     0.0
+
+-- !vec_const0 --
+1      12345.1 1.234567891E8
+2      12345.1 1.234567891E8
+3      12345.1 1.234567891E8
+4      0.0     0.0
+
+-- !vec_const0 --
+1      12340.0 1.2345678E8
+2      12340.0 1.2345678E8
+3      12340.0 1.2345678E8
+4      0.0     0.0
+
+-- !vec_const1 --
+1      123456789       123456789       12345678.1      12345678        
0.123456789     0
+2      123456789       123456789       12345678.1      12345678        
0.123456789     0
+3      123456789       123456789       12345678.1      12345678        
0.123456789     0
+4      0       0       0.0     0       0E-9    0
+
+-- !vec_const2 --
+1      123456789       123456789       1.123456789     1       0.1234567890    0
+2      123456789       123456789       1.123456789     1       0.1234567890    0
+3      123456789       123456789       1.123456789     1       0.1234567890    0
+4      0       0       0E-9    0       0E-10   0
+
+-- !const_vec1 --
+123456789.123456789    1       123456789.100000000
+123456789.123456789    1       123456789.100000000
+123456789.123456789    1       123456789.100000000
+123456789.123456789    1       123456789.100000000
+
+-- !const_vec2 --
+123456789.123456789    -1      123456780.000000000
+123456789.123456789    -1      123456780.000000000
+123456789.123456789    -1      123456780.000000000
+123456789.123456789    -1      123456780.000000000
+
+-- !vec_vec0 --
+1      1       12345.1 1.234567891E8
+2      1       12345.1 1.234567891E8
+3      1       12345.1 1.234567891E8
+4      1       0.0     0.0
+
+-- !truncate_dec128 --
+1      1234567891234567891     1234567891234567891     1234567891.123456789    
1234567891      0.1234567891234567891   0
+
+-- !truncate_dec128 --
+1      1234567891234567891     1234567891234567891     1234567891.123456789    
1234567891.100000000    0.1234567891234567891   0.1000000000000000000
+
diff --git 
a/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy
 
b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy
new file mode 100644
index 00000000000..767140e7a6f
--- /dev/null
+++ 
b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_function_truncate") {
+    qt_sql """
+        SELECT number, truncate(123.345 , 1) FROM numbers("number"="10");
+    """
+    qt_sql """
+        SELECT number, truncate(123.123, -1) FROM numbers("number"="10");
+    """
+    qt_sql """
+        SELECT number, truncate(123.123, 0) FROM numbers("number"="10");
+    """
+
+    // const_const, result scale should be 10
+    qt_sql """
+        SELECT truncate(cast(0 as Decimal(9,8)), 10);
+    """  
+
+    // const_const, result scale should be 1
+    qt_sql """
+        SELECT number, truncate(cast(0 as Decimal(9,4)), 1) FROM 
numbers("number"="5")
+    """
+
+    sql """DROP TABLE IF EXISTS test_function_truncate;"""
+    sql """DROP TABLE IF EXISTS test_function_truncate_dec128;"""
+    sql """
+        CREATE TABLE test_function_truncate (
+            rid int, flo float, dou double,
+            dec90 decimal(9, 0), dec91 decimal(9, 1), dec99 decimal(9, 9),
+            dec100 decimal(10,0), dec109 decimal(10,9), dec1010 decimal(10,10),
+            number int DEFAULT 1)
+        DISTRIBUTED BY HASH(rid)
+        PROPERTIES("replication_num" = "1" );
+        """
+
+    sql """
+        INSERT INTO test_function_truncate
+        VALUES
+            (1, 12345.123, 123456789.123456789,
+             123456789, 12345678.1, 0.123456789,
+             123456789.1, 1.123456789, 0.123456789, 1);
+    """
+    sql """
+        INSERT INTO test_function_truncate
+            VALUES
+        (2, 12345.123, 123456789.123456789,
+             123456789, 12345678.1, 0.123456789,
+             123456789.1, 1.123456789, 0.123456789, 1);
+    """
+    sql """
+        INSERT INTO test_function_truncate
+            VALUES
+        (3, 12345.123, 123456789.123456789,
+             123456789, 12345678.1, 0.123456789,
+             123456789.1, 1.123456789, 0.123456789, 1);
+    """
+    sql """
+        INSERT INTO test_function_truncate
+            VALUES
+        (4, 0, 0, 0, 0.0, 0, 0, 0, 0, 1);
+    """
+    qt_vec_const0 """
+        SELECT rid, truncate(flo, 0), truncate(dou, 0) FROM 
test_function_truncate order by rid;
+    """
+    qt_vec_const0 """
+        SELECT rid, truncate(flo, 1), truncate(dou, 1) FROM 
test_function_truncate order by rid;
+    """
+    qt_vec_const0 """
+        SELECT rid, truncate(flo, -1), truncate(dou, -1) FROM 
test_function_truncate order by rid;
+    """
+    qt_vec_const1 """
+        SELECT rid, dec90, truncate(dec90, 0), dec91, truncate(dec91, 0), 
dec99, truncate(dec99, 0) FROM test_function_truncate order by rid
+    """
+    qt_vec_const2 """
+        SELECT rid, dec100, truncate(dec100, 0), dec109, truncate(dec109, 0), 
dec1010, truncate(dec1010, 0) FROM test_function_truncate order by rid
+    """
+
+    
+
+    qt_const_vec1 """
+        SELECT 123456789.123456789, number, truncate(123456789.123456789, 
number) from test_function_truncate;
+    """
+    qt_const_vec2 """
+        SELECT 123456789.123456789, -number, truncate(123456789.123456789, 
-number) from test_function_truncate;
+    """
+    qt_vec_vec0 """
+        SELECT rid,number, truncate(flo, number), truncate(dou, number) FROM 
test_function_truncate order by rid;
+    """
+
+    sql """
+        CREATE TABLE test_function_truncate_dec128 (
+            rid int, dec190 decimal(19,0), dec199 decimal(19,9), dec1919 
decimal(19,19),
+                     dec380 decimal(38,0), dec3819 decimal(38,19), dec3838 
decimal(38,38),
+                     number int DEFAULT 1
+        )
+        DISTRIBUTED BY HASH(rid)
+        PROPERTIES("replication_num" = "1" );
+    """
+    sql """
+        INSERT INTO test_function_truncate_dec128
+            VALUES
+        (1, 1234567891234567891.0, 1234567891.123456789, 0.1234567891234567891,
+            12345678912345678912345678912345678912.0, 
+            1234567891234567891.1234567891234567891,
+            
0.12345678912345678912345678912345678912345678912345678912345678912345678912, 
1);
+    """
+    qt_truncate_dec128 """
+        SELECT rid, dec190, truncate(dec190, 0), dec199, truncate(dec199, 0), 
dec1919, truncate(dec1919, 0)
+            FROM test_function_truncate_dec128 order by rid
+    """
+
+    qt_truncate_dec128 """
+        SELECT rid, dec190, truncate(dec190, number), dec199, truncate(dec199, 
number), dec1919, truncate(dec1919, number)
+            FROM test_function_truncate_dec128 order by rid
+    """
+
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to