This is an automated email from the ASF dual-hosted git repository. gabriellee pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new ee4196d9d23 [Improvement](agg) Improve count distinct distribute keys (#33167) ee4196d9d23 is described below commit ee4196d9d23c252fc35a287c04adf9d705e7637f Author: Gabriel <gabrielleeb...@gmail.com> AuthorDate: Fri Apr 26 18:31:11 2024 +0800 [Improvement](agg) Improve count distinct distribute keys (#33167) --- .../aggregate_function_simple_factory.cpp | 2 + .../aggregate_functions/aggregate_function_uniq.h | 2 +- .../aggregate_function_uniq_distribute_key.cpp | 73 ++++++ .../aggregate_function_uniq_distribute_key.h | 253 +++++++++++++++++++++ 4 files changed, 329 insertions(+), 1 deletion(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 00597b212be..d95d0ce6ccb 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -40,6 +40,7 @@ void register_aggregate_function_count(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_count_by_enum(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_uniq_distribute_key(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_quantile_state(AggregateFunctionSimpleFactory& factory); @@ -80,6 +81,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_count(instance); register_aggregate_function_count_by_enum(instance); register_aggregate_function_uniq(instance); + register_aggregate_function_uniq_distribute_key(instance); register_aggregate_function_bit(instance); register_aggregate_function_bitmap(instance); register_aggregate_function_group_array_intersect(instance); diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.h b/be/src/vec/aggregate_functions/aggregate_function_uniq.h index 2e8855134eb..58abd3842c2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_uniq.h +++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.h @@ -75,7 +75,7 @@ struct AggregateFunctionUniqExactData { Set set; - static String get_name() { return "uniqExact"; } + static String get_name() { return "multi_distinct"; } }; namespace detail { diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp new file mode 100644 index 00000000000..3bf979483b5 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/aggregate_functions/aggregate_function_uniq_distribute_key.h" + +#include <string> + +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" + +namespace doris::vectorized { + +template <template <typename> class Data> +AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { + if (argument_types.size() == 1) { + const IDataType& argument_type = *remove_nullable(argument_types[0]); + WhichDataType which(argument_type); + + AggregateFunctionPtr res( + creator_with_numeric_type::create<AggregateFunctionUniqDistributeKey, Data>( + argument_types, result_is_nullable)); + if (res) { + return res; + } else if (which.is_decimal32()) { + return creator_without_type::create< + AggregateFunctionUniqDistributeKey<Decimal32, Data<Int32>>>(argument_types, + result_is_nullable); + } else if (which.is_decimal64()) { + return creator_without_type::create< + AggregateFunctionUniqDistributeKey<Decimal64, Data<Int64>>>(argument_types, + result_is_nullable); + } else if (which.is_decimal128v3()) { + return creator_without_type::create< + AggregateFunctionUniqDistributeKey<Decimal128V3, Data<Int128>>>( + argument_types, result_is_nullable); + } else if (which.is_decimal128v2() || which.is_decimal128v3()) { + return creator_without_type::create< + AggregateFunctionUniqDistributeKey<Decimal128V2, Data<Int128>>>( + argument_types, result_is_nullable); + } else if (which.is_string_or_fixed_string()) { + return creator_without_type::create< + AggregateFunctionUniqDistributeKey<String, Data<String>>>(argument_types, + result_is_nullable); + } + } + + return nullptr; +} + +void register_aggregate_function_uniq_distribute_key(AggregateFunctionSimpleFactory& factory) { + AggregateFunctionCreator creator = + create_aggregate_function_uniq<AggregateFunctionUniqDistributeKeyData>; + factory.register_function_both("multi_distinct_count_distribute_key", creator); +} + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.h b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.h new file mode 100644 index 00000000000..0fa66e34230 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.h @@ -0,0 +1,253 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionUniq.h +// and modified by Doris + +#pragma once + +#include <stddef.h> + +#include <algorithm> +#include <boost/iterator/iterator_facade.hpp> +#include <memory> +#include <vector> + +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_uniq.h" +#include "vec/columns/column.h" +#include "vec/columns/column_fixed_length_object.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_fixed_length_object.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/var_int.h" + +namespace doris { +namespace vectorized { +class Arena; +class BufferReadable; +class BufferWritable; +} // namespace vectorized +} // namespace doris +template <typename T> +struct HashCRC32; +namespace doris::vectorized { + +template <typename T> +struct AggregateFunctionUniqDistributeKeyData { + static constexpr bool is_string_key = std::is_same_v<T, String>; + using Key = std::conditional_t<is_string_key, UInt128, T>; + using Hash = std::conditional_t<is_string_key, UInt128TrivialHash, HashCRC32<Key>>; + + using Set = flat_hash_set<Key, Hash>; + + // TODO: replace SipHash with xxhash to speed up + static UInt128 ALWAYS_INLINE get_key(const StringRef& value) { + auto hash_value = XXH_INLINE_XXH128(value.data, value.size, 0); + return UInt128 {hash_value.high64, hash_value.low64}; + } + + Set set; + UInt64 count = 0; +}; + +template <typename T, typename Data> +class AggregateFunctionUniqDistributeKey final + : public IAggregateFunctionDataHelper<Data, AggregateFunctionUniqDistributeKey<T, Data>> { +public: + using KeyType = std::conditional_t<std::is_same_v<T, String>, UInt128, T>; + AggregateFunctionUniqDistributeKey(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<Data, AggregateFunctionUniqDistributeKey<T, Data>>( + argument_types_) {} + + String get_name() const override { return "multi_distinct_distribute_key"; } + + DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + detail::OneAdder<T, Data>::add(this->data(place), *columns[0], row_num); + } + + static ALWAYS_INLINE const KeyType* get_keys(std::vector<KeyType>& keys_container, + const IColumn& column, size_t batch_size) { + if constexpr (std::is_same_v<T, String>) { + keys_container.resize(batch_size); + for (size_t i = 0; i != batch_size; ++i) { + StringRef value = column.get_data_at(i); + keys_container[i] = Data::get_key(value); + } + return keys_container.data(); + } else { + using ColumnType = + std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; + return assert_cast<const ColumnType&>(column).get_data().data(); + } + } + + void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset, + const IColumn** columns, Arena* arena, bool /*agg_many*/) const override { + std::vector<KeyType> keys_container; + const KeyType* keys = get_keys(keys_container, *columns[0], batch_size); + + std::vector<typename Data::Set*> array_of_data_set(batch_size); + + for (size_t i = 0; i != batch_size; ++i) { + array_of_data_set[i] = &(this->data(places[i] + place_offset).set); + } + + for (size_t i = 0; i != batch_size; ++i) { + if (i + HASH_MAP_PREFETCH_DIST < batch_size) { + array_of_data_set[i + HASH_MAP_PREFETCH_DIST]->prefetch( + keys[i + HASH_MAP_PREFETCH_DIST]); + } + + array_of_data_set[i]->insert(keys[i]); + } + } + + void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, + Arena* arena) const override { + std::vector<KeyType> keys_container; + const KeyType* keys = get_keys(keys_container, *columns[0], batch_size); + auto& set = this->data(place).set; + + for (size_t i = 0; i != batch_size; ++i) { + if (i + HASH_MAP_PREFETCH_DIST < batch_size) { + set.prefetch(keys[i + HASH_MAP_PREFETCH_DIST]); + } + set.insert(keys[i]); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).count += this->data(rhs).count; + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + write_var_uint(this->data(place).set.size(), buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + read_var_uint(this->data(place).count, buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + assert_cast<ColumnInt64&>(to).get_data().push_back(this->data(place).count); + } + + void deserialize_from_column(AggregateDataPtr places, const IColumn& column, Arena* arena, + size_t num_rows) const override { + auto data = reinterpret_cast<const UInt64*>( + assert_cast<const ColumnFixedLengthObject&>(column).get_data().data()); + for (size_t i = 0; i != num_rows; ++i) { + auto rhs_place = places + sizeof(Data) * i; + this->create(rhs_place); + (reinterpret_cast<Data*>(rhs_place))->count = data[i]; + } + } + + void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset, + MutableColumnPtr& dst, const size_t num_rows) const override { + auto& col = assert_cast<ColumnFixedLengthObject&>(*dst); + CHECK(col.item_size() == sizeof(UInt64)) + << "size is not equal: " << col.item_size() << " " << sizeof(UInt64); + col.resize(num_rows); + auto* data = reinterpret_cast<UInt64*>(col.get_data().data()); + for (size_t i = 0; i != num_rows; ++i) { + data[i] = this->data(places[i] + offset).set.size(); + } + } + + void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, + const size_t num_rows, Arena* arena) const override { + auto& dst_col = assert_cast<ColumnFixedLengthObject&>(*dst); + CHECK(dst_col.item_size() == sizeof(UInt64)) + << "size is not equal: " << dst_col.item_size() << " " << sizeof(UInt64); + dst_col.resize(num_rows); + auto* data = reinterpret_cast<UInt64*>(dst_col.get_data().data()); + for (size_t i = 0; i != num_rows; ++i) { + data[i] = 1; + } + } + + void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, + Arena* arena) const override { + auto& col = assert_cast<const ColumnFixedLengthObject&>(column); + const size_t num_rows = column.size(); + auto* data = reinterpret_cast<const UInt64*>(col.get_data().data()); + for (size_t i = 0; i != num_rows; ++i) { + AggregateFunctionUniqDistributeKey::data(place).count += data[i]; + } + } + + void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place, + const IColumn& column, size_t begin, size_t end, + Arena* arena) const override { + CHECK(end <= column.size() && begin <= end) + << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size(); + auto& col = assert_cast<const ColumnFixedLengthObject&>(column); + auto* data = reinterpret_cast<const UInt64*>(col.get_data().data()); + for (size_t i = begin; i <= end; ++i) { + this->data(place).count += data[i]; + } + } + + void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t offset, + AggregateDataPtr rhs, const ColumnString* column, Arena* arena, + const size_t num_rows) const override { + this->deserialize_from_column(rhs, *column, arena, num_rows); + DEFER({ this->destroy_vec(rhs, num_rows); }); + this->merge_vec(places, offset, rhs, arena, num_rows); + } + + void deserialize_and_merge_vec_selected(const AggregateDataPtr* places, size_t offset, + AggregateDataPtr rhs, const ColumnString* column, + Arena* arena, const size_t num_rows) const override { + this->deserialize_from_column(rhs, *column, arena, num_rows); + DEFER({ this->destroy_vec(rhs, num_rows); }); + this->merge_vec_selected(places, offset, rhs, arena, num_rows); + } + + void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, + IColumn& to) const override { + auto& col = assert_cast<ColumnFixedLengthObject&>(to); + CHECK(col.item_size() == sizeof(UInt64)) + << "size is not equal: " << col.item_size() << " " << sizeof(UInt64); + size_t old_size = col.size(); + col.resize(old_size + 1); + *reinterpret_cast<UInt64*>(col.get_data().data() + old_size) = + AggregateFunctionUniqDistributeKey::data(place).set.size(); + } + + MutableColumnPtr create_serialize_column() const override { + return ColumnFixedLengthObject::create(sizeof(UInt64)); + } + + DataTypePtr get_serialized_type() const override { + return std::make_shared<DataTypeFixedLengthObject>(); + } +}; + +} // namespace doris::vectorized --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org