This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new 0553ce2 [feature](vectorization) support function topn && remove some
unused code (#7793)
0553ce2 is described below
commit 0553ce294449bcc1f025e22928c154d71c1cfe5d
Author: Pxl <[email protected]>
AuthorDate: Wed Feb 9 13:05:31 2022 +0800
[feature](vectorization) support function topn && remove some unused code
(#7793)
---
be/src/exprs/topn_function.cpp | 40 ----
be/src/vec/CMakeLists.txt | 1 +
.../vec/aggregate_functions/aggregate_function.h | 39 +---
.../aggregate_functions/aggregate_function_avg.h | 8 +-
.../aggregate_function_bitmap.h | 18 +-
.../aggregate_functions/aggregate_function_count.h | 32 ++-
.../aggregate_function_distinct.h | 12 +-
.../aggregate_function_hll_union_agg.h | 15 +-
.../aggregate_function_min_max.h | 27 ++-
.../aggregate_function_nothing.h | 4 +-
.../aggregate_functions/aggregate_function_null.h | 20 +-
.../aggregate_function_simple_factory.cpp | 6 +-
.../aggregate_function_stddev.h | 8 +-
.../aggregate_functions/aggregate_function_sum.h | 24 +--
.../aggregate_function_topn.cpp | 63 ++++++
.../aggregate_functions/aggregate_function_topn.h | 233 +++++++++++++++++++++
.../aggregate_functions/aggregate_function_uniq.h | 14 +-
.../aggregate_function_window.h | 10 +-
be/test/exprs/topn_function_test.cpp | 98 +++------
be/test/vec/aggregate_functions/agg_test.cpp | 46 +++-
.../java/org/apache/doris/catalog/FunctionSet.java | 93 +++-----
21 files changed, 498 insertions(+), 313 deletions(-)
diff --git a/be/src/exprs/topn_function.cpp b/be/src/exprs/topn_function.cpp
index e956e5e..544f980 100644
--- a/be/src/exprs/topn_function.cpp
+++ b/be/src/exprs/topn_function.cpp
@@ -91,50 +91,10 @@ StringVal TopNFunctions::topn_finalize(FunctionContext*
ctx, const StringVal& sr
return result;
}
-template void TopNFunctions::topn_update(FunctionContext*, const BooleanVal&,
const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const TinyIntVal&,
const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const SmallIntVal&,
const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const IntVal&,
const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const BigIntVal&,
const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const FloatVal&,
const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const DoubleVal&,
const IntVal&,
- StringVal*);
template void TopNFunctions::topn_update(FunctionContext*, const StringVal&,
const IntVal&,
StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const DateTimeVal&,
const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const LargeIntVal&,
const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const
DecimalV2Val&, const IntVal&,
- StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const BooleanVal&,
const IntVal&,
- const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const TinyIntVal&,
const IntVal&,
- const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const SmallIntVal&,
const IntVal&,
- const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const IntVal&,
const IntVal&,
- const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const BigIntVal&,
const IntVal&,
- const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const FloatVal&,
const IntVal&,
- const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const DoubleVal&,
const IntVal&,
- const IntVal&, StringVal*);
template void TopNFunctions::topn_update(FunctionContext*, const StringVal&,
const IntVal&,
const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const DateTimeVal&,
const IntVal&,
- const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const LargeIntVal&,
const IntVal&,
- const IntVal&, StringVal*);
-template void TopNFunctions::topn_update(FunctionContext*, const
DecimalV2Val&, const IntVal&,
- const IntVal&, StringVal*);
} // namespace doris
\ No newline at end of file
diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index 6201c67..5a9391e 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -32,6 +32,7 @@ set(VEC_FILES
aggregate_functions/aggregate_function_reader.cpp
aggregate_functions/aggregate_function_window.cpp
aggregate_functions/aggregate_function_stddev.cpp
+ aggregate_functions/aggregate_function_topn.cpp
aggregate_functions/aggregate_function_simple_factory.cpp
columns/collator.cpp
columns/column.cpp
diff --git a/be/src/vec/aggregate_functions/aggregate_function.h
b/be/src/vec/aggregate_functions/aggregate_function.h
index 4c2ef36..c3b5072 100644
--- a/be/src/vec/aggregate_functions/aggregate_function.h
+++ b/be/src/vec/aggregate_functions/aggregate_function.h
@@ -20,13 +20,6 @@
#pragma once
-#include <cstddef>
-#include <istream>
-#include <memory>
-#include <ostream>
-#include <type_traits>
-#include <vector>
-
#include "vec/common/exception.h"
#include "vec/core/block.h"
#include "vec/core/column_numbers.h"
@@ -95,13 +88,15 @@ public:
Arena* arena) const = 0;
/// Merges state (on which place points to) with other state of current
aggregation function.
- virtual void merge(AggregateDataPtr __restrict place,
ConstAggregateDataPtr rhs, Arena* arena) const = 0;
+ virtual void merge(AggregateDataPtr __restrict place,
ConstAggregateDataPtr rhs,
+ Arena* arena) const = 0;
/// Serializes state (to transmit it over the network, for example).
virtual void serialize(ConstAggregateDataPtr __restrict place,
BufferWritable& buf) const = 0;
/// Deserializes state. This function is called only for empty (just
created) states.
- virtual void deserialize(AggregateDataPtr __restrict place,
BufferReadable& buf, Arena* arena) const = 0;
+ virtual void deserialize(AggregateDataPtr __restrict place,
BufferReadable& buf,
+ Arena* arena) const = 0;
/// Returns true if a function requires Arena to handle own states (see
add(), merge(), deserialize()).
virtual bool allocates_memory_in_arena() const { return false; }
@@ -114,21 +109,11 @@ public:
*/
virtual bool is_state() const { return false; }
- /// if return false, during insert_result_into function, you colud get
nullable result column,
+ /// if return false, during insert_result_into function, you colud get
nullable result column,
/// so could insert to null value by yourself, rather than by
AggregateFunctionNullBase;
/// because you maybe be calculate a invalid value, but want to use null
replace it;
virtual bool insert_to_null_default() const { return true; }
- /** The inner loop that uses the function pointer is better than using the
virtual function.
- * The reason is that in the case of virtual functions GCC 5.1.2
generates code,
- * which, at each iteration of the loop, reloads the function address
(the offset value in the virtual function table) from memory to the register.
- * This gives a performance drop on simple queries around 12%.
- * After the appearance of better compilers, the code can be removed.
- */
- using AddFunc = void (*)(const IAggregateFunction*, AggregateDataPtr,
const IColumn**, size_t,
- Arena*);
- virtual AddFunc get_address_of_add_function() const = 0;
-
/** Contains a loop with calls to "add" function. You can collect
arguments into array "places"
* and do a single call to "add_batch" for devirtualization and inlining.
*/
@@ -150,12 +135,6 @@ public:
AggregateDataPtr place, const
IColumn** columns,
Arena* arena) const = 0;
- /** This is used for runtime code generation to determine, which header
files to include in generated source.
- * Always implement it as
- * const char * get_header_file_path() const override { return __FILE__; }
- */
- virtual const char* get_header_file_path() const = 0;
-
const DataTypes& get_argument_types() const { return argument_types; }
const Array& get_parameters() const { return parameters; }
@@ -167,18 +146,10 @@ protected:
/// Implement method to obtain an address of 'add' function.
template <typename Derived>
class IAggregateFunctionHelper : public IAggregateFunction {
-private:
- static void add_free(const IAggregateFunction* that, AggregateDataPtr
place,
- const IColumn** columns, size_t row_num, Arena*
arena) {
- static_cast<const Derived&>(*that).add(place, columns, row_num, arena);
- }
-
public:
IAggregateFunctionHelper(const DataTypes& argument_types_, const Array&
parameters_)
: IAggregateFunction(argument_types_, parameters_) {}
- AddFunc get_address_of_add_function() const override { return &add_free; }
-
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t
place_offset,
const IColumn** columns, Arena* arena) const override {
for (size_t i = 0; i < batch_size; ++i)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h
b/be/src/vec/aggregate_functions/aggregate_function_avg.h
index 18584ee..7b40f95 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h
@@ -101,7 +101,8 @@ public:
this->data(place).count = 0;
}
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
this->data(place).sum += this->data(rhs).sum;
this->data(place).count += this->data(rhs).count;
}
@@ -110,7 +111,8 @@ public:
this->data(place).write(buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
this->data(place).read(buf);
}
@@ -119,8 +121,6 @@ public:
column.get_data().push_back(this->data(place).template
result<ResultType>());
}
- const char* get_header_file_path() const override { return __FILE__; }
-
private:
UInt32 scale;
};
diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
index a2e43e5..4d72f07 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
@@ -16,8 +16,6 @@
// under the License.
#pragma once
-#include <istream>
-#include <ostream>
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_complex.h"
@@ -91,7 +89,8 @@ public:
this->data(place).add(column.get_data()[row_num]);
}
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
this->data(place).merge(
const_cast<AggregateFunctionBitmapData<Op>&>(this->data(rhs)).get());
}
@@ -100,7 +99,8 @@ public:
this->data(place).write(buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
this->data(place).read(buf);
}
@@ -109,8 +109,6 @@ public:
column.get_data().push_back(
const_cast<AggregateFunctionBitmapData<Op>&>(this->data(place)).get());
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
template <bool nullable, typename ColVecType>
@@ -146,7 +144,8 @@ public:
}
}
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
this->data(place).merge(const_cast<AggFunctionData&>(this->data(rhs)).get());
}
@@ -154,7 +153,8 @@ public:
this->data(place).write(buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
this->data(place).read(buf);
}
@@ -163,8 +163,6 @@ public:
auto& column = static_cast<ColVecResult&>(to);
column.get_data().push_back(value_data.cardinality());
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
AggregateFunctionPtr create_aggregate_function_bitmap_union(const std::string&
name,
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h
b/be/src/vec/aggregate_functions/aggregate_function_count.h
index fd096dc..3d8ab79 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_count.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_count.h
@@ -50,11 +50,10 @@ public:
++data(place).count;
}
- void reset(AggregateDataPtr place) const override {
- this->data(place).count = 0;
- }
+ void reset(AggregateDataPtr place) const override {
this->data(place).count = 0; }
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
data(place).count += data(rhs).count;
}
@@ -62,15 +61,14 @@ public:
write_var_uint(data(place).count, buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
read_var_uint(data(place).count, buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
/// Simply count number of not-NULL values.
@@ -90,11 +88,10 @@ public:
data(place).count += !assert_cast<const
ColumnNullable&>(*columns[0]).is_null_at(row_num);
}
- void reset(AggregateDataPtr place) const override {
- data(place).count = 0;
- }
+ void reset(AggregateDataPtr place) const override { data(place).count = 0;
}
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
data(place).count += data(rhs).count;
}
@@ -102,21 +99,22 @@ public:
write_var_uint(data(place).count, buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
read_var_uint(data(place).count, buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
if (to.is_nullable()) {
- auto& null_column = assert_cast<ColumnNullable &>(to);
+ auto& null_column = assert_cast<ColumnNullable&>(to);
null_column.get_null_map_data().push_back(0);
- assert_cast<ColumnInt64
&>(null_column.get_nested_column()).get_data().push_back(data(place).count);
+ assert_cast<ColumnInt64&>(null_column.get_nested_column())
+ .get_data()
+ .push_back(data(place).count);
} else {
- assert_cast<ColumnInt64
&>(to).get_data().push_back(data(place).count);
+
assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
}
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.h
b/be/src/vec/aggregate_functions/aggregate_function_distinct.h
index e100cf3..81f33bd 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_distinct.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.h
@@ -69,13 +69,12 @@ struct AggregateFunctionDistinctGenericData {
Set::LookupResult it;
bool inserted;
for (const auto& elem : rhs.set)
- set.emplace(ArenaKeyHolder{elem.get_value(), *arena}, it,
inserted);
+ set.emplace(ArenaKeyHolder {elem.get_value(), *arena}, it,
inserted);
}
void serialize(BufferWritable& buf) const {
write_var_uint(set.size(), buf);
- for (const auto& elem : set)
- write_string_binary(elem.get_value(), buf);
+ for (const auto& elem : set) write_string_binary(elem.get_value(),
buf);
}
void deserialize(BufferReadable& buf, Arena* arena) {
@@ -121,7 +120,7 @@ struct AggregateFunctionDistinctMultipleGenericData :
public AggregateFunctionDi
Set::LookupResult it;
bool inserted;
- auto key_holder = SerializedKeyHolder{value, *arena};
+ auto key_holder = SerializedKeyHolder {value, *arena};
set.emplace(key_holder, it, inserted);
}
@@ -180,7 +179,8 @@ public:
this->data(place).serialize(buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena* arena) const override {
this->data(place).deserialize(buf, arena);
}
@@ -219,8 +219,6 @@ public:
DataTypePtr get_return_type() const override { return
nested_func->get_return_type(); }
bool allocates_memory_in_arena() const override { return true; }
-
- const char* get_header_file_path() const override { return __FILE__; }
};
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h
b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h
index f71a1f5..612b552 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h
@@ -17,10 +17,6 @@
#pragma once
-#include <istream>
-#include <ostream>
-#include <type_traits>
-
#include "exprs/hll_function.h"
#include "olap/hll.h"
#include "util/slice.h"
@@ -86,7 +82,8 @@ public:
this->data(place).add(column.get_data_at(row_num));
}
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
this->data(place).merge(this->data(rhs));
}
@@ -94,16 +91,16 @@ public:
this->data(place).write(buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
this->data(place).read(buf);
}
- virtual void insert_result_into(ConstAggregateDataPtr __restrict place,
IColumn& to) const override {
+ virtual void insert_result_into(ConstAggregateDataPtr __restrict place,
+ IColumn& to) const override {
auto& column = static_cast<ColumnVector<Int64>&>(to);
column.get_data().push_back(this->data(place).get_cardinality());
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
class AggregateFunctionHLLUnion final : public AggregateFunctionHLLUnionAgg {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.h
b/be/src/vec/aggregate_functions/aggregate_function_min_max.h
index 17d6823..9c2f097 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.h
@@ -35,7 +35,8 @@ struct SingleValueDataFixed {
private:
using Self = SingleValueDataFixed;
- bool has_value = false; /// We need to remember if at least one value has
been passed. This is necessary for AggregateFunctionIf.
+ bool has_value =
+ false; /// We need to remember if at least one value has been
passed. This is necessary for AggregateFunctionIf.
T value;
public:
@@ -50,7 +51,7 @@ public:
void reset() {
if (has()) {
- has_value = false;
+ has_value = false;
}
}
@@ -166,10 +167,10 @@ public:
void reset() {
if (has()) {
- has_value = false;
+ has_value = false;
}
}
-
+
void write(BufferWritable& buf) const {
write_binary(has(), buf);
if (has()) write_binary(value, buf);
@@ -297,13 +298,13 @@ public:
void reset() {
if (size != -1) {
- size = -1;
- capacity = 0;
+ size = -1;
+ capacity = 0;
delete large_data;
large_data = nullptr;
}
}
-
+
void write(BufferWritable& buf) const {
write_binary(size, buf);
if (has()) buf.write(get_data(), size);
@@ -497,11 +498,10 @@ public:
this->data(place).change_if_better(*columns[0], row_num, arena);
}
- void reset(AggregateDataPtr place) const override {
- this->data(place).reset();
- }
+ void reset(AggregateDataPtr place) const override {
this->data(place).reset(); }
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena* arena) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena* arena) const override {
this->data(place).change_if_better(this->data(rhs), arena);
}
@@ -509,7 +509,8 @@ public:
this->data(place).write(buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
this->data(place).read(buf);
}
@@ -518,8 +519,6 @@ public:
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
this->data(place).insert_result_into(to);
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
AggregateFunctionPtr create_aggregate_function_max(const std::string& name,
diff --git a/be/src/vec/aggregate_functions/aggregate_function_nothing.h
b/be/src/vec/aggregate_functions/aggregate_function_nothing.h
index 4c7b193..c0ae740 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_nothing.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_nothing.h
@@ -54,7 +54,7 @@ public:
void add(AggregateDataPtr, const IColumn**, size_t, Arena*) const override
{}
void reset(AggregateDataPtr place) const override {}
-
+
void merge(AggregateDataPtr, ConstAggregateDataPtr, Arena*) const override
{}
void serialize(ConstAggregateDataPtr, BufferWritable& buf) const override
{}
@@ -64,8 +64,6 @@ public:
void insert_result_into(ConstAggregateDataPtr, IColumn& to) const override
{
to.insert_default();
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h
b/be/src/vec/aggregate_functions/aggregate_function_null.h
index 9458d7d..83cae6f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_null.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_null.h
@@ -62,13 +62,11 @@ protected:
}
static void init_flag(AggregateDataPtr __restrict place) noexcept {
- if constexpr (result_is_nullable)
- place[0] = 0;
+ if constexpr (result_is_nullable) place[0] = 0;
}
static void set_flag(AggregateDataPtr __restrict place) noexcept {
- if constexpr (result_is_nullable)
- place[0] = 1;
+ if constexpr (result_is_nullable) place[0] = 1;
}
static bool get_flag(ConstAggregateDataPtr __restrict place) noexcept {
@@ -117,7 +115,8 @@ public:
size_t align_of_data() const override { return
nested_function->align_of_data(); }
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena* arena) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena* arena) const override {
if (result_is_nullable && get_flag(rhs)) set_flag(place);
nested_function->merge(nested_place(place), nested_place(rhs), arena);
@@ -131,7 +130,8 @@ public:
}
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena* arena) const override {
bool flag = true;
if (result_is_nullable) read_binary(flag, buf);
if (flag) {
@@ -145,10 +145,12 @@ public:
ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
if (get_flag(place)) {
if (nested_function->insert_to_null_default()) {
- nested_function->insert_result_into(nested_place(place),
to_concrete.get_nested_column());
+ nested_function->insert_result_into(nested_place(place),
+
to_concrete.get_nested_column());
to_concrete.get_null_map_data().push_back(0);
} else {
- nested_function->insert_result_into(nested_place(place),
to); //want to insert into null value by self
+ nested_function->insert_result_into(
+ nested_place(place), to); //want to insert into
null value by self
}
} else {
to_concrete.insert_default();
@@ -163,8 +165,6 @@ public:
}
bool is_state() const override { return nested_function->is_state(); }
-
- const char* get_header_file_path() const override { return __FILE__; }
};
/** There are two cases: for single argument and variadic.
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index ba1b2ba..4844000 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -36,6 +36,7 @@ void
register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory)
void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory&
factory);
void
register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory&
factory);
void
register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory&
factory);
+void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
static std::once_flag oc;
static AggregateFunctionSimpleFactory instance;
@@ -51,10 +52,11 @@ AggregateFunctionSimpleFactory&
AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_reader(instance); // register aggregate
function for agg reader
register_aggregate_function_window_rank(instance);
register_aggregate_function_stddev_variance(instance);
-
+ register_aggregate_function_topn(instance);
+
// if you only register function with no nullable, and wants to add
nullable automatically, you should place function above this line
register_aggregate_function_combinator_null(instance);
-
+
register_aggregate_function_reader_no_spread(instance);
register_aggregate_function_window_lead_lag(instance);
});
diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h
b/be/src/vec/aggregate_functions/aggregate_function_stddev.h
index 6cdba200..82e8718 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h
@@ -188,7 +188,8 @@ struct BaseDatadecimal {
template <typename T, typename Data>
struct PopData : Data {
- using ColVecResult = std::conditional_t<IsDecimalNumber<T>,
ColumnDecimal<Decimal128>, ColumnVector<Float64>>;
+ using ColVecResult = std::conditional_t<IsDecimalNumber<T>,
ColumnDecimal<Decimal128>,
+ ColumnVector<Float64>>;
void insert_result_into(IColumn& to) const {
ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to);
auto& col =
static_cast<ColVecResult&>(nullable_column.get_nested_column());
@@ -203,7 +204,8 @@ struct PopData : Data {
template <typename T, typename Data>
struct SampData : Data {
- using ColVecResult = std::conditional_t<IsDecimalNumber<T>,
ColumnDecimal<Decimal128>, ColumnVector<Float64>>;
+ using ColVecResult = std::conditional_t<IsDecimalNumber<T>,
ColumnDecimal<Decimal128>,
+ ColumnVector<Float64>>;
void insert_result_into(IColumn& to) const {
ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to);
if (this->count == 1) {
@@ -278,8 +280,6 @@ public:
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
this->data(place).insert_result_into(to);
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.h
b/be/src/vec/aggregate_functions/aggregate_function_sum.h
index 402af36..eceaf03 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sum.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum.h
@@ -20,10 +20,6 @@
#pragma once
-#include <istream>
-#include <ostream>
-#include <type_traits>
-
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_vector.h"
#include "vec/data_types/data_type_decimal.h"
@@ -82,12 +78,11 @@ public:
const auto& column = static_cast<const ColVecType&>(*columns[0]);
this->data(place).add(column.get_data()[row_num]);
}
-
- void reset(AggregateDataPtr place) const override {
- this->data(place).sum = {};
- }
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
+ void reset(AggregateDataPtr place) const override { this->data(place).sum
= {}; }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
this->data(place).merge(this->data(rhs));
}
@@ -95,7 +90,8 @@ public:
this->data(place).write(buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
this->data(place).read(buf);
}
@@ -104,15 +100,13 @@ public:
column.get_data().push_back(this->data(place).get());
}
- const char* get_header_file_path() const override { return __FILE__; }
-
private:
UInt32 scale;
};
AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string&
name,
- const DataTypes&
argument_types,
- const Array& parameters,
- const bool
result_is_nullable);
+ const DataTypes&
argument_types,
+ const Array&
parameters,
+ const bool
result_is_nullable);
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
new file mode 100644
index 0000000..a8347bf
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <vec/aggregate_functions/aggregate_function_topn.h>
+
+namespace doris::vectorized {
+
+template <template <typename DataHelper> typename Impl>
+struct currying_function_topn {
+ template <typename T>
+ using Function = AggregateFunctionTopN<AggregateFunctionTopNData, Impl<T>>;
+
+ AggregateFunctionPtr operator()(const std::string& name, const DataTypes&
argument_types) {
+ AggregateFunctionPtr res = nullptr;
+ res.reset(new Function<StringDataImplTopN>(argument_types));
+
+ if (!res) {
+ LOG(WARNING) << fmt::format("Illegal type {} of argument for
aggregate function {}",
+ argument_types[0]->get_name(), name);
+ }
+
+ return res;
+ }
+};
+
+AggregateFunctionPtr create_aggregate_function_topn(const std::string& name,
+ const DataTypes&
argument_types,
+ const Array& parameters,
+ const bool
result_is_nullable) {
+ if (argument_types.size() == 1) {
+ return AggregateFunctionPtr(
+ new AggregateFunctionTopN<AggregateFunctionTopNData,
+
AggregateFunctionTopNImplMerge>(argument_types));
+ } else if (argument_types.size() == 2) {
+ return currying_function_topn<AggregateFunctionTopNImplInt>()(name,
argument_types);
+ } else if (argument_types.size() == 3) {
+ return currying_function_topn<AggregateFunctionTopNImplIntInt>()(name,
argument_types);
+ }
+
+ LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate
function {}",
+ argument_types.size(), name);
+ return nullptr;
+}
+
+void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory)
{
+ factory.register_function("topn", create_aggregate_function_topn);
+}
+
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h
b/be/src/vec/aggregate_functions/aggregate_function_topn.h
new file mode 100644
index 0000000..0bfc721
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h
@@ -0,0 +1,233 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <parallel_hashmap/phmap.h>
+#include <rapidjson/stringbuffer.h>
+#include <rapidjson/writer.h>
+
+#include <unordered_map>
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/data_types/data_type_date_time.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+
+// space-saving algorithm
+struct AggregateFunctionTopNData {
+ void set_paramenters(int input_top_num, int space_expand_rate = 50) {
+ top_num = input_top_num;
+ capacity = (uint64_t)top_num * space_expand_rate;
+ }
+
+ void add(const std::string& value) {
+ auto it = counter_map.find(value);
+ if (it != counter_map.end()) {
+ it->second++;
+ } else {
+ counter_map.insert({value, 1});
+ }
+ }
+
+ void merge(const AggregateFunctionTopNData& rhs) {
+ top_num = rhs.top_num;
+ capacity = rhs.capacity;
+
+ bool lhs_full = (counter_map.size() >= capacity);
+ bool rhs_full = (rhs.counter_map.size() >= capacity);
+
+ uint64_t lhs_min = 0;
+ uint64_t rhs_min = 0;
+
+ if (lhs_full) {
+ lhs_min = UINT64_MAX;
+ for (auto it : counter_map) {
+ lhs_min = std::min(lhs_min, it.second);
+ }
+ }
+
+ if (rhs_full) {
+ rhs_min = UINT64_MAX;
+ for (auto it : rhs.counter_map) {
+ rhs_min = std::min(rhs_min, it.second);
+ }
+
+ for (auto& it : counter_map) {
+ it.second += rhs_min;
+ }
+ }
+
+ for (auto rhs_it : rhs.counter_map) {
+ auto lhs_it = counter_map.find(rhs_it.first);
+ if (lhs_it != counter_map.end()) {
+ lhs_it->second += rhs_it.second - rhs_min;
+ } else {
+ counter_map.insert({rhs_it.first, rhs_it.second + lhs_min});
+ }
+ }
+ }
+
+ std::vector<std::pair<uint64_t, std::string>> get_remain_vector() const {
+ std::vector<std::pair<uint64_t, std::string>> counter_vector;
+ for (auto it : counter_map) {
+ counter_vector.emplace_back(it.second, it.first);
+ }
+ std::sort(counter_vector.begin(), counter_vector.end(),
+ std::greater<std::pair<uint64_t, std::string>>());
+ return counter_vector;
+ }
+
+ void write(BufferWritable& buf) const {
+ write_binary(top_num, buf);
+ write_binary(capacity, buf);
+
+ uint64_t element_number = std::min(capacity,
(uint64_t)counter_map.size());
+ write_binary(element_number, buf);
+
+ auto counter_vector = get_remain_vector();
+
+ for (auto i = 0; i < element_number; i++) {
+ auto element = counter_vector[i];
+ write_binary(element.second, buf);
+ write_binary(element.first, buf);
+ }
+ }
+
+ void read(BufferReadable& buf) {
+ read_binary(top_num, buf);
+ read_binary(capacity, buf);
+
+ uint64_t element_number = 0;
+ read_binary(element_number, buf);
+
+ counter_map.clear();
+ std::pair<std::string, uint64_t> element;
+ for (auto i = 0; i < element_number; i++) {
+ read_binary(element.first, buf);
+ read_binary(element.second, buf);
+ counter_map.insert(element);
+ }
+ }
+
+ std::string get() const {
+ auto counter_vector = get_remain_vector();
+
+ rapidjson::StringBuffer buffer;
+ rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
+
+ writer.StartObject();
+ for (int i = 0; i < std::min((int)counter_vector.size(), top_num);
i++) {
+ const auto& element = counter_vector[i];
+ writer.Key(element.second.c_str());
+ writer.Uint64(element.first);
+ }
+ writer.EndObject();
+
+ return buffer.GetString();
+ }
+
+ void reset() { counter_map.clear(); }
+
+ int top_num;
+ uint64_t capacity;
+ phmap::flat_hash_map<std::string, uint64_t> counter_map;
+};
+
+struct StringDataImplTopN {
+ using DataType = DataTypeString;
+ static std::string to_string(const IColumn& column, size_t row_num) {
+ StringRef ref =
+ static_cast<const typename
DataType::ColumnType&>(column).get_data_at(row_num);
+ return std::string(ref.data, ref.size);
+ }
+};
+
+template <typename DataHelper>
+struct AggregateFunctionTopNImplInt {
+ static void add(AggregateFunctionTopNData& __restrict place, const
IColumn** columns,
+ size_t row_num) {
+ place.set_paramenters(static_cast<const
ColumnInt32*>(columns[1])->get_element(row_num));
+ place.add(DataHelper::to_string(*columns[0], row_num));
+ }
+};
+
+template <typename DataHelper>
+struct AggregateFunctionTopNImplIntInt {
+ static void add(AggregateFunctionTopNData& __restrict place, const
IColumn** columns,
+ size_t row_num) {
+ place.set_paramenters(static_cast<const
ColumnInt32*>(columns[1])->get_element(row_num),
+ static_cast<const
ColumnInt32*>(columns[2])->get_element(row_num));
+ place.add(DataHelper::to_string(*columns[0], row_num));
+ }
+};
+
+struct AggregateFunctionTopNImplMerge {
+ // only used at AGGREGATE (merge finalize)
+ static void add(AggregateFunctionTopNData& __restrict place, const
IColumn** columns,
+ size_t row_num) {
+ LOG(FATAL) << "AggregateFunctionTopNImplMerge do not support add()";
+ }
+};
+
+//base function
+template <typename Data, typename Impl>
+class AggregateFunctionTopN final
+ : public IAggregateFunctionDataHelper<Data,
AggregateFunctionTopN<Data, Impl>> {
+public:
+ AggregateFunctionTopN(const DataTypes& argument_types_)
+ : IAggregateFunctionDataHelper<Data, AggregateFunctionTopN<Data,
Impl>>(argument_types_,
+
{}) {}
+
+ String get_name() const override { return "topn"; }
+
+ DataTypePtr get_return_type() const override { return
{std::make_shared<DataTypeString>()}; }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ Arena*) const override {
+ Impl::add(this->data(place), columns, row_num);
+ }
+
+ void reset(AggregateDataPtr __restrict place) const override {
this->data(place).reset(); }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
+ this->data(place).merge(this->data(rhs));
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ this->data(place).write(buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
+ this->data(place).read(buf);
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ std::string result = this->data(place).get();
+ static_cast<ColumnString&>(to).insert_data(result.c_str(),
result.length());
+ }
+};
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.h
b/be/src/vec/aggregate_functions/aggregate_function_uniq.h
index ea32404..c717307 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_uniq.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.h
@@ -23,7 +23,6 @@
#include <type_traits>
#include "gutil/hash/city.h"
-
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_decimal.h"
#include "vec/common/aggregation_common.h"
@@ -81,8 +80,9 @@ struct OneAdder {
hash.get128(key.low, key.high);
data.set.insert(key);
- } else if constexpr(std::is_same_v<T, Decimal128>) {
- data.set.insert(assert_cast<const
ColumnDecimal<Decimal128>&>(column).get_data()[row_num]);
+ } else if constexpr (std::is_same_v<T, Decimal128>) {
+ data.set.insert(
+ assert_cast<const
ColumnDecimal<Decimal128>&>(column).get_data()[row_num]);
} else {
data.set.insert(assert_cast<const
ColumnVector<T>&>(column).get_data()[row_num]);
}
@@ -109,7 +109,8 @@ public:
detail::OneAdder<T, Data>::add(this->data(place), *columns[0],
row_num);
}
- void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
this->data(place).set.merge(this->data(rhs).set);
}
@@ -117,15 +118,14 @@ public:
this->data(place).set.write(buf);
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
this->data(place).set.read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
assert_cast<ColumnInt64&>(to).get_data().push_back(this->data(place).set.size());
}
-
- const char* get_header_file_path() const override { return __FILE__; }
};
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.h
b/be/src/vec/aggregate_functions/aggregate_function_window.h
index 071c2ea..c6171e2 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.h
@@ -20,10 +20,6 @@
#pragma once
-#include <istream>
-#include <ostream>
-#include <type_traits>
-
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column_vector.h"
#include "vec/data_types/data_type_decimal.h"
@@ -66,7 +62,6 @@ public:
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*)
const override {}
void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const
override {}
void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*)
const override {}
- const char* get_header_file_path() const override { return __FILE__; }
};
struct RankData {
@@ -112,7 +107,6 @@ public:
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*)
const override {}
void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const
override {}
void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*)
const override {}
- const char* get_header_file_path() const override { return __FILE__; }
};
struct DenseRankData {
@@ -154,7 +148,6 @@ public:
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*)
const override {}
void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const
override {}
void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*)
const override {}
- const char* get_header_file_path() const override { return __FILE__; }
};
struct Value {
@@ -233,7 +226,7 @@ public:
return;
}
if constexpr (is_string) {
- const auto *sources = check_and_get_column<ColumnString>(
+ const auto* sources = check_and_get_column<ColumnString>(
nullable_column->get_nested_column_ptr().get());
_data_value.set_value(sources->get_data_at(pos));
} else {
@@ -410,7 +403,6 @@ public:
void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*)
const override {
LOG(FATAL) << "WindowFunctionData do not support deserialize";
}
- const char* get_header_file_path() const override { return __FILE__; }
private:
DataTypePtr _argument_type;
diff --git a/be/test/exprs/topn_function_test.cpp
b/be/test/exprs/topn_function_test.cpp
index 2f0e201..1ba2be4 100644
--- a/be/test/exprs/topn_function_test.cpp
+++ b/be/test/exprs/topn_function_test.cpp
@@ -15,17 +15,18 @@
// specific language governing permissions and limitations
// under the License.
-#include "exprs/anyval_util.h"
-#include "exprs/expr_context.h"
#include "exprs/topn_function.h"
-#include "util/topn_counter.h"
-#include "testutil/function_utils.h"
-#include "zipf_distribution.h"
-#include "test_util/test_util.h"
#include <gtest/gtest.h>
+
#include <unordered_map>
+#include "exprs/anyval_util.h"
+#include "exprs/expr_context.h"
+#include "test_util/test_util.h"
+#include "testutil/function_utils.h"
+#include "util/topn_counter.h"
+#include "zipf_distribution.h"
namespace doris {
@@ -34,13 +35,14 @@ static const uint32_t TOTAL_RECORDS =
LOOP_LESS_OR_MORE(1000, 1000000);
static const uint32_t PARALLEL = 10;
std::string gen_random(const int len) {
- std::string possible_characters =
"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
+ std::string possible_characters =
+ "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
std::random_device rd;
std::mt19937 generator(rd());
- std::uniform_int_distribution<> dist(0, possible_characters.size()-1);
+ std::uniform_int_distribution<> dist(0, possible_characters.size() - 1);
std::string rand_str(len, '\0');
- for(auto& dis: rand_str) {
+ for (auto& dis : rand_str) {
dis = possible_characters[dist(generator)];
}
return rand_str;
@@ -55,16 +57,15 @@ public:
ctx = utils->get_fn_ctx();
}
- void TearDown() {
- delete utils;
- }
+ void TearDown() { delete utils; }
private:
- FunctionUtils *utils;
- FunctionContext *ctx;
+ FunctionUtils* utils;
+ FunctionContext* ctx;
};
-void update_accuracy_map(const std::string& item,
std::unordered_map<std::string, uint32_t>& accuracy_map) {
+void update_accuracy_map(const std::string& item,
+ std::unordered_map<std::string, uint32_t>&
accuracy_map) {
if (accuracy_map.find(item) != accuracy_map.end()) {
++accuracy_map[item];
} else {
@@ -72,14 +73,18 @@ void update_accuracy_map(const std::string& item,
std::unordered_map<std::string
}
}
-void topn_single(FunctionContext* ctx, std::string& random_str, StringVal&
dst, std::unordered_map<std::string, uint32_t>& accuracy_map){
- TopNFunctions::topn_update(ctx, StringVal(((uint8_t*) random_str.data()),
random_str.length()), TOPN_NUM, &dst);
+void topn_single(FunctionContext* ctx, std::string& random_str, StringVal& dst,
+ std::unordered_map<std::string, uint32_t>& accuracy_map) {
+ TopNFunctions::topn_update(ctx, StringVal(((uint8_t*)random_str.data()),
random_str.length()),
+ TOPN_NUM, &dst);
update_accuracy_map(random_str, accuracy_map);
}
-void test_topn_accuracy(FunctionContext* ctx, int key_space, int
space_expand_rate, double zipf_distribution_exponent) {
- LOG(INFO) << "topn accuracy : " << "key space : " << key_space << " ,
space_expand_rate : " << space_expand_rate <<
- " , zf exponent : " << zipf_distribution_exponent;
+void test_topn_accuracy(FunctionContext* ctx, int key_space, int
space_expand_rate,
+ double zipf_distribution_exponent) {
+ LOG(INFO) << "topn accuracy : "
+ << "key space : " << key_space << " , space_expand_rate : " <<
space_expand_rate
+ << " , zf exponent : " << zipf_distribution_exponent;
std::unordered_map<std::string, uint32_t> accuracy_map;
// prepare random data
std::vector<std::string> random_strs(key_space);
@@ -110,7 +115,7 @@ void test_topn_accuracy(FunctionContext* ctx, int
key_space, int space_expand_ra
std::random_device random_rd;
std::mt19937 random_gen(random_rd());
- std::uniform_int_distribution<> dist(0, PARALLEL-1);
+ std::uniform_int_distribution<> dist(0, PARALLEL - 1);
for (uint32_t i = 0; i < TOTAL_RECORDS; ++i) {
// generate zipf_distribution
uint32_t index = zf(gen);
@@ -119,13 +124,14 @@ void test_topn_accuracy(FunctionContext* ctx, int
key_space, int space_expand_ra
}
for (uint32_t i = 0; i < PARALLEL; ++i) {
- StringVal serialized_str = TopNFunctions::topn_serialize(ctx,
single_dst_str[i]);
+ StringVal serialized_str = TopNFunctions::topn_serialize(ctx,
single_dst_str[i]);
TopNFunctions::topn_merge(ctx, serialized_str, &dst);
}
// get accuracy result
std::vector<Counter> accuracy_sort_vec;
- for(std::unordered_map<std::string, uint32_t >::const_iterator it =
accuracy_map.begin(); it != accuracy_map.end(); ++it) {
+ for (std::unordered_map<std::string, uint32_t>::const_iterator it =
accuracy_map.begin();
+ it != accuracy_map.end(); ++it) {
accuracy_sort_vec.emplace_back(it->first, it->second);
}
std::sort(accuracy_sort_vec.begin(), accuracy_sort_vec.end(),
TopNComparator());
@@ -143,8 +149,10 @@ void test_topn_accuracy(FunctionContext* ctx, int
key_space, int space_expand_ra
if (accuracy_counter.get_count() != topn_counter.get_count()) {
++error;
LOG(INFO) << "Failed";
- LOG(INFO) << "accuracy counter : (" << accuracy_counter.get_item()
<< ", " << accuracy_counter.get_count() << ")";
- LOG(INFO) << "topn counter : (" << topn_counter.get_item() << ", "
<< topn_counter.get_count() << ")";
+ LOG(INFO) << "accuracy counter : (" << accuracy_counter.get_item()
<< ", "
+ << accuracy_counter.get_count() << ")";
+ LOG(INFO) << "topn counter : (" << topn_counter.get_item() << ", "
+ << topn_counter.get_count() << ")";
}
}
error += std::abs((int32_t)(accuracy_sort_vec.size() -
topn_sort_vec.size()));
@@ -165,7 +173,6 @@ TEST_F(TopNFunctionsTest, topn_accuracy) {
}
}
}
-
}
TEST_F(TopNFunctionsTest, topn_update) {
@@ -220,44 +227,7 @@ TEST_F(TopNFunctionsTest, topn_merge) {
ASSERT_EQ(expected, result);
}
-TEST_F(TopNFunctionsTest, test_null_value) {
- StringVal dst1;
- TopNFunctions::topn_init(ctx, &dst1);
-
- for (uint32_t i = 0; i < 10; ++i) {
- TopNFunctions::topn_update(ctx, IntVal::null(), 2, &dst1);
- }
- StringVal serialized = TopNFunctions::topn_serialize(ctx, dst1);
-
- StringVal dst2;
- TopNFunctions::topn_init(ctx, &dst2);
- TopNFunctions::topn_merge(ctx, serialized, &dst2);
- StringVal result = TopNFunctions::topn_finalize(ctx, dst2);
- StringVal expected("{}");
- ASSERT_EQ(expected, result);
-}
-
-TEST_F(TopNFunctionsTest, test_date_type) {
- StringVal dst1;
- TopNFunctions::topn_init(ctx, &dst1);
-
- DateTimeValue dt(20201001000000);
- doris_udf::DateTimeVal dt_val;
- dt.to_datetime_val(&dt_val);
- for (uint32_t i = 0; i < 10; ++i) {
- TopNFunctions::topn_update(ctx, dt_val, 1, &dst1);
- }
- StringVal serialized = TopNFunctions::topn_serialize(ctx, dst1);
-
- StringVal dst2;
- TopNFunctions::topn_init(ctx, &dst2);
- TopNFunctions::topn_merge(ctx, serialized, &dst2);
- StringVal result = TopNFunctions::topn_finalize(ctx, dst2);
- StringVal expected("{\"2020-10-01 00:00:00\":10}");
- ASSERT_EQ(expected, result);
-}
-
-}
+} // namespace doris
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
diff --git a/be/test/vec/aggregate_functions/agg_test.cpp
b/be/test/vec/aggregate_functions/agg_test.cpp
index 33e6c43..0b00f2f 100644
--- a/be/test/vec/aggregate_functions/agg_test.cpp
+++ b/be/test/vec/aggregate_functions/agg_test.cpp
@@ -20,17 +20,22 @@
#include "gtest/gtest.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/aggregate_function_topn.h"
#include "vec/columns/column_vector.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+
+const int agg_test_batch_size = 4096;
namespace doris::vectorized {
// declare function
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
+void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
TEST(AggTest, basic_test) {
auto column_vector_int32 = ColumnVector<Int32>::create();
- for (int i = 0; i < 4096; i++) {
+ for (int i = 0; i < agg_test_batch_size; i++) {
column_vector_int32->insert(cast_to_nearest_field_type(i));
}
// test implement interface
@@ -40,14 +45,14 @@ TEST(AggTest, basic_test) {
DataTypes data_types = {data_type};
Array array;
auto agg_function = factory.get("sum", data_types, array);
- AggregateDataPtr place = (char*)malloc(sizeof(uint64_t) * 4096);
+ AggregateDataPtr place = new char[agg_function->size_of_data()];
agg_function->create(place);
const IColumn* column[1] = {column_vector_int32.get()};
- for (int i = 0; i < 4096; i++) {
+ for (int i = 0; i < agg_test_batch_size; i++) {
agg_function->add(place, column, i, nullptr);
}
int ans = 0;
- for (int i = 0; i < 4096; i++) {
+ for (int i = 0; i < agg_test_batch_size; i++) {
ans += i;
}
ASSERT_EQ(ans, *(int32_t*)place);
@@ -56,6 +61,39 @@ TEST(AggTest, basic_test) {
free(place);
}
}
+
+TEST(AggTest, topn_test) {
+ MutableColumns datas(2);
+ datas[0] = ColumnString::create();
+ datas[1] = ColumnInt32::create();
+ int top = 10;
+
+ for (int i = 0; i < agg_test_batch_size; i++) {
+ std::string str = std::to_string(agg_test_batch_size / (i + 1));
+ datas[0]->insert_data(str.c_str(), str.length());
+ datas[1]->insert_data(reinterpret_cast<char*>(&top), sizeof(top));
+ }
+
+ AggregateFunctionSimpleFactory factory;
+ register_aggregate_function_topn(factory);
+ DataTypes data_types = {std::make_shared<DataTypeString>(),
std::make_shared<DataTypeInt32>()};
+ Array array;
+
+ auto agg_function = factory.get("topn", data_types, array);
+ AggregateDataPtr place = new char[agg_function->size_of_data()];
+ agg_function->create(place);
+
+ IColumn* columns[2] = {datas[0].get(), datas[1].get()};
+
+ for (int i = 0; i < agg_test_batch_size; i++) {
+ agg_function->add(place, const_cast<const IColumn**>(columns), i,
nullptr);
+ }
+
+ std::string result =
reinterpret_cast<AggregateFunctionTopNData*>(place)->get();
+ std::string
expect_result="{\"1\":2048,\"2\":683,\"3\":341,\"4\":205,\"5\":137,\"6\":97,\"7\":73,\"8\":57,\"9\":46,\"10\":37}";
+ ASSERT_EQ(result, expect_result);
+ agg_function->destroy(place);
+}
} // namespace doris::vectorized
int main(int argc, char** argv) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index 3a731d4..5f435d1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1008,66 +1008,22 @@ public class
FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo
private static final Map<Type, String> TOPN_UPDATE_SYMBOL =
ImmutableMap.<Type, String>builder()
- .put(Type.BOOLEAN,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf10BooleanValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
- .put(Type.TINYINT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
- .put(Type.SMALLINT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf11SmallIntValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
- .put(Type.INT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf6IntValEEEvPNS2_15FunctionContextERKT_RKS3_PNS2_9StringValE")
- .put(Type.BIGINT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9BigIntValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
- .put(Type.FLOAT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf8FloatValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
- .put(Type.DOUBLE,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9DoubleValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
.put(Type.CHAR,
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPS3_")
.put(Type.VARCHAR,
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPS3_")
.put(Type.STRING,
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPS3_")
- .put(Type.DATE,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf11DateTimeValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
- .put(Type.DATETIME,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf11DateTimeValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
- .put(Type.DECIMALV2,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
- .put(Type.LARGEINT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf11LargeIntValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValEPNS2_9StringValE")
.build();
private static final Map<Type, String> TOPN_UPDATE_MORE_PARAM_SYMBOL =
ImmutableMap.<Type, String>builder()
- .put(Type.BOOLEAN,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf10BooleanValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
- .put(Type.TINYINT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
- .put(Type.SMALLINT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf11SmallIntValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
- .put(Type.INT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf6IntValEEEvPNS2_15FunctionContextERKT_RKS3_SA_PNS2_9StringValE")
- .put(Type.BIGINT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9BigIntValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
- .put(Type.FLOAT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf8FloatValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
- .put(Type.DOUBLE,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9DoubleValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
.put(Type.CHAR,
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PS3_")
.put(Type.VARCHAR,
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PS3_")
.put(Type.STRING,
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PS3_")
- .put(Type.DATE,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf11DateTimeValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
- .put(Type.DATETIME,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf11DateTimeValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
- .put(Type.DECIMALV2,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
- .put(Type.LARGEINT,
-
"_ZN5doris13TopNFunctions11topn_updateIN9doris_udf11LargeIntValEEEvPNS2_15FunctionContextERKT_RKNS2_6IntValESB_PNS2_9StringValE")
.build();
public Function getFunction(Function desc, Function.CompareMode mode) {
@@ -1631,22 +1587,39 @@ public class
FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo
// TopN
if (TOPN_UPDATE_SYMBOL.containsKey(t)) {
- addBuiltin(AggregateFunction.createBuiltin("topn",
- Lists.newArrayList(t, Type.INT), Type.VARCHAR,
Type.VARCHAR,
-
"_ZN5doris13TopNFunctions9topn_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
- TOPN_UPDATE_SYMBOL.get(t),
-
"_ZN5doris13TopNFunctions10topn_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
-
"_ZN5doris13TopNFunctions14topn_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
-
"_ZN5doris13TopNFunctions13topn_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
- true, false, true));
- addBuiltin(AggregateFunction.createBuiltin("topn",
- Lists.newArrayList(t, Type.INT, Type.INT),
Type.VARCHAR, Type.VARCHAR,
-
"_ZN5doris13TopNFunctions9topn_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
- TOPN_UPDATE_MORE_PARAM_SYMBOL.get(t),
-
"_ZN5doris13TopNFunctions10topn_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
-
"_ZN5doris13TopNFunctions14topn_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
-
"_ZN5doris13TopNFunctions13topn_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
- true, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("topn",
Lists.newArrayList(t, Type.INT), Type.VARCHAR,
+ Type.VARCHAR,
+
"_ZN5doris13TopNFunctions9topn_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+ TOPN_UPDATE_SYMBOL.get(t),
+
"_ZN5doris13TopNFunctions10topn_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+
"_ZN5doris13TopNFunctions14topn_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+
"_ZN5doris13TopNFunctions13topn_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("topn",
Lists.newArrayList(t, Type.INT, Type.INT),
+ Type.VARCHAR, Type.VARCHAR,
+
"_ZN5doris13TopNFunctions9topn_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+ TOPN_UPDATE_MORE_PARAM_SYMBOL.get(t),
+
"_ZN5doris13TopNFunctions10topn_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+
"_ZN5doris13TopNFunctions14topn_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+
"_ZN5doris13TopNFunctions13topn_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true));
+ // vectorized
+ addBuiltin(AggregateFunction.createBuiltin("topn",
Lists.newArrayList(t, Type.INT), Type.VARCHAR,
+ Type.VARCHAR,
+
"_ZN5doris13TopNFunctions9topn_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+ TOPN_UPDATE_SYMBOL.get(t),
+
"_ZN5doris13TopNFunctions10topn_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+
"_ZN5doris13TopNFunctions14topn_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+
"_ZN5doris13TopNFunctions13topn_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true, true));
+ addBuiltin(AggregateFunction.createBuiltin("topn",
Lists.newArrayList(t, Type.INT, Type.INT),
+ Type.VARCHAR, Type.VARCHAR,
+
"_ZN5doris13TopNFunctions9topn_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+ TOPN_UPDATE_MORE_PARAM_SYMBOL.get(t),
+
"_ZN5doris13TopNFunctions10topn_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+
"_ZN5doris13TopNFunctions14topn_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+
"_ZN5doris13TopNFunctions13topn_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true, true));
}
if (STDDEV_UPDATE_SYMBOL.containsKey(t)) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]