This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 2c155a45803 branch-3.0: [Fix](bug) Percentile* func core when percent 
args is negative number #47068 (#47219)
2c155a45803 is described below

commit 2c155a45803a1fde38df2ac76a7ad6e7fff19e88
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Fri Feb 7 12:45:12 2025 +0800

    branch-3.0: [Fix](bug) Percentile* func core when percent args is negative 
number #47068 (#47219)
    
    Cherry-picked from #47068
    
    Co-authored-by: HappenLee <happen...@selectdb.com>
---
 .../aggregate_function_percentile.h                |  56 +++++++++++++--------
 .../aggregate_function_simple_factory.h            |   9 ----
 .../expressions/functions/agg/PercentileArray.java |   9 ++++
 .../functions/combinator/ForEachCombinator.java    |  29 +++++++++++
 .../data/function_p0/test_agg_foreach.out          | Bin 1945 -> 1865 bytes
 .../data/function_p0/test_agg_foreach_notnull.out  | Bin 1945 -> 1865 bytes
 .../test_aggregate_all_functions.out               | Bin 2675 -> 2765 bytes
 .../suites/function_p0/test_agg_foreach.groovy     |  26 ++++++----
 .../function_p0/test_agg_foreach_notnull.groovy    |  30 ++++++-----
 .../suites/query_p0/aggregate/aggregate.groovy     |  17 +++++++
 .../test_aggregate_all_functions.groovy            |  24 +++++++++
 11 files changed, 149 insertions(+), 51 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.h 
b/be/src/vec/aggregate_functions/aggregate_function_percentile.h
index 0766c59f3de..dbd52af923f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_percentile.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.h
@@ -51,12 +51,20 @@ namespace doris::vectorized {
 class Arena;
 class BufferReadable;
 
+inline void check_quantile(double quantile) {
+    if (quantile < 0 || quantile > 1) {
+        throw Exception(ErrorCode::INVALID_ARGUMENT,
+                        "quantile in func percentile should in [0, 1], but 
real data is:" +
+                                std::to_string(quantile));
+    }
+}
+
 struct PercentileApproxState {
     static constexpr double INIT_QUANTILE = -1.0;
     PercentileApproxState() = default;
     ~PercentileApproxState() = default;
 
-    void init(double compression = 10000) {
+    void init(double quantile, double compression = 10000) {
         if (!init_flag) {
             
//https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description
             //The compression parameter setting range is [2048, 10000].
@@ -66,6 +74,8 @@ struct PercentileApproxState {
                 compression = 10000;
             }
             digest = TDigest::create_unique(compression);
+            check_quantile(quantile);
+            target_quantile = quantile;
             compressions = compression;
             init_flag = true;
         }
@@ -126,18 +136,14 @@ struct PercentileApproxState {
         }
     }
 
-    void add(double source, double quantile) {
-        digest->add(source);
-        target_quantile = quantile;
-    }
+    void add(double source) { digest->add(source); }
 
-    void add_with_weight(double source, double weight, double quantile) {
+    void add_with_weight(double source, double weight) {
         // the weight should be positive num, as have check the value valid 
use DCHECK_GT(c._weight, 0);
         if (weight <= 0) {
             return;
         }
         digest->add(source, weight);
-        target_quantile = quantile;
     }
 
     void reset() {
@@ -192,8 +198,8 @@ public:
                 assert_cast<const ColumnFloat64&, 
TypeCheckOnRelease::DISABLE>(*columns[0]);
         const auto& quantile =
                 assert_cast<const ColumnFloat64&, 
TypeCheckOnRelease::DISABLE>(*columns[1]);
-        this->data(place).init();
-        this->data(place).add(sources.get_element(row_num), 
quantile.get_element(row_num));
+        this->data(place).init(quantile.get_element(0));
+        this->data(place).add(sources.get_element(row_num));
     }
 
     DataTypePtr get_return_type() const override { return 
std::make_shared<DataTypeFloat64>(); }
@@ -223,8 +229,8 @@ public:
         const auto& compression =
                 assert_cast<const ColumnFloat64&, 
TypeCheckOnRelease::DISABLE>(*columns[2]);
 
-        this->data(place).init(compression.get_element(row_num));
-        this->data(place).add(sources.get_element(row_num), 
quantile.get_element(row_num));
+        this->data(place).init(quantile.get_element(0), 
compression.get_element(0));
+        this->data(place).add(sources.get_element(row_num));
     }
 
     DataTypePtr get_return_type() const override { return 
std::make_shared<DataTypeFloat64>(); }
@@ -256,9 +262,9 @@ public:
         const auto& quantile =
                 assert_cast<const ColumnVector<Float64>&, 
TypeCheckOnRelease::DISABLE>(*columns[2]);
 
-        this->data(place).init();
-        this->data(place).add_with_weight(sources.get_element(row_num), 
weight.get_element(row_num),
-                                          quantile.get_element(row_num));
+        this->data(place).init(quantile.get_element(0));
+        this->data(place).add_with_weight(sources.get_element(row_num),
+                                          weight.get_element(row_num));
     }
 
     DataTypePtr get_return_type() const override { return 
std::make_shared<DataTypeFloat64>(); }
@@ -291,9 +297,9 @@ public:
         const auto& compression =
                 assert_cast<const ColumnVector<Float64>&, 
TypeCheckOnRelease::DISABLE>(*columns[3]);
 
-        this->data(place).init(compression.get_element(row_num));
-        this->data(place).add_with_weight(sources.get_element(row_num), 
weight.get_element(row_num),
-                                          quantile.get_element(row_num));
+        this->data(place).init(quantile.get_element(0), 
compression.get_element(0));
+        this->data(place).add_with_weight(sources.get_element(row_num),
+                                          weight.get_element(row_num));
     }
 
     DataTypePtr get_return_type() const override { return 
std::make_shared<DataTypeFloat64>(); }
@@ -351,12 +357,19 @@ struct PercentileState {
         }
     }
 
