This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 1a11cefb8118bc7cd5ab46c088163bf99e926f14
Author: nanfeng <nanfeng_...@163.com>
AuthorDate: Wed Feb 7 08:32:06 2024 +0800

    [feature](agg-func) support corr function #30822
---
 .../aggregate_function_binary.h                    | 130 +++++++++++++++++++++
 .../aggregate_function_corr.cpp                    |  92 +++++++++++++++
 .../aggregate_function_simple_factory.cpp          |   3 +
 .../sql-functions/aggregate-functions/corr.md      |  49 ++++++++
 .../sql-functions/aggregate-functions/corr.md      |  50 ++++++++
 .../doris/catalog/BuiltinAggregateFunctions.java   |   2 +
 .../java/org/apache/doris/catalog/FunctionSet.java |  25 ++++
 .../trees/expressions/functions/agg/Corr.java      |  85 ++++++++++++++
 .../visitor/AggregateFunctionVisitor.java          |   5 +
 .../nereids_function_p0/agg_function/test_corr.out |  13 +++
 .../agg_function/test_corr.groovy                  |  85 ++++++++++++++
 11 files changed, 539 insertions(+)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_binary.h 
b/be/src/vec/aggregate_functions/aggregate_function_binary.h
new file mode 100644
index 00000000000..422919c52af
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_binary.h
@@ -0,0 +1,130 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <glog/logging.h>
+
+#include <cmath>
+
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/arithmetic_overflow.h"
+#include "vec/common/string_buffer.hpp"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+
+template <typename T1, typename T2, template <typename> typename Moments>
+struct StatFunc {
+    using Type1 = T1;
+    using Type2 = T2;
+    using ResultType = std::conditional_t<std::is_same_v<T1, T2> && 
std::is_same_v<T1, Float32>,
+                                          Float32, Float64>;
+    using Data = Moments<ResultType>;
+};
+
+template <typename StatFunc>
+struct AggregateFunctionBinary
+        : public IAggregateFunctionDataHelper<typename StatFunc::Data,
+                                              
AggregateFunctionBinary<StatFunc>> {
+    using ResultType = typename StatFunc::ResultType;
+
+    using ColVecT1 = ColumnVectorOrDecimal<typename StatFunc::Type1>;
+    using ColVecT2 = ColumnVectorOrDecimal<typename StatFunc::Type2>;
+    using ColVecResult = ColumnVector<ResultType>;
+    static constexpr UInt32 num_args = 2;
+
+    AggregateFunctionBinary(const DataTypes& argument_types_)
+            : IAggregateFunctionDataHelper<typename StatFunc::Data,
+                                           
AggregateFunctionBinary<StatFunc>>(argument_types_) {}
+
+    String get_name() const override { return StatFunc::Data::name(); }
+
+    DataTypePtr get_return_type() const override {
+        return std::make_shared<DataTypeNumber<ResultType>>();
+    }
+
+    bool allocates_memory_in_arena() const override { return false; }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
+             Arena*) const override {
+        this->data(place).add(
+                static_cast<ResultType>(
+                        static_cast<const 
ColVecT1&>(*columns[0]).get_data()[row_num]),
+                static_cast<ResultType>(
+                        static_cast<const 
ColVecT2&>(*columns[1]).get_data()[row_num]));
+    }
+
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena*) const override {
+        this->data(place).merge(this->data(rhs));
+    }
+
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
+        this->data(place).write(buf);
+    }
+
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena*) const override {
+        this->data(place).read(buf);
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        const auto& data = this->data(place);
+        auto& dst = static_cast<ColVecResult&>(to).get_data();
+        dst.push_back(data.get());
+    }
+};
+
+template <template <typename> typename Moments, typename FirstType, 
typename... TArgs>
+AggregateFunctionPtr create_with_two_basic_numeric_types_second(const 
DataTypePtr& second_type,
+                                                                TArgs&&... 
args) {
+    WhichDataType which(remove_nullable(second_type));
+#define DISPATCH(TYPE)                                                        \
+    if (which.idx == TypeIndex::TYPE)                                         \
+        return creator_without_type::create<                                  \
+                AggregateFunctionBinary<StatFunc<FirstType, TYPE, Moments>>>( \
+                std::forward<TArgs>(args)...);
+    FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
+    return nullptr;
+}
+
+template <template <typename> typename Moments, typename... TArgs>
+AggregateFunctionPtr create_with_two_basic_numeric_types(const DataTypePtr& 
first_type,
+                                                         const DataTypePtr& 
second_type,
+                                                         TArgs&&... args) {
+    WhichDataType which(remove_nullable(first_type));
+#define DISPATCH(TYPE)                                                    \
+    if (which.idx == TypeIndex::TYPE)                                     \
+        return create_with_two_basic_numeric_types_second<Moments, TYPE>( \
+                second_type, std::forward<TArgs>(args)...);
+    FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
+    return nullptr;
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp
new file mode 100644
index 00000000000..fb84e92e0e6
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_binary.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/core/types.h"
+
+namespace doris::vectorized {
+
+template <typename T>
+struct CorrMoment {
+    T m0 {};
+    T x1 {};
+    T y1 {};
+    T xy {};
+    T x2 {};
+    T y2 {};
+
+    void add(T x, T y) {
+        ++m0;
+        x1 += x;
+        y1 += y;
+        xy += x * y;
+        x2 += x * x;
+        y2 += y * y;
+    }
+
+    void merge(const CorrMoment& rhs) {
+        m0 += rhs.m0;
+        x1 += rhs.x1;
+        y1 += rhs.y1;
+        xy += rhs.xy;
+        x2 += rhs.x2;
+        y2 += rhs.y2;
+    }
+
+    void write(BufferWritable& buf) const {
+        write_binary(m0, buf);
+        write_binary(x1, buf);
+        write_binary(y1, buf);
+        write_binary(xy, buf);
+        write_binary(x2, buf);
+        write_binary(y2, buf);
+    }
+
+    void read(BufferReadable& buf) {
+        read_binary(m0, buf);
+        read_binary(x1, buf);
+        read_binary(y1, buf);
+        read_binary(xy, buf);
+        read_binary(x2, buf);
+        read_binary(y2, buf);
+    }
+
+    T get() const {
+        if ((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1) == 0) [[unlikely]] {
+            return 0;
+        }
+        return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 
* y1));
+    }
+
+    static String name() { return "corr"; }
+};
+
+AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
+                                                    const DataTypes& 
argument_types,
+                                                    const bool 
result_is_nullable) {
+    assert_binary(name, argument_types);
+    return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0], 
argument_types[1],
+                                                           argument_types, 
result_is_nullable);
+}
+
+void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& 
factory) {
+    factory.register_function_both("corr", create_aggregate_corr_function);
+}
+
+} // namespace doris::vectorized
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 068e2efaac4..9f99a64f2bc 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -60,6 +60,7 @@ void 
register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& fa
 void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& 
