This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 604842526b2 [improvement](expr) improve if expr performance (#27296) 604842526b2 is described below commit 604842526b2b1555bd2bb60d5de85d867c3e54b9 Author: TengJianPing <18241664+jackte...@users.noreply.github.com> AuthorDate: Wed Nov 22 12:48:06 2023 +0800 [improvement](expr) improve if expr performance (#27296) --- be/src/vec/core/accurate_comparison.h | 14 +++- be/src/vec/functions/functions_comparison.h | 14 ++-- be/src/vec/functions/if.cpp | 106 ++++++++++++++++++++++------ 3 files changed, 105 insertions(+), 29 deletions(-) diff --git a/be/src/vec/core/accurate_comparison.h b/be/src/vec/core/accurate_comparison.h index 71189d37d66..42dca505532 100644 --- a/be/src/vec/core/accurate_comparison.h +++ b/be/src/vec/core/accurate_comparison.h @@ -224,7 +224,7 @@ bool greaterOp(A a, B b) { } template <typename A, typename B> -bool greaterOrEqualsOp(A a, B b) { +inline bool_if_not_safe_conversion<A, B> greaterOrEqualsOp(A a, B b) { if (is_nan(a) || is_nan(b)) { return false; } @@ -233,7 +233,12 @@ bool greaterOrEqualsOp(A a, B b) { } template <typename A, typename B> -bool lessOrEqualsOp(A a, B b) { +inline bool_if_safe_conversion<A, B> greaterOrEqualsOp(A a, B b) { + return a >= b; +} + +template <typename A, typename B> +inline bool_if_not_safe_conversion<A, B> lessOrEqualsOp(A a, B b) { if (is_nan(a) || is_nan(b)) { return false; } @@ -241,6 +246,11 @@ bool lessOrEqualsOp(A a, B b) { return !lessOp(b, a); } +template <typename A, typename B> +inline bool_if_safe_conversion<A, B> lessOrEqualsOp(A a, B b) { + return a <= b; +} + template <typename A, typename B> bool equalsOp(A a, B b) { if constexpr (std::is_same_v<A, B>) { diff --git a/be/src/vec/functions/functions_comparison.h b/be/src/vec/functions/functions_comparison.h index 598bf01bfdd..6a0f3aa634d 100644 --- a/be/src/vec/functions/functions_comparison.h +++ b/be/src/vec/functions/functions_comparison.h @@ -61,10 +61,10 @@ struct NumComparisonImpl { static void NO_INLINE vector_vector(const PaddedPODArray<A>& a, const PaddedPODArray<B>& b, PaddedPODArray<UInt8>& c) { size_t size = a.size(); - const A* a_pos = a.data(); - const B* b_pos = b.data(); - UInt8* c_pos = c.data(); - const A* a_end = a_pos + size; + const A* __restrict a_pos = a.data(); + const B* __restrict b_pos = b.data(); + UInt8* __restrict c_pos = c.data(); + const A* __restrict a_end = a_pos + size; while (a_pos < a_end) { *c_pos = Op::apply(*a_pos, *b_pos); @@ -77,9 +77,9 @@ struct NumComparisonImpl { static void NO_INLINE vector_constant(const PaddedPODArray<A>& a, B b, PaddedPODArray<UInt8>& c) { size_t size = a.size(); - const A* a_pos = a.data(); - UInt8* c_pos = c.data(); - const A* a_end = a_pos + size; + const A* __restrict a_pos = a.data(); + UInt8* __restrict c_pos = c.data(); + const A* __restrict a_end = a_pos + size; while (a_pos < a_end) { *c_pos = Op::apply(*a_pos, b); diff --git a/be/src/vec/functions/if.cpp b/be/src/vec/functions/if.cpp index 9b14abce2ae..8f36e56d509 100644 --- a/be/src/vec/functions/if.cpp +++ b/be/src/vec/functions/if.cpp @@ -29,6 +29,7 @@ #include <utility> #include "common/status.h" +#include "util/simd/bits.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_const.h" @@ -49,7 +50,6 @@ #include "vec/functions/function.h" #include "vec/functions/function_helpers.h" #include "vec/functions/simple_function_factory.h" - namespace doris { class FunctionContext; @@ -140,6 +140,40 @@ public: } }; +size_t count_true_with_notnull(const ColumnPtr& col) { + if (col->only_null()) { + return 0; + } + + if (const auto* const_col = check_and_get_column_const<ColumnVector<UInt8>>(col.get())) { + bool is_true = const_col->get_bool(0); + return is_true ? col->size() : 0; + } + + auto count = col->size(); + if (col->is_nullable()) { + const auto* nullable = assert_cast<const ColumnNullable*>(col.get()); + const auto* __restrict null_data = nullable->get_null_map_data().data(); + const auto* __restrict bool_data = + ((const ColumnVector<UInt8>&)(nullable->get_nested_column())).get_data().data(); + + size_t null_count = count - simd::count_zero_num((const int8_t*)null_data, count); + + if (null_count == count) { + return 0; + } else if (null_count == 0) { + size_t true_count = count - simd::count_zero_num((const int8_t*)bool_data, count); + return true_count; + } else { + // In fact, the null_count maybe is different with true_count, but it's no impact + return null_count; + } + } else { + const auto* bool_col = typeid_cast<const ColumnUInt8*>(col.get()); + const auto* __restrict bool_data = bool_col->get_data().data(); + return count - simd::count_zero_num((const int8_t*)bool_data, count); + } +} // todo(wb) support llvm codegen class FunctionIf : public IFunction { public: @@ -306,7 +340,7 @@ public: return true; } - const ColumnUInt8* cond_col = typeid_cast<const ColumnUInt8*>(arg_cond.column.get()); + const auto* cond_col = typeid_cast<const ColumnUInt8*>(arg_cond.column.get()); const ColumnConst* cond_const_col = check_and_get_column_const<ColumnVector<UInt8>>(arg_cond.column.get()); @@ -347,15 +381,6 @@ public: if (else_is_null) { if (cond_col) { size_t size = input_rows_count; - auto& null_map_data = cond_col->get_data(); - - auto negated_null_map = ColumnUInt8::create(); - auto& negated_null_map_data = negated_null_map->get_data(); - negated_null_map_data.resize(size); - - for (size_t i = 0; i < size; ++i) { - negated_null_map_data[i] = !null_map_data[i]; - } if (is_column_nullable(*arg_then.column)) { // if(cond, nullable, NULL) auto arg_then_column = arg_then.column; @@ -365,6 +390,15 @@ public: assert_cast<const ColumnUInt8&>(*arg_cond.column)); block.replace_by_position(result, std::move(result_column)); } else { // if(cond, not_nullable, NULL) + const auto& null_map_data = cond_col->get_data(); + auto negated_null_map = ColumnUInt8::create(); + auto& negated_null_map_data = negated_null_map->get_data(); + negated_null_map_data.resize(size); + + for (size_t i = 0; i < size; ++i) { + negated_null_map_data[i] = !null_map_data[i]; + } + block.replace_by_position( result, ColumnNullable::create(materialize_column_if_const(arg_then.column), @@ -449,8 +483,8 @@ public: temporary_block.insert( {nullptr, std::make_shared<DataTypeUInt8>(), "result_column_null_map"}); - static_cast<void>( - execute_impl(context, temporary_block, {0, 1, 2}, 3, temporary_block.rows())); + static_cast<void>(_execute_impl_internal(context, temporary_block, {0, 1, 2}, 3, + temporary_block.rows())); result_null_mask = temporary_block.get_by_position(3).column; } @@ -464,8 +498,8 @@ public: {get_nested_column(arg_else.column), remove_nullable(arg_else.type), ""}, {nullptr, remove_nullable(block.get_by_position(result).type), ""}}); - static_cast<void>( - execute_impl(context, temporary_block, {0, 1, 2}, 3, temporary_block.rows())); + static_cast<void>(_execute_impl_internal(context, temporary_block, {0, 1, 2}, 3, + temporary_block.rows())); result_nested_column = temporary_block.get_by_position(3).column; } @@ -489,22 +523,22 @@ public: return true; } - if (auto* nullable = check_and_get_column<ColumnNullable>(*arg_cond.column)) { + if (const auto* nullable = check_and_get_column<ColumnNullable>(*arg_cond.column)) { DCHECK(remove_nullable(arg_cond.type)->get_type_id() == TypeIndex::UInt8); // update neseted column by nullmap - auto* __restrict null_map = nullable->get_null_map_data().data(); + const auto* __restrict null_map = nullable->get_null_map_data().data(); auto* __restrict nested_bool_data = ((ColumnVector<UInt8>&)(nullable->get_nested_column())).get_data().data(); auto rows = nullable->size(); for (size_t i = 0; i < rows; i++) { - nested_bool_data[i] = null_map[i] ? 0 : nested_bool_data[i]; + nested_bool_data[i] &= !null_map[i]; } auto column_size = block.columns(); block.insert({nullable->get_nested_column_ptr(), remove_nullable(arg_cond.type), arg_cond.name}); - static_cast<void>(execute_impl( + static_cast<void>(_execute_impl_internal( context, block, {column_size, arguments[1], arguments[2]}, result, rows)); return true; } @@ -527,6 +561,38 @@ public: cond_column.column = materialize_column_if_const(cond_column.column); const ColumnWithTypeAndName& arg_cond = block.get_by_position(arguments[0]); + auto true_count = count_true_with_notnull(arg_cond.column); + auto item_count = arg_cond.column->size(); + if (true_count == item_count || true_count == 0) { + bool result_nullable = block.get_by_position(result).type->is_nullable(); + if (true_count == item_count) { + block.replace_by_position( + result, + result_nullable + ? make_nullable(arg_then.column->clone_resized(input_rows_count)) + : arg_then.column->clone_resized(input_rows_count)); + } else { + block.replace_by_position( + result, + result_nullable + ? make_nullable(arg_else.column->clone_resized(input_rows_count)) + : arg_else.column->clone_resized(input_rows_count)); + } + return Status::OK(); + } + + return _execute_impl_internal(context, block, arguments, result, input_rows_count); + } + + Status _execute_impl_internal(FunctionContext* context, Block& block, + const ColumnNumbers& arguments, size_t result, + size_t input_rows_count) const { + const ColumnWithTypeAndName& arg_then = block.get_by_position(arguments[1]); + const ColumnWithTypeAndName& arg_else = block.get_by_position(arguments[2]); + ColumnWithTypeAndName& cond_column = block.get_by_position(arguments[0]); + cond_column.column = materialize_column_if_const(cond_column.column); + const ColumnWithTypeAndName& arg_cond = block.get_by_position(arguments[0]); + Status ret = Status::OK(); if (execute_for_null_condition(context, block, arguments, arg_cond, arg_then, arg_else, result) || @@ -537,7 +603,7 @@ public: return ret; } - const ColumnUInt8* cond_col = typeid_cast<const ColumnUInt8*>(arg_cond.column.get()); + const auto* cond_col = typeid_cast<const ColumnUInt8*>(arg_cond.column.get()); const ColumnConst* cond_const_col = check_and_get_column_const<ColumnVector<UInt8>>(arg_cond.column.get()); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org