zhiqiang-hhhh commented on code in PR #41240: URL: https://github.com/apache/doris/pull/41240#discussion_r1778272833
########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrIntercept.java: ########## @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'regr_intercept'. + */ +public class RegrIntercept extends NullableAggregateFunction + implements BinaryExpression, ExplicitlyCastableSignature, SupportWindowAnalytic { + + public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, LargeIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD)); Review Comment: If you use decimal as return type, it will be not difficult to infer the scale of decimal. So just remove the direct support for decimal. Just let planner add a cast expr (like cast(decimal as double)) for us. ########## be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp: ########## @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/aggregate_function_regr_intercept.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" + +namespace doris::vectorized { + +template <typename TX, typename TY> +AggregateFunctionPtr type_dispatch_for_aggregate_function_regr_intercept(const DataTypes& argument_types, + const bool& result_is_nullable, + bool nullable_input) { + using StatFunctionTemplate = RegrInterceptFuncTwoArg<TX, TY>; + if (nullable_input) { + return creator_without_type::create_ignore_nullable< + AggregateFunctionRegrInterceptSimple<StatFunctionTemplate, true>>( + argument_types, result_is_nullable); + } else { + return creator_without_type::create_ignore_nullable< + AggregateFunctionRegrInterceptSimple<StatFunctionTemplate, false>>( + argument_types, result_is_nullable); + } +} + +AggregateFunctionPtr create_aggregate_function_regr_intercept(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { + if (argument_types.size() != 2) { + LOG(WARNING) << "aggregate function " << name << " requires exactly 2 arguments"; + return nullptr; + } + if (!result_is_nullable) { + LOG(WARNING) << "aggregate function " << name << " requires nullable result type"; + return nullptr; + } + const bool nullable_input = argument_types[0]->is_nullable() || argument_types[1]->is_nullable(); + WhichDataType x_type(remove_nullable(argument_types[0])); + WhichDataType y_type(remove_nullable(argument_types[1])); + +#define DISPATCH(TX, TY) \ + if (x_type.idx == TypeIndex::TX && y_type.idx == TypeIndex::TY) \ + return type_dispatch_for_aggregate_function_regr_intercept<TX, TY>(argument_types, result_is_nullable, \ + nullable_input); +#define FOR_ALL_NUMERIC_TYPE_PAIRS(M) \ Review Comment: In your impl, signature on FE is just a subset of this combination, so it is not necessary to make this code to duplicated. use `FOR_NUMERIC_TYPES` is enough. Just make the two arguments have same type. ########## be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h: ########## @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <cstdint> +#include <string> +#include <type_traits> + +#include "common/exception.h" +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" +namespace doris::vectorized { + +template <typename T> +struct AggregateFunctionRegrInterceptData { + UInt64 count = 0; + double sum_x {}; + double sum_y {}; + double sum_of_x_mul_y {}; + double sum_of_x_squared {}; + + void write(BufferWritable& buf) const { + write_binary(sum_x, buf); + write_binary(sum_y, buf); + write_binary(sum_of_x_mul_y, buf); + write_binary(sum_of_x_squared, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_x, buf); + read_binary(sum_y, buf); + read_binary(sum_of_x_mul_y, buf); + read_binary(sum_of_x_squared, buf); + read_binary(count, buf); + } + + void reset() { + sum_x = {}; + sum_y = {}; + sum_of_x_mul_y = {}; + sum_of_x_squared = {}; + count = 0; + } + + double get_intercept_result() const { + double denominator = count * sum_of_x_squared - sum_x * sum_x; + if (count < 2 || denominator == 0.0) { + return std::numeric_limits<double>::quiet_NaN(); + } + double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator; + return (sum_y - slope * sum_x) / count; + } + + void merge(const AggregateFunctionRegrInterceptData& rhs) { + if (rhs.count == 0) { + return; + } + sum_x += rhs.sum_x; + sum_y += rhs.sum_y; + sum_of_x_mul_y += rhs.sum_of_x_mul_y; + sum_of_x_squared += rhs.sum_of_x_squared; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_x += value_x; + sum_y += value_y; + sum_of_x_mul_y += value_x * value_y; + sum_of_x_squared += value_x * value_x; + count += 1; + } +}; + +template <typename TX, typename TY> +struct RegrInterceptFuncTwoArg { + using TypeX = TX; + using TypeY = TY; + using Data = AggregateFunctionRegrInterceptData<Float64>; +}; + +template <typename StatFunc, bool NullableInput> +class AggregateFunctionRegrInterceptSimple + : public IAggregateFunctionDataHelper< + typename StatFunc::Data, + AggregateFunctionRegrInterceptSimple<StatFunc, NullableInput>> { +public: + using TX = typename StatFunc::TypeX; + using TY = typename StatFunc::TypeY; + using XInputCol = ColumnVector<TX>; + using YInputCol = ColumnVector<TY>; + using ResultCol = ColumnVector<Float64>; + + explicit AggregateFunctionRegrInterceptSimple(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper< + typename StatFunc::Data, + AggregateFunctionRegrInterceptSimple<StatFunc, NullableInput>>(argument_types_) { + DCHECK(!argument_types_.empty()); + } + + String get_name() const override { return "regr_intercept"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (NullableInput) { + const ColumnNullable& y_column_nullable = + assert_cast<const ColumnNullable&>(*columns[0]); Review Comment: add `TypeCheckOnRelease::Disable` since add function will be called on each row ########## be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h: ########## @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <cstdint> +#include <string> +#include <type_traits> + +#include "common/exception.h" +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" +namespace doris::vectorized { + +template <typename T> +struct AggregateFunctionRegrInterceptData { + UInt64 count = 0; + double sum_x {}; + double sum_y {}; + double sum_of_x_mul_y {}; + double sum_of_x_squared {}; + + void write(BufferWritable& buf) const { + write_binary(sum_x, buf); + write_binary(sum_y, buf); + write_binary(sum_of_x_mul_y, buf); + write_binary(sum_of_x_squared, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_x, buf); + read_binary(sum_y, buf); + read_binary(sum_of_x_mul_y, buf); + read_binary(sum_of_x_squared, buf); + read_binary(count, buf); + } + + void reset() { + sum_x = {}; + sum_y = {}; + sum_of_x_mul_y = {}; + sum_of_x_squared = {}; + count = 0; + } + + double get_intercept_result() const { + double denominator = count * sum_of_x_squared - sum_x * sum_x; + if (count < 2 || denominator == 0.0) { + return std::numeric_limits<double>::quiet_NaN(); + } + double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator; + return (sum_y - slope * sum_x) / count; + } + + void merge(const AggregateFunctionRegrInterceptData& rhs) { + if (rhs.count == 0) { + return; + } + sum_x += rhs.sum_x; + sum_y += rhs.sum_y; + sum_of_x_mul_y += rhs.sum_of_x_mul_y; + sum_of_x_squared += rhs.sum_of_x_squared; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_x += value_x; + sum_y += value_y; + sum_of_x_mul_y += value_x * value_y; + sum_of_x_squared += value_x * value_x; + count += 1; + } +}; + +template <typename TX, typename TY> Review Comment: Use one type is enough. `template <typename T>` ########## be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h: ########## @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <cstdint> +#include <string> +#include <type_traits> + +#include "common/exception.h" +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" +namespace doris::vectorized { + +template <typename T> +struct AggregateFunctionRegrInterceptData { + UInt64 count = 0; + double sum_x {}; + double sum_y {}; + double sum_of_x_mul_y {}; + double sum_of_x_squared {}; + + void write(BufferWritable& buf) const { + write_binary(sum_x, buf); + write_binary(sum_y, buf); + write_binary(sum_of_x_mul_y, buf); + write_binary(sum_of_x_squared, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_x, buf); + read_binary(sum_y, buf); + read_binary(sum_of_x_mul_y, buf); + read_binary(sum_of_x_squared, buf); + read_binary(count, buf); + } + + void reset() { + sum_x = {}; + sum_y = {}; + sum_of_x_mul_y = {}; + sum_of_x_squared = {}; + count = 0; + } + + double get_intercept_result() const { + double denominator = count * sum_of_x_squared - sum_x * sum_x; + if (count < 2 || denominator == 0.0) { + return std::numeric_limits<double>::quiet_NaN(); + } + double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator; + return (sum_y - slope * sum_x) / count; + } + + void merge(const AggregateFunctionRegrInterceptData& rhs) { + if (rhs.count == 0) { + return; + } + sum_x += rhs.sum_x; + sum_y += rhs.sum_y; + sum_of_x_mul_y += rhs.sum_of_x_mul_y; + sum_of_x_squared += rhs.sum_of_x_squared; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_x += value_x; + sum_y += value_y; + sum_of_x_mul_y += value_x * value_y; + sum_of_x_squared += value_x * value_x; + count += 1; + } +}; + +template <typename TX, typename TY> +struct RegrInterceptFuncTwoArg { + using TypeX = TX; + using TypeY = TY; + using Data = AggregateFunctionRegrInterceptData<Float64>; +}; + +template <typename StatFunc, bool NullableInput> +class AggregateFunctionRegrInterceptSimple + : public IAggregateFunctionDataHelper< + typename StatFunc::Data, + AggregateFunctionRegrInterceptSimple<StatFunc, NullableInput>> { +public: + using TX = typename StatFunc::TypeX; + using TY = typename StatFunc::TypeY; + using XInputCol = ColumnVector<TX>; + using YInputCol = ColumnVector<TY>; + using ResultCol = ColumnVector<Float64>; + + explicit AggregateFunctionRegrInterceptSimple(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper< + typename StatFunc::Data, + AggregateFunctionRegrInterceptSimple<StatFunc, NullableInput>>(argument_types_) { + DCHECK(!argument_types_.empty()); + } + + String get_name() const override { return "regr_intercept"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (NullableInput) { + const ColumnNullable& y_column_nullable = + assert_cast<const ColumnNullable&>(*columns[0]); + const ColumnNullable& x_column_nullable = + assert_cast<const ColumnNullable&>(*columns[1]); Review Comment: `TypeCheckOnRelease::Disable` ########## be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h: ########## @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <cstdint> +#include <string> +#include <type_traits> + +#include "common/exception.h" +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" +namespace doris::vectorized { + +template <typename T> +struct AggregateFunctionRegrInterceptData { + UInt64 count = 0; + double sum_x {}; + double sum_y {}; + double sum_of_x_mul_y {}; + double sum_of_x_squared {}; + + void write(BufferWritable& buf) const { + write_binary(sum_x, buf); + write_binary(sum_y, buf); + write_binary(sum_of_x_mul_y, buf); + write_binary(sum_of_x_squared, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_x, buf); + read_binary(sum_y, buf); + read_binary(sum_of_x_mul_y, buf); + read_binary(sum_of_x_squared, buf); + read_binary(count, buf); + } + + void reset() { + sum_x = {}; + sum_y = {}; + sum_of_x_mul_y = {}; + sum_of_x_squared = {}; + count = 0; + } + + double get_intercept_result() const { + double denominator = count * sum_of_x_squared - sum_x * sum_x; + if (count < 2 || denominator == 0.0) { + return std::numeric_limits<double>::quiet_NaN(); + } + double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator; + return (sum_y - slope * sum_x) / count; + } + + void merge(const AggregateFunctionRegrInterceptData& rhs) { + if (rhs.count == 0) { + return; + } + sum_x += rhs.sum_x; + sum_y += rhs.sum_y; + sum_of_x_mul_y += rhs.sum_of_x_mul_y; + sum_of_x_squared += rhs.sum_of_x_squared; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_x += value_x; + sum_y += value_y; + sum_of_x_mul_y += value_x * value_y; + sum_of_x_squared += value_x * value_x; + count += 1; + } +}; + +template <typename TX, typename TY> +struct RegrInterceptFuncTwoArg { + using TypeX = TX; + using TypeY = TY; + using Data = AggregateFunctionRegrInterceptData<Float64>; Review Comment: maybe `using Data = AggregateFunctionRegrInterceptData<T>;` ########## be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h: ########## @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <cstdint> +#include <string> +#include <type_traits> + +#include "common/exception.h" +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" +namespace doris::vectorized { + +template <typename T> +struct AggregateFunctionRegrInterceptData { + UInt64 count = 0; + double sum_x {}; + double sum_y {}; + double sum_of_x_mul_y {}; + double sum_of_x_squared {}; + + void write(BufferWritable& buf) const { + write_binary(sum_x, buf); + write_binary(sum_y, buf); + write_binary(sum_of_x_mul_y, buf); + write_binary(sum_of_x_squared, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_x, buf); + read_binary(sum_y, buf); + read_binary(sum_of_x_mul_y, buf); + read_binary(sum_of_x_squared, buf); + read_binary(count, buf); + } + + void reset() { + sum_x = {}; + sum_y = {}; + sum_of_x_mul_y = {}; + sum_of_x_squared = {}; + count = 0; + } + + double get_intercept_result() const { + double denominator = count * sum_of_x_squared - sum_x * sum_x; + if (count < 2 || denominator == 0.0) { + return std::numeric_limits<double>::quiet_NaN(); + } + double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator; + return (sum_y - slope * sum_x) / count; + } + + void merge(const AggregateFunctionRegrInterceptData& rhs) { + if (rhs.count == 0) { + return; + } + sum_x += rhs.sum_x; + sum_y += rhs.sum_y; + sum_of_x_mul_y += rhs.sum_of_x_mul_y; + sum_of_x_squared += rhs.sum_of_x_squared; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_x += value_x; + sum_y += value_y; + sum_of_x_mul_y += value_x * value_y; + sum_of_x_squared += value_x * value_x; + count += 1; + } +}; + +template <typename TX, typename TY> +struct RegrInterceptFuncTwoArg { + using TypeX = TX; + using TypeY = TY; Review Comment: `using TypeY = TX` or `using Type = T` ########## be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h: ########## @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <cstdint> +#include <string> +#include <type_traits> + +#include "common/exception.h" +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" +namespace doris::vectorized { + +template <typename T> +struct AggregateFunctionRegrInterceptData { + UInt64 count = 0; + double sum_x {}; + double sum_y {}; + double sum_of_x_mul_y {}; + double sum_of_x_squared {}; + + void write(BufferWritable& buf) const { + write_binary(sum_x, buf); + write_binary(sum_y, buf); + write_binary(sum_of_x_mul_y, buf); + write_binary(sum_of_x_squared, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_x, buf); + read_binary(sum_y, buf); + read_binary(sum_of_x_mul_y, buf); + read_binary(sum_of_x_squared, buf); + read_binary(count, buf); + } + + void reset() { + sum_x = {}; + sum_y = {}; + sum_of_x_mul_y = {}; + sum_of_x_squared = {}; + count = 0; + } + + double get_intercept_result() const { + double denominator = count * sum_of_x_squared - sum_x * sum_x; + if (count < 2 || denominator == 0.0) { + return std::numeric_limits<double>::quiet_NaN(); + } + double slope = (count * sum_of_x_mul_y - sum_x * sum_y) / denominator; + return (sum_y - slope * sum_x) / count; + } + + void merge(const AggregateFunctionRegrInterceptData& rhs) { + if (rhs.count == 0) { + return; + } + sum_x += rhs.sum_x; + sum_y += rhs.sum_y; + sum_of_x_mul_y += rhs.sum_of_x_mul_y; + sum_of_x_squared += rhs.sum_of_x_squared; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_x += value_x; + sum_y += value_y; + sum_of_x_mul_y += value_x * value_y; + sum_of_x_squared += value_x * value_x; + count += 1; + } +}; + +template <typename TX, typename TY> +struct RegrInterceptFuncTwoArg { + using TypeX = TX; + using TypeY = TY; + using Data = AggregateFunctionRegrInterceptData<Float64>; +}; + +template <typename StatFunc, bool NullableInput> +class AggregateFunctionRegrInterceptSimple + : public IAggregateFunctionDataHelper< + typename StatFunc::Data, + AggregateFunctionRegrInterceptSimple<StatFunc, NullableInput>> { +public: + using TX = typename StatFunc::TypeX; + using TY = typename StatFunc::TypeY; + using XInputCol = ColumnVector<TX>; + using YInputCol = ColumnVector<TY>; + using ResultCol = ColumnVector<Float64>; + + explicit AggregateFunctionRegrInterceptSimple(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper< + typename StatFunc::Data, + AggregateFunctionRegrInterceptSimple<StatFunc, NullableInput>>(argument_types_) { + DCHECK(!argument_types_.empty()); + } + + String get_name() const override { return "regr_intercept"; } + + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared<DataTypeFloat64>()); + } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (NullableInput) { + const ColumnNullable& y_column_nullable = + assert_cast<const ColumnNullable&>(*columns[0]); + const ColumnNullable& x_column_nullable = + assert_cast<const ColumnNullable&>(*columns[1]); + bool y_null = y_column_nullable.is_null_at(row_num); + bool x_null = x_column_nullable.is_null_at(row_num); + if (y_null || x_null) { + return; + } else { + TY y_value = assert_cast<const YInputCol&>(y_column_nullable.get_nested_column()) + .get_data()[row_num]; + TX x_value = assert_cast<const XInputCol&>(x_column_nullable.get_nested_column()) + .get_data()[row_num]; + this->data(place).add(static_cast<Float64>(y_value), static_cast<Float64>(x_value)); Review Comment: 1. No need to do static_cast after type argument is changed 2. why does position of y and x is changed? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org