zhangstar333 commented on code in PR #40813:
URL: https://github.com/apache/doris/pull/40813#discussion_r1830425662


##########
be/src/vec/aggregate_functions/aggregate_function_approx_topn.h:
##########
@@ -0,0 +1,286 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <rapidjson/encodings.h>
+#include <rapidjson/prettywriter.h>
+#include <rapidjson/stringbuffer.h>
+#include <rapidjson/writer.h>
+
+#include <cstdint>
+#include <string>
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_approx_top.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_array.h"
+#include "vec/columns/column_string.h"
+#include "vec/columns/column_struct.h"
+#include "vec/columns/column_vector.h"
+#include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
+#include "vec/common/space_saving.h"
+#include "vec/common/string_ref.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type_array.h"
+#include "vec/data_types/data_type_ipv4.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_struct.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+
+inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;
+
+struct AggregateFunctionTopKGenericData {
+    using Set = SpaceSaving<StringRef, StringRefHash>;
+
+    Set value;
+};
+
+template <int32_t ArgsSize>
+class AggregateFunctionApproxTopN final
+        : public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData,
+                                              
AggregateFunctionApproxTopN<ArgsSize>>,
+          AggregateFunctionApproxTop {
+private:
+    using State = AggregateFunctionTopKGenericData;
+
+public:
+    AggregateFunctionApproxTopN(const DataTypes& argument_types_)
+            : IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData,
+                                           
AggregateFunctionApproxTopN<ArgsSize>>(argument_types_),
+              _column_size(argument_types_.size() - ArgsSize) {}
+
+    String get_name() const override { return "approx_topn"; }
+
+    DataTypePtr get_return_type() const override { return 
std::make_shared<DataTypeString>(); }
+
+    // Serializes the aggregate function's state (including the SpaceSaving 
structure and threshold) into a buffer.
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
+        this->data(place).value.write(buf);
+
+        write_var_uint(_column_size, buf);
+        write_var_uint(_threshold, buf);
+        write_var_uint(_reserved, buf);
+    }
+
+    // Deserializes the aggregate function's state from a buffer (including 
the SpaceSaving structure and threshold).
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena* arena) const override {
+        auto readStringBinaryInto = [](Arena& arena, BufferReadable& buf) {
+            size_t size = 0;
+            read_var_uint(size, buf);
+
+            if (UNLIKELY(size > DEFAULT_MAX_STRING_SIZE)) {
+                throw Exception(ErrorCode::INTERNAL_ERROR, "Too large string 
size.");
+            }
+
+            char* data = arena.alloc(size);
+            buf.read(data, size);
+
+            return StringRef(data, size);
+        };
+
+        auto& set = this->data(place).value;
+        set.clear();
+
+        size_t size = 0;
+        read_var_uint(size, buf);
+        if (UNLIKELY(size > TOP_K_MAX_SIZE)) {
+            throw Exception(ErrorCode::INTERNAL_ERROR,
+                            "Too large size ({}) for aggregate function '{}' 
state (maximum is {})",
+                            size, get_name(), TOP_K_MAX_SIZE);
+        }
+
+        set.resize(size);
+        for (size_t i = 0; i < size; ++i) {
+            auto ref = readStringBinaryInto(*arena, buf);
+            uint64_t count = 0;
+            uint64_t error = 0;
+            read_var_uint(count, buf);
+            read_var_uint(error, buf);
+            set.insert(ref, count, error);
+            arena->rollback(ref.size);
+        }
+
+        set.read_alpha_map(buf);
+
+        read_var_uint(_column_size, buf);
+        read_var_uint(_threshold, buf);
+        read_var_uint(_reserved, buf);
+    }
+
+    // Adds a new row of data to the aggregate function (inserts a new value 
into the SpaceSaving structure).
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
ssize_t row_num,
+             Arena* arena) const override {
+        if (!_init_flag) {
+            lazy_init(columns, row_num);
+        }
+
+        auto& set = this->data(place).value;
+        if (set.capacity() != _reserved) {
+            set.resize(_reserved);
+        }
+
+        auto all_serialize_value_into_arena =
+                [](size_t i, size_t keys_size, const IColumn** columns, Arena* 
arena) -> StringRef {
+            const char* begin = nullptr;
+
+            size_t sum_size = 0;
+            for (size_t j = 0; j < keys_size; ++j) {
+                sum_size += columns[j]->serialize_value_into_arena(i, *arena, 
begin).size;
+            }
+
+            return {begin, sum_size};
+        };
+
+        StringRef str_serialized =
+                all_serialize_value_into_arena(row_num, _column_size, columns, 
arena);
+        set.insert(str_serialized);
+        arena->rollback(str_serialized.size);
+    }
+
+    void add_many(AggregateDataPtr __restrict place, const IColumn** columns,
+                  std::vector<int>& rows, Arena* arena) const override {
+        for (auto row : rows) {
+            add(place, columns, row, arena);
+        }
+    }
+
+    void add_range(AggregateDataPtr __restrict place, const IColumn** columns, 
ssize_t min,
+                   ssize_t max, Arena* arena) const {
+        for (ssize_t row_num = min; row_num < max; ++row_num) {
+            add(place, columns, row_num, arena);
+        }
+    }
+
+    void reset(AggregateDataPtr __restrict place) const override {
+        this->data(place).value.clear();
+    }
+
+    // Merges the state of another aggregate function into the current one 
(merges two SpaceSaving sets).
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena*) const override {
+        auto& rhs_set = this->data(rhs).value;
+        if (!rhs_set.size()) {
+            return;
+        }
+
+        auto& set = this->data(place).value;
+        if (set.capacity() != _reserved) {
+            set.resize(_reserved);
+        }
+        set.merge(rhs_set);
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        auto& data_to = assert_cast<ColumnString&, 
TypeCheckOnRelease::DISABLE>(to);
+
+        const typename State::Set& set = this->data(place).value;
+        auto result_vec = set.top_k(_threshold);
+
+        rapidjson::StringBuffer buffer;
+        rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
+        writer.StartArray();
+        for (auto& result : result_vec) {
+            auto argument_types = this->get_argument_types();
+            MutableColumns argument_columns(_column_size);
+            for (size_t i = 0; i < _column_size; ++i) {
+                argument_columns[i] = argument_types[i]->create_column();
+            }
+
+            rapidjson::StringBuffer sub_buffer;
+            rapidjson::Writer<rapidjson::StringBuffer> sub_writer(sub_buffer);
+            const char* begin = result.key.data;
+            for (size_t i = 0; i < argument_columns.size(); i++) {
+                sub_writer.StartObject();
+                begin = 
argument_columns[i]->deserialize_and_insert_from_arena(begin);
+                std::string row_str = 
argument_types[i]->to_string(*argument_columns[i], 0);
+                sub_writer.Key(row_str.data(), row_str.size());
+            }
+
+            sub_writer.Uint64(result.count);
+            for (size_t i = 0; i < argument_columns.size(); i++) {
+                sub_writer.EndObject();
+            }
+            writer.RawValue(sub_buffer.GetString(), sub_buffer.GetSize(), 
rapidjson::kObjectType);
+        }
+        writer.EndArray();
+        std::string res = buffer.GetString();
+        data_to.insert_data(res.data(), res.size());
+    }
+
+private:
+    void lazy_init(const IColumn** columns, ssize_t row_num) const {
+        auto get_param = [](size_t idx, const DataTypes& data_types,
+                            const IColumn** columns) -> uint64_t {
+            const auto& data_type = data_types.at(idx);
+            const IColumn* column = columns[idx];
+
+            const auto* type = data_type.get();
+            if (type->is_nullable()) {
+                type = assert_cast<const DataTypeNullable*, 
TypeCheckOnRelease::DISABLE>(type)
+                               ->get_nested_type()
+                               .get();
+            }
+            int64_t value = 0;
+            WhichDataType which(type);
+            if (which.idx == TypeIndex::Int8) {
+                value = assert_cast<const ColumnInt8*, 
TypeCheckOnRelease::DISABLE>(column)
+                                ->get_element(0);
+            } else if (which.idx == TypeIndex::Int16) {
+                value = assert_cast<const ColumnInt16*, 
TypeCheckOnRelease::DISABLE>(column)
+                                ->get_element(0);
+            } else if (which.idx == TypeIndex::Int32) {
+                value = assert_cast<const ColumnInt32*, 
TypeCheckOnRelease::DISABLE>(column)
+                                ->get_element(0);
+            }
+            if (value <= 0) {
+                throw Exception(ErrorCode::INVALID_ARGUMENT,
+                                "The parameter cannot be less than or equal to 
0.");
+            }
+            return value;
+        };
+
+        const auto& data_types = this->get_argument_types();
+        if (ArgsSize == 1) {
+            _threshold = std::min(get_param(_column_size, data_types, 
columns), (uint64_t)1000);
+        } else if (ArgsSize == 2) {
+            _threshold = std::min(get_param(_column_size, data_types, 
columns), (uint64_t)1000);
+            _reserved =
+                    std::min(std::max(get_param(_column_size + 1, data_types, 
columns), _threshold),
+                             (uint64_t)1000);
+        }
+
+        if (_threshold == 0 || _reserved == 0 || _threshold > 1000 || 
_reserved > 1000) {
+            throw Exception(ErrorCode::INTERNAL_ERROR,
+                            "approx_topn param error, _threshold: {}, 
_reserved: {}", _threshold,
+                            _reserved);
+        }
+
+        _init_flag = true;
+    }
+
+    mutable size_t _column_size = 0;
+    mutable bool _init_flag = false;
+    mutable uint64_t _threshold = 10;
+    mutable uint64_t _reserved = 300;

Review Comment:
   seems those variables could put in AggregateFunctionTopKGenericData,
   and implement some function more clearly.



##########
be/src/vec/aggregate_functions/aggregate_function.h:
##########
@@ -287,8 +288,14 @@ class IAggregateFunctionHelper : public IAggregateFunction 
{
     void add_batch_single_place(size_t batch_size, AggregateDataPtr place, 
const IColumn** columns,
                                 Arena* arena) const override {
         const Derived* derived = assert_cast<const Derived*>(this);
-        for (size_t i = 0; i < batch_size; ++i) {
-            derived->add(place, columns, i, arena);
+
+        if constexpr (is_aggregate_function_multi_top<Derived>::value ||
+                      
is_aggregate_function_multi_top_with_null_variadic_inline<Derived>::value) {
+            derived->add_range(place, columns, 0, batch_size, arena);

Review Comment:
   it's possible could override add_batch_single_place function directly? Maybe 
code more clearly



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to