github-actions[bot] commented on code in PR #33265: URL: https://github.com/apache/doris/pull/33265#discussion_r1552732592
########## be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp: ########## @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include "vec/aggregate_functions/aggregate_function_group_array_intersect.h" + +namespace doris::vectorized { + +IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, + const DataTypes& argument_types) { + WhichDataType which(nested_type); + if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) Review Comment: warning: statement should be inside braces [readability-braces-around-statements] ```suggestion if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) { ``` be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp:29: ```diff - else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) + } else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) ``` ########## be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h: ########## @@ -0,0 +1,531 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include <cassert> +#include <memory> + +#include "exprs/hybrid_set.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/columns/column_array.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" +#include "vec/data_types/data_type_time_v2.h" +#include "vec/io/io_helper.h" +#include "vec/io/var_int.h" + +namespace doris { +namespace vectorized { +class Arena; +class BufferReadable; +class BufferWritable; +} // namespace vectorized +} // namespace doris + +namespace doris::vectorized { + +/// Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType so that to inherit HybridSet +template <typename T> +constexpr PrimitiveType TypeToPrimitiveType() { + if constexpr (std::is_same_v<T, UInt8> || std::is_same_v<T, Int8>) { + return TYPE_TINYINT; + } else if constexpr (std::is_same_v<T, Int16>) { + return TYPE_SMALLINT; + } else if constexpr (std::is_same_v<T, Int32>) { + return TYPE_INT; + } else if constexpr (std::is_same_v<T, Int64>) { + return TYPE_BIGINT; + } else if constexpr (std::is_same_v<T, Int128>) { + return TYPE_LARGEINT; + } else if constexpr (std::is_same_v<T, Float32>) { + return TYPE_FLOAT; + } else if constexpr (std::is_same_v<T, Float64>) { + return TYPE_DOUBLE; + } else if constexpr (std::is_same_v<T, DateV2>) { + return TYPE_DATEV2; + } else if constexpr (std::is_same_v<T, DateTimeV2>) { + return TYPE_DATETIMEV2; + } else { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType"); + } +} + +template <typename T> +class NullableNumericOrDateSet + : public HybridSet<TypeToPrimitiveType<T>(), DynamicContainer<typename PrimitiveTypeTraits< + TypeToPrimitiveType<T>()>::CppType>> { +public: + NullableNumericOrDateSet() { this->_null_aware = true; } + + void change_contains_null_value(bool target_value) { this->_contains_null = target_value; } +}; + +template <typename T> +struct AggregateFunctionGroupArrayIntersectData { + using NullableNumericOrDateSetType = NullableNumericOrDateSet<T>; + using Set = std::shared_ptr<NullableNumericOrDateSetType>; + + AggregateFunctionGroupArrayIntersectData() + : value(std::make_shared<NullableNumericOrDateSetType>()) {} + + Set value; + UInt64 version = 0; +}; + +/// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. +template <typename T> +class AggregateFunctionGroupArrayIntersect + : public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>> { +private: + using State = AggregateFunctionGroupArrayIntersectData<T>; + DataTypePtr argument_type; + +public: + AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>>( + argument_types_), + argument_type(argument_types_[0]) {} + + AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_, + const bool result_is_nullable) + : IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>>( + argument_types_), + argument_type(argument_types_[0]) {} + + String get_name() const override { return "group_array_intersect"; } + + DataTypePtr get_return_type() const override { return argument_type; } + + bool allocates_memory_in_arena() const override { return false; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + auto& data = this->data(place); + auto& version = data.version; + auto& set = data.value; + + const bool col_is_nullable = (*columns[0]).is_nullable(); + const ColumnArray& column = + col_is_nullable ? assert_cast<const ColumnArray&>( + assert_cast<const ColumnNullable&>(*columns[0]) + .get_nested_column()) + : assert_cast<const ColumnArray&>(*columns[0]); + + const auto data_column = column.get_data_ptr(); + const auto& offsets = column.get_offsets(); + const size_t offset = offsets[static_cast<ssize_t>(row_num) - 1]; + const auto arr_size = offsets[row_num] - offset; + + using ColVecType = ColumnVector<T>; + const auto& column_data = column.get_data(); + + bool is_column_data_nullable = column_data.is_nullable(); + ColumnNullable* col_null = nullptr; + const ColVecType* nested_column_data = nullptr; + + if (is_column_data_nullable) { + auto const_col_data = const_cast<IColumn*>(&column_data); + col_null = static_cast<ColumnNullable*>(const_col_data); + nested_column_data = &assert_cast<const ColVecType&>(col_null->get_nested_column()); + } else { + nested_column_data = &static_cast<const ColVecType&>(column_data); + } + + ++version; + if (version == 1) { + for (size_t i = 0; i < arr_size; ++i) { + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + const T* src_data = + is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); + + set->insert(src_data); + } + } else if (set->size() != 0 || set->contain_null()) { + typename State::Set new_set = + std::make_shared<typename State::NullableNumericOrDateSetType>(); + + for (size_t i = 0; i < arr_size; ++i) { + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + const T* src_data = + is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); + + if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) { + new_set->insert(src_data); + } + } + set = std::move(new_set); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + auto& data = this->data(place); + auto& set = data.value; + auto& rhs_set = this->data(rhs).value; + + if (this->data(rhs).version == 0) { + return; + } + + UInt64 version = data.version++; + if (version == 0) { + set->change_contains_null_value(rhs_set->contain_null()); + HybridSetBase::IteratorBase* it = rhs_set->begin(); + while (it->has_next()) { + const void* value = it->get_value(); + set->insert(value); + it->next(); + } + return; + } + + if (set->size() != 0) { + auto create_new_set = [](auto& lhs_val, auto& rhs_val) { + typename State::Set new_set = + std::make_shared<typename State::NullableNumericOrDateSetType>(); + HybridSetBase::IteratorBase* it = lhs_val->begin(); + while (it->has_next()) { + const void* value = it->get_value(); + if ((rhs_val->find(value))) { + new_set->insert(value); + } + it->next(); + } + new_set->change_contains_null_value(lhs_val->contain_null() && + rhs_val->contain_null()); + return new_set; + }; + auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set) + : create_new_set(set, rhs_set); + set = std::move(new_set); + } + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + auto& data = this->data(place); + auto& set = data.value; + auto version = data.version; + + bool is_set_contains_null = set->contain_null(); + + write_pod_binary(is_set_contains_null, buf); + + write_var_uint(version, buf); + write_var_uint(set->size(), buf); + HybridSetBase::IteratorBase* it = set->begin(); + + while (it->has_next()) { + const T* value_ptr = static_cast<const T*>(it->get_value()); + write_int_binary((*value_ptr), buf); + it->next(); + } + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + auto& data = this->data(place); + bool is_set_contains_null; + + read_pod_binary(is_set_contains_null, buf); + data.value->change_contains_null_value(is_set_contains_null); + read_var_uint(data.version, buf); + size_t size; + read_var_uint(size, buf); + + T element; + for (size_t i = 0; i < size; ++i) { + read_int_binary(element, buf); + data.value->insert(static_cast<void*>(&element)); + } + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnArray& arr_to = assert_cast<ColumnArray&>(to); + ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); + + auto& to_nested_col = arr_to.get_data(); + using ElementType = T; + using ColVecType = ColumnVector<ElementType>; + + bool is_nullable = to_nested_col.is_nullable(); + + auto insert_values = [](ColVecType& nested_col, auto& set, bool is_nullable = false, + ColumnNullable* col_null = nullptr) { + size_t old_size = nested_col.get_data().size(); + size_t res_size = set->size(); + size_t i = 0; + + if (is_nullable && set->contain_null()) { + col_null->insert_data(nullptr, 0); + res_size += 1; + i = 1; + } + + nested_col.get_data().resize(old_size + res_size); + + HybridSetBase::IteratorBase* it = set->begin(); + while (it->has_next()) { + ElementType value = *reinterpret_cast<const ElementType*>(it->get_value()); + nested_col.get_data()[old_size + i] = value; + if (is_nullable) { + col_null->get_null_map_data().push_back(0); + } + it->next(); + ++i; + } + }; + + const auto& set = this->data(place).value; + if (is_nullable) { + auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col); + auto& nested_col = assert_cast<ColVecType&>(col_null->get_nested_column()); + offsets_to.push_back(offsets_to.back() + set->size() + (set->contain_null() ? 1 : 0)); + insert_values(nested_col, set, true, col_null); + } else { + auto& nested_col = static_cast<ColVecType&>(to_nested_col); + offsets_to.push_back(offsets_to.back() + set->size()); + insert_values(nested_col, set); + } + } +}; + +/// Generic implementation, it uses serialized representation as object descriptor. +class NullableStringSet : public StringValueSet<DynamicContainer<StringRef>> { +public: + NullableStringSet() { this->_null_aware = true; } + + void change_contains_null_value(bool target_value) { this->_contains_null = target_value; } +}; + +struct AggregateFunctionGroupArrayIntersectGenericData { + using Set = std::shared_ptr<NullableStringSet>; + + AggregateFunctionGroupArrayIntersectGenericData() Review Comment: warning: use '= default' to define a trivial default constructor [modernize-use-equals-default] be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h:331: ```diff - : value(std::make_shared<NullableStringSet>()) {} + : value(std::make_shared<NullableStringSet>()) = default; ``` ########## be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp: ########## @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include "vec/aggregate_functions/aggregate_function_group_array_intersect.h" + +namespace doris::vectorized { + +IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, + const DataTypes& argument_types) { + WhichDataType which(nested_type); + if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) + return new AggregateFunctionGroupArrayIntersect<DateV2>(argument_types); + else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) + return new AggregateFunctionGroupArrayIntersect<DateTimeV2>(argument_types); + else { + /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric + if (nested_type->is_value_unambiguously_represented_in_contiguous_memory_region()) + return new AggregateFunctionGroupArrayIntersectGeneric<true>(argument_types); + else + return new AggregateFunctionGroupArrayIntersectGeneric<false>(argument_types); + } +} + +inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const auto& nested_type = remove_nullable( + dynamic_cast<const DataTypeArray&>(*(argument_types[0])).get_nested_type()); + AggregateFunctionPtr res = nullptr; + + WhichDataType which(nested_type); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + res = creator_without_type::create<AggregateFunctionGroupArrayIntersect<TYPE>>( \ + argument_types, result_is_nullable); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH + + if (!res) { + res = AggregateFunctionPtr(create_with_extra_types(nested_type, argument_types)); + } + + if (!res) Review Comment: warning: statement should be inside braces [readability-braces-around-statements] ```suggestion if (!res) { ``` be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp:61: ```diff - argument_types[0]->get_name(), name); + argument_types[0]->get_name(), name); + } ``` ########## be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp: ########## @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include "vec/aggregate_functions/aggregate_function_group_array_intersect.h" + +namespace doris::vectorized { + +IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, + const DataTypes& argument_types) { + WhichDataType which(nested_type); + if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) + return new AggregateFunctionGroupArrayIntersect<DateV2>(argument_types); + else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) Review Comment: warning: statement should be inside braces [readability-braces-around-statements] ```suggestion else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) { ``` be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp:31: ```diff - else { + } else { ``` ########## be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h: ########## @@ -0,0 +1,531 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include <cassert> +#include <memory> + +#include "exprs/hybrid_set.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/columns/column_array.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" +#include "vec/data_types/data_type_time_v2.h" +#include "vec/io/io_helper.h" +#include "vec/io/var_int.h" + +namespace doris { +namespace vectorized { Review Comment: warning: nested namespaces can be concatenated [modernize-concat-nested-namespaces] ```suggestion namespace doris::vectorized { ``` be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h:43: ```diff - } // namespace vectorized - } // namespace doris + } // namespace doris ``` ########## be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h: ########## @@ -0,0 +1,531 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include <cassert> +#include <memory> + +#include "exprs/hybrid_set.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/columns/column_array.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" +#include "vec/data_types/data_type_time_v2.h" +#include "vec/io/io_helper.h" +#include "vec/io/var_int.h" + +namespace doris { +namespace vectorized { +class Arena; +class BufferReadable; +class BufferWritable; +} // namespace vectorized +} // namespace doris + +namespace doris::vectorized { + +/// Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType so that to inherit HybridSet +template <typename T> +constexpr PrimitiveType TypeToPrimitiveType() { + if constexpr (std::is_same_v<T, UInt8> || std::is_same_v<T, Int8>) { + return TYPE_TINYINT; + } else if constexpr (std::is_same_v<T, Int16>) { + return TYPE_SMALLINT; + } else if constexpr (std::is_same_v<T, Int32>) { + return TYPE_INT; + } else if constexpr (std::is_same_v<T, Int64>) { + return TYPE_BIGINT; + } else if constexpr (std::is_same_v<T, Int128>) { + return TYPE_LARGEINT; + } else if constexpr (std::is_same_v<T, Float32>) { + return TYPE_FLOAT; + } else if constexpr (std::is_same_v<T, Float64>) { + return TYPE_DOUBLE; + } else if constexpr (std::is_same_v<T, DateV2>) { + return TYPE_DATEV2; + } else if constexpr (std::is_same_v<T, DateTimeV2>) { + return TYPE_DATETIMEV2; + } else { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType"); + } +} + +template <typename T> +class NullableNumericOrDateSet + : public HybridSet<TypeToPrimitiveType<T>(), DynamicContainer<typename PrimitiveTypeTraits< + TypeToPrimitiveType<T>()>::CppType>> { +public: + NullableNumericOrDateSet() { this->_null_aware = true; } + + void change_contains_null_value(bool target_value) { this->_contains_null = target_value; } +}; + +template <typename T> +struct AggregateFunctionGroupArrayIntersectData { + using NullableNumericOrDateSetType = NullableNumericOrDateSet<T>; + using Set = std::shared_ptr<NullableNumericOrDateSetType>; + + AggregateFunctionGroupArrayIntersectData() + : value(std::make_shared<NullableNumericOrDateSetType>()) {} + + Set value; + UInt64 version = 0; +}; + +/// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. +template <typename T> +class AggregateFunctionGroupArrayIntersect + : public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>> { +private: + using State = AggregateFunctionGroupArrayIntersectData<T>; + DataTypePtr argument_type; + +public: + AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>>( + argument_types_), + argument_type(argument_types_[0]) {} + + AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_, + const bool result_is_nullable) + : IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>>( + argument_types_), + argument_type(argument_types_[0]) {} + + String get_name() const override { return "group_array_intersect"; } + + DataTypePtr get_return_type() const override { return argument_type; } + + bool allocates_memory_in_arena() const override { return false; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + auto& data = this->data(place); + auto& version = data.version; + auto& set = data.value; + + const bool col_is_nullable = (*columns[0]).is_nullable(); + const ColumnArray& column = + col_is_nullable ? assert_cast<const ColumnArray&>( + assert_cast<const ColumnNullable&>(*columns[0]) + .get_nested_column()) + : assert_cast<const ColumnArray&>(*columns[0]); + + const auto data_column = column.get_data_ptr(); + const auto& offsets = column.get_offsets(); + const size_t offset = offsets[static_cast<ssize_t>(row_num) - 1]; + const auto arr_size = offsets[row_num] - offset; + + using ColVecType = ColumnVector<T>; + const auto& column_data = column.get_data(); + + bool is_column_data_nullable = column_data.is_nullable(); + ColumnNullable* col_null = nullptr; + const ColVecType* nested_column_data = nullptr; + + if (is_column_data_nullable) { + auto const_col_data = const_cast<IColumn*>(&column_data); + col_null = static_cast<ColumnNullable*>(const_col_data); + nested_column_data = &assert_cast<const ColVecType&>(col_null->get_nested_column()); + } else { + nested_column_data = &static_cast<const ColVecType&>(column_data); + } + + ++version; + if (version == 1) { + for (size_t i = 0; i < arr_size; ++i) { + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + const T* src_data = + is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); + + set->insert(src_data); + } + } else if (set->size() != 0 || set->contain_null()) { + typename State::Set new_set = + std::make_shared<typename State::NullableNumericOrDateSetType>(); + + for (size_t i = 0; i < arr_size; ++i) { + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + const T* src_data = + is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); + + if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) { + new_set->insert(src_data); + } + } + set = std::move(new_set); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + auto& data = this->data(place); + auto& set = data.value; + auto& rhs_set = this->data(rhs).value; + + if (this->data(rhs).version == 0) { + return; + } + + UInt64 version = data.version++; + if (version == 0) { + set->change_contains_null_value(rhs_set->contain_null()); + HybridSetBase::IteratorBase* it = rhs_set->begin(); + while (it->has_next()) { + const void* value = it->get_value(); + set->insert(value); + it->next(); + } + return; + } + + if (set->size() != 0) { + auto create_new_set = [](auto& lhs_val, auto& rhs_val) { + typename State::Set new_set = + std::make_shared<typename State::NullableNumericOrDateSetType>(); + HybridSetBase::IteratorBase* it = lhs_val->begin(); + while (it->has_next()) { + const void* value = it->get_value(); + if ((rhs_val->find(value))) { + new_set->insert(value); + } + it->next(); + } + new_set->change_contains_null_value(lhs_val->contain_null() && + rhs_val->contain_null()); + return new_set; + }; + auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set) + : create_new_set(set, rhs_set); + set = std::move(new_set); + } + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + auto& data = this->data(place); + auto& set = data.value; + auto version = data.version; + + bool is_set_contains_null = set->contain_null(); + + write_pod_binary(is_set_contains_null, buf); + + write_var_uint(version, buf); + write_var_uint(set->size(), buf); + HybridSetBase::IteratorBase* it = set->begin(); + + while (it->has_next()) { + const T* value_ptr = static_cast<const T*>(it->get_value()); + write_int_binary((*value_ptr), buf); + it->next(); + } + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + auto& data = this->data(place); + bool is_set_contains_null; + + read_pod_binary(is_set_contains_null, buf); + data.value->change_contains_null_value(is_set_contains_null); + read_var_uint(data.version, buf); + size_t size; + read_var_uint(size, buf); + + T element; + for (size_t i = 0; i < size; ++i) { + read_int_binary(element, buf); + data.value->insert(static_cast<void*>(&element)); + } + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnArray& arr_to = assert_cast<ColumnArray&>(to); + ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); + + auto& to_nested_col = arr_to.get_data(); + using ElementType = T; + using ColVecType = ColumnVector<ElementType>; + + bool is_nullable = to_nested_col.is_nullable(); + + auto insert_values = [](ColVecType& nested_col, auto& set, bool is_nullable = false, + ColumnNullable* col_null = nullptr) { + size_t old_size = nested_col.get_data().size(); + size_t res_size = set->size(); + size_t i = 0; + + if (is_nullable && set->contain_null()) { + col_null->insert_data(nullptr, 0); + res_size += 1; + i = 1; + } + + nested_col.get_data().resize(old_size + res_size); + + HybridSetBase::IteratorBase* it = set->begin(); + while (it->has_next()) { + ElementType value = *reinterpret_cast<const ElementType*>(it->get_value()); + nested_col.get_data()[old_size + i] = value; + if (is_nullable) { + col_null->get_null_map_data().push_back(0); + } + it->next(); + ++i; + } + }; + + const auto& set = this->data(place).value; + if (is_nullable) { + auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col); + auto& nested_col = assert_cast<ColVecType&>(col_null->get_nested_column()); + offsets_to.push_back(offsets_to.back() + set->size() + (set->contain_null() ? 1 : 0)); + insert_values(nested_col, set, true, col_null); + } else { + auto& nested_col = static_cast<ColVecType&>(to_nested_col); + offsets_to.push_back(offsets_to.back() + set->size()); + insert_values(nested_col, set); + } + } +}; + +/// Generic implementation, it uses serialized representation as object descriptor. +class NullableStringSet : public StringValueSet<DynamicContainer<StringRef>> { +public: + NullableStringSet() { this->_null_aware = true; } + + void change_contains_null_value(bool target_value) { this->_contains_null = target_value; } +}; + +struct AggregateFunctionGroupArrayIntersectGenericData { + using Set = std::shared_ptr<NullableStringSet>; + + AggregateFunctionGroupArrayIntersectGenericData() + : value(std::make_shared<NullableStringSet>()) {} + Set value; + UInt64 version = 0; +}; + +/** Template parameter with true value should be used for columns that store their elements in memory continuously. + * For such columns group_array_intersect() can be implemented more efficiently (especially for small numeric arrays). + */ +template <bool is_plain_column = false> +class AggregateFunctionGroupArrayIntersectGeneric + : public IAggregateFunctionDataHelper< + AggregateFunctionGroupArrayIntersectGenericData, + AggregateFunctionGroupArrayIntersectGeneric<is_plain_column>> { +private: + using State = AggregateFunctionGroupArrayIntersectGenericData; + DataTypePtr input_data_type; + +public: + AggregateFunctionGroupArrayIntersectGeneric(const DataTypes& input_data_type_) + : IAggregateFunctionDataHelper< + AggregateFunctionGroupArrayIntersectGenericData, + AggregateFunctionGroupArrayIntersectGeneric<is_plain_column>>( + input_data_type_), + input_data_type(input_data_type_[0]) {} + + String get_name() const override { return "group_array_intersect"; } + + DataTypePtr get_return_type() const override { return input_data_type; } + + bool allocates_memory_in_arena() const override { return true; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena* arena) const override { + auto& data = this->data(place); + auto& version = data.version; + auto& set = data.value; + + const bool col_is_nullable = (*columns[0]).is_nullable(); + const ColumnArray& column = + col_is_nullable ? assert_cast<const ColumnArray&>( + assert_cast<const ColumnNullable&>(*columns[0]) + .get_nested_column()) + : assert_cast<const ColumnArray&>(*columns[0]); + + const auto nested_column_data = column.get_data_ptr(); + const auto& offsets = column.get_offsets(); + const size_t offset = offsets[static_cast<ssize_t>(row_num) - 1]; + const auto arr_size = offsets[row_num] - offset; + const auto& column_data = column.get_data(); + bool is_column_data_nullable = column_data.is_nullable(); + ColumnNullable* col_null = nullptr; + + if (is_column_data_nullable) { + auto const_col_data = const_cast<IColumn*>(&column_data); + col_null = static_cast<ColumnNullable*>(const_col_data); + } + + auto process_element = [&](size_t i) { + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + + StringRef src = StringRef(); + if constexpr (is_plain_column) { + src = nested_column_data->get_data_at(offset + i); + } else { + const char* begin = nullptr; + src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); + } + + src.data = is_null_element ? nullptr : arena->insert(src.data, src.size); + return src; + }; + + ++version; + if (version == 1) { + for (size_t i = 0; i < arr_size; ++i) { + StringRef src = process_element(i); + set->insert((void*)src.data, src.size); + } + } else if (set->size() != 0 || set->contain_null()) { + typename State::Set new_set = std::make_shared<NullableStringSet>(); + + for (size_t i = 0; i < arr_size; ++i) { + StringRef src = process_element(i); + if (set->find(src.data, src.size) || (set->contain_null() && src.data == nullptr)) { + new_set->insert((void*)src.data, src.size); + } + } + set = std::move(new_set); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + auto& data = this->data(place); + auto& set = data.value; + auto& rhs_set = this->data(rhs).value; + + if (this->data(rhs).version == 0) { + return; + } + + UInt64 version = data.version++; + if (version == 0) { + set->change_contains_null_value(rhs_set->contain_null()); + HybridSetBase::IteratorBase* it = rhs_set->begin(); + while (it->has_next()) { + const StringRef* value = reinterpret_cast<const StringRef*>(it->get_value()); Review Comment: warning: use auto when initializing with a cast to avoid duplicating the type name [modernize-use-auto] ```suggestion const auto* value = reinterpret_cast<const StringRef*>(it->get_value()); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org