-    void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) 
{
+    void add(T source, const PaddedPODArray<Float64>& quantiles, const 
NullMap& null_maps,
+             int arg_size) {
         if (!inited_flag) {
             vec_counts.resize(arg_size);
             vec_quantile.resize(arg_size, -1);
             inited_flag = true;
             for (int i = 0; i < arg_size; ++i) {
+                // throw Exception func call percentile_array(id, [1,0,null])
+                if (null_maps[i]) {
+                    throw Exception(ErrorCode::INVALID_ARGUMENT,
+                                    "quantiles in func percentile_array should 
not have null");
+                }
+                check_quantile(quantiles[i]);
                 vec_quantile[i] = quantiles[i];
             }
         }
@@ -429,7 +442,7 @@ public:
         const auto& quantile =
                 assert_cast<const ColumnFloat64&, 
TypeCheckOnRelease::DISABLE>(*columns[1]);
         
AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num],
-                                                     quantile.get_data(), 1);
+                                                     quantile.get_data(), 
NullMap(1, 0), 1);
     }
 
     void add_batch_single_place(size_t batch_size, AggregateDataPtr place, 
const IColumn** columns,
@@ -490,6 +503,9 @@ public:
         const auto& quantile_array =
                 assert_cast<const ColumnArray&, 
TypeCheckOnRelease::DISABLE>(*columns[1]);
         const auto& offset_column_data = quantile_array.get_offsets();
+        const auto& null_maps = assert_cast<const ColumnNullable&, 
TypeCheckOnRelease::DISABLE>(
+                                        quantile_array.get_data())
+                                        .get_null_map_data();
         const auto& nested_column = assert_cast<const ColumnNullable&, 
TypeCheckOnRelease::DISABLE>(
                                             quantile_array.get_data())
                                             .get_nested_column();
@@ -497,7 +513,7 @@ public:
                 assert_cast<const ColumnFloat64&, 
TypeCheckOnRelease::DISABLE>(nested_column);
 
         AggregateFunctionPercentileArray::data(place).add(
-                sources.get_int(row_num), nested_column_data.get_data(),
+                sources.get_int(row_num), nested_column_data.get_data(), 
null_maps,
                 offset_column_data.data()[row_num] - 
offset_column_data[(ssize_t)row_num - 1]);
     }
 
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
index aa33e7289df..0a0ec6abe16 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
@@ -61,15 +61,6 @@ private:
     std::unordered_map<std::string, std::string> function_alias;
 
 public:
