This is an automated email from the ASF dual-hosted git repository. panxiaolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 527eb5b059 [Enchancement](function) nullable inline refactor of min_max_by/bitmap && add register_functio… (#17228) 527eb5b059 is described below commit 527eb5b059e15c22411ba44f2434b47aace0dc9f Author: Pxl <pxl...@qq.com> AuthorDate: Thu Mar 2 00:00:01 2023 +0800 [Enchancement](function) nullable inline refactor of min_max_by/bitmap && add register_functio… (#17228) 1. nullable inline refactor of min_max_by/bitmap/group_concat/histogram/topn 2. add register_function_both method 3. add datetimev2 type creator of min_max_by 4. remove uint16/32/64 in FOR_INTEGER_TYPES --- .../aggregate_functions/aggregate_function_avg.cpp | 3 +- .../aggregate_functions/aggregate_function_bit.cpp | 21 +--- .../aggregate_function_bitmap.cpp | 58 +++++----- .../aggregate_function_group_concat.cpp | 14 ++- .../aggregate_function_histogram.cpp | 46 +++----- .../aggregate_function_min_max.cpp | 11 +- .../aggregate_function_min_max_by.cpp | 124 ++++++++++----------- .../aggregate_function_min_max_by.h | 16 +-- .../aggregate_function_orthogonal_bitmap.cpp | 28 ++--- .../aggregate_function_simple_factory.h | 5 + .../aggregate_functions/aggregate_function_sum.cpp | 5 +- .../aggregate_function_topn.cpp | 81 ++++++++------ .../aggregate_function_uniq.cpp | 3 +- .../aggregate_function_window.cpp | 12 +- be/src/vec/aggregate_functions/helpers.h | 119 +++++--------------- be/src/vec/core/types.h | 9 ++ 16 files changed, 229 insertions(+), 326 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 8bda389f4a..4f493c9529 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -52,7 +52,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name, } void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { - factory.register_function("avg", create_aggregate_function_avg); - factory.register_function("avg", create_aggregate_function_avg, true); + factory.register_function_both("avg", create_aggregate_function_avg); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp index 379df49559..6b9be5c92c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp @@ -47,21 +47,12 @@ AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name, } void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory) { - factory.register_function("group_bit_or", - createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>); - factory.register_function("group_bit_and", - createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>); - factory.register_function("group_bit_xor", - createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>); - - factory.register_function( - "group_bit_or", createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>, true); - factory.register_function("group_bit_and", - createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>, - true); - factory.register_function("group_bit_xor", - createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>, - true); + factory.register_function_both("group_bit_or", + createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>); + factory.register_function_both( + "group_bit_and", createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>); + factory.register_function_both( + "group_bit_xor", createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp index eb9a8fb35c..e2dd7e309d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp @@ -18,50 +18,45 @@ #include "vec/aggregate_functions/aggregate_function_bitmap.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { template <bool nullable, template <bool, typename> class AggregateFunctionTemplate> -static IAggregateFunction* createWithIntDataType(const DataTypes& argument_type) { - auto type = argument_type[0].get(); - if (type->is_nullable()) { - type = assert_cast<const DataTypeNullable*>(type)->get_nested_type().get(); - } +static IAggregateFunction* create_with_int_data_type(const DataTypes& argument_type) { + auto type = remove_nullable(argument_type[0]); WhichDataType which(type); - if (which.idx == TypeIndex::Int8) { - return new AggregateFunctionTemplate<nullable, ColumnVector<Int8>>(argument_type); - } - if (which.idx == TypeIndex::Int16) { - return new AggregateFunctionTemplate<nullable, ColumnVector<Int16>>(argument_type); - } - if (which.idx == TypeIndex::Int32) { - return new AggregateFunctionTemplate<nullable, ColumnVector<Int32>>(argument_type); - } - if (which.idx == TypeIndex::Int64) { - return new AggregateFunctionTemplate<nullable, ColumnVector<Int64>>(argument_type); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) { \ + return new AggregateFunctionTemplate<nullable, ColumnVector<TYPE>>(argument_type); \ } + FOR_INTEGER_TYPES(DISPATCH) +#undef DISPATCH return nullptr; } AggregateFunctionPtr create_aggregate_function_bitmap_union(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<AggregateFunctionBitmapOp<AggregateFunctionBitmapUnionOp>>( - argument_types); + return AggregateFunctionPtr( + creator_without_type::create<AggregateFunctionBitmapOp<AggregateFunctionBitmapUnionOp>>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_bitmap_intersect(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<AggregateFunctionBitmapOp<AggregateFunctionBitmapIntersectOp>>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionBitmapOp<AggregateFunctionBitmapIntersectOp>>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_group_bitmap_xor(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return std::make_shared<AggregateFunctionBitmapOp<AggregateFunctionGroupBitmapXorOp>>( - argument_types); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionBitmapOp<AggregateFunctionGroupBitmapXorOp>>( + result_is_nullable, argument_types)); } AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name, @@ -81,22 +76,19 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::strin const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return std::shared_ptr<IAggregateFunction>( - createWithIntDataType<true, AggregateFunctionBitmapCount>(argument_types)); + create_with_int_data_type<true, AggregateFunctionBitmapCount>(argument_types)); } else { return std::shared_ptr<IAggregateFunction>( - createWithIntDataType<false, AggregateFunctionBitmapCount>(argument_types)); + create_with_int_data_type<false, AggregateFunctionBitmapCount>(argument_types)); } } void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory) { - factory.register_function("bitmap_union", create_aggregate_function_bitmap_union); - factory.register_function("bitmap_intersect", create_aggregate_function_bitmap_intersect); - factory.register_function("group_bitmap_xor", create_aggregate_function_group_bitmap_xor); - factory.register_function("bitmap_union_count", create_aggregate_function_bitmap_union_count); - factory.register_function("bitmap_union_count", create_aggregate_function_bitmap_union_count, - true); - - factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int); - factory.register_function("bitmap_union_int", create_aggregate_function_bitmap_union_int, true); + factory.register_function_both("bitmap_union", create_aggregate_function_bitmap_union); + factory.register_function_both("bitmap_intersect", create_aggregate_function_bitmap_intersect); + factory.register_function_both("group_bitmap_xor", create_aggregate_function_group_bitmap_xor); + factory.register_function_both("bitmap_union_count", + create_aggregate_function_bitmap_union_count); + factory.register_function_both("bitmap_union_int", create_aggregate_function_bitmap_union_int); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp index bcd7becc5e..5bd070ada3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp @@ -17,6 +17,8 @@ #include "vec/aggregate_functions/aggregate_function_group_concat.h" +#include "vec/aggregate_functions/helpers.h" + namespace doris::vectorized { const std::string AggregateFunctionGroupConcatImplStr::separator = ", "; @@ -26,12 +28,14 @@ AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& n const bool result_is_nullable) { if (argument_types.size() == 1) { return AggregateFunctionPtr( - new AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>( - argument_types)); + creator_without_type::create< + AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>>( + result_is_nullable, argument_types)); } else if (argument_types.size() == 2) { return AggregateFunctionPtr( - new AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStrStr>( - argument_types)); + creator_without_type::create< + AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStrStr>>( + result_is_nullable, argument_types)); } LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", @@ -40,6 +44,6 @@ AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& n } void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory) { - factory.register_function("group_concat", create_aggregate_function_group_concat); + factory.register_function_both("group_concat", create_aggregate_function_group_concat); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp index 81dece0c95..77e67ab29a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp @@ -23,56 +23,46 @@ namespace doris::vectorized { template <typename T> -AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_types) { +AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_types, + const bool result_is_nullable) { bool has_input_param = (argument_types.size() == 3); if (has_input_param) { return AggregateFunctionPtr( - new AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, true>( - argument_types)); + creator_without_type::create< + AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, true>>( + result_is_nullable, argument_types)); } else { return AggregateFunctionPtr( - new AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, false>( - argument_types)); + creator_without_type::create< + AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, false>>( + result_is_nullable, argument_types)); } } AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - WhichDataType type(argument_types[0]); + WhichDataType type(remove_nullable(argument_types[0])); - LOG(INFO) << fmt::format("supported input type {} for aggregate function {}", - argument_types[0]->get_name(), name); - -#define DISPATCH(TYPE) \ - if (type.idx == TypeIndex::TYPE) return create_agg_function_histogram<TYPE>(argument_types); +#define DISPATCH(TYPE) \ + if (type.idx == TypeIndex::TYPE) \ + return create_agg_function_histogram<TYPE>(argument_types, result_is_nullable); FOR_NUMERIC_TYPES(DISPATCH) + FOR_DECIMAL_TYPES(DISPATCH) #undef DISPATCH if (type.idx == TypeIndex::String) { - return create_agg_function_histogram<String>(argument_types); + return create_agg_function_histogram<String>(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateTime || type.idx == TypeIndex::Date) { - return create_agg_function_histogram<Int64>(argument_types); + return create_agg_function_histogram<Int64>(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateV2) { - return create_agg_function_histogram<UInt32>(argument_types); + return create_agg_function_histogram<UInt32>(argument_types, result_is_nullable); } if (type.idx == TypeIndex::DateTimeV2) { - return create_agg_function_histogram<UInt64>(argument_types); - } - if (type.idx == TypeIndex::Decimal32) { - return create_agg_function_histogram<Decimal32>(argument_types); - } - if (type.idx == TypeIndex::Decimal64) { - return create_agg_function_histogram<Decimal64>(argument_types); - } - if (type.idx == TypeIndex::Decimal128) { - return create_agg_function_histogram<Decimal128>(argument_types); - } - if (type.idx == TypeIndex::Decimal128I) { - return create_agg_function_histogram<Decimal128I>(argument_types); + return create_agg_function_histogram<UInt64>(argument_types, result_is_nullable); } LOG(WARNING) << fmt::format("unsupported input type {} for aggregate function {}", @@ -81,7 +71,7 @@ AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name } void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& factory) { - factory.register_function("histogram", create_aggregate_function_histogram); + factory.register_function_both("histogram", create_aggregate_function_histogram); factory.register_alias("histogram", "hist"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index 882b532c7e..46606142b2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -97,14 +97,9 @@ AggregateFunctionPtr create_aggregate_function_any(const std::string& name, } void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory) { - factory.register_function("max", create_aggregate_function_max); - factory.register_function("min", create_aggregate_function_min); - factory.register_function("any", create_aggregate_function_any); - - factory.register_function("max", create_aggregate_function_max, true); - factory.register_function("min", create_aggregate_function_min, true); - factory.register_function("any", create_aggregate_function_any, true); - + factory.register_function_both("max", create_aggregate_function_max); + factory.register_function_both("min", create_aggregate_function_min); + factory.register_function_both("any", create_aggregate_function_any); factory.register_alias("any", "any_value"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp index 8a4ad945f9..2252da7721 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp @@ -26,101 +26,95 @@ namespace doris::vectorized { /// min_by, max_by -template <template <typename, bool> class AggregateFunctionTemplate, +template <template <typename> class AggregateFunctionTemplate, template <typename, typename> class Data, typename VT> static IAggregateFunction* create_aggregate_function_min_max_by_impl( - const DataTypes& argument_types) { - const DataTypePtr& value_arg_type = argument_types[0]; - const DataTypePtr& key_arg_type = argument_types[1]; + const DataTypes& argument_types, const bool result_is_nullable) { + WhichDataType which(remove_nullable(argument_types[1])); - WhichDataType which(key_arg_type); -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<TYPE>>, false>( \ - value_arg_type, key_arg_type); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return creator_without_type::create< \ + AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<TYPE>>>>( \ + result_is_nullable, argument_types); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH + +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return creator_without_type::create< \ + AggregateFunctionTemplate<Data<VT, SingleValueDataDecimal<TYPE>>>>( \ + result_is_nullable, argument_types); + FOR_DECIMAL_TYPES(DISPATCH) +#undef DISPATCH + if (which.idx == TypeIndex::String) { - return new AggregateFunctionTemplate<Data<VT, SingleValueDataString>, false>(value_arg_type, - key_arg_type); + return creator_without_type::create< + AggregateFunctionTemplate<Data<VT, SingleValueDataString>>>(result_is_nullable, + argument_types); } if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) { - return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<Int64>>, false>( - value_arg_type, key_arg_type); + return creator_without_type::create< + AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<Int64>>>>( + result_is_nullable, argument_types); } if (which.idx == TypeIndex::DateV2) { - return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<UInt32>>, false>( - value_arg_type, key_arg_type); - } - if (which.idx == TypeIndex::Decimal32) { - return new AggregateFunctionTemplate<Data<VT, SingleValueDataDecimal<Decimal32>>, false>( - value_arg_type, key_arg_type); - } - if (which.idx == TypeIndex::Decimal64) { - return new AggregateFunctionTemplate<Data<VT, SingleValueDataDecimal<Decimal64>>, false>( - value_arg_type, key_arg_type); + return creator_without_type::create< + AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<UInt32>>>>( + result_is_nullable, argument_types); } - if (which.idx == TypeIndex::Decimal128) { - return new AggregateFunctionTemplate<Data<VT, SingleValueDataDecimal<Decimal128>>, false>( - value_arg_type, key_arg_type); - } - if (which.idx == TypeIndex::Decimal128I) { - return new AggregateFunctionTemplate<Data<VT, SingleValueDataDecimal<Decimal128I>>, false>( - value_arg_type, key_arg_type); + if (which.idx == TypeIndex::DateTimeV2) { + return creator_without_type::create< + AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<UInt64>>>>( + result_is_nullable, argument_types); } return nullptr; } /// min_by, max_by -template <template <typename, bool> class AggregateFunctionTemplate, +template <template <typename> class AggregateFunctionTemplate, template <typename, typename> class Data> static IAggregateFunction* create_aggregate_function_min_max_by(const String& name, - const DataTypes& argument_types) { + const DataTypes& argument_types, + const bool result_is_nullable) { assert_binary(name, argument_types); - const DataTypePtr& value_arg_type = argument_types[0]; - - WhichDataType which(value_arg_type); + WhichDataType which(remove_nullable(argument_types[0])); #define DISPATCH(TYPE) \ if (which.idx == TypeIndex::TYPE) \ return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, \ SingleValueDataFixed<TYPE>>( \ - argument_types); + argument_types, result_is_nullable); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH + +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, \ + SingleValueDataDecimal<TYPE>>( \ + argument_types, result_is_nullable); + FOR_DECIMAL_TYPES(DISPATCH) +#undef DISPATCH + if (which.idx == TypeIndex::String) { return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, - SingleValueDataString>(argument_types); + SingleValueDataString>(argument_types, + result_is_nullable); } if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) { return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, SingleValueDataFixed<Int64>>( - argument_types); + argument_types, result_is_nullable); } if (which.idx == TypeIndex::DateV2) { return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, SingleValueDataFixed<UInt32>>( - argument_types); - } - if (which.idx == TypeIndex::Decimal128) { - return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, - SingleValueDataDecimal<Decimal128>>( - argument_types); - } - if (which.idx == TypeIndex::Decimal32) { - return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, - SingleValueDataDecimal<Decimal32>>( - argument_types); - } - if (which.idx == TypeIndex::Decimal64) { - return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, - SingleValueDataDecimal<Decimal64>>( - argument_types); + argument_types, result_is_nullable); } - if (which.idx == TypeIndex::Decimal128I) { + if (which.idx == TypeIndex::DateTimeV2) { return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, - SingleValueDataDecimal<Decimal128I>>( - argument_types); + SingleValueDataFixed<UInt64>>( + argument_types, result_is_nullable); } return nullptr; } @@ -128,22 +122,22 @@ static IAggregateFunction* create_aggregate_function_min_max_by(const String& na AggregateFunctionPtr create_aggregate_function_max_by(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return AggregateFunctionPtr( - create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy, - AggregateFunctionMaxByData>(name, argument_types)); + return AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy, + AggregateFunctionMaxByData>( + name, argument_types, result_is_nullable)); } AggregateFunctionPtr create_aggregate_function_min_by(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - return AggregateFunctionPtr( - create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy, - AggregateFunctionMinByData>(name, argument_types)); + return AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy, + AggregateFunctionMinByData>( + name, argument_types, result_is_nullable)); } void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory& factory) { - factory.register_function("max_by", create_aggregate_function_max_by); - factory.register_function("min_by", create_aggregate_function_min_by); + factory.register_function_both("max_by", create_aggregate_function_max_by); + factory.register_function_both("min_by", create_aggregate_function_min_by); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h index b25e771862..28133dbb5d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h @@ -95,24 +95,22 @@ struct AggregateFunctionMinByData : public AggregateFunctionMinMaxByBaseData<VT, static const char* name() { return "min_by"; } }; -template <typename Data, bool AllocatesMemoryInArena> +template <typename Data> class AggregateFunctionsMinMaxBy final - : public IAggregateFunctionDataHelper< - Data, AggregateFunctionsMinMaxBy<Data, AllocatesMemoryInArena>> { + : public IAggregateFunctionDataHelper<Data, AggregateFunctionsMinMaxBy<Data>> { private: DataTypePtr& value_type; DataTypePtr& key_type; public: - AggregateFunctionsMinMaxBy(const DataTypePtr& value_type_, const DataTypePtr& key_type_) - : IAggregateFunctionDataHelper< - Data, AggregateFunctionsMinMaxBy<Data, AllocatesMemoryInArena>>( - {value_type_, key_type_}), + AggregateFunctionsMinMaxBy(const DataTypes& arguments) + : IAggregateFunctionDataHelper<Data, AggregateFunctionsMinMaxBy<Data>>( + {arguments[0], arguments[1]}), value_type(this->argument_types[0]), key_type(this->argument_types[1]) { if (StringRef(Data::name()) == StringRef("min_by") || StringRef(Data::name()) == StringRef("max_by")) { - CHECK(key_type_->is_comparable()); + CHECK(key_type->is_comparable()); } } @@ -141,8 +139,6 @@ public: this->data(place).read(buf); } - bool allocates_memory_in_arena() const override { return AllocatesMemoryInArena; } - void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { this->data(place).insert_result_into(to); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp index 579fe930cf..d894a5be06 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp @@ -85,26 +85,12 @@ AggregateFunctionPtr create_aggregate_function_orthogonal_bitmap_union_count( } void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory& factory) { - factory.register_function("orthogonal_bitmap_intersect", - create_aggregate_function_orthogonal_bitmap_intersect); - - factory.register_function("orthogonal_bitmap_intersect_count", - create_aggregate_function_orthogonal_bitmap_intersect_count); - - factory.register_function("orthogonal_bitmap_union_count", - create_aggregate_function_orthogonal_bitmap_union_count); - - factory.register_function("intersect_count", create_aggregate_function_intersect_count); - - factory.register_function("orthogonal_bitmap_intersect", - create_aggregate_function_orthogonal_bitmap_intersect, true); - - factory.register_function("orthogonal_bitmap_intersect_count", - create_aggregate_function_orthogonal_bitmap_intersect_count, true); - - factory.register_function("orthogonal_bitmap_union_count", - create_aggregate_function_orthogonal_bitmap_union_count, true); - - factory.register_function("intersect_count", create_aggregate_function_intersect_count, true); + factory.register_function_both("orthogonal_bitmap_intersect", + create_aggregate_function_orthogonal_bitmap_intersect); + factory.register_function_both("orthogonal_bitmap_intersect_count", + create_aggregate_function_orthogonal_bitmap_intersect_count); + factory.register_function_both("orthogonal_bitmap_union_count", + create_aggregate_function_orthogonal_bitmap_union_count); + factory.register_function_both("intersect_count", create_aggregate_function_intersect_count); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index 12298d8aa8..4ebc804d5d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -107,6 +107,11 @@ public: } } + void register_function_both(const std::string& name, const Creator& creator) { + register_function(name, creator, false); + register_function(name, creator, true); + } + void register_alias(const std::string& name, const std::string& alias) { function_alias[alias] = name; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp index 7a08be0c1c..0f7b47193a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp @@ -73,9 +73,8 @@ AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string& nam } void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) { - factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>); - factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>, - true); + factory.register_function_both("sum", + create_aggregate_function_sum<AggregateFunctionSumSimple>); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp index cb4224a322..c57ec934e5 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp @@ -26,10 +26,12 @@ AggregateFunctionPtr create_aggregate_function_topn(const std::string& name, const bool result_is_nullable) { if (argument_types.size() == 2) { return AggregateFunctionPtr( - new AggregateFunctionTopN<AggregateFunctionTopNImplInt>(argument_types)); + creator_without_type::create<AggregateFunctionTopN<AggregateFunctionTopNImplInt>>( + result_is_nullable, argument_types)); } else if (argument_types.size() == 3) { - return AggregateFunctionPtr( - new AggregateFunctionTopN<AggregateFunctionTopNImplIntInt>(argument_types)); + return AggregateFunctionPtr(creator_without_type::create< + AggregateFunctionTopN<AggregateFunctionTopNImplIntInt>>( + result_is_nullable, argument_types)); } LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", @@ -39,44 +41,47 @@ AggregateFunctionPtr create_aggregate_function_topn(const std::string& name, template <template <typename, bool> class AggregateFunctionTemplate, bool has_default_param, bool is_weighted> -AggregateFunctionPtr create_topn_array(const DataTypes& argument_types) { - auto type = argument_types[0].get(); - if (type->is_nullable()) { - type = assert_cast<const DataTypeNullable*>(type)->get_nested_type().get(); - } +AggregateFunctionPtr create_topn_array(const DataTypes& argument_types, + const bool result_is_nullable) { + WhichDataType which(remove_nullable(argument_types[0])); - WhichDataType which(*type); - -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return AggregateFunctionPtr( \ - new AggregateFunctionTopNArray<AggregateFunctionTemplate<TYPE, has_default_param>, \ - TYPE, is_weighted>(argument_types)); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return AggregateFunctionPtr( \ + creator_without_type::create<AggregateFunctionTopNArray< \ + AggregateFunctionTemplate<TYPE, has_default_param>, TYPE, is_weighted>>( \ + result_is_nullable, argument_types)); FOR_NUMERIC_TYPES(DISPATCH) + FOR_DECIMAL_TYPES(DISPATCH) #undef DISPATCH + if (which.is_string_or_fixed_string()) { - return AggregateFunctionPtr(new AggregateFunctionTopNArray< - AggregateFunctionTemplate<std::string, has_default_param>, - std::string, is_weighted>(argument_types)); - } - if (which.is_decimal()) { - return AggregateFunctionPtr(new AggregateFunctionTopNArray< - AggregateFunctionTemplate<Decimal128, has_default_param>, - Decimal128, is_weighted>(argument_types)); + return AggregateFunctionPtr( + creator_without_type::create<AggregateFunctionTopNArray< + AggregateFunctionTemplate<std::string, has_default_param>, std::string, + is_weighted>>(result_is_nullable, argument_types)); } - if (which.is_date_or_datetime() || which.is_date_time_v2()) { + if (which.is_date_or_datetime()) { return AggregateFunctionPtr( - new AggregateFunctionTopNArray<AggregateFunctionTemplate<Int64, has_default_param>, - Int64, is_weighted>(argument_types)); + creator_without_type::create<AggregateFunctionTopNArray< + AggregateFunctionTemplate<Int64, has_default_param>, Int64, is_weighted>>( + result_is_nullable, argument_types)); } if (which.is_date_v2()) { return AggregateFunctionPtr( - new AggregateFunctionTopNArray<AggregateFunctionTemplate<UInt32, has_default_param>, - UInt32, is_weighted>(argument_types)); + creator_without_type::create<AggregateFunctionTopNArray< + AggregateFunctionTemplate<UInt32, has_default_param>, UInt32, is_weighted>>( + result_is_nullable, argument_types)); + } + if (which.is_date_time_v2()) { + return AggregateFunctionPtr( + creator_without_type::create<AggregateFunctionTopNArray< + AggregateFunctionTemplate<UInt64, has_default_param>, UInt64, is_weighted>>( + result_is_nullable, argument_types)); } LOG(WARNING) << fmt::format("Illegal argument type for aggregate function topn_array is: {}", - type->get_name()); + remove_nullable(argument_types[0])->get_name()); return nullptr; } @@ -85,9 +90,11 @@ AggregateFunctionPtr create_aggregate_function_topn_array(const std::string& nam const bool result_is_nullable) { bool has_default_param = (argument_types.size() == 3); if (has_default_param) { - return create_topn_array<AggregateFunctionTopNImplArray, true, false>(argument_types); + return create_topn_array<AggregateFunctionTopNImplArray, true, false>(argument_types, + result_is_nullable); } else { - return create_topn_array<AggregateFunctionTopNImplArray, false, false>(argument_types); + return create_topn_array<AggregateFunctionTopNImplArray, false, false>(argument_types, + result_is_nullable); } } @@ -96,16 +103,18 @@ AggregateFunctionPtr create_aggregate_function_topn_weighted(const std::string& const bool result_is_nullable) { bool has_default_param = (argument_types.size() == 4); if (has_default_param) { - return create_topn_array<AggregateFunctionTopNImplWeight, true, true>(argument_types); + return create_topn_array<AggregateFunctionTopNImplWeight, true, true>(argument_types, + result_is_nullable); } else { - return create_topn_array<AggregateFunctionTopNImplWeight, false, true>(argument_types); + return create_topn_array<AggregateFunctionTopNImplWeight, false, true>(argument_types, + result_is_nullable); } } void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory) { - factory.register_function("topn", create_aggregate_function_topn); - factory.register_function("topn_array", create_aggregate_function_topn_array); - factory.register_function("topn_weighted", create_aggregate_function_topn_weighted); + factory.register_function_both("topn", create_aggregate_function_topn); + factory.register_function_both("topn_array", create_aggregate_function_topn_array); + factory.register_function_both("topn_weighted", create_aggregate_function_topn_weighted); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp b/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp index 399fdb6317..18bd119a21 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp @@ -70,8 +70,7 @@ AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name, void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory) { AggregateFunctionCreator creator = create_aggregate_function_uniq<AggregateFunctionUniqExactData>; - factory.register_function("multi_distinct_count", creator); - factory.register_function("multi_distinct_count", creator, true); + factory.register_function_both("multi_distinct_count", creator); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp b/be/src/vec/aggregate_functions/aggregate_function_window.cpp index 7bb9e94524..a36b9601c2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp @@ -114,14 +114,10 @@ void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& fac void register_aggregate_function_window_lead_lag_first_last( AggregateFunctionSimpleFactory& factory) { - factory.register_function("lead", create_aggregate_function_window_lead); - factory.register_function("lead", create_aggregate_function_window_lead, true); - factory.register_function("lag", create_aggregate_function_window_lag); - factory.register_function("lag", create_aggregate_function_window_lag, true); - factory.register_function("first_value", create_aggregate_function_window_first); - factory.register_function("first_value", create_aggregate_function_window_first, true); - factory.register_function("last_value", create_aggregate_function_window_last); - factory.register_function("last_value", create_aggregate_function_window_last, true); + factory.register_function_both("lead", create_aggregate_function_window_lead); + factory.register_function_both("lag", create_aggregate_function_window_lag); + factory.register_function_both("first_value", create_aggregate_function_window_first); + factory.register_function_both("last_value", create_aggregate_function_window_last); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/helpers.h b/be/src/vec/aggregate_functions/helpers.h index e67e0976d7..7a811871ed 100644 --- a/be/src/vec/aggregate_functions/helpers.h +++ b/be/src/vec/aggregate_functions/helpers.h @@ -25,12 +25,8 @@ #include "vec/data_types/data_type.h" #include "vec/utils/template_helpers.hpp" -// TODO: Should we support decimal in numeric types? #define FOR_INTEGER_TYPES(M) \ M(UInt8) \ - M(UInt16) \ - M(UInt32) \ - M(UInt64) \ M(Int8) \ M(Int16) \ M(Int32) \ @@ -53,49 +49,50 @@ namespace doris::vectorized { -/** Create an aggregate function with a numeric type in the template parameter, depending on the type of the argument. - */ -template <template <typename> class AggregateFunctionTemplate, typename Type> -struct BuilderDirect { - using T = AggregateFunctionTemplate<Type>; -}; -template <template <typename> class AggregateFunctionTemplate, template <typename> class Data, - typename Type> -struct BuilderData { - using T = AggregateFunctionTemplate<Data<Type>>; -}; -template <template <typename> class AggregateFunctionTemplate, template <typename> class Data, - template <typename> class Impl, typename Type> -struct BuilderDataImpl { - using T = AggregateFunctionTemplate<Data<Impl<Type>>>; -}; -template <template <typename, typename> class AggregateFunctionTemplate, - template <typename> class Data, typename Type> -struct BuilderDirectAndData { - using T = AggregateFunctionTemplate<Type, Data<Type>>; +struct creator_without_type { + template <bool multi_arguments, bool f, typename T> + using NullableT = std::conditional_t<multi_arguments, AggregateFunctionNullVariadicInline<T, f>, + AggregateFunctionNullUnaryInline<T, f>>; + + template <typename AggregateFunctionTemplate, typename... TArgs> + static IAggregateFunction* create(const bool result_is_nullable, + const DataTypes& argument_types, TArgs&&... args) { + IAggregateFunction* result(new AggregateFunctionTemplate(std::forward<TArgs>(args)..., + remove_nullable(argument_types))); + if (have_nullable(argument_types)) { + std::visit( + [&](auto multi_arguments, auto result_is_nullable) { + result = new NullableT<multi_arguments, result_is_nullable, + AggregateFunctionTemplate>(result, argument_types); + }, + make_bool_variant(argument_types.size() > 1), + make_bool_variant(result_is_nullable)); + } + return result; + } }; template <template <typename> class AggregateFunctionTemplate> struct CurryDirect { template <typename Type> - using Builder = BuilderDirect<AggregateFunctionTemplate, Type>; + using T = AggregateFunctionTemplate<Type>; }; template <template <typename> class AggregateFunctionTemplate, template <typename> class Data> struct CurryData { template <typename Type> - using Builder = BuilderData<AggregateFunctionTemplate, Data, Type>; + using T = AggregateFunctionTemplate<Data<Type>>; }; template <template <typename> class AggregateFunctionTemplate, template <typename> class Data, template <typename> class Impl> struct CurryDataImpl { template <typename Type> - using Builder = BuilderDataImpl<AggregateFunctionTemplate, Data, Impl, Type>; + using T = AggregateFunctionTemplate<Data<Impl<Type>>>; }; template <template <typename, typename> class AggregateFunctionTemplate, template <typename> class Data> struct CurryDirectAndData { template <typename Type> - using Builder = BuilderDirectAndData<AggregateFunctionTemplate, Data, Type>; + using T = AggregateFunctionTemplate<Type, Data<Type>>; }; template <bool allow_integer, bool allow_float, bool allow_decimal, int define_index = 0> @@ -104,35 +101,10 @@ struct creator_with_type_base { static IAggregateFunction* create_base(const bool result_is_nullable, const DataTypes& argument_types, TArgs&&... args) { WhichDataType which(remove_nullable(argument_types[define_index])); -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) { \ - using T = typename Class::template Builder<TYPE>::T; \ - if (have_nullable(argument_types)) { \ - IAggregateFunction* result = nullptr; \ - if (argument_types.size() > 1) { \ - std::visit( \ - [&](auto result_is_nullable) { \ - result = new AggregateFunctionNullVariadicInline<T, \ - result_is_nullable>( \ - new T(std::forward<TArgs>(args)..., \ - remove_nullable(argument_types)), \ - argument_types); \ - }, \ - make_bool_variant(result_is_nullable)); \ - } else { \ - std::visit( \ - [&](auto result_is_nullable) { \ - result = new AggregateFunctionNullUnaryInline<T, result_is_nullable>( \ - new T(std::forward<TArgs>(args)..., \ - remove_nullable(argument_types)), \ - argument_types); \ - }, \ - make_bool_variant(result_is_nullable)); \ - } \ - return result; \ - } else { \ - return new T(std::forward<TArgs>(args)..., argument_types); \ - } \ +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) { \ + return creator_without_type::create<typename Class::template T<TYPE>>( \ + result_is_nullable, argument_types, std::forward<TArgs>(args)...); \ } if constexpr (allow_integer) { @@ -180,37 +152,4 @@ using creator_with_numeric_type = creator_with_type_base<true, true, false>; using creator_with_decimal_type = creator_with_type_base<false, false, true>; using creator_with_type = creator_with_type_base<true, true, true>; -struct creator_without_type { - template <typename AggregateFunctionTemplate, typename... TArgs> - static IAggregateFunction* create(const bool result_is_nullable, - const DataTypes& argument_types, TArgs&&... args) { - if (have_nullable(argument_types)) { - IAggregateFunction* result = nullptr; - if (argument_types.size() > 1) { - std::visit( - [&](auto result_is_nullable) { - result = new AggregateFunctionNullVariadicInline< - AggregateFunctionTemplate, result_is_nullable>( - new AggregateFunctionTemplate(std::forward<TArgs>(args)..., - remove_nullable(argument_types)), - argument_types); - }, - make_bool_variant(result_is_nullable)); - } else { - std::visit( - [&](auto result_is_nullable) { - result = new AggregateFunctionNullUnaryInline<AggregateFunctionTemplate, - result_is_nullable>( - new AggregateFunctionTemplate(std::forward<TArgs>(args)..., - remove_nullable(argument_types)), - argument_types); - }, - make_bool_variant(result_is_nullable)); - } - return result; - } else { - return new AggregateFunctionTemplate(std::forward<TArgs>(args)..., argument_types); - } - } -}; } // namespace doris::vectorized diff --git a/be/src/vec/core/types.h b/be/src/vec/core/types.h index f8f0203208..177d2166ed 100644 --- a/be/src/vec/core/types.h +++ b/be/src/vec/core/types.h @@ -626,6 +626,15 @@ struct std::hash<doris::vectorized::Decimal128> { } }; +template <> +struct std::hash<doris::vectorized::Decimal128I> { + size_t operator()(const doris::vectorized::Decimal<doris::vectorized::Int128I>& x) const { + return std::hash<doris::vectorized::Int64>()(x.value >> 64) ^ + std::hash<doris::vectorized::Int64>()( + x.value & std::numeric_limits<doris::vectorized::UInt64>::max()); + } +}; + constexpr bool typeindex_is_int(doris::vectorized::TypeIndex index) { using TypeIndex = doris::vectorized::TypeIndex; switch (index) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org