This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 6af68c04363 [feature](function) support format_round scala function (#49084) 6af68c04363 is described below commit 6af68c04363b3867c6f610f3401110d2f04fb855 Author: Tiewei Fang <fangtie...@selectdb.com> AuthorDate: Wed Apr 9 14:44:27 2025 +0800 [feature](function) support format_round scala function (#49084) --- be/src/vec/functions/function_string.cpp | 4 + be/src/vec/functions/function_string.h | 360 +++++++++++++++++++++ be/test/vec/function/function_math_test.cpp | 59 ++++ .../doris/catalog/BuiltinScalarFunctions.java | 2 + .../expressions/functions/scalar/FormatRound.java | 74 +++++ .../expressions/visitor/ScalarFunctionVisitor.java | 5 + gensrc/script/doris_builtins_functions.py | 6 + .../math_functions/test_format_round.out | Bin 0 -> 681 bytes .../math_functions/test_format_round.groovy | 64 ++++ 9 files changed, 574 insertions(+) diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index 3d79bed171a..8a5e9c24496 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -1314,6 +1314,10 @@ void register_function_string(SimpleFunctionFactory& factory) { factory.register_function<FunctionMoneyFormat<MoneyFormatInt64Impl>>(); factory.register_function<FunctionMoneyFormat<MoneyFormatInt128Impl>>(); factory.register_function<FunctionMoneyFormat<MoneyFormatDecimalImpl>>(); + factory.register_function<FunctionStringFormatRound<FormatRoundDoubleImpl>>(); + factory.register_function<FunctionStringFormatRound<FormatRoundInt64Impl>>(); + factory.register_function<FunctionStringFormatRound<FormatRoundInt128Impl>>(); + factory.register_function<FunctionStringFormatRound<FormatRoundDecimalImpl>>(); factory.register_function<FunctionStringDigestOneArg<SM3Sum>>(); factory.register_function<FunctionStringDigestOneArg<MD5Sum>>(); factory.register_function<FunctionStringDigestSHA1>(); diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 1af2855e8ad..6a7f0446df6 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -24,6 +24,7 @@ #include <array> #include <boost/iterator/iterator_facade.hpp> #include <boost/locale.hpp> +#include <boost/multiprecision/cpp_dec_float.hpp> #include <climits> #include <cmath> #include <cstddef> @@ -1779,6 +1780,44 @@ public: } }; +template <typename Impl> +class FunctionStringFormatRound : public IFunction { +public: + static constexpr auto name = "format_round"; + static FunctionPtr create() { return std::make_shared<FunctionStringFormatRound>(); } + String get_name() const override { return name; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + if (arguments.size() != 2) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Function {} requires exactly 2 argument", name); + } + return std::make_shared<DataTypeString>(); + } + DataTypes get_variadic_argument_types_impl() const override { + return Impl::get_variadic_argument_types(); + } + size_t get_number_of_arguments() const override { return 2; } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) const override { + auto res_column = ColumnString::create(); + ColumnPtr argument_column = block.get_by_position(arguments[0]).column; + ColumnPtr argument_column_2; + bool is_const; + std::tie(argument_column_2, is_const) = + unpack_if_const(block.get_by_position(arguments[1]).column); + + auto result_column = assert_cast<ColumnString*>(res_column.get()); + + RETURN_IF_ERROR(Impl::execute(context, result_column, argument_column, argument_column_2, + input_rows_count)); + + block.replace_by_position(result, std::move(res_column)); + return Status::OK(); + } +}; + class FunctionSplitPart : public IFunction { public: static constexpr auto name = "split_part"; @@ -3127,6 +3166,146 @@ static StringRef do_money_format(FunctionContext* context, const string& value) }; } // namespace MoneyFormat + +namespace FormatRound { + +constexpr size_t MAX_FORMAT_LEN_DEC32() { + // Decimal(9, 0) + // Double the size to avoid some unexpected bug. + return 2 * (1 + 9 + (9 / 3) + 3); +} + +constexpr size_t MAX_FORMAT_LEN_DEC64() { + // Decimal(18, 0) + // Double the size to avoid some unexpected bug. + return 2 * (1 + 18 + (18 / 3) + 3); +} + +constexpr size_t MAX_FORMAT_LEN_DEC128V2() { + // DecimalV2 has at most 27 digits + // Double the size to avoid some unexpected bug. + return 2 * (1 + 27 + (27 / 3) + 3); +} + +constexpr size_t MAX_FORMAT_LEN_DEC128V3() { + // Decimal(38, 0) + // Double the size to avoid some unexpected bug. + return 2 * (1 + 39 + (39 / 3) + 3); +} + +constexpr size_t MAX_FORMAT_LEN_INT64() { + // INT_MIN = -9223372036854775807 + // Double the size to avoid some unexpected bug. + return 2 * (1 + 20 + (20 / 3) + 3); +} + +constexpr size_t MAX_FORMAT_LEN_INT128() { + // INT128_MIN = -170141183460469231731687303715884105728 + return 2 * (1 + 39 + (39 / 3) + 3); +} + +template <typename T, size_t N> +StringRef do_format_round(FunctionContext* context, UInt32 scale, T int_value, T frac_value, + Int32 decimal_places) { + static_assert(std::is_integral<T>::value); + const bool is_negative = int_value < 0 || frac_value < 0; + + // do round to frac_part based on decimal_places + if (scale > decimal_places && decimal_places > 0) { + DCHECK(scale <= 38); + // do rounding, so we need to reserve decimal_places + 1 digits + auto multiplier = + common::exp10_i128(std::abs(static_cast<int>(scale - (decimal_places + 1)))); + // do divide first to avoid overflow + // after round frac_value will be positive by design + frac_value = std::abs(static_cast<int>(frac_value / multiplier)) + 5; + frac_value /= 10; + } else if (scale < decimal_places && decimal_places > 0) { + // since scale <= decimal_places, overflow is impossible + frac_value = frac_value * common::exp10_i32(decimal_places - scale); + } + + // Calculate power of 10 for decimal_places + T decimal_power = common::exp10_i32(decimal_places); + if (frac_value == decimal_power) { + if (is_negative) { + int_value -= 1; + } else { + int_value += 1; + } + frac_value = 0; + } + + bool append_sign_manually = false; + if (is_negative && int_value == 0) { + append_sign_manually = true; + } + + char local[N]; + char* p = SimpleItoaWithCommas(int_value, local, sizeof(local)); + const Int32 integer_str_len = N - (p - local); + const Int32 frac_str_len = decimal_places; + const Int32 whole_decimal_str_len = (append_sign_manually ? 1 : 0) + integer_str_len + + (decimal_places > 0 ? 1 : 0) + frac_str_len; + + StringRef result = context->create_temp_string_val(whole_decimal_str_len); + char* result_data = const_cast<char*>(result.data); + + if (append_sign_manually) { + memset(result_data, '-', 1); + } + + memcpy(result_data + (append_sign_manually ? 1 : 0), p, integer_str_len); + if (decimal_places > 0) { + *(result_data + whole_decimal_str_len - (frac_str_len + 1)) = '.'; + } + + // Convert fractional part to string with proper padding + T remaining_frac = std::abs(static_cast<int>(frac_value)); + for (int i = 0; i <= decimal_places - 1; ++i) { + *(result_data + whole_decimal_str_len - 1 - i) = '0' + (remaining_frac % 10); + remaining_frac /= 10; + } + return result; +} + +// Note string value must be valid decimal string which contains two digits after the decimal point +static StringRef do_format_round(FunctionContext* context, const string& value, + Int32 decimal_places) { + bool is_positive = (value[0] != '-'); + int32_t result_len = + value.size() + + (value.size() - (is_positive ? (decimal_places + 2) : (decimal_places + 3))) / 3; + StringRef result = context->create_temp_string_val(result_len); + char* result_data = const_cast<char*>(result.data); + if (!is_positive) { + *result_data = '-'; + } + for (int i = value.size() - (decimal_places + 2), j = result_len - (decimal_places + 2); i >= 0; + i = i - 3) { + *(result_data + j) = *(value.data() + i); + if (i - 1 < 0) { + break; + } + *(result_data + j - 1) = *(value.data() + i - 1); + if (i - 2 < 0) { + break; + } + *(result_data + j - 2) = *(value.data() + i - 2); + if (j - 3 > 1 || (j - 3 == 1 && is_positive)) { + *(result_data + j - 3) = ','; + j -= 4; + } else { + j -= 3; + } + } + memcpy(result_data + result_len - (decimal_places + 1), + value.data() + value.size() - (decimal_places + 1), (decimal_places + 1)); + return result; +}; + +} // namespace FormatRound + struct MoneyFormatDoubleImpl { static DataTypes get_variadic_argument_types() { return {std::make_shared<DataTypeFloat64>()}; } @@ -3267,6 +3446,187 @@ struct MoneyFormatDecimalImpl { } }; +struct FormatRoundDoubleImpl { + static DataTypes get_variadic_argument_types() { + return {std::make_shared<DataTypeFloat64>(), std::make_shared<vectorized::DataTypeInt32>()}; + } + + static Status execute(FunctionContext* context, ColumnString* result_column, + const ColumnPtr col_ptr, ColumnPtr decimal_places_col_ptr, + size_t input_rows_count) { + const auto& arg_column_data_2 = + assert_cast<const ColumnInt32*>(decimal_places_col_ptr.get())->get_data(); + const auto* data_column = assert_cast<const ColumnFloat64*>(col_ptr.get()); + // when scale is above 38, we will go here + for (size_t i = 0; i < input_rows_count; i++) { + int32_t decimal_places = arg_column_data_2[i]; + if (decimal_places < 0) { + return Status::InvalidArgument( + "The second argument is {}, it can not be less than 0.", decimal_places); + } + // round to `decimal_places` decimal places + double value = MathFunctions::my_double_round(data_column->get_element(i), + decimal_places, false, false); + StringRef str = FormatRound::do_format_round( + context, fmt::format("{:.{}f}", value, decimal_places), decimal_places); + result_column->insert_data(str.data, str.size); + } + return Status::OK(); + } +}; + +struct FormatRoundInt64Impl { + static DataTypes get_variadic_argument_types() { + return {std::make_shared<DataTypeInt64>(), std::make_shared<vectorized::DataTypeInt32>()}; + } + + static Status execute(FunctionContext* context, ColumnString* result_column, + const ColumnPtr col_ptr, ColumnPtr decimal_places_col_ptr, + size_t input_rows_count) { + const auto* data_column = assert_cast<const ColumnVector<Int64>*>(col_ptr.get()); + const auto& arg_column_data_2 = + assert_cast<const ColumnInt32*>(decimal_places_col_ptr.get())->get_data(); + for (size_t i = 0; i < input_rows_count; i++) { + int32_t decimal_places = arg_column_data_2[i]; + if (decimal_places < 0) { + return Status::InvalidArgument( + "The second argument is {}, it can not be less than 0.", decimal_places); + } + Int64 value = data_column->get_element(i); + StringRef str = + FormatRound::do_format_round<Int64, FormatRound::MAX_FORMAT_LEN_INT64()>( + context, 0, value, 0, decimal_places); + result_column->insert_data(str.data, str.size); + } + return Status::OK(); + } +}; + +struct FormatRoundInt128Impl { + static DataTypes get_variadic_argument_types() { + return {std::make_shared<DataTypeInt128>(), std::make_shared<vectorized::DataTypeInt32>()}; + } + + static Status execute(FunctionContext* context, ColumnString* result_column, + const ColumnPtr col_ptr, ColumnPtr decimal_places_col_ptr, + size_t input_rows_count) { + const auto* data_column = assert_cast<const ColumnVector<Int128>*>(col_ptr.get()); + const auto& arg_column_data_2 = + assert_cast<const ColumnInt32*>(decimal_places_col_ptr.get())->get_data(); + // SELECT money_format(170141183460469231731687303715884105728/*INT128_MAX + 1*/) will + // get "170,141,183,460,469,231,731,687,303,715,884,105,727.00" in doris, + // see https://github.com/apache/doris/blob/788abf2d7c3c7c2d57487a9608e889e7662d5fb2/be/src/vec/data_types/data_type_number_base.cpp#L124 + for (size_t i = 0; i < input_rows_count; i++) { + int32_t decimal_places = arg_column_data_2[i]; + if (decimal_places < 0) { + return Status::InvalidArgument( + "The second argument is {}, it can not be less than 0.", decimal_places); + } + Int128 value = data_column->get_element(i); + StringRef str = + FormatRound::do_format_round<Int128, FormatRound::MAX_FORMAT_LEN_INT128()>( + context, 0, value, 0, decimal_places); + result_column->insert_data(str.data, str.size); + } + return Status::OK(); + } +}; + +struct FormatRoundDecimalImpl { + static DataTypes get_variadic_argument_types() { + return {std::make_shared<DataTypeDecimal<Decimal128V2>>(27, 9), + std::make_shared<vectorized::DataTypeInt32>()}; + } + + static Status execute(FunctionContext* context, ColumnString* result_column, ColumnPtr col_ptr, + ColumnPtr decimal_places_col_ptr, size_t input_rows_count) { + const auto& arg_column_data_2 = + assert_cast<const ColumnInt32*>(decimal_places_col_ptr.get())->get_data(); + if (auto* decimalv2_column = check_and_get_column<ColumnDecimal<Decimal128V2>>(*col_ptr)) { + for (size_t i = 0; i < input_rows_count; i++) { + int32_t decimal_places = arg_column_data_2[i]; + if (decimal_places < 0) { + return Status::InvalidArgument( + "The second argument is {}, it can not be less than 0.", + decimal_places); + } + const Decimal128V2& dec128 = decimalv2_column->get_element(i); + DecimalV2Value value = DecimalV2Value(dec128.value); + // unified_frac_value has 3 digits + auto unified_frac_value = value.frac_value() / 1000000; + StringRef str = + FormatRound::do_format_round<Int128, + FormatRound::MAX_FORMAT_LEN_DEC128V2()>( + context, 3, value.int_value(), unified_frac_value, decimal_places); + + result_column->insert_data(str.data, str.size); + } + } else if (auto* decimal32_column = + check_and_get_column<ColumnDecimal<Decimal32>>(*col_ptr)) { + const UInt32 scale = decimal32_column->get_scale(); + for (size_t i = 0; i < input_rows_count; i++) { + int32_t decimal_places = arg_column_data_2[i]; + if (decimal_places < 0) { + return Status::InvalidArgument( + "The second argument is {}, it can not be less than 0.", + decimal_places); + } + const Decimal32& frac_part = decimal32_column->get_fractional_part(i); + const Decimal32& whole_part = decimal32_column->get_whole_part(i); + StringRef str = + FormatRound::do_format_round<Int64, FormatRound::MAX_FORMAT_LEN_DEC32()>( + context, scale, static_cast<Int64>(whole_part.value), + static_cast<Int64>(frac_part.value), decimal_places); + + result_column->insert_data(str.data, str.size); + } + } else if (auto* decimal64_column = + check_and_get_column<ColumnDecimal<Decimal64>>(*col_ptr)) { + const UInt32 scale = decimal64_column->get_scale(); + for (size_t i = 0; i < input_rows_count; i++) { + int32_t decimal_places = arg_column_data_2[i]; + if (decimal_places < 0) { + return Status::InvalidArgument( + "The second argument is {}, it can not be less than 0.", + decimal_places); + } + const Decimal64& frac_part = decimal64_column->get_fractional_part(i); + const Decimal64& whole_part = decimal64_column->get_whole_part(i); + + StringRef str = + FormatRound::do_format_round<Int64, FormatRound::MAX_FORMAT_LEN_DEC64()>( + context, scale, whole_part.value, frac_part.value, decimal_places); + + result_column->insert_data(str.data, str.size); + } + } else if (auto* decimal128_column = + check_and_get_column<ColumnDecimal<Decimal128V3>>(*col_ptr)) { + const UInt32 scale = decimal128_column->get_scale(); + for (size_t i = 0; i < input_rows_count; i++) { + int32_t decimal_places = arg_column_data_2[i]; + if (decimal_places < 0) { + return Status::InvalidArgument( + "The second argument is {}, it can not be less than 0.", + decimal_places); + } + const Decimal128V3& frac_part = decimal128_column->get_fractional_part(i); + const Decimal128V3& whole_part = decimal128_column->get_whole_part(i); + + StringRef str = + FormatRound::do_format_round<Int128, + FormatRound::MAX_FORMAT_LEN_DEC128V3()>( + context, scale, whole_part.value, frac_part.value, decimal_places); + + result_column->insert_data(str.data, str.size); + } + } else { + return Status::InternalError("Not supported input argument type {}", + col_ptr->get_name()); + } + return Status::OK(); + } +}; + class FunctionStringLocatePos : public IFunction { public: static constexpr auto name = "locate"; diff --git a/be/test/vec/function/function_math_test.cpp b/be/test/vec/function/function_math_test.cpp index a0feec51753..c7ffebb8bc1 100644 --- a/be/test/vec/function/function_math_test.cpp +++ b/be/test/vec/function/function_math_test.cpp @@ -585,4 +585,63 @@ TEST(MathFunctionTest, money_format_test) { } } +TEST(MathFunctionTest, format_round_test) { + std::string func_name = "format_round"; + + { + InputTypeSet input_types = {TypeIndex::Int64, TypeIndex::Int32}; + DataSet data_set = {{{Null(), INT(2)}, Null()}, + {{BIGINT(17014116), INT(2)}, VARCHAR("17,014,116.00")}, + {{BIGINT(-17014116), INT(2)}, VARCHAR("-17,014,116.00")}, + {{BIGINT(1), INT(0)}, VARCHAR("1")}, + {{BIGINT(123456), INT(0)}, VARCHAR("123,456")}, + {{BIGINT(123456), INT(3)}, VARCHAR("123,456.000")}, + {{BIGINT(123456), INT(10)}, VARCHAR("123,456.0000000000")}, + {{BIGINT(123456), INT(20)}, VARCHAR("123,456.00000000000000000000")}}; + + static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set)); + } + { + InputTypeSet input_types = {TypeIndex::Int128, TypeIndex::Int32}; + DataSet data_set = { + {{Null(), INT(2)}, Null()}, + {{LARGEINT(17014116), INT(2)}, VARCHAR("17,014,116.00")}, + {{LARGEINT(-17014116), INT(2)}, VARCHAR("-17,014,116.00")}, + {{LARGEINT(1), INT(0)}, VARCHAR("1")}, + {{LARGEINT(123456), INT(0)}, VARCHAR("123,456")}, + {{LARGEINT(123456), INT(3)}, VARCHAR("123,456.000")}, + {{LARGEINT(123456), INT(10)}, VARCHAR("123,456.0000000000")}, + {{LARGEINT(123456), INT(20)}, VARCHAR("123,456.00000000000000000000")}, + {{LARGEINT(123456789123456789), INT(2)}, VARCHAR("123,456,789,123,456,789.00")}}; + + static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set)); + } + { + InputTypeSet input_types = {TypeIndex::Float64, TypeIndex::Int32}; + DataSet data_set = {{{Null(), INT(2)}, Null()}, + {{DOUBLE(17014116.67), INT(2)}, VARCHAR("17,014,116.67")}, + {{DOUBLE(-17014116.67), INT(2)}, VARCHAR("-17,014,116.67")}, + {{DOUBLE(-123.45), INT(2)}, VARCHAR("-123.45")}}; + + static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set)); + } + { + InputTypeSet input_types = {TypeIndex::Decimal128V2, TypeIndex::Int32}; + DataSet data_set = {{{Null(), INT(2)}, Null()}, + {{DECIMALV2(17014116.67), INT(2)}, VARCHAR("17,014,116.67")}, + {{DECIMALV2(-17014116.67), INT(2)}, VARCHAR("-17,014,116.67")}}; + + static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set)); + } + { + BaseInputTypeSet input_types = {TypeIndex::Decimal64, TypeIndex::Int32}; + DataSet data_set = { + {{Null(), INT(2)}, Null()}, + {{DECIMAL64(17014116, 670000000), INT(2)}, VARCHAR("17,014,116.67")}, + {{DECIMAL64(-17014116, -670000000), INT(2)}, VARCHAR("-17,014,116.67")}}; + + check_function_all_arg_comb<DataTypeString, true>(func_name, input_types, data_set); + } +} + } // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index a621b9a054d..562c8ef3488 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -188,6 +188,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.FindInSet; import org.apache.doris.nereids.trees.expressions.functions.scalar.FirstSignificantSubdomain; import org.apache.doris.nereids.trees.expressions.functions.scalar.Floor; import org.apache.doris.nereids.trees.expressions.functions.scalar.Fmod; +import org.apache.doris.nereids.trees.expressions.functions.scalar.FormatRound; import org.apache.doris.nereids.trees.expressions.functions.scalar.Fpow; import org.apache.doris.nereids.trees.expressions.functions.scalar.FromBase64; import org.apache.doris.nereids.trees.expressions.functions.scalar.FromDays; @@ -872,6 +873,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(Right.class, "right", "strright"), scalar(Round.class, "round"), scalar(RoundBankers.class, "round_bankers"), + scalar(FormatRound.class, "format_round"), scalar(Rpad.class, "rpad"), scalar(Rtrim.class, "rtrim"), scalar(RtrimIn.class, "rtrim_in"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/FormatRound.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/FormatRound.java new file mode 100644 index 00000000000..368bcbab296 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/FormatRound.java @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.StringType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** FormatRound function */ +public class FormatRound extends ScalarFunction + implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { + public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + FunctionSignature.ret(StringType.INSTANCE).args(BigIntType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(StringType.INSTANCE).args(LargeIntType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(StringType.INSTANCE).args(DoubleType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(StringType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, IntegerType.INSTANCE), + FunctionSignature.ret(StringType.INSTANCE).args(DecimalV3Type.WILDCARD, IntegerType.INSTANCE)); + + /** + * constructor with 2 or more arguments. + */ + public FormatRound(Expression arg0, Expression arg1) { + super("format_round", arg0, arg1); + } + + /** + * withChildren. + */ + @Override + public FormatRound withChildren(List<Expression> children) { + Preconditions.checkArgument(children.size() == 2); + return new FormatRound(children.get(0), children.get(1)); + } + + @Override + public List<FunctionSignature> getSignatures() { + return SIGNATURES; + } + + @Override + public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { + return visitor.visitFormatRound(this, context); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index d31a80bda42..4e1c1a679a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -196,6 +196,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.FindInSet; import org.apache.doris.nereids.trees.expressions.functions.scalar.FirstSignificantSubdomain; import org.apache.doris.nereids.trees.expressions.functions.scalar.Floor; import org.apache.doris.nereids.trees.expressions.functions.scalar.Fmod; +import org.apache.doris.nereids.trees.expressions.functions.scalar.FormatRound; import org.apache.doris.nereids.trees.expressions.functions.scalar.Fpow; import org.apache.doris.nereids.trees.expressions.functions.scalar.FromBase64; import org.apache.doris.nereids.trees.expressions.functions.scalar.FromDays; @@ -1876,6 +1877,10 @@ public interface ScalarFunctionVisitor<R, C> { return visitScalarFunction(roundBankers, context); } + default R visitFormatRound(FormatRound formatRound, C context) { + return visitScalarFunction(formatRound, context); + } + default R visitRpad(Rpad rpad, C context) { return visitScalarFunction(rpad, context); } diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 829293bf6c3..a14251a6bf5 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -1694,6 +1694,12 @@ visible_functions = { [['repeat'], 'STRING', ['STRING', 'INT'], 'ALWAYS_NULLABLE'], [['lpad'], 'STRING', ['STRING', 'INT', 'STRING'], 'ALWAYS_NULLABLE'], [['rpad'], 'STRING', ['STRING', 'INT', 'STRING'], 'ALWAYS_NULLABLE'], + [['format_round'], 'STRING', ['BIGINT'], ''], + [['format_round'], 'STRING', ['LARGEINT'], ''], + [['format_round'], 'STRING', ['DOUBLE', 'INT'], ''], + [['format_round'], 'STRING', ['DECIMAL32', 'INT'], ''], + [['format_round'], 'STRING', ['DECIMAL64', 'INT'], ''], + [['format_round'], 'STRING', ['DECIMAL128', 'INT'], ''], [['append_trailing_char_if_absent'], 'STRING', ['STRING', 'STRING'], 'ALWAYS_NULLABLE'], [['length'], 'INT', ['STRING'], ''], [['crc32'], 'BIGINT', ['STRING'], ''], diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_format_round.out b/regression-test/data/query_p0/sql_functions/math_functions/test_format_round.out new file mode 100644 index 00000000000..42ce5b1ff28 Binary files /dev/null and b/regression-test/data/query_p0/sql_functions/math_functions/test_format_round.out differ diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_format_round.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_format_round.groovy new file mode 100644 index 00000000000..dfd4a057a36 --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_format_round.groovy @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_format_round", "p0") { + + order_qt_format_round_0 """ + select format_round(1234567.8910,5), + format_round(-1111.235,2), + format_round(123.49 ,0), + format_round(cast(1.44999 as double),5), + format_round(cast(1.44999 as decimal(20,2)),5); + """ + order_qt_format_round_1 """ select format_round(1, 0) """ + order_qt_format_round_2 """ select format_round(123456, 0) """ + order_qt_format_round_3 """ select format_round(123456, 3) """ + order_qt_format_round_4 """ select format_round(123456, 10) """ + order_qt_format_round_5 """ select format_round(123456.123456, 0) """ + order_qt_format_round_6 """ select format_round(123456.123456, 3) """ + order_qt_format_round_7 """ select format_round(123456.123456, 8) """ + + sql """ DROP TABLE IF EXISTS test_format_round """ + sql """ + CREATE TABLE IF NOT EXISTS test_format_round ( + `user_id` LARGEINT NOT NULL COMMENT "用户id", + `int_col` int COMMENT "", + `bigint_col` bigint COMMENT "", + `largeint_col` largeint COMMENT "", + `double_col` double COMMENT "", + `decimal_col` decimal COMMENT "" + ) + DISTRIBUTED BY HASH(user_id) PROPERTIES("replication_num" = "1"); + """ + + sql """ INSERT INTO test_format_round VALUES + (1, 123, 123456, 123455677788, 123456.1234567, 123456.1234567); + """ + qt_select_default """ SELECT * FROM test_format_round t ORDER BY user_id; """ + + order_qt_format_round_8 """ select format_round(int_col, 6) from test_format_round""" + order_qt_format_round_9 """ select format_round(bigint_col, 6) from test_format_round""" + order_qt_format_round_10 """ select format_round(largeint_col, 6) from test_format_round""" + order_qt_format_round_12 """ select format_round(double_col, 6) from test_format_round""" + order_qt_format_round_13 """ select format_round(decimal_col, 6) from test_format_round""" + + test { + sql """select format_round(1234567.8910, -1) """ + exception "it can not be less than 0" + } + +} \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org