This is an automated email from the ASF dual-hosted git repository. lihaopeng pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 2c155a45803 branch-3.0: [Fix](bug) Percentile* func core when percent args is negative number #47068 (#47219) 2c155a45803 is described below commit 2c155a45803a1fde38df2ac76a7ad6e7fff19e88 Author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> AuthorDate: Fri Feb 7 12:45:12 2025 +0800 branch-3.0: [Fix](bug) Percentile* func core when percent args is negative number #47068 (#47219) Cherry-picked from #47068 Co-authored-by: HappenLee <happen...@selectdb.com> --- .../aggregate_function_percentile.h | 56 +++++++++++++-------- .../aggregate_function_simple_factory.h | 9 ---- .../expressions/functions/agg/PercentileArray.java | 9 ++++ .../functions/combinator/ForEachCombinator.java | 29 +++++++++++ .../data/function_p0/test_agg_foreach.out | Bin 1945 -> 1865 bytes .../data/function_p0/test_agg_foreach_notnull.out | Bin 1945 -> 1865 bytes .../test_aggregate_all_functions.out | Bin 2675 -> 2765 bytes .../suites/function_p0/test_agg_foreach.groovy | 26 ++++++---- .../function_p0/test_agg_foreach_notnull.groovy | 30 ++++++----- .../suites/query_p0/aggregate/aggregate.groovy | 17 +++++++ .../test_aggregate_all_functions.groovy | 24 +++++++++ 11 files changed, 149 insertions(+), 51 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.h b/be/src/vec/aggregate_functions/aggregate_function_percentile.h index 0766c59f3de..dbd52af923f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.h @@ -51,12 +51,20 @@ namespace doris::vectorized { class Arena; class BufferReadable; +inline void check_quantile(double quantile) { + if (quantile < 0 || quantile > 1) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "quantile in func percentile should in [0, 1], but real data is:" + + std::to_string(quantile)); + } +} + struct PercentileApproxState { static constexpr double INIT_QUANTILE = -1.0; PercentileApproxState() = default; ~PercentileApproxState() = default; - void init(double compression = 10000) { + void init(double quantile, double compression = 10000) { if (!init_flag) { //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description //The compression parameter setting range is [2048, 10000]. @@ -66,6 +74,8 @@ struct PercentileApproxState { compression = 10000; } digest = TDigest::create_unique(compression); + check_quantile(quantile); + target_quantile = quantile; compressions = compression; init_flag = true; } @@ -126,18 +136,14 @@ struct PercentileApproxState { } } - void add(double source, double quantile) { - digest->add(source); - target_quantile = quantile; - } + void add(double source) { digest->add(source); } - void add_with_weight(double source, double weight, double quantile) { + void add_with_weight(double source, double weight) { // the weight should be positive num, as have check the value valid use DCHECK_GT(c._weight, 0); if (weight <= 0) { return; } digest->add(source, weight); - target_quantile = quantile; } void reset() { @@ -192,8 +198,8 @@ public: assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]); const auto& quantile = assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]); - this->data(place).init(); - this->data(place).add(sources.get_element(row_num), quantile.get_element(row_num)); + this->data(place).init(quantile.get_element(0)); + this->data(place).add(sources.get_element(row_num)); } DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } @@ -223,8 +229,8 @@ public: const auto& compression = assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]); - this->data(place).init(compression.get_element(row_num)); - this->data(place).add(sources.get_element(row_num), quantile.get_element(row_num)); + this->data(place).init(quantile.get_element(0), compression.get_element(0)); + this->data(place).add(sources.get_element(row_num)); } DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } @@ -256,9 +262,9 @@ public: const auto& quantile = assert_cast<const ColumnVector<Float64>&, TypeCheckOnRelease::DISABLE>(*columns[2]); - this->data(place).init(); - this->data(place).add_with_weight(sources.get_element(row_num), weight.get_element(row_num), - quantile.get_element(row_num)); + this->data(place).init(quantile.get_element(0)); + this->data(place).add_with_weight(sources.get_element(row_num), + weight.get_element(row_num)); } DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } @@ -291,9 +297,9 @@ public: const auto& compression = assert_cast<const ColumnVector<Float64>&, TypeCheckOnRelease::DISABLE>(*columns[3]); - this->data(place).init(compression.get_element(row_num)); - this->data(place).add_with_weight(sources.get_element(row_num), weight.get_element(row_num), - quantile.get_element(row_num)); + this->data(place).init(quantile.get_element(0), compression.get_element(0)); + this->data(place).add_with_weight(sources.get_element(row_num), + weight.get_element(row_num)); } DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } @@ -351,12 +357,19 @@ struct PercentileState { } } - void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + void add(T source, const PaddedPODArray<Float64>& quantiles, const NullMap& null_maps, + int arg_size) { if (!inited_flag) { vec_counts.resize(arg_size); vec_quantile.resize(arg_size, -1); inited_flag = true; for (int i = 0; i < arg_size; ++i) { + // throw Exception func call percentile_array(id, [1,0,null]) + if (null_maps[i]) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "quantiles in func percentile_array should not have null"); + } + check_quantile(quantiles[i]); vec_quantile[i] = quantiles[i]; } } @@ -429,7 +442,7 @@ public: const auto& quantile = assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]); AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num], - quantile.get_data(), 1); + quantile.get_data(), NullMap(1, 0), 1); } void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, @@ -490,6 +503,9 @@ public: const auto& quantile_array = assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]); const auto& offset_column_data = quantile_array.get_offsets(); + const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>( + quantile_array.get_data()) + .get_null_map_data(); const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>( quantile_array.get_data()) .get_nested_column(); @@ -497,7 +513,7 @@ public: assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column); AggregateFunctionPercentileArray::data(place).add( - sources.get_int(row_num), nested_column_data.get_data(), + sources.get_int(row_num), nested_column_data.get_data(), null_maps, offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index aa33e7289df..0a0ec6abe16 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -61,15 +61,6 @@ private: std::unordered_map<std::string, std::string> function_alias; public: - void register_nullable_function_combinator(const Creator& creator) { - for (const auto& entity : aggregate_functions) { - if (nullable_aggregate_functions.find(entity.first) == - nullable_aggregate_functions.end()) { - nullable_aggregate_functions[entity.first] = creator; - } - } - } - static bool is_foreach(const std::string& name) { constexpr std::string_view suffix = "_foreach"; if (name.length() < suffix.length()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java index 4412d96006f..49a0f836aed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; @@ -69,6 +70,14 @@ public class PercentileArray extends NotNullableAggregateFunction super("percentile_array", distinct, arg0, arg1); } + @Override + public void checkLegalityBeforeTypeCoercion() { + if (!getArgument(1).isConstant()) { + throw new AnalysisException( + "percentile_array requires second parameter must be a constant : " + this.toSql()); + } + } + /** * withDistinctAndChildren. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java index a6d011ff0fb..ddd92f894e1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java @@ -30,8 +30,11 @@ import org.apache.doris.nereids.types.DataType; import com.google.common.collect.ImmutableList; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; /** * combinator foreach @@ -39,6 +42,15 @@ import java.util.Objects; public class ForEachCombinator extends NullableAggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, Combinator { + public static final Set<String> UNSUPPORTED_AGGREGATE_FUNCTION = Collections.unmodifiableSet(new HashSet<String>() { + { + add("percentile"); + add("percentile_array"); + add("percentile_approx"); + add("percentile_approx_weighted"); + } + }); + private final AggregateFunction nested; /** @@ -48,10 +60,27 @@ public class ForEachCombinator extends NullableAggregateFunction this(arguments, false, nested); } + /** + * Constructs a new instance of {@code ForEachCombinator}. + * + * <p>This constructor initializes a combinator that will iterate over each item in the input list + * and apply the nested aggregate function. + * If the provided aggregate function name is within the list of unsupported functions, + * an {@link UnsupportedOperationException} will be thrown. + * + * @param arguments A list of {@code Expression} objects that serve as parameters to the aggregate function. + * @param alwaysNullable A boolean flag indicating whether this combinator should always return a nullable result. + * @param nested The nested aggregate function to apply to each element. It must not be {@code null}. + * @throws NullPointerException If the provided nested aggregate function is {@code null}. + * @throws UnsupportedOperationException If nested aggregate function is one of the unsupported aggregate functions + */ public ForEachCombinator(List<Expression> arguments, boolean alwaysNullable, AggregateFunction nested) { super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, false, alwaysNullable, arguments); this.nested = Objects.requireNonNull(nested, "nested can not be null"); + if (UNSUPPORTED_AGGREGATE_FUNCTION.contains(nested.getName().toLowerCase())) { + throw new UnsupportedOperationException("Unsupport the func:" + nested.getName() + " use in foreach"); + } } public static ForEachCombinator create(AggregateFunction nested) { diff --git a/regression-test/data/function_p0/test_agg_foreach.out b/regression-test/data/function_p0/test_agg_foreach.out index c45ae9f67a9..693009d3890 100644 Binary files a/regression-test/data/function_p0/test_agg_foreach.out and b/regression-test/data/function_p0/test_agg_foreach.out differ diff --git a/regression-test/data/function_p0/test_agg_foreach_notnull.out b/regression-test/data/function_p0/test_agg_foreach_notnull.out index c45ae9f67a9..693009d3890 100644 Binary files a/regression-test/data/function_p0/test_agg_foreach_notnull.out and b/regression-test/data/function_p0/test_agg_foreach_notnull.out differ diff --git a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out index 75d9a18679f..90953b0a11c 100644 Binary files a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out and b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out differ diff --git a/regression-test/suites/function_p0/test_agg_foreach.groovy b/regression-test/suites/function_p0/test_agg_foreach.groovy index 281fdea6a3b..fad9925af81 100644 --- a/regression-test/suites/function_p0/test_agg_foreach.groovy +++ b/regression-test/suites/function_p0/test_agg_foreach.groovy @@ -87,18 +87,24 @@ suite("test_agg_foreach") { select histogram_foreach(a) from foreach_table; """ - qt_sql """ - select PERCENTILE_foreach(a,a) from foreach_table; - """ + try { + sql "select PERCENTILE_foreach(a,a) from foreach_table;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } - qt_sql """ - select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1; - """ - qt_sql """ - - select PERCENTILE_APPROX_foreach(a,a) from foreach_table; - """ + try { + sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } + + try { + sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } qt_sql """ select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table; diff --git a/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy b/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy index 91f4ea902dd..68f27e6d049 100644 --- a/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy +++ b/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy @@ -85,20 +85,26 @@ suite("test_agg_foreach_not_null") { qt_sql """ select histogram_foreach(a) from foreach_table_not_null; """ - - qt_sql """ - select PERCENTILE_foreach(a,a) from foreach_table_not_null; - """ - - qt_sql """ - select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where id = 1; - """ - - qt_sql """ - select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null; - """ + try { + sql "select PERCENTILE_foreach(a,a) from foreach_table_not_null;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } + + try { + sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where id = 1;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } + + try { + sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null;" + } catch (Exception ex) { + assert("${ex}".contains("Unsupport the func")) + } + qt_sql """ select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), GROUP_BIT_XOR_foreach(a) from foreach_table_not_null; """ diff --git a/regression-test/suites/query_p0/aggregate/aggregate.groovy b/regression-test/suites/query_p0/aggregate/aggregate.groovy index b611ff92b0e..6079d09577f 100644 --- a/regression-test/suites/query_p0/aggregate/aggregate.groovy +++ b/regression-test/suites/query_p0/aggregate/aggregate.groovy @@ -141,6 +141,23 @@ suite("aggregate") { qt_aggregate32" select topn_weighted(c_string,c_bigint,3) from ${tableName}" qt_aggregate33" select avg_weighted(c_double,c_bigint) from ${tableName};" qt_aggregate34" select percentile_array(c_bigint,[0.2,0.5,0.9]) from ${tableName};" + + try { + sql "select percentile_array(c_bigint,[-1,0.5,0.9]) from ${tableName};" + } catch (Exception ex) { + assert("${ex}".contains("-1")) + } + try { + sql "select percentile_array(c_bigint,[0.5,0.9,3000]) from ${tableName};" + } catch (Exception ex) { + assert("${ex}".contains("3000")) + } + try { + sql "select percentile_array(c_bigint,[0.5,0.9,null]) from ${tableName};" + } catch (Exception ex) { + assert("${ex}".contains("null")) + } + qt_aggregate """ SELECT c_bigint, CASE diff --git a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy index cdab9472e27..c64d33e1e82 100644 --- a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy +++ b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy @@ -286,6 +286,18 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") { qt_select20_1 "select id,percentile(level + 0.1,0.5) from ${tableName_13} group by id order by id" qt_select21_1 "select id,percentile(level + 0.1,0.55) from ${tableName_13} group by id order by id" qt_select22_1 "select id,percentile(level + 0.1,0.805) from ${tableName_13} group by id order by id" + qt_select22_1_1 "select id,percentile(level + 0.1, null) from ${tableName_13} group by id order by id" + + try { + sql "select id,percentile(level + 0.1, -1) from ${tableName_13} group by id order by id" + } catch (Exception ex) { + assert("${ex}".contains("-1")) + } + try { + sql "select id,percentile(level + 0.1, 3000) from ${tableName_13} group by id order by id" + } catch (Exception ex) { + assert("${ex}".contains("3000")) + } sql "DROP TABLE IF EXISTS ${tableName_13}" @@ -313,6 +325,18 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") { qt_select26 "select id,PERCENTILE_APPROX(level,0.5,2048) from ${tableName_14} group by id order by id" qt_select27 "select id,PERCENTILE_APPROX(level,0.55,2048) from ${tableName_14} group by id order by id" qt_select28 "select id,PERCENTILE_APPROX(level,0.805,2048) from ${tableName_14} group by id order by id" + qt_select28_1 "select id,PERCENTILE_APPROX(level, null ,2048) from ${tableName_14} group by id order by id" + + try { + sql "select id,PERCENTILE_APPROX(level, -1, 2048) from ${tableName_14} group by id order by id" + } catch (Exception ex) { + assert("${ex}".contains("-1")) + } + try { + sql "select id,PERCENTILE_APPROX(level, 3000 ,2048) from ${tableName_14} group by id order by id" + } catch (Exception ex) { + assert("${ex}".contains("3000")) + } sql "DROP TABLE IF EXISTS ${tableName_14}" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org