This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch dev-1.0.1
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git

commit 314a6dceb9a6decab6604cbcbc4920f796ff6451
Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com>
AuthorDate: Tue Mar 29 18:18:06 2022 +0800

    [Vectorized][refactor] refactor stddev/variance agg functions (#8660)
    
    * [Vectorized][refactor] refactor stddev agg functions
---
 .../vec/aggregate_functions/aggregate_function.h   |   5 -
 .../aggregate_functions/aggregate_function_null.h  |  11 +-
 .../aggregate_function_simple_factory.cpp          |   6 +-
 .../aggregate_function_stddev.cpp                  |  61 ++++++----
 .../aggregate_function_stddev.h                    | 126 +++++++++++----------
 5 files changed, 110 insertions(+), 99 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function.h 
b/be/src/vec/aggregate_functions/aggregate_function.h
index c3b5072..5cf529c 100644
--- a/be/src/vec/aggregate_functions/aggregate_function.h
+++ b/be/src/vec/aggregate_functions/aggregate_function.h
@@ -109,11 +109,6 @@ public:
       */
     virtual bool is_state() const { return false; }
 
-    /// if return false, during insert_result_into function, you colud get 
nullable result column,
-    /// so could insert to null value by yourself, rather than by 
AggregateFunctionNullBase;
-    /// because you maybe be calculate a invalid value, but want to use null 
replace it;
-    virtual bool insert_to_null_default() const { return true; }
-
     /** Contains a loop with calls to "add" function. You can collect 
arguments into array "places"
       *  and do a single call to "add_batch" for devirtualization and inlining.
       */
diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h 
b/be/src/vec/aggregate_functions/aggregate_function_null.h
index 83cae6f..55e9100 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_null.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_null.h
@@ -144,14 +144,9 @@ public:
         if constexpr (result_is_nullable) {
             ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
             if (get_flag(place)) {
-                if (nested_function->insert_to_null_default()) {
-                    nested_function->insert_result_into(nested_place(place),
-                                                        
to_concrete.get_nested_column());
-                    to_concrete.get_null_map_data().push_back(0);
-                } else {
-                    nested_function->insert_result_into(
-                            nested_place(place), to); //want to insert into 
null value by self
-                }
+                nested_function->insert_result_into(nested_place(place),
+                                                    
to_concrete.get_nested_column());
+                to_concrete.get_null_map_data().push_back(0);
             } else {
                 to_concrete.insert_default();
             }
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index d578eef..3be7d18 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -37,7 +37,8 @@ void 
register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact
 void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& 
factory);
 void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& 
factory);
 void 
register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& 
factory);
-void 
register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& 
factory);
+void 
register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& 
factory);
+void 
register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory&
 factory);
 void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
 void 
register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory&
 factory);
 void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& 
factory);
@@ -56,7 +57,7 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
         register_aggregate_function_combinator_distinct(instance);
         register_aggregate_function_reader(instance); // register aggregate 
function for agg reader
         register_aggregate_function_window_rank(instance);
-        register_aggregate_function_stddev_variance(instance);
+        register_aggregate_function_stddev_variance_pop(instance);
         register_aggregate_function_topn(instance);
         register_aggregate_function_approx_count_distinct(instance);
         register_aggregate_function_group_concat(instance);
@@ -65,6 +66,7 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
         // if you only register function with no nullable, and wants to add 
nullable automatically, you should place function above this line
         register_aggregate_function_combinator_null(instance);
 
+        register_aggregate_function_stddev_variance_samp(instance);
         register_aggregate_function_reader_no_spread(instance);
         register_aggregate_function_window_lead_lag(instance);
         register_aggregate_function_HLL_union_agg(instance);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
