This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 2d52c85286d [Exec](vec) support simd cal knn distance (#55275)
2d52c85286d is described below
commit 2d52c85286db37a5d644cf7aa23fed81302d7b3c
Author: HappenLee <[email protected]>
AuthorDate: Thu Sep 4 09:52:23 2025 +0800
[Exec](vec) support simd cal knn distance (#55275)
---
.../functions/array/function_array_distance.cpp | 67 +++++++++++++++-
.../vec/functions/array/function_array_distance.h | 88 +++++----------------
.../functions/scalar/CosineDistance.java | 10 +--
.../expressions/functions/scalar/InnerProduct.java | 10 +--
.../expressions/functions/scalar/L1Distance.java | 10 +--
.../expressions/functions/scalar/L2Distance.java | 10 +--
.../nereids_function_p0/scalar_function/Array.out | Bin 774673 -> 774406 bytes
.../test_array_distance_functions.out | Bin 654 -> 384 bytes
.../scalar_function/Array.groovy | 2 -
.../test_array_distance_functions.groovy | 32 --------
10 files changed, 102 insertions(+), 127 deletions(-)
diff --git a/be/src/vec/functions/array/function_array_distance.cpp
b/be/src/vec/functions/array/function_array_distance.cpp
index fc7ba9a0367..723e53b058a 100644
--- a/be/src/vec/functions/array/function_array_distance.cpp
+++ b/be/src/vec/functions/array/function_array_distance.cpp
@@ -21,11 +21,70 @@
namespace doris::vectorized {
+#if defined(__x86_64__) && (defined(__clang_major__) && (__clang_major__ > 10))
+#define PRAGMA_IMPRECISE_FUNCTION_BEGIN _Pragma("float_control(precise, off,
push)")
+#define PRAGMA_IMPRECISE_FUNCTION_END _Pragma("float_control(pop)")
+
+#elif defined(__GNUC__)
+#define PRAGMA_IMPRECISE_FUNCTION_BEGIN \
+ _Pragma("GCC push_options") \
+ _Pragma("GCC optimize
(\"unroll-loops,associative-math,no-signed-zeros\")")
+#define PRAGMA_IMPRECISE_FUNCTION_END _Pragma("GCC pop_options")
+#else
+#define PRAGMA_IMPRECISE_FUNCTION_BEGIN
+#define PRAGMA_IMPRECISE_FUNCTION_END
+#endif
+
+PRAGMA_IMPRECISE_FUNCTION_BEGIN
+float L1Distance::distance(const float* x, const float* y, size_t d) {
+ size_t i;
+ float res = 0;
+ for (i = 0; i < d; i++) {
+ res += fabs(x[i] - y[i]);
+ }
+ return res;
+}
+
+float L2Distance::distance(const float* x, const float* y, size_t d) {
+ size_t i;
+ float res = 0;
+ for (i = 0; i < d; i++) {
+ const float tmp = x[i] - y[i];
+ res += tmp * tmp;
+ }
+ return std::sqrt(res);
+}
+
+float CosineDistance::distance(const float* x, const float* y, size_t d) {
+ float dot_prod = 0;
+ float squared_x = 0;
+ float squared_y = 0;
+ for (size_t i = 0; i < d; ++i) {
+ dot_prod += x[i] * y[i];
+ squared_x += x[i] * x[i];
+ squared_y += y[i] * y[i];
+ }
+ // division by zero check
+ if (squared_x == 0 || squared_y == 0) [[unlikely]] {
+ return 2.F;
+ }
+ return 1 - dot_prod / sqrt(squared_x * squared_y);
+}
+
+float InnerProduct::distance(const float* x, const float* y, size_t d) {
+ float res = 0.F;
+ for (size_t i = 0; i != d; ++i) {
+ res += x[i] * y[i];
+ }
+ return res;
+}
+PRAGMA_IMPRECISE_FUNCTION_END
+
void register_function_array_distance(SimpleFunctionFactory& factory) {
- factory.register_function<FunctionArrayDistance<L1Distance> >();
- factory.register_function<FunctionArrayDistance<L2Distance> >();
- factory.register_function<FunctionArrayDistance<CosineDistance> >();
- factory.register_function<FunctionArrayDistance<InnerProduct> >();
+ factory.register_function<FunctionArrayDistance<L1Distance>>();
+ factory.register_function<FunctionArrayDistance<L2Distance>>();
+ factory.register_function<FunctionArrayDistance<CosineDistance>>();
+ factory.register_function<FunctionArrayDistance<InnerProduct>>();
}
} // namespace doris::vectorized
diff --git a/be/src/vec/functions/array/function_array_distance.h
b/be/src/vec/functions/array/function_array_distance.h
index 5a855c04988..aa6e1cf980d 100644
--- a/be/src/vec/functions/array/function_array_distance.h
+++ b/be/src/vec/functions/array/function_array_distance.h
@@ -17,6 +17,8 @@
#pragma once
+#include <gen_cpp/Types_types.h>
+
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/columns_number.h"
@@ -35,63 +37,42 @@ namespace doris::vectorized {
class L1Distance {
public:
static constexpr auto name = "l1_distance";
- struct State {
- double sum = 0;
- };
- static void accumulate(State& state, double x, double y) { state.sum +=
fabs(x - y); }
- static double finalize(const State& state) { return state.sum; }
+ static float distance(const float* x, const float* y, size_t d);
};
class L2Distance {
public:
static constexpr auto name = "l2_distance";
- struct State {
- double sum = 0;
- };
- static void accumulate(State& state, double x, double y) { state.sum += (x
- y) * (x - y); }
- static double finalize(const State& state) { return sqrt(state.sum); }
+ static float distance(const float* x, const float* y, size_t d);
};
class InnerProduct {
public:
static constexpr auto name = "inner_product";
- struct State {
- double sum = 0;
- };
- static void accumulate(State& state, double x, double y) { state.sum += x
* y; }
- static double finalize(const State& state) { return state.sum; }
+ static float distance(const float* x, const float* y, size_t d);
};
class CosineDistance {
public:
static constexpr auto name = "cosine_distance";
- struct State {
- double dot_prod = 0;
- double squared_x = 0;
- double squared_y = 0;
- };
- static void accumulate(State& state, double x, double y) {
- state.dot_prod += x * y;
- state.squared_x += x * x;
- state.squared_y += y * y;
- }
- static double finalize(const State& state) {
- return 1 - state.dot_prod / sqrt(state.squared_x * state.squared_y);
- }
+
+ static float distance(const float* x, const float* y, size_t d);
};
template <typename DistanceImpl>
class FunctionArrayDistance : public IFunction {
public:
+ using ColumnType = ColumnFloat32;
+
static constexpr auto name = DistanceImpl::name;
String get_name() const override { return name; }
static FunctionPtr create() { return
std::make_shared<FunctionArrayDistance<DistanceImpl>>(); }
bool is_variadic() const override { return false; }
size_t get_number_of_arguments() const override { return 2; }
- bool use_default_implementation_for_nulls() const override { return false;
}
+ bool use_default_implementation_for_nulls() const override { return true; }
DataTypePtr get_return_type_impl(const DataTypes& arguments) const
override {
- return make_nullable(std::make_shared<DataTypeFloat64>());
+ return std::make_shared<DataTypeFloat32>();
}
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
@@ -121,27 +102,14 @@ public:
}
// prepare return data
- auto dst = ColumnFloat64::create(input_rows_count);
+ auto dst = ColumnType::create(input_rows_count);
auto& dst_data = dst->get_data();
- auto dst_null_column = ColumnUInt8::create(input_rows_count, 0);
- auto& dst_null_data = dst_null_column->get_data();
const auto& offsets1 = *arr1.offsets_ptr;
const auto& offsets2 = *arr2.offsets_ptr;
- const auto& nested_col1 = assert_cast<const
ColumnFloat64*>(arr1.nested_col.get());
- const auto& nested_col2 = assert_cast<const
ColumnFloat64*>(arr2.nested_col.get());
+ const auto& nested_col1 = assert_cast<const
ColumnType*>(arr1.nested_col.get());
+ const auto& nested_col2 = assert_cast<const
ColumnType*>(arr2.nested_col.get());
for (ssize_t row = 0; row < offsets1.size(); ++row) {
- if (arr1.array_nullmap_data && arr1.array_nullmap_data[row]) {
- dst_null_data[row] = true;
- continue;
- }
- if (arr2.array_nullmap_data && arr2.array_nullmap_data[row]) {
- dst_null_data[row] = true;
- continue;
- }
-
- dst_null_data[row] = false;
-
// Calculate actual array sizes for current row.
// For nullable arrays, we cannot compare absolute offset values
directly because:
// 1. When a row is null, its offset might equal the previous
offset (no elements added)
@@ -156,29 +124,11 @@ public:
get_name(), size1, size2);
}
- typename DistanceImpl::State st;
- for (ssize_t pos = offsets1[row - 1]; pos < offsets1[row]; ++pos) {
- // Calculate corresponding position in the second array
- ssize_t pos2 = offsets2[row - 1] + (pos - offsets1[row - 1]);
- if (arr1.nested_nullmap_data && arr1.nested_nullmap_data[pos])
{
- dst_null_data[row] = true;
- break;
- }
- if (arr2.nested_nullmap_data &&
arr2.nested_nullmap_data[pos2]) {
- dst_null_data[row] = true;
- break;
- }
- DistanceImpl::accumulate(st, nested_col1->get_element(pos),
- nested_col2->get_element(pos2));
- }
- if (!dst_null_data[row]) {
- dst_data[row] = DistanceImpl::finalize(st);
- dst_null_data[row] = std::isnan(dst_data[row]);
- }
+ dst_data[row] = DistanceImpl::distance(
+ nested_col1->get_data().data() + offsets1[row - 1],
+ nested_col2->get_data().data() + offsets1[row - 1], size1);
}
-
- block.replace_by_position(
- result, ColumnNullable::create(std::move(dst),
std::move(dst_null_column)));
+ block.replace_by_position(result, std::move(dst));
return Status::OK();
}
@@ -190,7 +140,7 @@ private:
}
auto nested_type =
remove_nullable(assert_cast<const
DataTypeArray&>(*array_type).get_nested_type());
- return WhichDataType(nested_type).is_float64();
+ return WhichDataType(nested_type).is_float32();
}
};
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineDistance.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineDistance.java
index 14c388bb933..acf4af47d1c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineDistance.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CosineDistance.java
@@ -19,12 +19,12 @@ package
org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
-import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@@ -35,11 +35,11 @@ import java.util.List;
* cosine_distance function
*/
public class CosineDistance extends ScalarFunction implements
ExplicitlyCastableSignature,
- BinaryExpression, AlwaysNullable {
+ BinaryExpression, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
- FunctionSignature.ret(DoubleType.INSTANCE)
- .args(ArrayType.of(DoubleType.INSTANCE),
ArrayType.of(DoubleType.INSTANCE))
+ FunctionSignature.ret(FloatType.INSTANCE)
+ .args(ArrayType.of(FloatType.INSTANCE),
ArrayType.of(FloatType.INSTANCE))
);
/**
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProduct.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProduct.java
index a56d5d5a522..3daddcffd80 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProduct.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProduct.java
@@ -19,12 +19,12 @@ package
org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
-import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@@ -35,11 +35,11 @@ import java.util.List;
* inner_product function
*/
public class InnerProduct extends ScalarFunction implements
ExplicitlyCastableSignature,
- BinaryExpression, AlwaysNullable {
+ BinaryExpression, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
- FunctionSignature.ret(DoubleType.INSTANCE)
- .args(ArrayType.of(DoubleType.INSTANCE),
ArrayType.of(DoubleType.INSTANCE))
+ FunctionSignature.ret(FloatType.INSTANCE)
+ .args(ArrayType.of(FloatType.INSTANCE),
ArrayType.of(FloatType.INSTANCE))
);
/**
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L1Distance.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L1Distance.java
index 66a6ebd2bf4..eb2ce8be17b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L1Distance.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L1Distance.java
@@ -19,12 +19,12 @@ package
org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
-import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@@ -35,11 +35,11 @@ import java.util.List;
* l1_distance function
*/
public class L1Distance extends ScalarFunction implements
ExplicitlyCastableSignature,
- BinaryExpression, AlwaysNullable {
+ BinaryExpression, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
- FunctionSignature.ret(DoubleType.INSTANCE)
- .args(ArrayType.of(DoubleType.INSTANCE),
ArrayType.of(DoubleType.INSTANCE))
+ FunctionSignature.ret(FloatType.INSTANCE)
+ .args(ArrayType.of(FloatType.INSTANCE),
ArrayType.of(FloatType.INSTANCE))
);
/**
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2Distance.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2Distance.java
index a9775f59ad7..6939b4c215f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2Distance.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/L2Distance.java
@@ -19,12 +19,12 @@ package
org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
-import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@@ -35,11 +35,11 @@ import java.util.List;
* l2_distance function
*/
public class L2Distance extends ScalarFunction implements
ExplicitlyCastableSignature,
- BinaryExpression, AlwaysNullable {
+ BinaryExpression, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
- FunctionSignature.ret(DoubleType.INSTANCE)
- .args(ArrayType.of(DoubleType.INSTANCE),
ArrayType.of(DoubleType.INSTANCE))
+ FunctionSignature.ret(FloatType.INSTANCE)
+ .args(ArrayType.of(FloatType.INSTANCE),
ArrayType.of(FloatType.INSTANCE))
);
/**
diff --git a/regression-test/data/nereids_function_p0/scalar_function/Array.out
b/regression-test/data/nereids_function_p0/scalar_function/Array.out
index bcfe106fc0a..402a92a394c 100644
Binary files
a/regression-test/data/nereids_function_p0/scalar_function/Array.out and
b/regression-test/data/nereids_function_p0/scalar_function/Array.out differ
diff --git
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_distance_functions.out
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_distance_functions.out
index 4a14b8c9a40..ef647fe914b 100644
Binary files
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_distance_functions.out
and
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_distance_functions.out
differ
diff --git
a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy
b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy
index ae6193d3737..8211877bfe2 100644
--- a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy
+++ b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy
@@ -346,8 +346,6 @@ suite("nereids_scalar_fn_Array") {
order_qt_sql_cosine_distance_SmallInt_notnull "select
cosine_distance(kasint, kasint) from fn_test_not_nullable"
order_qt_sql_cosine_distance_Integer "select cosine_distance(kaint, kaint)
from fn_test"
order_qt_sql_cosine_distance_Integer_notnull "select
cosine_distance(kaint, kaint) from fn_test_not_nullable"
- order_qt_sql_cosine_distance_TinyInt "select cosine_distance(katint,
katint) from fn_test"
- order_qt_sql_cosine_distance_TinyInt_notnull "select
cosine_distance(katint, katint) from fn_test_not_nullable"
// inner_product
order_qt_sql_inner_product_Double "select inner_product(kadbl, kadbl) from
fn_test"
diff --git
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy
index 9010750a2ec..4a9d792ac29 100644
---
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy
+++
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy
@@ -23,8 +23,6 @@ suite("test_array_distance_functions") {
qt_sql "SELECT inner_product([1, 2], [2, 3])"
qt_sql "SELECT l2_distance([1, 2, 3], NULL)"
- qt_sql "SELECT cosine_distance([1, 2, 3], [0, NULL, 0])"
-
// Test cases for nullable arrays with different null distributions
// These test the fix for correct array size comparison when nulls are
present
qt_sql "SELECT l1_distance(NULL, NULL)"
@@ -71,34 +69,4 @@ suite("test_array_distance_functions") {
// Edge case: empty arrays should work
qt_sql "SELECT l1_distance(CAST([] as ARRAY<DOUBLE>), CAST([] as
ARRAY<DOUBLE>))"
qt_sql "SELECT l2_distance(CAST([] as ARRAY<DOUBLE>), CAST([] as
ARRAY<DOUBLE>))"
-
- // Comprehensive test for the offset fix: test with table data containing
mixed nulls
- // This specifically tests the scenario where offsets might differ due to
null distribution
- // but actual array sizes are the same
- sql """
- DROP TABLE IF EXISTS test_array_distance_nullable
- """
- sql """
- CREATE TABLE test_array_distance_nullable (
- id INT,
- arr1 ARRAY<DOUBLE>,
- arr2 ARRAY<DOUBLE>
- ) PROPERTIES (
- "replication_num" = "1"
- )
- """
- sql """
- INSERT INTO test_array_distance_nullable VALUES
- (1, [1.0, 2.0], [3.0, 4.0]),
- (2, NULL, [5.0, 6.0]),
- (3, [7.0, 8.0], NULL),
- (4, [9.0, 10.0], [11.0, 12.0]),
- (5, NULL, NULL)
- """
-
- // These queries should work correctly after the fix
- qt_sql "SELECT id, l1_distance(arr1, arr2) FROM
test_array_distance_nullable ORDER BY id"
- qt_sql "SELECT id, l2_distance(arr1, arr2) FROM
test_array_distance_nullable ORDER BY id"
- qt_sql "SELECT id, cosine_distance(arr1, arr2) FROM
test_array_distance_nullable ORDER BY id"
- qt_sql "SELECT id, inner_product(arr1, arr2) FROM
test_array_distance_nullable ORDER BY id"
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]