github-actions[bot] commented on code in PR #33265:
URL: https://github.com/apache/doris/pull/33265#discussion_r1552806789


##########
be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h:
##########
@@ -0,0 +1,529 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// 
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp
+// and modified by Doris
+
+#include <cassert>
+#include <memory>
+
+#include "exprs/hybrid_set.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/columns/column_array.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/data_types/data_type_array.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+#include "vec/data_types/data_type_time_v2.h"
+#include "vec/io/io_helper.h"
+#include "vec/io/var_int.h"
+
+namespace doris::vectorized {
+class Arena;
+class BufferReadable;
+class BufferWritable;
+} // namespace doris::vectorized
+
+namespace doris::vectorized {
+
+/// Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType 
so that to inherit HybridSet
+template <typename T>
+constexpr PrimitiveType TypeToPrimitiveType() {
+    if constexpr (std::is_same_v<T, UInt8> || std::is_same_v<T, Int8>) {
+        return TYPE_TINYINT;
+    } else if constexpr (std::is_same_v<T, Int16>) {
+        return TYPE_SMALLINT;
+    } else if constexpr (std::is_same_v<T, Int32>) {
+        return TYPE_INT;
+    } else if constexpr (std::is_same_v<T, Int64>) {
+        return TYPE_BIGINT;
+    } else if constexpr (std::is_same_v<T, Int128>) {
+        return TYPE_LARGEINT;
+    } else if constexpr (std::is_same_v<T, Float32>) {
+        return TYPE_FLOAT;
+    } else if constexpr (std::is_same_v<T, Float64>) {
+        return TYPE_DOUBLE;
+    } else if constexpr (std::is_same_v<T, DateV2>) {
+        return TYPE_DATEV2;
+    } else if constexpr (std::is_same_v<T, DateTimeV2>) {
+        return TYPE_DATETIMEV2;
+    } else {
+        throw Exception(ErrorCode::INVALID_ARGUMENT,
+                        "Only for changing Numeric type or Date(DateTime)V2 
type to PrimitiveType");
+    }
+}
+
+template <typename T>
+class NullableNumericOrDateSet
+        : public HybridSet<TypeToPrimitiveType<T>(), DynamicContainer<typename 
PrimitiveTypeTraits<
+                                                             
TypeToPrimitiveType<T>()>::CppType>> {
+public:
+    NullableNumericOrDateSet() { this->_null_aware = true; }
+
+    void change_contains_null_value(bool target_value) { this->_contains_null 
= target_value; }
+};
+
+template <typename T>
+struct AggregateFunctionGroupArrayIntersectData {
+    using NullableNumericOrDateSetType = NullableNumericOrDateSet<T>;
+    using Set = std::shared_ptr<NullableNumericOrDateSetType>;
+
+    AggregateFunctionGroupArrayIntersectData()
+            : value(std::make_shared<NullableStringSet>()) = default;
+
+    Set value;
+    UInt64 version = 0;
+};
+
+/// Puts all values to the hash set. Returns an array of unique values. 
Implemented for numeric types.
+template <typename T>
+class AggregateFunctionGroupArrayIntersect
+        : public 
IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>,
+                                              
AggregateFunctionGroupArrayIntersect<T>> {
+private:
+    using State = AggregateFunctionGroupArrayIntersectData<T>;
+    DataTypePtr argument_type;
+
+public:
+    AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_)
+            : 
IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>,
+                                           
AggregateFunctionGroupArrayIntersect<T>>(
+                      argument_types_),
+              argument_type(argument_types_[0]) {}
+
+    AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_,
+                                         const bool result_is_nullable)
+            : 
IAggregateFunctionDataHelper<AggregateFunctionGroupArrayIntersectData<T>,
+                                           
AggregateFunctionGroupArrayIntersect<T>>(
+                      argument_types_),
+              argument_type(argument_types_[0]) {}
+
+    String get_name() const override { return "group_array_intersect"; }
+
+    DataTypePtr get_return_type() const override { return argument_type; }
+
+    bool allocates_memory_in_arena() const override { return false; }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
ssize_t row_num,
+             Arena*) const override {
+        auto& data = this->data(place);
+        auto& version = data.version;
+        auto& set = data.value;
+
+        const bool col_is_nullable = (*columns[0]).is_nullable();
+        const ColumnArray& column =
+                col_is_nullable ? assert_cast<const ColumnArray&>(
+                                          assert_cast<const 
ColumnNullable&>(*columns[0])
+                                                  .get_nested_column())
+                                : assert_cast<const ColumnArray&>(*columns[0]);
+
+        const auto data_column = column.get_data_ptr();
+        const auto& offsets = column.get_offsets();
+        const size_t offset = offsets[static_cast<ssize_t>(row_num) - 1];
+        const auto arr_size = offsets[row_num] - offset;
+
+        using ColVecType = ColumnVector<T>;
+        const auto& column_data = column.get_data();
+
+        bool is_column_data_nullable = column_data.is_nullable();
+        ColumnNullable* col_null = nullptr;
+        const ColVecType* nested_column_data = nullptr;
+
+        if (is_column_data_nullable) {
+            auto const_col_data = const_cast<IColumn*>(&column_data);
+            col_null = static_cast<ColumnNullable*>(const_col_data);
+            nested_column_data = &assert_cast<const 
ColVecType&>(col_null->get_nested_column());
+        } else {
+            nested_column_data = &static_cast<const ColVecType&>(column_data);
+        }
+
+        ++version;
+        if (version == 1) {
+            for (size_t i = 0; i < arr_size; ++i) {
+                const bool is_null_element =
+                        is_column_data_nullable && col_null->is_null_at(offset 
+ i);
+                const T* src_data =
+                        is_null_element ? nullptr : 
&(nested_column_data->get_element(offset + i));
+
+                set->insert(src_data);
+            }
+        } else if (set->size() != 0 || set->contain_null()) {
+            typename State::Set new_set =
+                    std::make_shared<typename 
State::NullableNumericOrDateSetType>();
+
+            for (size_t i = 0; i < arr_size; ++i) {
+                const bool is_null_element =
+                        is_column_data_nullable && col_null->is_null_at(offset 
+ i);
+                const T* src_data =
+                        is_null_element ? nullptr : 
&(nested_column_data->get_element(offset + i));
+
+                if (set->find(src_data) || (set->contain_null() && src_data == 
nullptr)) {
+                    new_set->insert(src_data);
+                }
+            }
+            set = std::move(new_set);
+        }
+    }
+
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena*) const override {
+        auto& data = this->data(place);
+        auto& set = data.value;
+        auto& rhs_set = this->data(rhs).value;
+
+        if (this->data(rhs).version == 0) {
+            return;
+        }
+
+        UInt64 version = data.version++;
+        if (version == 0) {
+            set->change_contains_null_value(rhs_set->contain_null());
+            HybridSetBase::IteratorBase* it = rhs_set->begin();
+            while (it->has_next()) {
+                const void* value = it->get_value();
+                set->insert(value);
+                it->next();
+            }
+            return;
+        }
+
+        if (set->size() != 0) {
+            auto create_new_set = [](auto& lhs_val, auto& rhs_val) {
+                typename State::Set new_set =
+                        std::make_shared<typename 
State::NullableNumericOrDateSetType>();
+                HybridSetBase::IteratorBase* it = lhs_val->begin();
+                while (it->has_next()) {
+                    const void* value = it->get_value();
+                    if ((rhs_val->find(value))) {
+                        new_set->insert(value);
+                    }
+                    it->next();
+                }
+                new_set->change_contains_null_value(lhs_val->contain_null() &&
+                                                    rhs_val->contain_null());
+                return new_set;
+            };
+            auto new_set = rhs_set->size() < set->size() ? 
create_new_set(rhs_set, set)
+                                                         : create_new_set(set, 
rhs_set);
+            set = std::move(new_set);
+        }
+    }
+
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
+        auto& data = this->data(place);
+        auto& set = data.value;
+        auto version = data.version;
+
+        bool is_set_contains_null = set->contain_null();
+
+        write_pod_binary(is_set_contains_null, buf);
+
+        write_var_uint(version, buf);
+        write_var_uint(set->size(), buf);
+        HybridSetBase::IteratorBase* it = set->begin();
+
+        while (it->has_next()) {
+            const T* value_ptr = static_cast<const T*>(it->get_value());
+            write_int_binary((*value_ptr), buf);
+            it->next();
+        }
+    }
+
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena*) const override {
+        auto& data = this->data(place);
+        bool is_set_contains_null;
+
+        read_pod_binary(is_set_contains_null, buf);
+        data.value->change_contains_null_value(is_set_contains_null);
+        read_var_uint(data.version, buf);
+        size_t size;
+        read_var_uint(size, buf);
+
+        T element;
+        for (size_t i = 0; i < size; ++i) {
+            read_int_binary(element, buf);
+            data.value->insert(static_cast<void*>(&element));
+        }
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        ColumnArray& arr_to = assert_cast<ColumnArray&>(to);
+        ColumnArray::Offsets64& offsets_to = arr_to.get_offsets();
+
+        auto& to_nested_col = arr_to.get_data();
+        using ElementType = T;
+        using ColVecType = ColumnVector<ElementType>;
+
+        bool is_nullable = to_nested_col.is_nullable();
+
+        auto insert_values = [](ColVecType& nested_col, auto& set, bool 
is_nullable = false,
+                                ColumnNullable* col_null = nullptr) {
+            size_t old_size = nested_col.get_data().size();
+            size_t res_size = set->size();
+            size_t i = 0;
+
+            if (is_nullable && set->contain_null()) {
+                col_null->insert_data(nullptr, 0);
+                res_size += 1;
+                i = 1;
+            }
+
+            nested_col.get_data().resize(old_size + res_size);
+
+            HybridSetBase::IteratorBase* it = set->begin();
+            while (it->has_next()) {
+                ElementType value = *reinterpret_cast<const 
ElementType*>(it->get_value());
+                nested_col.get_data()[old_size + i] = value;
+                if (is_nullable) {
+                    col_null->get_null_map_data().push_back(0);
+                }
+                it->next();
+                ++i;
+            }
+        };
+
+        const auto& set = this->data(place).value;
+        if (is_nullable) {
+            auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col);
+            auto& nested_col = 
assert_cast<ColVecType&>(col_null->get_nested_column());
+            offsets_to.push_back(offsets_to.back() + set->size() + 
(set->contain_null() ? 1 : 0));
+            insert_values(nested_col, set, true, col_null);
+        } else {
+            auto& nested_col = static_cast<ColVecType&>(to_nested_col);
+            offsets_to.push_back(offsets_to.back() + set->size());
+            insert_values(nested_col, set);
+        }
+    }
+};
+
+/// Generic implementation, it uses serialized representation as object 
descriptor.
+class NullableStringSet : public StringValueSet<DynamicContainer<StringRef>> {
+public:
+    NullableStringSet() { this->_null_aware = true; }
+
+    void change_contains_null_value(bool target_value) { this->_contains_null 
= target_value; }
+};
+
+struct AggregateFunctionGroupArrayIntersectGenericData {
+    using Set = std::shared_ptr<NullableStringSet>;
+
+    AggregateFunctionGroupArrayIntersectGenericData()

Review Comment:
   warning: use '= default' to define a trivial default constructor 
[modernize-use-equals-default]
   
   
be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h:329:
   ```diff
   -             : value(std::make_shared<NullableStringSet>()) {}
   +             : value(std::make_shared<NullableStringSet>()) = default;
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to