This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 73546b773a6 [Fix](AI) remove thread_pool in AI Functions (#56057)
73546b773a6 is described below

commit 73546b773a63f4acf92e7eb979148942a4491218
Author: linrrarity <[email protected]>
AuthorDate: Tue Sep 16 11:52:00 2025 +0800

    [Fix](AI) remove thread_pool in AI Functions (#56057)
    
    ### What problem does this PR solve?
    
    Issue Number: close #xxx
    
    Related PR: https://github.com/apache/doris/pull/55886
    
    Problem Summary:
    
    AI Function creates a thread pool in execute_impl. So
    `llm_max_concurrent_requests * CPU cores / 2 threads` will be used to
    call AI api for a single AI function. It's not our expectation.
    
    ### Release note
    
    None
    
    ### Check List (For Author)
    
    - Test <!-- At least one of them must be included. -->
        - [ ] Regression test
        - [ ] Unit Test
        - [ ] Manual test (add detailed scripts or steps below)
        - [x] No need to test or manual test. Explain why:
    - [ ] This is a refactor/code format and no logic has been changed.
            - [ ] Previous test can cover this change.
            - [ ] No code files have been changed.
            - [ ] Other reason <!-- Add your reason?  -->
    
    - Behavior changed:
        - [x] No.
        - [ ] Yes. <!-- Explain the behavior change -->
    
    - Does this need documentation?
        - [x] No.
    - [ ] Yes. <!-- Add document PR link here. eg:
    https://github.com/apache/doris-website/pull/1214 -->
    
    ### Check List (For Reviewer who merge this PR)
    
    - [ ] Confirm the release note
    - [ ] Confirm test cases
    - [ ] Confirm document
    - [ ] Add branch pick label <!-- Add branch pick label that this PR
    should merge into -->
---
 be/src/common/config.cpp                           |   3 -
 be/src/common/config.h                             |   3 -
 be/src/vec/functions/ai/ai_functions.h             | 149 ++++++---------------
 ...test.cpp => aggregate_function_ai_agg_test.cpp} |   3 +-
 ...{build_prompt_test.cpp => ai_function_test.cpp} |  27 ++++
 5 files changed, 69 insertions(+), 116 deletions(-)

diff --git a/be/src/common/config.cpp b/be/src/common/config.cpp
index 41d6e356032..ba21728d916 100644
--- a/be/src/common/config.cpp
+++ b/be/src/common/config.cpp
@@ -1579,9 +1579,6 @@ 
DEFINE_mBool(enable_auto_clone_on_mow_publish_missing_version, "false");
 // The maximum csv line reader output buffer size
 DEFINE_mInt64(max_csv_line_reader_output_buffer_size, "4294967296");
 
-// The maximum number of threads supported when executing LLMFunction
-DEFINE_mInt32(llm_max_concurrent_requests, "1");
-
 // Maximum number of openmp threads can be used by each doris threads.
 // This configuration controls the parallelism level for OpenMP operations 
within Doris,
 // helping to prevent resource contention and ensure stable performance when 
multiple
diff --git a/be/src/common/config.h b/be/src/common/config.h
index 2196550968f..959b46c9747 100644
--- a/be/src/common/config.h
+++ b/be/src/common/config.h
@@ -1636,9 +1636,6 @@ DECLARE_String(fuzzy_test_type);
 // The maximum csv line reader output buffer size
 DECLARE_mInt64(max_csv_line_reader_output_buffer_size);
 
-// The maximum number of threads supported when executing LLMFunction
-DECLARE_mInt32(llm_max_concurrent_requests);
-
 // Maximum number of OpenMP threads that can be used by each Doris thread
 DECLARE_Int32(omp_threads_limit);
 // The capacity of segment partial column cache, used to cache column readers 
for each segment.
diff --git a/be/src/vec/functions/ai/ai_functions.h 
b/be/src/vec/functions/ai/ai_functions.h
index ce9e77cc572..1dd8dcf79a3 100644
--- a/be/src/vec/functions/ai/ai_functions.h
+++ b/be/src/vec/functions/ai/ai_functions.h
@@ -75,23 +75,6 @@ public:
                 assert_cast<const 
Derived&>(*this).get_return_type_impl(DataTypes());
         MutableColumnPtr col_result = return_type_impl->create_column();
 
-        std::unique_ptr<ThreadPool> thread_pool;
-        Status st = ThreadPoolBuilder("LLMRequestPool")
-                            .set_min_threads(1)
-                            
.set_max_threads(config::llm_max_concurrent_requests > 0
-                                                     ? 
config::llm_max_concurrent_requests
-                                                     : 1)
-                            .build(&thread_pool);
-        if (!st.ok()) {
-            return Status::InternalError("Failed to create thread pool: " + 
st.to_string());
-        }
-
-        struct RowResult {
-            std::variant<std::string, std::vector<float>> data;
-            Status status;
-            bool is_null = false;
-        };
-
         TAIResource config;
         std::shared_ptr<AIAdapter> adapter;
         if (Status status =
@@ -101,114 +84,62 @@ public:
             return status;
         }
 
-        std::vector<RowResult> results(input_rows_count);
         for (size_t i = 0; i < input_rows_count; ++i) {
-            Status submit_status = thread_pool->submit_func([this, i, &block, 
&arguments, &results,
-                                                             &adapter, 
&config, context,
-                                                             
&return_type_impl]() {
-                RowResult& row_result = results[i];
-
-                try {
-                    // Build AI prompt text
-                    std::string prompt;
-                    Status status = assert_cast<const 
Derived&>(*this).build_prompt(
-                            block, arguments, i, prompt);
-
-                    if (!status.ok()) {
-                        row_result.status = status;
-                        row_result.is_null = true;
-                        return;
-                    }
-
-                    // Execute a single AI request and get the result
-                    if (return_type_impl->get_primitive_type() == 
PrimitiveType::TYPE_ARRAY) {
-                        std::vector<float> float_result;
-                        status = execute_single_request(prompt, float_result, 
config, adapter,
-                                                        context);
-                        if (!status.ok()) {
-                            row_result.status = status;
-                            row_result.is_null = true;
-                            return;
-                        }
-                        row_result.data = std::move(float_result);
-                    } else {
-                        std::string string_result;
-                        status = execute_single_request(prompt, string_result, 
config, adapter,
-                                                        context);
-                        if (!status.ok()) {
-                            row_result.status = status;
-                            row_result.is_null = true;
-                            return;
-                        }
-                        row_result.data = std::move(string_result);
-                    }
-                    row_result.status = Status::OK();
-                } catch (const std::exception& e) {
-                    row_result.status = Status::InternalError("Exception in AI 
request: " +
-                                                              
std::string(e.what()));
-                    row_result.is_null = true;
-                }
-            });
-
-            if (!submit_status.ok()) {
-                return Status::InternalError("Failed to submit task to thread 
pool: " +
-                                             submit_status.to_string());
-            }
-        }
-
-        thread_pool->wait();
-
-        for (size_t i = 0; i < input_rows_count; ++i) {
-            const RowResult& row_result = results[i];
-
-            if (!row_result.status.ok()) {
-                return row_result.status;
-            }
+            // Build AI prompt text
+            std::string prompt;
+            RETURN_IF_ERROR(
+                    assert_cast<const Derived&>(*this).build_prompt(block, 
arguments, i, prompt));
+
+            // Execute a single AI request and get the result
+            if (return_type_impl->get_primitive_type() == 
PrimitiveType::TYPE_ARRAY) {
+                // Array(Float) for AI_EMBED
+                std::vector<float> float_result;
+                RETURN_IF_ERROR(
+                        execute_single_request(prompt, float_result, config, 
adapter, context));
+
+                auto& col_array = assert_cast<ColumnArray&>(*col_result);
+                auto& offsets = col_array.get_offsets();
+                auto& nested_nullable_col = 
assert_cast<ColumnNullable&>(col_array.get_data());
+                auto& nested_col =
+                        
assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
+                nested_col.reserve(nested_col.size() + float_result.size());
+
+                size_t current_offset = nested_col.size();
+                nested_col.insert_many_raw_data(reinterpret_cast<const 
char*>(float_result.data()),
+                                                float_result.size());
+                offsets.push_back(current_offset + float_result.size());
+                auto& null_map = nested_nullable_col.get_null_map_column();
+                null_map.insert_many_vals(0, float_result.size());
+            } else {
+                std::string string_result;
+                RETURN_IF_ERROR(
+                        execute_single_request(prompt, string_result, config, 
adapter, context));
 
-            if (!row_result.is_null) {
                 switch (return_type_impl->get_primitive_type()) {
                 case PrimitiveType::TYPE_STRING: { // string
-                    const auto& str_data = 
std::get<std::string>(row_result.data);
                     assert_cast<ColumnString&>(*col_result)
-                            .insert_data(str_data.data(), str_data.size());
+                            .insert_data(string_result.data(), 
string_result.size());
                     break;
                 }
-                case PrimitiveType::TYPE_BOOLEAN: { // boolean
-                    const auto& bool_data = 
std::get<std::string>(row_result.data);
-                    if (bool_data != "true" && bool_data != "false") {
-                        return Status::RuntimeError("Failed to parse boolean 
value: " + bool_data);
+                case PrimitiveType::TYPE_BOOLEAN: { // boolean for AI_FILTER
+#ifdef BE_TEST
+                    string_result = "false";
+#endif
+                    if (string_result != "true" && string_result != "false") {
+                        return Status::RuntimeError("Failed to parse boolean 
value: " +
+                                                    string_result);
                     }
                     assert_cast<ColumnUInt8&>(*col_result)
-                            .insert_value(static_cast<UInt8>(bool_data == 
"true"));
-                    break;
-                }
-                case PrimitiveType::TYPE_FLOAT: { // float
-                    const auto& str_data = 
std::get<std::string>(row_result.data);
-                    
assert_cast<ColumnFloat32&>(*col_result).insert_value(std::stof(str_data));
+                            .insert_value(static_cast<UInt8>(string_result == 
"true"));
                     break;
                 }
-                case PrimitiveType::TYPE_ARRAY: { // array of floats
-                    const auto& float_data = 
std::get<std::vector<float>>(row_result.data);
-                    auto& col_array = assert_cast<ColumnArray&>(*col_result);
-                    auto& offsets = col_array.get_offsets();
-                    auto& nested_nullable_col = 
assert_cast<ColumnNullable&>(col_array.get_data());
-                    auto& nested_col = assert_cast<ColumnFloat32&>(
-                            *(nested_nullable_col.get_nested_column_ptr()));
-                    nested_col.reserve(nested_col.size() + float_data.size());
-
-                    size_t current_offset = nested_col.size();
-                    nested_col.insert_many_raw_data(
-                            reinterpret_cast<const char*>(float_data.data()), 
float_data.size());
-                    offsets.push_back(current_offset + float_data.size());
-                    auto& null_map = nested_nullable_col.get_null_map_column();
-                    null_map.insert_many_vals(0, float_data.size());
+                case PrimitiveType::TYPE_FLOAT: { // float for AI_SIMILARITY
+                    
assert_cast<ColumnFloat32&>(*col_result).insert_value(std::stof(string_result));
                     break;
                 }
                 default:
                     return Status::InternalError("Unsupported ReturnType for 
AIFunction");
                 }
-            } else {
-                col_result->insert_default();
             }
         }
 
diff --git a/be/test/ai/aggregate_function_llm_agg_test.cpp 
b/be/test/ai/aggregate_function_ai_agg_test.cpp
similarity index 99%
rename from be/test/ai/aggregate_function_llm_agg_test.cpp
rename to be/test/ai/aggregate_function_ai_agg_test.cpp
index 391f0686547..4374ff16c01 100644
--- a/be/test/ai/aggregate_function_llm_agg_test.cpp
+++ b/be/test/ai/aggregate_function_ai_agg_test.cpp
@@ -15,6 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.
 
+#include "vec/aggregate_functions/aggregate_function_ai_agg.h"
+
 #include <gmock/gmock-matchers.h>
 #include <gtest/gtest.h>
 
@@ -26,7 +28,6 @@
 #include "runtime/query_context.h"
 #include "testutil/column_helper.h"
 #include "testutil/mock/mock_runtime_state.h"
-#include "vec/aggregate_functions/aggregate_function_ai_agg.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
 #include "vec/columns/column_string.h"
 #include "vec/common/arena.h"
diff --git a/be/test/ai/build_prompt_test.cpp b/be/test/ai/ai_function_test.cpp
similarity index 93%
rename from be/test/ai/build_prompt_test.cpp
rename to be/test/ai/ai_function_test.cpp
index 9a38c071784..f288a4b7b92 100644
--- a/be/test/ai/build_prompt_test.cpp
+++ b/be/test/ai/ai_function_test.cpp
@@ -316,6 +316,33 @@ TEST(AIFunctionTest, AIFilterTest) {
     ASSERT_EQ(prompt, "This is a valid sentence.");
 }
 
+TEST(AIFunctionTest, AIFilterExecuteTest) {
+    auto runtime_state = std::make_unique<MockRuntimeState>();
+    auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});
+
+    std::vector<std::string> resources = {"mock_resource"};
+    std::vector<std::string> texts = {"This is a valid sentence."};
+    auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
+    auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
+
+    Block block;
+    block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), 
"resource"});
+    block.insert({std::move(col_text), std::make_shared<DataTypeString>(), 
"text"});
+    block.insert({nullptr, std::make_shared<DataTypeBool>(), "result"});
+
+    ColumnNumbers arguments = {0, 1};
+    size_t result_idx = 2;
+
+    auto filter_func = FunctionAIFilter::create();
+    Status exec_status =
+            filter_func->execute_impl(ctx.get(), block, arguments, result_idx, 
texts.size());
+
+    const auto& res_col =
+            assert_cast<const 
ColumnUInt8&>(*block.get_by_position(result_idx).column);
+    UInt8 val = res_col.get_data()[0];
+    ASSERT_TRUE(val == 0);
+}
+
 TEST(AIFunctionTest, ResourceNotFound) {
     auto runtime_state = std::make_unique<MockRuntimeState>();
     auto ctx = FunctionContext::create_context(runtime_state.get(), {}, {});


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to