This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 1a11cefb8118bc7cd5ab46c088163bf99e926f14 Author: nanfeng <nanfeng_...@163.com> AuthorDate: Wed Feb 7 08:32:06 2024 +0800 [feature](agg-func) support corr function #30822 --- .../aggregate_function_binary.h | 130 +++++++++++++++++++++ .../aggregate_function_corr.cpp | 92 +++++++++++++++ .../aggregate_function_simple_factory.cpp | 3 + .../sql-functions/aggregate-functions/corr.md | 49 ++++++++ .../sql-functions/aggregate-functions/corr.md | 50 ++++++++ .../doris/catalog/BuiltinAggregateFunctions.java | 2 + .../java/org/apache/doris/catalog/FunctionSet.java | 25 ++++ .../trees/expressions/functions/agg/Corr.java | 85 ++++++++++++++ .../visitor/AggregateFunctionVisitor.java | 5 + .../nereids_function_p0/agg_function/test_corr.out | 13 +++ .../agg_function/test_corr.groovy | 85 ++++++++++++++ 11 files changed, 539 insertions(+) diff --git a/be/src/vec/aggregate_functions/aggregate_function_binary.h b/be/src/vec/aggregate_functions/aggregate_function_binary.h new file mode 100644 index 00000000000..422919c52af --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_binary.h @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <glog/logging.h> + +#include <cmath> + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/column_vector.h" +#include "vec/common/arithmetic_overflow.h" +#include "vec/common/string_buffer.hpp" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +template <typename T1, typename T2, template <typename> typename Moments> +struct StatFunc { + using Type1 = T1; + using Type2 = T2; + using ResultType = std::conditional_t<std::is_same_v<T1, T2> && std::is_same_v<T1, Float32>, + Float32, Float64>; + using Data = Moments<ResultType>; +}; + +template <typename StatFunc> +struct AggregateFunctionBinary + : public IAggregateFunctionDataHelper<typename StatFunc::Data, + AggregateFunctionBinary<StatFunc>> { + using ResultType = typename StatFunc::ResultType; + + using ColVecT1 = ColumnVectorOrDecimal<typename StatFunc::Type1>; + using ColVecT2 = ColumnVectorOrDecimal<typename StatFunc::Type2>; + using ColVecResult = ColumnVector<ResultType>; + static constexpr UInt32 num_args = 2; + + AggregateFunctionBinary(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<typename StatFunc::Data, + AggregateFunctionBinary<StatFunc>>(argument_types_) {} + + String get_name() const override { return StatFunc::Data::name(); } + + DataTypePtr get_return_type() const override { + return std::make_shared<DataTypeNumber<ResultType>>(); + } + + bool allocates_memory_in_arena() const override { return false; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena*) const override { + this->data(place).add( + static_cast<ResultType>( + static_cast<const ColVecT1&>(*columns[0]).get_data()[row_num]), + static_cast<ResultType>( + static_cast<const ColVecT2&>(*columns[1]).get_data()[row_num])); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + const auto& data = this->data(place); + auto& dst = static_cast<ColVecResult&>(to).get_data(); + dst.push_back(data.get()); + } +}; + +template <template <typename> typename Moments, typename FirstType, typename... TArgs> +AggregateFunctionPtr create_with_two_basic_numeric_types_second(const DataTypePtr& second_type, + TArgs&&... args) { + WhichDataType which(remove_nullable(second_type)); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return creator_without_type::create< \ + AggregateFunctionBinary<StatFunc<FirstType, TYPE, Moments>>>( \ + std::forward<TArgs>(args)...); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH + return nullptr; +} + +template <template <typename> typename Moments, typename... TArgs> +AggregateFunctionPtr create_with_two_basic_numeric_types(const DataTypePtr& first_type, + const DataTypePtr& second_type, + TArgs&&... args) { + WhichDataType which(remove_nullable(first_type)); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + return create_with_two_basic_numeric_types_second<Moments, TYPE>( \ + second_type, std::forward<TArgs>(args)...); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH + return nullptr; +} + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp new file mode 100644 index 00000000000..fb84e92e0e6 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_binary.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/core/types.h" + +namespace doris::vectorized { + +template <typename T> +struct CorrMoment { + T m0 {}; + T x1 {}; + T y1 {}; + T xy {}; + T x2 {}; + T y2 {}; + + void add(T x, T y) { + ++m0; + x1 += x; + y1 += y; + xy += x * y; + x2 += x * x; + y2 += y * y; + } + + void merge(const CorrMoment& rhs) { + m0 += rhs.m0; + x1 += rhs.x1; + y1 += rhs.y1; + xy += rhs.xy; + x2 += rhs.x2; + y2 += rhs.y2; + } + + void write(BufferWritable& buf) const { + write_binary(m0, buf); + write_binary(x1, buf); + write_binary(y1, buf); + write_binary(xy, buf); + write_binary(x2, buf); + write_binary(y2, buf); + } + + void read(BufferReadable& buf) { + read_binary(m0, buf); + read_binary(x1, buf); + read_binary(y1, buf); + read_binary(xy, buf); + read_binary(x2, buf); + read_binary(y2, buf); + } + + T get() const { + if ((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1) == 0) [[unlikely]] { + return 0; + } + return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1)); + } + + static String name() { return "corr"; } +}; + +AggregateFunctionPtr create_aggregate_corr_function(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { + assert_binary(name, argument_types); + return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0], argument_types[1], + argument_types, result_is_nullable); +} + +void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory) { + factory.register_function_both("corr", create_aggregate_corr_function); +} + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 068e2efaac4..9f99a64f2bc 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -60,6 +60,7 @@ void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& fa void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_map_agg(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory& factory); +void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory); AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { static std::once_flag oc; @@ -100,6 +101,8 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_replace_reader_load(instance); register_aggregate_function_window_lead_lag_first_last(instance); register_aggregate_function_HLL_union_agg(instance); + + register_aggregate_functions_corr(instance); }); return instance; } diff --git a/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md new file mode 100644 index 00000000000..862dbad02b1 --- /dev/null +++ b/docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md @@ -0,0 +1,49 @@ +--- +{ + "title": "CORR", + "language": "en" +} +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## CORR +### Description +#### Syntax + +` double corr(x, y)` + +Calculate the Pearson correlation coefficient, which is returned as the covariance of x and y divided by the product of the standard deviations of x and y. +If the standard deviation of x or y is 0, the result will be 0. + +### example + +``` +mysql> select corr(x,y) from baseall; ++---------------------+ +| corr(x, y) | ++---------------------+ +| 0.89442719099991586 | ++---------------------+ +1 row in set (0.21 sec) + +``` +### keywords +CORR diff --git a/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md new file mode 100644 index 00000000000..0437d5e9d8f --- /dev/null +++ b/docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md @@ -0,0 +1,50 @@ +--- +{ + "title": "CORR", + "language": "zh-CN" +} +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## CORR +### Description +#### Syntax + +` double corr(x, y)` + +计算皮尔逊系数, 即返回结果为: x和y的协方差,除x和y的标准差乘积。 +如果x或y的标准差为0, 将返回0。 + + +### example + +``` +mysql> select corr(x,y) from baseall; ++---------------------+ +| corr(x, y) | ++---------------------+ +| 0.89442719099991586 | ++---------------------+ +1 row in set (0.21 sec) + +``` +### keywords +CORR diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index 5a101e71014..a8fc246d239 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt; import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList; import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet; +import org.apache.doris.nereids.trees.expressions.functions.agg.Corr; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum; import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd; @@ -93,6 +94,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(BitmapUnionInt.class, "bitmap_union_int"), agg(CollectList.class, "collect_list", "group_array"), agg(CollectSet.class, "collect_set", "group_uniq_array"), + agg(Corr.class, "corr"), agg(Count.class, "count"), agg(CountByEnum.class, "count_by_enum"), agg(GroupBitAnd.class, "group_bit_and"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index c589cbbf505..629e4556df2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1710,6 +1710,31 @@ public class FunctionSet<T> { "", false, true, false, true)); + // corr + addBuiltin(AggregateFunction.createBuiltin("corr", + Lists.<Type>newArrayList(Type.TINYINT, Type.TINYINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin("corr", + Lists.<Type>newArrayList(Type.SMALLINT, Type.SMALLINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin("corr", + Lists.<Type>newArrayList(Type.INT, Type.INT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin("corr", + Lists.<Type>newArrayList(Type.BIGINT, Type.BIGINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin("corr", + Lists.<Type>newArrayList(Type.FLOAT, Type.FLOAT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin("corr", + Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); } public Map<String, List<Function>> getVectorizedFunctions() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java new file mode 100644 index 00000000000..26f8a720c26 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'corr'. This class is generated by GenerateFunction. + */ +public class Corr extends AggregateFunction + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + + public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE) + ); + + /** + * constructor with 2 argument. + */ + public Corr(Expression arg1, Expression arg2) { + super("corr", arg1, arg2); + } + + /** + * constructor with 3 arguments. + */ + public Corr(boolean distinct, Expression arg1, Expression arg2) { + super("corr", distinct, arg1, arg2); + } + + /** + * withDistinctAndChildren. + */ + @Override + public Corr withDistinctAndChildren(boolean distinct, List<Expression> children) { + Preconditions.checkArgument(children.size() == 2); + return new Corr(distinct, children.get(0), children.get(1)); + } + + @Override + public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { + return visitor.visitCorr(this, context); + } + + @Override + public List<FunctionSignature> getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 14e3dc304e9..73dd6a838b9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt; import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList; import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet; +import org.apache.doris.nereids.trees.expressions.functions.agg.Corr; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum; import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd; @@ -126,6 +127,10 @@ public interface AggregateFunctionVisitor<R, C> { return visitAggregateFunction(collectSet, context); } + default R visitCorr(Corr corr, C context) { + return visitAggregateFunction(corr, context); + } + default R visitCount(Count count, C context) { return visitAggregateFunction(count, context); } diff --git a/regression-test/data/nereids_function_p0/agg_function/test_corr.out b/regression-test/data/nereids_function_p0/agg_function/test_corr.out new file mode 100644 index 00000000000..4fc9a9d4baa --- /dev/null +++ b/regression-test/data/nereids_function_p0/agg_function/test_corr.out @@ -0,0 +1,13 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql -- +1.0 + +-- !sql -- +-1.0 + +-- !sql -- +0.0 + +-- !sql -- +0.8944271909999159 + diff --git a/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy b/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy new file mode 100644 index 00000000000..15f27f84276 --- /dev/null +++ b/regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_corr") { + sql """ DROP TABLE IF EXISTS test_corr """ + + sql """ SET enable_nereids_planner=true """ + sql """ SET enable_fallback_to_original_planner=false """ + + sql """ + CREATE TABLE test_corr ( + `id` int, + `x` int, + `y` int, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + // Perfect positive correlation + sql """ + insert into test_corr values + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + (4, 4, 4), + (5, 5, 5) + """ + qt_sql "select corr(x,y) from test_corr" + sql """ truncate table test_corr """ + + // Perfect negative correlation + sql """ + insert into test_corr values + (1, 1, 5), + (2, 2, 4), + (3, 3, 3), + (4, 4, 2), + (5, 5, 1) + """ + qt_sql "select corr(x,y) from test_corr" + sql """ truncate table test_corr """ + + // Zero correlation + sql """ + insert into test_corr values + (1, 1, 1), + (2, 1, 2), + (3, 1, 3), + (4, 1, 4), + (5, 1, 5) + """ + qt_sql "select corr(x,y) from test_corr" + sql """ truncate table test_corr """ + + // Partial linear correlation + sql """ + insert into test_corr values + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + (4, 4, 4), + (5, 5, 10) + """ + qt_sql "select corr(x,y) from test_corr" + + sql """ DROP TABLE IF EXISTS test_corr """ +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org