This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new f58a071605 [Bug][Function] pass intermediate argument list to be (#10650) f58a071605 is described below commit f58a071605a1aaa7a68a99cdd2f098a5868787e4 Author: Pxl <952130...@qq.com> AuthorDate: Fri Jul 8 20:50:05 2022 +0800 [Bug][Function] pass intermediate argument list to be (#10650) --- .../aggregate_function_orthogonal_bitmap.cpp | 2 -- .../aggregate_function_topn.cpp | 5 +---- .../aggregate_functions/aggregate_function_topn.h | 8 -------- be/src/vec/data_types/data_type_factory.hpp | 4 ++++ be/src/vec/exprs/vectorized_agg_fn.cpp | 22 ++++++++++------------ be/src/vec/exprs/vectorized_agg_fn.h | 2 +- .../org/apache/doris/analysis/AggregateInfo.java | 15 ++++----------- .../apache/doris/analysis/FunctionCallExpr.java | 20 +++++++++++++++----- .../org/apache/doris/analysis/FunctionParams.java | 15 +++++++++++++++ gensrc/thrift/Exprs.thrift | 1 + gensrc/thrift/Types.thrift | 1 + 11 files changed, 52 insertions(+), 43 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp index 470a6c8388..9794a72090 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp @@ -34,8 +34,6 @@ AggregateFunctionPtr create_aggregate_function_orthogonal(const std::string& nam LOG(WARNING) << "Incorrect number of arguments for aggregate function " << name; return nullptr; } else if (argument_types.size() == 1) { - // only used at AGGREGATE (merge finalize) for variadic function - // and for orthogonal_bitmap_union_count function return std::make_shared<AggFunctionOrthBitmapFunc<Impl<StringValue>>>(argument_types); } else { const IDataType& argument_type = *argument_types[1]; diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp index 04df93ce67..19f52fbff8 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp @@ -23,10 +23,7 @@ AggregateFunctionPtr create_aggregate_function_topn(const std::string& name, const DataTypes& argument_types, const Array& parameters, const bool result_is_nullable) { - if (argument_types.size() == 1) { - return AggregateFunctionPtr( - new AggregateFunctionTopN<AggregateFunctionTopNImplEmpty>(argument_types)); - } else if (argument_types.size() == 2) { + if (argument_types.size() == 2) { return AggregateFunctionPtr( new AggregateFunctionTopN<AggregateFunctionTopNImplInt<StringDataImplTopN>>( argument_types)); diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h b/be/src/vec/aggregate_functions/aggregate_function_topn.h index 97ac5c7cba..ae9fdf322d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_topn.h +++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h @@ -168,14 +168,6 @@ struct StringDataImplTopN { } }; -struct AggregateFunctionTopNImplEmpty { - // only used at AGGREGATE (merge finalize) - static void add(AggregateFunctionTopNData& __restrict place, const IColumn** columns, - size_t row_num) { - LOG(FATAL) << "AggregateFunctionTopNImplEmpty do not support add()"; - } -}; - template <typename DataHelper> struct AggregateFunctionTopNImplInt { static void add(AggregateFunctionTopNData& __restrict place, const IColumn** columns, diff --git a/be/src/vec/data_types/data_type_factory.hpp b/be/src/vec/data_types/data_type_factory.hpp index 08dc6a9f31..59740debd3 100644 --- a/be/src/vec/data_types/data_type_factory.hpp +++ b/be/src/vec/data_types/data_type_factory.hpp @@ -102,6 +102,10 @@ public: DataTypePtr create_data_type(const arrow::DataType* type, bool is_nullable); + DataTypePtr create_data_type(const TTypeDesc& raw_type) { + return create_data_type(TypeDescriptor::from_thrift(raw_type), raw_type.is_nullable); + } + private: DataTypePtr _create_primitive_data_type(const FieldType& type) const; diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index ad7066a9a4..b7e14817f1 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -33,7 +33,6 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) : _fn(desc.fn), _is_merge(desc.agg_expr.is_merge_agg), _return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)), - _intermediate_type(TypeDescriptor::from_thrift(desc.fn.aggregate_fn.intermediate_type)), _intermediate_slot_desc(nullptr), _output_slot_desc(nullptr), _exec_timer(nullptr), @@ -44,6 +43,11 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) nullable = desc.is_nullable; } _data_type = DataTypeFactory::instance().create_data_type(_return_type, nullable); + + auto& param_types = desc.agg_expr.param_types; + for (auto raw_type : param_types) { + _argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type)); + } } Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluator** result) { @@ -55,7 +59,7 @@ Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluato VExpr* expr = nullptr; VExprContext* ctx = nullptr; RETURN_IF_ERROR( - VExpr::create_tree_from_thrift(pool, desc.nodes, NULL, &node_idx, &expr, &ctx)); + VExpr::create_tree_from_thrift(pool, desc.nodes, nullptr, &node_idx, &expr, &ctx)); agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx); } return Status::OK(); @@ -65,25 +69,19 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, M const SlotDescriptor* intermediate_slot_desc, const SlotDescriptor* output_slot_desc, const std::shared_ptr<MemTracker>& mem_tracker) { - DCHECK(pool != NULL); - DCHECK(intermediate_slot_desc != NULL); - DCHECK(_intermediate_slot_desc == NULL); + DCHECK(pool != nullptr); + DCHECK(intermediate_slot_desc != nullptr); + DCHECK(_intermediate_slot_desc == nullptr); _output_slot_desc = output_slot_desc; _intermediate_slot_desc = intermediate_slot_desc; Status status = VExpr::prepare(_input_exprs_ctxs, state, desc, mem_tracker); RETURN_IF_ERROR(status); - DataTypes argument_types; - argument_types.reserve(_input_exprs_ctxs.size()); - std::vector<std::string_view> child_expr_name; - doris::vectorized::Array params; // prepare for argument for (int i = 0; i < _input_exprs_ctxs.size(); ++i) { - auto data_type = _input_exprs_ctxs[i]->root()->data_type(); - argument_types.emplace_back(data_type); child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name()); } @@ -95,7 +93,7 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, M #endif } else { _function = AggregateFunctionSimpleFactory::instance().get( - _fn.name.function_name, argument_types, params, _data_type->is_nullable()); + _fn.name.function_name, _argument_types, {}, _data_type->is_nullable()); } if (_function == nullptr) { return Status::InternalError("Agg Function {} is not implemented", _fn.name.function_name); diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h index 0f1f145ced..9a1dbafdcc 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.h +++ b/be/src/vec/exprs/vectorized_agg_fn.h @@ -78,8 +78,8 @@ private: void _calc_argment_columns(Block* block); + DataTypes _argument_types; const TypeDescriptor _return_type; - const TypeDescriptor _intermediate_type; const SlotDescriptor* _intermediate_slot_desc; const SlotDescriptor* _output_slot_desc; diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java index a855b53de2..a0152a8e8a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java @@ -491,17 +491,10 @@ public final class AggregateInfo extends AggregateInfoBase { for (int i = 0; i < getAggregateExprs().size(); ++i) { FunctionCallExpr inputExpr = getAggregateExprs().get(i); Preconditions.checkState(inputExpr.isAggregateFunction()); - List<Expr> paramExprs = new ArrayList<>(); - // TODO(zhannngchen), change intermediate argument to a list, and remove this - // ad-hoc logic - if (inputExpr.fn.functionName().equals("max_by") - || inputExpr.fn.functionName().equals("min_by")) { - paramExprs.addAll(inputExpr.getFnParams().exprs()); - } else { - paramExprs.add(new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size()))); - } + Expr aggExprParam = + new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size())); FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall( - inputExpr, paramExprs); + inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs()); aggExpr.analyzeNoThrow(analyzer); aggExprs.add(aggExpr); } @@ -623,7 +616,7 @@ public final class AggregateInfo extends AggregateInfoBase { Expr aggExprParam = new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size())); FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall( - inputExpr, Lists.newArrayList(aggExprParam)); + inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs()); secondPhaseAggExprs.add(aggExpr); } Preconditions.checkState( diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 637bf6c158..543c74776f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -37,7 +37,6 @@ import org.apache.doris.common.ErrorReport; import org.apache.doris.common.util.VectorizedUtil; import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.qe.ConnectContext; -import org.apache.doris.thrift.TAggregateExpr; import org.apache.doris.thrift.TExprNode; import org.apache.doris.thrift.TExprNodeType; @@ -69,6 +68,9 @@ public class FunctionCallExpr extends Expr { // private BuiltinAggregateFunction.Operator aggOp; private FunctionParams fnParams; + // represent original parament from aggregate function + private FunctionParams aggFnParams; + // check analytic function private boolean isAnalyticFnCall = false; // check table function @@ -92,6 +94,10 @@ public class FunctionCallExpr extends Expr { private boolean isRewrote = false; + public void setAggFnParams(FunctionParams aggFnParams) { + this.aggFnParams = aggFnParams; + } + public void setIsAnalyticFnCall(boolean v) { isAnalyticFnCall = v; } @@ -153,6 +159,7 @@ public class FunctionCallExpr extends Expr { // aggOp = e.aggOp; isAnalyticFnCall = e.isAnalyticFnCall; fnParams = params; + aggFnParams = e.aggFnParams; // Just inherit the function object from 'e'. fn = e.fn; this.isMergeAggFn = e.isMergeAggFn; @@ -175,6 +182,7 @@ public class FunctionCallExpr extends Expr { } else { fnParams = new FunctionParams(other.fnParams.isDistinct(), children); } + aggFnParams = other.aggFnParams; this.isMergeAggFn = other.isMergeAggFn; fn = other.fn; this.isTableFnCall = other.isTableFnCall; @@ -428,9 +436,10 @@ public class FunctionCallExpr extends Expr { // except in test cases that do it explicitly. if (isAggregate() || isAnalyticFnCall) { msg.node_type = TExprNodeType.AGG_EXPR; - if (!isAnalyticFnCall) { - msg.setAggExpr(new TAggregateExpr(isMergeAggFn)); + if (aggFnParams == null) { + aggFnParams = fnParams; } + msg.setAggExpr(aggFnParams.createTAggregateExpr(isMergeAggFn)); } else { msg.node_type = TExprNodeType.FUNCTION_CALL; } @@ -1143,14 +1152,15 @@ public class FunctionCallExpr extends Expr { } public static FunctionCallExpr createMergeAggCall( - FunctionCallExpr agg, List<Expr> params) { + FunctionCallExpr agg, List<Expr> intermediateParams, List<Expr> realParams) { Preconditions.checkState(agg.isAnalyzed); Preconditions.checkState(agg.isAggregateFunction()); FunctionCallExpr result = new FunctionCallExpr( - agg.fnName, new FunctionParams(false, params), true); + agg.fnName, new FunctionParams(false, intermediateParams), true); // Inherit the function object from 'agg'. result.fn = agg.fn; result.type = agg.type; + result.setAggFnParams(new FunctionParams(false, realParams)); return result; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java index 32cfba0351..3b77ec52b6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java @@ -21,12 +21,15 @@ package org.apache.doris.analysis; import org.apache.doris.common.io.Writable; +import org.apache.doris.thrift.TAggregateExpr; +import org.apache.doris.thrift.TTypeDesc; import com.google.common.collect.Lists; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -62,6 +65,18 @@ public class FunctionParams implements Writable { return new FunctionParams(); } + public TAggregateExpr createTAggregateExpr(boolean isMergeAggFn) { + List<TTypeDesc> paramTypes = new ArrayList<TTypeDesc>(); + if (exprs != null) { + for (Expr expr : exprs) { + TTypeDesc desc = expr.getType().toThrift(); + desc.setIsNullable(expr.isNullable()); + paramTypes.add(desc); + } + } + return new TAggregateExpr(isMergeAggFn, paramTypes); + } + public boolean isStar() { return isStar; } diff --git a/gensrc/thrift/Exprs.thrift b/gensrc/thrift/Exprs.thrift index 450148f381..50c9119410 100644 --- a/gensrc/thrift/Exprs.thrift +++ b/gensrc/thrift/Exprs.thrift @@ -73,6 +73,7 @@ enum TExprNodeType { struct TAggregateExpr { // Indicates whether this expr is the merge() of an aggregation. 1: required bool is_merge_agg + 2: required list<Types.TTypeDesc> param_types } struct TBoolLiteral { 1: required bool value diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index 381f2879e9..a7212e0476 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -145,6 +145,7 @@ struct TTypeNode { // to TTypeDesc. In future, we merge these two to one struct TTypeDesc { 1: list<TTypeNode> types + 2: optional bool is_nullable } enum TAggregationType { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org