This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new f58a071605 [Bug][Function] pass intermediate argument list to be 
(#10650)
f58a071605 is described below

commit f58a071605a1aaa7a68a99cdd2f098a5868787e4
Author: Pxl <952130...@qq.com>
AuthorDate: Fri Jul 8 20:50:05 2022 +0800

    [Bug][Function] pass intermediate argument list to be (#10650)
---
 .../aggregate_function_orthogonal_bitmap.cpp       |  2 --
 .../aggregate_function_topn.cpp                    |  5 +----
 .../aggregate_functions/aggregate_function_topn.h  |  8 --------
 be/src/vec/data_types/data_type_factory.hpp        |  4 ++++
 be/src/vec/exprs/vectorized_agg_fn.cpp             | 22 ++++++++++------------
 be/src/vec/exprs/vectorized_agg_fn.h               |  2 +-
 .../org/apache/doris/analysis/AggregateInfo.java   | 15 ++++-----------
 .../apache/doris/analysis/FunctionCallExpr.java    | 20 +++++++++++++++-----
 .../org/apache/doris/analysis/FunctionParams.java  | 15 +++++++++++++++
 gensrc/thrift/Exprs.thrift                         |  1 +
 gensrc/thrift/Types.thrift                         |  1 +
 11 files changed, 52 insertions(+), 43 deletions(-)

diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
index 470a6c8388..9794a72090 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
@@ -34,8 +34,6 @@ AggregateFunctionPtr 
create_aggregate_function_orthogonal(const std::string& nam
         LOG(WARNING) << "Incorrect number of arguments for aggregate function 
" << name;
         return nullptr;
     } else if (argument_types.size() == 1) {
-        // only used at AGGREGATE (merge finalize) for variadic function
-        // and for orthogonal_bitmap_union_count function
         return 
std::make_shared<AggFunctionOrthBitmapFunc<Impl<StringValue>>>(argument_types);
     } else {
         const IDataType& argument_type = *argument_types[1];
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
index 04df93ce67..19f52fbff8 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
@@ -23,10 +23,7 @@ AggregateFunctionPtr create_aggregate_function_topn(const 
std::string& name,
                                                     const DataTypes& 
argument_types,
                                                     const Array& parameters,
                                                     const bool 
result_is_nullable) {
-    if (argument_types.size() == 1) {
-        return AggregateFunctionPtr(
-                new 
AggregateFunctionTopN<AggregateFunctionTopNImplEmpty>(argument_types));
-    } else if (argument_types.size() == 2) {
+    if (argument_types.size() == 2) {
         return AggregateFunctionPtr(
                 new 
AggregateFunctionTopN<AggregateFunctionTopNImplInt<StringDataImplTopN>>(
                         argument_types));
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h 
b/be/src/vec/aggregate_functions/aggregate_function_topn.h
index 97ac5c7cba..ae9fdf322d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h
@@ -168,14 +168,6 @@ struct StringDataImplTopN {
     }
 };
 
-struct AggregateFunctionTopNImplEmpty {
-    // only used at AGGREGATE (merge finalize)
-    static void add(AggregateFunctionTopNData& __restrict place, const 
IColumn** columns,
-                    size_t row_num) {
-        LOG(FATAL) << "AggregateFunctionTopNImplEmpty do not support add()";
-    }
-};
-
 template <typename DataHelper>
 struct AggregateFunctionTopNImplInt {
     static void add(AggregateFunctionTopNData& __restrict place, const 
IColumn** columns,
diff --git a/be/src/vec/data_types/data_type_factory.hpp 
b/be/src/vec/data_types/data_type_factory.hpp
index 08dc6a9f31..59740debd3 100644
--- a/be/src/vec/data_types/data_type_factory.hpp
+++ b/be/src/vec/data_types/data_type_factory.hpp
@@ -102,6 +102,10 @@ public:
 
     DataTypePtr create_data_type(const arrow::DataType* type, bool 
is_nullable);
 
+    DataTypePtr create_data_type(const TTypeDesc& raw_type) {
+        return create_data_type(TypeDescriptor::from_thrift(raw_type), 
raw_type.is_nullable);
+    }
+
 private:
     DataTypePtr _create_primitive_data_type(const FieldType& type) const;
 
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp 
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index ad7066a9a4..b7e14817f1 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -33,7 +33,6 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
         : _fn(desc.fn),
           _is_merge(desc.agg_expr.is_merge_agg),
           _return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)),
-          
_intermediate_type(TypeDescriptor::from_thrift(desc.fn.aggregate_fn.intermediate_type)),
           _intermediate_slot_desc(nullptr),
           _output_slot_desc(nullptr),
           _exec_timer(nullptr),
@@ -44,6 +43,11 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
         nullable = desc.is_nullable;
     }
     _data_type = DataTypeFactory::instance().create_data_type(_return_type, 
nullable);
+
+    auto& param_types = desc.agg_expr.param_types;
+    for (auto raw_type : param_types) {
+        
_argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type));
+    }
 }
 
 Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, 
AggFnEvaluator** result) {
@@ -55,7 +59,7 @@ Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& 
desc, AggFnEvaluato
         VExpr* expr = nullptr;
         VExprContext* ctx = nullptr;
         RETURN_IF_ERROR(
-                VExpr::create_tree_from_thrift(pool, desc.nodes, NULL, 
&node_idx, &expr, &ctx));
+                VExpr::create_tree_from_thrift(pool, desc.nodes, nullptr, 
&node_idx, &expr, &ctx));
         agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx);
     }
     return Status::OK();