index 2b06423..31f4556 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
@@ -23,8 +23,9 @@
 #include "vec/aggregate_functions/helpers.h"
 namespace doris::vectorized {
 
-template <template <typename> class AggregateFunctionTemplate, template 
<typename> class NameData,
-          template <typename, typename> class Data, bool is_stddev>
+template <template <typename, bool> class AggregateFunctionTemplate,
+          template <typename> class NameData, template <typename, typename> 
class Data,
+          bool is_stddev, bool is_nullable = false>
 static IAggregateFunction* create_function_single_value(const String& name,
                                                         const DataTypes& 
argument_types,
                                                         const Array& 
parameters) {
@@ -32,40 +33,42 @@ static IAggregateFunction* 
create_function_single_value(const String& name,
     if (type->is_nullable()) {
         type = assert_cast<const 
DataTypeNullable*>(type)->get_nested_type().get();
     }
+
     WhichDataType which(*type);
 
-#define DISPATCH(TYPE)                                                         
                \
-    if (which.idx == TypeIndex::TYPE)                                          
                \
-        return new AggregateFunctionTemplate<NameData<Data<TYPE, 
BaseData<TYPE, is_stddev>>>>( \
-                argument_types);
+#define DISPATCH(TYPE)                                                         
               \
+    if (which.idx == TypeIndex::TYPE)                                          
               \
+        return new AggregateFunctionTemplate<NameData<Data<TYPE, 
BaseData<TYPE, is_stddev>>>, \
+                                             is_nullable>(argument_types);
+
     FOR_NUMERIC_TYPES(DISPATCH)
 #undef DISPATCH
     if (which.is_decimal()) {
-        return new AggregateFunctionTemplate<
-                NameData<Data<Decimal128, 
BaseDatadecimal<is_stddev>>>>(argument_types);
+        return new AggregateFunctionTemplate<NameData<Data<Decimal128, 
BaseDatadecimal<is_stddev>>>,
+                                             is_nullable>(argument_types);
     }
     DCHECK(false) << "with unknowed type, failed in  
create_aggregate_function_stddev_variance";
     return nullptr;
 }
 
-template <bool is_stddev>
+template <bool is_stddev, bool is_nullable>
 AggregateFunctionPtr create_aggregate_function_variance_samp(const 
std::string& name,
                                                              const DataTypes& 
argument_types,
                                                              const Array& 
parameters,
                                                              const bool 
result_is_nullable) {
     return AggregateFunctionPtr(
-            create_function_single_value<AggregateFunctionStddevSamp, 
VarianceSampData, SampData,
-                                         is_stddev>(name, argument_types, 
parameters));
+            create_function_single_value<AggregateFunctionSamp, 
VarianceSampName, SampData,
+                                         is_stddev, is_nullable>(name, 
argument_types, parameters));
 }
 
-template <bool is_stddev>
+template <bool is_stddev, bool is_nullable>
 AggregateFunctionPtr create_aggregate_function_stddev_samp(const std::string& 
name,
                                                            const DataTypes& 
argument_types,
                                                            const Array& 
parameters,
                                                            const bool 
result_is_nullable) {
     return AggregateFunctionPtr(
-            create_function_single_value<AggregateFunctionStddevSamp, 
StddevSampData, SampData,
-                                         is_stddev>(name, argument_types, 
parameters));
+            create_function_single_value<AggregateFunctionSamp, 
StddevSampName, SampData, is_stddev,
+                                         is_nullable>(name, argument_types, 
parameters));
 }
 
 template <bool is_stddev>
@@ -74,8 +77,8 @@ AggregateFunctionPtr 
create_aggregate_function_variance_pop(const std::string& n
                                                             const Array& 
parameters,
                                                             const bool 
result_is_nullable) {
     return AggregateFunctionPtr(
-            create_function_single_value<AggregateFunctionStddevSamp, 
VarianceData, PopData,
-                                         is_stddev>(name, argument_types, 
parameters));
+            create_function_single_value<AggregateFunctionPop, VarianceName, 
PopData, is_stddev>(
+                    name, argument_types, parameters));
 }
 
 template <bool is_stddev>
@@ -84,21 +87,29 @@ AggregateFunctionPtr 
create_aggregate_function_stddev_pop(const std::string& nam
                                                           const Array& 
parameters,
                                                           const bool 
result_is_nullable) {
     return AggregateFunctionPtr(
-            create_function_single_value<AggregateFunctionStddevSamp, 
StddevData, PopData,
-                                         is_stddev>(name, argument_types, 
parameters));
+            create_function_single_value<AggregateFunctionPop, StddevName, 
PopData, is_stddev>(
+                    name, argument_types, parameters));
 }
 
