HappenLee commented on code in PR #50181: URL: https://github.com/apache/doris/pull/50181#discussion_r2057587756
########## be/src/vec/aggregate_functions/aggregate_function_collect.cpp: ########## @@ -17,125 +17,94 @@ #include "vec/aggregate_functions/aggregate_function_collect.h" -#include <fmt/format.h> - -#include <boost/iterator/iterator_facade.hpp> #include <type_traits> +#include "common/exception.h" +#include "common/status.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/helpers.h" namespace doris::vectorized { #include "common/compile_check_begin.h" -template <typename T, typename HasLimit, typename ShowNull> +template <typename T, typename HasLimit> AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const DataTypes& argument_types, const bool result_is_nullable) { - if (argument_types[0]->is_nullable()) { - if constexpr (ShowNull::value) { - return creator_without_type::create_ignore_nullable<AggregateFunctionCollect< - AggregateFunctionArrayAggData<T>, std::false_type, std::true_type>>( - argument_types, result_is_nullable); - } - } - - if constexpr (!std::is_same_v<T, void>) { - if (distinct) { - return creator_without_type::create<AggregateFunctionCollect< - AggregateFunctionCollectSetData<T, HasLimit>, HasLimit, std::false_type>>( - argument_types, result_is_nullable); + if (distinct) { + if constexpr (std::is_same_v<T, void>) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "unexpected type for collect, please check the input"); } else { return creator_without_type::create<AggregateFunctionCollect< - AggregateFunctionCollectListData<T, HasLimit>, HasLimit, std::false_type>>( - argument_types, result_is_nullable); + AggregateFunctionCollectSetData<T, HasLimit>, HasLimit>>(argument_types, + result_is_nullable); } - } else if (!distinct) { - // void type means support array/map/struct type for collect_list - return creator_without_type::create<AggregateFunctionCollect< - AggregateFunctionCollectListData<void, HasLimit>, HasLimit, std::false_type>>( + } else { + return creator_without_type::create< + AggregateFunctionCollect<AggregateFunctionCollectListData<T, HasLimit>, HasLimit>>( argument_types, result_is_nullable); } - return nullptr; } -template <typename HasLimit, typename ShowNull> +template <typename HasLimit> AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - bool distinct = false; - if (name == "collect_set") { - distinct = true; - } + bool distinct = name == "collect_set"; WhichDataType which(remove_nullable(argument_types[0])); -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return do_create_agg_function_collect<TYPE, HasLimit, ShowNull>(distinct, argument_types, \ - result_is_nullable); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return do_create_agg_function_collect<TYPE, HasLimit>(distinct, argument_types, \ + result_is_nullable); FOR_NUMERIC_TYPES(DISPATCH) FOR_DECIMAL_TYPES(DISPATCH) #undef DISPATCH if (which.is_date_or_datetime()) { - return do_create_agg_function_collect<Int64, HasLimit, ShowNull>(distinct, argument_types, - result_is_nullable); + return do_create_agg_function_collect<Int64, HasLimit>(distinct, argument_types, + result_is_nullable); } else if (which.is_date_v2()) { - return do_create_agg_function_collect<UInt32, HasLimit, ShowNull>(distinct, argument_types, - result_is_nullable); + return do_create_agg_function_collect<UInt32, HasLimit>(distinct, argument_types, + result_is_nullable); } else if (which.is_date_time_v2()) { - return do_create_agg_function_collect<UInt64, HasLimit, ShowNull>(distinct, argument_types, - result_is_nullable); + return do_create_agg_function_collect<UInt64, HasLimit>(distinct, argument_types, + result_is_nullable); } else if (which.is_ipv6()) { - return do_create_agg_function_collect<IPv6, HasLimit, ShowNull>(distinct, argument_types, - result_is_nullable); + return do_create_agg_function_collect<IPv6, HasLimit>(distinct, argument_types, + result_is_nullable); } else if (which.is_ipv4()) { - return do_create_agg_function_collect<IPv4, HasLimit, ShowNull>(distinct, argument_types, - result_is_nullable); + return do_create_agg_function_collect<IPv4, HasLimit>(distinct, argument_types, + result_is_nullable); } else if (which.is_string()) { - return do_create_agg_function_collect<StringRef, HasLimit, ShowNull>( - distinct, argument_types, result_is_nullable); + return do_create_agg_function_collect<StringRef, HasLimit>(distinct, argument_types, + result_is_nullable); } else { - // generic serialize which will not use specializations, ShowNull::value always means array_agg - if constexpr (ShowNull::value) { - return do_create_agg_function_collect<void, HasLimit, ShowNull>( - distinct, argument_types, result_is_nullable); - } else { - return do_create_agg_function_collect<void, HasLimit, ShowNull>( - distinct, argument_types, result_is_nullable); - } + // generic serialize which will not use specializations::value always means array_agg + return do_create_agg_function_collect<void, HasLimit>(distinct, argument_types, + result_is_nullable); } - - LOG(WARNING) << fmt::format("unsupported input type {} for aggregate function {}", - argument_types[0]->get_name(), name); - return nullptr; } AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, const AggregateFunctionAttr& attr) { if (argument_types.size() == 1) { - if (name == "array_agg") { - return create_aggregate_function_collect_impl<std::false_type, std::true_type>( - name, argument_types, result_is_nullable); - } else { - return create_aggregate_function_collect_impl<std::false_type, std::false_type>( - name, argument_types, result_is_nullable); - } + return create_aggregate_function_collect_impl<std::false_type>(name, argument_types, + result_is_nullable); } if (argument_types.size() == 2) { - return create_aggregate_function_collect_impl<std::true_type, std::false_type>( - name, argument_types, result_is_nullable); + return create_aggregate_function_collect_impl<std::true_type>(name, argument_types, + result_is_nullable); } - LOG(WARNING) << fmt::format("number of parameters for aggregate function {}, should be 1 or 2", - name); - return nullptr; + throw Exception(ErrorCode::INTERNAL_ERROR, + "unexpected type for collect, please check the input"); } void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory) { // notice: array_agg only differs from collect_list in that array_agg will show null elements in array Review Comment: remove the comment -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org