This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new bf1b47d12c8 [fix](decimal256) support decimal256 for many functions (#42136) bf1b47d12c8 is described below commit bf1b47d12c853e0f5e62527c075eb0fc1ea8cc63 Author: TengJianPing <18241664+jackte...@users.noreply.github.com> AuthorDate: Wed Oct 23 18:14:57 2024 +0800 [fix](decimal256) support decimal256 for many functions (#42136) ## Proposed changes Issue Number: close #xxx Support decimal256 for the following functions: ``` multi_distinct_sum multi_distinct_count array_sum array_avg array_product array_cum_sum ``` --- .../vec/aggregate_functions/aggregate_function.h | 4 + .../aggregate_function_approx_count_distinct.cpp | 3 +- .../aggregate_functions/aggregate_function_avg.cpp | 15 +- .../aggregate_function_bitmap.cpp | 9 +- .../aggregate_function_bitmap_agg.cpp | 3 +- .../aggregate_function_collect.cpp | 3 +- .../aggregate_function_corr.cpp | 3 +- .../aggregate_function_count.cpp | 9 +- .../aggregate_function_count_by_enum.cpp | 3 +- .../aggregate_function_covar.cpp | 6 +- .../aggregate_function_distinct.cpp | 5 +- .../aggregate_function_foreach.cpp | 7 +- .../aggregate_function_group_array_intersect.cpp | 3 +- .../aggregate_function_group_concat.cpp | 3 +- .../aggregate_function_histogram.cpp | 3 +- .../aggregate_function_kurtosis.cpp | 3 +- .../aggregate_function_linear_histogram.cpp | 3 +- .../aggregate_functions/aggregate_function_map.cpp | 3 +- .../aggregate_function_min_max.cpp | 3 +- .../aggregate_function_min_max.h | 3 +- .../aggregate_function_min_max_by.h | 3 +- .../aggregate_function_orthogonal_bitmap.cpp | 3 +- .../aggregate_function_percentile.cpp | 9 +- .../aggregate_function_product.h | 16 +- .../aggregate_function_quantile_state.cpp | 6 +- .../aggregate_function_quantile_state.h | 6 +- .../aggregate_function_reader_first_last.h | 39 +-- .../aggregate_function_regr_union.cpp | 3 +- .../aggregate_function_sequence_match.cpp | 3 +- .../aggregate_function_simple_factory.h | 15 +- .../aggregate_function_skew.cpp | 3 +- .../aggregate_function_stddev.cpp | 12 +- .../aggregate_functions/aggregate_function_sum.cpp | 15 +- .../aggregate_functions/aggregate_function_sum.h | 1 - .../aggregate_function_topn.cpp | 9 +- .../aggregate_function_uniq.cpp | 8 +- .../aggregate_function_uniq_distribute_key.cpp | 3 +- .../aggregate_function_window.cpp | 6 +- .../aggregate_function_window_funnel.cpp | 3 +- be/src/vec/aggregate_functions/helpers.h | 15 +- be/src/vec/core/wide_integer.h | 5 + be/src/vec/core/wide_integer_impl.h | 34 +-- be/src/vec/exec/scan/vfile_scanner.cpp | 5 +- be/src/vec/exprs/vcase_expr.cpp | 6 +- be/src/vec/exprs/vcast_expr.cpp | 6 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 4 +- be/src/vec/exprs/vectorized_fn_call.cpp | 3 +- be/src/vec/exprs/vin_predicate.cpp | 5 +- be/src/vec/exprs/vmatch_predicate.cpp | 6 +- be/src/vec/exprs/vtopn_pred.h | 2 +- .../functions/array/function_array_aggregation.cpp | 100 +++++-- .../vec/functions/array/function_array_cum_sum.cpp | 31 ++- be/src/vec/functions/comparison_equal_for_null.cpp | 11 +- be/src/vec/functions/function.h | 4 + be/src/vec/functions/function_coalesce.cpp | 21 +- be/src/vec/functions/function_ifnull.h | 4 +- be/src/vec/functions/nullif.cpp | 11 +- be/src/vec/functions/simple_function_factory.h | 11 +- .../agg_linear_histogram_test.cpp | 3 +- .../decimalv3}/aggregate_decimal256.out | 8 + .../decimalv3/test_decimal256_array.out | 63 +++++ .../decimalv3/test_decimal256_multi_distinct.out | 33 +++ .../nereids_function_p0/scalar_function/Array.out | 306 +++++++++++++++++++++ .../decimalv3}/aggregate_decimal256.groovy | 4 +- .../decimalv3/test_decimal256_array.groovy | 118 ++++++++ .../test_decimal256_multi_distinct.groovy | 73 +++++ .../scalar_function/Array.groovy | 24 +- 67 files changed, 956 insertions(+), 217 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index 05f1bd2a602..cd1f8922e1b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -38,6 +38,10 @@ class Arena; class IColumn; class IDataType; +struct AggregateFunctionAttr { + bool enable_decimal256 {false}; +}; + template <bool nullable, typename ColVecType> class AggregateFunctionBitmapCount; template <typename Op> diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp index 10616be4258..18662bf66cf 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp @@ -31,7 +31,8 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_approx_count_distinct( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { WhichDataType which(remove_nullable(argument_types[0])); #define DISPATCH(TYPE, COLUMN_TYPE) \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 0f3d0fd3bda..6a6711f90f9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -45,8 +45,17 @@ template <typename T> using AggregateFuncAvgDecimal256 = typename AvgDecimal256<T>::Function; void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { - factory.register_function_both("avg", creator_with_type::creator<AggregateFuncAvg>); - factory.register_function_both("avg_decimal256", - creator_with_type::creator<AggregateFuncAvgDecimal256>); + AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { + if (attr.enable_decimal256) { + return creator_with_type::creator<AggregateFuncAvgDecimal256>(name, types, + result_is_nullable, attr); + } else { + return creator_with_type::creator<AggregateFuncAvg>(name, types, result_is_nullable, + attr); + } + }; + factory.register_function_both("avg", creator); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp index 0676fd5bc27..e9c86d4b955 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp @@ -40,9 +40,9 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_type) { return nullptr; } -AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +AggregateFunctionPtr create_aggregate_function_bitmap_union_count( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return std::make_shared<AggregateFunctionBitmapCount<true, ColumnBitmap>>(argument_types); @@ -53,7 +53,8 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::str AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return AggregateFunctionPtr( diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp index b8ae4c6530d..0b95ddfd46f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp @@ -41,7 +41,8 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_types) AggregateFunctionPtr create_aggregate_function_bitmap_agg(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return AggregateFunctionPtr(create_with_int_data_type<true>(argument_types)); diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp index 4fcf09b59b3..d726b7c6355 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp @@ -96,7 +96,8 @@ AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& n AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() == 1) { if (name == "array_agg") { return create_aggregate_function_collect_impl<std::false_type, std::true_type>( diff --git a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp index a454afb45f2..cdaab6e086f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp @@ -89,7 +89,8 @@ struct CorrMoment { AggregateFunctionPtr create_aggregate_corr_function(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_binary(name, argument_types); return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0], argument_types[1], argument_types, result_is_nullable); diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.cpp b/be/src/vec/aggregate_functions/aggregate_function_count.cpp index 8c54714b046..5cfe5af4198 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_count.cpp @@ -29,15 +29,16 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_count(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_arity_at_most<1>(name, argument_types); return std::make_shared<AggregateFunctionCount>(argument_types); } -AggregateFunctionPtr create_aggregate_function_count_not_null_unary(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +AggregateFunctionPtr create_aggregate_function_count_not_null_unary( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_arity_at_most<1>(name, argument_types); return std::make_shared<AggregateFunctionCountNotNullUnary>(argument_types); diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp index 1a0bf251820..093b31d57db 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp @@ -29,7 +29,8 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_count_by_enum(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() < 1) { LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", argument_types.size(), name); diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp index b02d6ae0e12..71d09f61de4 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp @@ -53,14 +53,16 @@ AggregateFunctionPtr create_function_single_value(const String& name, AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData>( name, argument_types, result_is_nullable, NOTNULLABLE); } AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_function_single_value<AggregateFunctionPop, CovarName, PopData>( name, argument_types, result_is_nullable, NOTNULLABLE); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp index 9bb2954207b..fce58b38688 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp @@ -83,7 +83,8 @@ const std::string DISTINCT_FUNCTION_PREFIX = "multi_distinct_"; void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory) { AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { // 1. we should get not nullable types; DataTypes nested_types(types.size()); std::transform(types.begin(), types.end(), nested_types.begin(), @@ -92,7 +93,7 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact auto transform_arguments = function_combinator->transform_arguments(nested_types); auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size()); auto nested_function = factory.get(nested_function_name, transform_arguments, false, - BeExecVersionManager::get_newest_version()); + BeExecVersionManager::get_newest_version(), attr); return function_combinator->transform_aggregate_function(nested_function, types, result_is_nullable); }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp index ab6d0142f6a..c1cbcc89996 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp @@ -34,8 +34,9 @@ namespace doris::vectorized { void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory) { - AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, - const bool result_is_nullable) -> AggregateFunctionPtr { + AggregateFunctionCreator creator = + [&](const std::string& name, const DataTypes& types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) -> AggregateFunctionPtr { const std::string& suffix = AggregateFunctionForEach::AGG_FOREACH_SUFFIX; DataTypes transform_arguments; for (const auto& t : types) { @@ -46,7 +47,7 @@ void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFacto auto nested_function_name = name.substr(0, name.size() - suffix.size()); auto nested_function = factory.get(nested_function_name, transform_arguments, result_is_nullable, - BeExecVersionManager::get_newest_version(), false); + BeExecVersionManager::get_newest_version(), attr); if (!nested_function) { throw Exception( ErrorCode::INTERNAL_ERROR, diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp index b3b9a8b9af4..24faf58b2e1 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp @@ -70,7 +70,8 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl } AggregateFunctionPtr create_aggregate_function_group_array_intersect( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_unary(name, argument_types); const DataTypePtr& argument_type = remove_nullable(argument_types[0]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp index 9661b9c89d5..286795ea2ba 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp @@ -28,7 +28,8 @@ const std::string AggregateFunctionGroupConcatImplStr::separator = ","; AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() == 1) { return creator_without_type::create< AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>>( diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp index 5b06af28399..fb2fa9c2513 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp @@ -47,7 +47,8 @@ AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_typ AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { WhichDataType type(remove_nullable(argument_types[0])); #define DISPATCH(TYPE) \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp index 00ad1893eaf..a763721f3f4 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp @@ -45,7 +45,8 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_kurt(const DataTypes& AggregateFunctionPtr create_aggregate_function_kurt(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() != 1) { LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument"; return nullptr; diff --git a/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp index 62ce1657526..683cf1a18f7 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp @@ -41,7 +41,8 @@ AggregateFunctionPtr create_agg_function_linear_histogram(const DataTypes& argum AggregateFunctionPtr create_aggregate_function_linear_histogram(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { WhichDataType type(remove_nullable(argument_types[0])); #define DISPATCH(TYPE) \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.cpp b/be/src/vec/aggregate_functions/aggregate_function_map.cpp index bcf3f2d66df..f289d885f48 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_map.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_map.cpp @@ -32,7 +32,8 @@ AggregateFunctionPtr create_agg_function_map_agg(const DataTypes& argument_types AggregateFunctionPtr create_aggregate_function_map_agg(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { WhichDataType type(remove_nullable(argument_types[0])); #define DISPATCH(TYPE) \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index 8aa8850a314..c1a72fd52bd 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -30,7 +30,8 @@ namespace doris::vectorized { template <template <typename> class Data> AggregateFunctionPtr create_aggregate_function_single_value(const String& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { assert_unary(name, argument_types); AggregateFunctionPtr res(creator_with_numeric_type::create<AggregateFunctionsSingleValue, Data, diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.h b/be/src/vec/aggregate_functions/aggregate_function_min_max.h index 1281e7ca4c4..a5423cd72f5 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.h +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.h @@ -714,5 +714,6 @@ public: template <template <typename> class Data> AggregateFunctionPtr create_aggregate_function_single_value(const String& name, const DataTypes& argument_types, - const bool result_is_nullable); + const bool result_is_nullable, + const AggregateFunctionAttr& attr = {}); } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h index e4693115120..4caded0011a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h @@ -243,7 +243,8 @@ template <template <typename> class AggregateFunctionTemplate, template <typename, typename> class Data> AggregateFunctionPtr create_aggregate_function_min_max_by(const String& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() != 2) { return nullptr; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp index 97269ced37e..fe41aba2f0b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp @@ -35,7 +35,8 @@ template <template <typename> class Impl> AggregateFunctionPtr create_aggregate_function_orthogonal(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.empty()) { LOG(WARNING) << "Incorrect number of arguments for aggregate function " << name; return nullptr; diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp index 8f528e0121d..bb4e1bd81e3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp @@ -23,9 +23,9 @@ namespace doris::vectorized { -AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +AggregateFunctionPtr create_aggregate_function_percentile_approx( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const DataTypePtr& argument_type = remove_nullable(argument_types[0]); WhichDataType which(argument_type); if (which.idx != TypeIndex::Float64) { @@ -43,7 +43,8 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri } AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const DataTypePtr& argument_type = remove_nullable(argument_types[0]); WhichDataType which(argument_type); if (which.idx != TypeIndex::Float64) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_product.h b/be/src/vec/aggregate_functions/aggregate_function_product.h index 1ec9a2711ce..82f765a909d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_product.h +++ b/be/src/vec/aggregate_functions/aggregate_function_product.h @@ -76,18 +76,20 @@ struct AggregateFunctionProductData<Decimal128V2> { void reset(Decimal128V2 value) { product = std::move(value); } }; +template <typename T> +concept DecimalTypeConcept = std::is_same_v<T, Decimal128V3> || std::is_same_v<T, Decimal256>; -template <> -struct AggregateFunctionProductData<Decimal128V3> { - Decimal128V3 product {}; +template <DecimalTypeConcept T> +struct AggregateFunctionProductData<T> { + T product {}; template <typename NestedType> - void add(Decimal<NestedType> value, Decimal128V3 multiplier) { + void add(Decimal<NestedType> value, T multiplier) { product *= value; product /= multiplier; } - void merge(const AggregateFunctionProductData& other, Decimal128V3 multiplier) { + void merge(const AggregateFunctionProductData& other, T multiplier) { product *= other.product; product /= multiplier; } @@ -96,9 +98,9 @@ struct AggregateFunctionProductData<Decimal128V3> { void read(BufferReadable& buffer) { read_binary(product, buffer); } - Decimal128V2 get() const { return product; } + T get() const { return product; } - void reset(Decimal128V2 value) { product = value; } + void reset(T value) { product = std::move(value); } }; template <typename T, typename TResult, typename Data> diff --git a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.cpp b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.cpp index f50870a277b..128edc59915 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.cpp @@ -24,9 +24,9 @@ namespace doris::vectorized { -AggregateFunctionPtr create_aggregate_function_quantile_state_union(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +AggregateFunctionPtr create_aggregate_function_quantile_state_union( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return std::make_shared< diff --git a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h index 5954120553e..c48ac920919 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h +++ b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h @@ -152,8 +152,8 @@ public: void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } }; -AggregateFunctionPtr create_aggregate_function_quantile_state_union(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable); +AggregateFunctionPtr create_aggregate_function_quantile_state_union( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr); } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h index b9d2545e0c0..60ab42b5298 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h +++ b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h @@ -275,25 +275,26 @@ AggregateFunctionPtr create_function_single_value(const String& name, return nullptr; } -#define CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, FUNCTION_DATA) \ - template <bool is_copy> \ - AggregateFunctionPtr CREATE_FUNCTION_NAME( \ - const std::string& name, const DataTypes& argument_types, bool result_is_nullable) { \ - const bool arg_is_nullable = argument_types[0]->is_nullable(); \ - AggregateFunctionPtr res = nullptr; \ - std::visit( \ - [&](auto result_is_nullable, auto arg_is_nullable) { \ - res = AggregateFunctionPtr( \ - create_function_single_value<ReaderFunctionData, FUNCTION_DATA, \ - result_is_nullable, arg_is_nullable, \ - is_copy>(name, argument_types)); \ - }, \ - make_bool_variant(result_is_nullable), make_bool_variant(arg_is_nullable)); \ - if (!res) { \ - LOG(WARNING) << " failed in create_aggregate_function_" << name \ - << " and type is: " << argument_types[0]->get_name(); \ - } \ - return res; \ +#define CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, FUNCTION_DATA) \ + template <bool is_copy> \ + AggregateFunctionPtr CREATE_FUNCTION_NAME( \ + const std::string& name, const DataTypes& argument_types, bool result_is_nullable, \ + const AggregateFunctionAttr& attr) { \ + const bool arg_is_nullable = argument_types[0]->is_nullable(); \ + AggregateFunctionPtr res = nullptr; \ + std::visit( \ + [&](auto result_is_nullable, auto arg_is_nullable) { \ + res = AggregateFunctionPtr( \ + create_function_single_value<ReaderFunctionData, FUNCTION_DATA, \ + result_is_nullable, arg_is_nullable, \ + is_copy>(name, argument_types)); \ + }, \ + make_bool_variant(result_is_nullable), make_bool_variant(arg_is_nullable)); \ + if (!res) { \ + LOG(WARNING) << " failed in create_aggregate_function_" << name \ + << " and type is: " << argument_types[0]->get_name(); \ + } \ + return res; \ } CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_first, ReaderFunctionFirstData); diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp b/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp index 738d777441c..c20b5977f21 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp @@ -58,7 +58,8 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_regr(const DataTypes& template <template <typename> class StatFunctionTemplate> AggregateFunctionPtr create_aggregate_function_regr(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() != 2) { LOG(WARNING) << "aggregate function " << name << " requires exactly 2 arguments"; return nullptr; diff --git a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp index 2953db15c5b..c49ee021dbf 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp @@ -30,7 +30,8 @@ namespace doris::vectorized { template <template <typename, typename> typename AggregateFunction> AggregateFunctionPtr create_aggregate_function_sequence_base(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { const auto arg_count = argument_types.size(); if (arg_count < 4) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index 7f6a200bcb2..aa33e7289df 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -38,8 +38,8 @@ namespace doris::vectorized { using DataTypePtr = std::shared_ptr<const IDataType>; using DataTypes = std::vector<DataTypePtr>; -using AggregateFunctionCreator = - std::function<AggregateFunctionPtr(const std::string&, const DataTypes&, const bool)>; +using AggregateFunctionCreator = std::function<AggregateFunctionPtr( + const std::string&, const DataTypes&, const bool, const AggregateFunctionAttr&)>; inline std::string types_name(const DataTypes& types) { std::string name; @@ -119,7 +119,7 @@ public: AggregateFunctionPtr get(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, int be_version, - bool enable_decimal256 = false) { + AggregateFunctionAttr attr = {}) { bool nullable = false; for (const auto& type : argument_types) { if (type->is_nullable()) { @@ -128,11 +128,6 @@ public: } std::string name_str = name; - if (enable_decimal256) { - if (name_str == "sum" || name_str == "avg") { - name_str += "_decimal256"; - } - } temporary_function_update(be_version, name_str); if (function_alias.contains(name)) { @@ -142,12 +137,12 @@ public: return nullable_aggregate_functions.find(name_str) == nullable_aggregate_functions.end() ? nullptr : nullable_aggregate_functions[name_str](name_str, argument_types, - result_is_nullable); + result_is_nullable, attr); } else { return aggregate_functions.find(name_str) == aggregate_functions.end() ? nullptr : aggregate_functions[name_str](name_str, argument_types, - result_is_nullable); + result_is_nullable, attr); } } diff --git a/be/src/vec/aggregate_functions/aggregate_function_skew.cpp b/be/src/vec/aggregate_functions/aggregate_function_skew.cpp index 144e482ad23..af2eb443eb0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_skew.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_skew.cpp @@ -45,7 +45,8 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_skew(const DataTypes& AggregateFunctionPtr create_aggregate_function_skew(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() != 1) { LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument"; return nullptr; diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp index 72448a419e9..f9fe2dca748 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp @@ -53,7 +53,8 @@ AggregateFunctionPtr create_function_single_value(const String& name, AggregateFunctionPtr create_aggregate_function_variance_samp(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_function_single_value<AggregateFunctionSamp, VarianceSampName, SampData, false>( name, argument_types, result_is_nullable, false); } @@ -61,7 +62,8 @@ AggregateFunctionPtr create_aggregate_function_variance_samp(const std::string& template <bool is_stddev> AggregateFunctionPtr create_aggregate_function_variance_pop(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_function_single_value<AggregateFunctionPop, VarianceName, PopData, is_stddev>( name, argument_types, result_is_nullable, false); } @@ -69,14 +71,16 @@ AggregateFunctionPtr create_aggregate_function_variance_pop(const std::string& n template <bool is_stddev> AggregateFunctionPtr create_aggregate_function_stddev_pop(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_function_single_value<AggregateFunctionPop, StddevName, PopData, is_stddev>( name, argument_types, result_is_nullable, false); } AggregateFunctionPtr create_aggregate_function_stddev_samp(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_function_single_value<AggregateFunctionSamp, StddevSampName, SampData, true>( name, argument_types, result_is_nullable, false); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp index e0676957d46..91063c22dc6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp @@ -26,9 +26,18 @@ namespace doris::vectorized { void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) { - factory.register_function_both("sum", creator_with_type::creator<AggregateFunctionSumSimple>); - factory.register_function_both( - "sum_decimal256", creator_with_type::creator<AggregateFunctionSumSimpleDecimal256>); + AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { + if (attr.enable_decimal256) { + return creator_with_type::creator<AggregateFunctionSumSimpleDecimal256>( + name, types, result_is_nullable, attr); + } else { + return creator_with_type::creator<AggregateFunctionSumSimple>(name, types, + result_is_nullable, attr); + } + }; + factory.register_function_both("sum", creator); } void register_aggregate_function_sum0(AggregateFunctionSimpleFactory& factory) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.h b/be/src/vec/aggregate_functions/aggregate_function_sum.h index cc05435a950..846104915b1 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.h @@ -232,7 +232,6 @@ struct SumSimple { template <typename T> using AggregateFunctionSumSimple = typename SumSimple<T, true>::Function; -const static std::string DECIMAL256_SUFFIX {"_decimal256"}; template <typename T, bool level_up> struct SumSimpleDecimal256 { /// @note It uses slow Decimal128 (cause we need such a variant). sumWithOverflow is faster for Decimal32/64 diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp index 20513ef189e..799d8fe1c75 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp @@ -28,7 +28,8 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_topn(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() == 2) { return creator_without_type::create<AggregateFunctionTopN<AggregateFunctionTopNImplInt>>( argument_types, result_is_nullable); @@ -82,7 +83,8 @@ AggregateFunctionPtr create_topn_array(const DataTypes& argument_types, AggregateFunctionPtr create_aggregate_function_topn_array(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { bool has_default_param = (argument_types.size() == 3); if (has_default_param) { return create_topn_array<AggregateFunctionTopNImplArray, true, false>(argument_types, @@ -95,7 +97,8 @@ AggregateFunctionPtr create_aggregate_function_topn_array(const std::string& nam AggregateFunctionPtr create_aggregate_function_topn_weighted(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { bool has_default_param = (argument_types.size() == 4); if (has_default_param) { return create_topn_array<AggregateFunctionTopNImplWeight, true, true>(argument_types, diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp b/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp index 735b8a737c2..25231025416 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp @@ -25,6 +25,7 @@ #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/helpers.h" #include "vec/common/hash_table/hash.h" // IWYU pragma: keep +#include "vec/core/wide_integer.h" #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_nullable.h" @@ -33,7 +34,8 @@ namespace doris::vectorized { template <template <typename> class Data> AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() == 1) { const IDataType& argument_type = *remove_nullable(argument_types[0]); WhichDataType which(argument_type); @@ -51,6 +53,10 @@ AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name, } else if (which.is_decimal128v3()) { return creator_without_type::create<AggregateFunctionUniq<Decimal128V3, Data<Int128>>>( argument_types, result_is_nullable); + } else if (which.is_decimal256()) { + return creator_without_type::create< + AggregateFunctionUniq<Decimal256, Data<wide::Int256>>>(argument_types, + result_is_nullable); } else if (which.is_decimal128v2() || which.is_decimal128v3()) { return creator_without_type::create<AggregateFunctionUniq<Decimal128V2, Data<Int128>>>( argument_types, result_is_nullable); diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp index 3bf979483b5..c89c8aa14f0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp @@ -28,7 +28,8 @@ namespace doris::vectorized { template <template <typename> class Data> AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() == 1) { const IDataType& argument_type = *remove_nullable(argument_types[0]); WhichDataType which(argument_type); diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp b/be/src/vec/aggregate_functions/aggregate_function_window.cpp index 44575588187..9da838a6b90 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp @@ -68,9 +68,9 @@ AggregateFunctionPtr create_function_lead_lag_first_last(const String& name, #define CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, FUNCTION_DATA, \ FUNCTION_IMPL) \ - AggregateFunctionPtr CREATE_FUNCTION_NAME(const std::string& name, \ - const DataTypes& argument_types, \ - const bool result_is_nullable) { \ + AggregateFunctionPtr CREATE_FUNCTION_NAME( \ + const std::string& name, const DataTypes& argument_types, \ + const bool result_is_nullable, const AggregateFunctionAttr& attr) { \ const bool arg_is_nullable = argument_types[0]->is_nullable(); \ AggregateFunctionPtr res = nullptr; \ \ diff --git a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp index c16121bad73..f95dccd547e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp @@ -32,7 +32,8 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_window_funnel(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { if (argument_types.size() < 3) { LOG(WARNING) << "window_funnel's argument less than 3."; return nullptr; diff --git a/be/src/vec/aggregate_functions/helpers.h b/be/src/vec/aggregate_functions/helpers.h index 9bea4e74536..34b7e76c2ea 100644 --- a/be/src/vec/aggregate_functions/helpers.h +++ b/be/src/vec/aggregate_functions/helpers.h @@ -107,7 +107,8 @@ struct creator_without_type { template <typename AggregateFunctionTemplate> static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { CHECK_AGG_FUNCTION_SERIALIZED_TYPE(AggregateFunctionTemplate); return create<AggregateFunctionTemplate>(argument_types, result_is_nullable); } @@ -194,7 +195,8 @@ struct creator_with_type_base { template <template <typename> class AggregateFunctionTemplate> static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_base<CurryDirect<AggregateFunctionTemplate>>(argument_types, result_is_nullable); } @@ -206,7 +208,8 @@ struct creator_with_type_base { template <template <typename> class AggregateFunctionTemplate, template <typename> class Data> static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_base<CurryData<AggregateFunctionTemplate, Data>>(argument_types, result_is_nullable); } @@ -221,7 +224,8 @@ struct creator_with_type_base { template <template <typename> class AggregateFunctionTemplate, template <typename> class Data, template <typename> class Impl> static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_base<CurryDataImpl<AggregateFunctionTemplate, Data, Impl>>( argument_types, result_is_nullable); } @@ -236,7 +240,8 @@ struct creator_with_type_base { template <template <typename, typename> class AggregateFunctionTemplate, template <typename> class Data> static AggregateFunctionPtr creator(const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { return create_base<CurryDirectAndData<AggregateFunctionTemplate, Data>>(argument_types, result_is_nullable); } diff --git a/be/src/vec/core/wide_integer.h b/be/src/vec/core/wide_integer.h index d74aef27a49..ec316dc2a80 100644 --- a/be/src/vec/core/wide_integer.h +++ b/be/src/vec/core/wide_integer.h @@ -79,7 +79,12 @@ public: // ctors constexpr integer() noexcept = default; + template <size_t Bits2, typename Signed2> + constexpr integer(const integer<Bits2, Signed2> rhs) noexcept; + template <typename T> + requires(std::is_arithmetic_v<T> || std::is_same_v<T, __int128> || + std::is_same_v<T, unsigned __int128>) constexpr integer(T rhs) noexcept; template <typename T> diff --git a/be/src/vec/core/wide_integer_impl.h b/be/src/vec/core/wide_integer_impl.h index 2eb9381bd37..e737da96c81 100644 --- a/be/src/vec/core/wide_integer_impl.h +++ b/be/src/vec/core/wide_integer_impl.h @@ -47,10 +47,6 @@ using FromDoubleIntermediateType = long double; using FromDoubleIntermediateType = boost::multiprecision::cpp_bin_float_double_extended; #endif -namespace CityHash_v1_0_2 { -struct uint128; -} - namespace wide { template <typename T> @@ -304,18 +300,6 @@ struct integer<Bits, Signed>::_impl { } } - template <typename CityHashUInt128 = CityHash_v1_0_2::uint128> - constexpr static void wide_integer_from_cityhash_uint128( - integer<Bits, Signed>& self, const CityHashUInt128& value) noexcept { - static_assert(sizeof(item_count) >= 2); - - if constexpr (std::endian::native == std::endian::little) { - wide_integer_from_tuple_like(self, std::make_pair(value.low64, value.high64)); - } else { - wide_integer_from_tuple_like(self, std::make_pair(value.high64, value.low64)); - } - } - /** * N.B. t is constructed from double, so max(t) = max(double) ~ 2^310 * the recursive call happens when t / 2^64 > 2^64, so there won't be more than 5 of them. @@ -1010,15 +994,19 @@ public: // Members +template <size_t Bits, typename Signed> +template <size_t Bits2, typename Signed2> +constexpr integer<Bits, Signed>::integer(const integer<Bits2, Signed2> rhs) noexcept : items {} { + _impl::wide_integer_from_wide_integer(*this, rhs); +} + template <size_t Bits, typename Signed> template <typename T> + requires(std::is_arithmetic_v<T> || std::is_same_v<T, __int128> || + std::is_same_v<T, unsigned __int128>) constexpr integer<Bits, Signed>::integer(T rhs) noexcept : items {} { - if constexpr (IsWideInteger<T>::value) { - _impl::wide_integer_from_wide_integer(*this, rhs); - } else if constexpr (IsTupleLike<T>::value) { + if constexpr (IsTupleLike<T>::value) { _impl::wide_integer_from_tuple_like(*this, rhs); - } else if constexpr (std::is_same_v<std::remove_cvref_t<T>, CityHash_v1_0_2::uint128>) { - _impl::wide_integer_from_cityhash_uint128(*this, rhs); } else { _impl::wide_integer_from_builtin(*this, rhs); } @@ -1032,8 +1020,6 @@ constexpr integer<Bits, Signed>::integer(std::initializer_list<T> il) noexcept : _impl::wide_integer_from_wide_integer(*this, *il.begin()); } else if constexpr (IsTupleLike<T>::value) { _impl::wide_integer_from_tuple_like(*this, *il.begin()); - } else if constexpr (std::is_same_v<std::remove_cvref_t<T>, CityHash_v1_0_2::uint128>) { - _impl::wide_integer_from_cityhash_uint128(*this, *il.begin()); } else { _impl::wide_integer_from_builtin(*this, *il.begin()); } @@ -1065,8 +1051,6 @@ template <typename T> constexpr integer<Bits, Signed>& integer<Bits, Signed>::operator=(T rhs) noexcept { if constexpr (IsTupleLike<T>::value) { _impl::wide_integer_from_tuple_like(*this, rhs); - } else if constexpr (std::is_same_v<std::remove_cvref_t<T>, CityHash_v1_0_2::uint128>) { - _impl::wide_integer_from_cityhash_uint128(*this, rhs); } else { _impl::wide_integer_from_builtin(*this, rhs); } diff --git a/be/src/vec/exec/scan/vfile_scanner.cpp b/be/src/vec/exec/scan/vfile_scanner.cpp index 4fc58177c8a..c3f4d12f9dc 100644 --- a/be/src/vec/exec/scan/vfile_scanner.cpp +++ b/be/src/vec/exec/scan/vfile_scanner.cpp @@ -449,8 +449,9 @@ Status VFileScanner::_cast_to_input_block(Block* block) { remove_nullable(return_type)->get_type_as_type_descriptor()); ColumnsWithTypeAndName arguments { arg, {data_type->create_column(), data_type, slot_desc->col_name()}}; - auto func_cast = - SimpleFunctionFactory::instance().get_function("CAST", arguments, return_type); + auto func_cast = SimpleFunctionFactory::instance().get_function( + "CAST", arguments, return_type, + {.enable_decimal256 = runtime_state()->enable_decimal256()}); idx = _src_block_name_to_idx[slot_desc->col_name()]; RETURN_IF_ERROR( func_cast->execute(nullptr, *_src_block_ptr, {idx}, idx, arg.column->size())); diff --git a/be/src/vec/exprs/vcase_expr.cpp b/be/src/vec/exprs/vcase_expr.cpp index d6573a0e25f..222a8f5629a 100644 --- a/be/src/vec/exprs/vcase_expr.cpp +++ b/be/src/vec/exprs/vcase_expr.cpp @@ -27,6 +27,7 @@ #include <vector> #include "common/status.h" +#include "runtime/runtime_state.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/core/block.h" @@ -66,8 +67,9 @@ Status VCaseExpr::prepare(RuntimeState* state, const RowDescriptor& desc, VExprC arguments.emplace_back(child->data_type()); } - _function = SimpleFunctionFactory::instance().get_function(_function_name, argument_template, - _data_type); + _function = SimpleFunctionFactory::instance().get_function( + _function_name, argument_template, _data_type, + {.enable_decimal256 = state->enable_decimal256()}); if (_function == nullptr) { return Status::NotSupported("vcase_expr Function {} is not implemented", _fn.name.function_name); diff --git a/be/src/vec/exprs/vcast_expr.cpp b/be/src/vec/exprs/vcast_expr.cpp index 6cd914080cd..38f861add87 100644 --- a/be/src/vec/exprs/vcast_expr.cpp +++ b/be/src/vec/exprs/vcast_expr.cpp @@ -27,6 +27,7 @@ #include "common/exception.h" #include "common/status.h" +#include "runtime/runtime_state.h" #include "vec/core/block.h" #include "vec/core/column_with_type_and_name.h" #include "vec/core/columns_with_type_and_name.h" @@ -64,8 +65,9 @@ doris::Status VCastExpr::prepare(doris::RuntimeState* state, const doris::RowDes argument_template.reserve(2); argument_template.emplace_back(nullptr, child->data_type(), child_name); argument_template.emplace_back(_cast_param, _cast_param_data_type, _target_data_type_name); - _function = SimpleFunctionFactory::instance().get_function(function_name, argument_template, - _data_type); + _function = SimpleFunctionFactory::instance().get_function( + function_name, argument_template, _data_type, + {.enable_decimal256 = state->enable_decimal256()}); if (_function == nullptr) { return Status::NotSupported("Function {} is not implemented", _fn.name.function_name); diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index 786986f4c78..45ad573cb5d 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -198,11 +198,11 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, _function = AggregateFunctionSimpleFactory::instance().get( _fn.name.function_name, argument_types, AggregateFunctionSimpleFactory::result_nullable_by_foreach(_data_type), - state->be_exec_version(), state->enable_decimal256()); + state->be_exec_version(), {.enable_decimal256 = state->enable_decimal256()}); } else { _function = AggregateFunctionSimpleFactory::instance().get( _fn.name.function_name, argument_types, _data_type->is_nullable(), - state->be_exec_version(), state->enable_decimal256()); + state->be_exec_version(), {.enable_decimal256 = state->enable_decimal256()}); } } if (_function == nullptr) { diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index cd9138ee971..3192653a816 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -106,7 +106,8 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, } else { // get the function. won't prepare function. _function = SimpleFunctionFactory::instance().get_function( - _fn.name.function_name, argument_template, _data_type, state->be_exec_version()); + _fn.name.function_name, argument_template, _data_type, + {.enable_decimal256 = state->enable_decimal256()}, state->be_exec_version()); } if (_function == nullptr) { return Status::InternalError("Could not find function {}, arg {} return {} ", diff --git a/be/src/vec/exprs/vin_predicate.cpp b/be/src/vec/exprs/vin_predicate.cpp index b85a936ef37..efd757ddd8b 100644 --- a/be/src/vec/exprs/vin_predicate.cpp +++ b/be/src/vec/exprs/vin_predicate.cpp @@ -73,8 +73,9 @@ Status VInPredicate::prepare(RuntimeState* state, const RowDescriptor& desc, if (is_struct(arg_type) || is_array(arg_type) || is_map(arg_type)) { real_function_name = "collection_" + real_function_name; } - _function = SimpleFunctionFactory::instance().get_function(real_function_name, - argument_template, _data_type); + _function = SimpleFunctionFactory::instance().get_function( + real_function_name, argument_template, _data_type, + {.enable_decimal256 = state->enable_decimal256()}); if (_function == nullptr) { return Status::NotSupported("Function {} is not implemented", real_function_name); } diff --git a/be/src/vec/exprs/vmatch_predicate.cpp b/be/src/vec/exprs/vmatch_predicate.cpp index 8a64dec604b..c80933df13c 100644 --- a/be/src/vec/exprs/vmatch_predicate.cpp +++ b/be/src/vec/exprs/vmatch_predicate.cpp @@ -36,6 +36,7 @@ #include "common/status.h" #include "olap/rowset/segment_v2/inverted_index/analyzer/analyzer.h" #include "olap/rowset/segment_v2/inverted_index_reader.h" +#include "runtime/runtime_state.h" #include "vec/core/block.h" #include "vec/core/column_numbers.h" #include "vec/core/column_with_type_and_name.h" @@ -81,8 +82,9 @@ Status VMatchPredicate::prepare(RuntimeState* state, const RowDescriptor& desc, child_expr_name.emplace_back(child->expr_name()); } - _function = SimpleFunctionFactory::instance().get_function(_fn.name.function_name, - argument_template, _data_type); + _function = SimpleFunctionFactory::instance().get_function( + _fn.name.function_name, argument_template, _data_type, + {.enable_decimal256 = state->enable_decimal256()}); if (_function == nullptr) { std::string type_str; for (auto arg : argument_template) { diff --git a/be/src/vec/exprs/vtopn_pred.h b/be/src/vec/exprs/vtopn_pred.h index 675c8fb293c..044bc28b261 100644 --- a/be/src/vec/exprs/vtopn_pred.h +++ b/be/src/vec/exprs/vtopn_pred.h @@ -71,7 +71,7 @@ public: _function = SimpleFunctionFactory::instance().get_function( _predicate->is_asc() ? "le" : "ge", argument_template, _data_type, - state->be_exec_version()); + {.enable_decimal256 = state->enable_decimal256()}, state->be_exec_version()); if (!_function) { return Status::InternalError("get function failed"); } diff --git a/be/src/vec/functions/array/function_array_aggregation.cpp b/be/src/vec/functions/array/function_array_aggregation.cpp index 18367816bc8..24d82f7894a 100644 --- a/be/src/vec/functions/array/function_array_aggregation.cpp +++ b/be/src/vec/functions/array/function_array_aggregation.cpp @@ -52,7 +52,7 @@ namespace vectorized { enum class AggregateOperation { MIN, MAX, SUM, AVERAGE, PRODUCT }; -template <typename Element, AggregateOperation operation> +template <typename Element, AggregateOperation operation, bool enable_decimal256 = false> struct ArrayAggregateResultImpl; template <typename Element> @@ -70,11 +70,21 @@ struct ArrayAggregateResultImpl<Element, AggregateOperation::AVERAGE> { using Result = DisposeDecimal<Element, Float64>; }; +template <typename Element> +struct ArrayAggregateResultImpl<Element, AggregateOperation::AVERAGE, true> { + using Result = DisposeDecimal256<Element, Float64>; +}; + template <typename Element> struct ArrayAggregateResultImpl<Element, AggregateOperation::PRODUCT> { using Result = DisposeDecimal<Element, Float64>; }; +template <typename Element> +struct ArrayAggregateResultImpl<Element, AggregateOperation::PRODUCT, true> { + using Result = DisposeDecimal256<Element, Float64>; +}; + template <typename Element> struct ArrayAggregateResultImpl<Element, AggregateOperation::SUM> { using Result = DisposeDecimal< @@ -82,13 +92,21 @@ struct ArrayAggregateResultImpl<Element, AggregateOperation::SUM> { std::conditional_t<IsFloatNumber<Element>, Float64, std::conditional_t<std::is_same_v<Element, Int128>, Int128, Int64>>>; }; +template <typename Element> +struct ArrayAggregateResultImpl<Element, AggregateOperation::SUM, true> { + using Result = DisposeDecimal256< + Element, + std::conditional_t<IsFloatNumber<Element>, Float64, + std::conditional_t<std::is_same_v<Element, Int128>, Int128, Int64>>>; +}; -template <typename Element, AggregateOperation operation> -using ArrayAggregateResult = typename ArrayAggregateResultImpl<Element, operation>::Result; +template <typename Element, AggregateOperation operation, bool enable_decimal256 = false> +using ArrayAggregateResult = + typename ArrayAggregateResultImpl<Element, operation, enable_decimal256>::Result; // For MIN/MAX, the type of result is the same as the type of elements, we can omit the // template specialization. -template <AggregateOperation operation> +template <AggregateOperation operation, bool enable_decimal256 = false> struct AggregateFunctionImpl; template <> @@ -100,6 +118,15 @@ struct AggregateFunctionImpl<AggregateOperation::SUM> { using Function = AggregateFunctionSum<Element, ResultType, AggregateDataType>; }; }; +template <> +struct AggregateFunctionImpl<AggregateOperation::SUM, true> { + template <typename Element> + struct TypeTraits { + using ResultType = ArrayAggregateResult<Element, AggregateOperation::SUM, true>; + using AggregateDataType = AggregateFunctionSumData<ResultType>; + using Function = AggregateFunctionSum<Element, ResultType, AggregateDataType>; + }; +}; template <> struct AggregateFunctionImpl<AggregateOperation::AVERAGE> { @@ -113,6 +140,18 @@ struct AggregateFunctionImpl<AggregateOperation::AVERAGE> { }; }; +template <> +struct AggregateFunctionImpl<AggregateOperation::AVERAGE, true> { + template <typename Element> + struct TypeTraits { + using ResultType = ArrayAggregateResult<Element, AggregateOperation::AVERAGE, true>; + using AggregateDataType = AggregateFunctionAvgData<ResultType>; + using Function = AggregateFunctionAvg<Element, AggregateDataType>; + static_assert(std::is_same_v<ResultType, typename Function::ResultType>, + "ResultType doesn't match."); + }; +}; + template <> struct AggregateFunctionImpl<AggregateOperation::PRODUCT> { template <typename Element> @@ -123,6 +162,16 @@ struct AggregateFunctionImpl<AggregateOperation::PRODUCT> { }; }; +template <> +struct AggregateFunctionImpl<AggregateOperation::PRODUCT, true> { + template <typename Element> + struct TypeTraits { + using ResultType = ArrayAggregateResult<Element, AggregateOperation::PRODUCT, true>; + using AggregateDataType = AggregateFunctionProductData<ResultType>; + using Function = AggregateFunctionProduct<Element, ResultType, AggregateDataType>; + }; +}; + template <typename Derived> struct AggregateFunction { template <typename T> @@ -133,7 +182,7 @@ struct AggregateFunction { } }; -template <AggregateOperation operation> +template <AggregateOperation operation, bool enable_decimal256 = false> struct ArrayAggregateImpl { using column_type = ColumnArray; using data_type = DataTypeArray; @@ -143,21 +192,9 @@ struct ArrayAggregateImpl { static size_t _get_number_of_arguments() { return 1; } static DataTypePtr get_return_type(const DataTypes& arguments) { - using Function = AggregateFunction<AggregateFunctionImpl<operation>>; + using Function = AggregateFunction<AggregateFunctionImpl<operation, enable_decimal256>>; const DataTypeArray* data_type_array = static_cast<const DataTypeArray*>(remove_nullable(arguments[0]).get()); - if constexpr (operation != AggregateOperation::MIN && - operation != AggregateOperation::MAX) { - // only array_min and array_max support decimal256 type - if (is_decimal(remove_nullable(data_type_array->get_nested_type()))) { - const auto decimal_type = remove_nullable(data_type_array->get_nested_type()); - if (check_decimal<Decimal256>(*decimal_type)) { - throw doris::Exception( - ErrorCode::INVALID_ARGUMENT, "Unexpected type {} for aggregation {}", - data_type_array->get_nested_type()->get_name(), operation); - } - } - } auto function = Function::create(data_type_array->get_nested_type()); if (function) { return function->get_return_type(); @@ -203,9 +240,9 @@ struct ArrayAggregateImpl { static bool execute_type(ColumnPtr& res_ptr, const DataTypePtr& type, const IColumn* data, const ColumnArray::Offsets64& offsets) { using ColVecType = ColumnVectorOrDecimal<Element>; - using ResultType = ArrayAggregateResult<Element, operation>; + using ResultType = ArrayAggregateResult<Element, operation, enable_decimal256>; using ColVecResultType = ColumnVectorOrDecimal<ResultType>; - using Function = AggregateFunction<AggregateFunctionImpl<operation>>; + using Function = AggregateFunction<AggregateFunctionImpl<operation, enable_decimal256>>; const ColVecType* column = data->is_nullable() @@ -275,34 +312,57 @@ struct AggregateFunction<AggregateFunctionImpl<AggregateOperation::MAX>> { struct NameArraySum { static constexpr auto name = "array_sum"; }; +struct NameArraySumDecimal256 { + static constexpr auto name = "array_sum_decimal256"; +}; struct NameArrayAverage { static constexpr auto name = "array_avg"; }; +struct NameArrayAverageDecimal256 { + static constexpr auto name = "array_avg_decimal256"; +}; struct NameArrayProduct { static constexpr auto name = "array_product"; }; +struct NameArrayProductDecimal256 { + static constexpr auto name = "array_product_decimal256"; +}; + using FunctionArrayMin = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::MIN>, NameArrayMin>; using FunctionArrayMax = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::MAX>, NameArrayMax>; using FunctionArraySum = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::SUM>, NameArraySum>; +using FunctionArraySumDecimal256 = + FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::SUM, true>, + NameArraySumDecimal256>; using FunctionArrayAverage = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::AVERAGE>, NameArrayAverage>; +using FunctionArrayAverageDecimal256 = + FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::AVERAGE, true>, + NameArrayAverageDecimal256>; using FunctionArrayProduct = FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::PRODUCT>, NameArrayProduct>; +using FunctionArrayProductDecimal256 = + FunctionArrayMapped<ArrayAggregateImpl<AggregateOperation::PRODUCT, true>, + NameArrayProductDecimal256>; + using FunctionArrayJoin = FunctionArrayMapped<ArrayJoinImpl, NameArrayJoin>; void register_function_array_aggregation(SimpleFunctionFactory& factory) { factory.register_function<FunctionArrayMin>(); factory.register_function<FunctionArrayMax>(); factory.register_function<FunctionArraySum>(); + factory.register_function<FunctionArraySumDecimal256>(); factory.register_function<FunctionArrayAverage>(); + factory.register_function<FunctionArrayAverageDecimal256>(); factory.register_function<FunctionArrayProduct>(); + factory.register_function<FunctionArrayProductDecimal256>(); factory.register_function<FunctionArrayJoin>(); } diff --git a/be/src/vec/functions/array/function_array_cum_sum.cpp b/be/src/vec/functions/array/function_array_cum_sum.cpp index 970be746632..24750b55f6c 100644 --- a/be/src/vec/functions/array/function_array_cum_sum.cpp +++ b/be/src/vec/functions/array/function_array_cum_sum.cpp @@ -35,13 +35,18 @@ namespace doris::vectorized { // array_cum_sum([1, 2, 3, 4, 5]) -> [1, 3, 6, 10, 15] // array_cum_sum([1, NULL, 3, NULL, 5]) -> [1, NULL, 4, NULL, 9] +template <bool enable_decimal256> class FunctionArrayCumSum : public IFunction { public: using NullMapType = PaddedPODArray<UInt8>; - static constexpr auto name = "array_cum_sum"; + static constexpr auto name = enable_decimal256 ? "array_cum_sum_decimal256" : "array_cum_sum"; - static FunctionPtr create() { return std::make_shared<FunctionArrayCumSum>(); } + static FunctionPtr create() { + return std::make_shared<FunctionArrayCumSum<enable_decimal256>>(); + } + + using DecimalResultType = std::conditional_t<enable_decimal256, Decimal256, Decimal128V3>; String get_name() const override { return name; } @@ -76,8 +81,8 @@ public: return_type = std::make_shared<DataTypeDecimal<Decimal128V2>>( DataTypeDecimal<Decimal128V2>::max_precision(), scale); } else if (which.is_decimal()) { - return_type = std::make_shared<DataTypeDecimal<Decimal128V3>>( - DataTypeDecimal<Decimal128V3>::max_precision(), scale); + return_type = std::make_shared<DataTypeDecimal<DecimalResultType>>( + DataTypeDecimal<DecimalResultType>::max_precision(), scale); } if (return_type) { return std::make_shared<DataTypeArray>(make_nullable(return_type)); @@ -163,14 +168,17 @@ private: res = _execute_number<Float64, Float64>(src_column, src_offsets, src_null_map, res_nested_ptr); } else if (which.is_decimal32()) { - res = _execute_number<Decimal32, Decimal128V3>(src_column, src_offsets, src_null_map, - res_nested_ptr); + res = _execute_number<Decimal32, DecimalResultType>(src_column, src_offsets, + src_null_map, res_nested_ptr); } else if (which.is_decimal64()) { - res = _execute_number<Decimal64, Decimal128V3>(src_column, src_offsets, src_null_map, - res_nested_ptr); + res = _execute_number<Decimal64, DecimalResultType>(src_column, src_offsets, + src_null_map, res_nested_ptr); } else if (which.is_decimal128v3()) { - res = _execute_number<Decimal128V3, Decimal128V3>(src_column, src_offsets, src_null_map, - res_nested_ptr); + res = _execute_number<Decimal128V3, DecimalResultType>(src_column, src_offsets, + src_null_map, res_nested_ptr); + } else if (which.is_decimal256()) { + res = _execute_number<Decimal256, DecimalResultType>(src_column, src_offsets, + src_null_map, res_nested_ptr); } else if (which.is_decimal128v2()) { res = _execute_number<Decimal128V2, Decimal128V2>(src_column, src_offsets, src_null_map, res_nested_ptr); @@ -244,7 +252,8 @@ private: }; void register_function_array_cum_sum(SimpleFunctionFactory& factory) { - factory.register_function<FunctionArrayCumSum>(); + factory.register_function<FunctionArrayCumSum<false>>(FunctionArrayCumSum<false>::name); + factory.register_function<FunctionArrayCumSum<true>>(FunctionArrayCumSum<true>::name); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/functions/comparison_equal_for_null.cpp b/be/src/vec/functions/comparison_equal_for_null.cpp index 49db471f8d1..24c669094a5 100644 --- a/be/src/vec/functions/comparison_equal_for_null.cpp +++ b/be/src/vec/functions/comparison_equal_for_null.cpp @@ -24,6 +24,7 @@ #include <utility> #include "common/status.h" +#include "runtime/runtime_state.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_const.h" @@ -180,8 +181,9 @@ public: ""}}; Block temporary_block(eq_columns); - auto func_eq = - SimpleFunctionFactory::instance().get_function("eq", eq_columns, return_type); + auto func_eq = SimpleFunctionFactory::instance().get_function( + "eq", eq_columns, return_type, + {.enable_decimal256 = context ? context->state()->enable_decimal256() : false}); DCHECK(func_eq) << fmt::format("Left type {} right type {} return type {}", col_left.type->get_name(), col_right.type->get_name(), return_type->get_name()); @@ -219,8 +221,9 @@ public: const ColumnsWithTypeAndName eq_columns { ColumnWithTypeAndName {col_left.column, col_left.type, ""}, ColumnWithTypeAndName {col_right.column, col_right.type, ""}}; - auto func_eq = - SimpleFunctionFactory::instance().get_function("eq", eq_columns, return_type); + auto func_eq = SimpleFunctionFactory::instance().get_function( + "eq", eq_columns, return_type, + {.enable_decimal256 = context ? context->state()->enable_decimal256() : false}); DCHECK(func_eq); Block temporary_block(eq_columns); diff --git a/be/src/vec/functions/function.h b/be/src/vec/functions/function.h index 1b0c94771dd..4702a4b7af0 100644 --- a/be/src/vec/functions/function.h +++ b/be/src/vec/functions/function.h @@ -47,6 +47,10 @@ struct FuncExprParams; namespace doris::vectorized { +struct FunctionAttr { + bool enable_decimal256 {false}; +}; + #define RETURN_REAL_TYPE_FOR_DATEV2_FUNCTION(TYPE) \ bool is_nullable = false; \ bool is_datev2 = false; \ diff --git a/be/src/vec/functions/function_coalesce.cpp b/be/src/vec/functions/function_coalesce.cpp index dbe75cf1408..d3450e97e98 100644 --- a/be/src/vec/functions/function_coalesce.cpp +++ b/be/src/vec/functions/function_coalesce.cpp @@ -26,6 +26,7 @@ #include <vector> #include "common/status.h" +#include "runtime/runtime_state.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_complex.h" @@ -55,6 +56,7 @@ class FunctionCoalesce : public IFunction { public: static constexpr auto name = "coalesce"; + mutable DataTypePtr result_type; mutable FunctionBasePtr func_is_not_null; static FunctionPtr create() { return std::make_shared<FunctionCoalesce>(); } @@ -68,25 +70,26 @@ public: size_t get_number_of_arguments() const override { return 0; } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - DataTypePtr res; for (const auto& arg : arguments) { if (!arg->is_nullable()) { - res = arg; + result_type = arg; break; } } - res = res ? res : arguments[0]; - - const ColumnsWithTypeAndName is_not_null_col {{nullptr, make_nullable(res), ""}}; - func_is_not_null = SimpleFunctionFactory::instance().get_function( - "is_not_null_pred", is_not_null_col, std::make_shared<DataTypeUInt8>()); - - return res; + result_type = result_type ? result_type : arguments[0]; + return result_type; } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { + if (!func_is_not_null) [[unlikely]] { + const ColumnsWithTypeAndName is_not_null_col { + {nullptr, make_nullable(result_type), ""}}; + func_is_not_null = SimpleFunctionFactory::instance().get_function( + "is_not_null_pred", is_not_null_col, std::make_shared<DataTypeUInt8>(), + {.enable_decimal256 = context->state()->enable_decimal256()}); + } DCHECK_GE(arguments.size(), 1); DataTypePtr result_type = block.get_by_position(result).type; ColumnNumbers filtered_args; diff --git a/be/src/vec/functions/function_ifnull.h b/be/src/vec/functions/function_ifnull.h index 2a43727f7ad..9deb7f8d71f 100644 --- a/be/src/vec/functions/function_ifnull.h +++ b/be/src/vec/functions/function_ifnull.h @@ -27,6 +27,7 @@ #include <memory> #include "common/status.h" +#include "runtime/runtime_state.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_nullable.h" @@ -117,7 +118,8 @@ public: }); auto func_if = SimpleFunctionFactory::instance().get_function( - "if", if_columns, block.get_by_position(result).type); + "if", if_columns, block.get_by_position(result).type, + {.enable_decimal256 = context->state()->enable_decimal256()}); RETURN_IF_ERROR(func_if->execute(context, temporary_block, {0, 1, 2}, 3, input_rows_count)); block.get_by_position(result).column = temporary_block.get_by_position(3).column; return Status::OK(); diff --git a/be/src/vec/functions/nullif.cpp b/be/src/vec/functions/nullif.cpp index 928fb1c0767..f122938df14 100644 --- a/be/src/vec/functions/nullif.cpp +++ b/be/src/vec/functions/nullif.cpp @@ -27,6 +27,7 @@ #include <vector> #include "common/status.h" +#include "runtime/runtime_state.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_const.h" @@ -95,8 +96,9 @@ public: block.get_by_position(arguments[1]), {nullptr, result_type, ""}}); - auto equals_func = - SimpleFunctionFactory::instance().get_function("eq", eq_columns, result_type); + auto equals_func = SimpleFunctionFactory::instance().get_function( + "eq", eq_columns, result_type, + {.enable_decimal256 = context->state()->enable_decimal256()}); DCHECK(equals_func); RETURN_IF_ERROR( equals_func->execute(context, eq_temporary_block, {0, 1}, 2, input_rows_count)); @@ -124,8 +126,9 @@ public: block.get_by_position(arguments[0]), new_result_column}); - auto func_if = SimpleFunctionFactory::instance().get_function("if", if_columns, - new_result_column.type); + auto func_if = SimpleFunctionFactory::instance().get_function( + "if", if_columns, new_result_column.type, + {.enable_decimal256 = context->state()->enable_decimal256()}); DCHECK(func_if); RETURN_IF_ERROR(func_if->execute(context, temporary_block, {0, 1, 2}, 3, input_rows_count)); block.get_by_position(result).column = temporary_block.get_by_position(3).column; diff --git a/be/src/vec/functions/simple_function_factory.h b/be/src/vec/functions/simple_function_factory.h index b434fcf2ae3..98f2917d163 100644 --- a/be/src/vec/functions/simple_function_factory.h +++ b/be/src/vec/functions/simple_function_factory.h @@ -28,6 +28,8 @@ namespace doris::vectorized { +constexpr auto DECIMAL256_FUNCTION_SUFFIX {"_decimal256"}; + class SimpleFunctionFactory; void register_function_size(SimpleFunctionFactory& factory); @@ -154,7 +156,7 @@ public: } FunctionBasePtr get_function(const std::string& name, const ColumnsWithTypeAndName& arguments, - const DataTypePtr& return_type, + const DataTypePtr& return_type, const FunctionAttr& attr = {}, int be_version = BeExecVersionManager::get_newest_version()) { std::string key_str = name; @@ -162,6 +164,13 @@ public: key_str = function_alias[name]; } + if (attr.enable_decimal256) { + if (key_str == "array_sum" || key_str == "array_avg" || key_str == "array_product" || + key_str == "array_cum_sum") { + key_str += DECIMAL256_FUNCTION_SUFFIX; + } + } + temporary_function_update(be_version, key_str); // if function is variadic, added types_str as key diff --git a/be/test/vec/aggregate_functions/agg_linear_histogram_test.cpp b/be/test/vec/aggregate_functions/agg_linear_histogram_test.cpp index 9406e96f4a2..3dbf34a4dcb 100644 --- a/be/test/vec/aggregate_functions/agg_linear_histogram_test.cpp +++ b/be/test/vec/aggregate_functions/agg_linear_histogram_test.cpp @@ -204,7 +204,8 @@ public: << "(" << data_types[0]->get_name() << ")"; AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance(); - auto agg_function = factory.get("linear_histogram", data_types, false, -1, true); + auto agg_function = + factory.get("linear_histogram", data_types, false, -1, {.enable_decimal256 = true}); EXPECT_NE(agg_function, nullptr); std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); diff --git a/regression-test/data/query_p0/aggregate/aggregate_decimal256.out b/regression-test/data/datatype_p0/decimalv3/aggregate_decimal256.out similarity index 95% rename from regression-test/data/query_p0/aggregate/aggregate_decimal256.out rename to regression-test/data/datatype_p0/decimalv3/aggregate_decimal256.out index 95d30e83374..c61af8c712e 100644 --- a/regression-test/data/query_p0/aggregate/aggregate_decimal256.out +++ b/regression-test/data/datatype_p0/decimalv3/aggregate_decimal256.out @@ -95,3 +95,11 @@ -- !sql_count_3 -- 8 8 +-- !sql_distinct_count_1 -- +1 1 +2 1 + +-- !sql_distinct_count_2 -- +1 1 3 +2 1 3 + diff --git a/regression-test/data/datatype_p0/decimalv3/test_decimal256_array.out b/regression-test/data/datatype_p0/decimalv3/test_decimal256_array.out new file mode 100644 index 00000000000..d3c63e8ae79 --- /dev/null +++ b/regression-test/data/datatype_p0/decimalv3/test_decimal256_array.out @@ -0,0 +1,63 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !decimal256_array_sum -- +6 6.600000000000000000000000000000000000 + +-- !decimal256_array_sum2 -- +6 6.600000000000000000000000000000000000 +15 16.500000000000000000000000000000000000 +24 26.400000000000000000000000000000000000 + +-- !decimal256_array_sum3 -- +15 16.500000000000000000000000000000000000 +24 26.400000000000000000000000000000000000 + +-- !decimal256_array_avg -- +2 2.200000000000000000000000000000000000 +5 5.500000000000000000000000000000000000 +8 8.800000000000000000000000000000000000 + +-- !decimal256_array_avg1 -- + +-- !decimal256_array_avg2 -- +5 5.500000000000000000000000000000000000 +8 8.800000000000000000000000000000000000 + +-- !decimal256_array_avg3 -- +[4, 5, 6] 5 [4.400000000000000000000000000000000000, 5.500000000000000000000000000000000000, 6.600000000000000000000000000000000000] 5.500000000000000000000000000000000000 +[7, 8, 9, null] 8 [7.700000000000000000000000000000000000, 8.800000000000000000000000000000000000, 9.900000000000000000000000000000000000, null] 8.800000000000000000000000000000000000 + +-- !decimal256_array_product -- +6 7.986000000000000000000000000000000000 + +-- !decimal256_array_product2 -- +120 159.720000000000000000000000000000000000 +504 670.824000000000000000000000000000000000 +6 7.986000000000000000000000000000000000 + +-- !decimal256_array_product3 -- +120 159.720000000000000000000000000000000000 +504 670.824000000000000000000000000000000000 + +-- !decimal256_array_cum_sum -- +[1, 2, 3] [1, 3, 6] [1.100000000000000000000000000000000000, 2.200000000000000000000000000000000000, 3.300000000000000000000000000000000000] [1.100000000000000000000000000000000000, 3.300000000000000000000000000000000000, 6.600000000000000000000000000000000000] +[4, 5, 6] [4, 9, 15] [4.400000000000000000000000000000000000, 5.500000000000000000000000000000000000, 6.600000000000000000000000000000000000] [4.400000000000000000000000000000000000, 9.900000000000000000000000000000000000, 16.500000000000000000000000000000000000] +[7, 8, 9, null] [7, 15, 24, null] [7.700000000000000000000000000000000000, 8.800000000000000000000000000000000000, 9.900000000000000000000000000000000000, null] [7.700000000000000000000000000000000000, 16.500000000000000000000000000000000000, 26.400000000000000000000000000000000000, null] + +-- !decimal256_array_cum_sum2 -- +[1, 2, 3] [1, 3, 6] [1.100000000000000000000000000000000000, 2.200000000000000000000000000000000000, 3.300000000000000000000000000000000000] [1.100000000000000000000000000000000000, 3.300000000000000000000000000000000000, 6.600000000000000000000000000000000000] + +-- !decimal256_array_cum_sum3 -- +[4, 5, 6] [4, 9, 15] [4.400000000000000000000000000000000000, 5.500000000000000000000000000000000000, 6.600000000000000000000000000000000000] [4.400000000000000000000000000000000000, 9.900000000000000000000000000000000000, 16.500000000000000000000000000000000000] + +-- !decimal256_sum_foreach -- +[12, 15, 18, null] [13.200000000000000000000000000000000000, 16.500000000000000000000000000000000000, 19.800000000000000000000000000000000000, null] + +-- !decimal256_sum_foreach2 -- +[12, 15, 18, null] [13.200000000000000000000000000000000000, 16.500000000000000000000000000000000000, 19.800000000000000000000000000000000000, null] + +-- !decimal256_sum_foreach3 -- +[12, 15, 18, null] [13.200000000000000000000000000000000000, 16.500000000000000000000000000000000000, 19.800000000000000000000000000000000000, null] + +-- !decimal256_avg_foreach2 -- +[4.400000000000000000000000000000000000, 5.500000000000000000000000000000000000, 6.600000000000000000000000000000000000, null] + diff --git a/regression-test/data/datatype_p0/decimalv3/test_decimal256_multi_distinct.out b/regression-test/data/datatype_p0/decimalv3/test_decimal256_multi_distinct.out new file mode 100644 index 00000000000..d8b7dd273f1 --- /dev/null +++ b/regression-test/data/datatype_p0/decimalv3/test_decimal256_multi_distinct.out @@ -0,0 +1,33 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !decimal256_multi_distinct_sum -- +1 3 3.300000000000000000000000000000000000 +2 5 5.500000000000000000000000000000000000 + +-- !decimal256_multi_distinct_sum2 -- +1 3 3.300000000000000000000000000000000000 + +-- !decimal256_multi_distinct_sum3 -- +2 5 5.500000000000000000000000000000000000 + +-- !decimal256_multi_distinct_count -- +1 2 +2 2 + +-- !decimal256_multi_distinct_count2 -- +1 2 +2 2 + +-- !decimal256_multi_distinct_count3 -- +1 2 2 +2 2 2 + +-- !decimal256_multi_distinct_avg -- +1 1.50000000 1.650000000000000000000000000000000000 +2 2.50000000 2.750000000000000000000000000000000000 + +-- !decimal256_multi_distinct_avg2 -- +1 1.50000000 1.650000000000000000000000000000000000 + +-- !decimal256_multi_distinct_avg3 -- +2 2.50000000 2.750000000000000000000000000000000000 + diff --git a/regression-test/data/nereids_function_p0/scalar_function/Array.out b/regression-test/data/nereids_function_p0/scalar_function/Array.out index cbfbf3a9519..9570ac178d7 100644 --- a/regression-test/data/nereids_function_p0/scalar_function/Array.out +++ b/regression-test/data/nereids_function_p0/scalar_function/Array.out @@ -14658,6 +14658,312 @@ true 9876543210.54321098765432109876543210987654321000000000000000000000 98765432109876543210.65432109876543210987654321098765432100000000000000000000 +-- !sql_array_product_decimal256 -- +-1.11089030379753379825E-36 +-1.15124609027348871047E-36 +-1.20316083385972392568E-36 +-1.21167050052433434998E-36 +-1.33059735498032041369E-36 +-1.33059735498032041369E-36 +-1.82156954541119249033E-36 +-1.82156954541119249033E-36 +-1.85863116714366281287E-36 +-2.37374509709306887541E-36 +-2.41386418671374962135E-36 +-2.45988930136657294736E-36 +-2.46417160445103636709E-36 +-2.56049975893912100561E-36 +-2.75305782044403083264E-36 +-2.79910149694808391403E-36 +-2.79910149694808391403E-36 +-2.86499291870397349138E-36 +-3.33566875141199191869E-36 +-3.5368044514285609949E-37 +-3.5368044514285609949E-37 +-3.67937523172259205538E-36 +-3.68452126603999106590E-36 +-3.71030605689385405875E-36 +-3.78526621062516991080E-36 +-3.79632207995656186667E-36 +-3.81707054517283638897E-36 +-3.9173386463672416041E-37 +-3.96551467931687567991E-36 +-3.988154730296666392E-38 +-3.988154730296666392E-38 +-4.27166987302514918639E-36 +-4.28729427038743394332E-36 +-4.28729427038743394332E-36 +-4.40936518647539351998E-36 +-4.65676335758612546241E-36 +-4.67143416231321849951E-36 +-4.6935900034901270766E-37 +-4.74099232235060022374E-36 +-4.92034358705743192696E-36 +-4.92034358705743192696E-36 +-5.03348501146740850819E-36 +-5.10840633177794202111E-36 +-5.11291847115851716824E-36 +-5.19906267543202124019E-36 +-5.41131577748830400361E-36 +-5.41131577748830400361E-36 +-5.53827487101353220685E-36 +-5.72666878488522669982E-36 +-5.73639862303715509180E-36 +-6.8706967526543668364E-37 +-7.3094606021823512708E-37 +-8.743288118444520341E-38 +-8.7962773677034651067E-37 +-8.854001720170268951E-38 +-9.9156283503028916941E-37 +1.0942410564979919373E-37 +1.14561704288254771601E-36 +1.32496830758937941923E-36 +1.34855780630288153772E-36 +1.36056433544801845896E-36 +1.36056433544801845896E-36 +1.68833162001331406487E-36 +1.85153652587889053561E-36 +1.91579735501810526744E-36 +1.96583563857929146014E-36 +1.98050644330638449725E-36 +1.98050644330638449725E-36 +1.98050644330638449725E-36 +2.07083635956378692614E-36 +2.14710656165020445904E-36 +2.38143936389998857707E-36 +2.62854206325881343579E-36 +2.69186193041133429379E-36 +3.09470099967476574588E-36 +3.25293651086094412867E-36 +3.30160104948051804097E-36 +3.30782999752085701738E-36 +3.37632574098352471282E-36 +3.59036876403224351683E-36 +3.82225275285588660715E-36 +3.85393652698006881922E-36 +4.03213219243240234106E-36 +4.09973770951346675179E-36 +4.7120770865390877655E-37 +4.71878879359359439025E-36 +4.71878879359359439025E-36 +4.71878879359359439025E-36 +4.74320929952406375732E-36 +4.77355332278037270178E-36 +4.77355332278037270178E-36 +5.13827656539993655728E-36 +5.20652113080353816190E-36 +5.27853651944920585633E-36 +5.43410896703278603398E-36 +5.65852420076209230378E-36 +5.75345528002125840233E-36 +5.76732577445675949395E-36 +8.4181413147842139986E-37 +9.1060101600728325984E-37 + +-- !sql_array_avg_decimal256 -- +1.12345678901234567890123456789012345600000000000000000000 +1234567890.09876543210987654321098765432109876500000000000000000000 +12345678901234567.43210987654321098765432109876543210900000000000000000000 +123456789012345678.23456789012345678901234567890123456700000000000000000000 +123456789012345678.23456789012345678901234567890123456700000000000000000000 +1234567890123456789.09876543210987654321098765432109876500000000000000000000 +1234567890123456789.09876543210987654321098765432109876500000000000000000000 +1234567890123456789.34567890123456789012345678901234567800000000000000000000 +1234567890123456789.43210987654321098765432109876543210900000000000000000000 +1234567890123456789.54321098765432109876543210987654321000000000000000000000 +1234567890123456789.54321098765432109876543210987654321000000000000000000000 +1234567890123456789.65432109876543210987654321098765432100000000000000000000 +1234567890123456789.65432109876543210987654321098765432100000000000000000000 +1234567890123456789.98765432109876543210987654321098765400000000000000000000 +1234567890123456789.98765432109876543210987654321098765400000000000000000000 +12345678901234567890.09876543210987654321098765432109876500000000000000000000 +12345678901234567890.09876543210987654321098765432109876500000000000000000000 +12345678901234567890.23456789012345678901234567890123456700000000000000000000 +12345678901234567890.54321098765432109876543210987654321000000000000000000000 +12345678901234567890.65432109876543210987654321098765432100000000000000000000 +12345678901234567890.65432109876543210987654321098765432100000000000000000000 +12345678901234567890.98765432109876543210987654321098765400000000000000000000 +2345678901234567.56789012345678901234567890123456789000000000000000000000 +234567890123456789.54321098765432109876543210987654321000000000000000000000 +234567890123456789.65432109876543210987654321098765432100000000000000000000 +2345678901234567890.09876543210987654321098765432109876500000000000000000000 +2345678901234567890.23456789012345678901234567890123456700000000000000000000 +2345678901234567890.54321098765432109876543210987654321000000000000000000000 +2345678901234567890.98765432109876543210987654321098765400000000000000000000 +23456789012345678901.65432109876543210987654321098765432100000000000000000000 +23456789012345678901.65432109876543210987654321098765432100000000000000000000 +23456789012345678901.87654321098765432109876543210987654321000000000000000000 +345678901234567890.23456789012345678901234567890123456700000000000000000000 +345678901234567890.34567890123456789012345678901234567800000000000000000000 +345678901234567890.65432109876543210987654321098765432100000000000000000000 +345678901234567890.98765432109876543210987654321098765400000000000000000000 +3456789012345678901.09876543210987654321098765432109876500000000000000000000 +3456789012345678901.23456789012345678901234567890123456700000000000000000000 +3456789012345678901.34567890123456789012345678901234567800000000000000000000 +3456789012345678901.87654321098765432109876543210987654321000000000000000000 +3456789012345678901.98765432109876543210987654321098765400000000000000000000 +34567890123456789012.09876543210987654321098765432109876500000000000000000000 +34567890123456789012.43210987654321098765432109876543210900000000000000000000 +4567890123456789.09876543210987654321098765432109876500000000000000000000 +4567890123456789012.12345678901234567890123456789012345600000000000000000000 +4567890123456789012.12345678901234567890123456789012345600000000000000000000 +4567890123456789012.23456789012345678901234567890123456700000000000000000000 +4567890123456789012.65432109876543210987654321098765432100000000000000000000 +4567890123456789012.98765432109876543210987654321098765400000000000000000000 +45678901234567890123.12345678901234567890123456789012345600000000000000000000 +567890123456789012.34567890123456789012345678901234567800000000000000000000 +567890123456789012.54321098765432109876543210987654321000000000000000000000 +5678901234567890123.23456789012345678901234567890123456700000000000000000000 +5678901234567890123.43210987654321098765432109876543210900000000000000000000 +5678901234567890123.65432109876543210987654321098765432100000000000000000000 +5678901234567890123.87654321098765432109876543210987654321000000000000000000 +5678901234567890123.98765432109876543210987654321098765400000000000000000000 +56789012345678901234.34567890123456789012345678901234567800000000000000000000 +56789012345678901234.43210987654321098765432109876543210900000000000000000000 +56789012345678901234.54321098765432109876543210987654321000000000000000000000 +67890123456789012.34567890123456789012345678901234567800000000000000000000 +67890123456789012.98765432109876543210987654321098765400000000000000000000 +678901234567890123.09876543210987654321098765432109876500000000000000000000 +678901234567890123.43210987654321098765432109876543210900000000000000000000 +678901234567890123.43210987654321098765432109876543210900000000000000000000 +678901234567890123.54321098765432109876543210987654321000000000000000000000 +6789012345678901234.12345678901234567890123456789012345600000000000000000000 +6789012345678901234.98765432109876543210987654321098765400000000000000000000 +67890123456789012345.12345678901234567890123456789012345600000000000000000000 +67890123456789012345.67890123456789012345678901234567890100000000000000000000 +789012345678901234.12345678901234567890123456789012345600000000000000000000 +789012345678901234.34567890123456789012345678901234567800000000000000000000 +789012345678901234.65432109876543210987654321098765432100000000000000000000 +789012345678901234.87654321098765432109876543210987654321000000000000000000 +7890123456789012345.09876543210987654321098765432109876500000000000000000000 +7890123456789012345.09876543210987654321098765432109876500000000000000000000 +7890123456789012345.09876543210987654321098765432109876500000000000000000000 +7890123456789012345.54321098765432109876543210987654321000000000000000000000 +7890123456789012345.98765432109876543210987654321098765400000000000000000000 +78901234567890123456.09876543210987654321098765432109876500000000000000000000 +78901234567890123456.23456789012345678901234567890123456700000000000000000000 +78901234567890123456.65432109876543210987654321098765432100000000000000000000 +89012345678901234.34567890123456789012345678901234567800000000000000000000 +89012345678901234.98765432109876543210987654321098765400000000000000000000 +890123456789012345.09876543210987654321098765432109876500000000000000000000 +890123456789012345.12345678901234567890123456789012345600000000000000000000 +890123456789012345.54321098765432109876543210987654321000000000000000000000 +8901234567890123456.23456789012345678901234567890123456700000000000000000000 +8901234567890123456.34567890123456789012345678901234567800000000000000000000 +8901234567890123456.43210987654321098765432109876543210900000000000000000000 +8901234567890123456.87654321098765432109876543210987654321000000000000000000 +8901234567890123456.98765432109876543210987654321098765400000000000000000000 +8901234567890123456.98765432109876543210987654321098765400000000000000000000 +8901234567890123456.98765432109876543210987654321098765400000000000000000000 +89012345678901234567.09876543210987654321098765432109876500000000000000000000 +89012345678901234567.43210987654321098765432109876543210900000000000000000000 +89012345678901234567.54321098765432109876543210987654321000000000000000000000 +89012345678901234567.87654321098765432109876543210987654321000000000000000000 +9876543210.54321098765432109876543210987654321000000000000000000000 +98765432109876543210.65432109876543210987654321098765432100000000000000000000 + +-- !sql_array_sum_decimal256 -- +1.12345678901234567890123456789012345600000000000000000000 +1234567890.09876543210987654321098765432109876500000000000000000000 +12345678901234567.43210987654321098765432109876543210900000000000000000000 +123456789012345678.23456789012345678901234567890123456700000000000000000000 +123456789012345678.23456789012345678901234567890123456700000000000000000000 +1234567890123456789.09876543210987654321098765432109876500000000000000000000 +1234567890123456789.09876543210987654321098765432109876500000000000000000000 +1234567890123456789.34567890123456789012345678901234567800000000000000000000 +1234567890123456789.43210987654321098765432109876543210900000000000000000000 +1234567890123456789.54321098765432109876543210987654321000000000000000000000 +1234567890123456789.54321098765432109876543210987654321000000000000000000000 +1234567890123456789.65432109876543210987654321098765432100000000000000000000 +1234567890123456789.65432109876543210987654321098765432100000000000000000000 +1234567890123456789.98765432109876543210987654321098765400000000000000000000 +1234567890123456789.98765432109876543210987654321098765400000000000000000000 +12345678901234567890.09876543210987654321098765432109876500000000000000000000 +12345678901234567890.09876543210987654321098765432109876500000000000000000000 +12345678901234567890.23456789012345678901234567890123456700000000000000000000 +12345678901234567890.54321098765432109876543210987654321000000000000000000000 +12345678901234567890.65432109876543210987654321098765432100000000000000000000 +12345678901234567890.65432109876543210987654321098765432100000000000000000000 +12345678901234567890.98765432109876543210987654321098765400000000000000000000 +2345678901234567.56789012345678901234567890123456789000000000000000000000 +234567890123456789.54321098765432109876543210987654321000000000000000000000 +234567890123456789.65432109876543210987654321098765432100000000000000000000 +2345678901234567890.09876543210987654321098765432109876500000000000000000000 +2345678901234567890.23456789012345678901234567890123456700000000000000000000 +2345678901234567890.54321098765432109876543210987654321000000000000000000000 +2345678901234567890.98765432109876543210987654321098765400000000000000000000 +23456789012345678901.65432109876543210987654321098765432100000000000000000000 +23456789012345678901.65432109876543210987654321098765432100000000000000000000 +23456789012345678901.87654321098765432109876543210987654321000000000000000000 +345678901234567890.23456789012345678901234567890123456700000000000000000000 +345678901234567890.34567890123456789012345678901234567800000000000000000000 +345678901234567890.65432109876543210987654321098765432100000000000000000000 +345678901234567890.98765432109876543210987654321098765400000000000000000000 +3456789012345678901.09876543210987654321098765432109876500000000000000000000 +3456789012345678901.23456789012345678901234567890123456700000000000000000000 +3456789012345678901.34567890123456789012345678901234567800000000000000000000 +3456789012345678901.87654321098765432109876543210987654321000000000000000000 +3456789012345678901.98765432109876543210987654321098765400000000000000000000 +34567890123456789012.09876543210987654321098765432109876500000000000000000000 +34567890123456789012.43210987654321098765432109876543210900000000000000000000 +4567890123456789.09876543210987654321098765432109876500000000000000000000 +4567890123456789012.12345678901234567890123456789012345600000000000000000000 +4567890123456789012.12345678901234567890123456789012345600000000000000000000 +4567890123456789012.23456789012345678901234567890123456700000000000000000000 +4567890123456789012.65432109876543210987654321098765432100000000000000000000 +4567890123456789012.98765432109876543210987654321098765400000000000000000000 +45678901234567890123.12345678901234567890123456789012345600000000000000000000 +567890123456789012.34567890123456789012345678901234567800000000000000000000 +567890123456789012.54321098765432109876543210987654321000000000000000000000 +5678901234567890123.23456789012345678901234567890123456700000000000000000000 +5678901234567890123.43210987654321098765432109876543210900000000000000000000 +5678901234567890123.65432109876543210987654321098765432100000000000000000000 +5678901234567890123.87654321098765432109876543210987654321000000000000000000 +5678901234567890123.98765432109876543210987654321098765400000000000000000000 +56789012345678901234.34567890123456789012345678901234567800000000000000000000 +56789012345678901234.43210987654321098765432109876543210900000000000000000000 +56789012345678901234.54321098765432109876543210987654321000000000000000000000 +67890123456789012.34567890123456789012345678901234567800000000000000000000 +67890123456789012.98765432109876543210987654321098765400000000000000000000 +678901234567890123.09876543210987654321098765432109876500000000000000000000 +678901234567890123.43210987654321098765432109876543210900000000000000000000 +678901234567890123.43210987654321098765432109876543210900000000000000000000 +678901234567890123.54321098765432109876543210987654321000000000000000000000 +6789012345678901234.12345678901234567890123456789012345600000000000000000000 +6789012345678901234.98765432109876543210987654321098765400000000000000000000 +67890123456789012345.12345678901234567890123456789012345600000000000000000000 +67890123456789012345.67890123456789012345678901234567890100000000000000000000 +789012345678901234.12345678901234567890123456789012345600000000000000000000 +789012345678901234.34567890123456789012345678901234567800000000000000000000 +789012345678901234.65432109876543210987654321098765432100000000000000000000 +789012345678901234.87654321098765432109876543210987654321000000000000000000 +7890123456789012345.09876543210987654321098765432109876500000000000000000000 +7890123456789012345.09876543210987654321098765432109876500000000000000000000 +7890123456789012345.09876543210987654321098765432109876500000000000000000000 +7890123456789012345.54321098765432109876543210987654321000000000000000000000 +7890123456789012345.98765432109876543210987654321098765400000000000000000000 +78901234567890123456.09876543210987654321098765432109876500000000000000000000 +78901234567890123456.23456789012345678901234567890123456700000000000000000000 +78901234567890123456.65432109876543210987654321098765432100000000000000000000 +89012345678901234.34567890123456789012345678901234567800000000000000000000 +89012345678901234.98765432109876543210987654321098765400000000000000000000 +890123456789012345.09876543210987654321098765432109876500000000000000000000 +890123456789012345.12345678901234567890123456789012345600000000000000000000 +890123456789012345.54321098765432109876543210987654321000000000000000000000 +8901234567890123456.23456789012345678901234567890123456700000000000000000000 +8901234567890123456.34567890123456789012345678901234567800000000000000000000 +8901234567890123456.43210987654321098765432109876543210900000000000000000000 +8901234567890123456.87654321098765432109876543210987654321000000000000000000 +8901234567890123456.98765432109876543210987654321098765400000000000000000000 +8901234567890123456.98765432109876543210987654321098765400000000000000000000 +8901234567890123456.98765432109876543210987654321098765400000000000000000000 +89012345678901234567.09876543210987654321098765432109876500000000000000000000 +89012345678901234567.43210987654321098765432109876543210900000000000000000000 +89012345678901234567.54321098765432109876543210987654321000000000000000000000 +89012345678901234567.87654321098765432109876543210987654321000000000000000000 +9876543210.54321098765432109876543210987654321000000000000000000000 +98765432109876543210.65432109876543210987654321098765432100000000000000000000 + -- !sql_array_overlaps_1 -- \N \N diff --git a/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy b/regression-test/suites/datatype_p0/decimalv3/aggregate_decimal256.groovy similarity index 95% rename from regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy rename to regression-test/suites/datatype_p0/decimalv3/aggregate_decimal256.groovy index 88121ebb145..9fe63131d42 100644 --- a/regression-test/suites/query_p0/aggregate/aggregate_decimal256.groovy +++ b/regression-test/suites/datatype_p0/decimalv3/aggregate_decimal256.groovy @@ -150,6 +150,6 @@ suite("aggregate_decimal256") { qt_sql_count_2 """ select k1, count(cast(v1 as decimalv3(39, 6))), count(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2, 3; """ qt_sql_count_3 """ select count(cast(v1 as decimalv3(39, 6))), count(cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg order by 1, 2; """ - // qt_sql_distinct_count_1 """ select k1, count(distinct cast(v1 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2;""" - // qt_sql_distinct_count_2 """ select k1, count(distinct cast(v1 as decimalv3(39, 6))), count(distinct cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2, 3;""" + qt_sql_distinct_count_1 """ select k1, count(distinct cast(v1 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2;""" + qt_sql_distinct_count_2 """ select k1, count(distinct cast(v1 as decimalv3(39, 6))), count(distinct cast(v2 as decimalv3(39, 6))) from test_aggregate_decimal256_avg group by k1 order by 1, 2, 3;""" } diff --git a/regression-test/suites/datatype_p0/decimalv3/test_decimal256_array.groovy b/regression-test/suites/datatype_p0/decimalv3/test_decimal256_array.groovy new file mode 100644 index 00000000000..5ec838a8938 --- /dev/null +++ b/regression-test/suites/datatype_p0/decimalv3/test_decimal256_array.groovy @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_decimal256_array") { + sql "set enable_nereids_planner = true;" + sql "set enable_decimal256 = true;" + + sql """ + drop table if exists test_decimal256_array_agg; + """ + sql """ + create table test_decimal256_array_agg( + id int, + a array<decimalv3(10,0)>, + b array<decimalv3(76,36)>) properties('replication_num' = '1'); + """ + sql """ + insert into test_decimal256_array_agg values + (1, [1, 2, 3], [1.1, 2.2, 3.3]), + (2, [4, 5, 6], [4.4, 5.5, 6.6]), + (3, [7, 8, 9, null], [7.7, 8.8, 9.9, null]); + """ + + order_qt_decimal256_array_sum """ + select array_sum(a), array_sum(b) from test_decimal256_array_agg + where array_sum(a) = 6; + """ + qt_decimal256_array_sum2 """ + select array_sum(a), array_sum(b) from test_decimal256_array_agg + where array_sum(a) > 0 order by array_sum(a), array_sum(b); + """ + order_qt_decimal256_array_sum3 """ + select array_sum(a), array_sum(b) from test_decimal256_array_agg + where array_sum(b) >= 16.5; + """ + + order_qt_decimal256_array_avg """ + select array_avg(a), array_avg(b) from test_decimal256_array_agg; + """ + order_qt_decimal256_array_avg1 """ + select array_avg(a), array_avg(b) from test_decimal256_array_agg + where array_avg(a) = 2.2; + """ + order_qt_decimal256_array_avg2 """ + select array_avg(a), array_avg(b) from test_decimal256_array_agg + where array_avg(a) >= 3; + """ + order_qt_decimal256_array_avg3 """ + select a, array_avg(a), b, array_avg(b) from test_decimal256_array_agg + where array_avg(b) >= 3; + """ + + order_qt_decimal256_array_product """ + select array_product(a), array_product(b) from test_decimal256_array_agg + where array_product(a) = 6; + """ + order_qt_decimal256_array_product2 """ + select array_product(a), array_product(b) from test_decimal256_array_agg + where array_product(a) >= 6; + """ + order_qt_decimal256_array_product3 """ + select array_product(a), array_product(b) from test_decimal256_array_agg + where array_product(b) >= 100; + """ + + order_qt_decimal256_array_cum_sum """ + select a, array_cum_sum(a), b, array_cum_sum(b) from test_decimal256_array_agg; + """ + order_qt_decimal256_array_cum_sum2 """ + select a, array_cum_sum(a), b, array_cum_sum(b) from test_decimal256_array_agg + where array_contains(array_cum_sum(a), 6); + """ + order_qt_decimal256_array_cum_sum3 """ + select a, array_cum_sum(a), b, array_cum_sum(b) from test_decimal256_array_agg + where array_contains(array_cum_sum(b), 9.9); + """ + + order_qt_decimal256_sum_foreach """ + select sum_foreach(a), sum_foreach(b) from test_decimal256_array_agg; + """ + order_qt_decimal256_sum_foreach2 """ + select * from ( + select sum_foreach(a) suma, sum_foreach(b) sumb from test_decimal256_array_agg + ) tmpa + where array_contains(suma, 12); + """ + order_qt_decimal256_sum_foreach3 """ + select * from ( + select sum_foreach(a) suma, sum_foreach(b) sumb from test_decimal256_array_agg + ) tmpa + where array_contains(sumb, 13.2); + """ + + // column_type not match data_types in agg node, column_type=Nullable(Array(Nullable(Decimal(76, 4)))), data_types=Nullable(Array(Nullable(Decimal(76, 0)))) + // order_qt_decimal256_avg_foreach """ + // select avg_foreach(a) from test_decimal256_array_agg; + // """ + order_qt_decimal256_avg_foreach2 """ + select avg_foreach(b) from test_decimal256_array_agg; + """ + // order_qt_decimal256_avg_foreach3 """ + // select avg_foreach(a), avg_foreach(b) from test_decimal256_array_agg; + // """ +} \ No newline at end of file diff --git a/regression-test/suites/datatype_p0/decimalv3/test_decimal256_multi_distinct.groovy b/regression-test/suites/datatype_p0/decimalv3/test_decimal256_multi_distinct.groovy new file mode 100644 index 00000000000..9097225a094 --- /dev/null +++ b/regression-test/suites/datatype_p0/decimalv3/test_decimal256_multi_distinct.groovy @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_decimal256_multi_distinct") { + sql "set enable_nereids_planner = true;" + sql "set enable_decimal256 = true;" + + sql """ + drop table if exists test_decimal256_multi_distinct; + """ + sql """ + create table test_decimal256_multi_distinct( + id int, + a decimal(10,0), + b decimal(76,36)) properties('replication_num' = '1'); + """ + sql """ + insert into test_decimal256_multi_distinct values + (1, 1, 1.1), + (1, 1, 2.2), + (1, 2, 2.2), + (2, 2, 2.2), + (2, 2, 2.2), + (2, 3, 3.3); + """ + order_qt_decimal256_multi_distinct_sum """ + select id, sum(distinct a) suma, sum(distinct b) sumb from test_decimal256_multi_distinct group by id; + """ + order_qt_decimal256_multi_distinct_sum2 """ + select id, sum(distinct a) suma, sum(distinct b) sumb from test_decimal256_multi_distinct group by id + having suma = 3; + """ + order_qt_decimal256_multi_distinct_sum3 """ + select id, sum(distinct a) suma, sum(distinct b) sumb from test_decimal256_multi_distinct group by id + having sumb = 5.5; + """ + + order_qt_decimal256_multi_distinct_count """ + select id, count(distinct a) counta from test_decimal256_multi_distinct group by id; + """ + order_qt_decimal256_multi_distinct_count2 """ + select id, count(distinct b) countb from test_decimal256_multi_distinct group by id; + """ + order_qt_decimal256_multi_distinct_count3 """ + select id, count(distinct a) counta, count(distinct b) countb from test_decimal256_multi_distinct group by id; + """ + + order_qt_decimal256_multi_distinct_avg """ + select id, avg(distinct a) avga, avg(distinct b) avgb from test_decimal256_multi_distinct group by id; + """ + order_qt_decimal256_multi_distinct_avg2 """ + select id, avg(distinct a) avga, avg(distinct b) avgb from test_decimal256_multi_distinct group by id + having avga = 1.5; + """ + order_qt_decimal256_multi_distinct_avg3 """ + select id, avg(distinct a) avga, avg(distinct b) avgb from test_decimal256_multi_distinct group by id + having avgb = 2.75; + """ +} \ No newline at end of file diff --git a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy index 8a7f08a883a..defa553279c 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy @@ -1354,27 +1354,9 @@ suite("nereids_scalar_fn_Array") { sql """ set enable_decimal256=true; """ order_qt_sql_array_min_decimal256 "select array_min(c) from fn_test_array_with_large_decimal order by id" order_qt_sql_array_max_decimal256 "select array_max(c) from fn_test_array_with_large_decimal order by id" - test { - sql "select array_product(c) from fn_test_array_with_large_decimal order by id" - check{result, exception, startTime, endTime -> - assertTrue(exception != null) - logger.info(exception.message) - } - } - test { - sql "select array_avg(c) from fn_test_array_with_large_decimal order by id" - check{result, exception, startTime, endTime -> - assertTrue(exception != null) - logger.info(exception.message) - } - } - test { - sql "select array_sum(c) from fn_test_array_with_large_decimal order by id" - check{result, exception, startTime, endTime -> - assertTrue(exception != null) - logger.info(exception.message) - } - } + order_qt_sql_array_product_decimal256 "select array_product(c) from fn_test_array_with_large_decimal order by id" + order_qt_sql_array_avg_decimal256 "select array_avg(c) from fn_test_array_with_large_decimal order by id" + order_qt_sql_array_sum_decimal256 "select array_sum(c) from fn_test_array_with_large_decimal order by id" // array_overlap for type correctness order_qt_sql_array_overlaps_1 """select arrays_overlap(a, b) from fn_test_array_with_large_decimal order by id""" order_qt_sql_array_overlaps_2 """select arrays_overlap(b, a) from fn_test_array_with_large_decimal order by id""" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org