This is an automated email from the ASF dual-hosted git repository. lihaopeng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new ebf2b88f039 [optimization](agg) add more type in agg percentile (#34423) ebf2b88f039 is described below commit ebf2b88f03955a9bbc35af8ec55293f3104c069f Author: Mryange <59914473+mrya...@users.noreply.github.com> AuthorDate: Wed May 8 14:32:40 2024 +0800 [optimization](agg) add more type in agg percentile (#34423) --- be/src/util/counts.h | 19 ++--- .../aggregate_function_percentile.cpp | 8 +- .../aggregate_function_percentile.h | 35 ++++---- be/test/util/counts_test.cpp | 12 +-- .../expressions/functions/agg/Percentile.java | 11 ++- .../expressions/functions/agg/PercentileArray.java | 15 +++- .../data/mv_p0/mv_percentile/mv_percentile.out | 6 +- .../test_aggregate_percentile_no_cast.out | 36 ++++++++ .../mv_p0/mv_percentile/mv_percentile.groovy | 2 +- .../test_aggregate_percentile_no_cast.groovy | 97 ++++++++++++++++++++++ 10 files changed, 201 insertions(+), 40 deletions(-) diff --git a/be/src/util/counts.h b/be/src/util/counts.h index 70469d6fa72..e479f04c620 100644 --- a/be/src/util/counts.h +++ b/be/src/util/counts.h @@ -138,8 +138,7 @@ public: private: std::unordered_map<int64_t, uint32_t> _counts; }; - -// #TODO use template to reduce the Counts memery. Eg: Int do not need use int64_t +template <typename Ty> class Counts { public: Counts() = default; @@ -150,7 +149,7 @@ public: } } - void increment(int64_t key, uint32_t i) { + void increment(Ty key, uint32_t i) { auto old_size = _nums.size(); _nums.resize(_nums.size() + i); for (uint32_t j = 0; j < i; ++j) { @@ -163,7 +162,7 @@ public: pdqsort(_nums.begin(), _nums.end()); size_t size = _nums.size(); write_binary(size, buf); - buf.write(reinterpret_cast<const char*>(_nums.data()), sizeof(int64_t) * size); + buf.write(reinterpret_cast<const char*>(_nums.data()), sizeof(Ty) * size); } else { // convert _sorted_nums_vec to _nums and do seiralize again _convert_sorted_num_vec_to_nums(); @@ -175,7 +174,7 @@ public: size_t size; read_binary(size, buf); _nums.resize(size); - auto buff = buf.read(sizeof(int64_t) * size); + auto buff = buf.read(sizeof(Ty) * size); memcpy(_nums.data(), buff.data, buff.size); } @@ -231,7 +230,7 @@ public: private: struct Node { - int64_t value; + Ty value; int array_index; int64_t element_index; @@ -265,8 +264,8 @@ private: _sorted_nums_vec.clear(); } - std::pair<int64_t, int64_t> _merge_sort_and_get_numbers(int64_t target, bool reverse) { - int64_t first_number = 0, second_number = 0; + std::pair<Ty, Ty> _merge_sort_and_get_numbers(int64_t target, bool reverse) { + Ty first_number = 0, second_number = 0; size_t count = 0; if (reverse) { std::priority_queue<Node> max_heap; @@ -321,8 +320,8 @@ private: return {first_number, second_number}; } - vectorized::PODArray<int64_t> _nums; - std::vector<vectorized::PODArray<int64_t>> _sorted_nums_vec; + vectorized::PODArray<Ty> _nums; + std::vector<vectorized::PODArray<Ty>> _sorted_nums_vec; }; } // namespace doris diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp index 079b1da83ff..afadb5b8dca 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp @@ -19,6 +19,7 @@ #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/helpers.h" +#include "vec/core/types.h" namespace doris::vectorized { @@ -50,9 +51,10 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("percentile", - creator_without_type::creator<AggregateFunctionPercentile>); - factory.register_function_both("percentile_array", - creator_without_type::creator<AggregateFunctionPercentileArray>); + creator_with_integer_type::creator<AggregateFunctionPercentile>); + factory.register_function_both( + "percentile_array", + creator_with_integer_type::creator<AggregateFunctionPercentileArray>); } void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.h b/be/src/vec/aggregate_functions/aggregate_function_percentile.h index 6322a80c934..231057158ce 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.h @@ -283,8 +283,9 @@ public: } }; +template <typename T> struct PercentileState { - mutable std::vector<Counts> vec_counts; + mutable std::vector<Counts<T>> vec_counts; std::vector<double> vec_quantile {-1}; bool inited_flag = false; @@ -317,7 +318,7 @@ struct PercentileState { } } - void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) { if (!inited_flag) { vec_counts.resize(arg_size); vec_quantile.resize(arg_size, -1); @@ -346,7 +347,7 @@ struct PercentileState { if (vec_quantile[i] == -1.0) { vec_quantile[i] = rhs.vec_quantile[i]; } - vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i]))); + vec_counts[i].merge(const_cast<Counts<T>*>(&(rhs.vec_counts[i]))); } } @@ -366,12 +367,13 @@ struct PercentileState { } }; +template <typename T> class AggregateFunctionPercentile final - : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> { + : public IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>> { public: - AggregateFunctionPercentile(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>( - argument_types_) {} + using ColVecType = ColumnVector<T>; + using Base = IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>; + AggregateFunctionPercentile(const DataTypes& argument_types_) : Base(argument_types_) {} String get_name() const override { return "percentile"; } @@ -379,10 +381,10 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& sources = assert_cast<const ColVecType&>(*columns[0]); const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); - AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(), - 1); + AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num], + quantile.get_data(), 1); } void reset(AggregateDataPtr __restrict place) const override { @@ -409,12 +411,15 @@ public: } }; +template <typename T> class AggregateFunctionPercentileArray final - : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray> { + : public IAggregateFunctionDataHelper<PercentileState<T>, + AggregateFunctionPercentileArray<T>> { public: - AggregateFunctionPercentileArray(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray>( - argument_types_) {} + using ColVecType = ColumnVector<T>; + using Base = + IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentileArray<T>>; + AggregateFunctionPercentileArray(const DataTypes& argument_types_) : Base(argument_types_) {} String get_name() const override { return "percentile_array"; } @@ -424,7 +429,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& sources = assert_cast<const ColVecType&>(*columns[0]); const auto& quantile_array = assert_cast<const ColumnArray&>(*columns[1]); const auto& offset_column_data = quantile_array.get_offsets(); const auto& nested_column = diff --git a/be/test/util/counts_test.cpp b/be/test/util/counts_test.cpp index 20d9ea54c97..d60f235e788 100644 --- a/be/test/util/counts_test.cpp +++ b/be/test/util/counts_test.cpp @@ -20,6 +20,8 @@ #include <gtest/gtest-message.h> #include <gtest/gtest-test-part.h> +#include <cstdint> + #include "gtest/gtest_pred_impl.h" namespace doris { @@ -27,7 +29,7 @@ namespace doris { class TCountsTest : public testing::Test {}; TEST_F(TCountsTest, TotalTest) { - Counts counts; + Counts<int64_t> counts; // 1 1 1 2 5 7 7 9 9 19 // >>> import numpy as np // >>> a = np.array([1,1,1,2,5,7,7,9,9,19]) @@ -48,14 +50,14 @@ TEST_F(TCountsTest, TotalTest) { counts.serialize(bw); bw.commit(); - Counts other; + Counts<int64_t> other; StringRef res(cs->get_chars().data(), cs->get_chars().size()); vectorized::BufferReadable br(res); other.unserialize(br); double result1 = other.terminate(0.2); EXPECT_EQ(result, result1); - Counts other1; + Counts<int64_t> other1; other1.increment(1, 1); other1.increment(100, 3); other1.increment(50, 3); @@ -66,11 +68,11 @@ TEST_F(TCountsTest, TotalTest) { cs->clear(); other1.serialize(bw); bw.commit(); - Counts other1_deserialized; + Counts<int64_t> other1_deserialized; vectorized::BufferReadable br1(res); other1_deserialized.unserialize(br1); - Counts merge_res; + Counts<int64_t> merge_res; merge_res.merge(&other); merge_res.merge(&other1_deserialized); // 1 1 1 1 2 5 7 7 9 9 10 19 50 50 50 99 99 100 100 100 diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java index d85b69516a8..abc0498f6e1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java @@ -26,6 +26,10 @@ import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -39,7 +43,12 @@ public class Percentile extends AggregateFunction implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( - FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE) + FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE) + ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java index d4d8ed6c39a..61fcaf3b4c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java @@ -26,6 +26,10 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -40,8 +44,15 @@ public class PercentileArray extends AggregateFunction public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) - .args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)) - ); + .args(LargeIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(BigIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(IntegerType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(SmallIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE)), + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(TinyIntType.INSTANCE, ArrayType.of(DoubleType.INSTANCE))); /** * constructor with 2 arguments. diff --git a/regression-test/data/mv_p0/mv_percentile/mv_percentile.out b/regression-test/data/mv_p0/mv_percentile/mv_percentile.out index 32e5595dac7..858d558c346 100644 --- a/regression-test/data/mv_p0/mv_percentile/mv_percentile.out +++ b/regression-test/data/mv_p0/mv_percentile/mv_percentile.out @@ -1,9 +1,9 @@ -- This file is automatically generated. You should know what you did if you want to edit this -- !select_star -- \N 4 \N d --4 -4 -4.000000 d -1 1 1.000000 a -2 2 2.000000 b +-4 -4 -4 d +1 1 1 a +2 2 2 b 3 -3 \N c -- !select_mv -- diff --git a/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.out b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.out new file mode 100644 index 00000000000..1764ba21dde --- /dev/null +++ b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.out @@ -0,0 +1,36 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select -- +1 10.0 10.0 10.0 [10, 10, 10] +2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33] +3 10.0 10.0 10.0 [10, 10, 10] +5 29.0 29.0 29.0 [29, 29, 29] +6 101.0 101.0 101.0 [101, 101, 101] + +-- !select -- +1 10.0 10.0 10.0 [10, 10, 10] +2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33] +3 10.0 10.0 10.0 [10, 10, 10] +5 29.0 29.0 29.0 [29, 29, 29] +6 101.0 101.0 101.0 [101, 101, 101] + +-- !select -- +1 10.0 10.0 10.0 [10, 10, 10] +2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33] +3 10.0 10.0 10.0 [10, 10, 10] +5 29.0 29.0 29.0 [29, 29, 29] +6 101.0 101.0 101.0 [101, 101, 101] + +-- !select -- +1 10.0 10.0 10.0 [10, 10, 10] +2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33] +3 10.0 10.0 10.0 [10, 10, 10] +5 29.0 29.0 29.0 [29, 29, 29] +6 101.0 101.0 101.0 [101, 101, 101] + +-- !select -- +1 10.0 10.0 10.0 [10, 10, 10] +2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33] +3 10.0 10.0 10.0 [10, 10, 10] +5 29.0 29.0 29.0 [29, 29, 29] +6 101.0 101.0 101.0 [101, 101, 101] + diff --git a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy b/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy index dd6cb453305..e4624d29f00 100644 --- a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy +++ b/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy @@ -26,7 +26,7 @@ suite ("mv_percentile") { create table d_table( k1 int null, k2 int not null, - k3 decimal(28,6) null, + k3 bigint null, k4 varchar(100) null ) duplicate key (k1,k2,k3) diff --git a/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.groovy b/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.groovy new file mode 100644 index 00000000000..ef76aee4405 --- /dev/null +++ b/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.groovy @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_aggregate_percentile_no_cast") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + sql "set batch_size = 4096" + + sql "DROP TABLE IF EXISTS percentile_test_db" + sql """ + CREATE TABLE IF NOT EXISTS percentile_test_db ( + id int, + level tinyint + ) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10) ,(5,29) ,(6,101)" + + qt_select "select id,percentile(level,0.5) , percentile(level,0.55) , percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from percentile_test_db group by id order by id" + + sql "DROP TABLE IF EXISTS percentile_test_db" + sql """ + CREATE TABLE IF NOT EXISTS percentile_test_db ( + id int, + level smallint + ) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10) ,(5,29) ,(6,101)" + + qt_select "select id,percentile(level,0.5) , percentile(level,0.55) , percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from percentile_test_db group by id order by id" + + sql "DROP TABLE IF EXISTS percentile_test_db" + sql """ + CREATE TABLE IF NOT EXISTS percentile_test_db ( + id int, + level int + ) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10) ,(5,29) ,(6,101)" + + qt_select "select id,percentile(level,0.5) , percentile(level,0.55) , percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from percentile_test_db group by id order by id" + + sql "DROP TABLE IF EXISTS percentile_test_db" + sql """ + CREATE TABLE IF NOT EXISTS percentile_test_db ( + id int, + level bigint + ) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10) ,(5,29) ,(6,101)" + qt_select "select id,percentile(level,0.5) , percentile(level,0.55) , percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from percentile_test_db group by id order by id" + + sql "DROP TABLE IF EXISTS percentile_test_db" + sql """ + CREATE TABLE IF NOT EXISTS percentile_test_db ( + id int, + level largeint + ) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10) ,(5,29) ,(6,101)" + qt_select "select id,percentile(level,0.5) , percentile(level,0.55) , percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from percentile_test_db group by id order by id" + +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org