This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch dev-1.0.1 in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
commit 314a6dceb9a6decab6604cbcbc4920f796ff6451 Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com> AuthorDate: Tue Mar 29 18:18:06 2022 +0800 [Vectorized][refactor] refactor stddev/variance agg functions (#8660) * [Vectorized][refactor] refactor stddev agg functions --- .../vec/aggregate_functions/aggregate_function.h | 5 - .../aggregate_functions/aggregate_function_null.h | 11 +- .../aggregate_function_simple_factory.cpp | 6 +- .../aggregate_function_stddev.cpp | 61 ++++++---- .../aggregate_function_stddev.h | 126 +++++++++++---------- 5 files changed, 110 insertions(+), 99 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index c3b5072..5cf529c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -109,11 +109,6 @@ public: */ virtual bool is_state() const { return false; } - /// if return false, during insert_result_into function, you colud get nullable result column, - /// so could insert to null value by yourself, rather than by AggregateFunctionNullBase; - /// because you maybe be calculate a invalid value, but want to use null replace it; - virtual bool insert_to_null_default() const { return true; } - /** Contains a loop with calls to "add" function. You can collect arguments into array "places" * and do a single call to "add_batch" for devirtualization and inlining. */ diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h b/be/src/vec/aggregate_functions/aggregate_function_null.h index 83cae6f..55e9100 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_null.h +++ b/be/src/vec/aggregate_functions/aggregate_function_null.h @@ -144,14 +144,9 @@ public: if constexpr (result_is_nullable) { ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to); if (get_flag(place)) { - if (nested_function->insert_to_null_default()) { - nested_function->insert_result_into(nested_place(place), - to_concrete.get_nested_column()); - to_concrete.get_null_map_data().push_back(0); - } else { - nested_function->insert_result_into( - nested_place(place), to); //want to insert into null value by self - } + nested_function->insert_result_into(nested_place(place), + to_concrete.get_nested_column()); + to_concrete.get_null_map_data().push_back(0); } else { to_concrete.insert_default(); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index d578eef..3be7d18 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -37,7 +37,8 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& factory); -void register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory); @@ -56,7 +57,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_combinator_distinct(instance); register_aggregate_function_reader(instance); // register aggregate function for agg reader register_aggregate_function_window_rank(instance); - register_aggregate_function_stddev_variance(instance); + register_aggregate_function_stddev_variance_pop(instance); register_aggregate_function_topn(instance); register_aggregate_function_approx_count_distinct(instance); register_aggregate_function_group_concat(instance); @@ -65,6 +66,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { // if you only register function with no nullable, and wants to add nullable automatically, you should place function above this line register_aggregate_function_combinator_null(instance); + register_aggregate_function_stddev_variance_samp(instance); register_aggregate_function_reader_no_spread(instance); register_aggregate_function_window_lead_lag(instance); register_aggregate_function_HLL_union_agg(instance); diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp index 2b06423..31f4556 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp @@ -23,8 +23,9 @@ #include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { -template <template <typename> class AggregateFunctionTemplate, template <typename> class NameData, - template <typename, typename> class Data, bool is_stddev> +template <template <typename, bool> class AggregateFunctionTemplate, + template <typename> class NameData, template <typename, typename> class Data, + bool is_stddev, bool is_nullable = false> static IAggregateFunction* create_function_single_value(const String& name, const DataTypes& argument_types, const Array& parameters) { @@ -32,40 +33,42 @@ static IAggregateFunction* create_function_single_value(const String& name, if (type->is_nullable()) { type = assert_cast<const DataTypeNullable*>(type)->get_nested_type().get(); } + WhichDataType which(*type); -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return new AggregateFunctionTemplate<NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>>( \ - argument_types); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return new AggregateFunctionTemplate<NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>, \ + is_nullable>(argument_types); + FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH if (which.is_decimal()) { - return new AggregateFunctionTemplate< - NameData<Data<Decimal128, BaseDatadecimal<is_stddev>>>>(argument_types); + return new AggregateFunctionTemplate<NameData<Data<Decimal128, BaseDatadecimal<is_stddev>>>, + is_nullable>(argument_types); } DCHECK(false) << "with unknowed type, failed in create_aggregate_function_stddev_variance"; return nullptr; } -template <bool is_stddev> +template <bool is_stddev, bool is_nullable> AggregateFunctionPtr create_aggregate_function_variance_samp(const std::string& name, const DataTypes& argument_types, const Array& parameters, const bool result_is_nullable) { return AggregateFunctionPtr( - create_function_single_value<AggregateFunctionStddevSamp, VarianceSampData, SampData, - is_stddev>(name, argument_types, parameters)); + create_function_single_value<AggregateFunctionSamp, VarianceSampName, SampData, + is_stddev, is_nullable>(name, argument_types, parameters)); } -template <bool is_stddev> +template <bool is_stddev, bool is_nullable> AggregateFunctionPtr create_aggregate_function_stddev_samp(const std::string& name, const DataTypes& argument_types, const Array& parameters, const bool result_is_nullable) { return AggregateFunctionPtr( - create_function_single_value<AggregateFunctionStddevSamp, StddevSampData, SampData, - is_stddev>(name, argument_types, parameters)); + create_function_single_value<AggregateFunctionSamp, StddevSampName, SampData, is_stddev, + is_nullable>(name, argument_types, parameters)); } template <bool is_stddev> @@ -74,8 +77,8 @@ AggregateFunctionPtr create_aggregate_function_variance_pop(const std::string& n const Array& parameters, const bool result_is_nullable) { return AggregateFunctionPtr( - create_function_single_value<AggregateFunctionStddevSamp, VarianceData, PopData, - is_stddev>(name, argument_types, parameters)); + create_function_single_value<AggregateFunctionPop, VarianceName, PopData, is_stddev>( + name, argument_types, parameters)); } template <bool is_stddev> @@ -84,21 +87,29 @@ AggregateFunctionPtr create_aggregate_function_stddev_pop(const std::string& nam const Array& parameters, const bool result_is_nullable) { return AggregateFunctionPtr( - create_function_single_value<AggregateFunctionStddevSamp, StddevData, PopData, - is_stddev>(name, argument_types, parameters)); + create_function_single_value<AggregateFunctionPop, StddevName, PopData, is_stddev>( + name, argument_types, parameters)); } -void register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& factory) { - factory.register_function("variance_samp", create_aggregate_function_variance_samp<false>); - factory.register_function("variance_samp", create_aggregate_function_variance_samp<false>, true); - factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true>); - factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true>, true); - factory.register_alias("variance_samp", "var_samp"); - +void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& factory) { factory.register_function("variance", create_aggregate_function_variance_pop<false>); factory.register_alias("variance", "var_pop"); factory.register_alias("variance", "variance_pop"); factory.register_function("stddev", create_aggregate_function_stddev_pop<true>); factory.register_alias("stddev", "stddev_pop"); } + +void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory& factory) { + // _samp<bool, bool>: first indicate is stddev or variance function + // second indicate is arg nullable column + factory.register_function("variance_samp", + create_aggregate_function_variance_samp<false, false>, false); + factory.register_function("variance_samp", create_aggregate_function_variance_samp<false, true>, + true); + factory.register_alias("variance_samp", "var_samp"); + factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true, false>, + false); + factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true, true>, + true); +} } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h b/be/src/vec/aggregate_functions/aggregate_function_stddev.h index 83c4041..8821232 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h @@ -69,10 +69,6 @@ struct BaseData { return get_result(res); } - static const DataTypePtr get_return_type() { - return std::make_shared<DataTypeNumber<Float64>>(); - } - void merge(const BaseData& rhs) { if (rhs.count == 0) { return; @@ -84,8 +80,8 @@ struct BaseData { count = sum_count; } - virtual void add(const IColumn** columns, size_t row_num) { - const auto& sources = static_cast<const ColumnVector<T>&>(*columns[0]); + void add(const IColumn* column, size_t row_num) { + const auto& sources = static_cast<const ColumnVector<T>&>(*column); double source_data = sources.get_data()[row_num]; double delta = source_data - mean; @@ -146,10 +142,6 @@ struct BaseDatadecimal { return get_result(res); } - static const DataTypePtr get_return_type() { - return std::make_shared<DataTypeDecimal<Decimal128>>(27, 9); - } - void merge(const BaseDatadecimal& rhs) { if (rhs.count == 0) { return; @@ -166,9 +158,9 @@ struct BaseDatadecimal { count += rhs.count; } - virtual void add(const IColumn** columns, size_t row_num) { + void add(const IColumn* column, size_t row_num) { DecimalV2Value source_data = DecimalV2Value(); - const auto& sources = static_cast<const ColumnDecimal<Decimal128>&>(*columns[0]); + const auto& sources = static_cast<const ColumnDecimal<Decimal128>&>(*column); source_data = (DecimalV2Value)sources.get_data()[row_num]; DecimalV2Value new_count = DecimalV2Value(); @@ -202,13 +194,33 @@ struct PopData : Data { } }; +template <typename Data> +struct StddevName : Data { + static const char* name() { return "stddev"; } +}; + +template <typename Data> +struct VarianceName : Data { + static const char* name() { return "variance"; } +}; + +template <typename Data> +struct VarianceSampName : Data { + static const char* name() { return "variance_samp"; } +}; + +template <typename Data> +struct StddevSampName : Data { + static const char* name() { return "stddev_samp"; } +}; + template <typename T, typename Data> struct SampData : Data { using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<Float64>>; void insert_result_into(IColumn& to) const { ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); - if (this->count == 1) { + if (this->count == 1 || this->count == 0) { nullable_column.insert_default(); } else { auto& col = static_cast<ColVecResult&>(nullable_column.get_nested_column()); @@ -220,61 +232,40 @@ struct SampData : Data { nullable_column.get_null_map_data().push_back(0); } } - - static const DataTypePtr get_return_type() { - return make_nullable(Data::get_return_type()); - } - - void add(const IColumn** columns, size_t row_num) override { - if (columns[0]->is_nullable()) { - const auto& nullable_column = assert_cast<const ColumnNullable&>(*columns[0]); - if (!nullable_column.is_null_at(row_num)) { - const IColumn* new_columns[1]; - new_columns[0] = &nullable_column.get_nested_column(); - Data::add(new_columns, row_num); - } - } else { - Data::add(columns, row_num); - } - } - -}; - -template <typename Data> -struct StddevData : Data { - static const char* name() { return "stddev"; } -}; - -template <typename Data> -struct VarianceData : Data { - static const char* name() { return "variance"; } -}; - -template <typename Data> -struct VarianceSampData : Data { - static const char* name() { return "variance_samp"; } -}; - -template <typename Data> -struct StddevSampData : Data { - static const char* name() { return "stddev_samp"; } }; -template <typename Data> -class AggregateFunctionStddevSamp final - : public IAggregateFunctionDataHelper<Data, AggregateFunctionStddevSamp<Data>> { +template <bool is_pop, typename Data, bool is_nullable> +class AggregateFunctionSampVariance + : public IAggregateFunctionDataHelper<Data, AggregateFunctionSampVariance<is_pop, Data, is_nullable>> { public: - AggregateFunctionStddevSamp(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper<Data, AggregateFunctionStddevSamp<Data>>(argument_types_, - {}) {} + AggregateFunctionSampVariance(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<Data, AggregateFunctionSampVariance<is_pop, Data, is_nullable>>( + argument_types_, {}) {} String get_name() const override { return Data::name(); } - DataTypePtr get_return_type() const override { return Data::get_return_type(); } + DataTypePtr get_return_type() const override { + if constexpr (is_pop) { + return std::make_shared<DataTypeFloat64>(); + } else { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + } void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - this->data(place).add(columns, row_num); + if constexpr (is_pop) { + this->data(place).add(columns[0], row_num); + } else { + if constexpr (is_nullable) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[0]); + if (!nullable_column->is_null_at(row_num)) { + this->data(place).add(&nullable_column->get_nested_column(), row_num); + } + } else { + this->data(place).add(columns[0], row_num); + } + } } void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } @@ -298,4 +289,21 @@ public: } }; +//samp function it's always nullables, it's need to handle nullable column +//so return type and add function should processing null values +template <typename Data, bool is_nullable> +class AggregateFunctionSamp final: public AggregateFunctionSampVariance<false, Data, is_nullable> { +public: + AggregateFunctionSamp(const DataTypes& argument_types_) + : AggregateFunctionSampVariance<false, Data, is_nullable>(argument_types_) {} +}; + +//pop function have use AggregateFunctionNullBase function, so needn't processing null values +template <typename Data, bool is_nullable> +class AggregateFunctionPop final: public AggregateFunctionSampVariance<true, Data, is_nullable> { +public: + AggregateFunctionPop(const DataTypes& argument_types_) + : AggregateFunctionSampVariance<true, Data, is_nullable>(argument_types_) {} +}; + } // namespace doris::vectorized --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org