zhiqiang-hhhh commented on code in PR #41240:
URL: https://github.com/apache/doris/pull/41240#discussion_r1778272833


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrIntercept.java:
##########
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.agg;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import 
org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic;
+import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+import org.apache.doris.nereids.types.TinyIntType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * AggregateFunction 'regr_intercept'.
+ */
+public class RegrIntercept extends NullableAggregateFunction
+        implements BinaryExpression, ExplicitlyCastableSignature, 
SupportWindowAnalytic {
+
+    public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, 
TinyIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, 
SmallIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, 
IntegerType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, 
BigIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, 
LargeIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, 
DoubleType.INSTANCE),
+            
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD, 
DecimalV3Type.WILDCARD));

Review Comment:
   If you use decimal as return type, it will be not difficult to infer the 
scale of decimal. So just remove the direct support for decimal. Just let 
planner add a cast expr (like cast(decimal as double)) for us.



##########
be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp:
##########
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/aggregate_function_regr_intercept.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
+
+namespace doris::vectorized {
+
+template <typename TX, typename TY>
+AggregateFunctionPtr type_dispatch_for_aggregate_function_regr_intercept(const 
DataTypes& argument_types,
+                                                                         const 
bool& result_is_nullable,
+                                                                         bool 
nullable_input) {
+    using StatFunctionTemplate = RegrInterceptFuncTwoArg<TX, TY>;
+    if (nullable_input) {
+        return creator_without_type::create_ignore_nullable<
+                AggregateFunctionRegrInterceptSimple<StatFunctionTemplate, 
true>>(
+                argument_types, result_is_nullable);
+    } else {
+        return creator_without_type::create_ignore_nullable<
+                AggregateFunctionRegrInterceptSimple<StatFunctionTemplate, 
false>>(
+                argument_types, result_is_nullable);
+    }
+}
+
+AggregateFunctionPtr create_aggregate_function_regr_intercept(const 
std::string& name,
+                                                              const DataTypes& 
argument_types,
+                                                              const bool 
result_is_nullable) {
+    if (argument_types.size() != 2) {
+        LOG(WARNING) << "aggregate function " << name << " requires exactly 2 
arguments";
+        return nullptr;
+    }
+    if (!result_is_nullable) {
+        LOG(WARNING) << "aggregate function " << name << " requires nullable 
result type";
+        return nullptr;
+    }
+    const bool nullable_input = argument_types[0]->is_nullable() || 
argument_types[1]->is_nullable();
+    WhichDataType x_type(remove_nullable(argument_types[0]));
+    WhichDataType y_type(remove_nullable(argument_types[1]));
+
+#define DISPATCH(TX, TY)                                                       
                            \
+    if (x_type.idx == TypeIndex::TX && y_type.idx == TypeIndex::TY)            
                            \
+        return type_dispatch_for_aggregate_function_regr_intercept<TX, 
TY>(argument_types, result_is_nullable, \
+                                                                           
nullable_input);
+#define FOR_ALL_NUMERIC_TYPE_PAIRS(M) \

Review Comment:
   In your impl, signature on FE is just a subset of this combination, so it is 
not necessary to make this code to duplicated. use `FOR_NUMERIC_TYPES` is 
enough. Just make the two arguments have same type.



##########
be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h:
##########
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <cstdint>
+#include <string>
+#include <type_traits>
+
+#include "common/exception.h"
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/io_helper.h"
+namespace doris::vectorized {
+
+template <typename T>
+struct AggregateFunctionRegrInterceptData {
+    UInt64 count = 0;
+    double sum_x {};
+    double sum_y {};
+    double sum_of_x_mul_y {};
+    double sum_of_x_squared {};
+
+    void write(BufferWritable& buf) const {
+        write_binary(sum_x, buf);
+        write_binary(sum_y, buf);
+        write_binary(sum_of_x_mul_y, buf);
+        write_binary(sum_of_x_squared, buf);
+        write_binary(count, buf);
+    }
+
+    void read(BufferReadable& buf) {
+        read_binary(sum_x, buf);
+        read_binary(sum_y, buf);
+        read_binary(sum_of_x_mul_y, buf);
+        read_binary(sum_of_x_squared, buf);
+        read_binary(count, buf);
+    }
+
+    void reset() {
+        sum_x = {};
+        sum_y = {};
+        sum_of_x_mul_y = {};
+        sum_of_x_squared = {};
+        count = 0;
+    }
+
+    double get_intercept_result() const {
+        double denominator = count * sum_of_x_squared - sum_x * sum_x;
+        if (count < 2 || denominator == 0.0) {
+            return std::numeric_limits<double>::quiet_NaN();
+        }
+        double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator;
+        return (sum_y - slope * sum_x) / count;
+    }
+
+    void merge(const AggregateFunctionRegrInterceptData& rhs) {
+        if (rhs.count == 0) {
+            return;
+        }
+        sum_x += rhs.sum_x;
+        sum_y += rhs.sum_y;
+        sum_of_x_mul_y += rhs.sum_of_x_mul_y;
+        sum_of_x_squared += rhs.sum_of_x_squared;
+        count += rhs.count;
+    }
+
+    void add(T value_y, T value_x) {
+        sum_x += value_x;
+        sum_y += value_y;
+        sum_of_x_mul_y += value_x * value_y;
+        sum_of_x_squared += value_x * value_x;
+        count += 1;
+    }
+};
+
+template <typename TX, typename TY>
+struct RegrInterceptFuncTwoArg {
+    using TypeX = TX;
+    using TypeY = TY;
+    using Data = AggregateFunctionRegrInterceptData<Float64>;
+};
+
+template <typename StatFunc, bool NullableInput>
+class AggregateFunctionRegrInterceptSimple
+        : public IAggregateFunctionDataHelper<
+                  typename StatFunc::Data,
+                  AggregateFunctionRegrInterceptSimple<StatFunc, 
NullableInput>> {
+public:
+    using TX = typename StatFunc::TypeX;
+    using TY = typename StatFunc::TypeY;
+    using XInputCol = ColumnVector<TX>;
+    using YInputCol = ColumnVector<TY>;
+    using ResultCol = ColumnVector<Float64>;
+
+    explicit AggregateFunctionRegrInterceptSimple(const DataTypes& 
argument_types_)
+            : IAggregateFunctionDataHelper<
+                      typename StatFunc::Data,
+                      AggregateFunctionRegrInterceptSimple<StatFunc, 
NullableInput>>(argument_types_) {
+        DCHECK(!argument_types_.empty());
+    }
+
+    String get_name() const override { return "regr_intercept"; }
+
+    DataTypePtr get_return_type() const override {
+        return make_nullable(std::make_shared<DataTypeFloat64>());
+    }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
ssize_t row_num,
+             Arena*) const override {
+        if constexpr (NullableInput) {
+            const ColumnNullable& y_column_nullable =
+                    assert_cast<const ColumnNullable&>(*columns[0]);

Review Comment:
   add `TypeCheckOnRelease::Disable` since add function will be called on each 
row



##########
be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h:
##########
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <cstdint>
+#include <string>
+#include <type_traits>
+
+#include "common/exception.h"
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/io_helper.h"
+namespace doris::vectorized {
+
+template <typename T>
+struct AggregateFunctionRegrInterceptData {
+    UInt64 count = 0;
+    double sum_x {};
+    double sum_y {};
+    double sum_of_x_mul_y {};
+    double sum_of_x_squared {};
+
+    void write(BufferWritable& buf) const {
+        write_binary(sum_x, buf);
+        write_binary(sum_y, buf);
+        write_binary(sum_of_x_mul_y, buf);
+        write_binary(sum_of_x_squared, buf);
+        write_binary(count, buf);
+    }
+
+    void read(BufferReadable& buf) {
+        read_binary(sum_x, buf);
+        read_binary(sum_y, buf);
+        read_binary(sum_of_x_mul_y, buf);
+        read_binary(sum_of_x_squared, buf);
+        read_binary(count, buf);
+    }
+
+    void reset() {
+        sum_x = {};
+        sum_y = {};
+        sum_of_x_mul_y = {};
+        sum_of_x_squared = {};
+        count = 0;
+    }
+
+    double get_intercept_result() const {
+        double denominator = count * sum_of_x_squared - sum_x * sum_x;
+        if (count < 2 || denominator == 0.0) {
+            return std::numeric_limits<double>::quiet_NaN();
+        }
+        double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator;
+        return (sum_y - slope * sum_x) / count;
+    }
+
+    void merge(const AggregateFunctionRegrInterceptData& rhs) {
+        if (rhs.count == 0) {
+            return;
+        }
+        sum_x += rhs.sum_x;
+        sum_y += rhs.sum_y;
+        sum_of_x_mul_y += rhs.sum_of_x_mul_y;
+        sum_of_x_squared += rhs.sum_of_x_squared;
+        count += rhs.count;
+    }
+
+    void add(T value_y, T value_x) {
+        sum_x += value_x;
+        sum_y += value_y;
+        sum_of_x_mul_y += value_x * value_y;
+        sum_of_x_squared += value_x * value_x;
+        count += 1;
+    }
+};
+
+template <typename TX, typename TY>

Review Comment:
   Use one type is enough.
   `template <typename T>`



##########
be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h:
##########
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <cstdint>
+#include <string>
+#include <type_traits>
+
+#include "common/exception.h"
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/io_helper.h"
+namespace doris::vectorized {
+
+template <typename T>
+struct AggregateFunctionRegrInterceptData {
+    UInt64 count = 0;
+    double sum_x {};
+    double sum_y {};
+    double sum_of_x_mul_y {};
+    double sum_of_x_squared {};
+
+    void write(BufferWritable& buf) const {
+        write_binary(sum_x, buf);
+        write_binary(sum_y, buf);
+        write_binary(sum_of_x_mul_y, buf);
+        write_binary(sum_of_x_squared, buf);
+        write_binary(count, buf);
+    }
+
+    void read(BufferReadable& buf) {
+        read_binary(sum_x, buf);
+        read_binary(sum_y, buf);
+        read_binary(sum_of_x_mul_y, buf);
+        read_binary(sum_of_x_squared, buf);
+        read_binary(count, buf);
+    }
+
+    void reset() {
+        sum_x = {};
+        sum_y = {};
+        sum_of_x_mul_y = {};
+        sum_of_x_squared = {};
+        count = 0;
+    }
+
+    double get_intercept_result() const {
+        double denominator = count * sum_of_x_squared - sum_x * sum_x;
+        if (count < 2 || denominator == 0.0) {
+            return std::numeric_limits<double>::quiet_NaN();
+        }
+        double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator;
+        return (sum_y - slope * sum_x) / count;
+    }
+
+    void merge(const AggregateFunctionRegrInterceptData& rhs) {
+        if (rhs.count == 0) {
+            return;
+        }
+        sum_x += rhs.sum_x;
+        sum_y += rhs.sum_y;
+        sum_of_x_mul_y += rhs.sum_of_x_mul_y;
+        sum_of_x_squared += rhs.sum_of_x_squared;
+        count += rhs.count;
+    }
+
+    void add(T value_y, T value_x) {
+        sum_x += value_x;
+        sum_y += value_y;
+        sum_of_x_mul_y += value_x * value_y;
+        sum_of_x_squared += value_x * value_x;
+        count += 1;
+    }
+};
+
+template <typename TX, typename TY>
+struct RegrInterceptFuncTwoArg {
+    using TypeX = TX;
+    using TypeY = TY;
+    using Data = AggregateFunctionRegrInterceptData<Float64>;
+};
+
+template <typename StatFunc, bool NullableInput>
+class AggregateFunctionRegrInterceptSimple
+        : public IAggregateFunctionDataHelper<
+                  typename StatFunc::Data,
+                  AggregateFunctionRegrInterceptSimple<StatFunc, 
NullableInput>> {
+public:
+    using TX = typename StatFunc::TypeX;
+    using TY = typename StatFunc::TypeY;
+    using XInputCol = ColumnVector<TX>;
+    using YInputCol = ColumnVector<TY>;
+    using ResultCol = ColumnVector<Float64>;
+
+    explicit AggregateFunctionRegrInterceptSimple(const DataTypes& 
argument_types_)
+            : IAggregateFunctionDataHelper<
+                      typename StatFunc::Data,
+                      AggregateFunctionRegrInterceptSimple<StatFunc, 
NullableInput>>(argument_types_) {
+        DCHECK(!argument_types_.empty());
+    }
+
+    String get_name() const override { return "regr_intercept"; }
+
+    DataTypePtr get_return_type() const override {
+        return make_nullable(std::make_shared<DataTypeFloat64>());
+    }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
ssize_t row_num,
+             Arena*) const override {
+        if constexpr (NullableInput) {
+            const ColumnNullable& y_column_nullable =
+                    assert_cast<const ColumnNullable&>(*columns[0]);
+            const ColumnNullable& x_column_nullable =
+                    assert_cast<const ColumnNullable&>(*columns[1]);

Review Comment:
   `TypeCheckOnRelease::Disable`



##########
be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h:
##########
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <cstdint>
+#include <string>
+#include <type_traits>
+
+#include "common/exception.h"
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/io_helper.h"
+namespace doris::vectorized {
+
+template <typename T>
+struct AggregateFunctionRegrInterceptData {
+    UInt64 count = 0;
+    double sum_x {};
+    double sum_y {};
+    double sum_of_x_mul_y {};
+    double sum_of_x_squared {};
+
+    void write(BufferWritable& buf) const {
+        write_binary(sum_x, buf);
+        write_binary(sum_y, buf);
+        write_binary(sum_of_x_mul_y, buf);
+        write_binary(sum_of_x_squared, buf);
+        write_binary(count, buf);
+    }
+
+    void read(BufferReadable& buf) {
+        read_binary(sum_x, buf);
+        read_binary(sum_y, buf);
+        read_binary(sum_of_x_mul_y, buf);
+        read_binary(sum_of_x_squared, buf);
+        read_binary(count, buf);
+    }
+
+    void reset() {
+        sum_x = {};
+        sum_y = {};
+        sum_of_x_mul_y = {};
+        sum_of_x_squared = {};
+        count = 0;
+    }
+
+    double get_intercept_result() const {
+        double denominator = count * sum_of_x_squared - sum_x * sum_x;
+        if (count < 2 || denominator == 0.0) {
+            return std::numeric_limits<double>::quiet_NaN();
+        }
+        double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator;
+        return (sum_y - slope * sum_x) / count;
+    }
+
+    void merge(const AggregateFunctionRegrInterceptData& rhs) {
+        if (rhs.count == 0) {
+            return;
+        }
+        sum_x += rhs.sum_x;
+        sum_y += rhs.sum_y;
+        sum_of_x_mul_y += rhs.sum_of_x_mul_y;
+        sum_of_x_squared += rhs.sum_of_x_squared;
+        count += rhs.count;
+    }
+
+    void add(T value_y, T value_x) {
+        sum_x += value_x;
+        sum_y += value_y;
+        sum_of_x_mul_y += value_x * value_y;
+        sum_of_x_squared += value_x * value_x;
+        count += 1;
+    }
+};
+
+template <typename TX, typename TY>
+struct RegrInterceptFuncTwoArg {
+    using TypeX = TX;
+    using TypeY = TY;
+    using Data = AggregateFunctionRegrInterceptData<Float64>;

Review Comment:
   maybe `using Data = AggregateFunctionRegrInterceptData<T>;`



##########
be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h:
##########
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <cstdint>
+#include <string>
+#include <type_traits>
+
+#include "common/exception.h"
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/io_helper.h"
+namespace doris::vectorized {
+
+template <typename T>
+struct AggregateFunctionRegrInterceptData {
+    UInt64 count = 0;
+    double sum_x {};
+    double sum_y {};
+    double sum_of_x_mul_y {};
+    double sum_of_x_squared {};
+
+    void write(BufferWritable& buf) const {
+        write_binary(sum_x, buf);
+        write_binary(sum_y, buf);
+        write_binary(sum_of_x_mul_y, buf);
+        write_binary(sum_of_x_squared, buf);
+        write_binary(count, buf);
+    }
+
+    void read(BufferReadable& buf) {
+        read_binary(sum_x, buf);
+        read_binary(sum_y, buf);
+        read_binary(sum_of_x_mul_y, buf);
+        read_binary(sum_of_x_squared, buf);
+        read_binary(count, buf);
+    }
+
+    void reset() {
+        sum_x = {};
+        sum_y = {};
+        sum_of_x_mul_y = {};
+        sum_of_x_squared = {};
+        count = 0;
+    }
+
+    double get_intercept_result() const {
+        double denominator = count * sum_of_x_squared - sum_x * sum_x;
+        if (count < 2 || denominator == 0.0) {
+            return std::numeric_limits<double>::quiet_NaN();
+        }
+        double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator;
+        return (sum_y - slope * sum_x) / count;
+    }
+
+    void merge(const AggregateFunctionRegrInterceptData& rhs) {
+        if (rhs.count == 0) {
+            return;
+        }
+        sum_x += rhs.sum_x;
+        sum_y += rhs.sum_y;
+        sum_of_x_mul_y += rhs.sum_of_x_mul_y;
+        sum_of_x_squared += rhs.sum_of_x_squared;
+        count += rhs.count;
+    }
+
+    void add(T value_y, T value_x) {
+        sum_x += value_x;
+        sum_y += value_y;
+        sum_of_x_mul_y += value_x * value_y;
+        sum_of_x_squared += value_x * value_x;
+        count += 1;
+    }
+};
+
+template <typename TX, typename TY>
+struct RegrInterceptFuncTwoArg {
+    using TypeX = TX;
+    using TypeY = TY;

Review Comment:
   `using TypeY = TX` or `using Type = T`



##########
be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h:
##########
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cmath>
+#include <cstdint>
+#include <string>
+#include <type_traits>
+
+#include "common/exception.h"
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/io_helper.h"
+namespace doris::vectorized {
+
+template <typename T>
+struct AggregateFunctionRegrInterceptData {
+    UInt64 count = 0;
+    double sum_x {};
+    double sum_y {};
+    double sum_of_x_mul_y {};
+    double sum_of_x_squared {};
+
+    void write(BufferWritable& buf) const {
+        write_binary(sum_x, buf);
+        write_binary(sum_y, buf);
+        write_binary(sum_of_x_mul_y, buf);
+        write_binary(sum_of_x_squared, buf);
+        write_binary(count, buf);
+    }
+
+    void read(BufferReadable& buf) {
+        read_binary(sum_x, buf);
+        read_binary(sum_y, buf);
+        read_binary(sum_of_x_mul_y, buf);
+        read_binary(sum_of_x_squared, buf);
+        read_binary(count, buf);
+    }
+
+    void reset() {
+        sum_x = {};
+        sum_y = {};
+        sum_of_x_mul_y = {};
+        sum_of_x_squared = {};
+        count = 0;
+    }
+
+    double get_intercept_result() const {
+        double denominator = count * sum_of_x_squared - sum_x * sum_x;
+        if (count < 2 || denominator == 0.0) {
+            return std::numeric_limits<double>::quiet_NaN();
+        }
+        double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator;
+        return (sum_y - slope * sum_x) / count;
+    }
+
+    void merge(const AggregateFunctionRegrInterceptData& rhs) {
+        if (rhs.count == 0) {
+            return;
+        }
+        sum_x += rhs.sum_x;
+        sum_y += rhs.sum_y;
+        sum_of_x_mul_y += rhs.sum_of_x_mul_y;
+        sum_of_x_squared += rhs.sum_of_x_squared;
+        count += rhs.count;
+    }
+
+    void add(T value_y, T value_x) {
+        sum_x += value_x;
+        sum_y += value_y;
+        sum_of_x_mul_y += value_x * value_y;
+        sum_of_x_squared += value_x * value_x;
+        count += 1;
+    }
+};
+
+template <typename TX, typename TY>
+struct RegrInterceptFuncTwoArg {
+    using TypeX = TX;
+    using TypeY = TY;
+    using Data = AggregateFunctionRegrInterceptData<Float64>;
+};
+
+template <typename StatFunc, bool NullableInput>
+class AggregateFunctionRegrInterceptSimple
+        : public IAggregateFunctionDataHelper<
+                  typename StatFunc::Data,
+                  AggregateFunctionRegrInterceptSimple<StatFunc, 
NullableInput>> {
+public:
+    using TX = typename StatFunc::TypeX;
+    using TY = typename StatFunc::TypeY;
+    using XInputCol = ColumnVector<TX>;
+    using YInputCol = ColumnVector<TY>;
+    using ResultCol = ColumnVector<Float64>;
+
+    explicit AggregateFunctionRegrInterceptSimple(const DataTypes& 
argument_types_)
+            : IAggregateFunctionDataHelper<
+                      typename StatFunc::Data,
+                      AggregateFunctionRegrInterceptSimple<StatFunc, 
NullableInput>>(argument_types_) {
+        DCHECK(!argument_types_.empty());
+    }
+
+    String get_name() const override { return "regr_intercept"; }
+
+    DataTypePtr get_return_type() const override {
+        return make_nullable(std::make_shared<DataTypeFloat64>());
+    }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
ssize_t row_num,
+             Arena*) const override {
+        if constexpr (NullableInput) {
+            const ColumnNullable& y_column_nullable =
+                    assert_cast<const ColumnNullable&>(*columns[0]);
+            const ColumnNullable& x_column_nullable =
+                    assert_cast<const ColumnNullable&>(*columns[1]);
+            bool y_null = y_column_nullable.is_null_at(row_num);
+            bool x_null = x_column_nullable.is_null_at(row_num);
+            if (y_null || x_null) {
+                return;
+            } else {
+                TY y_value = assert_cast<const 
YInputCol&>(y_column_nullable.get_nested_column())
+                                        .get_data()[row_num];
+                TX x_value = assert_cast<const 
XInputCol&>(x_column_nullable.get_nested_column())
+                                        .get_data()[row_num];
+                this->data(place).add(static_cast<Float64>(y_value), 
static_cast<Float64>(x_value));

Review Comment:
   1. No need to do static_cast after type argument is changed
   2. why does position of y and x is changed?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to