This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch dev-1.0.1
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git

commit 9e6a213039cdb9ab2b3251c6a865d87d167f8b6e
Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com>
AuthorDate: Tue Mar 29 14:47:39 2022 +0800

    [Vectorized][Bug] fix percentile_approx function to return always nullable 
(#8572)
---
 .../aggregate_function_percentile_approx.cpp       | 21 +++--
 .../aggregate_function_percentile_approx.h         | 96 ++++++++++++++++++----
 .../aggregate_function_simple_factory.cpp          |  5 +-
 .../apache/doris/catalog/AggregateFunction.java    |  2 +-
 4 files changed, 100 insertions(+), 24 deletions(-)

diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp
index 976565f..0a5ffda 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp
@@ -24,17 +24,20 @@
 
 namespace doris::vectorized {
 
+template <bool is_nullable>
 AggregateFunctionPtr create_aggregate_function_percentile_approx(const 
std::string& name,
                                                                  const 
DataTypes& argument_types,
                                                                  const Array& 
parameters,
                                                                  const bool 
result_is_nullable) {
-
     if (argument_types.size() == 1) {
-        return 
std::make_shared<AggregateFunctionPercentileApproxMerge>(argument_types);
+        return 
std::make_shared<AggregateFunctionPercentileApproxMerge<is_nullable>>(
+                argument_types);
     } else if (argument_types.size() == 2) {
-        return 
std::make_shared<AggregateFunctionPercentileApproxTwoParams>(argument_types);
+        return 
std::make_shared<AggregateFunctionPercentileApproxTwoParams<is_nullable>>(
+                argument_types);
     } else if (argument_types.size() == 3) {
-        return 
std::make_shared<AggregateFunctionPercentileApproxThreeParams>(argument_types);
+        return 
std::make_shared<AggregateFunctionPercentileApproxThreeParams<is_nullable>>(
+                argument_types);
     }
     LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate 
function {}",
                                 argument_types.size(), name);
@@ -50,8 +53,14 @@ AggregateFunctionPtr 
create_aggregate_function_percentile(const std::string& nam
     return std::make_shared<AggregateFunctionPercentile>(argument_types);
 }
 
-void 
register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& 
factory) {
+void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& 
factory) {
     factory.register_function("percentile", 
create_aggregate_function_percentile);
-    factory.register_function("percentile_approx", 
create_aggregate_function_percentile_approx);
+}
+
+void 
register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& 
factory) {
+    factory.register_function("percentile_approx",
+                              
create_aggregate_function_percentile_approx<false>, false);
+    factory.register_function("percentile_approx",
+                              
create_aggregate_function_percentile_approx<true>, true);
 }
 } // namespace doris::vectorized
\ No newline at end of file
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h 
b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h
index f7b620b..3e5576b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h
@@ -42,8 +42,11 @@ struct PercentileApproxState {
 
     void write(BufferWritable& buf) const {
         write_binary(init_flag, buf);
-        write_binary(target_quantile, buf);
+        if (!init_flag) {
+            return;
+        }
 
+        write_binary(target_quantile, buf);
         uint32_t serialize_size = digest->serialized_size();
         std::string result(serialize_size, '0');
         DCHECK(digest.get() != nullptr);
@@ -54,17 +57,29 @@ struct PercentileApproxState {
 
     void read(BufferReadable& buf) {
         read_binary(init_flag, buf);
-        read_binary(target_quantile, buf);
+        if (!init_flag) {
+            return;
+        }
 
+        read_binary(target_quantile, buf);
         std::string str;
         read_binary(str, buf);
         digest.reset(new TDigest());
         digest->unserialize((uint8_t*)str.c_str());
     }
 
-    double get() const { return digest->quantile(target_quantile); }
+    double get() const {
+        if (init_flag) {
+            return digest->quantile(target_quantile);
+        } else {
+            return std::nan("");
+        }
+    }
 
     void merge(const PercentileApproxState& rhs) {
+        if (!rhs.init_flag) {
+            return;
+        }
         if (init_flag) {
             DCHECK(digest.get() != nullptr);
             digest->merge(rhs.digest.get());
@@ -90,7 +105,7 @@ struct PercentileApproxState {
     }
 
     bool init_flag = false;
-    std::unique_ptr<TDigest> digest;
+    std::unique_ptr<TDigest> digest = nullptr;
     double target_quantile = INIT_QUANTILE;
 };
 
@@ -105,8 +120,6 @@ public:
 
     String get_name() const override { return "percentile_approx"; }
 
-    bool insert_to_null_default() const override { return false; }
-
     DataTypePtr get_return_type() const override {
         return make_nullable(std::make_shared<DataTypeFloat64>());
     }
@@ -142,6 +155,7 @@ public:
 };
 
 // only for merge
+template <bool is_nullable>
 class AggregateFunctionPercentileApproxMerge : public 
AggregateFunctionPercentileApprox {
 public:
     AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_)
@@ -152,32 +166,84 @@ public:
     }
 };
 
+template <bool is_nullable>
 class AggregateFunctionPercentileApproxTwoParams : public 
AggregateFunctionPercentileApprox {
 public:
     AggregateFunctionPercentileApproxTwoParams(const DataTypes& 
argument_types_)
             : AggregateFunctionPercentileApprox(argument_types_) {}
     void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
              Arena*) const override {
-        const auto& sources = static_cast<const 
ColumnVector<Float64>&>(*columns[0]);
-        const auto& quantile = static_cast<const 
ColumnVector<Float64>&>(*columns[1]);
+        if constexpr (is_nullable) {
+            double column_data[2] = {0, 0};
+
+            for (int i = 0; i < 2; ++i) {
+                const auto* nullable_column = 
check_and_get_column<ColumnNullable>(columns[i]);
+                if (nullable_column == nullptr) { //Not Nullable column
+                    const auto& column = static_cast<const 
ColumnVector<Float64>&>(*columns[i]);
+                    column_data[i] = column.get_float64(row_num);
+                } else if (!nullable_column->is_null_at(
+                                   row_num)) { // Nullable column && Not null 
data
+                    const auto& column = static_cast<const 
ColumnVector<Float64>&>(
+                            nullable_column->get_nested_column());
+                    column_data[i] = column.get_float64(row_num);
+                } else { // Nullable column && null data
+                    if (i == 0) {
+                        return;
+                    }
+                }
+            }
+
+            this->data(place).init();
+            this->data(place).add(column_data[0], column_data[1]);
+
+        } else {
+            const auto& sources = static_cast<const 
ColumnVector<Float64>&>(*columns[0]);
+            const auto& quantile = static_cast<const 
ColumnVector<Float64>&>(*columns[1]);
 
-        this->data(place).init();
-        this->data(place).add(sources.get_float64(row_num), 
quantile.get_float64(row_num));
+            this->data(place).init();
+            this->data(place).add(sources.get_float64(row_num), 
quantile.get_float64(row_num));
+        }
     }
 };
 