@@ -65,25 +69,19 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const 
RowDescriptor& desc, M
                                const SlotDescriptor* intermediate_slot_desc,
                                const SlotDescriptor* output_slot_desc,
                                const std::shared_ptr<MemTracker>& mem_tracker) 
{
-    DCHECK(pool != NULL);
-    DCHECK(intermediate_slot_desc != NULL);
-    DCHECK(_intermediate_slot_desc == NULL);
+    DCHECK(pool != nullptr);
+    DCHECK(intermediate_slot_desc != nullptr);
+    DCHECK(_intermediate_slot_desc == nullptr);
     _output_slot_desc = output_slot_desc;
     _intermediate_slot_desc = intermediate_slot_desc;
 
     Status status = VExpr::prepare(_input_exprs_ctxs, state, desc, 
mem_tracker);
     RETURN_IF_ERROR(status);
 
-    DataTypes argument_types;
-    argument_types.reserve(_input_exprs_ctxs.size());
-
     std::vector<std::string_view> child_expr_name;
 
-    doris::vectorized::Array params;
     // prepare for argument
     for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
-        auto data_type = _input_exprs_ctxs[i]->root()->data_type();
-        argument_types.emplace_back(data_type);
         
child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name());
     }
 
@@ -95,7 +93,7 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const 
RowDescriptor& desc, M
 #endif
     } else {
         _function = AggregateFunctionSimpleFactory::instance().get(
-                _fn.name.function_name, argument_types, params, 
_data_type->is_nullable());
+                _fn.name.function_name, _argument_types, {}, 
_data_type->is_nullable());
     }
     if (_function == nullptr) {
         return Status::InternalError("Agg Function {} is not implemented", 
_fn.name.function_name);
diff --git a/be/src/vec/exprs/vectorized_agg_fn.h 
b/be/src/vec/exprs/vectorized_agg_fn.h
index 0f1f145ced..9a1dbafdcc 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.h
+++ b/be/src/vec/exprs/vectorized_agg_fn.h
@@ -78,8 +78,8 @@ private:
 
     void _calc_argment_columns(Block* block);
 
+    DataTypes _argument_types;
     const TypeDescriptor _return_type;
-    const TypeDescriptor _intermediate_type;
 
     const SlotDescriptor* _intermediate_slot_desc;
     const SlotDescriptor* _output_slot_desc;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
index a855b53de2..a0152a8e8a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
@@ -491,17 +491,10 @@ public final class AggregateInfo extends 
AggregateInfoBase {
         for (int i = 0; i < getAggregateExprs().size(); ++i) {
             FunctionCallExpr inputExpr = getAggregateExprs().get(i);
             Preconditions.checkState(inputExpr.isAggregateFunction());
-            List<Expr> paramExprs = new ArrayList<>();
-            // TODO(zhannngchen), change intermediate argument to a list, and 
remove this
-            // ad-hoc logic
-            if (inputExpr.fn.functionName().equals("max_by")
-                    || inputExpr.fn.functionName().equals("min_by")) {
-                paramExprs.addAll(inputExpr.getFnParams().exprs());
-            } else {
-                paramExprs.add(new SlotRef(inputDesc.getSlots().get(i + 
getGroupingExprs().size())));
-            }
+            Expr aggExprParam =
+                    new SlotRef(inputDesc.getSlots().get(i + 
getGroupingExprs().size()));
             FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
-                    inputExpr, paramExprs);
+                    inputExpr, Lists.newArrayList(aggExprParam), 
inputExpr.getFnParams().exprs());
             aggExpr.analyzeNoThrow(analyzer);
             aggExprs.add(aggExpr);
         }
