This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 604842526b2 [improvement](expr) improve if expr performance (#27296)
604842526b2 is described below

commit 604842526b2b1555bd2bb60d5de85d867c3e54b9
Author: TengJianPing <18241664+jackte...@users.noreply.github.com>
AuthorDate: Wed Nov 22 12:48:06 2023 +0800

    [improvement](expr) improve if expr performance (#27296)
---
 be/src/vec/core/accurate_comparison.h       |  14 +++-
 be/src/vec/functions/functions_comparison.h |  14 ++--
 be/src/vec/functions/if.cpp                 | 106 ++++++++++++++++++++++------
 3 files changed, 105 insertions(+), 29 deletions(-)

diff --git a/be/src/vec/core/accurate_comparison.h 
b/be/src/vec/core/accurate_comparison.h
index 71189d37d66..42dca505532 100644
--- a/be/src/vec/core/accurate_comparison.h
+++ b/be/src/vec/core/accurate_comparison.h
@@ -224,7 +224,7 @@ bool greaterOp(A a, B b) {
 }
 
 template <typename A, typename B>
-bool greaterOrEqualsOp(A a, B b) {
+inline bool_if_not_safe_conversion<A, B> greaterOrEqualsOp(A a, B b) {
     if (is_nan(a) || is_nan(b)) {
         return false;
     }
@@ -233,7 +233,12 @@ bool greaterOrEqualsOp(A a, B b) {
 }
 
 template <typename A, typename B>
-bool lessOrEqualsOp(A a, B b) {
+inline bool_if_safe_conversion<A, B> greaterOrEqualsOp(A a, B b) {
+    return a >= b;
+}
+
+template <typename A, typename B>
+inline bool_if_not_safe_conversion<A, B> lessOrEqualsOp(A a, B b) {
     if (is_nan(a) || is_nan(b)) {
         return false;
     }
@@ -241,6 +246,11 @@ bool lessOrEqualsOp(A a, B b) {
     return !lessOp(b, a);
 }
 
+template <typename A, typename B>
+inline bool_if_safe_conversion<A, B> lessOrEqualsOp(A a, B b) {
+    return a <= b;
+}
+
 template <typename A, typename B>
 bool equalsOp(A a, B b) {
     if constexpr (std::is_same_v<A, B>) {
diff --git a/be/src/vec/functions/functions_comparison.h 
b/be/src/vec/functions/functions_comparison.h
index 598bf01bfdd..6a0f3aa634d 100644
--- a/be/src/vec/functions/functions_comparison.h
+++ b/be/src/vec/functions/functions_comparison.h
@@ -61,10 +61,10 @@ struct NumComparisonImpl {
     static void NO_INLINE vector_vector(const PaddedPODArray<A>& a, const 
PaddedPODArray<B>& b,
                                         PaddedPODArray<UInt8>& c) {
         size_t size = a.size();
-        const A* a_pos = a.data();
-        const B* b_pos = b.data();
-        UInt8* c_pos = c.data();
-        const A* a_end = a_pos + size;
+        const A* __restrict a_pos = a.data();
+        const B* __restrict b_pos = b.data();
+        UInt8* __restrict c_pos = c.data();
+        const A* __restrict a_end = a_pos + size;
 
         while (a_pos < a_end) {
             *c_pos = Op::apply(*a_pos, *b_pos);
@@ -77,9 +77,9 @@ struct NumComparisonImpl {
     static void NO_INLINE vector_constant(const PaddedPODArray<A>& a, B b,
                                           PaddedPODArray<UInt8>& c) {
         size_t size = a.size();
-        const A* a_pos = a.data();
-        UInt8* c_pos = c.data();
-        const A* a_end = a_pos + size;
+        const A* __restrict a_pos = a.data();
+        UInt8* __restrict c_pos = c.data();
+        const A* __restrict a_end = a_pos + size;
 
         while (a_pos < a_end) {
             *c_pos = Op::apply(*a_pos, b);
diff --git a/be/src/vec/functions/if.cpp b/be/src/vec/functions/if.cpp
index 9b14abce2ae..8f36e56d509 100644
--- a/be/src/vec/functions/if.cpp
+++ b/be/src/vec/functions/if.cpp
@@ -29,6 +29,7 @@
 #include <utility>
 
 #include "common/status.h"
+#include "util/simd/bits.h"
 #include "vec/aggregate_functions/aggregate_function.h"
 #include "vec/columns/column.h"
 #include "vec/columns/column_const.h"
@@ -49,7 +50,6 @@
 #include "vec/functions/function.h"
 #include "vec/functions/function_helpers.h"
 #include "vec/functions/simple_function_factory.h"
-
 namespace doris {
 class FunctionContext;
 
@@ -140,6 +140,40 @@ public:
     }
 };
 
+size_t count_true_with_notnull(const ColumnPtr& col) {
+    if (col->only_null()) {
+        return 0;
+    }
+
+    if (const auto* const_col = 
check_and_get_column_const<ColumnVector<UInt8>>(col.get())) {
+        bool is_true = const_col->get_bool(0);
+        return is_true ? col->size() : 0;
+    }
+
+    auto count = col->size();
+    if (col->is_nullable()) {
+        const auto* nullable = assert_cast<const ColumnNullable*>(col.get());
+        const auto* __restrict null_data = 
nullable->get_null_map_data().data();
+        const auto* __restrict bool_data =
+                ((const 
ColumnVector<UInt8>&)(nullable->get_nested_column())).get_data().data();
+
+        size_t null_count = count - simd::count_zero_num((const 
int8_t*)null_data, count);
+
+        if (null_count == count) {
+            return 0;
+        } else if (null_count == 0) {
+            size_t true_count = count - simd::count_zero_num((const 
int8_t*)bool_data, count);
+            return true_count;
+        } else {
+            // In fact, the null_count maybe is different with true_count, but 
it's no impact
+            return null_count;
+        }
+    } else {
+        const auto* bool_col = typeid_cast<const ColumnUInt8*>(col.get());
+        const auto* __restrict bool_data = bool_col->get_data().data();
+        return count - simd::count_zero_num((const int8_t*)bool_data, count);
+    }
+}
 // todo(wb) support llvm codegen
 class FunctionIf : public IFunction {
 public:
@@ -306,7 +340,7 @@ public:
             return true;
         }
 
-        const ColumnUInt8* cond_col = typeid_cast<const 
ColumnUInt8*>(arg_cond.column.get());
+        const auto* cond_col = typeid_cast<const 
ColumnUInt8*>(arg_cond.column.get());
         const ColumnConst* cond_const_col =
                 
check_and_get_column_const<ColumnVector<UInt8>>(arg_cond.column.get());
 
@@ -347,15 +381,6 @@ public:
         if (else_is_null) {
             if (cond_col) {
                 size_t size = input_rows_count;
-                auto& null_map_data = cond_col->get_data();
-
-                auto negated_null_map = ColumnUInt8::create();
-                auto& negated_null_map_data = negated_null_map->get_data();
-                negated_null_map_data.resize(size);
-
-                for (size_t i = 0; i < size; ++i) {
-                    negated_null_map_data[i] = !null_map_data[i];
-                }
 
                 if (is_column_nullable(*arg_then.column)) { // if(cond, 
nullable, NULL)
                     auto arg_then_column = arg_then.column;
@@ -365,6 +390,15 @@ public:
                                     assert_cast<const 
ColumnUInt8&>(*arg_cond.column));
                     block.replace_by_position(result, 
std::move(result_column));
                 } else { // if(cond, not_nullable, NULL)
+                    const auto& null_map_data = cond_col->get_data();
+                    auto negated_null_map = ColumnUInt8::create();
+                    auto& negated_null_map_data = negated_null_map->get_data();
+                    negated_null_map_data.resize(size);
+
+                    for (size_t i = 0; i < size; ++i) {
+                        negated_null_map_data[i] = !null_map_data[i];
+                    }
+
                     block.replace_by_position(
                             result,
                             
ColumnNullable::create(materialize_column_if_const(arg_then.column),
@@ -449,8 +483,8 @@ public:
             temporary_block.insert(
                     {nullptr, std::make_shared<DataTypeUInt8>(), 
"result_column_null_map"});
 
-            static_cast<void>(
-                    execute_impl(context, temporary_block, {0, 1, 2}, 3, 
temporary_block.rows()));
+            static_cast<void>(_execute_impl_internal(context, temporary_block, 
{0, 1, 2}, 3,
+                                                     temporary_block.rows()));
 
             result_null_mask = temporary_block.get_by_position(3).column;
         }
@@ -464,8 +498,8 @@ public:
                      {get_nested_column(arg_else.column), 
remove_nullable(arg_else.type), ""},
                      {nullptr, 
remove_nullable(block.get_by_position(result).type), ""}});
 
-            static_cast<void>(
-                    execute_impl(context, temporary_block, {0, 1, 2}, 3, 
temporary_block.rows()));
+            static_cast<void>(_execute_impl_internal(context, temporary_block, 
{0, 1, 2}, 3,
+                                                     temporary_block.rows()));
 
             result_nested_column = temporary_block.get_by_position(3).column;
         }
@@ -489,22 +523,22 @@ public:
             return true;
         }
 
-        if (auto* nullable = 
check_and_get_column<ColumnNullable>(*arg_cond.column)) {
+        if (const auto* nullable = 
check_and_get_column<ColumnNullable>(*arg_cond.column)) {
             DCHECK(remove_nullable(arg_cond.type)->get_type_id() == 
TypeIndex::UInt8);
 
             // update neseted column by nullmap
-            auto* __restrict null_map = nullable->get_null_map_data().data();
+            const auto* __restrict null_map = 
nullable->get_null_map_data().data();
             auto* __restrict nested_bool_data =
                     
((ColumnVector<UInt8>&)(nullable->get_nested_column())).get_data().data();
             auto rows = nullable->size();
             for (size_t i = 0; i < rows; i++) {
-                nested_bool_data[i] = null_map[i] ? 0 : nested_bool_data[i];
+                nested_bool_data[i] &= !null_map[i];
             }
             auto column_size = block.columns();
             block.insert({nullable->get_nested_column_ptr(), 
remove_nullable(arg_cond.type),
                           arg_cond.name});
 
-            static_cast<void>(execute_impl(
+            static_cast<void>(_execute_impl_internal(
                     context, block, {column_size, arguments[1], arguments[2]}, 
result, rows));
             return true;
         }
@@ -527,6 +561,38 @@ public:
         cond_column.column = materialize_column_if_const(cond_column.column);
         const ColumnWithTypeAndName& arg_cond = 
block.get_by_position(arguments[0]);
 
+        auto true_count = count_true_with_notnull(arg_cond.column);
+        auto item_count = arg_cond.column->size();
+        if (true_count == item_count || true_count == 0) {
+            bool result_nullable = 
block.get_by_position(result).type->is_nullable();
+            if (true_count == item_count) {
+                block.replace_by_position(
+                        result,
+                        result_nullable
+                                ? 
make_nullable(arg_then.column->clone_resized(input_rows_count))
+                                : 
arg_then.column->clone_resized(input_rows_count));
+            } else {
+                block.replace_by_position(
+                        result,
+                        result_nullable
+                                ? 
make_nullable(arg_else.column->clone_resized(input_rows_count))
+                                : 
arg_else.column->clone_resized(input_rows_count));
+            }
+            return Status::OK();
+        }
+
+        return _execute_impl_internal(context, block, arguments, result, 
input_rows_count);
+    }
+
+    Status _execute_impl_internal(FunctionContext* context, Block& block,
+                                  const ColumnNumbers& arguments, size_t 
result,
+                                  size_t input_rows_count) const {
+        const ColumnWithTypeAndName& arg_then = 
block.get_by_position(arguments[1]);
+        const ColumnWithTypeAndName& arg_else = 
block.get_by_position(arguments[2]);
+        ColumnWithTypeAndName& cond_column = 
block.get_by_position(arguments[0]);
+        cond_column.column = materialize_column_if_const(cond_column.column);
+        const ColumnWithTypeAndName& arg_cond = 
block.get_by_position(arguments[0]);
+
         Status ret = Status::OK();
         if (execute_for_null_condition(context, block, arguments, arg_cond, 
arg_then, arg_else,
                                        result) ||
@@ -537,7 +603,7 @@ public:
             return ret;
         }
 
-        const ColumnUInt8* cond_col = typeid_cast<const 
ColumnUInt8*>(arg_cond.column.get());
+        const auto* cond_col = typeid_cast<const 
ColumnUInt8*>(arg_cond.column.get());
         const ColumnConst* cond_const_col =
                 
check_and_get_column_const<ColumnVector<UInt8>>(arg_cond.column.get());
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to