This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 1362c00c7ef branch-4.0: [fix](regr) Use Youngs-Cramer for 
REGR_SLOPE/INTERCEPT to align with PG #55940 (#58920)
1362c00c7ef is described below

commit 1362c00c7ef164e681c2efce318a8782744ab6b8
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Thu Dec 11 18:29:33 2025 +0800

    branch-4.0: [fix](regr) Use Youngs-Cramer for REGR_SLOPE/INTERCEPT to align 
with PG #55940 (#58920)
    
    Cherry-picked from #55940
    
    Co-authored-by: Jover <[email protected]>
---
 .../aggregate_function_regr_union.h                | 250 ++++++++++++++++-----
 .../support_type/regr_intercept/regr_intercept.out |   8 +-
 .../support_type/regr_slope/regr_slope.out         |   8 +-
 .../query_p0/aggregate/test_regr_intercept.groovy  |  18 +-
 .../query_p0/aggregate/test_regr_slope.groovy      |  18 +-
 5 files changed, 226 insertions(+), 76 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h 
b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
index 1cc30b8c430..dde9fc5e48f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
@@ -32,88 +32,238 @@
 namespace doris::vectorized {
 #include "common/compile_check_begin.h"
 
-template <PrimitiveType T>
+template <PrimitiveType T,
+          // requires Sx and Sy
+          bool NeedSxy,
+          // level 1: Sx
+          // level 2: Sxx
+          size_t SxLevel = size_t {NeedSxy},
+          // level 1: Sy
+          // level 2: Syy
+          size_t SyLevel = size_t {NeedSxy}>
 struct AggregateFunctionRegrData {
     static constexpr PrimitiveType Type = T;
-    UInt64 count = 0;
-    Float64 sum_x {};
-    Float64 sum_y {};
-    Float64 sum_of_x_mul_y {};
-    Float64 sum_of_x_squared {};
+
+    static_assert(!NeedSxy || (SxLevel > 0 && SyLevel > 0),
+                  "NeedSxy requires SxLevel > 0 and SyLevel > 0");
+    static_assert(SxLevel <= 2 && SyLevel <= 2, "Sx/Sy level must be <= 2");
+
+    static constexpr bool need_sx = SxLevel > 0;
+    static constexpr bool need_sy = SyLevel > 0;
+    static constexpr bool need_sxx = SxLevel > 1;
+    static constexpr bool need_syy = SyLevel > 1;
+    static constexpr bool need_sxy = NeedSxy;
+
+    static constexpr size_t kMomentSize = SxLevel + SyLevel + size_t 
{need_sxy};
+    static_assert(kMomentSize > 0 && kMomentSize <= 5, "Unexpected size of 
regr moment array");
+
+    /**
+     * The moments array is:
+     *     Sx  = sum(X)
+     *     Sy  = sum(Y)
+     *     Sxx = sum((X-Sx/N)^2)
+     *     Syy = sum((Y-Sy/N)^2)
+     *     Sxy = sum((X-Sx/N)*(Y-Sy/N))
+    */
+    std::array<Float64, kMomentSize> moments {};
+    UInt64 n {};
+
+    static constexpr size_t idx_sx() {
+        static_assert(need_sx, "sx not enabled");
+        return 0;
+    }
+    static constexpr size_t idx_sy() {
+        static_assert(need_sy, "sy not enabled");
+        return size_t {need_sx};
+    }
+    static constexpr size_t idx_sxx() {
+        static_assert(need_sxx, "sxx not enabled");
+        return size_t {need_sx + need_sy};
+    }
+    static constexpr size_t idx_syy() {
+        static_assert(need_syy, "syy not enabled");
+        return size_t {need_sx + need_sy + need_sxx};
+    }
+    static constexpr size_t idx_sxy() {
+        static_assert(need_sxy, "sxy not enabled");
+        return size_t {need_sx + need_sy + need_sxx + need_syy};
+    }
+
+    Float64& sx() { return moments[idx_sx()]; }
+    Float64& sy() { return moments[idx_sy()]; }
+    Float64& sxx() { return moments[idx_sxx()]; }
+    Float64& syy() { return moments[idx_syy()]; }
+    Float64& sxy() { return moments[idx_sxy()]; }
+
+    const Float64& sx() const { return moments[idx_sx()]; }
+    const Float64& sy() const { return moments[idx_sy()]; }
+    const Float64& sxx() const { return moments[idx_sxx()]; }
+    const Float64& syy() const { return moments[idx_syy()]; }
+    const Float64& sxy() const { return moments[idx_sxy()]; }
 
     void write(BufferWritable& buf) const {
-        buf.write_binary(sum_x);
-        buf.write_binary(sum_y);
-        buf.write_binary(sum_of_x_mul_y);
-        buf.write_binary(sum_of_x_squared);
-        buf.write_binary(count);
+        if constexpr (need_sx) {
+            buf.write_binary(sx());
+        }
+        if constexpr (need_sy) {
+            buf.write_binary(sy());
+        }
+        if constexpr (need_sxx) {
+            buf.write_binary(sxx());
+        }
+        if constexpr (need_syy) {
+            buf.write_binary(syy());
+        }
+        if constexpr (need_sxy) {
+            buf.write_binary(sxy());
+        }
+        buf.write_binary(n);
     }
 
     void read(BufferReadable& buf) {
-        buf.read_binary(sum_x);
-        buf.read_binary(sum_y);
-        buf.read_binary(sum_of_x_mul_y);
-        buf.read_binary(sum_of_x_squared);
-        buf.read_binary(count);
+        if constexpr (need_sx) {
+            buf.read_binary(sx());
+        }
+        if constexpr (need_sy) {
+            buf.read_binary(sy());
+        }
+        if constexpr (need_sxx) {
+            buf.read_binary(sxx());
+        }
+        if constexpr (need_syy) {
+            buf.read_binary(syy());
+        }
+        if constexpr (need_sxy) {
+            buf.read_binary(sxy());
+        }
+        buf.read_binary(n);
     }
 
     void reset() {
-        sum_x = {};
-        sum_y = {};
-        sum_of_x_mul_y = {};
-        sum_of_x_squared = {};
-        count = 0;
+        moments.fill({});
+        n = {};
     }
 
+    /**
+     * The merge function uses the Youngs–Cramer algorithm:
+     *     N   = N1 + N2
+     *     Sx  = Sx1 + Sx2
+     *     Sy  = Sy1 + Sy2
+     *     Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N
+     *     Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N
+     *     Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2) 
/ N
+     */
     void merge(const AggregateFunctionRegrData& rhs) {
-        if (rhs.count == 0) {
+        if (rhs.n == 0) {
+            return;
+        }
+        if (n == 0) {
+            *this = rhs;
             return;
         }
-        sum_x += rhs.sum_x;
-        sum_y += rhs.sum_y;
-        sum_of_x_mul_y += rhs.sum_of_x_mul_y;
-        sum_of_x_squared += rhs.sum_of_x_squared;
-        count += rhs.count;
+        const auto n1 = static_cast<Float64>(n);
+        const auto n2 = static_cast<Float64>(rhs.n);
+        const auto nsum = n1 + n2;
+
+        Float64 dx {};
+        Float64 dy {};
+        if constexpr (need_sxx || need_sxy) {
+            dx = sx() / n1 - rhs.sx() / n2;
+        }
+        if constexpr (need_syy || need_sxy) {
+            dy = sy() / n1 - rhs.sy() / n2;
+        }
+
+        n += rhs.n;
+        if constexpr (need_sx) {
+            sx() += rhs.sx();
+        }
+        if constexpr (need_sy) {
+            sy() += rhs.sy();
+        }
+        if constexpr (need_sxx) {
+            sxx() += rhs.sxx() + n1 * n2 * dx * dx / nsum;
+        }
+        if constexpr (need_syy) {
+            syy() += rhs.syy() + n1 * n2 * dy * dy / nsum;
+        }
+        if constexpr (need_sxy) {
+            sxy() += rhs.sxy() + n1 * n2 * dx * dy / nsum;
+        }
     }
 
+    /**
+     * N
+     * Sx  = sum(X)
+     * Sy  = sum(Y)
+     * Sxx = sum((X-Sx/N)^2)
+     * Syy = sum((Y-Sy/N)^2)
+     * Sxy = sum((X-Sx/N)*(Y-Sy/N))
+     */
     void add(typename PrimitiveTypeTraits<T>::ColumnItemType value_y,
              typename PrimitiveTypeTraits<T>::ColumnItemType value_x) {
-        sum_x += (double)value_x;
-        sum_y += (double)value_y;
-        sum_of_x_mul_y += (double)value_x * (double)value_y;
-        sum_of_x_squared += (double)value_x * (double)value_x;
-        count += 1;
-    }
+        const auto x = static_cast<Float64>(value_x);
+        const auto y = static_cast<Float64>(value_y);
 
-    Float64 get_slope() const {
-        Float64 denominator = (double)count * sum_of_x_squared - sum_x * sum_x;
-        if (count < 2 || denominator == 0.0) {
-            return std::numeric_limits<Float64>::quiet_NaN();
+        if constexpr (need_sx) {
+            sx() += x;
+        }
+        if constexpr (need_sy) {
+            sy() += y;
+        }
+
+        if (n == 0) [[unlikely]] {
+            n = 1;
+            return;
+        }
+        const auto n_old = static_cast<Float64>(n);
+        const auto n_new = n_old + 1;
+        const auto scale = 1.0 / (n_new * n_old);
+        n += 1;
+
+        Float64 tmp_x {};
+        Float64 tmp_y {};
+        if constexpr (need_sxx || need_sxy) {
+            tmp_x = x * n_new - sx();
+        }
+        if constexpr (need_syy || need_sxy) {
+            tmp_y = y * n_new - sy();
+        }
+
+        if constexpr (need_sxx) {
+            sxx() += tmp_x * tmp_x * scale;
+        }
+        if constexpr (need_syy) {
+            syy() += tmp_y * tmp_y * scale;
+        }
+        if constexpr (need_sxy) {
+            sxy() += tmp_x * tmp_y * scale;
         }
-        Float64 slope = ((double)count * sum_of_x_mul_y - sum_x * sum_y) / 
denominator;
-        return slope;
     }
 };
 
 template <PrimitiveType T>
-struct RegrSlopeFunc : AggregateFunctionRegrData<T> {
+struct RegrSlopeFunc : AggregateFunctionRegrData<T, true, 2, 1> {
     static constexpr const char* name = "regr_slope";
 
-    Float64 get_result() const { return this->get_slope(); }
+    Float64 get_result() const {
+        if (this->n < 1 || this->sxx() == 0.0) {
+            return std::numeric_limits<Float64>::quiet_NaN();
+        }
+        return this->sxy() / this->sxx();
+    }
 };
 
 template <PrimitiveType T>
-struct RegrInterceptFunc : AggregateFunctionRegrData<T> {
+struct RegrInterceptFunc : AggregateFunctionRegrData<T, true, 2, 2> {
     static constexpr const char* name = "regr_intercept";
 
     Float64 get_result() const {
-        auto slope = this->get_slope();
-        if (std::isnan(slope)) {
-            return slope;
-        } else {
-            Float64 intercept = (this->sum_y - slope * this->sum_x) / 
(double)this->count;
-            return intercept;
+        if (this->n < 1 || this->sxx() == 0.0) {
+            return std::numeric_limits<Float64>::quiet_NaN();
         }
+        return (this->sy() - this->sx() * this->sxy() / this->sxx()) /
+               static_cast<Float64>(this->n);
     }
 };
 
@@ -147,7 +297,7 @@ public:
         const XInputCol* x_nested_column = nullptr;
 
         if constexpr (y_nullable) {
-            const ColumnNullable& y_column_nullable =
+            const auto& y_column_nullable =
                     assert_cast<const ColumnNullable&, 
TypeCheckOnRelease::DISABLE>(*columns[0]);
             y_null = y_column_nullable.is_null_at(row_num);
             y_nested_column = assert_cast<const YInputCol*, 
TypeCheckOnRelease::DISABLE>(
@@ -158,7 +308,7 @@ public:
         }
 
         if constexpr (x_nullable) {
-            const ColumnNullable& x_column_nullable =
+            const auto& x_column_nullable =
                     assert_cast<const ColumnNullable&, 
TypeCheckOnRelease::DISABLE>(*columns[1]);
             x_null = x_column_nullable.is_null_at(row_num);
             x_nested_column = assert_cast<const XInputCol*, 
TypeCheckOnRelease::DISABLE>(
diff --git 
a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
 
b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
index 88a91371f5f..f58aaf4a55f 100644
--- 
a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
+++ 
b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
@@ -6,17 +6,17 @@
 -990000.0
 
 -- !regr_intercept_int --
-1000001.0
+-9.999E9
 
 -- !regr_intercept_bigint --
 \N
 
 -- !regr_intercept_largeint --
-9.999999999999989E19
+1.0E20
 
 -- !regr_intercept_float --
-13.241664047161644
+13.24166404716167
 
 -- !regr_intercept_double --
-58.05515207632899
+58.05515207633332
 
diff --git 
a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
 
b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
index 77140f0d1d3..0e9d13ae71d 100644
--- 
a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
+++ 
b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
@@ -6,17 +6,17 @@
 1.0
 
 -- !regr_slope_int --
--0.0
+1.0
 
 -- !regr_slope_bigint --
 \N
 
 -- !regr_slope_largeint --
-17725.127617654194
+0.0
 
 -- !regr_slope_float --
--2.79289213515492
+-2.792892135154929
 
 -- !regr_slope_double --
--0.5501239199999569
+-0.5501239199999999
 
diff --git 
a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy 
b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
index f7c44642427..10683585309 100644
--- a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
+++ b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
@@ -51,9 +51,9 @@ suite("test_regr_intercept") {
     
     // no value
     // agg function without group by should return null
-    qt_sql_empty_1 "select regr_intercept(y,x) from test_regr_intercept_int"
+    qt_sql_empty_1 "select regr_intercept(y, x) from test_regr_intercept_int"
     // agg function with group by should return empty set
-    qt_sql_empty_2 "select regr_intercept(y,x) from test_regr_intercept_int 
group by id"
+    qt_sql_empty_2 "select regr_intercept(y, x) from test_regr_intercept_int 
group by id"
 
     sql """ TRUNCATE TABLE test_regr_intercept_int """
 
@@ -83,7 +83,7 @@ suite("test_regr_intercept") {
     qt_sql_int_2 "select regr_intercept(x, 4) from test_regr_intercept_int"
 
     // int value
-    qt_sql_int_3 "select regr_intercept(y,x) from test_regr_intercept_int"
+    qt_sql_int_3 "select regr_intercept(y, x) from test_regr_intercept_int"
 
     // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test 
non-Nullable input column
     qt_sql_int_4 "select regr_intercept(non_nullable(y), non_nullable(x)) from 
test_regr_intercept_int"
@@ -122,8 +122,8 @@ suite("test_regr_intercept") {
     qt_sql_int_7 "select regr_intercept(x, 4) from test_regr_intercept_int"
 
     // int value
-    qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int"
-    qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int 
group by id order by id"
+    qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int"
+    qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int 
group by id order by id"
 
     // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test 
non-Nullable input column
     qt_sql_int_9 "select regr_intercept(non_nullable(y), non_nullable(x)) from 
test_regr_intercept_int where id >= 3"
@@ -142,8 +142,8 @@ suite("test_regr_intercept") {
     qt_sql_double_2 "select regr_intercept(x, 4) from 
test_regr_intercept_double"
 
     // int value
-    qt_sql_double_3 "select regr_intercept(y,x) from 
test_regr_intercept_double"
-    qt_sql_double_3 "select regr_intercept(y,x) from 
test_regr_intercept_double group by id order by id"
+    qt_sql_double_3 "select regr_intercept(y, x) from 
test_regr_intercept_double"
+    qt_sql_double_3 "select regr_intercept(y, x) from 
test_regr_intercept_double group by id order by id"
 
     // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test 
non-Nullable input column
     qt_sql_double_4 "select regr_intercept(non_nullable(y), non_nullable(x)) 
from test_regr_intercept_double"
@@ -183,8 +183,8 @@ suite("test_regr_intercept") {
     qt_sql_double_7 "select regr_intercept(x, 4) from 
test_regr_intercept_double"
 
     // int value
-    qt_sql_double_8 "select regr_intercept(y,x) from 
test_regr_intercept_double"
-    qt_sql_double_8 "select regr_intercept(y,x) from 
test_regr_intercept_double group by id order by id"
+    qt_sql_double_8 "select regr_intercept(y, x) from 
test_regr_intercept_double"
+    qt_sql_double_8 "select regr_intercept(y, x) from 
test_regr_intercept_double group by id order by id"
     
     // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test 
non-Nullable input column
     qt_sql_double_9 "select regr_intercept(non_nullable(y), non_nullable(x)) 
from test_regr_intercept_double where id >= 3"
diff --git a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy 
b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
index 19397036234..0c600710367 100644
--- a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
+++ b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
@@ -51,9 +51,9 @@ suite("test_regr_slope") {
     
     // no value
     // agg function without group by should return null
-    qt_sql_empty_1 "select regr_slope(y,x) from test_regr_slope_int"
+    qt_sql_empty_1 "select regr_slope(y, x) from test_regr_slope_int"
     // agg function with group by should return empty set
-    qt_sql_empty_2 "select regr_slope(y,x) from test_regr_slope_int group by 
id"
+    qt_sql_empty_2 "select regr_slope(y, x) from test_regr_slope_int group by 
id"
 
     sql """ TRUNCATE TABLE test_regr_slope_int """
 
@@ -83,7 +83,7 @@ suite("test_regr_slope") {
     qt_sql_int_2 "select regr_slope(x, 4) from test_regr_slope_int"
 
     // int value
-    qt_sql_int_3 "select regr_slope(y,x) from test_regr_slope_int"
+    qt_sql_int_3 "select regr_slope(y, x) from test_regr_slope_int"
 
     // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test 
non-Nullable input column
     qt_sql_int_4 "select regr_slope(non_nullable(y), non_nullable(x)) from 
test_regr_slope_int"
@@ -122,8 +122,8 @@ suite("test_regr_slope") {
     qt_sql_int_7 "select regr_slope(x, 4) from test_regr_slope_int"
 
     // int value
-    qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int"
-    qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int group by id 
order by id"
+    qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int"
+    qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int group by id 
order by id"
 
     // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test 
non-Nullable input column
     qt_sql_int_9 "select regr_slope(non_nullable(y), non_nullable(x)) from 
test_regr_slope_int where id >= 3"
@@ -142,8 +142,8 @@ suite("test_regr_slope") {
     qt_sql_double_2 "select regr_slope(x, 4) from test_regr_slope_double"
 
     // int value
-    qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double"
-    qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double group 
by id order by id"
+    qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double"
+    qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double group 
by id order by id"
 
     // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test 
non-Nullable input column
     qt_sql_double_4 "select regr_slope(non_nullable(y), non_nullable(x)) from 
test_regr_slope_double"
@@ -183,8 +183,8 @@ suite("test_regr_slope") {
     qt_sql_double_7 "select regr_slope(x, 4) from test_regr_slope_double"
 
     // int value
-    qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double"
-    qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double group 
by id order by id"
+    qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double"
+    qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double group 
by id order by id"
     
     // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test 
non-Nullable input column
     qt_sql_double_9 "select regr_slope(non_nullable(y), non_nullable(x)) from 
test_regr_slope_double where id >= 3"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to