github-actions[bot] commented on code in PR #15339:
URL: https://github.com/apache/doris/pull/15339#discussion_r1111501259
##########
be/src/vec/aggregate_functions/aggregate_function_collect.cpp:
##########
@@ -18,78 +18,88 @@
#include "vec/aggregate_functions/aggregate_function_collect.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
-template <typename T>
-AggregateFunctionPtr create_agg_function_collect(bool distinct, const
DataTypes& argument_types) {
+#define FOR_DECIMAL_TYPES(M) \
+ M(Decimal32) \
+ M(Decimal64) \
+ M(Decimal128) \
+ M(Decimal128I)
+
+template <typename T, typename HasLimit, typename... TArgs>
+AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const
DataTypePtr& argument_type,
+ TArgs... args) {
if (distinct) {
return AggregateFunctionPtr(
- new
AggregateFunctionCollect<AggregateFunctionCollectSetData<T>>(argument_types));
+ new
AggregateFunctionCollect<AggregateFunctionCollectSetData<T, HasLimit>,
+ HasLimit>(argument_type,
+
std::forward<TArgs>(args)...));
} else {
return AggregateFunctionPtr(
- new
AggregateFunctionCollect<AggregateFunctionCollectListData<T>>(argument_types));
+ new
AggregateFunctionCollect<AggregateFunctionCollectListData<T, HasLimit>,
+ HasLimit>(argument_type,
+
std::forward<TArgs>(args)...));
}
}
-AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- if (argument_types.size() != 1) {
- LOG(WARNING) << fmt::format("Illegal number {} of argument for
aggregate function {}",
- argument_types.size(), name);
- return nullptr;
- }
-
+template <typename HasLimit, typename... TArgs>
+AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string&
name,
+ const DataTypePtr&
argument_type,
+ TArgs... args) {
bool distinct = false;
if (name == "collect_set") {
distinct = true;
}
- WhichDataType type(argument_types[0]);
- if (type.is_uint8()) {
- return create_agg_function_collect<UInt8>(distinct, argument_types);
- } else if (type.is_int8()) {
- return create_agg_function_collect<Int8>(distinct, argument_types);
- } else if (type.is_int16()) {
- return create_agg_function_collect<Int16>(distinct, argument_types);
- } else if (type.is_int32()) {
- return create_agg_function_collect<Int32>(distinct, argument_types);
- } else if (type.is_int64()) {
- return create_agg_function_collect<Int64>(distinct, argument_types);
- } else if (type.is_int128()) {
- return create_agg_function_collect<Int128>(distinct, argument_types);
- } else if (type.is_float32()) {
- return create_agg_function_collect<Float32>(distinct, argument_types);
- } else if (type.is_float64()) {
- return create_agg_function_collect<Float64>(distinct, argument_types);
- } else if (type.is_decimal32()) {
- return create_agg_function_collect<Decimal32>(distinct,
argument_types);
- } else if (type.is_decimal64()) {
- return create_agg_function_collect<Decimal64>(distinct,
argument_types);
- } else if (type.is_decimal128()) {
- return create_agg_function_collect<Decimal128>(distinct,
argument_types);
- } else if (type.is_decimal128i()) {
- return create_agg_function_collect<Decimal128I>(distinct,
argument_types);
- } else if (type.is_date()) {
- return create_agg_function_collect<Int64>(distinct, argument_types);
- } else if (type.is_date_time()) {
- return create_agg_function_collect<Int64>(distinct, argument_types);
- } else if (type.is_date_v2()) {
- return create_agg_function_collect<UInt32>(distinct, argument_types);
- } else if (type.is_date_time_v2()) {
- return create_agg_function_collect<UInt64>(distinct, argument_types);
- } else if (type.is_string()) {
- return create_agg_function_collect<StringRef>(distinct,
argument_types);
+ WhichDataType which(argument_type);
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return do_create_agg_function_collect<TYPE, HasLimit>(distinct,
argument_type, \
+
std::forward<TArgs>(args)...);
+ FOR_NUMERIC_TYPES(DISPATCH)
+ FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
+ if (which.is_date_or_datetime()) {
+ return do_create_agg_function_collect<Int64, HasLimit>(distinct,
argument_type,
+
std::forward<TArgs>(args)...);
+ } else if (which.is_date_v2()) {
+ return do_create_agg_function_collect<UInt32, HasLimit>(distinct,
argument_type,
+
std::forward<TArgs>(args)...);
+ } else if (which.is_date_time_v2()) {
+ return do_create_agg_function_collect<UInt64, HasLimit>(distinct,
argument_type,
+
std::forward<TArgs>(args)...);
+ } else if (which.is_string()) {
+ return do_create_agg_function_collect<StringRef, HasLimit>(distinct,
argument_type,
+
std::forward<TArgs>(args)...);
}
LOG(WARNING) << fmt::format("unsupported input type {} for aggregate
function {}",
- argument_types[0]->get_name(), name);
+ argument_type->get_name(), name);
+ return nullptr;
+}
+
+AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable) {
+ if (argument_types.size() == 1) {
+ return create_aggregate_function_collect_impl<std::false_type>(name,
argument_types[0],
+
parameters);
+ }
+ if (argument_types.size() == 2) {
+ return create_aggregate_function_collect_impl<std::true_type>(name,
argument_types[0],
+
parameters);
Review Comment:
warning: use of undeclared identifier 'parameters' [clang-diagnostic-error]
```cpp
parameters);
^
```
##########
be/src/vec/aggregate_functions/aggregate_function_collect.cpp:
##########
@@ -18,78 +18,88 @@
#include "vec/aggregate_functions/aggregate_function_collect.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
-template <typename T>
-AggregateFunctionPtr create_agg_function_collect(bool distinct, const
DataTypes& argument_types) {
+#define FOR_DECIMAL_TYPES(M) \
+ M(Decimal32) \
+ M(Decimal64) \
+ M(Decimal128) \
+ M(Decimal128I)
+
+template <typename T, typename HasLimit, typename... TArgs>
+AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const
DataTypePtr& argument_type,
+ TArgs... args) {
if (distinct) {
return AggregateFunctionPtr(
- new
AggregateFunctionCollect<AggregateFunctionCollectSetData<T>>(argument_types));
+ new
AggregateFunctionCollect<AggregateFunctionCollectSetData<T, HasLimit>,
+ HasLimit>(argument_type,
+
std::forward<TArgs>(args)...));
} else {
return AggregateFunctionPtr(
- new
AggregateFunctionCollect<AggregateFunctionCollectListData<T>>(argument_types));
+ new
AggregateFunctionCollect<AggregateFunctionCollectListData<T, HasLimit>,
+ HasLimit>(argument_type,
+
std::forward<TArgs>(args)...));
}
}
-AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- if (argument_types.size() != 1) {
- LOG(WARNING) << fmt::format("Illegal number {} of argument for
aggregate function {}",
- argument_types.size(), name);
- return nullptr;
- }
-
+template <typename HasLimit, typename... TArgs>
+AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string&
name,
+ const DataTypePtr&
argument_type,
+ TArgs... args) {
bool distinct = false;
if (name == "collect_set") {
distinct = true;
}
- WhichDataType type(argument_types[0]);
- if (type.is_uint8()) {
- return create_agg_function_collect<UInt8>(distinct, argument_types);
- } else if (type.is_int8()) {
- return create_agg_function_collect<Int8>(distinct, argument_types);
- } else if (type.is_int16()) {
- return create_agg_function_collect<Int16>(distinct, argument_types);
- } else if (type.is_int32()) {
- return create_agg_function_collect<Int32>(distinct, argument_types);
- } else if (type.is_int64()) {
- return create_agg_function_collect<Int64>(distinct, argument_types);
- } else if (type.is_int128()) {
- return create_agg_function_collect<Int128>(distinct, argument_types);
- } else if (type.is_float32()) {
- return create_agg_function_collect<Float32>(distinct, argument_types);
- } else if (type.is_float64()) {
- return create_agg_function_collect<Float64>(distinct, argument_types);
- } else if (type.is_decimal32()) {
- return create_agg_function_collect<Decimal32>(distinct,
argument_types);
- } else if (type.is_decimal64()) {
- return create_agg_function_collect<Decimal64>(distinct,
argument_types);
- } else if (type.is_decimal128()) {
- return create_agg_function_collect<Decimal128>(distinct,
argument_types);
- } else if (type.is_decimal128i()) {
- return create_agg_function_collect<Decimal128I>(distinct,
argument_types);
- } else if (type.is_date()) {
- return create_agg_function_collect<Int64>(distinct, argument_types);
- } else if (type.is_date_time()) {
- return create_agg_function_collect<Int64>(distinct, argument_types);
- } else if (type.is_date_v2()) {
- return create_agg_function_collect<UInt32>(distinct, argument_types);
- } else if (type.is_date_time_v2()) {
- return create_agg_function_collect<UInt64>(distinct, argument_types);
- } else if (type.is_string()) {
- return create_agg_function_collect<StringRef>(distinct,
argument_types);
+ WhichDataType which(argument_type);
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return do_create_agg_function_collect<TYPE, HasLimit>(distinct,
argument_type, \
+
std::forward<TArgs>(args)...);
+ FOR_NUMERIC_TYPES(DISPATCH)
+ FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
+ if (which.is_date_or_datetime()) {
+ return do_create_agg_function_collect<Int64, HasLimit>(distinct,
argument_type,
+
std::forward<TArgs>(args)...);
+ } else if (which.is_date_v2()) {
+ return do_create_agg_function_collect<UInt32, HasLimit>(distinct,
argument_type,
+
std::forward<TArgs>(args)...);
+ } else if (which.is_date_time_v2()) {
+ return do_create_agg_function_collect<UInt64, HasLimit>(distinct,
argument_type,
+
std::forward<TArgs>(args)...);
+ } else if (which.is_string()) {
+ return do_create_agg_function_collect<StringRef, HasLimit>(distinct,
argument_type,
+
std::forward<TArgs>(args)...);
}
LOG(WARNING) << fmt::format("unsupported input type {} for aggregate
function {}",
- argument_types[0]->get_name(), name);
+ argument_type->get_name(), name);
+ return nullptr;
+}
+
+AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable) {
+ if (argument_types.size() == 1) {
+ return create_aggregate_function_collect_impl<std::false_type>(name,
argument_types[0],
+
parameters);
Review Comment:
warning: use of undeclared identifier 'parameters' [clang-diagnostic-error]
```cpp
parameters);
^
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]