This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 1362c00c7ef branch-4.0: [fix](regr) Use Youngs-Cramer for
REGR_SLOPE/INTERCEPT to align with PG #55940 (#58920)
1362c00c7ef is described below
commit 1362c00c7ef164e681c2efce318a8782744ab6b8
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Thu Dec 11 18:29:33 2025 +0800
branch-4.0: [fix](regr) Use Youngs-Cramer for REGR_SLOPE/INTERCEPT to align
with PG #55940 (#58920)
Cherry-picked from #55940
Co-authored-by: Jover <[email protected]>
---
.../aggregate_function_regr_union.h | 250 ++++++++++++++++-----
.../support_type/regr_intercept/regr_intercept.out | 8 +-
.../support_type/regr_slope/regr_slope.out | 8 +-
.../query_p0/aggregate/test_regr_intercept.groovy | 18 +-
.../query_p0/aggregate/test_regr_slope.groovy | 18 +-
5 files changed, 226 insertions(+), 76 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
index 1cc30b8c430..dde9fc5e48f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h
@@ -32,88 +32,238 @@
namespace doris::vectorized {
#include "common/compile_check_begin.h"
-template <PrimitiveType T>
+template <PrimitiveType T,
+ // requires Sx and Sy
+ bool NeedSxy,
+ // level 1: Sx
+ // level 2: Sxx
+ size_t SxLevel = size_t {NeedSxy},
+ // level 1: Sy
+ // level 2: Syy
+ size_t SyLevel = size_t {NeedSxy}>
struct AggregateFunctionRegrData {
static constexpr PrimitiveType Type = T;
- UInt64 count = 0;
- Float64 sum_x {};
- Float64 sum_y {};
- Float64 sum_of_x_mul_y {};
- Float64 sum_of_x_squared {};
+
+ static_assert(!NeedSxy || (SxLevel > 0 && SyLevel > 0),
+ "NeedSxy requires SxLevel > 0 and SyLevel > 0");
+ static_assert(SxLevel <= 2 && SyLevel <= 2, "Sx/Sy level must be <= 2");
+
+ static constexpr bool need_sx = SxLevel > 0;
+ static constexpr bool need_sy = SyLevel > 0;
+ static constexpr bool need_sxx = SxLevel > 1;
+ static constexpr bool need_syy = SyLevel > 1;
+ static constexpr bool need_sxy = NeedSxy;
+
+ static constexpr size_t kMomentSize = SxLevel + SyLevel + size_t
{need_sxy};
+ static_assert(kMomentSize > 0 && kMomentSize <= 5, "Unexpected size of
regr moment array");
+
+ /**
+ * The moments array is:
+ * Sx = sum(X)
+ * Sy = sum(Y)
+ * Sxx = sum((X-Sx/N)^2)
+ * Syy = sum((Y-Sy/N)^2)
+ * Sxy = sum((X-Sx/N)*(Y-Sy/N))
+ */
+ std::array<Float64, kMomentSize> moments {};
+ UInt64 n {};
+
+ static constexpr size_t idx_sx() {
+ static_assert(need_sx, "sx not enabled");
+ return 0;
+ }
+ static constexpr size_t idx_sy() {
+ static_assert(need_sy, "sy not enabled");
+ return size_t {need_sx};
+ }
+ static constexpr size_t idx_sxx() {
+ static_assert(need_sxx, "sxx not enabled");
+ return size_t {need_sx + need_sy};
+ }
+ static constexpr size_t idx_syy() {
+ static_assert(need_syy, "syy not enabled");
+ return size_t {need_sx + need_sy + need_sxx};
+ }
+ static constexpr size_t idx_sxy() {
+ static_assert(need_sxy, "sxy not enabled");
+ return size_t {need_sx + need_sy + need_sxx + need_syy};
+ }
+
+ Float64& sx() { return moments[idx_sx()]; }
+ Float64& sy() { return moments[idx_sy()]; }
+ Float64& sxx() { return moments[idx_sxx()]; }
+ Float64& syy() { return moments[idx_syy()]; }
+ Float64& sxy() { return moments[idx_sxy()]; }
+
+ const Float64& sx() const { return moments[idx_sx()]; }
+ const Float64& sy() const { return moments[idx_sy()]; }
+ const Float64& sxx() const { return moments[idx_sxx()]; }
+ const Float64& syy() const { return moments[idx_syy()]; }
+ const Float64& sxy() const { return moments[idx_sxy()]; }
void write(BufferWritable& buf) const {
- buf.write_binary(sum_x);
- buf.write_binary(sum_y);
- buf.write_binary(sum_of_x_mul_y);
- buf.write_binary(sum_of_x_squared);
- buf.write_binary(count);
+ if constexpr (need_sx) {
+ buf.write_binary(sx());
+ }
+ if constexpr (need_sy) {
+ buf.write_binary(sy());
+ }
+ if constexpr (need_sxx) {
+ buf.write_binary(sxx());
+ }
+ if constexpr (need_syy) {
+ buf.write_binary(syy());
+ }
+ if constexpr (need_sxy) {
+ buf.write_binary(sxy());
+ }
+ buf.write_binary(n);
}
void read(BufferReadable& buf) {
- buf.read_binary(sum_x);
- buf.read_binary(sum_y);
- buf.read_binary(sum_of_x_mul_y);
- buf.read_binary(sum_of_x_squared);
- buf.read_binary(count);
+ if constexpr (need_sx) {
+ buf.read_binary(sx());
+ }
+ if constexpr (need_sy) {
+ buf.read_binary(sy());
+ }
+ if constexpr (need_sxx) {
+ buf.read_binary(sxx());
+ }
+ if constexpr (need_syy) {
+ buf.read_binary(syy());
+ }
+ if constexpr (need_sxy) {
+ buf.read_binary(sxy());
+ }
+ buf.read_binary(n);
}
void reset() {
- sum_x = {};
- sum_y = {};
- sum_of_x_mul_y = {};
- sum_of_x_squared = {};
- count = 0;
+ moments.fill({});
+ n = {};
}
+ /**
+ * The merge function uses the Youngs–Cramer algorithm:
+ * N = N1 + N2
+ * Sx = Sx1 + Sx2
+ * Sy = Sy1 + Sy2
+ * Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N
+ * Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N
+ * Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2)
/ N
+ */
void merge(const AggregateFunctionRegrData& rhs) {
- if (rhs.count == 0) {
+ if (rhs.n == 0) {
+ return;
+ }
+ if (n == 0) {
+ *this = rhs;
return;
}
- sum_x += rhs.sum_x;
- sum_y += rhs.sum_y;
- sum_of_x_mul_y += rhs.sum_of_x_mul_y;
- sum_of_x_squared += rhs.sum_of_x_squared;
- count += rhs.count;
+ const auto n1 = static_cast<Float64>(n);
+ const auto n2 = static_cast<Float64>(rhs.n);
+ const auto nsum = n1 + n2;
+
+ Float64 dx {};
+ Float64 dy {};
+ if constexpr (need_sxx || need_sxy) {
+ dx = sx() / n1 - rhs.sx() / n2;
+ }
+ if constexpr (need_syy || need_sxy) {
+ dy = sy() / n1 - rhs.sy() / n2;
+ }
+
+ n += rhs.n;
+ if constexpr (need_sx) {
+ sx() += rhs.sx();
+ }
+ if constexpr (need_sy) {
+ sy() += rhs.sy();
+ }
+ if constexpr (need_sxx) {
+ sxx() += rhs.sxx() + n1 * n2 * dx * dx / nsum;
+ }
+ if constexpr (need_syy) {
+ syy() += rhs.syy() + n1 * n2 * dy * dy / nsum;
+ }
+ if constexpr (need_sxy) {
+ sxy() += rhs.sxy() + n1 * n2 * dx * dy / nsum;
+ }
}
+ /**
+ * N
+ * Sx = sum(X)
+ * Sy = sum(Y)
+ * Sxx = sum((X-Sx/N)^2)
+ * Syy = sum((Y-Sy/N)^2)
+ * Sxy = sum((X-Sx/N)*(Y-Sy/N))
+ */
void add(typename PrimitiveTypeTraits<T>::ColumnItemType value_y,
typename PrimitiveTypeTraits<T>::ColumnItemType value_x) {
- sum_x += (double)value_x;
- sum_y += (double)value_y;
- sum_of_x_mul_y += (double)value_x * (double)value_y;
- sum_of_x_squared += (double)value_x * (double)value_x;
- count += 1;
- }
+ const auto x = static_cast<Float64>(value_x);
+ const auto y = static_cast<Float64>(value_y);
- Float64 get_slope() const {
- Float64 denominator = (double)count * sum_of_x_squared - sum_x * sum_x;
- if (count < 2 || denominator == 0.0) {
- return std::numeric_limits<Float64>::quiet_NaN();
+ if constexpr (need_sx) {
+ sx() += x;
+ }
+ if constexpr (need_sy) {
+ sy() += y;
+ }
+
+ if (n == 0) [[unlikely]] {
+ n = 1;
+ return;
+ }
+ const auto n_old = static_cast<Float64>(n);
+ const auto n_new = n_old + 1;
+ const auto scale = 1.0 / (n_new * n_old);
+ n += 1;
+
+ Float64 tmp_x {};
+ Float64 tmp_y {};
+ if constexpr (need_sxx || need_sxy) {
+ tmp_x = x * n_new - sx();
+ }
+ if constexpr (need_syy || need_sxy) {
+ tmp_y = y * n_new - sy();
+ }
+
+ if constexpr (need_sxx) {
+ sxx() += tmp_x * tmp_x * scale;
+ }
+ if constexpr (need_syy) {
+ syy() += tmp_y * tmp_y * scale;
+ }
+ if constexpr (need_sxy) {
+ sxy() += tmp_x * tmp_y * scale;
}
- Float64 slope = ((double)count * sum_of_x_mul_y - sum_x * sum_y) /
denominator;
- return slope;
}
};
template <PrimitiveType T>
-struct RegrSlopeFunc : AggregateFunctionRegrData<T> {
+struct RegrSlopeFunc : AggregateFunctionRegrData<T, true, 2, 1> {
static constexpr const char* name = "regr_slope";
- Float64 get_result() const { return this->get_slope(); }
+ Float64 get_result() const {
+ if (this->n < 1 || this->sxx() == 0.0) {
+ return std::numeric_limits<Float64>::quiet_NaN();
+ }
+ return this->sxy() / this->sxx();
+ }
};
template <PrimitiveType T>
-struct RegrInterceptFunc : AggregateFunctionRegrData<T> {
+struct RegrInterceptFunc : AggregateFunctionRegrData<T, true, 2, 2> {
static constexpr const char* name = "regr_intercept";
Float64 get_result() const {
- auto slope = this->get_slope();
- if (std::isnan(slope)) {
- return slope;
- } else {
- Float64 intercept = (this->sum_y - slope * this->sum_x) /
(double)this->count;
- return intercept;
+ if (this->n < 1 || this->sxx() == 0.0) {
+ return std::numeric_limits<Float64>::quiet_NaN();
}
+ return (this->sy() - this->sx() * this->sxy() / this->sxx()) /
+ static_cast<Float64>(this->n);
}
};
@@ -147,7 +297,7 @@ public:
const XInputCol* x_nested_column = nullptr;
if constexpr (y_nullable) {
- const ColumnNullable& y_column_nullable =
+ const auto& y_column_nullable =
assert_cast<const ColumnNullable&,
TypeCheckOnRelease::DISABLE>(*columns[0]);
y_null = y_column_nullable.is_null_at(row_num);
y_nested_column = assert_cast<const YInputCol*,
TypeCheckOnRelease::DISABLE>(
@@ -158,7 +308,7 @@ public:
}
if constexpr (x_nullable) {
- const ColumnNullable& x_column_nullable =
+ const auto& x_column_nullable =
assert_cast<const ColumnNullable&,
TypeCheckOnRelease::DISABLE>(*columns[1]);
x_null = x_column_nullable.is_null_at(row_num);
x_nested_column = assert_cast<const XInputCol*,
TypeCheckOnRelease::DISABLE>(
diff --git
a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
index 88a91371f5f..f58aaf4a55f 100644
---
a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
+++
b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out
@@ -6,17 +6,17 @@
-990000.0
-- !regr_intercept_int --
-1000001.0
+-9.999E9
-- !regr_intercept_bigint --
\N
-- !regr_intercept_largeint --
-9.999999999999989E19
+1.0E20
-- !regr_intercept_float --
-13.241664047161644
+13.24166404716167
-- !regr_intercept_double --
-58.05515207632899
+58.05515207633332
diff --git
a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
index 77140f0d1d3..0e9d13ae71d 100644
---
a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
+++
b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out
@@ -6,17 +6,17 @@
1.0
-- !regr_slope_int --
--0.0
+1.0
-- !regr_slope_bigint --
\N
-- !regr_slope_largeint --
-17725.127617654194
+0.0
-- !regr_slope_float --
--2.79289213515492
+-2.792892135154929
-- !regr_slope_double --
--0.5501239199999569
+-0.5501239199999999
diff --git
a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
index f7c44642427..10683585309 100644
--- a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
+++ b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy
@@ -51,9 +51,9 @@ suite("test_regr_intercept") {
// no value
// agg function without group by should return null
- qt_sql_empty_1 "select regr_intercept(y,x) from test_regr_intercept_int"
+ qt_sql_empty_1 "select regr_intercept(y, x) from test_regr_intercept_int"
// agg function with group by should return empty set
- qt_sql_empty_2 "select regr_intercept(y,x) from test_regr_intercept_int
group by id"
+ qt_sql_empty_2 "select regr_intercept(y, x) from test_regr_intercept_int
group by id"
sql """ TRUNCATE TABLE test_regr_intercept_int """
@@ -83,7 +83,7 @@ suite("test_regr_intercept") {
qt_sql_int_2 "select regr_intercept(x, 4) from test_regr_intercept_int"
// int value
- qt_sql_int_3 "select regr_intercept(y,x) from test_regr_intercept_int"
+ qt_sql_int_3 "select regr_intercept(y, x) from test_regr_intercept_int"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_int_4 "select regr_intercept(non_nullable(y), non_nullable(x)) from
test_regr_intercept_int"
@@ -122,8 +122,8 @@ suite("test_regr_intercept") {
qt_sql_int_7 "select regr_intercept(x, 4) from test_regr_intercept_int"
// int value
- qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int"
- qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int
group by id order by id"
+ qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int"
+ qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int
group by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_int_9 "select regr_intercept(non_nullable(y), non_nullable(x)) from
test_regr_intercept_int where id >= 3"
@@ -142,8 +142,8 @@ suite("test_regr_intercept") {
qt_sql_double_2 "select regr_intercept(x, 4) from
test_regr_intercept_double"
// int value
- qt_sql_double_3 "select regr_intercept(y,x) from
test_regr_intercept_double"
- qt_sql_double_3 "select regr_intercept(y,x) from
test_regr_intercept_double group by id order by id"
+ qt_sql_double_3 "select regr_intercept(y, x) from
test_regr_intercept_double"
+ qt_sql_double_3 "select regr_intercept(y, x) from
test_regr_intercept_double group by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_double_4 "select regr_intercept(non_nullable(y), non_nullable(x))
from test_regr_intercept_double"
@@ -183,8 +183,8 @@ suite("test_regr_intercept") {
qt_sql_double_7 "select regr_intercept(x, 4) from
test_regr_intercept_double"
// int value
- qt_sql_double_8 "select regr_intercept(y,x) from
test_regr_intercept_double"
- qt_sql_double_8 "select regr_intercept(y,x) from
test_regr_intercept_double group by id order by id"
+ qt_sql_double_8 "select regr_intercept(y, x) from
test_regr_intercept_double"
+ qt_sql_double_8 "select regr_intercept(y, x) from
test_regr_intercept_double group by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_double_9 "select regr_intercept(non_nullable(y), non_nullable(x))
from test_regr_intercept_double where id >= 3"
diff --git a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
index 19397036234..0c600710367 100644
--- a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
+++ b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy
@@ -51,9 +51,9 @@ suite("test_regr_slope") {
// no value
// agg function without group by should return null
- qt_sql_empty_1 "select regr_slope(y,x) from test_regr_slope_int"
+ qt_sql_empty_1 "select regr_slope(y, x) from test_regr_slope_int"
// agg function with group by should return empty set
- qt_sql_empty_2 "select regr_slope(y,x) from test_regr_slope_int group by
id"
+ qt_sql_empty_2 "select regr_slope(y, x) from test_regr_slope_int group by
id"
sql """ TRUNCATE TABLE test_regr_slope_int """
@@ -83,7 +83,7 @@ suite("test_regr_slope") {
qt_sql_int_2 "select regr_slope(x, 4) from test_regr_slope_int"
// int value
- qt_sql_int_3 "select regr_slope(y,x) from test_regr_slope_int"
+ qt_sql_int_3 "select regr_slope(y, x) from test_regr_slope_int"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_int_4 "select regr_slope(non_nullable(y), non_nullable(x)) from
test_regr_slope_int"
@@ -122,8 +122,8 @@ suite("test_regr_slope") {
qt_sql_int_7 "select regr_slope(x, 4) from test_regr_slope_int"
// int value
- qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int"
- qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int group by id
order by id"
+ qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int"
+ qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int group by id
order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_int_9 "select regr_slope(non_nullable(y), non_nullable(x)) from
test_regr_slope_int where id >= 3"
@@ -142,8 +142,8 @@ suite("test_regr_slope") {
qt_sql_double_2 "select regr_slope(x, 4) from test_regr_slope_double"
// int value
- qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double"
- qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double group
by id order by id"
+ qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double"
+ qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double group
by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_double_4 "select regr_slope(non_nullable(y), non_nullable(x)) from
test_regr_slope_double"
@@ -183,8 +183,8 @@ suite("test_regr_slope") {
qt_sql_double_7 "select regr_slope(x, 4) from test_regr_slope_double"
// int value
- qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double"
- qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double group
by id order by id"
+ qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double"
+ qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double group
by id order by id"
// qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test
non-Nullable input column
qt_sql_double_9 "select regr_slope(non_nullable(y), non_nullable(x)) from
test_regr_slope_double where id >= 3"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]