-void 
register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& 
factory) {
-    factory.register_function("variance_samp", 
create_aggregate_function_variance_samp<false>);
-    factory.register_function("variance_samp", 
create_aggregate_function_variance_samp<false>, true);
-    factory.register_function("stddev_samp", 
create_aggregate_function_stddev_samp<true>);
-    factory.register_function("stddev_samp", 
create_aggregate_function_stddev_samp<true>, true);
-    factory.register_alias("variance_samp", "var_samp");
-
+void 
register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& 
factory) {
     factory.register_function("variance", 
create_aggregate_function_variance_pop<false>);
     factory.register_alias("variance", "var_pop");
     factory.register_alias("variance", "variance_pop");
     factory.register_function("stddev", 
create_aggregate_function_stddev_pop<true>);
     factory.register_alias("stddev", "stddev_pop");
 }
+
+void 
register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory&
 factory) {
+    // _samp<bool, bool>: first  indicate is stddev or variance function
+    //                    second indicate is arg nullable column
+    factory.register_function("variance_samp",
+                              create_aggregate_function_variance_samp<false, 
false>, false);
+    factory.register_function("variance_samp", 
create_aggregate_function_variance_samp<false, true>,
+                              true);
+    factory.register_alias("variance_samp", "var_samp");
+    factory.register_function("stddev_samp", 
create_aggregate_function_stddev_samp<true, false>,
+                              false);
+    factory.register_function("stddev_samp", 
create_aggregate_function_stddev_samp<true, true>,
+                              true);
+}
 } // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h 
b/be/src/vec/aggregate_functions/aggregate_function_stddev.h
index 83c4041..8821232 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h
@@ -69,10 +69,6 @@ struct BaseData {
         return get_result(res);
     }
 