factory);
 void register_aggregate_function_map_agg(AggregateFunctionSimpleFactory& 
factory);
 void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory& 
factory);
+void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& 
factory);
 
 AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
     static std::once_flag oc;
@@ -100,6 +101,8 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
         register_aggregate_function_replace_reader_load(instance);
         register_aggregate_function_window_lead_lag_first_last(instance);
         register_aggregate_function_HLL_union_agg(instance);
+
+        register_aggregate_functions_corr(instance);
     });
     return instance;
 }
diff --git a/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md 
b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md
new file mode 100644
index 00000000000..862dbad02b1
--- /dev/null
+++ b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md
@@ -0,0 +1,49 @@
+---
+{
+    "title": "CORR",
+    "language": "en"
+}
+---
+
+<!-- 
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## CORR
+### Description
+#### Syntax
+
+` double corr(x, y)`
+
+Calculate the Pearson correlation coefficient, which is returned as the 
covariance of x and y divided by the product of the standard deviations of x 
and y. 
+If the standard deviation of x or y is 0, the result will be 0.
+
+### example
+
+```
+mysql> select corr(x,y) from baseall;
++---------------------+
+| corr(x, y)          |
++---------------------+
+| 0.89442719099991586 |
++---------------------+
+1 row in set (0.21 sec)
+
+```
+### keywords
+CORR
diff --git 
a/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md 
b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md
new file mode 100644
index 00000000000..0437d5e9d8f
--- /dev/null
+++ b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md
@@ -0,0 +1,50 @@
+---
+{
+    "title": "CORR",
+    "language": "zh-CN"
+}
+---
+
+<!-- 
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## CORR
+### Description
+#### Syntax
+
+` double corr(x, y)`
+
+计算皮尔逊系数, 即返回结果为: x和y的协方差,除x和y的标准差乘积。
+如果x或y的标准差为0, 将返回0。
+
+
+### example
+
+```
+mysql> select corr(x,y) from baseall;
++---------------------+
+| corr(x, y)          |
++---------------------+
+| 0.89442719099991586 |
++---------------------+
+1 row in set (0.21 sec)
+
+```
+### keywords
+CORR
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
index 5a101e71014..a8fc246d239 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
@@ -28,6 +28,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
 import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
 import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
 import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
 import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