-    void register_nullable_function_combinator(const Creator& creator) {
-        for (const auto& entity : aggregate_functions) {
-            if (nullable_aggregate_functions.find(entity.first) ==
-                nullable_aggregate_functions.end()) {
-                nullable_aggregate_functions[entity.first] = creator;
-            }
-        }
-    }
-
     static bool is_foreach(const std::string& name) {
         constexpr std::string_view suffix = "_foreach";
         if (name.length() < suffix.length()) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java
index 4412d96006f..49a0f836aed 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.trees.expressions.functions.agg;
 
 import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
 import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
@@ -69,6 +70,14 @@ public class PercentileArray extends 
NotNullableAggregateFunction
         super("percentile_array", distinct, arg0, arg1);
     }
 
+    @Override
+    public void checkLegalityBeforeTypeCoercion() {
+        if (!getArgument(1).isConstant()) {
+            throw new AnalysisException(
+                    "percentile_array requires second parameter must be a 
constant : " + this.toSql());
+        }
+    }
+
     /**
      * withDistinctAndChildren.
      */
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
index a6d011ff0fb..ddd92f894e1 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
@@ -30,8 +30,11 @@ import org.apache.doris.nereids.types.DataType;
 
 import com.google.common.collect.ImmutableList;
 
+import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
+import java.util.Set;
 
 /**
  * combinator foreach
@@ -39,6 +42,15 @@ import java.util.Objects;
 public class ForEachCombinator extends NullableAggregateFunction
         implements UnaryExpression, ExplicitlyCastableSignature, Combinator {
 
+    public static final Set<String> UNSUPPORTED_AGGREGATE_FUNCTION = 
Collections.unmodifiableSet(new HashSet<String>() {
+        {
+            add("percentile");
+            add("percentile_array");
+            add("percentile_approx");
+            add("percentile_approx_weighted");
+        }
+    });
+
     private final AggregateFunction nested;
 
     /**
@@ -48,10 +60,27 @@ public class ForEachCombinator extends 
NullableAggregateFunction
         this(arguments, false, nested);
     }
 
+    /**
+     * Constructs a new instance of {@code ForEachCombinator}.
+     *
+     * <p>This constructor initializes a combinator that will iterate over 
each item in the input list
+     * and apply the nested aggregate function.
+     * If the provided aggregate function name is within the list of 
unsupported functions,
+     * an {@link UnsupportedOperationException} will be thrown.
+     *
+     * @param arguments A list of {@code Expression} objects that serve as 
parameters to the aggregate function.
+     * @param alwaysNullable A boolean flag indicating whether this combinator 
should always return a nullable result.
+     * @param nested The nested aggregate function to apply to each element. 
It must not be {@code null}.
+     * @throws NullPointerException If the provided nested aggregate function 
is {@code null}.
+     * @throws UnsupportedOperationException If nested aggregate function is 
one of the unsupported aggregate functions
+     */
     public ForEachCombinator(List<Expression> arguments, boolean 
alwaysNullable, AggregateFunction nested) {
         super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, 
false, alwaysNullable, arguments);
 
         this.nested = Objects.requireNonNull(nested, "nested can not be null");
+        if 
(UNSUPPORTED_AGGREGATE_FUNCTION.contains(nested.getName().toLowerCase())) {
+            throw new UnsupportedOperationException("Unsupport the func:" + 
nested.getName() + " use in foreach");
+        }
     }
 
     public static ForEachCombinator create(AggregateFunction nested) {
diff --git a/regression-test/data/function_p0/test_agg_foreach.out 
b/regression-test/data/function_p0/test_agg_foreach.out
index c45ae9f67a9..693009d3890 100644
Binary files a/regression-test/data/function_p0/test_agg_foreach.out and 
b/regression-test/data/function_p0/test_agg_foreach.out differ
diff --git a/regression-test/data/function_p0/test_agg_foreach_notnull.out 
b/regression-test/data/function_p0/test_agg_foreach_notnull.out
index c45ae9f67a9..693009d3890 100644
Binary files a/regression-test/data/function_p0/test_agg_foreach_notnull.out 
and b/regression-test/data/function_p0/test_agg_foreach_notnull.out differ
diff --git 
a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out
 
b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out
index 75d9a18679f..90953b0a11c 100644
Binary files 
a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out
 and 
b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out
 differ
diff --git a/regression-test/suites/function_p0/test_agg_foreach.groovy 
b/regression-test/suites/function_p0/test_agg_foreach.groovy
index 281fdea6a3b..fad9925af81 100644
--- a/regression-test/suites/function_p0/test_agg_foreach.groovy
+++ b/regression-test/suites/function_p0/test_agg_foreach.groovy
@@ -87,18 +87,24 @@ suite("test_agg_foreach") {
    select histogram_foreach(a) from foreach_table;
    """
    
-   qt_sql """
-      select PERCENTILE_foreach(a,a)  from foreach_table;
-   """
+   try {
+        sql "select PERCENTILE_foreach(a,a)  from foreach_table;"
+   } catch (Exception ex) {
+        assert("${ex}".contains("Unsupport the func"))
+   }
   
-   qt_sql """
-      select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1;
-   """
 
-   qt_sql """
-
-   select PERCENTILE_APPROX_foreach(a,a) from foreach_table;
-   """
+   try {
+        sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id 
= 1;"
+   } catch (Exception ex) {
+        assert("${ex}".contains("Unsupport the func"))
+   }
+
+   try {
+       sql "select PERCENTILE_APPROX_foreach(a,a) from foreach_table;"
+   } catch (Exception ex) {
+        assert("${ex}".contains("Unsupport the func"))
+   }
 
    qt_sql """
    select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), 
GROUP_BIT_XOR_foreach(a)  from foreach_table;
diff --git a/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy 
b/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy
index 91f4ea902dd..68f27e6d049 100644
--- a/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy
+++ b/regression-test/suites/function_p0/test_agg_foreach_notnull.groovy
@@ -85,20 +85,26 @@ suite("test_agg_foreach_not_null") {
    qt_sql """
    select histogram_foreach(a) from foreach_table_not_null;
    """
-   
-   qt_sql """
-      select PERCENTILE_foreach(a,a)  from foreach_table_not_null;
-   """
-  
-   qt_sql """
-      select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null where 
id = 1;
-   """
-
-   qt_sql """
 
-   select PERCENTILE_APPROX_foreach(a,a) from foreach_table_not_null;
-   """
+   try {
+        sql "select PERCENTILE_foreach(a,a)  from foreach_table_not_null;"
+   } catch (Exception ex) {
+        assert("${ex}".contains("Unsupport the func"))
+   }
+  
 
+   try {
+        sql "select PERCENTILE_ARRAY_foreach(a,b) from foreach_table_not_null 
where id = 1;"
+   } catch (Exception ex) {
+        assert("${ex}".contains("Unsupport the func"))
+   }
+
+   try {
+        sql "select PERCENTILE_APPROX_foreach(a,a) from 
foreach_table_not_null;"
+   } catch (Exception ex) {
+        assert("${ex}".contains("Unsupport the func"))
+   }
+ 
    qt_sql """
    select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a), 
GROUP_BIT_XOR_foreach(a)  from foreach_table_not_null;
    """
diff --git a/regression-test/suites/query_p0/aggregate/aggregate.groovy 
b/regression-test/suites/query_p0/aggregate/aggregate.groovy
index b611ff92b0e..6079d09577f 100644
--- a/regression-test/suites/query_p0/aggregate/aggregate.groovy
+++ b/regression-test/suites/query_p0/aggregate/aggregate.groovy
@@ -141,6 +141,23 @@ suite("aggregate") {
     qt_aggregate32" select topn_weighted(c_string,c_bigint,3) from 
${tableName}"
     qt_aggregate33" select avg_weighted(c_double,c_bigint) from ${tableName};"
     qt_aggregate34" select percentile_array(c_bigint,[0.2,0.5,0.9]) from 
${tableName};"
+    
+    try {
+        sql "select percentile_array(c_bigint,[-1,0.5,0.9]) from ${tableName};"
+    } catch (Exception ex) {
+        assert("${ex}".contains("-1"))
+    }
+    try {
+        sql "select percentile_array(c_bigint,[0.5,0.9,3000]) from 
${tableName};"
+    } catch (Exception ex) {
+        assert("${ex}".contains("3000"))
+    }
+    try {
+        sql "select percentile_array(c_bigint,[0.5,0.9,null]) from 
${tableName};"
+    } catch (Exception ex) {
+        assert("${ex}".contains("null"))
+    }
+
     qt_aggregate """
                 SELECT c_bigint,  
                     CASE
diff --git 
a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy
 
b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy
index cdab9472e27..c64d33e1e82 100644
--- 
a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy
+++ 
b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy
@@ -286,6 +286,18 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") {
     qt_select20_1 "select id,percentile(level + 0.1,0.5) from ${tableName_13} 
group by id order by id"
     qt_select21_1 "select id,percentile(level + 0.1,0.55) from ${tableName_13} 
group by id order by id"
     qt_select22_1 "select id,percentile(level + 0.1,0.805) from 
${tableName_13} group by id order by id"
+    qt_select22_1_1 "select id,percentile(level + 0.1, null) from 
${tableName_13} group by id order by id"
+
+    try {
+        sql "select id,percentile(level + 0.1, -1) from ${tableName_13} group 
by id order by id"
+    } catch (Exception ex) {
+        assert("${ex}".contains("-1"))
+    }
+    try {
+        sql "select id,percentile(level + 0.1, 3000) from ${tableName_13} 
group by id order by id"
+    } catch (Exception ex) {
+        assert("${ex}".contains("3000"))
+    }
 
     sql "DROP TABLE IF EXISTS ${tableName_13}"
 
@@ -313,6 +325,18 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") {
     qt_select26 "select id,PERCENTILE_APPROX(level,0.5,2048) from 
${tableName_14} group by id order by id"
     qt_select27 "select id,PERCENTILE_APPROX(level,0.55,2048) from 
${tableName_14} group by id order by id"
     qt_select28 "select id,PERCENTILE_APPROX(level,0.805,2048) from 
${tableName_14} group by id order by id"
+    qt_select28_1 "select id,PERCENTILE_APPROX(level, null ,2048) from 
${tableName_14} group by id order by id"
+
+    try {
+        sql "select id,PERCENTILE_APPROX(level, -1, 2048) from ${tableName_14} 
group by id order by id"
+    } catch (Exception ex) {
+        assert("${ex}".contains("-1"))
+    }
+    try {
+        sql "select id,PERCENTILE_APPROX(level, 3000 ,2048) from 
${tableName_14} group by id order by id"
+    } catch (Exception ex) {
+        assert("${ex}".contains("3000"))
+    }
 
     sql "DROP TABLE IF EXISTS ${tableName_14}"
     


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to