github-actions[bot] commented on code in PR #34390: URL: https://github.com/apache/doris/pull/34390#discussion_r1589984710
########## be/src/util/counts.h: ########## @@ -135,4 +139,190 @@ std::unordered_map<int64_t, uint32_t> _counts; }; +// #TODO use template to reduce the Counts memery. Eg: Int do not need use int64_t +class Counts { +public: + Counts() = default; + + void merge(Counts* other) { + if (other != nullptr && !other->_nums.empty()) { + _sorted_nums_vec.emplace_back(std::move(other->_nums)); + } + } + + void increment(int64_t key, uint32_t i) { Review Comment: warning: method 'increment' can be made static [readability-convert-member-functions-to-static] ```suggestion static void increment(int64_t key, uint32_t i) { ``` ########## be/src/util/counts.h: ########## @@ -135,4 +139,190 @@ std::unordered_map<int64_t, uint32_t> _counts; }; +// #TODO use template to reduce the Counts memery. Eg: Int do not need use int64_t +class Counts { +public: + Counts() = default; + + void merge(Counts* other) { + if (other != nullptr && !other->_nums.empty()) { + _sorted_nums_vec.emplace_back(std::move(other->_nums)); + } + } + + void increment(int64_t key, uint32_t i) { + auto old_size = _nums.size(); + _nums.resize(_nums.size() + i); + for (uint32_t j = 0; j < i; ++j) { + _nums[old_size + j] = key; + } + } + + void serialize(vectorized::BufferWritable& buf) { + if (!_nums.empty()) { + pdqsort(_nums.begin(), _nums.end()); + size_t size = _nums.size(); + write_binary(size, buf); + buf.write(reinterpret_cast<const char*>(_nums.data()), sizeof(int64_t) * size); + } else { + // convert _sorted_nums_vec to _nums and do seiralize again + _convert_sorted_num_vec_to_nums(); + serialize(buf); + } + } + + void unserialize(vectorized::BufferReadable& buf) { Review Comment: warning: method 'unserialize' can be made static [readability-convert-member-functions-to-static] ```suggestion static void unserialize(vectorized::BufferReadable& buf) { ``` ########## be/src/util/counts.h: ########## @@ -135,4 +139,190 @@ std::unordered_map<int64_t, uint32_t> _counts; }; +// #TODO use template to reduce the Counts memery. Eg: Int do not need use int64_t +class Counts { +public: + Counts() = default; + + void merge(Counts* other) { + if (other != nullptr && !other->_nums.empty()) { + _sorted_nums_vec.emplace_back(std::move(other->_nums)); + } + } + + void increment(int64_t key, uint32_t i) { + auto old_size = _nums.size(); + _nums.resize(_nums.size() + i); + for (uint32_t j = 0; j < i; ++j) { + _nums[old_size + j] = key; + } + } + + void serialize(vectorized::BufferWritable& buf) { + if (!_nums.empty()) { + pdqsort(_nums.begin(), _nums.end()); + size_t size = _nums.size(); + write_binary(size, buf); + buf.write(reinterpret_cast<const char*>(_nums.data()), sizeof(int64_t) * size); + } else { + // convert _sorted_nums_vec to _nums and do seiralize again + _convert_sorted_num_vec_to_nums(); + serialize(buf); + } + } + + void unserialize(vectorized::BufferReadable& buf) { + size_t size; + read_binary(size, buf); + _nums.resize(size); + auto buff = buf.read(sizeof(int64_t) * size); + memcpy(_nums.data(), buff.data, buff.size); + } + + double terminate(double quantile) { + if (_sorted_nums_vec.size() <= 1) { + if (_sorted_nums_vec.size() == 1) { + _nums = std::move(_sorted_nums_vec[0]); + } + + if (_nums.empty()) { + // Although set null here, but the value is 0.0 and the call method just + // get val in aggregate_function_percentile_approx.h + return 0.0; + } + if (quantile == 1 || _nums.size() == 1) { + return _nums.back(); + } + if (UNLIKELY(!std::is_sorted(_nums.begin(), _nums.end()))) { + pdqsort(_nums.begin(), _nums.end()); + } + + double u = (_nums.size() - 1) * quantile; + auto index = static_cast<uint32_t>(u); + return _nums[index] + + (u - static_cast<double>(index)) * (_nums[index + 1] - _nums[index]); + } else { + DCHECK(_nums.empty()); + size_t rows = 0; + for (const auto& i : _sorted_nums_vec) { + rows += i.size(); + } + const bool reverse = quantile > 0.5 && rows > 2; + double u = (rows - 1) * quantile; + auto index = static_cast<uint32_t>(u); + // if reverse, the step of target should start 0 like not reverse + // so here rows need to minus index + 2 + // eg: rows = 10, index = 5 + // if not reverse, so the first number loc is 5, the second number loc is 6 + // if reverse, so the second number is 3, the first number is 4 + // 5 + 4 = 3 + 6 = 9 = rows - 1. + // the rows must GE 2 beacuse `_sorted_nums_vec` size GE 2 + size_t target = reverse ? rows - index - 2 : index; + if (quantile == 1) { + target = 0; + } + auto [first_number, second_number] = _merge_sort_and_get_numbers(target, reverse); + if (quantile == 1) { + return second_number; + } + return first_number + (u - static_cast<double>(index)) * (second_number - first_number); + } + } + +private: + struct Node { + int64_t value; + int array_index; + int64_t element_index; + + std::strong_ordering operator<=>(const Node& other) const { return value <=> other.value; } + }; + + void _convert_sorted_num_vec_to_nums() { Review Comment: warning: method '_convert_sorted_num_vec_to_nums' can be made static [readability-convert-member-functions-to-static] ```suggestion static void _convert_sorted_num_vec_to_nums() { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> Review Comment: warning: inclusion of deprecated C++ header 'stdint.h'; consider using 'cstdint' instead [modernize-deprecated-headers] ```suggestion #include <cstdint> ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileApprox::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).merge( + AggregateFunctionPercentileApprox::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileApprox::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); + double result = AggregateFunctionPercentileApprox::data(place).get(); + + if (std::isnan(result)) { + nullable_column.insert_default(); + } else { + auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); + col.get_data().push_back(result); + nullable_column.get_null_map_data().push_back(0); + } + } +}; + +// only for merge +template <bool is_nullable> +class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; + __builtin_unreachable(); + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[2] = {0, 0}; + + for (int i = 0; i < 2; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + + this->data(place).init(); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[3] = {0, 0, 0}; + + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[2]); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +struct PercentileState { + mutable std::vector<Counts> vec_counts; + std::vector<double> vec_quantile {-1}; + bool inited_flag = false; + + void write(BufferWritable& buf) const { + write_binary(inited_flag, buf); + int size_num = vec_quantile.size(); + write_binary(size_num, buf); + for (const auto& quantile : vec_quantile) { + write_binary(quantile, buf); + } + for (auto& counts : vec_counts) { + counts.serialize(buf); + } + } + + void read(BufferReadable& buf) { + read_binary(inited_flag, buf); + int size_num = 0; + read_binary(size_num, buf); + double data = 0.0; + vec_quantile.clear(); + for (int i = 0; i < size_num; ++i) { + read_binary(data, buf); + vec_quantile.emplace_back(data); + } + vec_counts.clear(); + vec_counts.resize(size_num); + for (int i = 0; i < size_num; ++i) { + vec_counts[i].unserialize(buf); + } + } + + void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + if (!inited_flag) { + vec_counts.resize(arg_size); + vec_quantile.resize(arg_size, -1); + inited_flag = true; + for (int i = 0; i < arg_size; ++i) { + vec_quantile[i] = quantiles[i]; + } + } + for (int i = 0; i < arg_size; ++i) { + vec_counts[i].increment(source, 1); + } + } + + void merge(const PercentileState& rhs) { + if (!rhs.inited_flag) { + return; + } + int size_num = rhs.vec_quantile.size(); + if (!inited_flag) { + vec_counts.resize(size_num); + vec_quantile.resize(size_num, -1); + inited_flag = true; + } + + for (int i = 0; i < size_num; ++i) { + if (vec_quantile[i] == -1.0) { + vec_quantile[i] = rhs.vec_quantile[i]; + } + vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i]))); + } + } + + void reset() { + vec_counts.clear(); + vec_quantile.clear(); + inited_flag = false; + } + + double get() const { return vec_counts[0].terminate(vec_quantile[0]); } + + void insert_result_into(IColumn& to) const { + auto& column_data = assert_cast<ColumnVector<Float64>&>(to).get_data(); + for (int i = 0; i < vec_counts.size(); ++i) { + column_data.push_back(vec_counts[i].terminate(vec_quantile[i])); + } + } +}; + +class AggregateFunctionPercentile final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> { +public: + AggregateFunctionPercentile(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>( + argument_types_) {} + + String get_name() const override { return "percentile"; } + + DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(), + 1); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentile::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentile::data(place).merge(AggregateFunctionPercentile::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentile::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentile::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + auto& col = assert_cast<ColumnVector<Float64>&>(to); + col.insert_value(AggregateFunctionPercentile::data(place).get()); + } +}; + +class AggregateFunctionPercentileArray final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray> { +public: + AggregateFunctionPercentileArray(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray>( + argument_types_) {} + + String get_name() const override { return "percentile_array"; } + + DataTypePtr get_return_type() const override { + return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat64>())); + } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& quantile_array = assert_cast<const ColumnArray&>(*columns[1]); + const auto& offset_column_data = quantile_array.get_offsets(); + const auto& nested_column = + assert_cast<const ColumnNullable&>(quantile_array.get_data()).get_nested_column(); + const auto& nested_column_data = assert_cast<const ColumnVector<Float64>&>(nested_column); + + AggregateFunctionPercentileArray::data(place).add( + sources.get_int(row_num), nested_column_data.get_data(), + offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileArray::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileArray::data(place).merge( + AggregateFunctionPercentileArray::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileArray::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileArray::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { Review Comment: warning: method 'insert_result_into' can be made static [readability-convert-member-functions-to-static] ```suggestion static void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) override { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileApprox::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).merge( + AggregateFunctionPercentileApprox::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileApprox::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); + double result = AggregateFunctionPercentileApprox::data(place).get(); + + if (std::isnan(result)) { + nullable_column.insert_default(); + } else { + auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); + col.get_data().push_back(result); + nullable_column.get_null_map_data().push_back(0); + } + } +}; + +// only for merge +template <bool is_nullable> +class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; + __builtin_unreachable(); + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[2] = {0, 0}; + + for (int i = 0; i < 2; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + + this->data(place).init(); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[3] = {0, 0, 0}; + + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[2]); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +struct PercentileState { + mutable std::vector<Counts> vec_counts; + std::vector<double> vec_quantile {-1}; + bool inited_flag = false; + + void write(BufferWritable& buf) const { + write_binary(inited_flag, buf); + int size_num = vec_quantile.size(); + write_binary(size_num, buf); + for (const auto& quantile : vec_quantile) { + write_binary(quantile, buf); + } + for (auto& counts : vec_counts) { + counts.serialize(buf); + } + } + + void read(BufferReadable& buf) { + read_binary(inited_flag, buf); + int size_num = 0; + read_binary(size_num, buf); + double data = 0.0; + vec_quantile.clear(); + for (int i = 0; i < size_num; ++i) { + read_binary(data, buf); + vec_quantile.emplace_back(data); + } + vec_counts.clear(); + vec_counts.resize(size_num); + for (int i = 0; i < size_num; ++i) { + vec_counts[i].unserialize(buf); + } + } + + void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + if (!inited_flag) { + vec_counts.resize(arg_size); + vec_quantile.resize(arg_size, -1); + inited_flag = true; + for (int i = 0; i < arg_size; ++i) { + vec_quantile[i] = quantiles[i]; + } + } + for (int i = 0; i < arg_size; ++i) { + vec_counts[i].increment(source, 1); + } + } + + void merge(const PercentileState& rhs) { + if (!rhs.inited_flag) { + return; + } + int size_num = rhs.vec_quantile.size(); + if (!inited_flag) { + vec_counts.resize(size_num); + vec_quantile.resize(size_num, -1); + inited_flag = true; + } + + for (int i = 0; i < size_num; ++i) { + if (vec_quantile[i] == -1.0) { + vec_quantile[i] = rhs.vec_quantile[i]; + } + vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i]))); + } + } + + void reset() { + vec_counts.clear(); + vec_quantile.clear(); + inited_flag = false; + } + + double get() const { return vec_counts[0].terminate(vec_quantile[0]); } + + void insert_result_into(IColumn& to) const { Review Comment: warning: method 'insert_result_into' can be made static [readability-convert-member-functions-to-static] ```suggestion static void insert_result_into(IColumn& to) { ``` ########## be/src/util/counts.h: ########## @@ -135,4 +139,190 @@ std::unordered_map<int64_t, uint32_t> _counts; }; +// #TODO use template to reduce the Counts memery. Eg: Int do not need use int64_t +class Counts { +public: + Counts() = default; + + void merge(Counts* other) { Review Comment: warning: method 'merge' can be made static [readability-convert-member-functions-to-static] ```suggestion static void merge(Counts* other) { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> Review Comment: warning: 'glog/logging.h' file not found [clang-diagnostic-error] ```cpp #include <glog/logging.h> ^ ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> Review Comment: warning: inclusion of deprecated C++ header 'stddef.h'; consider using 'cstddef' instead [modernize-deprecated-headers] ```suggestion #include <cstddef> ``` ########## be/src/util/counts.h: ########## @@ -17,20 +17,24 @@ #pragma once +#include <pdqsort.h> Review Comment: warning: 'pdqsort.h' file not found [clang-diagnostic-error] ```cpp #include <pdqsort.h> ^ ``` ########## be/src/util/counts.h: ########## @@ -135,4 +139,190 @@ std::unordered_map<int64_t, uint32_t> _counts; }; +// #TODO use template to reduce the Counts memery. Eg: Int do not need use int64_t +class Counts { +public: + Counts() = default; + + void merge(Counts* other) { + if (other != nullptr && !other->_nums.empty()) { + _sorted_nums_vec.emplace_back(std::move(other->_nums)); + } + } + + void increment(int64_t key, uint32_t i) { + auto old_size = _nums.size(); + _nums.resize(_nums.size() + i); + for (uint32_t j = 0; j < i; ++j) { + _nums[old_size + j] = key; + } + } + + void serialize(vectorized::BufferWritable& buf) { + if (!_nums.empty()) { + pdqsort(_nums.begin(), _nums.end()); + size_t size = _nums.size(); + write_binary(size, buf); + buf.write(reinterpret_cast<const char*>(_nums.data()), sizeof(int64_t) * size); + } else { + // convert _sorted_nums_vec to _nums and do seiralize again + _convert_sorted_num_vec_to_nums(); + serialize(buf); + } + } + + void unserialize(vectorized::BufferReadable& buf) { + size_t size; + read_binary(size, buf); + _nums.resize(size); + auto buff = buf.read(sizeof(int64_t) * size); + memcpy(_nums.data(), buff.data, buff.size); + } + + double terminate(double quantile) { Review Comment: warning: method 'terminate' can be made static [readability-convert-member-functions-to-static] ```suggestion static double terminate(double quantile) { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { Review Comment: warning: method 'get_return_type' can be made static [readability-convert-member-functions-to-static] ```suggestion static DataTypePtr get_return_type() override { ``` ########## be/src/util/counts.h: ########## @@ -17,20 +17,24 @@ #pragma once +#include <pdqsort.h> + #include <algorithm> #include <cmath> -#include <unordered_map> -#include <vector> +#include <queue> #include "udf/udf.h" +#include "vec/common/pod_array.h" +#include "vec/common/string_buffer.hpp" +#include "vec/io/io_helper.h" namespace doris { -class Counts { +class OldCounts { public: - Counts() = default; + OldCounts() = default; - inline void merge(const Counts* other) { + inline void merge(const OldCounts* other) { Review Comment: warning: method 'merge' can be made static [readability-convert-member-functions-to-static] ```suggestion static inline void merge(const OldCounts* other) { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileApprox::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).merge( + AggregateFunctionPercentileApprox::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileApprox::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); + double result = AggregateFunctionPercentileApprox::data(place).get(); + + if (std::isnan(result)) { + nullable_column.insert_default(); + } else { + auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); + col.get_data().push_back(result); + nullable_column.get_null_map_data().push_back(0); + } + } +}; + +// only for merge +template <bool is_nullable> +class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; + __builtin_unreachable(); + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[2] = {0, 0}; + + for (int i = 0; i < 2; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + + this->data(place).init(); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[3] = {0, 0, 0}; + + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[2]); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +struct PercentileState { + mutable std::vector<Counts> vec_counts; + std::vector<double> vec_quantile {-1}; + bool inited_flag = false; + + void write(BufferWritable& buf) const { + write_binary(inited_flag, buf); + int size_num = vec_quantile.size(); + write_binary(size_num, buf); + for (const auto& quantile : vec_quantile) { + write_binary(quantile, buf); + } + for (auto& counts : vec_counts) { + counts.serialize(buf); + } + } + + void read(BufferReadable& buf) { + read_binary(inited_flag, buf); + int size_num = 0; + read_binary(size_num, buf); + double data = 0.0; + vec_quantile.clear(); + for (int i = 0; i < size_num; ++i) { + read_binary(data, buf); + vec_quantile.emplace_back(data); + } + vec_counts.clear(); + vec_counts.resize(size_num); + for (int i = 0; i < size_num; ++i) { + vec_counts[i].unserialize(buf); + } + } + + void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + if (!inited_flag) { + vec_counts.resize(arg_size); + vec_quantile.resize(arg_size, -1); + inited_flag = true; + for (int i = 0; i < arg_size; ++i) { + vec_quantile[i] = quantiles[i]; + } + } + for (int i = 0; i < arg_size; ++i) { + vec_counts[i].increment(source, 1); + } + } + + void merge(const PercentileState& rhs) { + if (!rhs.inited_flag) { + return; + } + int size_num = rhs.vec_quantile.size(); + if (!inited_flag) { + vec_counts.resize(size_num); + vec_quantile.resize(size_num, -1); + inited_flag = true; + } + + for (int i = 0; i < size_num; ++i) { + if (vec_quantile[i] == -1.0) { + vec_quantile[i] = rhs.vec_quantile[i]; + } + vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i]))); + } + } + + void reset() { + vec_counts.clear(); + vec_quantile.clear(); + inited_flag = false; + } + + double get() const { return vec_counts[0].terminate(vec_quantile[0]); } + + void insert_result_into(IColumn& to) const { + auto& column_data = assert_cast<ColumnVector<Float64>&>(to).get_data(); + for (int i = 0; i < vec_counts.size(); ++i) { + column_data.push_back(vec_counts[i].terminate(vec_quantile[i])); + } + } +}; + +class AggregateFunctionPercentile final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> { +public: + AggregateFunctionPercentile(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>( + argument_types_) {} + + String get_name() const override { return "percentile"; } + + DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { Review Comment: warning: method 'add' can be made static [readability-convert-member-functions-to-static] ```suggestion static void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) override { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileApprox::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).merge( + AggregateFunctionPercentileApprox::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileApprox::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { Review Comment: warning: method 'insert_result_into' can be made static [readability-convert-member-functions-to-static] ```suggestion static void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) override { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileApprox::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).merge( + AggregateFunctionPercentileApprox::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileApprox::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); + double result = AggregateFunctionPercentileApprox::data(place).get(); + + if (std::isnan(result)) { + nullable_column.insert_default(); + } else { + auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); + col.get_data().push_back(result); + nullable_column.get_null_map_data().push_back(0); + } + } +}; + +// only for merge +template <bool is_nullable> +class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; + __builtin_unreachable(); + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[2] = {0, 0}; + + for (int i = 0; i < 2; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + + this->data(place).init(); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[3] = {0, 0, 0}; + + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[2]); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +struct PercentileState { + mutable std::vector<Counts> vec_counts; + std::vector<double> vec_quantile {-1}; + bool inited_flag = false; + + void write(BufferWritable& buf) const { + write_binary(inited_flag, buf); + int size_num = vec_quantile.size(); + write_binary(size_num, buf); + for (const auto& quantile : vec_quantile) { + write_binary(quantile, buf); + } + for (auto& counts : vec_counts) { + counts.serialize(buf); + } + } + + void read(BufferReadable& buf) { + read_binary(inited_flag, buf); + int size_num = 0; + read_binary(size_num, buf); + double data = 0.0; + vec_quantile.clear(); + for (int i = 0; i < size_num; ++i) { + read_binary(data, buf); + vec_quantile.emplace_back(data); + } + vec_counts.clear(); + vec_counts.resize(size_num); + for (int i = 0; i < size_num; ++i) { + vec_counts[i].unserialize(buf); + } + } + + void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + if (!inited_flag) { + vec_counts.resize(arg_size); + vec_quantile.resize(arg_size, -1); + inited_flag = true; + for (int i = 0; i < arg_size; ++i) { + vec_quantile[i] = quantiles[i]; + } + } + for (int i = 0; i < arg_size; ++i) { + vec_counts[i].increment(source, 1); + } + } + + void merge(const PercentileState& rhs) { + if (!rhs.inited_flag) { + return; + } + int size_num = rhs.vec_quantile.size(); + if (!inited_flag) { + vec_counts.resize(size_num); + vec_quantile.resize(size_num, -1); + inited_flag = true; + } + + for (int i = 0; i < size_num; ++i) { + if (vec_quantile[i] == -1.0) { + vec_quantile[i] = rhs.vec_quantile[i]; + } + vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i]))); + } + } + + void reset() { + vec_counts.clear(); + vec_quantile.clear(); + inited_flag = false; + } + + double get() const { return vec_counts[0].terminate(vec_quantile[0]); } + + void insert_result_into(IColumn& to) const { + auto& column_data = assert_cast<ColumnVector<Float64>&>(to).get_data(); + for (int i = 0; i < vec_counts.size(); ++i) { + column_data.push_back(vec_counts[i].terminate(vec_quantile[i])); + } + } +}; + +class AggregateFunctionPercentile final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> { +public: + AggregateFunctionPercentile(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>( + argument_types_) {} + + String get_name() const override { return "percentile"; } + + DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } Review Comment: warning: method 'get_return_type' can be made static [readability-convert-member-functions-to-static] ```suggestion static DataTypePtr get_return_type() override { return std::make_shared<DataTypeFloat64>(); } ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileApprox::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).merge( + AggregateFunctionPercentileApprox::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileApprox::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); + double result = AggregateFunctionPercentileApprox::data(place).get(); + + if (std::isnan(result)) { + nullable_column.insert_default(); + } else { + auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); + col.get_data().push_back(result); + nullable_column.get_null_map_data().push_back(0); + } + } +}; + +// only for merge +template <bool is_nullable> +class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; + __builtin_unreachable(); + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[2] = {0, 0}; + + for (int i = 0; i < 2; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + + this->data(place).init(); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[3] = {0, 0, 0}; + + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[2]); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +struct PercentileState { + mutable std::vector<Counts> vec_counts; + std::vector<double> vec_quantile {-1}; + bool inited_flag = false; + + void write(BufferWritable& buf) const { + write_binary(inited_flag, buf); + int size_num = vec_quantile.size(); + write_binary(size_num, buf); + for (const auto& quantile : vec_quantile) { + write_binary(quantile, buf); + } + for (auto& counts : vec_counts) { + counts.serialize(buf); + } + } + + void read(BufferReadable& buf) { + read_binary(inited_flag, buf); + int size_num = 0; + read_binary(size_num, buf); + double data = 0.0; + vec_quantile.clear(); + for (int i = 0; i < size_num; ++i) { + read_binary(data, buf); + vec_quantile.emplace_back(data); + } + vec_counts.clear(); + vec_counts.resize(size_num); + for (int i = 0; i < size_num; ++i) { + vec_counts[i].unserialize(buf); + } + } + + void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + if (!inited_flag) { + vec_counts.resize(arg_size); + vec_quantile.resize(arg_size, -1); + inited_flag = true; + for (int i = 0; i < arg_size; ++i) { + vec_quantile[i] = quantiles[i]; + } + } + for (int i = 0; i < arg_size; ++i) { + vec_counts[i].increment(source, 1); + } + } + + void merge(const PercentileState& rhs) { + if (!rhs.inited_flag) { + return; + } + int size_num = rhs.vec_quantile.size(); + if (!inited_flag) { + vec_counts.resize(size_num); + vec_quantile.resize(size_num, -1); + inited_flag = true; + } + + for (int i = 0; i < size_num; ++i) { + if (vec_quantile[i] == -1.0) { + vec_quantile[i] = rhs.vec_quantile[i]; + } + vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i]))); + } + } + + void reset() { + vec_counts.clear(); + vec_quantile.clear(); + inited_flag = false; + } + + double get() const { return vec_counts[0].terminate(vec_quantile[0]); } + + void insert_result_into(IColumn& to) const { + auto& column_data = assert_cast<ColumnVector<Float64>&>(to).get_data(); + for (int i = 0; i < vec_counts.size(); ++i) { + column_data.push_back(vec_counts[i].terminate(vec_quantile[i])); + } + } +}; + +class AggregateFunctionPercentile final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> { +public: + AggregateFunctionPercentile(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>( + argument_types_) {} + + String get_name() const override { return "percentile"; } + + DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(), + 1); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentile::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentile::data(place).merge(AggregateFunctionPercentile::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentile::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentile::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { Review Comment: warning: method 'insert_result_into' can be made static [readability-convert-member-functions-to-static] ```suggestion static void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) override { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileApprox::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).merge( + AggregateFunctionPercentileApprox::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileApprox::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); + double result = AggregateFunctionPercentileApprox::data(place).get(); + + if (std::isnan(result)) { + nullable_column.insert_default(); + } else { + auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); + col.get_data().push_back(result); + nullable_column.get_null_map_data().push_back(0); + } + } +}; + +// only for merge +template <bool is_nullable> +class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; + __builtin_unreachable(); + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[2] = {0, 0}; + + for (int i = 0; i < 2; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + + this->data(place).init(); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[3] = {0, 0, 0}; + + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[2]); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +struct PercentileState { + mutable std::vector<Counts> vec_counts; + std::vector<double> vec_quantile {-1}; + bool inited_flag = false; + + void write(BufferWritable& buf) const { + write_binary(inited_flag, buf); + int size_num = vec_quantile.size(); + write_binary(size_num, buf); + for (const auto& quantile : vec_quantile) { + write_binary(quantile, buf); + } + for (auto& counts : vec_counts) { + counts.serialize(buf); + } + } + + void read(BufferReadable& buf) { + read_binary(inited_flag, buf); + int size_num = 0; + read_binary(size_num, buf); + double data = 0.0; + vec_quantile.clear(); + for (int i = 0; i < size_num; ++i) { + read_binary(data, buf); + vec_quantile.emplace_back(data); + } + vec_counts.clear(); + vec_counts.resize(size_num); + for (int i = 0; i < size_num; ++i) { + vec_counts[i].unserialize(buf); + } + } + + void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + if (!inited_flag) { + vec_counts.resize(arg_size); + vec_quantile.resize(arg_size, -1); + inited_flag = true; + for (int i = 0; i < arg_size; ++i) { + vec_quantile[i] = quantiles[i]; + } + } + for (int i = 0; i < arg_size; ++i) { + vec_counts[i].increment(source, 1); + } + } + + void merge(const PercentileState& rhs) { + if (!rhs.inited_flag) { + return; + } + int size_num = rhs.vec_quantile.size(); + if (!inited_flag) { + vec_counts.resize(size_num); + vec_quantile.resize(size_num, -1); + inited_flag = true; + } + + for (int i = 0; i < size_num; ++i) { + if (vec_quantile[i] == -1.0) { + vec_quantile[i] = rhs.vec_quantile[i]; + } + vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i]))); + } + } + + void reset() { + vec_counts.clear(); + vec_quantile.clear(); + inited_flag = false; + } + + double get() const { return vec_counts[0].terminate(vec_quantile[0]); } + + void insert_result_into(IColumn& to) const { + auto& column_data = assert_cast<ColumnVector<Float64>&>(to).get_data(); + for (int i = 0; i < vec_counts.size(); ++i) { + column_data.push_back(vec_counts[i].terminate(vec_quantile[i])); + } + } +}; + +class AggregateFunctionPercentile final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> { +public: + AggregateFunctionPercentile(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>( + argument_types_) {} + + String get_name() const override { return "percentile"; } + + DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(), + 1); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentile::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentile::data(place).merge(AggregateFunctionPercentile::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentile::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentile::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + auto& col = assert_cast<ColumnVector<Float64>&>(to); + col.insert_value(AggregateFunctionPercentile::data(place).get()); + } +}; + +class AggregateFunctionPercentileArray final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray> { +public: + AggregateFunctionPercentileArray(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray>( + argument_types_) {} + + String get_name() const override { return "percentile_array"; } + + DataTypePtr get_return_type() const override { Review Comment: warning: method 'get_return_type' can be made static [readability-convert-member-functions-to-static] ```suggestion static DataTypePtr get_return_type() override { ``` ########## be/src/vec/aggregate_functions/aggregate_function_percentile.h: ########## @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <cmath> +#include <memory> +#include <ostream> +#include <string> +#include <vector> + +#include "util/counts.h" +#include "util/tdigest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/common/pod_array_fwd.h" +#include "vec/common/string_ref.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class Arena; +class BufferReadable; + +struct PercentileApproxState { + static constexpr double INIT_QUANTILE = -1.0; + PercentileApproxState() = default; + ~PercentileApproxState() = default; + + void init(double compression = 10000) { + if (!init_flag) { + //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description + //The compression parameter setting range is [2048, 10000]. + //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000], + //will use the default value of 10000 + if (compression < 2048 || compression > 10000) { + compression = 10000; + } + digest = TDigest::create_unique(compression); + compressions = compression; + init_flag = true; + } + } + + void write(BufferWritable& buf) const { + write_binary(init_flag, buf); + if (!init_flag) { + return; + } + + write_binary(target_quantile, buf); + write_binary(compressions, buf); + uint32_t serialize_size = digest->serialized_size(); + std::string result(serialize_size, '0'); + DCHECK(digest.get() != nullptr); + digest->serialize((uint8_t*)result.c_str()); + + write_binary(result, buf); + } + + void read(BufferReadable& buf) { + read_binary(init_flag, buf); + if (!init_flag) { + return; + } + + read_binary(target_quantile, buf); + read_binary(compressions, buf); + std::string str; + read_binary(str, buf); + digest = TDigest::create_unique(compressions); + digest->unserialize((uint8_t*)str.c_str()); + } + + double get() const { + if (init_flag) { + return digest->quantile(target_quantile); + } else { + return std::nan(""); + } + } + + void merge(const PercentileApproxState& rhs) { + if (!rhs.init_flag) { + return; + } + if (init_flag) { + DCHECK(digest.get() != nullptr); + digest->merge(rhs.digest.get()); + } else { + digest = TDigest::create_unique(compressions); + digest->merge(rhs.digest.get()); + init_flag = true; + } + if (target_quantile == PercentileApproxState::INIT_QUANTILE) { + target_quantile = rhs.target_quantile; + } + } + + void add(double source, double quantile) { + digest->add(source); + target_quantile = quantile; + } + + void reset() { + target_quantile = INIT_QUANTILE; + init_flag = false; + digest = TDigest::create_unique(compressions); + } + + bool init_flag = false; + std::unique_ptr<TDigest> digest; + double target_quantile = INIT_QUANTILE; + double compressions = 10000; +}; + +class AggregateFunctionPercentileApprox + : public IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox> { +public: + AggregateFunctionPercentileApprox(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileApproxState, + AggregateFunctionPercentileApprox>(argument_types_) {} + + String get_name() const override { return "percentile_approx"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentileApprox::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).merge( + AggregateFunctionPercentileApprox::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentileApprox::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentileApprox::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to); + double result = AggregateFunctionPercentileApprox::data(place).get(); + + if (std::isnan(result)) { + nullable_column.insert_default(); + } else { + auto& col = assert_cast<ColumnVector<Float64>&>(nullable_column.get_nested_column()); + col.get_data().push_back(result); + nullable_column.get_null_map_data().push_back(0); + } + } +}; + +// only for merge +template <bool is_nullable> +class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; + __builtin_unreachable(); + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[2] = {0, 0}; + + for (int i = 0; i < 2; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + + this->data(place).init(); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[3] = {0, 0, 0}; + + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[2]); + this->data(place).add(column_data[0], column_data[1]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add(sources.get_float64(row_num), quantile.get_float64(row_num)); + } + } +}; + +struct PercentileState { + mutable std::vector<Counts> vec_counts; + std::vector<double> vec_quantile {-1}; + bool inited_flag = false; + + void write(BufferWritable& buf) const { + write_binary(inited_flag, buf); + int size_num = vec_quantile.size(); + write_binary(size_num, buf); + for (const auto& quantile : vec_quantile) { + write_binary(quantile, buf); + } + for (auto& counts : vec_counts) { + counts.serialize(buf); + } + } + + void read(BufferReadable& buf) { + read_binary(inited_flag, buf); + int size_num = 0; + read_binary(size_num, buf); + double data = 0.0; + vec_quantile.clear(); + for (int i = 0; i < size_num; ++i) { + read_binary(data, buf); + vec_quantile.emplace_back(data); + } + vec_counts.clear(); + vec_counts.resize(size_num); + for (int i = 0; i < size_num; ++i) { + vec_counts[i].unserialize(buf); + } + } + + void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int arg_size) { + if (!inited_flag) { + vec_counts.resize(arg_size); + vec_quantile.resize(arg_size, -1); + inited_flag = true; + for (int i = 0; i < arg_size; ++i) { + vec_quantile[i] = quantiles[i]; + } + } + for (int i = 0; i < arg_size; ++i) { + vec_counts[i].increment(source, 1); + } + } + + void merge(const PercentileState& rhs) { + if (!rhs.inited_flag) { + return; + } + int size_num = rhs.vec_quantile.size(); + if (!inited_flag) { + vec_counts.resize(size_num); + vec_quantile.resize(size_num, -1); + inited_flag = true; + } + + for (int i = 0; i < size_num; ++i) { + if (vec_quantile[i] == -1.0) { + vec_quantile[i] = rhs.vec_quantile[i]; + } + vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i]))); + } + } + + void reset() { + vec_counts.clear(); + vec_quantile.clear(); + inited_flag = false; + } + + double get() const { return vec_counts[0].terminate(vec_quantile[0]); } + + void insert_result_into(IColumn& to) const { + auto& column_data = assert_cast<ColumnVector<Float64>&>(to).get_data(); + for (int i = 0; i < vec_counts.size(); ++i) { + column_data.push_back(vec_counts[i].terminate(vec_quantile[i])); + } + } +}; + +class AggregateFunctionPercentile final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile> { +public: + AggregateFunctionPercentile(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentile>( + argument_types_) {} + + String get_name() const override { return "percentile"; } + + DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + const auto& sources = assert_cast<const ColumnVector<Int64>&>(*columns[0]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + AggregateFunctionPercentile::data(place).add(sources.get_int(row_num), quantile.get_data(), + 1); + } + + void reset(AggregateDataPtr __restrict place) const override { + AggregateFunctionPercentile::data(place).reset(); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + AggregateFunctionPercentile::data(place).merge(AggregateFunctionPercentile::data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + AggregateFunctionPercentile::data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + AggregateFunctionPercentile::data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + auto& col = assert_cast<ColumnVector<Float64>&>(to); + col.insert_value(AggregateFunctionPercentile::data(place).get()); + } +}; + +class AggregateFunctionPercentileArray final + : public IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray> { +public: + AggregateFunctionPercentileArray(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<PercentileState, AggregateFunctionPercentileArray>( + argument_types_) {} + + String get_name() const override { return "percentile_array"; } + + DataTypePtr get_return_type() const override { + return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat64>())); + } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { Review Comment: warning: method 'add' can be made static [readability-convert-member-functions-to-static] ```suggestion static void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) override { ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org