BiteTheDDDDt commented on code in PR #33265: URL: https://github.com/apache/doris/pull/33265#discussion_r1559166531
########## be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h: ########## @@ -0,0 +1,526 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include <cassert> +#include <memory> + +#include "exprs/hybrid_set.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/columns/column_array.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" +#include "vec/data_types/data_type_time_v2.h" +#include "vec/io/io_helper.h" +#include "vec/io/var_int.h" + +namespace doris::vectorized { +class Arena; +class BufferReadable; +class BufferWritable; +} // namespace doris::vectorized + +namespace doris::vectorized { + +/// Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType so that to inherit HybridSet +template <typename T> +constexpr PrimitiveType TypeToPrimitiveType() { + if constexpr (std::is_same_v<T, UInt8> || std::is_same_v<T, Int8>) { + return TYPE_TINYINT; + } else if constexpr (std::is_same_v<T, Int16>) { + return TYPE_SMALLINT; + } else if constexpr (std::is_same_v<T, Int32>) { + return TYPE_INT; + } else if constexpr (std::is_same_v<T, Int64>) { + return TYPE_BIGINT; + } else if constexpr (std::is_same_v<T, Int128>) { + return TYPE_LARGEINT; + } else if constexpr (std::is_same_v<T, Float32>) { + return TYPE_FLOAT; + } else if constexpr (std::is_same_v<T, Float64>) { + return TYPE_DOUBLE; + } else if constexpr (std::is_same_v<T, DateV2>) { + return TYPE_DATEV2; + } else if constexpr (std::is_same_v<T, DateTimeV2>) { + return TYPE_DATETIMEV2; + } else { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType"); + } +} + +template <typename T> +class NullableNumericOrDateSet + : public HybridSet<TypeToPrimitiveType<T>(), DynamicContainer<typename PrimitiveTypeTraits< + TypeToPrimitiveType<T>()>::CppType>> { +public: + NullableNumericOrDateSet() { this->_null_aware = true; } + + void change_contains_null_value(bool target_value) { this->_contains_null = target_value; } +}; + +template <typename T> +struct AggregateFunctionGroupArrayIntersectData { + using NullableNumericOrDateSetType = NullableNumericOrDateSet<T>; + using Set = std::unique_ptr<NullableNumericOrDateSetType>; + + AggregateFunctionGroupArrayIntersectData() + : value(std::make_unique<NullableNumericOrDateSetType>()) {} + + Set value; + bool init = false; +}; + +/// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. +template <typename T> +class AggregateFunctionGroupArrayIntersect + : public IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>> { +private: + using State = AggregateFunctionGroupArrayIntersectData<T>; + DataTypePtr argument_type; + +public: + AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>>( + argument_types_), + argument_type(argument_types_[0]) {} + + AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_, + const bool result_is_nullable) + : IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>, + AggregateFunctionGroupArrayIntersect<T>>( + argument_types_), + argument_type(argument_types_[0]) {} + + String get_name() const override { return "group_array_intersect"; } + + DataTypePtr get_return_type() const override { return argument_type; } + + bool allocates_memory_in_arena() const override { return false; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + auto& data = this->data(place); + auto& init = data.init; + auto& set = data.value; + + const bool col_is_nullable = (*columns[0]).is_nullable(); + const ColumnArray& column = + col_is_nullable ? assert_cast<const ColumnArray&>( + assert_cast<const ColumnNullable&>(*columns[0]) + .get_nested_column()) + : assert_cast<const ColumnArray&>(*columns[0]); + + const auto data_column = column.get_data_ptr(); + const auto& offsets = column.get_offsets(); + const size_t offset = offsets[static_cast<ssize_t>(row_num) - 1]; + const auto arr_size = offsets[row_num] - offset; + + using ColVecType = ColumnVector<T>; + const auto& column_data = column.get_data(); + + const bool is_column_data_nullable = column_data.is_nullable(); + const ColumnNullable* col_null = nullptr; + const ColVecType* nested_column_data = nullptr; + + if (is_column_data_nullable) { + auto const_col_data = const_cast<IColumn*>(&column_data); + col_null = static_cast<ColumnNullable*>(const_col_data); + nested_column_data = &assert_cast<const ColVecType&>(col_null->get_nested_column()); + } else { + nested_column_data = &static_cast<const ColVecType&>(column_data); + } + Review Comment: This method is too complicated. Maybe some logic can be move into the member function of AggregateFunctionGroupArrayIntersectData? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org