This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 0b1d06bfd6 [Vectorized] Support order by aggregate function (#11187) 0b1d06bfd6 is described below commit 0b1d06bfd66ec8ac62d94fa4e3a8578f54516f7e Author: HappenLee <happen...@hotmail.com> AuthorDate: Thu Jul 28 09:12:58 2022 +0800 [Vectorized] Support order by aggregate function (#11187) Co-authored-by: lihaopeng <lihaop...@baidu.com> --- .../aggregate_function_simple_factory.cpp | 2 - .../aggregate_function_sort.cpp | 79 +++------------------- .../aggregate_functions/aggregate_function_sort.h | 73 +++++++++----------- be/src/vec/exec/vaggregation_node.cpp | 4 +- be/src/vec/exec/vanalytic_eval_node.cpp | 2 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 36 ++++++++-- be/src/vec/exprs/vectorized_agg_fn.h | 10 ++- fe/fe-core/src/main/cup/sql_parser.cup | 4 +- .../org/apache/doris/analysis/AggregateInfo.java | 15 ++-- .../apache/doris/analysis/FunctionCallExpr.java | 56 +++++++++++++-- .../java/org/apache/doris/analysis/SelectList.java | 7 ++ .../apache/doris/catalog/AggregateFunction.java | 2 + .../org/apache/doris/planner/AggregationNode.java | 14 ++++ gensrc/thrift/PlanNodes.thrift | 1 + .../data/query/group_concat/test_group_concat.out | 20 ++++++ .../query/group_concat/test_group_concat.groovy | 8 +++ 16 files changed, 192 insertions(+), 141 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index c2169eb223..73779f8ffa 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -84,8 +84,6 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_window_lead_lag(instance); register_aggregate_function_HLL_union_agg(instance); register_aggregate_function_percentile_approx(instance); - - register_aggregate_function_combinator_sort(instance); }); return instance; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_sort.cpp b/be/src/vec/aggregate_functions/aggregate_function_sort.cpp index fbdb16df4f..b0566b829f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sort.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_sort.cpp @@ -19,84 +19,21 @@ #include "vec/aggregate_functions/aggregate_function_combinator.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" -#include "vec/aggregate_functions/helpers.h" #include "vec/common/typeid_cast.h" #include "vec/data_types/data_type_nullable.h" -#include "vec/utils/template_helpers.hpp" namespace doris::vectorized { -class AggregateFunctionCombinatorSort final : public IAggregateFunctionCombinator { -private: - int _sort_column_number; - -public: - AggregateFunctionCombinatorSort(int sort_column_number) - : _sort_column_number(sort_column_number) {} - - String get_name() const override { return "Sort"; } - - DataTypes transform_arguments(const DataTypes& arguments) const override { - if (arguments.size() < _sort_column_number + 2) { - LOG(FATAL) << "Incorrect number of arguments for aggregate function with Sort, " - << arguments.size() << " less than " << _sort_column_number + 2; - } - - DataTypes nested_types; - nested_types.assign(arguments.begin(), arguments.end() - 1 - _sort_column_number); - return nested_types; +AggregateFunctionPtr transform_to_sort_agg_function(const AggregateFunctionPtr& nested_function, + const DataTypes& arguments, + const SortDescription& sort_desc) { + DCHECK(nested_function != nullptr); + if (nested_function == nullptr) { + return nullptr; } - template <int sort_column_number> - struct Reducer { - static void run(AggregateFunctionPtr& function, const AggregateFunctionPtr& nested_function, - const DataTypes& arguments) { - function = std::make_shared< - AggregateFunctionSort<sort_column_number, AggregateFunctionSortData>>( - nested_function, arguments); - } - }; - - AggregateFunctionPtr transform_aggregate_function( - const AggregateFunctionPtr& nested_function, const DataTypes& arguments, - const Array& params, const bool result_is_nullable) const override { - DCHECK(nested_function != nullptr); - if (nested_function == nullptr) { - return nullptr; - } - - AggregateFunctionPtr function = nullptr; - constexpr_int_match<1, 3, Reducer>::run(_sort_column_number, function, nested_function, - arguments); - - return function; - } + return std::make_shared<AggregateFunctionSort<AggregateFunctionSortData>>(nested_function, + arguments, sort_desc); }; -const std::string SORT_FUNCTION_PREFIX = "sort_"; - -void register_aggregate_function_combinator_sort(AggregateFunctionSimpleFactory& factory) { - AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, - const Array& params, const bool result_is_nullable) { - int sort_column_number = std::stoi(name.substr(SORT_FUNCTION_PREFIX.size(), 2)); - auto nested_function_name = name.substr(SORT_FUNCTION_PREFIX.size() + 2); - - auto function_combinator = - std::make_shared<AggregateFunctionCombinatorSort>(sort_column_number); - - auto transform_arguments = function_combinator->transform_arguments(types); - - auto nested_function = - factory.get(nested_function_name, transform_arguments, params, result_is_nullable); - return function_combinator->transform_aggregate_function(nested_function, types, params, - result_is_nullable); - }; - - for (char c = '1'; c <= '3'; c++) { - factory.register_distinct_function_combinator(creator, SORT_FUNCTION_PREFIX + c + "_", - false); - factory.register_distinct_function_combinator(creator, SORT_FUNCTION_PREFIX + c + "_", - true); - } -} } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_sort.h b/be/src/vec/aggregate_functions/aggregate_function_sort.h index 5cad555c37..cd16d29770 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sort.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sort.h @@ -18,6 +18,7 @@ #pragma once #include <string> +#include <utility> #include "vec/aggregate_functions/aggregate_function.h" #include "vec/aggregate_functions/key_holder_helpers.h" @@ -35,10 +36,17 @@ namespace doris::vectorized { -template <int sort_column_size> struct AggregateFunctionSortData { + const SortDescription sort_desc; + Block block; + + // The construct only support the template compiler, useless + AggregateFunctionSortData() {}; + AggregateFunctionSortData(SortDescription sort_desc, const Block& block) + : sort_desc(std::move(sort_desc)), block(block.clone_empty()) {} + void merge(const AggregateFunctionSortData& rhs) { - if (block.is_empty_column()) { + if (block.rows() == 0) { block = rhs.block; } else { for (size_t i = 0; i < block.columns(); i++) { @@ -78,45 +86,18 @@ struct AggregateFunctionSortData { } } - void sort() { - size_t sort_desc_idx = block.columns() - sort_column_size - 1; - StringRef desc_str = - block.get_by_position(sort_desc_idx).column->assume_mutable()->get_data_at(0); - DCHECK(sort_column_size == desc_str.size); - - SortDescription sort_description(sort_column_size); - for (size_t i = 0; i < sort_column_size; i++) { - sort_description[i].column_number = sort_desc_idx + 1 + i; - sort_description[i].direction = (desc_str.data[i] == '0') ? 1 : -1; - sort_description[i].nulls_direction = sort_description[i].direction; - } - - sort_block(block, sort_description, block.rows()); - } - - void try_init(const DataTypes& _arguments) { - if (!block.is_empty_column()) { - return; - } - - for (auto type : _arguments) { - block.insert({type, ""}); - } - } - - Block block; + void sort() { sort_block(block, sort_desc, block.rows()); } }; -template <int sort_column_size, template <int> typename Data> +template <typename Data> class AggregateFunctionSort - : public IAggregateFunctionDataHelper<Data<sort_column_size>, - AggregateFunctionSort<sort_column_size, Data>> { - using DataReal = Data<sort_column_size>; - + : public IAggregateFunctionDataHelper<Data, AggregateFunctionSort<Data>> { private: - static constexpr auto prefix_size = sizeof(DataReal); + static constexpr auto prefix_size = sizeof(Data); AggregateFunctionPtr _nested_func; DataTypes _arguments; + const SortDescription& _sort_desc; + Block _block; AggregateDataPtr get_nested_place(AggregateDataPtr __restrict place) const noexcept { return place + prefix_size; @@ -127,15 +108,20 @@ private: } public: - AggregateFunctionSort(AggregateFunctionPtr nested_func, const DataTypes& arguments) - : IAggregateFunctionDataHelper<DataReal, AggregateFunctionSort>( + AggregateFunctionSort(const AggregateFunctionPtr& nested_func, const DataTypes& arguments, + const SortDescription& sort_desc) + : IAggregateFunctionDataHelper<Data, AggregateFunctionSort>( arguments, nested_func->get_parameters()), _nested_func(nested_func), - _arguments(arguments) {} + _arguments(arguments), + _sort_desc(sort_desc) { + for (const auto& type : _arguments) { + _block.insert({type, ""}); + } + } void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena* arena) const override { - this->data(place).try_init(_arguments); this->data(place).add(columns, _arguments.size(), row_num); } @@ -159,7 +145,7 @@ public: this->data(place).sort(); ColumnRawPtrs arguments_nested; - for (int i = 0; i < _arguments.size() - 1 - sort_column_size; i++) { + for (int i = 0; i < _arguments.size() - _sort_desc.size(); i++) { arguments_nested.emplace_back( this->data(place).block.get_by_position(i).column.get()); } @@ -176,12 +162,12 @@ public: size_t align_of_data() const override { return _nested_func->align_of_data(); } void create(AggregateDataPtr __restrict place) const override { - new (place) DataReal; + new (place) Data(_sort_desc, _block); _nested_func->create(get_nested_place(place)); } void destroy(AggregateDataPtr __restrict place) const noexcept override { - this->data(place).~DataReal(); + this->data(place).~Data(); _nested_func->destroy(get_nested_place(place)); } @@ -190,4 +176,7 @@ public: DataTypePtr get_return_type() const override { return _nested_func->get_return_type(); } }; +AggregateFunctionPtr transform_to_sort_agg_function(const AggregateFunctionPtr& nested_function, + const DataTypes& arguments, + const SortDescription& sort_desc); } // namespace doris::vectorized diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp index e7008957eb..d4325bc17e 100644 --- a/be/src/vec/exec/vaggregation_node.cpp +++ b/be/src/vec/exec/vaggregation_node.cpp @@ -116,8 +116,8 @@ Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) { _aggregate_evaluators.reserve(tnode.agg_node.aggregate_functions.size()); for (int i = 0; i < tnode.agg_node.aggregate_functions.size(); ++i) { AggFnEvaluator* evaluator = nullptr; - RETURN_IF_ERROR( - AggFnEvaluator::create(_pool, tnode.agg_node.aggregate_functions[i], &evaluator)); + RETURN_IF_ERROR(AggFnEvaluator::create(_pool, tnode.agg_node.aggregate_functions[i], + tnode.agg_node.agg_sort_infos[i], &evaluator)); _aggregate_evaluators.push_back(evaluator); } diff --git a/be/src/vec/exec/vanalytic_eval_node.cpp b/be/src/vec/exec/vanalytic_eval_node.cpp index 3906062cc4..dde7cb8453 100644 --- a/be/src/vec/exec/vanalytic_eval_node.cpp +++ b/be/src/vec/exec/vanalytic_eval_node.cpp @@ -127,7 +127,7 @@ Status VAnalyticEvalNode::init(const TPlanNode& tnode, RuntimeState* state) { AggFnEvaluator* evaluator = nullptr; RETURN_IF_ERROR( - AggFnEvaluator::create(_pool, analytic_node.analytic_functions[i], &evaluator)); + AggFnEvaluator::create(_pool, analytic_node.analytic_functions[i], {}, &evaluator)); _agg_functions.emplace_back(evaluator); for (size_t j = 0; j < _agg_expr_ctxs[i].size(); ++j) { _agg_intput_columns[i][j] = _agg_expr_ctxs[i][j]->root()->data_type()->create_column(); diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index 427b90ec2f..527cb40e18 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -23,11 +23,13 @@ #include "vec/aggregate_functions/aggregate_function_java_udaf.h" #include "vec/aggregate_functions/aggregate_function_rpc.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/aggregate_function_sort.h" #include "vec/columns/column_nullable.h" #include "vec/core/materialize_block.h" #include "vec/data_types/data_type_factory.hpp" #include "vec/data_types/data_type_nullable.h" #include "vec/exprs/vexpr.h" + namespace doris::vectorized { AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) @@ -46,12 +48,14 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) _data_type = DataTypeFactory::instance().create_data_type(_return_type, nullable); auto& param_types = desc.agg_expr.param_types; - for (auto raw_type : param_types) { - _argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type)); + for (int i = 0; i < param_types.size(); i++) { + _argument_types_with_sort.push_back( + DataTypeFactory::instance().create_data_type(param_types[i])); } } -Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluator** result) { +Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info, + AggFnEvaluator** result) { *result = pool->add(new AggFnEvaluator(desc.nodes[0])); auto& agg_fn_evaluator = *result; int node_idx = 0; @@ -63,6 +67,22 @@ Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluato VExpr::create_tree_from_thrift(pool, desc.nodes, nullptr, &node_idx, &expr, &ctx)); agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx); } + + auto sort_size = sort_info.ordering_exprs.size(); + auto real_arguments_size = agg_fn_evaluator->_argument_types_with_sort.size() - sort_size; + // Child arguments conatins [real arguments, order by arguments], we pass the arguments + // to the order by functions + for (int i = 0; i < sort_size; ++i) { + agg_fn_evaluator->_sort_description.emplace_back(real_arguments_size + i, + sort_info.is_asc_order[i] == true, + sort_info.nulls_first[i] == true); + } + + // Pass the real arguments to get functions + for (int i = 0; i < real_arguments_size; ++i) { + agg_fn_evaluator->_real_argument_types.emplace_back( + agg_fn_evaluator->_argument_types_with_sort[i]); + } return Status::OK(); } @@ -87,20 +107,24 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, M if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { #ifdef LIBJVM - _function = AggregateJavaUdaf::create(_fn, _argument_types, {}, _data_type); + _function = AggregateJavaUdaf::create(_fn, _real_argument_types, {}, _data_type); #else return Status::InternalError("Java UDAF is disabled since no libjvm is found!"); #endif } else if (_fn.binary_type == TFunctionBinaryType::RPC) { - _function = AggregateRpcUdaf::create(_fn, _argument_types, {}, _data_type); + _function = AggregateRpcUdaf::create(_fn, _real_argument_types, {}, _data_type); } else { _function = AggregateFunctionSimpleFactory::instance().get( - _fn.name.function_name, _argument_types, {}, _data_type->is_nullable()); + _fn.name.function_name, _real_argument_types, {}, _data_type->is_nullable()); } if (_function == nullptr) { return Status::InternalError("Agg Function {} is not implemented", _fn.name.function_name); } + if (!_sort_description.empty()) { + _function = transform_to_sort_agg_function(_function, _argument_types_with_sort, + _sort_description); + } _expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name); return Status::OK(); } diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h index a541257487..52098f0d8f 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.h +++ b/be/src/vec/exprs/vectorized_agg_fn.h @@ -20,6 +20,7 @@ #include "util/runtime_profile.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/core/block.h" +#include "vec/core/sort_description.h" #include "vec/data_types/data_type.h" #include "vec/exprs/vexpr_context.h" @@ -29,7 +30,8 @@ class SlotDescriptor; namespace vectorized { class AggFnEvaluator { public: - static Status create(ObjectPool* pool, const TExpr& desc, AggFnEvaluator** result); + static Status create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info, + AggFnEvaluator** result); Status prepare(RuntimeState* state, const RowDescriptor& desc, MemPool* pool, const SlotDescriptor* intermediate_slot_desc, @@ -80,7 +82,9 @@ private: void _calc_argment_columns(Block* block); - DataTypes _argument_types; + DataTypes _argument_types_with_sort; + DataTypes _real_argument_types; + const TypeDescriptor _return_type; const SlotDescriptor* _intermediate_slot_desc; @@ -93,6 +97,8 @@ private: // input context std::vector<VExprContext*> _input_exprs_ctxs; + SortDescription _sort_description; + DataTypePtr _data_type; AggregateFunctionPtr _function; diff --git a/fe/fe-core/src/main/cup/sql_parser.cup b/fe/fe-core/src/main/cup/sql_parser.cup index 556ab0785f..1042783da3 100644 --- a/fe/fe-core/src/main/cup/sql_parser.cup +++ b/fe/fe-core/src/main/cup/sql_parser.cup @@ -4950,8 +4950,8 @@ non_pred_expr ::= {: RESULT = new FunctionCallExpr(fn_name, exprs); :} //| function_name:fn_name LPAREN RPAREN //{: RESULT = new FunctionCallExpr(fn_name, new ArrayList<Expr>()); :} - //| function_name:fn_name LPAREN function_params:params RPAREN - //{: RESULT = new FunctionCallExpr(fn_name, params); :} + | function_name:fn_name LPAREN function_params:params order_by_clause:o RPAREN + {: RESULT = new FunctionCallExpr(fn_name, params, o); :} | analytic_expr:e {: RESULT = e; :} /* Since "IF" is a keyword, need to special case this function */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java index a0152a8e8a..79e01e0cae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java @@ -496,6 +496,8 @@ public final class AggregateInfo extends AggregateInfoBase { FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall( inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs()); aggExpr.analyzeNoThrow(analyzer); + // do not need analyze in merge stage, just do mark for BE get right function + aggExpr.setOrderByElements(inputExpr.getOrderByElements()); aggExprs.add(aggExpr); } @@ -621,7 +623,6 @@ public final class AggregateInfo extends AggregateInfoBase { } Preconditions.checkState( secondPhaseAggExprs.size() == aggregateExprs.size() + distinctAggExprs.size()); - for (FunctionCallExpr aggExpr : secondPhaseAggExprs) { aggExpr.analyzeNoThrow(analyzer); Preconditions.checkState(aggExpr.isAggregateFunction()); @@ -649,18 +650,16 @@ public final class AggregateInfo extends AggregateInfoBase { int numDistinctParams = 0; if (!isMultiDistinct) { numDistinctParams = distinctAggExprs.get(0).getChildren().size(); - // If we are counting distinct params of group_concat, we cannot include the custom - // separator since it is not a distinct param. - if (distinctAggExprs.get(0).getFnName().getFunction().equalsIgnoreCase("group_concat") - && numDistinctParams == 2) { - --numDistinctParams; - } } else { for (int i = 0; i < distinctAggExprs.size(); i++) { numDistinctParams += distinctAggExprs.get(i).getChildren().size(); } } - + // If we are counting distinct params of group_concat, we cannot include the custom + // separator since it is not a distinct param. + if (distinctAggExprs.get(0).getFnName().getFunction().equalsIgnoreCase("group_concat")) { + numDistinctParams = 1; + } int numOrigGroupingExprs = inputAggInfo.getGroupingExprs().size() - numDistinctParams; Preconditions.checkState( slotDescs.size() == numOrigGroupingExprs + distinctAggExprs.size() diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index ee65f35469..6038d5e3f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -90,9 +90,10 @@ public class FunctionCallExpr extends Expr { // private BuiltinAggregateFunction.Operator aggOp; private FunctionParams fnParams; - // represent original parament from aggregate function private FunctionParams aggFnParams; + private List<OrderByElement> orderByElements = Lists.newArrayList(); + // check analytic function private boolean isAnalyticFnCall = false; // check table function @@ -155,6 +156,27 @@ public class FunctionCallExpr extends Expr { this(fnName, params, false); } + public FunctionCallExpr( + FunctionName fnName, FunctionParams params, List<OrderByElement> orderByElements) throws AnalysisException { + this(fnName, params, false); + this.orderByElements = orderByElements; + if (!orderByElements.isEmpty()) { + if (!VectorizedUtil.isVectorized()) { + throw new AnalysisException( + "ORDER BY for arguments only support in vec exec engine"); + } else if (!AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains( + fnName.getFunction().toLowerCase())) { + throw new AnalysisException( + "ORDER BY not support for the function:" + fnName.getFunction().toLowerCase()); + } else if (params.isDistinct()) { + throw new AnalysisException( + "ORDER BY not support for the distinct, support in the furture:" + + fnName.getFunction().toLowerCase()); + } + } + setChildren(); + } + private FunctionCallExpr( FunctionName fnName, FunctionParams params, boolean isMergeAggFn) { super(); @@ -187,6 +209,7 @@ public class FunctionCallExpr extends Expr { protected FunctionCallExpr(FunctionCallExpr other) { super(other); fnName = other.fnName; + orderByElements = other.orderByElements; isAnalyticFnCall = other.isAnalyticFnCall; // aggOp = other.aggOp; // fnParams = other.fnParams; @@ -289,6 +312,8 @@ public class FunctionCallExpr extends Expr { || fnName.getFunction().equalsIgnoreCase("sm4_decrypt") || fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))) { result.add("\'***\'"); + } else if (orderByElements.size() > 0 && i == len - orderByElements.size()) { + result.add("ORDER BY " + children.get(i).toSql()); } else { result.add(children.get(i).toSql()); } @@ -503,7 +528,7 @@ public class FunctionCallExpr extends Expr { } if (fnName.getFunction().equalsIgnoreCase("group_concat")) { - if (children.size() > 2 || children.isEmpty()) { + if (children.size() - orderByElements.size() > 2 || children.isEmpty()) { throw new AnalysisException( "group_concat requires one or two parameters: " + this.toSql()); } @@ -514,13 +539,14 @@ public class FunctionCallExpr extends Expr { "group_concat requires first parameter to be of type STRING: " + this.toSql()); } - if (children.size() == 2) { + if (children.size() - orderByElements.size() == 2) { Expr arg1 = getChild(1); if (!arg1.type.isStringType() && !arg1.type.isNull()) { throw new AnalysisException( "group_concat requires second parameter to be of type STRING: " + this.toSql()); } } + return; } @@ -926,6 +952,15 @@ public class FunctionCallExpr extends Expr { childTypes[2] = assignmentCompatibleType; fn = getBuiltinFunction(fnName.getFunction(), childTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); + } else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains( + fnName.getFunction().toLowerCase())) { + // order by elements add as child like windows function. so if we get the + // param of arg, we need remove the order by elements + Type[] childTypes = collectChildReturnTypes(); + Type[] newChildTypes = new Type[children.size() - orderByElements.size()]; + System.arraycopy(childTypes, 0, newChildTypes, 0, newChildTypes.length); + fn = getBuiltinFunction(fnName.getFunction(), newChildTypes, + Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else { // now first find table function in table function sets if (isTableFnCall) { @@ -1024,7 +1059,7 @@ public class FunctionCallExpr extends Expr { Type[] args = fn.getArgs(); if (args.length > 0) { // Implicitly cast all the children to match the function if necessary - for (int i = 0; i < argTypes.length; ++i) { + for (int i = 0; i < argTypes.length - orderByElements.size(); ++i) { // For varargs, we must compare with the last type in callArgs.argTypes. int ix = Math.min(args.length - 1, i); if (!argTypes[i].matchesType(args[ix]) && Config.use_date_v2_by_default @@ -1327,7 +1362,6 @@ public class FunctionCallExpr extends Expr { return result.toString(); } - @Override public void finalizeImplForNereids() throws AnalysisException { // TODO: support other functions // TODO: Supports type conversion to match the type of the function's parameters @@ -1356,4 +1390,16 @@ public class FunctionCallExpr extends Expr { public void setMergeForNereids(boolean isMergeAggFn) { this.isMergeAggFn = isMergeAggFn; } + + public List<OrderByElement> getOrderByElements() { + return orderByElements; + } + + public void setOrderByElements(List<OrderByElement> orderByElements) { + this.orderByElements = orderByElements; + } + + private void setChildren() { + orderByElements.forEach(o -> addChild(o.getExpr())); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java index 77a2084f79..ee950a032e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java @@ -37,6 +37,7 @@ public class SelectList { private boolean isDistinct; private Map<String, String> optHints; + private List<OrderByElement> orderByElements; // /////////////////////////////////////// // BEGIN: Members that need to be reset() @@ -90,6 +91,12 @@ public class SelectList { } } + public void setOrderByElements(List<OrderByElement> orderByElements) { + if (orderByElements != null) { + this.orderByElements = orderByElements; + } + } + public void reset() { for (SelectListItem item : items) { if (!item.isStar()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index 72957c0eff..a58097e9eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -55,6 +55,8 @@ public class AggregateFunction extends Function { public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx"); + public static ImmutableSet<String> SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("group_concat"); + // Set if different from retType_, null otherwise. private Type intermediateType; diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java index c8561b54dc..4cdd1aceee 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java @@ -36,6 +36,7 @@ import org.apache.doris.thrift.TExplainLevel; import org.apache.doris.thrift.TExpr; import org.apache.doris.thrift.TPlanNode; import org.apache.doris.thrift.TPlanNodeType; +import org.apache.doris.thrift.TSortInfo; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; @@ -249,14 +250,27 @@ public class AggregationNode extends PlanNode { protected void toThrift(TPlanNode msg) { msg.node_type = TPlanNodeType.AGGREGATION_NODE; List<TExpr> aggregateFunctions = Lists.newArrayList(); + List<TSortInfo> aggSortInfos = Lists.newArrayList(); // only serialize agg exprs that are being materialized for (FunctionCallExpr e : aggInfo.getMaterializedAggregateExprs()) { aggregateFunctions.add(e.treeToThrift()); + List<TExpr> orderingExpr = Lists.newArrayList(); + List<Boolean> isAscs = Lists.newArrayList(); + List<Boolean> nullFirsts = Lists.newArrayList(); + + e.getOrderByElements().forEach(o -> { + orderingExpr.add(o.getExpr().treeToThrift()); + isAscs.add(o.getIsAsc()); + nullFirsts.add(o.getNullsFirstParam()); + }); + aggSortInfos.add(new TSortInfo(orderingExpr, isAscs, nullFirsts)); } + msg.agg_node = new TAggregationNode( aggregateFunctions, aggInfo.getIntermediateTupleId().asInt(), aggInfo.getOutputTupleId().asInt(), needsFinalize); + msg.agg_node.setAggSortInfos(aggSortInfos); msg.agg_node.setUseStreamingPreaggregation(useStreamingPreagg); List<Expr> groupingExprs = aggInfo.getGroupingExprs(); if (groupingExprs != null) { diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 3024bac292..b3823f1cf6 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -543,6 +543,7 @@ struct TAggregationNode { // rows have been aggregated, and this node is not an intermediate node. 5: required bool need_finalize 6: optional bool use_streaming_preaggregation + 7: optional list<TSortInfo> agg_sort_infos } struct TRepeatNode { diff --git a/regression-test/data/query/group_concat/test_group_concat.out b/regression-test/data/query/group_concat/test_group_concat.out index 94f73cc536..e61bd4fd4f 100644 --- a/regression-test/data/query/group_concat/test_group_concat.out +++ b/regression-test/data/query/group_concat/test_group_concat.out @@ -5,3 +5,23 @@ false, false -- !select -- false +-- !select -- +\N \N +103 255 +1001 1986, 1989 +1002 1989, 32767 +3021 1991, 1992, 32767 +5014 1985, 1991 +25699 1989 +2147483647 255, 1991, 32767, 32767 + +-- !select -- +\N \N +103 255 +1001 1986:1989 +1002 1989:32767 +3021 1991:1992:32767 +5014 1985:1991 +25699 1989 +2147483647 255:1991:32767:32767 + diff --git a/regression-test/suites/query/group_concat/test_group_concat.groovy b/regression-test/suites/query/group_concat/test_group_concat.groovy index 6bb57dea44..12d420cbe0 100644 --- a/regression-test/suites/query/group_concat/test_group_concat.groovy +++ b/regression-test/suites/query/group_concat/test_group_concat.groovy @@ -23,4 +23,12 @@ suite("test_group_concat", "query") { qt_select """ SELECT group_concat(DISTINCT k6) FROM test_query_db.test where k6='false' """ + + qt_select """ + SELECT abs(k3), group_concat(cast(abs(k2) as varchar) order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3) + """ + + qt_select """ + SELECT abs(k3), group_concat(cast(abs(k2) as varchar), ":" order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3) + """ } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org