+template <bool is_nullable>
 class AggregateFunctionPercentileApproxThreeParams : public 
AggregateFunctionPercentileApprox {
 public:
     AggregateFunctionPercentileApproxThreeParams(const DataTypes& 
argument_types_)
             : AggregateFunctionPercentileApprox(argument_types_) {}
     void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
              Arena*) const override {
-        const auto& sources = static_cast<const 
ColumnVector<Float64>&>(*columns[0]);
-        const auto& quantile = static_cast<const 
ColumnVector<Float64>&>(*columns[1]);
-        const auto& compression = static_cast<const 
ColumnVector<Float64>&>(*columns[2]);
+        if constexpr (is_nullable) {
+            double column_data[3] = {0, 0, 0};
+
+            for (int i = 0; i < 3; ++i) {
+                const auto* nullable_column = 
check_and_get_column<ColumnNullable>(columns[i]);
+                if (nullable_column == nullptr) { //Not Nullable column
+                    const auto& column = static_cast<const 
ColumnVector<Float64>&>(*columns[i]);
+                    column_data[i] = column.get_float64(row_num);
+                } else if (!nullable_column->is_null_at(
+                                   row_num)) { // Nullable column && Not null 
data
+                    const auto& column = static_cast<const 
ColumnVector<Float64>&>(
+                            nullable_column->get_nested_column());
+                    column_data[i] = column.get_float64(row_num);
+                } else { // Nullable column && null data
+                    if (i == 0) {
+                        return;
+                    }
+                }
+            }
+
+            this->data(place).init(column_data[2]);
+            this->data(place).add(column_data[0], column_data[1]);
 
-        this->data(place).init(compression.get_float64(row_num));
-        this->data(place).add(sources.get_float64(row_num), 
quantile.get_float64(row_num));
+        } else {
+            const auto& sources = static_cast<const 
ColumnVector<Float64>&>(*columns[0]);
+            const auto& quantile = static_cast<const 
ColumnVector<Float64>&>(*columns[1]);
+            const auto& compression = static_cast<const 
ColumnVector<Float64>&>(*columns[2]);
+
+            this->data(place).init(compression.get_float64(row_num));
+            this->data(place).add(sources.get_float64(row_num), 
quantile.get_float64(row_num));
+        }
     }
 };
 
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index c153d32..d578eef 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -41,7 +41,7 @@ void 
register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory&
 void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
 void 
register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory&
 factory);
 void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& 
factory);
-
+void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& 
factory);
 void 
register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& 
factory);
 AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
     static std::once_flag oc;
@@ -60,7 +60,7 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
         register_aggregate_function_topn(instance);
         register_aggregate_function_approx_count_distinct(instance);
         register_aggregate_function_group_concat(instance);
-        register_aggregate_function_percentile_approx(instance);
+        register_aggregate_function_percentile(instance);
 
         // if you only register function with no nullable, and wants to add 
nullable automatically, you should place function above this line
         register_aggregate_function_combinator_null(instance);
@@ -68,6 +68,7 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
         register_aggregate_function_reader_no_spread(instance);
         register_aggregate_function_window_lead_lag(instance);
         register_aggregate_function_HLL_union_agg(instance);
+        register_aggregate_function_percentile_approx(instance);
     });
     return instance;
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
index da8dc10..fba3617 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
@@ -52,7 +52,7 @@ public class AggregateFunction extends Function {
             ImmutableSet.of("row_number", "rank", "dense_rank", 
"hll_union_agg", "hll_union", "bitmap_union", "bitmap_intersect", 
FunctionSet.COUNT, "ndv", FunctionSet.BITMAP_UNION_INT, 
FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize");
 
     public static ImmutableSet<String> 
ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
-            ImmutableSet.of("stddev_samp", "variance_samp", "var_samp");
+            ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", 
"percentile_approx");
 
     // Set if different from retType_, null otherwise.
     private Type intermediateType;

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to