@@ -623,7 +616,7 @@ public final class AggregateInfo extends AggregateInfoBase {
             Expr aggExprParam =
                     new SlotRef(inputDesc.getSlots().get(i + 
getGroupingExprs().size()));
             FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
-                    inputExpr, Lists.newArrayList(aggExprParam));
+                    inputExpr, Lists.newArrayList(aggExprParam), 
inputExpr.getFnParams().exprs());
             secondPhaseAggExprs.add(aggExpr);
         }
         Preconditions.checkState(
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index 637bf6c158..543c74776f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -37,7 +37,6 @@ import org.apache.doris.common.ErrorReport;
 import org.apache.doris.common.util.VectorizedUtil;
 import org.apache.doris.mysql.privilege.PrivPredicate;
 import org.apache.doris.qe.ConnectContext;
-import org.apache.doris.thrift.TAggregateExpr;
 import org.apache.doris.thrift.TExprNode;
 import org.apache.doris.thrift.TExprNodeType;
 
@@ -69,6 +68,9 @@ public class FunctionCallExpr extends Expr {
     // private BuiltinAggregateFunction.Operator aggOp;
     private FunctionParams fnParams;
 
+    // represent original parament from aggregate function
+    private FunctionParams aggFnParams;
+
     // check analytic function
     private boolean isAnalyticFnCall = false;
     // check table function
@@ -92,6 +94,10 @@ public class FunctionCallExpr extends Expr {
 
     private boolean isRewrote = false;
 
+    public void setAggFnParams(FunctionParams aggFnParams) {
+        this.aggFnParams = aggFnParams;
+    }
+
     public void setIsAnalyticFnCall(boolean v) {
         isAnalyticFnCall = v;
     }
@@ -153,6 +159,7 @@ public class FunctionCallExpr extends Expr {
         // aggOp = e.aggOp;
         isAnalyticFnCall = e.isAnalyticFnCall;
         fnParams = params;
+        aggFnParams = e.aggFnParams;
         // Just inherit the function object from 'e'.
         fn = e.fn;
         this.isMergeAggFn = e.isMergeAggFn;
@@ -175,6 +182,7 @@ public class FunctionCallExpr extends Expr {
         } else {
             fnParams = new FunctionParams(other.fnParams.isDistinct(), 
children);
         }
+        aggFnParams = other.aggFnParams;
         this.isMergeAggFn = other.isMergeAggFn;
         fn = other.fn;
         this.isTableFnCall = other.isTableFnCall;
@@ -428,9 +436,10 @@ public class FunctionCallExpr extends Expr {
         // except in test cases that do it explicitly.
         if (isAggregate() || isAnalyticFnCall) {
             msg.node_type = TExprNodeType.AGG_EXPR;
-            if (!isAnalyticFnCall) {
-                msg.setAggExpr(new TAggregateExpr(isMergeAggFn));
+            if (aggFnParams == null) {
+                aggFnParams = fnParams;
             }
+            msg.setAggExpr(aggFnParams.createTAggregateExpr(isMergeAggFn));
         } else {
             msg.node_type = TExprNodeType.FUNCTION_CALL;
         }
@@ -1143,14 +1152,15 @@ public class FunctionCallExpr extends Expr {
     }
 
     public static FunctionCallExpr createMergeAggCall(
-            FunctionCallExpr agg, List<Expr> params) {
+            FunctionCallExpr agg, List<Expr> intermediateParams, List<Expr> 
realParams) {
         Preconditions.checkState(agg.isAnalyzed);
         Preconditions.checkState(agg.isAggregateFunction());
         FunctionCallExpr result = new FunctionCallExpr(
-                agg.fnName, new FunctionParams(false, params), true);
+                agg.fnName, new FunctionParams(false, intermediateParams), 
true);
         // Inherit the function object from 'agg'.
         result.fn = agg.fn;
         result.type = agg.type;
+        result.setAggFnParams(new FunctionParams(false, realParams));
         return result;
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
index 32cfba0351..3b77ec52b6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionParams.java
@@ -21,12 +21,15 @@
 package org.apache.doris.analysis;
 
 import org.apache.doris.common.io.Writable;
+import org.apache.doris.thrift.TAggregateExpr;
+import org.apache.doris.thrift.TTypeDesc;
 
 import com.google.common.collect.Lists;
 
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 
@@ -62,6 +65,18 @@ public class FunctionParams implements Writable {
         return new FunctionParams();
     }
 
+    public TAggregateExpr createTAggregateExpr(boolean isMergeAggFn) {
+        List<TTypeDesc> paramTypes = new ArrayList<TTypeDesc>();
+        if (exprs != null) {
+            for (Expr expr : exprs) {
+                TTypeDesc desc = expr.getType().toThrift();
+                desc.setIsNullable(expr.isNullable());
+                paramTypes.add(desc);
+            }
+        }
+        return new TAggregateExpr(isMergeAggFn, paramTypes);
+    }
+
     public boolean isStar() {
         return isStar;
     }
diff --git a/gensrc/thrift/Exprs.thrift b/gensrc/thrift/Exprs.thrift
index 450148f381..50c9119410 100644
--- a/gensrc/thrift/Exprs.thrift
+++ b/gensrc/thrift/Exprs.thrift
@@ -73,6 +73,7 @@ enum TExprNodeType {
 struct TAggregateExpr {
   // Indicates whether this expr is the merge() of an aggregation.
   1: required bool is_merge_agg
+  2: required list<Types.TTypeDesc> param_types
 }
 struct TBoolLiteral {
   1: required bool value
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index 381f2879e9..a7212e0476 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -145,6 +145,7 @@ struct TTypeNode {
 // to TTypeDesc. In future, we merge these two to one
 struct TTypeDesc {
     1: list<TTypeNode> types
+    2: optional bool is_nullable
 }
 
 enum TAggregationType {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to