This is an automated email from the ASF dual-hosted git repository. panxiaolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 65b8dfc7ff [Enchancement](function) Inline some aggregate function && remove nullable combinator (#17328) 65b8dfc7ff is described below commit 65b8dfc7ff8eb2af77d76631d898a41814774aa7 Author: Pxl <pxl...@qq.com> AuthorDate: Thu Mar 9 10:39:04 2023 +0800 [Enchancement](function) Inline some aggregate function && remove nullable combinator (#17328) 1. Inline some aggregate function 2. remove nullable combinator --- .gitignore | 1 + be/src/vec/CMakeLists.txt | 1 - .../vec/aggregate_functions/aggregate_function.h | 4 +- .../aggregate_function_approx_count_distinct.cpp | 18 +- .../aggregate_function_avg_weighted.cpp | 27 +- .../aggregate_function_avg_weighted.h | 8 +- .../aggregate_functions/aggregate_function_bit.cpp | 6 - .../aggregate_function_collect.cpp | 56 ++-- .../aggregate_function_collect.h | 6 +- .../aggregate_function_hll_union_agg.cpp | 34 +- .../aggregate_function_hll_union_agg.h | 33 +- .../aggregate_function_null.cpp | 99 ------ .../aggregate_functions/aggregate_function_null.h | 368 --------------------- .../aggregate_function_percentile_approx.cpp | 34 +- .../aggregate_function_percentile_approx.h | 34 +- .../aggregate_function_reader.cpp | 18 +- .../aggregate_function_retention.cpp | 5 +- .../aggregate_function_sequence_match.cpp | 20 +- .../aggregate_function_simple_factory.cpp | 4 - .../aggregate_function_stddev.cpp | 83 ++--- .../aggregate_function_stddev.h | 1 + .../aggregate_function_window.cpp | 14 +- .../aggregate_function_window_funnel.cpp | 18 +- .../functions/array/function_array_aggregation.cpp | 30 +- be/src/vec/utils/template_helpers.hpp | 21 +- .../vec/aggregate_functions/agg_histogram_test.cpp | 5 +- 26 files changed, 206 insertions(+), 742 deletions(-) diff --git a/.gitignore b/.gitignore index c6229863aa..7b5868f79b 100644 --- a/.gitignore +++ b/.gitignore @@ -96,5 +96,6 @@ tools/single-node-cluster/fe* # be-ut data_test +lru_cache_test /conf/log4j2-spring.xml diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index d9e897e67b..ac94673475 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -31,7 +31,6 @@ set(VEC_FILES aggregate_functions/aggregate_function_sort.cpp aggregate_functions/aggregate_function_min_max.cpp aggregate_functions/aggregate_function_min_max_by.cpp - aggregate_functions/aggregate_function_null.cpp aggregate_functions/aggregate_function_uniq.cpp aggregate_functions/aggregate_function_hll_union_agg.cpp aggregate_functions/aggregate_function_bit.cpp diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index d4a906231f..e86415b729 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -317,13 +317,13 @@ public: void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, const size_t num_rows, Arena* arena) const override { - VectorBufferWriter writter(static_cast<ColumnString&>(*dst)); + VectorBufferWriter writter(assert_cast<ColumnString&>(*dst)); streaming_agg_serialize(columns, writter, num_rows, arena); } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, MutableColumnPtr& dst) const override { - VectorBufferWriter writter(static_cast<ColumnString&>(*dst)); + VectorBufferWriter writter(assert_cast<ColumnString&>(*dst)); static_cast<const Derived*>(this)->serialize(place, writter); writter.commit(); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp index 0fa50b5194..2c22586d43 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp @@ -17,6 +17,7 @@ #include "vec/aggregate_functions/aggregate_function_approx_count_distinct.h" +#include "vec/aggregate_functions/helpers.h" #include "vec/utils/template_helpers.hpp" namespace doris::vectorized { @@ -24,13 +25,14 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_approx_count_distinct( const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { AggregateFunctionPtr res = nullptr; - WhichDataType which(argument_types[0]->is_nullable() - ? reinterpret_cast<const DataTypeNullable*>(argument_types[0].get()) - ->get_nested_type() - : argument_types[0]); + WhichDataType which(remove_nullable(argument_types[0])); - res.reset(create_class_with_type<AggregateFunctionApproxCountDistinct>(*argument_types[0], - argument_types)); +#define DISPATCH(TYPE, COLUMN_TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + res.reset(creator_without_type::create<AggregateFunctionApproxCountDistinct<COLUMN_TYPE>>( \ + result_is_nullable, argument_types)); + TYPE_TO_COLUMN_TYPE(DISPATCH) +#undef DISPATCH if (!res) { LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate function {}", @@ -41,8 +43,8 @@ AggregateFunctionPtr create_aggregate_function_approx_count_distinct( } void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory) { - factory.register_function("approx_count_distinct", - create_aggregate_function_approx_count_distinct); + factory.register_function_both("approx_count_distinct", + create_aggregate_function_approx_count_distinct); factory.register_alias("approx_count_distinct", "ndv"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp index ea4e058550..c81bf4b42f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp @@ -19,37 +19,18 @@ #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/helpers.h" +#include "vec/data_types/data_type_nullable.h" namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_avg_weight(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - auto type = argument_types[0].get(); - if (type->is_nullable()) { - type = assert_cast<const DataTypeNullable*>(type)->get_nested_type().get(); - } - - WhichDataType which(*type); - -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return AggregateFunctionPtr(new AggregateFunctionAvgWeight<TYPE>(argument_types)); - FOR_NUMERIC_TYPES(DISPATCH) -#undef DISPATCH - if (which.is_decimal128()) { - return AggregateFunctionPtr(new AggregateFunctionAvgWeight<Decimal128>(argument_types)); - } - if (which.is_decimal()) { - return AggregateFunctionPtr(new AggregateFunctionAvgWeight<Decimal128I>(argument_types)); - } - - LOG(WARNING) << fmt::format("Illegal argument type for aggregate function topn_array is: {}", - type->get_name()); - return nullptr; + return AggregateFunctionPtr(creator_with_type::create<AggregateFunctionAvgWeight>( + result_is_nullable, argument_types)); } void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& factory) { - factory.register_function("avg_weighted", create_aggregate_function_avg_weight); + factory.register_function_both("avg_weighted", create_aggregate_function_avg_weight); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h index aa3b70d4de..cc14f1e1b3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h @@ -31,7 +31,7 @@ struct AggregateFunctionAvgWeightedData { void add(const T& data_val, double weight_val) { if constexpr (IsDecimalV2<T>) { DecimalV2Value value = binary_cast<Int128, DecimalV2Value>(data_val); - data_sum = data_sum + (static_cast<double>(value) * weight_val); + data_sum = data_sum + (double(value) * weight_val); } else { data_sum = data_sum + (data_val * weight_val); } @@ -81,8 +81,8 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& column = static_cast<const ColVecType&>(*columns[0]); - const auto& weight = static_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); + const auto& weight = assert_cast<const ColumnVector<Float64>&>(*columns[1]); this->data(place).add(column.get_data()[row_num], weight.get_element(row_num)); } @@ -103,7 +103,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast<ColumnVector<Float64>&>(to); + auto& column = assert_cast<ColumnVector<Float64>&>(to); column.get_data().push_back(this->data(place).get()); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp index 6b9be5c92c..bdc51daaf9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp @@ -29,12 +29,6 @@ template <template <typename> class Data> AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - if (!argument_types[0]->can_be_used_in_bit_operations()) { - LOG(WARNING) << fmt::format("The type " + argument_types[0]->get_name() + - " of argument for aggregate function " + name + - " is illegal, because it cannot be used in bitwise operations"); - } - AggregateFunctionPtr res(creator_with_integer_type::create<AggregateFunctionBitwise, Data>( result_is_nullable, argument_types)); if (res) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp index d18ae860e5..110618581e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp @@ -22,53 +22,55 @@ namespace doris::vectorized { -#define FOR_DECIMAL_TYPES(M) \ - M(Decimal32) \ - M(Decimal64) \ - M(Decimal128) \ - M(Decimal128I) - template <typename T, typename HasLimit> -AggregateFunctionPtr do_create_agg_function_collect(bool distinct, - const DataTypePtr& argument_type) { +AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const DataTypes& argument_types, + const bool result_is_nullable) { if (distinct) { return AggregateFunctionPtr( - new AggregateFunctionCollect<AggregateFunctionCollectSetData<T, HasLimit>, - HasLimit>(argument_type)); + creator_without_type::create<AggregateFunctionCollect< + AggregateFunctionCollectSetData<T, HasLimit>, HasLimit>>(result_is_nullable, + argument_types)); } else { return AggregateFunctionPtr( - new AggregateFunctionCollect<AggregateFunctionCollectListData<T, HasLimit>, - HasLimit>(argument_type)); + creator_without_type::create<AggregateFunctionCollect< + AggregateFunctionCollectListData<T, HasLimit>, HasLimit>>( + result_is_nullable, argument_types)); } } template <typename HasLimit> AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& name, - const DataTypePtr& argument_type) { + const DataTypes& argument_types, + const bool result_is_nullable) { bool distinct = false; if (name == "collect_set") { distinct = true; } - WhichDataType which(argument_type); -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return do_create_agg_function_collect<TYPE, HasLimit>(distinct, argument_type); + WhichDataType which(remove_nullable(argument_types[0])); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return do_create_agg_function_collect<TYPE, HasLimit>(distinct, argument_types, \ + result_is_nullable); FOR_NUMERIC_TYPES(DISPATCH) FOR_DECIMAL_TYPES(DISPATCH) #undef DISPATCH if (which.is_date_or_datetime()) { - return do_create_agg_function_collect<Int64, HasLimit>(distinct, argument_type); + return do_create_agg_function_collect<Int64, HasLimit>(distinct, argument_types, + result_is_nullable); } else if (which.is_date_v2()) { - return do_create_agg_function_collect<UInt32, HasLimit>(distinct, argument_type); + return do_create_agg_function_collect<UInt32, HasLimit>(distinct, argument_types, + result_is_nullable); } else if (which.is_date_time_v2()) { - return do_create_agg_function_collect<UInt64, HasLimit>(distinct, argument_type); + return do_create_agg_function_collect<UInt64, HasLimit>(distinct, argument_types, + result_is_nullable); } else if (which.is_string()) { - return do_create_agg_function_collect<StringRef, HasLimit>(distinct, argument_type); + return do_create_agg_function_collect<StringRef, HasLimit>(distinct, argument_types, + result_is_nullable); } LOG(WARNING) << fmt::format("unsupported input type {} for aggregate function {}", - argument_type->get_name(), name); + argument_types[0]->get_name(), name); return nullptr; } @@ -76,10 +78,12 @@ AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { if (argument_types.size() == 1) { - return create_aggregate_function_collect_impl<std::false_type>(name, argument_types[0]); + return create_aggregate_function_collect_impl<std::false_type>(name, argument_types, + result_is_nullable); } if (argument_types.size() == 2) { - return create_aggregate_function_collect_impl<std::true_type>(name, argument_types[0]); + return create_aggregate_function_collect_impl<std::true_type>(name, argument_types, + result_is_nullable); } LOG(WARNING) << fmt::format("number of parameters for aggregate function {}, should be 1 or 2", name); @@ -87,8 +91,8 @@ AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, } void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory) { - factory.register_function("collect_list", create_aggregate_function_collect); - factory.register_function("collect_set", create_aggregate_function_collect); + factory.register_function_both("collect_list", create_aggregate_function_collect); + factory.register_function_both("collect_set", create_aggregate_function_collect); factory.register_alias("collect_list", "group_array"); factory.register_alias("collect_set", "group_uniq_array"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.h b/be/src/vec/aggregate_functions/aggregate_function_collect.h index 27d9733bd3..dd4e2eca84 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_collect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.h @@ -249,11 +249,11 @@ class AggregateFunctionCollect static constexpr bool ENABLE_ARENA = std::is_same_v<Data, GenericType>; public: - AggregateFunctionCollect(const DataTypePtr& argument_type, + AggregateFunctionCollect(const DataTypes& argument_types, UInt64 max_size_ = std::numeric_limits<UInt64>::max()) : IAggregateFunctionDataHelper<Data, AggregateFunctionCollect<Data, HasLimit>>( - {argument_type}), - return_type(argument_type) {} + {argument_types}), + return_type(argument_types[0]) {} std::string get_name() const override { if constexpr (std::is_same_v<AggregateFunctionCollectListData<typename Data::ElementType, diff --git a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.cpp b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.cpp index 8ca12d4e14..dc575e6d25 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.cpp @@ -19,36 +19,26 @@ #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { -template <bool is_nullable> -AggregateFunctionPtr create_aggregate_function_HLL_union_agg(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +template <template <typename> class Impl> +AggregateFunctionPtr create_aggregate_function_HLL(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { assert_arity_at_most<1>(name, argument_types); - - return std::make_shared<AggregateFunctionHLLUnion< - AggregateFunctionHLLUnionAggImpl<AggregateFunctionHLLData<is_nullable>>>>( - argument_types); -} - -template <bool is_nullable> -AggregateFunctionPtr create_aggregate_function_HLL_union(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { - assert_arity_at_most<1>(name, argument_types); - - return std::make_shared<AggregateFunctionHLLUnion< - AggregateFunctionHLLUnionImpl<AggregateFunctionHLLData<is_nullable>>>>(argument_types); + return AggregateFunctionPtr( + creator_without_type::create<AggregateFunctionHLLUnion<Impl<AggregateFunctionHLLData>>>( + result_is_nullable, argument_types)); } void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory& factory) { - factory.register_function("hll_union_agg", create_aggregate_function_HLL_union_agg<false>); - factory.register_function("hll_union_agg", create_aggregate_function_HLL_union_agg<true>, true); + factory.register_function_both("hll_union_agg", + create_aggregate_function_HLL<AggregateFunctionHLLUnionAggImpl>); - factory.register_function("hll_union", create_aggregate_function_HLL_union<false>); - factory.register_function("hll_union", create_aggregate_function_HLL_union<true>, true); + factory.register_function_both("hll_union", + create_aggregate_function_HLL<AggregateFunctionHLLUnionImpl>); factory.register_alias("hll_union", "hll_raw_agg"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h index 9052d78027..3e8ee3cb29 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h @@ -17,8 +17,9 @@ #pragma once +#include <type_traits> + #include "olap/hll.h" -#include "util/slice.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column_string.h" #include "vec/columns/column_vector.h" @@ -29,7 +30,6 @@ namespace doris::vectorized { -template <bool is_nullable> struct AggregateFunctionHLLData { HyperLogLog dst_hll {}; @@ -55,26 +55,16 @@ struct AggregateFunctionHLLData { void reset() { dst_hll.clear(); } void add(const IColumn* column, size_t row_num) { - if constexpr (is_nullable) { - auto* nullable_column = check_and_get_column<const ColumnNullable>(*column); - if (nullable_column->is_null_at(row_num)) { - return; - } - const auto& sources = - static_cast<const ColumnHLL&>((nullable_column->get_nested_column())); - dst_hll.merge(sources.get_element(row_num)); - - } else { - const auto& sources = static_cast<const ColumnHLL&>(*column); - dst_hll.merge(sources.get_element(row_num)); - } + const auto& sources = static_cast<const ColumnHLL&>(*column); + dst_hll.merge(sources.get_element(row_num)); } }; template <typename Data> struct AggregateFunctionHLLUnionImpl : Data { void insert_result_into(IColumn& to) const { - assert_cast<ColumnHLL&>(to).get_data().emplace_back(this->get()); + ColumnHLL& column = assert_cast<ColumnHLL&>(to); + column.get_data().emplace_back(this->get()); } static DataTypePtr get_return_type() { return std::make_shared<DataTypeHLL>(); } @@ -85,7 +75,8 @@ struct AggregateFunctionHLLUnionImpl : Data { template <typename Data> struct AggregateFunctionHLLUnionAggImpl : Data { void insert_result_into(IColumn& to) const { - assert_cast<ColumnInt64&>(to).get_data().emplace_back(this->get_cardinality()); + ColumnInt64& column = assert_cast<ColumnInt64&>(to); + column.get_data().emplace_back(this->get_cardinality()); } static DataTypePtr get_return_type() { return std::make_shared<DataTypeInt64>(); } @@ -130,9 +121,9 @@ public: void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } }; -template <bool is_nullable = false> -AggregateFunctionPtr create_aggregate_function_HLL_union(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable); +template <template <typename> class Impl> +AggregateFunctionPtr create_aggregate_function_HLL(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable); } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.cpp b/be/src/vec/aggregate_functions/aggregate_function_null.cpp deleted file mode 100644 index 6795adecf9..0000000000 --- a/be/src/vec/aggregate_functions/aggregate_function_null.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -// This file is copied from -// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionNull.cpp -// and modified by Doris - -#include "vec/aggregate_functions/aggregate_function_null.h" - -#include "common/logging.h" -#include "vec/aggregate_functions/aggregate_function.h" -#include "vec/aggregate_functions/aggregate_function_combinator.h" -#include "vec/aggregate_functions/aggregate_function_count.h" -#include "vec/aggregate_functions/aggregate_function_nothing.h" -#include "vec/aggregate_functions/aggregate_function_simple_factory.h" -#include "vec/data_types/data_type_nullable.h" - -namespace doris::vectorized { - -class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinator { -public: - String get_name() const override { return "Null"; } - - bool is_for_internal_usage_only() const override { return true; } - - DataTypes transform_arguments(const DataTypes& arguments) const override { - size_t size = arguments.size(); - DataTypes res(size); - for (size_t i = 0; i < size; ++i) { - res[i] = remove_nullable(arguments[i]); - } - return res; - } - - AggregateFunctionPtr transform_aggregate_function( - const AggregateFunctionPtr& nested_function, const DataTypes& arguments, - const bool result_is_nullable) const override { - if (nested_function == nullptr) { - return nullptr; - } - - bool has_null_types = false; - for (const auto& arg_type : arguments) { - if (arg_type->only_null()) { - has_null_types = true; - break; - } - } - - if (has_null_types) { - return std::make_shared<AggregateFunctionNothing>(arguments); - } - - if (arguments.size() == 1) { - if (result_is_nullable) { - return std::make_shared<AggregateFunctionNullUnary<true>>(nested_function, - arguments); - } else { - return std::make_shared<AggregateFunctionNullUnary<false>>(nested_function, - arguments); - } - } else { - if (result_is_nullable) { - return std::make_shared<AggregateFunctionNullVariadic<true>>(nested_function, - arguments); - } else { - return std::make_shared<AggregateFunctionNullVariadic<false>>(nested_function, - arguments); - } - } - } -}; - -void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory) { - AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, - const bool result_is_nullable) { - auto function_combinator = std::make_shared<AggregateFunctionCombinatorNull>(); - auto transform_arguments = function_combinator->transform_arguments(types); - auto nested_function = factory.get(name, transform_arguments); - return function_combinator->transform_aggregate_function(nested_function, types, - result_is_nullable); - }; - factory.register_nullable_function_combinator(creator); -} - -} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h b/be/src/vec/aggregate_functions/aggregate_function_null.h index 0e56b1f306..180a334d19 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_null.h +++ b/be/src/vec/aggregate_functions/aggregate_function_null.h @@ -32,374 +32,6 @@ namespace doris::vectorized { -/// This class implements a wrapper around an aggregate function. Despite its name, -/// this is an adapter. It is used to handle aggregate functions that are called with -/// at least one nullable argument. It implements the logic according to which any -/// row that contains at least one NULL is skipped. - -/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter. -/// true - return NULL; false - return value from empty aggregation state of nested function. - -// TODO: only keep class xxxInline after we support all aggregate function -template <bool result_is_nullable, typename Derived> -class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived> { -protected: - AggregateFunctionPtr nested_function; - size_t prefix_size; - - /** In addition to data for nested aggregate function, we keep a flag - * indicating - was there at least one non-NULL value accumulated. - * In case of no not-NULL values, the function will return NULL. - * - * We use prefix_size bytes for flag to satisfy the alignment requirement of nested state. - */ - - AggregateDataPtr nested_place(AggregateDataPtr __restrict place) const noexcept { - return place + prefix_size; - } - - ConstAggregateDataPtr nested_place(ConstAggregateDataPtr __restrict place) const noexcept { - return place + prefix_size; - } - - static void init_flag(AggregateDataPtr __restrict place) noexcept { - if constexpr (result_is_nullable) { - place[0] = false; - } - } - - static void set_flag(AggregateDataPtr __restrict place) noexcept { - if constexpr (result_is_nullable) { - place[0] = true; - } - } - - static bool get_flag(ConstAggregateDataPtr __restrict place) noexcept { - return result_is_nullable ? place[0] : true; - } - -public: - AggregateFunctionNullBase(AggregateFunctionPtr nested_function_, const DataTypes& arguments) - : IAggregateFunctionHelper<Derived>(arguments), nested_function {nested_function_} { - if (result_is_nullable) { - prefix_size = nested_function->align_of_data(); - } else { - prefix_size = 0; - } - } - - String get_name() const override { - /// This is just a wrapper. The function for Nullable arguments is named the same as the nested function itself. - return nested_function->get_name(); - } - - DataTypePtr get_return_type() const override { - return result_is_nullable ? make_nullable(nested_function->get_return_type()) - : nested_function->get_return_type(); - } - - void create(AggregateDataPtr __restrict place) const override { - init_flag(place); - nested_function->create(nested_place(place)); - } - - void destroy(AggregateDataPtr __restrict place) const noexcept override { - nested_function->destroy(nested_place(place)); - } - void reset(AggregateDataPtr place) const override { - init_flag(place); - nested_function->reset(nested_place(place)); - } - - bool has_trivial_destructor() const override { - return nested_function->has_trivial_destructor(); - } - - size_t size_of_data() const override { return prefix_size + nested_function->size_of_data(); } - - size_t align_of_data() const override { return nested_function->align_of_data(); } - - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, - Arena* arena) const override { - if (result_is_nullable && get_flag(rhs)) { - set_flag(place); - } - - nested_function->merge(nested_place(place), nested_place(rhs), arena); - } - - void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { - bool flag = get_flag(place); - if (result_is_nullable) { - write_binary(flag, buf); - } - if (flag) { - nested_function->serialize(nested_place(place), buf); - } - } - - void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, - Arena* arena) const override { - bool flag = true; - if (result_is_nullable) { - read_binary(flag, buf); - } - if (flag) { - set_flag(place); - nested_function->deserialize(nested_place(place), buf, arena); - } - } - - void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf, - Arena* arena) const override { - bool flag = true; - if (result_is_nullable) { - read_binary(flag, buf); - } - if (flag) { - set_flag(place); - nested_function->deserialize_and_merge(nested_place(place), buf, arena); - } - } - - void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, - Arena* arena) const override { - size_t num_rows = column.size(); - for (size_t i = 0; i != num_rows; ++i) { - VectorBufferReader buffer_reader( - (assert_cast<const ColumnString&>(column)).get_data_at(i)); - deserialize_and_merge(place, buffer_reader, arena); - } - } - - void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - if constexpr (result_is_nullable) { - ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to); - if (get_flag(place)) { - nested_function->insert_result_into(nested_place(place), - to_concrete.get_nested_column()); - to_concrete.get_null_map_data().push_back(0); - } else { - to_concrete.insert_default(); - } - } else { - nested_function->insert_result_into(nested_place(place), to); - } - } - - bool allocates_memory_in_arena() const override { - return nested_function->allocates_memory_in_arena(); - } - - bool is_state() const override { return nested_function->is_state(); } -}; - -/** There are two cases: for single argument and variadic. - * Code for single argument is much more efficient. - */ -template <bool result_is_nullable> -class AggregateFunctionNullUnary final - : public AggregateFunctionNullBase<result_is_nullable, - AggregateFunctionNullUnary<result_is_nullable>> { -public: - AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes& arguments) - : AggregateFunctionNullBase<result_is_nullable, - AggregateFunctionNullUnary<result_is_nullable>>( - std::move(nested_function_), arguments) {} - - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, - Arena* arena) const override { - const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); - if (!column->is_null_at(row_num)) { - this->set_flag(place); - const IColumn* nested_column = &column->get_nested_column(); - this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena); - } - } - - void add_not_nullable(AggregateDataPtr __restrict place, const IColumn** columns, - size_t row_num, Arena* arena) const { - const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); - this->set_flag(place); - const IColumn* nested_column = &column->get_nested_column(); - this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena); - } - - void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset, - const IColumn** columns, Arena* arena, bool agg_many) const override { - int processed_records_num = 0; - - // we can use column->has_null() to judge whether whole batch of data is null and skip batch, - // but it's maybe too coarse-grained. -#ifdef __AVX2__ - const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); - // The overhead introduced is negligible here, just an extra memory read from NullMap - const NullMap& null_map_data = column->get_null_map_data(); - - // NullMap use uint8_t type to indicate values is null or not, 1 indicates null, 0 versus. - // It's important to keep consistent with element type size in NullMap - constexpr int simd_batch_size = 256 / (8 * sizeof(uint8_t)); - __m256i all0 = _mm256_setzero_si256(); - auto null_map_start_position = reinterpret_cast<const int8_t*>(null_map_data.data()); - - while (processed_records_num + simd_batch_size <= batch_size) { - // load unaligned data from null_map, 1 means value is null, 0 versus - __m256i f = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( - null_map_start_position + processed_records_num)); - int mask = _mm256_movemask_epi8(_mm256_cmpgt_epi8(f, all0)); - // all data is null - if (mask == 0xffffffff) { - } else if (mask == 0) { // all data is not null - for (size_t i = processed_records_num; i < processed_records_num + simd_batch_size; - i++) { - AggregateFunctionNullUnary::add_not_nullable(places[i] + place_offset, columns, - i, arena); - } - } else { - // data is partly null - for (size_t i = processed_records_num; i < processed_records_num + simd_batch_size; - i++) { - add(places[i] + place_offset, columns, i, arena); - } - } - processed_records_num += simd_batch_size; - } - -#elif __SSE2__ - const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); - // The overhead introduced is negligible here, just an extra memory read from NullMap - const NullMap& null_map_data = column->get_null_map_data(); - // NullMap use uint8_t type to indicate values is null or not, 1 indicates null, 0 versus. - // It's important to keep consistent with element type size in NullMap - constexpr int simd_batch_size = 128 / (8 * sizeof(uint8_t)); - __m128i all0 = _mm_setzero_si128(); - auto null_map_start_position = reinterpret_cast<const int8_t*>(null_map_data.data()); - while (processed_records_num + simd_batch_size <= batch_size) { - // load unaligned data from null_map, 1 means value is null, 0 versus - __m128i f = _mm_loadu_si128(reinterpret_cast<const __m128i*>(null_map_start_position + - processed_records_num)); - int mask = _mm_movemask_epi8(_mm_cmpgt_epi8(f, all0)); - // all data is null - if (mask == 0xffff) { - } else if (mask == 0) { // all data is not null - for (size_t i = processed_records_num; i < processed_records_num + simd_batch_size; - i++) { - add_not_nullable(places[i] + place_offset, columns, i, arena); - } - } else { - // data is partly null - for (size_t i = processed_records_num; i < processed_records_num + simd_batch_size; - i++) { - add(places[i] + place_offset, columns, i, arena); - } - } - processed_records_num += simd_batch_size; - } -#endif - - for (; processed_records_num < batch_size; ++processed_records_num) { - add(places[processed_records_num] + place_offset, columns, processed_records_num, - arena); - } - } - - void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, - Arena* arena) const override { - const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); - bool has_null = column->has_null(); - - if (has_null) { - for (size_t i = 0; i < batch_size; ++i) { - this->add(place, columns, i, arena); - } - } else { - this->set_flag(place); - const IColumn* nested_column = &column->get_nested_column(); - this->nested_function->add_batch_single_place(batch_size, this->nested_place(place), - &nested_column, arena); - } - } - - void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place, - const IColumn** columns, Arena* arena, bool has_null) override { - const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]); - - if (has_null) { - for (size_t i = batch_begin; i <= batch_end; ++i) { - this->add(place, columns, i, arena); - } - } else { - this->set_flag(place); - const IColumn* nested_column = &column->get_nested_column(); - this->nested_function->add_batch_range( - batch_begin, batch_end, this->nested_place(place), &nested_column, arena); - } - } -}; - -template <bool result_is_nullable> -class AggregateFunctionNullVariadic final - : public AggregateFunctionNullBase<result_is_nullable, - AggregateFunctionNullVariadic<result_is_nullable>> { -public: - AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes& arguments) - : AggregateFunctionNullBase<result_is_nullable, - AggregateFunctionNullVariadic<result_is_nullable>>( - std::move(nested_function_), arguments), - number_of_arguments(arguments.size()) { - if (number_of_arguments == 1) { - LOG(FATAL) - << "Logical error: single argument is passed to AggregateFunctionNullVariadic"; - } - - if (number_of_arguments > MAX_ARGS) { - LOG(FATAL) << fmt::format( - "Maximum number of arguments for aggregate function with Nullable types is {}", - size_t(MAX_ARGS)); - } - - for (size_t i = 0; i < number_of_arguments; ++i) { - is_nullable[i] = arguments[i]->is_nullable(); - } - } - - void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, - Arena* arena) const override { - /// This container stores the columns we really pass to the nested function. - const IColumn* nested_columns[number_of_arguments]; - - for (size_t i = 0; i < number_of_arguments; ++i) { - if (is_nullable[i]) { - const ColumnNullable& nullable_col = - assert_cast<const ColumnNullable&>(*columns[i]); - if (nullable_col.is_null_at(row_num)) { - /// If at least one column has a null value in the current row, - /// we don't process this row. - return; - } - nested_columns[i] = &nullable_col.get_nested_column(); - } else { - nested_columns[i] = columns[i]; - } - } - - this->set_flag(place); - this->nested_function->add(this->nested_place(place), nested_columns, row_num, arena); - } - - bool allocates_memory_in_arena() const override { - return this->nested_function->allocates_memory_in_arena(); - } - -private: - // The array length is fixed in the implementation of some aggregate functions. - // Therefore we choose 256 as the appropriate maximum length limit. - static const size_t MAX_ARGS = 256; - size_t number_of_arguments = 0; - std::array<char, MAX_ARGS> - is_nullable; /// Plain array is better than std::vector due to one indirection less. -}; - template <typename NestFunction, bool result_is_nullable, typename Derived> class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived> { protected: diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp index 9548663b57..3dfe11388b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp @@ -19,6 +19,7 @@ #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { @@ -27,14 +28,17 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri const DataTypes& argument_types, const bool result_is_nullable) { if (argument_types.size() == 1) { - return std::make_shared<AggregateFunctionPercentileApproxMerge<is_nullable>>( - argument_types); + return AggregateFunctionPtr( + creator_without_type::create<AggregateFunctionPercentileApproxMerge<is_nullable>>( + result_is_nullable, remove_nullable(argument_types))); } else if (argument_types.size() == 2) { - return std::make_shared<AggregateFunctionPercentileApproxTwoParams<is_nullable>>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionPercentileApproxTwoParams<is_nullable>>( + result_is_nullable, remove_nullable(argument_types))); } else if (argument_types.size() == 3) { - return std::make_shared<AggregateFunctionPercentileApproxThreeParams<is_nullable>>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionPercentileApproxThreeParams<is_nullable>>( + result_is_nullable, remove_nullable(argument_types))); } LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", argument_types.size(), name); @@ -44,24 +48,26 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri AggregateFunctionPtr create_aggregate_function_percentile(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<AggregateFunctionPercentile>(argument_types); + return AggregateFunctionPtr(creator_without_type::create<AggregateFunctionPercentile>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_percentile_array(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<AggregateFunctionPercentileArray>(argument_types); + return AggregateFunctionPtr(creator_without_type::create<AggregateFunctionPercentileArray>( + result_is_nullable, argument_types)); } void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) { - factory.register_function("percentile", create_aggregate_function_percentile); - factory.register_function("percentile_array", create_aggregate_function_percentile_array); + factory.register_function_both("percentile", create_aggregate_function_percentile); + factory.register_function_both("percentile_array", create_aggregate_function_percentile_array); } void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) { - factory.register_function("percentile_approx", - create_aggregate_function_percentile_approx<false>, false); - factory.register_function("percentile_approx", - create_aggregate_function_percentile_approx<true>, true); + factory.register_function_both("percentile_approx", + create_aggregate_function_percentile_approx<false>); + factory.register_function_both("percentile_approx", + create_aggregate_function_percentile_approx<true>); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h index c07cd3e2f3..dc3e232850 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h @@ -161,7 +161,7 @@ public: if (std::isnan(result)) { nullable_column.insert_default(); } else { - auto& col = static_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); + auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); col.get_data().push_back(result); nullable_column.get_null_map_data().push_back(0); } @@ -193,11 +193,11 @@ public: for (int i = 0; i < 2; ++i) { const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); if (nullable_column == nullptr) { //Not Nullable column - const auto& column = static_cast<const ColumnVector<Float64>&>(*columns[i]); + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); column_data[i] = column.get_float64(row_num); } else if (!nullable_column->is_null_at( row_num)) { // Nullable column && Not null data - const auto& column = static_cast<const ColumnVector<Float64>&>( + const auto& column = assert_cast<const ColumnVector<Float64>&>( nullable_column->get_nested_column()); column_data[i] = column.get_float64(row_num); } else { // Nullable column && null data @@ -211,8 +211,8 @@ public: this->data(place).add(column_data[0], column_data[1]); } else { - const auto& sources = static_cast<const ColumnVector<Float64>&>(*columns[0]); - const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); this->data(place).init(); this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); @@ -233,11 +233,11 @@ public: for (int i = 0; i < 3; ++i) { const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); if (nullable_column == nullptr) { //Not Nullable column - const auto& column = static_cast<const ColumnVector<Float64>&>(*columns[i]); + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); column_data[i] = column.get_float64(row_num); } else if (!nullable_column->is_null_at( row_num)) { // Nullable column && Not null data - const auto& column = static_cast<const ColumnVector<Float64>&>( + const auto& column = assert_cast<const ColumnVector<Float64>&>( nullable_column->get_nested_column()); column_data[i] = column.get_float64(row_num); } else { // Nullable column && null data @@ -251,9 +251,9 @@ public: this->data(place).add(column_data[0], column_data[1]); } else { - const auto& sources = static_cast<const ColumnVector<Float64>&>(*columns[0]); - const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]); - const auto& compression = static_cast<const ColumnVector<Float64>&>(*columns[2]); + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]); this->data(place).init(compression.get_float64(row_num)); this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); @@ -342,7 +342,7 @@ struct PercentileState { double get() const { return vec_counts[0].terminate(vec_quantile[0]); } void insert_result_into(IColumn& to) const { - auto& column_data = static_cast<ColumnVector<Float64>&>(to).get_data(); + auto& column_data = assert_cast<ColumnVector<Float64>&>(to).get_data(); for (int i = 0; i < vec_counts.size(); ++i) { column_data.push_back(vec_counts[i].terminate(vec_quantile[i])); } @@ -362,8 +362,8 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& sources = static_cast<const ColumnVector<Int64>&>(*columns[0]); - const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(), 1); } @@ -407,12 +407,12 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& sources = static_cast<const ColumnVector<Int64>&>(*columns[0]); - const auto& quantile_array = static_cast<const ColumnArray&>(*columns[1]); + const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& quantile_array = assert_cast<const ColumnArray&>(*columns[1]); const auto& offset_column_data = quantile_array.get_offsets(); const auto& nested_column = - static_cast<const ColumnNullable&>(quantile_array.get_data()).get_nested_column(); - const auto& nested_column_data = static_cast<const ColumnVector<Float64>&>(nested_column); + assert_cast<const ColumnNullable&>(quantile_array.get_data()).get_nested_column(); + const auto& nested_column_data = assert_cast<const ColumnVector<Float64>&>(nested_column); AggregateFunctionPercentileArray::data(place).add( sources.get_int(row_num), nested_column_data.get_data(), diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp index 8ff5159615..0d4231e8e7 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp @@ -22,16 +22,18 @@ namespace doris::vectorized { // auto spread at nullable condition, null value do not participate aggregate void register_aggregate_function_reader_load(AggregateFunctionSimpleFactory& factory) { // add a suffix to the function name here to distinguish special functions of agg reader - auto register_function = [&](const std::string& name, const AggregateFunctionCreator& creator) { - factory.register_function(name + AGG_READER_SUFFIX, creator, false); - factory.register_function(name + AGG_LOAD_SUFFIX, creator, false); + auto register_function_both = [&](const std::string& name, + const AggregateFunctionCreator& creator) { + factory.register_function_both(name + AGG_READER_SUFFIX, creator); + factory.register_function_both(name + AGG_LOAD_SUFFIX, creator); }; - register_function("sum", create_aggregate_function_sum_reader); - register_function("max", create_aggregate_function_max); - register_function("min", create_aggregate_function_min); - register_function("bitmap_union", create_aggregate_function_bitmap_union); - register_function("hll_union", create_aggregate_function_HLL_union<false>); + register_function_both("sum", create_aggregate_function_sum_reader); + register_function_both("max", create_aggregate_function_max); + register_function_both("min", create_aggregate_function_min); + register_function_both("bitmap_union", create_aggregate_function_bitmap_union); + register_function_both("hll_union", + create_aggregate_function_HLL<AggregateFunctionHLLUnionImpl>); } // only replace function in load/reader do different agg operation. diff --git a/be/src/vec/aggregate_functions/aggregate_function_retention.cpp b/be/src/vec/aggregate_functions/aggregate_function_retention.cpp index 44c2baaaaf..c57c1d075c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_retention.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_retention.cpp @@ -26,10 +26,11 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_retention(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<AggregateFunctionRetention>(argument_types); + return AggregateFunctionPtr(creator_without_type::create<AggregateFunctionRetention>( + result_is_nullable, argument_types)); } void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory) { - factory.register_function("retention", create_aggregate_function_retention, false); + factory.register_function_both("retention", create_aggregate_function_retention); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp index 13ed539f87..ce8db857b3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp @@ -41,24 +41,28 @@ AggregateFunctionPtr create_aggregate_function_sequence_base(const std::string& } if (WhichDataType(remove_nullable(argument_types[1])).is_date_time_v2()) { - return std::make_shared<AggregateFunction<DateV2Value<DateTimeV2ValueType>, UInt64>>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunction<DateV2Value<DateTimeV2ValueType>, UInt64>>( + result_is_nullable, argument_types)); } else if (WhichDataType(remove_nullable(argument_types[1])).is_date_time()) { - return std::make_shared<AggregateFunction<VecDateTimeValue, Int64>>(argument_types); + return AggregateFunctionPtr( + creator_without_type::create<AggregateFunction<VecDateTimeValue, Int64>>( + result_is_nullable, argument_types)); } else if (WhichDataType(remove_nullable(argument_types[1])).is_date_v2()) { - return std::make_shared<AggregateFunction<DateV2Value<DateV2ValueType>, UInt32>>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunction<DateV2Value<DateV2ValueType>, UInt32>>( + result_is_nullable, argument_types)); } else { - LOG(FATAL) << "Only support Date and DateTime type as timestamp argument!"; + LOG(WARNING) << "Only support Date and DateTime type as timestamp argument!"; return nullptr; } } void register_aggregate_function_sequence_match(AggregateFunctionSimpleFactory& factory) { - factory.register_function( + factory.register_function_both( "sequence_match", create_aggregate_function_sequence_base<AggregateFunctionSequenceMatch>); - factory.register_function( + factory.register_function_both( "sequence_count", create_aggregate_function_sequence_base<AggregateFunctionSequenceCount>); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 400c34ae4d..6a29f581a4 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -28,7 +28,6 @@ class AggregateFunctionSimpleFactory; void register_aggregate_function_combinator_sort(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory); -void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory); @@ -87,9 +86,6 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_avg_weighted(instance); register_aggregate_function_histogram(instance); - // if you only register function with no nullable, and wants to add nullable automatically, you should place function above this line - register_aggregate_function_combinator_null(instance); - register_aggregate_function_stddev_variance_samp(instance); register_aggregate_function_replace_reader_load(instance); register_aggregate_function_window_lead_lag_first_last(instance); diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp index 8373f12383..d549ad2ed0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp @@ -27,43 +27,34 @@ template <template <typename, bool> class AggregateFunctionTemplate, template <typename> class NameData, template <typename, typename> class Data, bool is_stddev, bool is_nullable = false> static IAggregateFunction* create_function_single_value(const String& name, - const DataTypes& argument_types) { - auto type = argument_types[0].get(); - if (type->is_nullable()) { - type = assert_cast<const DataTypeNullable*>(type)->get_nested_type().get(); - } - - WhichDataType which(*type); - -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return new AggregateFunctionTemplate<NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>, \ - is_nullable>(argument_types); - + const DataTypes& argument_types, + const bool result_is_nullable, + bool custom_nullable) { + IAggregateFunction* res = nullptr; + WhichDataType which(remove_nullable(argument_types[0])); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + res = creator_without_type::create<AggregateFunctionTemplate< \ + NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>, is_nullable>>( \ + result_is_nullable, \ + custom_nullable ? remove_nullable(argument_types) : argument_types); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH - if (which.is_decimal32()) { - return new AggregateFunctionTemplate< - NameData<Data<Decimal32, BaseDatadecimal<Decimal32, is_stddev>>>, is_nullable>( - argument_types); - } - if (which.is_decimal64()) { - return new AggregateFunctionTemplate< - NameData<Data<Decimal64, BaseDatadecimal<Decimal64, is_stddev>>>, is_nullable>( - argument_types); - } - if (which.is_decimal128()) { - return new AggregateFunctionTemplate< - NameData<Data<Decimal128, BaseDatadecimal<Decimal128, is_stddev>>>, is_nullable>( - argument_types); - } - if (which.is_decimal128i()) { - return new AggregateFunctionTemplate< - NameData<Data<Decimal128I, BaseDatadecimal<Decimal128I, is_stddev>>>, is_nullable>( - argument_types); + +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + res = creator_without_type::create<AggregateFunctionTemplate< \ + NameData<Data<TYPE, BaseDatadecimal<TYPE, is_stddev>>>, is_nullable>>( \ + result_is_nullable, \ + custom_nullable ? remove_nullable(argument_types) : argument_types); + FOR_DECIMAL_TYPES(DISPATCH) +#undef DISPATCH + + if (res == nullptr) { + LOG(WARNING) << fmt::format("create_function_single_value with unknowed type {}", + argument_types[0]->get_name()); } - DCHECK(false) << "with unknowed type, failed in create_aggregate_function_stddev_variance"; - return nullptr; + return res; } template <bool is_stddev, bool is_nullable> @@ -72,16 +63,17 @@ AggregateFunctionPtr create_aggregate_function_variance_samp(const std::string& const bool result_is_nullable) { return AggregateFunctionPtr( create_function_single_value<AggregateFunctionSamp, VarianceSampName, SampData, - is_stddev, is_nullable>(name, argument_types)); + is_stddev, is_nullable>(name, argument_types, + result_is_nullable, true)); } template <bool is_stddev, bool is_nullable> AggregateFunctionPtr create_aggregate_function_stddev_samp(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return AggregateFunctionPtr( - create_function_single_value<AggregateFunctionSamp, StddevSampName, SampData, is_stddev, - is_nullable>(name, argument_types)); + return AggregateFunctionPtr(create_function_single_value<AggregateFunctionSamp, StddevSampName, + SampData, is_stddev, is_nullable>( + name, argument_types, result_is_nullable, true)); } template <bool is_stddev> @@ -90,7 +82,7 @@ AggregateFunctionPtr create_aggregate_function_variance_pop(const std::string& n const bool result_is_nullable) { return AggregateFunctionPtr( create_function_single_value<AggregateFunctionPop, VarianceName, PopData, is_stddev>( - name, argument_types)); + name, argument_types, result_is_nullable, false)); } template <bool is_stddev> @@ -99,27 +91,24 @@ AggregateFunctionPtr create_aggregate_function_stddev_pop(const std::string& nam const bool result_is_nullable) { return AggregateFunctionPtr( create_function_single_value<AggregateFunctionPop, StddevName, PopData, is_stddev>( - name, argument_types)); + name, argument_types, result_is_nullable, false)); } void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& factory) { - factory.register_function("variance", create_aggregate_function_variance_pop<false>); + factory.register_function_both("variance", create_aggregate_function_variance_pop<false>); factory.register_alias("variance", "var_pop"); factory.register_alias("variance", "variance_pop"); - factory.register_function("stddev", create_aggregate_function_stddev_pop<true>); + factory.register_function_both("stddev", create_aggregate_function_stddev_pop<true>); factory.register_alias("stddev", "stddev_pop"); } void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory& factory) { - // _samp<bool, bool>: first indicate is stddev or variance function - // second indicate is arg nullable column factory.register_function("variance_samp", - create_aggregate_function_variance_samp<false, false>, false); + create_aggregate_function_variance_samp<false, false>); factory.register_function("variance_samp", create_aggregate_function_variance_samp<false, true>, true); factory.register_alias("variance_samp", "var_samp"); - factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true, false>, - false); + factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true, false>); factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true, true>, true); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h b/be/src/vec/aggregate_functions/aggregate_function_stddev.h index 63ce35b9c1..625f41888b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h @@ -20,6 +20,7 @@ #include "common/status.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" #include "vec/data_types/data_type_decimal.h" #include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_number.h" diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp b/be/src/vec/aggregate_functions/aggregate_function_window.cpp index a36b9601c2..97c8d18565 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp @@ -22,6 +22,7 @@ #include "common/logging.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" #include "vec/utils/template_helpers.hpp" namespace doris::vectorized { @@ -29,27 +30,30 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_dense_rank(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<WindowFunctionDenseRank>(argument_types); + return AggregateFunctionPtr(creator_without_type::create<WindowFunctionDenseRank>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_rank(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<WindowFunctionRank>(argument_types); + return AggregateFunctionPtr( + creator_without_type::create<WindowFunctionRank>(result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_row_number(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<WindowFunctionRowNumber>(argument_types); + return AggregateFunctionPtr(creator_without_type::create<WindowFunctionRowNumber>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_ntile(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { assert_unary(name, argument_types); - - return std::make_shared<WindowFunctionNTile>(argument_types); + return AggregateFunctionPtr( + creator_without_type::create<WindowFunctionNTile>(result_is_nullable, argument_types)); } template <template <typename> class AggregateFunctionTemplate, diff --git a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp index d778ee9c13..7617d30656 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp @@ -18,6 +18,7 @@ #include "vec/aggregate_functions/aggregate_function_window_funnel.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" #include "vec/data_types/data_type_nullable.h" namespace doris::vectorized { @@ -30,18 +31,21 @@ AggregateFunctionPtr create_aggregate_function_window_funnel(const std::string& return nullptr; } if (WhichDataType(remove_nullable(argument_types[2])).is_date_time_v2()) { - return std::make_shared< - AggregateFunctionWindowFunnel<DateV2Value<DateTimeV2ValueType>, UInt64>>( - argument_types); + return AggregateFunctionPtr( + creator_without_type::create< + AggregateFunctionWindowFunnel<DateV2Value<DateTimeV2ValueType>, UInt64>>( + result_is_nullable, argument_types)); } else if (WhichDataType(remove_nullable(argument_types[2])).is_date_time()) { - return std::make_shared<AggregateFunctionWindowFunnel<VecDateTimeValue, Int64>>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionWindowFunnel<VecDateTimeValue, Int64>>( + result_is_nullable, argument_types)); } else { - LOG(FATAL) << "Only support DateTime type as window argument!"; + LOG(WARNING) << "Only support DateTime type as window argument!"; + return nullptr; } } void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory) { - factory.register_function("window_funnel", create_aggregate_function_window_funnel, false); + factory.register_function_both("window_funnel", create_aggregate_function_window_funnel); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/array/function_array_aggregation.cpp b/be/src/vec/functions/array/function_array_aggregation.cpp index 2f2cc63816..bb468c3241 100644 --- a/be/src/vec/functions/array/function_array_aggregation.cpp +++ b/be/src/vec/functions/array/function_array_aggregation.cpp @@ -117,14 +117,8 @@ struct AggregateFunction { using Function = typename Derived::template TypeTraits<T>::Function; static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr { - DataTypes data_types = {remove_nullable(data_type_ptr)}; - AggregateFunctionPtr nested_function; - nested_function.reset(creator_with_type::create<Function>(false, data_types)); - - AggregateFunctionPtr function; - function.reset(new AggregateFunctionNullUnary<true>(nested_function, - {make_nullable(data_type_ptr)})); - return function; + return AggregateFunctionPtr(creator_with_type::create<Function>( + true, DataTypes {make_nullable(data_type_ptr)})); } }; @@ -229,14 +223,8 @@ struct NameArrayMin { template <> struct AggregateFunction<AggregateFunctionImpl<AggregateOperation::MIN>> { static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr { - DataTypes data_types = {remove_nullable(data_type_ptr)}; - auto nested_function = AggregateFunctionPtr( - create_aggregate_function_min(NameArrayMin::name, data_types, false)); - - AggregateFunctionPtr function; - function.reset(new AggregateFunctionNullUnary<true>(nested_function, - {make_nullable(data_type_ptr)})); - return function; + return AggregateFunctionPtr(create_aggregate_function_min( + NameArrayMin::name, {make_nullable(data_type_ptr)}, true)); } }; @@ -247,14 +235,8 @@ struct NameArrayMax { template <> struct AggregateFunction<AggregateFunctionImpl<AggregateOperation::MAX>> { static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr { - DataTypes data_types = {remove_nullable(data_type_ptr)}; - auto nested_function = AggregateFunctionPtr( - create_aggregate_function_max(NameArrayMax::name, data_types, false)); - - AggregateFunctionPtr function; - function.reset(new AggregateFunctionNullUnary<true>(nested_function, - {make_nullable(data_type_ptr)})); - return function; + return AggregateFunctionPtr(create_aggregate_function_max( + NameArrayMax::name, {make_nullable(data_type_ptr)}, true)); } }; diff --git a/be/src/vec/utils/template_helpers.hpp b/be/src/vec/utils/template_helpers.hpp index 265f74d53d..7c8b9bac73 100644 --- a/be/src/vec/utils/template_helpers.hpp +++ b/be/src/vec/utils/template_helpers.hpp @@ -69,17 +69,6 @@ namespace doris::vectorized { -template <template <typename> typename ClassTemplate, typename... TArgs> -IAggregateFunction* create_class_with_type(const IDataType& argument_type, TArgs&&... args) { - WhichDataType which(argument_type); -#define DISPATCH(TYPE, COLUMN_TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return new ClassTemplate<COLUMN_TYPE>(std::forward<TArgs>(args)...); - TYPE_TO_COLUMN_TYPE(DISPATCH) -#undef DISPATCH - return nullptr; -} - template <typename LoopType, LoopType start, LoopType end, template <LoopType> typename Reducer> struct constexpr_loop_match { template <typename... TArgs> @@ -127,10 +116,6 @@ struct constexpr_2_loop_match { } }; -template <template <bool, bool> typename Reducer> -using constexpr_2_bool_match = - constexpr_2_loop_match<bool, false, true, Reducer, constexpr_bool_match>; - template <typename LoopType, LoopType start, LoopType end, template <LoopType, LoopType, LoopType> typename Reducer, template <template <LoopType, LoopType> typename> typename InnerMatch> @@ -153,11 +138,7 @@ struct constexpr_3_loop_match { } }; -template <template <bool, bool, bool> typename Reducer> -using constexpr_3_bool_match = - constexpr_3_loop_match<bool, false, true, Reducer, constexpr_2_bool_match>; - -std::variant<std::false_type, std::true_type> static inline make_bool_variant(bool condition) { +std::variant<std::false_type, std::true_type> inline make_bool_variant(bool condition) { if (condition) { return std::true_type {}; } else { diff --git a/be/test/vec/aggregate_functions/agg_histogram_test.cpp b/be/test/vec/aggregate_functions/agg_histogram_test.cpp index 3ed16f4434..f47cabb50e 100644 --- a/be/test/vec/aggregate_functions/agg_histogram_test.cpp +++ b/be/test/vec/aggregate_functions/agg_histogram_test.cpp @@ -17,17 +17,16 @@ #include <gtest/gtest.h> -#include "common/logging.h" -#include "gtest/gtest.h" #include "vec/aggregate_functions/aggregate_function.h" -#include "vec/aggregate_functions/aggregate_function_histogram.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/common/arena.h" #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_date.h" #include "vec/data_types/data_type_date_time.h" #include "vec/data_types/data_type_decimal.h" #include "vec/data_types/data_type_number.h" #include "vec/data_types/data_type_string.h" +#include "vec/data_types/data_type_time_v2.h" namespace doris::vectorized { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org