@@ -93,6 +94,7 @@ public class BuiltinAggregateFunctions implements 
FunctionHelper {
             agg(BitmapUnionInt.class, "bitmap_union_int"),
             agg(CollectList.class, "collect_list", "group_array"),
             agg(CollectSet.class, "collect_set", "group_uniq_array"),
+            agg(Corr.class, "corr"),
             agg(Count.class, "count"),
             agg(CountByEnum.class, "count_by_enum"),
             agg(GroupBitAnd.class, "group_bit_and"),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index c589cbbf505..629e4556df2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1710,6 +1710,31 @@ public class FunctionSet<T> {
                 "",
                 false, true, false, true));
 
+        // corr
+        addBuiltin(AggregateFunction.createBuiltin("corr",
+                Lists.<Type>newArrayList(Type.TINYINT, Type.TINYINT), 
Type.DOUBLE, Type.DOUBLE,
+                "", "", "", "", "", "", "",
+                false, false, false, true));
+        addBuiltin(AggregateFunction.createBuiltin("corr",
+                Lists.<Type>newArrayList(Type.SMALLINT, Type.SMALLINT), 
Type.DOUBLE, Type.DOUBLE,
+                "", "", "", "", "", "", "",
+                false, false, false, true));
+        addBuiltin(AggregateFunction.createBuiltin("corr",
+                Lists.<Type>newArrayList(Type.INT, Type.INT), Type.DOUBLE, 
Type.DOUBLE,
+                "", "", "", "", "", "", "",
+                false, false, false, true));
+        addBuiltin(AggregateFunction.createBuiltin("corr",
+                Lists.<Type>newArrayList(Type.BIGINT, Type.BIGINT), 
Type.DOUBLE, Type.DOUBLE,
+                "", "", "", "", "", "", "",
+                false, false, false, true));
+        addBuiltin(AggregateFunction.createBuiltin("corr",
+                Lists.<Type>newArrayList(Type.FLOAT, Type.FLOAT), Type.DOUBLE, 
Type.DOUBLE,
+                "", "", "", "", "", "", "",
+                false, false, false, true));
+        addBuiltin(AggregateFunction.createBuiltin("corr",
+                Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), 
Type.DOUBLE, Type.DOUBLE,
+                "", "", "", "", "", "", "",
+                false, false, false, true));
     }
 
     public Map<String, List<Function>> getVectorizedFunctions() {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java
new file mode 100644
index 00000000000..26f8a720c26
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.agg;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
+import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.SmallIntType;
+import org.apache.doris.nereids.types.TinyIntType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * AggregateFunction 'corr'. This class is generated by GenerateFunction.
+ */
+public class Corr extends AggregateFunction
+        implements UnaryExpression, ExplicitlyCastableSignature, 
AlwaysNullable {
+
+    public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, 
TinyIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, 
SmallIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, 
IntegerType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, 
BigIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, 
FloatType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, 
DoubleType.INSTANCE)
+    );
+
+    /**
+     * constructor with 2 argument.
+     */
+    public Corr(Expression arg1, Expression arg2) {
+        super("corr", arg1, arg2);
+    }
+
+    /**
+     * constructor with 3 arguments.
+     */
+    public Corr(boolean distinct, Expression arg1, Expression arg2) {
+        super("corr", distinct, arg1, arg2);
+    }
+
+    /**
+     * withDistinctAndChildren.
+     */
+    @Override
+    public Corr withDistinctAndChildren(boolean distinct, List<Expression> 
children) {
+        Preconditions.checkArgument(children.size() == 2);
+        return new Corr(distinct, children.get(0), children.get(1));
+    }
+
+    @Override
+    public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+        return visitor.visitCorr(this, context);
+    }
+
+    @Override
+    public List<FunctionSignature> getSignatures() {
+        return SIGNATURES;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
index 14e3dc304e9..73dd6a838b9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
@@ -29,6 +29,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
 import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
 import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
 import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
 import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
@@ -126,6 +127,10 @@ public interface AggregateFunctionVisitor<R, C> {
         return visitAggregateFunction(collectSet, context);
     }
 
+    default R visitCorr(Corr corr, C context) {
+        return visitAggregateFunction(corr, context);
+    }
+
     default R visitCount(Count count, C context) {
         return visitAggregateFunction(count, context);
     }
diff --git 
a/regression-test/data/nereids_function_p0/agg_function/test_corr.out 
b/regression-test/data/nereids_function_p0/agg_function/test_corr.out
new file mode 100644
index 00000000000..4fc9a9d4baa
--- /dev/null
+++ b/regression-test/data/nereids_function_p0/agg_function/test_corr.out
@@ -0,0 +1,13 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !sql --
+1.0
+
+-- !sql --
+-1.0
+
+-- !sql --
+0.0
+
+-- !sql --
+0.8944271909999159
+
diff --git 
a/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy 
b/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy
new file mode 100644
index 00000000000..15f27f84276
--- /dev/null
+++ b/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_corr") {
+    sql """ DROP TABLE IF EXISTS test_corr """
+
+    sql """ SET enable_nereids_planner=true """
+    sql """ SET enable_fallback_to_original_planner=false """
+
+    sql """
+        CREATE TABLE test_corr (
+          `id` int,
+          `x` int,
+          `y` int,
+        ) ENGINE=OLAP
+        Duplicate KEY (`id`)
+        DISTRIBUTED BY HASH(`id`) BUCKETS 4
+        PROPERTIES (
+        "replication_allocation" = "tag.location.default: 1"
+        );
+        """
+    
+    // Perfect positive correlation    
+    sql """
+        insert into test_corr values
+        (1, 1, 1),
+        (2, 2, 2),
+        (3, 3, 3),
+        (4, 4, 4),
+        (5, 5, 5)
+        """
+    qt_sql "select corr(x,y) from test_corr"
+    sql """ truncate table test_corr """
+    
+    // Perfect negative correlation
+    sql """
+    insert into test_corr values
+    (1, 1, 5),
+    (2, 2, 4),
+    (3, 3, 3),
+    (4, 4, 2),
+    (5, 5, 1)
+    """
+    qt_sql "select corr(x,y) from test_corr"
+    sql """ truncate table test_corr """
+    
+    // Zero correlation
+    sql """
+    insert into test_corr values
+    (1, 1, 1),
+    (2, 1, 2),
+    (3, 1, 3),
+    (4, 1, 4),
+    (5, 1, 5)
+    """
+    qt_sql "select corr(x,y) from test_corr"
+    sql """ truncate table test_corr """
+    
+    // Partial linear correlation
+    sql """
+    insert into test_corr values
+    (1, 1, 1),
+    (2, 2, 2),
+    (3, 3, 3),
+    (4, 4, 4),
+    (5, 5, 10)
+    """
+    qt_sql "select corr(x,y) from test_corr"
+    
+    sql """ DROP TABLE IF EXISTS test_corr """
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to