This is an automated email from the ASF dual-hosted git repository.
lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 24ef60b491 [Opt](exec) opt aggreate function performance in nullable
column
24ef60b491 is described below
commit 24ef60b491801e91ff37a6b34642d47f8ab46604
Author: HappenLee <[email protected]>
AuthorDate: Thu Feb 16 22:26:12 2023 +0800
[Opt](exec) opt aggreate function performance in nullable column
---
.../aggregate_functions/aggregate_function_avg.cpp | 21 +-
.../aggregate_functions/aggregate_function_count.h | 3 +-
.../aggregate_function_min_max.cpp | 1 -
.../aggregate_function_null.cpp | 1 -
.../aggregate_functions/aggregate_function_null.h | 245 +++++++++++++++++++++
.../aggregate_functions/aggregate_function_sum.cpp | 24 +-
be/src/vec/aggregate_functions/helpers.h | 149 ++++---------
be/src/vec/data_types/data_type_nullable.cpp | 12 +
be/src/vec/data_types/data_type_nullable.h | 1 +
9 files changed, 339 insertions(+), 118 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
index 5875b831f3..7f9295d8e7 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
@@ -45,11 +45,23 @@ AggregateFunctionPtr create_aggregate_function_avg(const
std::string& name,
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
- if (is_decimal(data_type)) {
- res.reset(
- create_with_decimal_type<AggregateFuncAvg>(*data_type,
*data_type, argument_types));
+ if (data_type->is_nullable()) {
+ auto no_null_argument_types = remove_nullable(argument_types);
+ if (is_decimal(no_null_argument_types[0])) {
+ res.reset(create_with_decimal_type_null<AggregateFuncAvg>(
+ no_null_argument_types, parameters,
*no_null_argument_types[0],
+ no_null_argument_types));
+ } else {
+ res.reset(create_with_numeric_type_null<AggregateFuncAvg>(
+ no_null_argument_types, parameters,
no_null_argument_types));
+ }
} else {
- res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type,
argument_types));
+ if (is_decimal(data_type)) {
+ res.reset(create_with_decimal_type<AggregateFuncAvg>(*data_type,
*data_type,
+
argument_types));
+ } else {
+ res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type,
argument_types));
+ }
}
if (!res) {
@@ -61,5 +73,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const
std::string& name,
void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
factory.register_function("avg", create_aggregate_function_avg);
+ factory.register_function("avg", create_aggregate_function_avg, true);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h
b/be/src/vec/aggregate_functions/aggregate_function_count.h
index 960d4111cb..bc87e4bb10 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_count.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_count.h
@@ -121,7 +121,8 @@ public:
DataTypePtr get_serialized_type() const override { return
std::make_shared<DataTypeUInt64>(); }
};
-/// Simply count number of not-NULL values.
+// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of
AggregateFunctionCount
+// Simply count number of not-NULL values.
class AggregateFunctionCountNotNullUnary final
: public IAggregateFunctionDataHelper<AggregateFunctionCountData,
AggregateFunctionCountNotNullUnary> {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
index 83045dbd00..a01e2ce51a 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
@@ -25,7 +25,6 @@
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
-
/// min, max, any
template <template <typename, bool> class AggregateFunctionTemplate, template
<typename> class Data>
static IAggregateFunction* create_aggregate_function_single_value(const
String& name,
diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.cpp
b/be/src/vec/aggregate_functions/aggregate_function_null.cpp
index 495cefcb84..8ae2368864 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_null.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_null.cpp
@@ -85,7 +85,6 @@ public:
};
void
register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory&
factory) {
- //
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorNull>());
AggregateFunctionCreator creator = [&](const std::string& name, const
DataTypes& types,
const Array& params, const bool
result_is_nullable) {
auto function_combinator =
std::make_shared<AggregateFunctionCombinatorNull>();
diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h
b/be/src/vec/aggregate_functions/aggregate_function_null.h
index 86fe7734e1..69642d0deb 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_null.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_null.h
@@ -40,6 +40,7 @@ namespace doris::vectorized {
/// If all rows had NULL, the behaviour is determined by "result_is_nullable"
template parameter.
/// true - return NULL; false - return value from empty aggregation state of
nested function.
+// TODO: only keep class xxxInline after we support all aggregate function
template <bool result_is_nullable, typename Derived>
class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived> {
protected:
@@ -409,4 +410,248 @@ private:
is_nullable; /// Plain array is better than std::vector due to one
indirection less.
};
+template <typename NestFunction, bool result_is_nullable, typename Derived>
+class AggregateFunctionNullBaseInline : public
IAggregateFunctionHelper<Derived> {
+protected:
+ std::unique_ptr<NestFunction> nested_function;
+ size_t prefix_size;
+
+ /** In addition to data for nested aggregate function, we keep a flag
+ * indicating - was there at least one non-NULL value accumulated.
+ * In case of no not-NULL values, the function will return NULL.
+ *
+ * We use prefix_size bytes for flag to satisfy the alignment requirement
of nested state.
+ */
+
+ AggregateDataPtr nested_place(AggregateDataPtr __restrict place) const
noexcept {
+ return place + prefix_size;
+ }
+
+ ConstAggregateDataPtr nested_place(ConstAggregateDataPtr __restrict place)
const noexcept {
+ return place + prefix_size;
+ }
+
+ static void init_flag(AggregateDataPtr __restrict place) noexcept {
+ if constexpr (result_is_nullable) {
+ place[0] = false;
+ }
+ }
+
+ static void set_flag(AggregateDataPtr __restrict place) noexcept {
+ if constexpr (result_is_nullable) {
+ place[0] = true;
+ }
+ }
+
+ static bool get_flag(ConstAggregateDataPtr __restrict place) noexcept {
+ return result_is_nullable ? place[0] : true;
+ }
+
+public:
+ AggregateFunctionNullBaseInline(IAggregateFunction* nested_function_,
+ const DataTypes& arguments, const Array&
params)
+ : IAggregateFunctionHelper<Derived>(arguments, params),
+ nested_function {assert_cast<NestFunction*>(nested_function_)} {
+ if (result_is_nullable) {
+ prefix_size = nested_function->align_of_data();
+ } else {
+ prefix_size = 0;
+ }
+ }
+
+ String get_name() const override {
+ /// This is just a wrapper. The function for Nullable arguments is
named the same as the nested function itself.
+ return nested_function->get_name();
+ }
+
+ DataTypePtr get_return_type() const override {
+ return result_is_nullable ?
make_nullable(nested_function->get_return_type())
+ : nested_function->get_return_type();
+ }
+
+ void create(AggregateDataPtr __restrict place) const override {
+ init_flag(place);
+ nested_function->create(nested_place(place));
+ }
+
+ void destroy(AggregateDataPtr __restrict place) const noexcept override {
+ nested_function->destroy(nested_place(place));
+ }
+ void reset(AggregateDataPtr place) const override {
+ init_flag(place);
+ nested_function->reset(nested_place(place));
+ }
+
+ bool has_trivial_destructor() const override {
+ return nested_function->has_trivial_destructor();
+ }
+
+ size_t size_of_data() const override { return prefix_size +
nested_function->size_of_data(); }
+
+ size_t align_of_data() const override { return
nested_function->align_of_data(); }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena* arena) const override {
+ if (result_is_nullable && get_flag(rhs)) {
+ set_flag(place);
+ }
+
+ nested_function->merge(nested_place(place), nested_place(rhs), arena);
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ bool flag = get_flag(place);
+ if (result_is_nullable) {
+ write_binary(flag, buf);
+ }
+ if (flag) {
+ nested_function->serialize(nested_place(place), buf);
+ }
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena* arena) const override {
+ bool flag = true;
+ if (result_is_nullable) {
+ read_binary(flag, buf);
+ }
+ if (flag) {
+ set_flag(place);
+ nested_function->deserialize(nested_place(place), buf, arena);
+ }
+ }
+
+ void deserialize_and_merge(AggregateDataPtr __restrict place,
BufferReadable& buf,
+ Arena* arena) const override {
+ bool flag = true;
+ if (result_is_nullable) {
+ read_binary(flag, buf);
+ }
+ if (flag) {
+ set_flag(place);
+ nested_function->deserialize_and_merge(nested_place(place), buf,
arena);
+ }
+ }
+
+ void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column,
+ Arena* arena) const override {
+ size_t num_rows = column.size();
+ for (size_t i = 0; i != num_rows; ++i) {
+ VectorBufferReader buffer_reader(
+ (assert_cast<const ColumnString&>(column)).get_data_at(i));
+ deserialize_and_merge(place, buffer_reader, arena);
+ }
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ if constexpr (result_is_nullable) {
+ ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
+ if (get_flag(place)) {
+ nested_function->insert_result_into(nested_place(place),
+
to_concrete.get_nested_column());
+ to_concrete.get_null_map_data().push_back(0);
+ } else {
+ to_concrete.insert_default();
+ }
+ } else {
+ nested_function->insert_result_into(nested_place(place), to);
+ }
+ }
+
+ bool allocates_memory_in_arena() const override {
+ return nested_function->allocates_memory_in_arena();
+ }
+
+ bool is_state() const override { return nested_function->is_state(); }
+};
+
+/** There are two cases: for single argument and variadic.
+ * Code for single argument is much more efficient.
+ */
+template <typename NestFuction, bool result_is_nullable>
+class AggregateFunctionNullUnaryInline final
+ : public AggregateFunctionNullBaseInline<
+ NestFuction, result_is_nullable,
+ AggregateFunctionNullUnaryInline<NestFuction,
result_is_nullable>> {
+public:
+ AggregateFunctionNullUnaryInline(IAggregateFunction* nested_function_,
+ const DataTypes& arguments, const Array&
params)
+ : AggregateFunctionNullBaseInline<
+ NestFuction, result_is_nullable,
+ AggregateFunctionNullUnaryInline<NestFuction,
result_is_nullable>>(
+ nested_function_, arguments, params) {}
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ Arena* arena) const override {
+ const ColumnNullable* column = assert_cast<const
ColumnNullable*>(columns[0]);
+ if (!column->is_null_at(row_num)) {
+ this->set_flag(place);
+ const IColumn* nested_column = &column->get_nested_column();
+ this->nested_function->add(this->nested_place(place),
&nested_column, row_num, arena);
+ }
+ }
+
+ void add_not_nullable(AggregateDataPtr __restrict place, const IColumn**
columns,
+ size_t row_num, Arena* arena) const {
+ const ColumnNullable* column = assert_cast<const
ColumnNullable*>(columns[0]);
+ this->set_flag(place);
+ const IColumn* nested_column = &column->get_nested_column();
+ this->nested_function->add(this->nested_place(place), &nested_column,
row_num, arena);
+ }
+
+ void add_batch(size_t batch_size, AggregateDataPtr* places, size_t
place_offset,
+ const IColumn** columns, Arena* arena, bool agg_many) const
override {
+ const ColumnNullable* column = assert_cast<const
ColumnNullable*>(columns[0]);
+ // The overhead introduced is negligible here, just an extra memory
read from NullMap
+ const auto* __restrict null_map_data =
column->get_null_map_data().data();
+ const IColumn* nested_column = &column->get_nested_column();
+ for (int i = 0; i < batch_size; ++i) {
+ if (!null_map_data[i]) {
+ AggregateDataPtr __restrict place = places[i] + place_offset;
+ this->set_flag(place);
+ this->nested_function->add(this->nested_place(place),
&nested_column, i, arena);
+ }
+ }
+ }
+
+ void add_batch_single_place(size_t batch_size, AggregateDataPtr place,
const IColumn** columns,
+ Arena* arena) const override {
+ const ColumnNullable* column = assert_cast<const
ColumnNullable*>(columns[0]);
+ bool has_null = column->has_null();
+
+ if (has_null) {
+ for (size_t i = 0; i < batch_size; ++i) {
+ if (!column->is_null_at(i)) {
+ this->set_flag(place);
+ this->add(place, columns, i, arena);
+ }
+ }
+ } else {
+ this->set_flag(place);
+ const IColumn* nested_column = &column->get_nested_column();
+ this->nested_function->add_batch_single_place(batch_size,
this->nested_place(place),
+ &nested_column,
arena);
+ }
+ }
+
+ void add_batch_range(size_t batch_begin, size_t batch_end,
AggregateDataPtr place,
+ const IColumn** columns, Arena* arena, bool has_null)
override {
+ const ColumnNullable* column = assert_cast<const
ColumnNullable*>(columns[0]);
+
+ if (has_null) {
+ for (size_t i = batch_begin; i <= batch_end; ++i) {
+ if (!column->is_null_at(i)) {
+ this->set_flag(place);
+ this->add(place, columns, i, arena);
+ }
+ }
+ } else {
+ this->set_flag(place);
+ const IColumn* nested_column = &column->get_nested_column();
+ this->nested_function->add_batch_range(batch_begin, batch_end,
+ this->nested_place(place),
&nested_column, arena,
+ false);
+ }
+ }
+};
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
index ca40e4196c..75d4d36414 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
@@ -25,6 +25,7 @@
#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
+#include "vec/data_types/data_type_nullable.h"
namespace doris::vectorized {
@@ -45,15 +46,24 @@ AggregateFunctionPtr create_aggregate_function_sum(const
std::string& name,
const DataTypes&
argument_types,
const Array& parameters,
const bool
result_is_nullable) {
- // assert_no_parameters(name, parameters);
- // assert_unary(name, argument_types);
-
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
- if (is_decimal(data_type)) {
- res.reset(create_with_decimal_type<Function>(*data_type, *data_type,
argument_types));
+ if (data_type->is_nullable()) {
+ auto no_null_argument_types = remove_nullable(argument_types);
+ if (is_decimal(no_null_argument_types[0])) {
+
res.reset(create_with_decimal_type_null<Function>(no_null_argument_types,
parameters,
+
*no_null_argument_types[0],
+
no_null_argument_types));
+ } else {
+
res.reset(create_with_numeric_type_null<Function>(no_null_argument_types,
parameters,
+
no_null_argument_types));
+ }
} else {
- res.reset(create_with_numeric_type<Function>(*data_type,
argument_types));
+ if (is_decimal(data_type)) {
+ res.reset(create_with_decimal_type<Function>(*data_type,
*data_type, argument_types));
+ } else {
+ res.reset(create_with_numeric_type<Function>(*data_type,
argument_types));
+ }
}
if (!res) {
@@ -84,6 +94,8 @@ AggregateFunctionPtr
create_aggregate_function_sum_reader(const std::string& nam
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) {
factory.register_function("sum",
create_aggregate_function_sum<AggregateFunctionSumSimple>);
+ factory.register_function("sum",
create_aggregate_function_sum<AggregateFunctionSumSimple>,
+ true);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/helpers.h
b/be/src/vec/aggregate_functions/helpers.h
index 36e11f7011..0970a860a6 100644
--- a/be/src/vec/aggregate_functions/helpers.h
+++ b/be/src/vec/aggregate_functions/helpers.h
@@ -21,8 +21,10 @@
#pragma once
#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_null.h"
#include "vec/data_types/data_type.h"
+// TODO: Should we support decimal in numeric types?
#define FOR_NUMERIC_TYPES(M) \
M(UInt8) \
M(UInt16) \
@@ -36,6 +38,12 @@
M(Float32) \
M(Float64)
+#define FOR_DECIMAL_TYPES(M) \
+ M(Decimal32) \
+ M(Decimal64) \
+ M(Decimal128) \
+ M(Decimal128I)
+
namespace doris::vectorized {
/** Create an aggregate function with a numeric type in the template
parameter, depending on the type of the argument.
@@ -49,12 +57,20 @@ static IAggregateFunction* create_with_numeric_type(const
IDataType& argument_ty
return new
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
- if (which.idx == TypeIndex::Enum8) {
- return new
AggregateFunctionTemplate<Int8>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Enum16) {
- return new
AggregateFunctionTemplate<Int16>(std::forward<TArgs>(args)...);
- }
+ return nullptr;
+}
+
+template <template <typename> class AggregateFunctionTemplate, typename...
TArgs>
+static IAggregateFunction* create_with_numeric_type_null(const DataTypes&
argument_types,
+ const Array& params,
TArgs&&... args) {
+ WhichDataType which(argument_types[0]);
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return new
AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>(
\
+ new
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \
+ params);
+ FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
return nullptr;
}
@@ -68,12 +84,6 @@ static IAggregateFunction* create_with_numeric_type(const
IDataType& argument_ty
return new AggregateFunctionTemplate<TYPE,
bool_param>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
- if (which.idx == TypeIndex::Enum8) {
- return new AggregateFunctionTemplate<Int8,
bool_param>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Enum16) {
- return new AggregateFunctionTemplate<Int16,
bool_param>(std::forward<TArgs>(args)...);
- }
return nullptr;
}
@@ -87,12 +97,6 @@ static IAggregateFunction* create_with_numeric_type(const
IDataType& argument_ty
return new AggregateFunctionTemplate<TYPE,
Data>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
- if (which.idx == TypeIndex::Enum8) {
- return new AggregateFunctionTemplate<Int8,
Data>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Enum16) {
- return new AggregateFunctionTemplate<Int16,
Data>(std::forward<TArgs>(args)...);
- }
return nullptr;
}
@@ -106,12 +110,6 @@ static IAggregateFunction* create_with_numeric_type(const
IDataType& argument_ty
return new AggregateFunctionTemplate<TYPE,
Data<TYPE>>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
- if (which.idx == TypeIndex::Enum8) {
- return new AggregateFunctionTemplate<Int8,
Data<Int8>>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Enum16) {
- return new AggregateFunctionTemplate<Int16,
Data<Int16>>(std::forward<TArgs>(args)...);
- }
return nullptr;
}
@@ -125,70 +123,32 @@ static IAggregateFunction* create_with_numeric_type(const
IDataType& argument_ty
return new
AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
- // if (which.idx == TypeIndex::Enum8) return new
AggregateFunctionTemplate<Data<Int8>>(std::forward<TArgs>(args)...);
- // if (which.idx == TypeIndex::Enum16) return new
AggregateFunctionTemplate<Data<Int16>>(std::forward<TArgs>(args)...);
- return nullptr;
-}
-
-template <template <typename, typename> class AggregateFunctionTemplate,
- template <typename> class Data, typename... TArgs>
-static IAggregateFunction* create_with_unsigned_integer_type(const IDataType&
argument_type,
- TArgs&&... args) {
- WhichDataType which(argument_type);
- if (which.idx == TypeIndex::UInt8) {
- return new AggregateFunctionTemplate<UInt8,
Data<UInt8>>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::UInt16) {
- return new AggregateFunctionTemplate<UInt16,
Data<UInt16>>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::UInt32) {
- return new AggregateFunctionTemplate<UInt32,
Data<UInt32>>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::UInt64) {
- return new AggregateFunctionTemplate<UInt64,
Data<UInt64>>(std::forward<TArgs>(args)...);
- }
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename...
TArgs>
-static IAggregateFunction* create_with_numeric_based_type(const IDataType&
argument_type,
- TArgs&&... args) {
- IAggregateFunction* f =
create_with_numeric_type<AggregateFunctionTemplate>(
- argument_type, std::forward<TArgs>(args)...);
- if (f) {
- return f;
- }
-
- /// expects that DataTypeDate based on UInt16, DataTypeDateTime based on
UInt32 and UUID based on UInt128
+static IAggregateFunction* create_with_decimal_type(const IDataType&
argument_type,
+ TArgs&&... args) {
WhichDataType which(argument_type);
- if (which.idx == TypeIndex::Date) {
- return new
AggregateFunctionTemplate<UInt16>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::DateTime) {
- return new
AggregateFunctionTemplate<UInt32>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::UUID) {
- return new
AggregateFunctionTemplate<UInt128>(std::forward<TArgs>(args)...);
- }
+#define DISPATCH(TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return new
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
+ FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename...
TArgs>
-static IAggregateFunction* create_with_decimal_type(const IDataType&
argument_type,
- TArgs&&... args) {
- WhichDataType which(argument_type);
- if (which.idx == TypeIndex::Decimal32) {
- return new
AggregateFunctionTemplate<Decimal32>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Decimal64) {
- return new
AggregateFunctionTemplate<Decimal64>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Decimal128) {
- return new
AggregateFunctionTemplate<Decimal128>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Decimal128I) {
- return new
AggregateFunctionTemplate<Decimal128I>(std::forward<TArgs>(args)...);
- }
+static IAggregateFunction* create_with_decimal_type_null(const DataTypes&
argument_types,
+ const Array& params,
TArgs&&... args) {
+ WhichDataType which(argument_types[0]);
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return new
AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>(
\
+ new
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...), argument_types, \
+ params);
+ FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
return nullptr;
}
@@ -197,18 +157,11 @@ template <template <typename, typename> class
AggregateFunctionTemplate, typenam
static IAggregateFunction* create_with_decimal_type(const IDataType&
argument_type,
TArgs&&... args) {
WhichDataType which(argument_type);
- if (which.idx == TypeIndex::Decimal32) {
- return new AggregateFunctionTemplate<Decimal32,
Data>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Decimal64) {
- return new AggregateFunctionTemplate<Decimal64,
Data>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Decimal128) {
- return new AggregateFunctionTemplate<Decimal128,
Data>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Decimal128I) {
- return new AggregateFunctionTemplate<Decimal128I,
Data>(std::forward<TArgs>(args)...);
- }
+#define DISPATCH(TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return new AggregateFunctionTemplate<TYPE,
Data>(std::forward<TArgs>(args)...);
+ FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
return nullptr;
}
@@ -224,12 +177,6 @@ static IAggregateFunction*
create_with_two_numeric_types_second(const IDataType&
return new AggregateFunctionTemplate<FirstType,
TYPE>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
- if (which.idx == TypeIndex::Enum8) {
- return new AggregateFunctionTemplate<FirstType,
Int8>(std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Enum16) {
- return new AggregateFunctionTemplate<FirstType,
Int16>(std::forward<TArgs>(args)...);
- }
return nullptr;
}
@@ -244,14 +191,6 @@ static IAggregateFunction*
create_with_two_numeric_types(const IDataType& first_
second_type, std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
- if (which.idx == TypeIndex::Enum8) {
- return create_with_two_numeric_types_second<Int8,
AggregateFunctionTemplate>(
- second_type, std::forward<TArgs>(args)...);
- }
- if (which.idx == TypeIndex::Enum16) {
- return create_with_two_numeric_types_second<Int16,
AggregateFunctionTemplate>(
- second_type, std::forward<TArgs>(args)...);
- }
return nullptr;
}
diff --git a/be/src/vec/data_types/data_type_nullable.cpp
b/be/src/vec/data_types/data_type_nullable.cpp
index 6f69145504..e86cf77a79 100644
--- a/be/src/vec/data_types/data_type_nullable.cpp
+++ b/be/src/vec/data_types/data_type_nullable.cpp
@@ -158,4 +158,16 @@ DataTypePtr remove_nullable(const DataTypePtr& type) {
return type;
}
+DataTypes remove_nullable(const DataTypes& types) {
+ DataTypes no_null_types;
+ for (auto& type : types) {
+ if (type->is_nullable()) {
+ no_null_types.push_back(static_cast<const
DataTypeNullable&>(*type).get_nested_type());
+ } else {
+ no_null_types.push_back(type);
+ }
+ }
+ return no_null_types;
+}
+
} // namespace doris::vectorized
diff --git a/be/src/vec/data_types/data_type_nullable.h
b/be/src/vec/data_types/data_type_nullable.h
index 32488ca35e..d8e6bf22b2 100644
--- a/be/src/vec/data_types/data_type_nullable.h
+++ b/be/src/vec/data_types/data_type_nullable.h
@@ -93,5 +93,6 @@ private:
DataTypePtr make_nullable(const DataTypePtr& type);
DataTypePtr remove_nullable(const DataTypePtr& type);
+DataTypes remove_nullable(const DataTypes& types);
} // namespace doris::vectorized
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]