This is an automated email from the ASF dual-hosted git repository. panxiaolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 6f15d7ce715 [Improvement](function) optimize for case when have many conditions (#51205) 6f15d7ce715 is described below commit 6f15d7ce7153c5c1e8c893036328c9952cf03e86 Author: Pxl <x...@selectdb.com> AuthorDate: Wed May 28 10:29:08 2025 +0800 [Improvement](function) optimize for case when have many conditions (#51205) ### What problem does this PR solve? select count(isnull(x)) from (select case i1 when -128 then -128 when -127 then -127 when -126 then -126 when -125 then -125 when -124 then -124 when -123 then -123 when -122 then -122 when -121 then -121 when -120 then -120 when -119 then -119 when -118 then -118 when -117 then -117 when -116 then -116 when -115 then -115 when -114 then -114 when -113 then -113 when -112 then -112 when -111 then -111 when -110 then -110 when -109 then -109 when -108 then -108 when -107 then -107 when -106 then -106 when -105 then -105 when -104 then -104 when -103 then -103 when -102 then -102 when -101 then -101 when -100 then -100 when -99 then -99 when -98 then -98 when -97 then -97 when -96 then -96 when -95 then -95 when -94 then -94 when -93 then -93 when -92 then -92 when -91 then -91 when -90 then -90 when -89 then -89 when -88 then -88 when -87 then -87 when -86 then -86 when -85 then -85 when -84 then -84 when -83 then -83 when -82 then -82 when -81 then -81 when -80 then -80 when -79 then -79 when -78 then -78 when -77 then -77 when -76 then -76 when -75 then -75 when -74 then -74 when -73 then -73 when -72 then -72 when -71 then -71 when -70 then -70 when -69 then -69 when -68 then -68 when -67 then -67 when -66 then -66 when -65 then -65 when -64 then -64 when -63 then -63 when -62 then -62 when -61 then -61 when -60 then -60 when -59 then -59 when -58 then -58 when -57 then -57 when -56 then -56 when -55 then -55 when -54 then -54 when -53 then -53 when -52 then -52 when -51 then -51 when -50 then -50 when -49 then -49 when -48 then -48 when -47 then -47 when -46 then -46 when -45 then -45 when -44 then -44 when -43 then -43 when -42 then -42 when -41 then -41 when -40 then -40 when -39 then -39 when -38 then -38 when -37 then -37 when -36 then -36 when -35 then -35 when -34 then -34 when -33 then -33 when -32 then -32 when -31 then -31 when -30 then -30 when -29 then -29 when -28 then -28 when -27 then -27 when -26 then -26 when -25 then -25 when -24 then -24 when -23 then -23 when -22 then -22 when -21 then -21 when -20 then -20 when -19 then -19 when -18 then -18 when -17 then -17 when -16 then -16 when -15 then -15 when -14 then -14 when -13 then -13 when -12 then -12 when -11 then -11 when -10 then -10 when -9 then -9 when -8 then -8 when -7 then -7 when -6 then -6 when -5 then -5 when -4 then -4 when -3 then -3 when -2 then -2 when -1 then -1 when 0 then 0 when 1 then 1 when 2 then 2 when 3 then 3 when 4 then 4 when 5 then 5 when 6 then 6 when 7 then 7 when 8 then 8 when 9 then 9 when 10 then 10 when 11 then 11 when 12 then 12 when 13 then 13 when 14 then 14 when 15 then 15 when 16 then 16 when 17 then 17 when 18 then 18 when 19 then 19 when 20 then 20 when 21 then 21 when 22 then 22 when 23 then 23 when 24 then 24 when 25 then 25 when 26 then 26 when 27 then 27 when 28 then 28 when 29 then 29 when 30 then 30 when 31 then 31 when 32 then 32 when 33 then 33 when 34 then 34 when 35 then 35 when 36 then 36 when 37 then 37 when 38 then 38 when 39 then 39 when 40 then 40 when 41 then 41 when 42 then 42 when 43 then 43 when 44 then 44 when 45 then 45 when 46 then 46 when 47 then 47 when 48 then 48 when 49 then 49 when 50 then 50 when 51 then 51 when 52 then 52 when 53 then 53 when 54 then 54 when 55 then 55 when 56 then 56 when 57 then 57 when 58 then 58 when 59 then 59 when 60 then 60 when 61 then 61 when 62 then 62 when 63 then 63 when 64 then 64 when 65 then 65 when 66 then 66 when 67 then 67 when 68 then 68 when 69 then 69 when 70 then 70 when 71 then 71 when 72 then 72 when 73 then 73 when 74 then 74 when 75 then 75 when 76 then 76 when 77 then 77 when 78 then 78 when 79 then 79 when 80 then 80 when 81 then 81 when 82 then 82 when 83 then 83 when 84 then 84 when 85 then 85 when 86 then 86 when 87 then 87 when 88 then 88 when 89 then 89 when 90 then 90 when 91 then 91 when 92 then 92 when 93 then 93 when 94 then 94 when 95 then 95 when 96 then 96 when 97 then 97 when 98 then 98 when 99 then 99 when 100 then 100 when 101 then 101 when 102 then 102 when 103 then 103 when 104 then 104 when 105 then 105 when 106 then 106 when 107 then 107 when 108 then 108 when 109 then 109 when 110 then 110 when 111 then 111 when 112 then 112 when 113 then 113 when 114 then 114 when 115 then 115 when 116 then 116 when 117 then 117 when 118 then 118 when 119 then 119 when 120 then 120 when 121 then 121 when 122 then 122 when 123 then 123 when 124 then 124 when 125 then 125 when 126 then 126 when 127 then 127 else null end as x from half_null)t; before:10.31 sec after:1.26 sec ### Check List (For Author) - Test <!-- At least one of them must be included. --> - [ ] Regression test - [ ] Unit Test - [ ] Manual test (add detailed scripts or steps below) - [ ] No need to test or manual test. Explain why: - [ ] This is a refactor/code format and no logic has been changed. - [ ] Previous test can cover this change. - [ ] No code files have been changed. - [ ] Other reason <!-- Add your reason? --> - Behavior changed: - [ ] No. - [ ] Yes. <!-- Explain the behavior change --> - Does this need documentation? - [ ] No. - [ ] Yes. <!-- Add document PR link here. eg: https://github.com/apache/doris-website/pull/1214 --> ### Check List (For Reviewer who merge this PR) - [ ] Confirm the release note - [ ] Confirm test cases - [ ] Confirm document - [ ] Add branch pick label <!-- Add branch pick label that this PR should merge into --> --- be/src/vec/columns/column_nullable.h | 5 +- be/src/vec/functions/function_case.cpp | 6 +- be/src/vec/functions/function_case.h | 201 ++++++++++++--------------------- 3 files changed, 77 insertions(+), 135 deletions(-) diff --git a/be/src/vec/columns/column_nullable.h b/be/src/vec/columns/column_nullable.h index 3d4e4a7da6f..7c18174b635 100644 --- a/be/src/vec/columns/column_nullable.h +++ b/be/src/vec/columns/column_nullable.h @@ -208,8 +208,9 @@ public: template <typename ColumnType> void insert_from_with_type(const IColumn& src, size_t n) { - const auto& src_concrete = assert_cast<const ColumnNullable&>(src); - assert_cast<ColumnType*>(nested_column.get()) + const auto& src_concrete = + assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(src); + assert_cast<ColumnType*, TypeCheckOnRelease::DISABLE>(nested_column.get()) ->insert_from(src_concrete.get_nested_column(), n); auto is_null = src_concrete.get_null_map_data()[n]; if (is_null) { diff --git a/be/src/vec/functions/function_case.cpp b/be/src/vec/functions/function_case.cpp index 8f0278b3170..20dbd51b56d 100644 --- a/be/src/vec/functions/function_case.cpp +++ b/be/src/vec/functions/function_case.cpp @@ -22,10 +22,8 @@ namespace doris::vectorized { void register_function_case(SimpleFunctionFactory& factory) { - factory.register_function<FunctionCase<false, false>>(); - factory.register_function<FunctionCase<false, true>>(); - factory.register_function<FunctionCase<true, false>>(); - factory.register_function<FunctionCase<true, true>>(); + factory.register_function<FunctionCase<false>>(); + factory.register_function<FunctionCase<true>>(); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/function_case.h b/be/src/vec/functions/function_case.h index c1a18c83414..e760c04ecda 100644 --- a/be/src/vec/functions/function_case.h +++ b/be/src/vec/functions/function_case.h @@ -17,7 +17,6 @@ #pragma once -#include <algorithm> #include <cstdint> #include <memory> #include <optional> @@ -35,6 +34,7 @@ #include "vec/columns/column_object.h" #include "vec/columns/column_struct.h" #include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" #include "vec/core/block.h" #include "vec/core/column_numbers.h" #include "vec/core/column_with_type_and_name.h" @@ -42,34 +42,23 @@ #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_nullable.h" #include "vec/functions/function.h" -#include "vec/utils/template_helpers.hpp" namespace doris::vectorized { #include "common/compile_check_begin.h" -template <bool has_case, bool has_else> +template <bool has_else> struct FunctionCaseName; template <> -struct FunctionCaseName<false, false> { +struct FunctionCaseName<false> { static constexpr auto name = "case"; }; template <> -struct FunctionCaseName<true, false> { - static constexpr auto name = "case_has_case"; -}; - -template <> -struct FunctionCaseName<false, true> { +struct FunctionCaseName<true> { static constexpr auto name = "case_has_else"; }; -template <> -struct FunctionCaseName<true, true> { - static constexpr auto name = "case_has_case_has_else"; -}; - struct CaseWhenColumnHolder { using OptionalPtr = std::optional<ColumnPtr>; @@ -79,16 +68,15 @@ struct CaseWhenColumnHolder { size_t rows_count; CaseWhenColumnHolder(Block& block, const ColumnNumbers& arguments, size_t input_rows_count, - bool has_case, bool has_else, bool when_null, bool then_null) + bool has_else, bool when_null, bool then_null) : rows_count(input_rows_count) { - when_ptrs.emplace_back(has_case ? OptionalPtr(block.get_by_position(arguments[0]).column) - : std::nullopt); + when_ptrs.emplace_back(std::nullopt); then_ptrs.emplace_back( has_else ? OptionalPtr(block.get_by_position(arguments[arguments.size() - 1]).column) : std::nullopt); - int begin = 0 + has_case; + int begin = 0; int end = cast_set<int>(arguments.size() - has_else); pair_count = (end - begin) / 2 + 1; // when/then at [1: pair_count) @@ -120,17 +108,17 @@ struct CaseWhenColumnHolder { } }; -template <bool has_case, bool has_else> +template <bool has_else> class FunctionCase : public IFunction { public: - static constexpr auto name = FunctionCaseName<has_case, has_else>::name; + static constexpr auto name = FunctionCaseName<has_else>::name; static FunctionPtr create() { return std::make_shared<FunctionCase>(); } String get_name() const override { return name; } size_t get_number_of_arguments() const override { return 0; } bool is_variadic() const override { return true; } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - int loop_start = has_case ? 2 : 1; + int loop_start = 1; int loop_end = cast_set<int>(has_else ? arguments.size() - 1 : arguments.size()); bool is_nullable = false; @@ -152,112 +140,59 @@ public: bool use_default_implementation_for_nulls() const override { return false; } - template <typename ColumnType, bool when_null, bool then_null> - Status execute_short_circuit(const DataTypePtr& data_type, Block& block, uint32_t result, - CaseWhenColumnHolder column_holder) const { - auto case_column_ptr = column_holder.when_ptrs[0].value_or(nullptr); - size_t rows_count = column_holder.rows_count; - - // `then` data index corresponding to each row of results, 0 represents `else`. - auto then_idx_uptr = std::unique_ptr<int[]>(new int[rows_count]); - int* __restrict then_idx_ptr = then_idx_uptr.get(); - memset(then_idx_ptr, 0, rows_count * sizeof(int)); - - for (int row_idx = 0; row_idx < column_holder.rows_count; row_idx++) { - for (int i = 1; i < column_holder.pair_count; i++) { - auto when_column_ptr = column_holder.when_ptrs[i].value(); - if constexpr (has_case) { - if (!case_column_ptr->is_null_at(row_idx) && - case_column_ptr->compare_at(row_idx, row_idx, *when_column_ptr, -1) == 0) { - then_idx_ptr[row_idx] = i; - break; - } - } else { - if (!then_idx_ptr[row_idx] && when_column_ptr->get_bool(row_idx)) { - then_idx_ptr[row_idx] = i; - break; - } - } - } - } - - auto result_column_ptr = data_type->create_column(); - update_result_normal<int, ColumnType, then_null>(result_column_ptr, then_idx_ptr, - column_holder); - block.replace_by_position(result, std::move(result_column_ptr)); - return Status::OK(); - } - - template <typename ColumnType, bool when_null, bool then_null> + template <typename IndexType, typename ColumnType, bool when_null, bool then_null> Status execute_impl(const DataTypePtr& data_type, Block& block, uint32_t result, CaseWhenColumnHolder column_holder) const { - if (column_holder.pair_count > UINT8_MAX) { - return execute_short_circuit<ColumnType, when_null, then_null>(data_type, block, result, - column_holder); - } - size_t rows_count = column_holder.rows_count; // `then` data index corresponding to each row of results, 0 represents `else`. - auto then_idx_uptr = std::unique_ptr<uint8_t[]>(new uint8_t[rows_count]); - uint8_t* __restrict then_idx_ptr = then_idx_uptr.get(); - memset(then_idx_ptr, 0, rows_count); + auto then_idx_uptr = std::unique_ptr<IndexType[]>(new IndexType[rows_count]); + IndexType* __restrict then_idx_ptr = then_idx_uptr.get(); + memset(then_idx_ptr, 0, sizeof(IndexType) * rows_count); auto case_column_ptr = column_holder.when_ptrs[0].value_or(nullptr); - for (uint8_t i = 1; i < column_holder.pair_count; i++) { + for (IndexType i = 1; i < column_holder.pair_count; i++) { auto when_column_ptr = column_holder.when_ptrs[i].value(); - if constexpr (has_case) { - // TODO: need simd + if constexpr (when_null) { + const auto* column_nullable_ptr = + assert_cast<const ColumnNullable*>(when_column_ptr.get()); + const auto* __restrict cond_raw_data = + assert_cast<const ColumnUInt8*>( + column_nullable_ptr->get_nested_column_ptr().get()) + ->get_data() + .data(); + const auto* __restrict cond_raw_nullmap = + assert_cast<const ColumnUInt8*>( + column_nullable_ptr->get_null_map_column_ptr().get()) + ->get_data() + .data(); + + // simd automatically for (int row_idx = 0; row_idx < rows_count; row_idx++) { - if (!then_idx_ptr[row_idx] && !case_column_ptr->is_null_at(row_idx) && - case_column_ptr->compare_at(row_idx, row_idx, *when_column_ptr, -1) == 0) { - then_idx_ptr[row_idx] = i; - } + then_idx_ptr[row_idx] |= (!then_idx_ptr[row_idx] * cond_raw_data[row_idx] * + !cond_raw_nullmap[row_idx]) * + i; } } else { - if constexpr (when_null) { - const auto* column_nullable_ptr = - assert_cast<const ColumnNullable*>(when_column_ptr.get()); - const auto* __restrict cond_raw_data = - assert_cast<const ColumnUInt8*>( - column_nullable_ptr->get_nested_column_ptr().get()) - ->get_data() - .data(); - const auto* __restrict cond_raw_nullmap = - assert_cast<const ColumnUInt8*>( - column_nullable_ptr->get_null_map_column_ptr().get()) - ->get_data() - .data(); - - // simd automatically - for (int row_idx = 0; row_idx < rows_count; row_idx++) { - then_idx_ptr[row_idx] |= (!then_idx_ptr[row_idx] * cond_raw_data[row_idx] * - !cond_raw_nullmap[row_idx]) * - i; - } - } else { - const auto* __restrict cond_raw_data = - assert_cast<const ColumnUInt8*>(when_column_ptr.get()) - ->get_data() - .data(); - - // simd automatically - for (int row_idx = 0; row_idx < rows_count; row_idx++) { - then_idx_ptr[row_idx] |= - (!then_idx_ptr[row_idx]) * cond_raw_data[row_idx] * i; - } + const auto* __restrict cond_raw_data = + assert_cast<const ColumnUInt8*>(when_column_ptr.get())->get_data().data(); + + // simd automatically + for (int row_idx = 0; row_idx < rows_count; row_idx++) { + then_idx_ptr[row_idx] |= (!then_idx_ptr[row_idx]) * cond_raw_data[row_idx] * i; } } } - return execute_update_result<ColumnType, then_null>(data_type, result, block, then_idx_ptr, - column_holder); + return execute_update_result<IndexType, ColumnType, then_null>(data_type, result, block, + then_idx_ptr, column_holder); } - template <typename ColumnType, bool then_null> + template <typename IndexType, typename ColumnType, bool then_null> Status execute_update_result(const DataTypePtr& data_type, uint32_t result, Block& block, - const uint8* then_idx, CaseWhenColumnHolder& column_holder) const { + const IndexType* then_idx, + CaseWhenColumnHolder& column_holder) const { auto result_column_ptr = data_type->create_column(); if constexpr (std::is_same_v<ColumnType, ColumnString> || @@ -272,13 +207,13 @@ public: std::is_same_v<ColumnType, ColumnIPv6>) { // result_column and all then_column is not nullable. // can't simd when type is string. - update_result_normal<uint8_t, ColumnType, then_null>(result_column_ptr, then_idx, - column_holder); - } else if constexpr (then_null) { + update_result_normal<IndexType, ColumnType, then_null>(result_column_ptr, then_idx, + column_holder); + } else if constexpr (then_null || !std::is_same_v<IndexType, uint8_t>) { // result_column and all then_column is nullable. // TODO: make here simd automatically. - update_result_normal<uint8_t, ColumnType, then_null>(result_column_ptr, then_idx, - column_holder); + update_result_normal<IndexType, ColumnType, then_null>(result_column_ptr, then_idx, + column_holder); } else { update_result_auto_simd<ColumnType>(result_column_ptr, then_idx, column_holder); } @@ -299,10 +234,13 @@ public: unpack_if_const(column_holder.then_ptrs[i].value()); } } + + auto* raw_result_column = result_column_ptr.get(); for (int row_idx = 0; row_idx < column_holder.rows_count; row_idx++) { if constexpr (!has_else) { if (!then_idx[row_idx]) { - result_column_ptr->insert_default(); + assert_cast<ColumnNullable*, TypeCheckOnRelease::DISABLE>(raw_result_column) + ->insert_default(); continue; } } @@ -357,7 +295,7 @@ public: const ColumnNumbers& arguments, uint32_t result, size_t input_rows_count) const { bool then_null = false; - for (int i = 1 + has_case; i < arguments.size() - has_else; i += 2) { + for (int i = 1; i < arguments.size() - has_else; i += 2) { if (block.get_by_position(arguments[i]).type->is_nullable()) { then_null = true; } @@ -371,14 +309,25 @@ public: } CaseWhenColumnHolder column_holder = CaseWhenColumnHolder( - block, arguments, input_rows_count, has_case, has_else, when_null, then_null); - + block, arguments, input_rows_count, has_else, when_null, then_null); + if (column_holder.pair_count > UINT16_MAX) { + return Status::NotSupported( + "case when do not support more than UINT16_MAX pairs conditions"); + } if (then_null) { - return execute_impl<ColumnType, when_null, true>(data_type, block, result, - column_holder); + if (column_holder.pair_count > UINT8_MAX) { + return execute_impl<uint16_t, ColumnType, when_null, true>(data_type, block, result, + column_holder); + } + return execute_impl<uint8_t, ColumnType, when_null, true>(data_type, block, result, + column_holder); } else { - return execute_impl<ColumnType, when_null, false>(data_type, block, result, - column_holder); + if (column_holder.pair_count > UINT8_MAX) { + return execute_impl<uint16_t, ColumnType, when_null, false>(data_type, block, + result, column_holder); + } + return execute_impl<uint8_t, ColumnType, when_null, false>(data_type, block, result, + column_holder); } } @@ -387,13 +336,7 @@ public: const ColumnNumbers& arguments, uint32_t result, size_t input_rows_count) const { bool when_null = false; - if constexpr (has_case) { - block.replace_by_position_if_const(arguments[0]); - if (block.get_by_position(arguments[0]).type->is_nullable()) { - when_null = true; - } - } - for (int i = has_case; i < arguments.size() - has_else; i += 2) { + for (int i = 0; i < arguments.size() - has_else; i += 2) { block.replace_by_position_if_const(arguments[i]); if (block.get_by_position(arguments[i]).type->is_nullable()) { when_null = true; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org