This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch dev-1.0.1 in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
commit 9e6a213039cdb9ab2b3251c6a865d87d167f8b6e Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com> AuthorDate: Tue Mar 29 14:47:39 2022 +0800 [Vectorized][Bug] fix percentile_approx function to return always nullable (#8572) --- .../aggregate_function_percentile_approx.cpp | 21 +++-- .../aggregate_function_percentile_approx.h | 96 ++++++++++++++++++---- .../aggregate_function_simple_factory.cpp | 5 +- .../apache/doris/catalog/AggregateFunction.java | 2 +- 4 files changed, 100 insertions(+), 24 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp index 976565f..0a5ffda 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp @@ -24,17 +24,20 @@ namespace doris::vectorized { +template <bool is_nullable> AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name, const DataTypes& argument_types, const Array& parameters, const bool result_is_nullable) { - if (argument_types.size() == 1) { - return std::make_shared<AggregateFunctionPercentileApproxMerge>(argument_types); + return std::make_shared<AggregateFunctionPercentileApproxMerge<is_nullable>>( + argument_types); } else if (argument_types.size() == 2) { - return std::make_shared<AggregateFunctionPercentileApproxTwoParams>(argument_types); + return std::make_shared<AggregateFunctionPercentileApproxTwoParams<is_nullable>>( + argument_types); } else if (argument_types.size() == 3) { - return std::make_shared<AggregateFunctionPercentileApproxThreeParams>(argument_types); + return std::make_shared<AggregateFunctionPercentileApproxThreeParams<is_nullable>>( + argument_types); } LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}", argument_types.size(), name); @@ -50,8 +53,14 @@ AggregateFunctionPtr create_aggregate_function_percentile(const std::string& nam return std::make_shared<AggregateFunctionPercentile>(argument_types); } -void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) { +void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) { factory.register_function("percentile", create_aggregate_function_percentile); - factory.register_function("percentile_approx", create_aggregate_function_percentile_approx); +} + +void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) { + factory.register_function("percentile_approx", + create_aggregate_function_percentile_approx<false>, false); + factory.register_function("percentile_approx", + create_aggregate_function_percentile_approx<true>, true); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h index f7b620b..3e5576b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h @@ -42,8 +42,11 @@ struct PercentileApproxState { void write(BufferWritable& buf) const { write_binary(init_flag, buf); - write_binary(target_quantile, buf); + if (!init_flag) { + return; + } + write_binary(target_quantile, buf); uint32_t serialize_size = digest->serialized_size(); std::string result(serialize_size, '0'); DCHECK(digest.get() != nullptr); @@ -54,17 +57,29 @@ struct PercentileApproxState { void read(BufferReadable& buf) { read_binary(init_flag, buf); - read_binary(target_quantile, buf); + if (!init_flag) { + return; + } + read_binary(target_quantile, buf); std::string str; read_binary(str, buf); digest.reset(new TDigest()); digest->unserialize((uint8_t*)str.c_str()); } - double get() const { return digest->quantile(target_quantile); } + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } if (init_flag) { DCHECK(digest.get() != nullptr); digest->merge(rhs.digest.get()); @@ -90,7 +105,7 @@ struct PercentileApproxState { } bool init_flag = false; - std::unique_ptr<TDigest> digest; + std::unique_ptr<TDigest> digest = nullptr; double target_quantile = INIT_QUANTILE; }; @@ -105,8 +120,6 @@ public: String get_name() const override { return "percentile_approx"; } - bool insert_to_null_default() const override { return false; } - DataTypePtr get_return_type() const override { return make_nullable(std::make_shared<DataTypeFloat64>()); } @@ -142,6 +155,7 @@ public: }; // only for merge +template <bool is_nullable> class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { public: AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) @@ -152,32 +166,84 @@ public: } }; +template <bool is_nullable> class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { public: AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) : AggregateFunctionPercentileApprox(argument_types_) {} void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& sources = static_cast<const ColumnVector<Float64>&>(*columns[0]); - const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]); + if constexpr (is_nullable) { + double column_data[2] = {0, 0}; + + for (int i = 0; i < 2; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = static_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = static_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = static_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]); - this->data(place).init(); - this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + this->data(place).init(); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } } }; +template <bool is_nullable> class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox { public: AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) : AggregateFunctionPercentileApprox(argument_types_) {} void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& sources = static_cast<const ColumnVector<Float64>&>(*columns[0]); - const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]); - const auto& compression = static_cast<const ColumnVector<Float64>&>(*columns[2]); + if constexpr (is_nullable) { + double column_data[3] = {0, 0, 0}; + + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = static_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = static_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[2]); + this->data(place).add(column_data[0], column_data[1]); - this->data(place).init(compression.get_float64(row_num)); - this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } else { + const auto& sources = static_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = static_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = static_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index c153d32..d578eef 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -41,7 +41,7 @@ void register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory& void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory); - +void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory); AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { static std::once_flag oc; @@ -60,7 +60,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_topn(instance); register_aggregate_function_approx_count_distinct(instance); register_aggregate_function_group_concat(instance); - register_aggregate_function_percentile_approx(instance); + register_aggregate_function_percentile(instance); // if you only register function with no nullable, and wants to add nullable automatically, you should place function above this line register_aggregate_function_combinator_null(instance); @@ -68,6 +68,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_reader_no_spread(instance); register_aggregate_function_window_lead_lag(instance); register_aggregate_function_HLL_union_agg(instance); + register_aggregate_function_percentile_approx(instance); }); return instance; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index da8dc10..fba3617 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -52,7 +52,7 @@ public class AggregateFunction extends Function { ImmutableSet.of("row_number", "rank", "dense_rank", "hll_union_agg", "hll_union", "bitmap_union", "bitmap_intersect", FunctionSet.COUNT, "ndv", FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize"); public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = - ImmutableSet.of("stddev_samp", "variance_samp", "var_samp"); + ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx"); // Set if different from retType_, null otherwise. private Type intermediateType; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org