-    static const DataTypePtr get_return_type() {
-        return std::make_shared<DataTypeNumber<Float64>>();
-    }
-
     void merge(const BaseData& rhs) {
         if (rhs.count == 0) {
             return;
@@ -84,8 +80,8 @@ struct BaseData {
         count = sum_count;
     }
 
-    virtual void add(const IColumn** columns, size_t row_num) {
-        const auto& sources = static_cast<const ColumnVector<T>&>(*columns[0]);
+    void add(const IColumn* column, size_t row_num) {
+        const auto& sources = static_cast<const ColumnVector<T>&>(*column);
         double source_data = sources.get_data()[row_num];
 
         double delta = source_data - mean;
@@ -146,10 +142,6 @@ struct BaseDatadecimal {
         return get_result(res);
     }
 
-    static const DataTypePtr get_return_type() {
-        return std::make_shared<DataTypeDecimal<Decimal128>>(27, 9);
-    }
-
     void merge(const BaseDatadecimal& rhs) {
         if (rhs.count == 0) {
             return;
@@ -166,9 +158,9 @@ struct BaseDatadecimal {
         count += rhs.count;
     }
 
-    virtual void add(const IColumn** columns, size_t row_num) {
+    void add(const IColumn* column, size_t row_num) {
         DecimalV2Value source_data = DecimalV2Value();
-        const auto& sources = static_cast<const 
ColumnDecimal<Decimal128>&>(*columns[0]);
+        const auto& sources = static_cast<const 
ColumnDecimal<Decimal128>&>(*column);
         source_data = (DecimalV2Value)sources.get_data()[row_num];
 
         DecimalV2Value new_count = DecimalV2Value();
@@ -202,13 +194,33 @@ struct PopData : Data {
     }
 };
 
+template <typename Data>
+struct StddevName : Data {
+    static const char* name() { return "stddev"; }
+};
+
+template <typename Data>
+struct VarianceName : Data {
+    static const char* name() { return "variance"; }
+};
+
+template <typename Data>
+struct VarianceSampName : Data {
+    static const char* name() { return "variance_samp"; }
+};
+
+template <typename Data>
+struct StddevSampName : Data {
+    static const char* name() { return "stddev_samp"; }
+};
+
 template <typename T, typename Data>
 struct SampData : Data {
     using ColVecResult = std::conditional_t<IsDecimalNumber<T>, 
ColumnDecimal<Decimal128>,
                                             ColumnVector<Float64>>;
     void insert_result_into(IColumn& to) const {
         ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to);
-        if (this->count == 1) {
+        if (this->count == 1 || this->count == 0) {
             nullable_column.insert_default();
         } else {
             auto& col = 
static_cast<ColVecResult&>(nullable_column.get_nested_column());
@@ -220,61 +232,40 @@ struct SampData : Data {
             nullable_column.get_null_map_data().push_back(0);
         }
     }
-
-    static const DataTypePtr get_return_type() {
-        return make_nullable(Data::get_return_type());
-    }
-
-    void add(const IColumn** columns, size_t row_num) override {
-        if (columns[0]->is_nullable()) {
-            const auto& nullable_column = assert_cast<const 
ColumnNullable&>(*columns[0]);
-            if (!nullable_column.is_null_at(row_num)) {
-                const IColumn* new_columns[1];
-                new_columns[0] = &nullable_column.get_nested_column();
-                Data::add(new_columns, row_num);
-            }
-        } else {
-            Data::add(columns, row_num);
-        }
-    }
-
-};
-
-template <typename Data>
-struct StddevData : Data {
-    static const char* name() { return "stddev"; }
-};
-
-template <typename Data>
-struct VarianceData : Data {
-    static const char* name() { return "variance"; }
-};
-
-template <typename Data>
-struct VarianceSampData : Data {
-    static const char* name() { return "variance_samp"; }
-};
-
-template <typename Data>
-struct StddevSampData : Data {
-    static const char* name() { return "stddev_samp"; }
 };
 
-template <typename Data>
-class AggregateFunctionStddevSamp final
-        : public IAggregateFunctionDataHelper<Data, 
AggregateFunctionStddevSamp<Data>> {
+template <bool is_pop, typename Data, bool is_nullable>
+class AggregateFunctionSampVariance
+        : public IAggregateFunctionDataHelper<Data, 
AggregateFunctionSampVariance<is_pop, Data, is_nullable>> {
 public:
-    AggregateFunctionStddevSamp(const DataTypes& argument_types_)
-            : IAggregateFunctionDataHelper<Data, 
AggregateFunctionStddevSamp<Data>>(argument_types_,
-                                                                               
     {}) {}
+    AggregateFunctionSampVariance(const DataTypes& argument_types_)
+            : IAggregateFunctionDataHelper<Data, 
AggregateFunctionSampVariance<is_pop, Data, is_nullable>>(
+                      argument_types_, {}) {}
 
     String get_name() const override { return Data::name(); }
 
-    DataTypePtr get_return_type() const override { return 
Data::get_return_type(); }
+    DataTypePtr get_return_type() const override {
+        if constexpr (is_pop) {
+            return std::make_shared<DataTypeFloat64>();
+        } else {
+            return make_nullable(std::make_shared<DataTypeFloat64>());
+        }
+    }
 
     void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
              Arena*) const override {
-        this->data(place).add(columns, row_num);
+        if constexpr (is_pop) {
+            this->data(place).add(columns[0], row_num);
+        } else {
+            if constexpr (is_nullable) {
+                const auto* nullable_column = 
check_and_get_column<ColumnNullable>(columns[0]);
+                if (!nullable_column->is_null_at(row_num)) {
+                    
this->data(place).add(&nullable_column->get_nested_column(), row_num);
+                }
+            } else {
+                this->data(place).add(columns[0], row_num);
+            }
+        }
     }
 
     void reset(AggregateDataPtr __restrict place) const override { 
this->data(place).reset(); }
@@ -298,4 +289,21 @@ public:
     }
 };
 
+//samp function it's always nullables, it's need to handle nullable column
+//so return type and add function should processing null values
+template <typename Data, bool is_nullable>
+class AggregateFunctionSamp final: public AggregateFunctionSampVariance<false, 
Data, is_nullable> {
+public:
+    AggregateFunctionSamp(const DataTypes& argument_types_)
+            : AggregateFunctionSampVariance<false, Data, 
is_nullable>(argument_types_) {}
+};
+
+//pop function have use AggregateFunctionNullBase function, so needn't 
processing null values
+template <typename Data, bool is_nullable>
+class AggregateFunctionPop final: public AggregateFunctionSampVariance<true, 
Data, is_nullable> {
+public:
+    AggregateFunctionPop(const DataTypes& argument_types_)
+            : AggregateFunctionSampVariance<true, Data, 
is_nullable>(argument_types_) {}
+};
+
 } // namespace doris::vectorized

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to