This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ae4ff86439 [opt](function) Optimize the concat(col, constant, 
constant, constant) function (#40670)
2ae4ff86439 is described below

commit 2ae4ff86439e7532845963d470e1061ee62656ad
Author: Mryange <59914473+mrya...@users.noreply.github.com>
AuthorDate: Wed Sep 18 09:17:55 2024 +0800

    [opt](function) Optimize the concat(col, constant, constant, constant) 
function (#40670)
    
    ```
    mysql [test]>select count(concat(short, 
"123","121231","123123",'12312313')) from strings;
    +-------------------------------------------------------------+
    | count(concat(short, '123', '121231', '123123', '12312313')) |
    +-------------------------------------------------------------+
    |                                                    10000000 |
    +-------------------------------------------------------------+
    1 row in set (0.52 sec)
    
    mysql [test]>select count(concat(short, "123","121231","123123",'12312313' 
, short , short, short)) from strings;
    
+----------------------------------------------------------------------------------+
    | count(concat(short, '123', '121231', '123123', '12312313', short, short, 
short)) |
    
+----------------------------------------------------------------------------------+
    |                                                                         
10000000 |
    
+----------------------------------------------------------------------------------+
    1 row in set (0.98 sec)
    
    
    
    now
    
    mysql [test]>select count(concat(short, 
"123","121231","123123",'12312313')) from strings;
    +-------------------------------------------------------------+
    | count(concat(short, '123', '121231', '123123', '12312313')) |
    +-------------------------------------------------------------+
    |                                                    10000000 |
    +-------------------------------------------------------------+
    1 row in set (0.19 sec)
    
    mysql [test]>select count(concat(short, "123","121231","123123",'12312313' 
, short , short, short)) from strings;
    
+----------------------------------------------------------------------------------+
    | count(concat(short, '123', '121231', '123123', '12312313', short, short, 
short)) |
    
+----------------------------------------------------------------------------------+
    |                                                                         
10000000 |
    
+----------------------------------------------------------------------------------+
    1 row in set (0.71 sec)
    ```
---
 be/src/vec/functions/function_string.h  | 126 ++++++++++++++++++++++++++++----
 be/test/vec/core/column_string_test.cpp |  12 ++-
 2 files changed, 121 insertions(+), 17 deletions(-)

diff --git a/be/src/vec/functions/function_string.h 
b/be/src/vec/functions/function_string.h
index 160cc484a74..ef5122ac84d 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -1007,6 +1007,11 @@ public:
 
 class FunctionStringConcat : public IFunction {
 public:
+    struct ConcatState {
+        bool use_state = false;
+        std::string tail;
+    };
+
     static constexpr auto name = "concat";
     static FunctionPtr create() { return 
std::make_shared<FunctionStringConcat>(); }
     String get_name() const override { return name; }
@@ -1017,6 +1022,40 @@ public:
         return std::make_shared<DataTypeString>();
     }
 
+    Status open(FunctionContext* context, FunctionContext::FunctionStateScope 
scope) override {
+        if (scope == FunctionContext::THREAD_LOCAL) {
+            return Status::OK();
+        }
+        std::shared_ptr<ConcatState> state = std::make_shared<ConcatState>();
+
+        context->set_function_state(scope, state);
+
+        state->use_state = true;
+
+        // Optimize function calls like this:
+        // concat(col, "123", "abc", "456") -> tail = "123abc456"
+        for (size_t i = 1; i < context->get_num_args(); i++) {
+            const auto* column_string = context->get_constant_col(i);
+            if (column_string == nullptr) {
+                state->use_state = false;
+                return IFunction::open(context, scope);
+            }
+            auto string_vale = column_string->column_ptr->get_data_at(0);
+            if (string_vale.data == nullptr) {
+                // For concat(col, null), it is handled by 
default_implementation_for_nulls
+                state->use_state = false;
+                return IFunction::open(context, scope);
+            }
+
+            state->tail.append(string_vale.begin(), string_vale.size);
+        }
+
+        // The reserve is used here to allow the usage of 
memcpy_small_allow_read_write_overflow15 below.
+        state->tail.reserve(state->tail.size() + 16);
+
+        return IFunction::open(context, scope);
+    }
+
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
         DCHECK_GE(arguments.size(), 1);
@@ -1025,7 +1064,29 @@ public:
             block.get_by_position(result).column = 
block.get_by_position(arguments[0]).column;
             return Status::OK();
         }
+        auto* concat_state = reinterpret_cast<ConcatState*>(
+                context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
+        if (!concat_state) {
+            return Status::RuntimeError("funciton context for function '{}' 
must have ConcatState;",
+                                        get_name());
+        }
+        if (concat_state->use_state) {
+            const auto& [col, is_const] =
+                    
unpack_if_const(block.get_by_position(arguments[0]).column);
+            const auto* col_str = assert_cast<const ColumnString*>(col.get());
+            if (is_const) {
+                return execute_const<true>(concat_state, block, col_str, 
result, input_rows_count);
+            } else {
+                return execute_const<false>(concat_state, block, col_str, 
result, input_rows_count);
+            }
 
+        } else {
+            return execute_vecotr(block, arguments, result, input_rows_count);
+        }
+    }
+
+    Status execute_vecotr(Block& block, const ColumnNumbers& arguments, size_t 
result,
+                          size_t input_rows_count) const {
         int argument_size = arguments.size();
         std::vector<ColumnPtr> argument_columns(argument_size);
 
@@ -1048,18 +1109,12 @@ public:
         auto& res_offset = res->get_offsets();
 
         res_offset.resize(input_rows_count);
-
         size_t res_reserve_size = 0;
-        // we could ignore null string column
-        // but it's not necessary to ignore it
         for (size_t i = 0; i < argument_size; ++i) {
             if (is_const_args[i]) {
-                res_reserve_size +=
-                        ((*offsets_list[i])[0] - (*offsets_list[i])[-1]) * 
input_rows_count;
+                res_reserve_size += (*offsets_list[i])[0] * input_rows_count;
             } else {
-                for (size_t j = 0; j < input_rows_count; ++j) {
-                    res_reserve_size += (*offsets_list[i])[j] - 
(*offsets_list[i])[j - 1];
-                }
+                res_reserve_size += (*offsets_list[i])[input_rows_count - 1];
             }
         }
 
@@ -1067,24 +1122,65 @@ public:
 
         res_data.resize(res_reserve_size);
 
+        auto* data = res_data.data();
+        size_t dst_offset = 0;
+
         for (size_t i = 0; i < input_rows_count; ++i) {
-            int current_length = 0;
             for (size_t j = 0; j < argument_size; ++j) {
                 const auto& current_offsets = *offsets_list[j];
                 const auto& current_chars = *chars_list[j];
-
                 auto idx = index_check_const(i, is_const_args[j]);
-                auto size = current_offsets[idx] - current_offsets[idx - 1];
+                const auto size = current_offsets[idx] - current_offsets[idx - 
1];
                 if (size > 0) {
                     memcpy_small_allow_read_write_overflow15(
-                            &res_data[res_offset[i - 1]] + current_length,
-                            &current_chars[current_offsets[idx - 1]], size);
-                    current_length += size;
+                            data + dst_offset, current_chars.data() + 
current_offsets[idx - 1],
+                            size);
+                    dst_offset += size;
                 }
             }
-            res_offset[i] = res_offset[i - 1] + current_length;
+            res_offset[i] = dst_offset;
+        }
+
+        block.get_by_position(result).column = std::move(res);
+        return Status::OK();
+    }
+
+    template <bool is_const>
+    Status execute_const(ConcatState* concat_state, Block& block, const 
ColumnString* col_str,
+                         size_t result, size_t input_rows_count) const {
+        // using tail optimize
+
+        auto res = ColumnString::create();
+        auto& res_data = res->get_chars();
+        auto& res_offset = res->get_offsets();
+        res_offset.resize(input_rows_count);
+
+        size_t res_reserve_size = 0;
+        if constexpr (is_const) {
+            res_reserve_size = col_str->get_offsets()[0] * input_rows_count;
+        } else {
+            res_reserve_size = col_str->get_offsets()[input_rows_count - 1];
         }
+        res_reserve_size += concat_state->tail.size() * input_rows_count;
 
+        ColumnString::check_chars_length(res_reserve_size, 0);
+        res_data.resize(res_reserve_size);
+
+        const auto& tail = concat_state->tail;
+        auto* data = res_data.data();
+        size_t dst_offset = 0;
+
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            const auto idx = index_check_const<is_const>(i);
+            StringRef str_val = col_str->get_data_at(idx);
+            // copy column
+            memcpy_small_allow_read_write_overflow15(data + dst_offset, 
str_val.data, str_val.size);
+            dst_offset += str_val.size;
+            // copy tail
+            memcpy_small_allow_read_write_overflow15(data + dst_offset, 
tail.data(), tail.size());
+            dst_offset += tail.size();
+            res_offset[i] = dst_offset;
+        }
         block.get_by_position(result).column = std::move(res);
         return Status::OK();
     }
diff --git a/be/test/vec/core/column_string_test.cpp 
b/be/test/vec/core/column_string_test.cpp
index 81f41bd11c4..a1967a30ce7 100644
--- a/be/test/vec/core/column_string_test.cpp
+++ b/be/test/vec/core/column_string_test.cpp
@@ -48,8 +48,16 @@ TEST(ColumnStringTest, TestConcat) {
     ColumnNumbers arguments = {0, 1};
 
     FunctionStringConcat func_concat;
-    auto status = func_concat.execute_impl(nullptr, block, arguments, 2, 3);
-    EXPECT_TRUE(status.ok());
+    auto fn_ctx = FunctionContext::create_context(nullptr, TypeDescriptor {}, 
{});
+    {
+        auto status =
+                func_concat.open(fn_ctx.get(), 
FunctionContext::FunctionStateScope::FRAGMENT_LOCAL);
+        EXPECT_TRUE(status.ok());
+    }
+    {
+        auto status = func_concat.execute_impl(fn_ctx.get(), block, arguments, 
2, 3);
+        EXPECT_TRUE(status.ok());
+    }
 
     auto actual_res_col = block.get_by_position(2).column;
     EXPECT_EQ(actual_res_col->size(), 3);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to