This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push: new 4a31fc4e096 [Bug](fix) fix the percentile func result do not equal the percentile array rewrite result (#49379) 4a31fc4e096 is described below commit 4a31fc4e09611322494548b50bd0dca5a6ade467 Author: HappenLee <happen...@selectdb.com> AuthorDate: Sat Mar 29 08:56:24 2025 +0800 [Bug](fix) fix the percentile func result do not equal the percentile array rewrite result (#49379) cherry pick https://github.com/apache/doris/pull/49351 --- be/src/util/counts.h | 29 +++++---- .../aggregate_function_percentile_approx.cpp | 7 ++- .../aggregate_function_percentile_approx.h | 28 +++++---- be/test/util/counts_test.cpp | 6 +- .../doris/catalog/BuiltinAggregateFunctions.java | 2 +- .../java/org/apache/doris/catalog/FunctionSet.java | 9 +++ .../expressions/functions/agg/Percentile.java | 13 +++- .../expressions/functions/agg/PercentileArray.java | 20 ++++++- .../nereids/rules/analysis/GenerateFunction.java | 1 + .../test_aggregate_window_functions.out | Bin 21064 -> 21094 bytes .../data/query_p0/aggregate/aggregate.out | Bin 9762 -> 9841 bytes .../mv_p0/mv_percentile/mv_percentile.groovy | 66 --------------------- .../suites/query_p0/aggregate/aggregate.groovy | 40 +++++++++++++ 13 files changed, 118 insertions(+), 103 deletions(-) diff --git a/be/src/util/counts.h b/be/src/util/counts.h index fec18cedcd6..5dcd14f310e 100644 --- a/be/src/util/counts.h +++ b/be/src/util/counts.h @@ -26,6 +26,7 @@ namespace doris { +template <typename T> class Counts { public: Counts() = default; @@ -40,7 +41,7 @@ public: } } - void increment(int64_t key, uint32_t i) { + void increment(T key, uint32_t i) { auto item = _counts.find(key); if (item != _counts.end()) { item->second += i; @@ -50,8 +51,7 @@ public: } uint32_t serialized_size() const { - return sizeof(uint32_t) + sizeof(int64_t) * _counts.size() + - sizeof(uint32_t) * _counts.size(); + return sizeof(uint32_t) + sizeof(T) * _counts.size() + sizeof(uint32_t) * _counts.size(); } void serialize(uint8_t* writer) const { @@ -59,8 +59,8 @@ public: memcpy(writer, &size, sizeof(uint32_t)); writer += sizeof(uint32_t); for (auto& cell : _counts) { - memcpy(writer, &cell.first, sizeof(int64_t)); - writer += sizeof(int64_t); + memcpy(writer, &cell.first, sizeof(T)); + writer += sizeof(T); memcpy(writer, &cell.second, sizeof(uint32_t)); writer += sizeof(uint32_t); } @@ -71,18 +71,17 @@ public: memcpy(&size, type_reader, sizeof(uint32_t)); type_reader += sizeof(uint32_t); for (uint32_t i = 0; i < size; ++i) { - int64_t key; + T key; uint32_t count; - memcpy(&key, type_reader, sizeof(int64_t)); - type_reader += sizeof(int64_t); + memcpy(&key, type_reader, sizeof(T)); + type_reader += sizeof(T); memcpy(&count, type_reader, sizeof(uint32_t)); type_reader += sizeof(uint32_t); _counts.emplace(std::make_pair(key, count)); } } - double get_percentile(std::vector<std::pair<int64_t, uint32_t>>& counts, - double position) const { + double get_percentile(std::vector<std::pair<T, uint32_t>>& counts, double position) const { long lower = long(std::floor(position)); long higher = long(std::ceil(position)); @@ -90,7 +89,7 @@ public: for (; iter != counts.end() && iter->second < lower + 1; ++iter) ; - int64_t lower_key = iter->first; + T lower_key = iter->first; if (higher == lower) { return lower_key; } @@ -99,7 +98,7 @@ public: iter++; } - int64_t higher_key = iter->first; + T higher_key = iter->first; if (lower_key == higher_key) { return lower_key; } @@ -114,9 +113,9 @@ public: return 0.0; } - std::vector<std::pair<int64_t, uint32_t>> elems(_counts.begin(), _counts.end()); + std::vector<std::pair<T, uint32_t>> elems(_counts.begin(), _counts.end()); sort(elems.begin(), elems.end(), - [](const std::pair<int64_t, uint32_t> l, const std::pair<int64_t, uint32_t> r) { + [](const std::pair<T, uint32_t> l, const std::pair<T, uint32_t> r) { return l.first < r.first; }); @@ -132,7 +131,7 @@ public: } private: - std::unordered_map<int64_t, uint32_t> _counts; + std::unordered_map<T, uint32_t> _counts; }; } // namespace doris diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp index 05e36a8f72b..4cbe8a06900 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp @@ -50,9 +50,10 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx( void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("percentile", - creator_without_type::creator<AggregateFunctionPercentile>); - factory.register_function_both("percentile_array", - creator_without_type::creator<AggregateFunctionPercentileArray>); + creator_with_numeric_type::creator<AggregateFunctionPercentile>); + factory.register_function_both( + "percentile_array", + creator_with_numeric_type::creator<AggregateFunctionPercentileArray>); } void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h index 0f24aef6dbb..87e55604660 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h @@ -288,8 +288,9 @@ public: } }; +template <typename T> struct PercentileState { - std::vector<Counts> vec_counts; + std::vector<Counts<T>> vec_counts; std::vector<double> vec_quantile {-1}; bool inited_flag = false; @@ -327,7 +328,7 @@ struct PercentileState { } } - void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) { if (!inited_flag) { vec_counts.resize(arg_size); vec_quantile.resize(arg_size, -1); @@ -376,11 +377,12 @@ struct PercentileState { } }; +template <typename T> class AggregateFunctionPercentile final - : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> { + : public IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>> { public: AggregateFunctionPercentile(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>( + : IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>( argument_types_) {} String get_name() const override { return "percentile"; } @@ -389,10 +391,10 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& sources = assert_cast<const ColumnVector<T>&>(*columns[0]); const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); - AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(), - 1); + AggregateFunctionPercentile::data(place).add(sources.get_element(row_num), + quantile.get_data(), 1); } void reset(AggregateDataPtr __restrict place) const override { @@ -419,11 +421,13 @@ public: } }; +template <typename T> class AggregateFunctionPercentileArray final - : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray> { + : public IAggregateFunctionDataHelper<PercentileState<T>, + AggregateFunctionPercentileArray<T>> { public: AggregateFunctionPercentileArray(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray>( + : IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentileArray<T>>( argument_types_) {} String get_name() const override { return "percentile_array"; } @@ -434,7 +438,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& sources = assert_cast<const ColumnVector<T>&>(*columns[0]); const auto& quantile_array = assert_cast<const ColumnArray&>(*columns[1]); const auto& offset_column_data = quantile_array.get_offsets(); const auto& nested_column = @@ -442,7 +446,7 @@ public: const auto& nested_column_data = assert_cast<const ColumnVector<Float64>&>(nested_column); AggregateFunctionPercentileArray::data(place).add( - sources.get_int(row_num), nested_column_data.get_data(), + sources.get_element(row_num), nested_column_data.get_data(), offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]); } @@ -480,4 +484,4 @@ public: } }; -} // namespace doris::vectorized \ No newline at end of file +} // namespace doris::vectorized diff --git a/be/test/util/counts_test.cpp b/be/test/util/counts_test.cpp index 908bbcefd58..42370f8057d 100644 --- a/be/test/util/counts_test.cpp +++ b/be/test/util/counts_test.cpp @@ -27,7 +27,7 @@ namespace doris { class TCountsTest : public testing::Test {}; TEST_F(TCountsTest, TotalTest) { - Counts counts; + Counts<int64_t> counts; // 1 1 1 2 5 7 7 9 9 19 // >>> import numpy as np // >>> a = np.array([1,1,1,2,5,7,7,9,9,19]) @@ -46,12 +46,12 @@ TEST_F(TCountsTest, TotalTest) { uint8_t* type_reader = writer; counts.serialize(writer); - Counts other; + Counts<int64_t> other; other.unserialize(type_reader); double result1 = other.terminate(0.2); EXPECT_EQ(result, result1); - Counts other1; + Counts<int64_t> other1; other1.increment(1, 1); other1.increment(100, 3); other1.increment(50, 3); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index d8f95b2f6bc..09acb21f47f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -131,7 +131,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(OrthogonalBitmapIntersect.class, "orthogonal_bitmap_intersect"), agg(OrthogonalBitmapIntersectCount.class, "orthogonal_bitmap_intersect_count"), agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"), - agg(Percentile.class, "percentile"), + agg(Percentile.class, "percentile", "percentile_cont"), agg(PercentileApprox.class, "percentile_approx"), agg(PercentileArray.class, "percentile_array"), agg(QuantileUnion.class, "quantile_union"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 2db943993dd..d943ad4f6ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1430,6 +1430,15 @@ public class FunctionSet<T> { "", false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin("percentile_cont", + Lists.newArrayList(Type.BIGINT, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR, + "", + "", + "", + "", + "", + false, true, false, true)); + addBuiltin(AggregateFunction.createBuiltin("percentile_approx", Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR, "", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java index 31ab925ca6c..fd3ba4890d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java @@ -25,6 +25,11 @@ import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -38,7 +43,13 @@ public class Percentile extends NullableAggregateFunction implements BinaryExpression, ExplicitlyCastableSignature { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( - FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE) + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE) ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java index d4d8ed6c39a..c97b617e616 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java @@ -26,6 +26,11 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -40,8 +45,19 @@ public class PercentileArray extends AggregateFunction public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) - .args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)) - ); + .args(DoubleType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(FloatType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(LargeIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(IntegerType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(SmallIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(TinyIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE))); /** * constructor with 2 arguments. diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/GenerateFunction.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/GenerateFunction.java index 2d010df4c51..105ab00f395 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/GenerateFunction.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/GenerateFunction.java @@ -182,6 +182,7 @@ public class GenerateFunction { .put("any", "any_value") .put("char_length", "character_length") .put("stddev_pop", "stddev") + .put("percentile_cont", "percentile") .put("var_pop", "variance") .put("variance_pop", "variance") .put("var_samp", "variance_samp") diff --git a/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_window_functions.out b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_window_functions.out index 6729ea26bc1..03569f1aedf 100644 Binary files a/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_window_functions.out and b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_window_functions.out differ diff --git a/regression-test/data/query_p0/aggregate/aggregate.out b/regression-test/data/query_p0/aggregate/aggregate.out index ffd37904994..f17c690ec49 100644 Binary files a/regression-test/data/query_p0/aggregate/aggregate.out and b/regression-test/data/query_p0/aggregate/aggregate.out differ diff --git a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy b/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy deleted file mode 100644 index dd6cb453305..00000000000 --- a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy +++ /dev/null @@ -1,66 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -import org.codehaus.groovy.runtime.IOGroovyMethods - -suite ("mv_percentile") { - sql "set enable_fallback_to_original_planner = false" - - sql """DROP TABLE IF EXISTS d_table;""" - - sql """ - create table d_table( - k1 int null, - k2 int not null, - k3 decimal(28,6) null, - k4 varchar(100) null - ) - duplicate key (k1,k2,k3) - distributed BY hash(k1) buckets 3 - properties("replication_num" = "1"); - """ - - sql "insert into d_table select 1,1,1,'a';" - sql "insert into d_table select 2,2,2,'b';" - sql "insert into d_table select 3,-3,null,'c';" - - createMV("create materialized view kp as select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2;") - - sql "insert into d_table select -4,-4,-4,'d';" - sql "insert into d_table(k4,k2) values('d',4);" - - qt_select_star "select * from d_table order by k1;" - - explain { - sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2 order by k1,k2;") - contains "(kp)" - } - qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by k1,k2 order by k1,k2;" - - explain { - sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by grouping sets((k1),(k1,k2),()) order by 1,2;") - contains "(kp)" - } - qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table group by grouping sets((k1),(k1,k2),()) order by 1,2,3;" - - - explain { - sql("select percentile(k3, 0.1) from d_table group by grouping sets((k1),()) order by 1;") - contains "(kp)" - } - qt_select_mv "select percentile(k3, 0.1) from d_table group by grouping sets((k1),()) order by 1;" -} diff --git a/regression-test/suites/query_p0/aggregate/aggregate.groovy b/regression-test/suites/query_p0/aggregate/aggregate.groovy index b611ff92b0e..9836b8a2f57 100644 --- a/regression-test/suites/query_p0/aggregate/aggregate.groovy +++ b/regression-test/suites/query_p0/aggregate/aggregate.groovy @@ -308,4 +308,44 @@ suite("aggregate") { qt_aggregate_limit_contain_null """ select count(), cast(k12 as int) as t from baseall group by t limit 1; """ + + // Test case for percentile function with sales data + sql """ DROP TABLE IF EXISTS sales_data """ + sql """ + CREATE TABLE sales_data ( + product_id INT, + sale_price DECIMAL(10, 2) + ) DUPLICATE KEY(`product_id`) + DISTRIBUTED BY HASH(`product_id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ) + """ + + sql """ + INSERT INTO sales_data VALUES + (1, 10.00), + (1, 15.00), + (1, 20.00), + (1, 25.00), + (1, 30.25), + (1, 35.00), + (1, 40.00), + (1, 45.00), + (1, 50.00), + (1, 100.00) + """ + + qt_aggregate35 """ + SELECT + percentile(sale_price, 0.05) as median_price_05, + percentile(sale_price, 0.5) as median_price, + percentile(sale_price, 0.75) as p75_price, + percentile(sale_price, 0.90) as p90_price, + percentile(sale_price, 0.95) as p95_price, + percentile(null, 0.99) as p99_null + FROM sales_data + """ + + sql """ DROP TABLE IF EXISTS sales_data """ } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org