linrrzqqq commented on code in PR #55017:
URL: https://github.com/apache/doris/pull/55017#discussion_r2298722873


##########
be/src/vec/functions/llm/llm_adapter.h:
##########
@@ -50,18 +52,150 @@ class LLMAdapter {
     virtual Status parse_response(const std::string& response_body,
                                   std::vector<std::string>& results) const = 0;
 
-    // Get the adapter type identifier
-    virtual std::string get_type() const = 0;
+    virtual Status build_embedding_request(const std::vector<std::string>& 
inputs,
+                                           std::string& request_body) const {
+        return Status::NotSupported("{} does not support the Embed feature.",
+                                    _config.provider_type);
+    }
+
+    virtual Status parse_embedding_response(const std::string& response_body,
+                                            std::vector<std::vector<float>>& 
results) const {
+        return Status::NotSupported("{} does not support the Embed feature.",
+                                    _config.provider_type);
+    }
 
 protected:
     TLLMResource _config;
+
+    // return true if the model support dimension parameter
+    virtual bool supports_dimension_param(const std::string& model_name) const 
{ return false; }
+
+    // Different providers may have different dimension parameter names.
+    virtual std::string get_dimension_param_name() const { return 
"dimensions"; }
+
+    virtual void add_dimension_params(rapidjson::Value& doc,
+                                      rapidjson::Document::AllocatorType& 
allocator) const {
+        if (_config.dimensions != -1 && 
supports_dimension_param(_config.model_name)) {
+            std::string param_name = get_dimension_param_name();
+            rapidjson::Value name(param_name.c_str(), allocator);
+            doc.AddMember(name, _config.dimensions, allocator);
+        }
+    }
+};
+
+// Most LLM-providers' Embedding formats are based on VoyageAI.
+// The following adapters inherit from VoyageAIAdapter to directly reuse its 
embedding logic.
+class VoyageAIAdapter : public LLMAdapter {
+public:
+    Status set_authentication(HttpClient* client) const override {
+        client->set_header(HttpHeaders::AUTHORIZATION, "Bearer " + 
_config.api_key);
+        client->set_content_type("application/json");
+
+        return Status::OK();
+    }
+
+    Status build_request_payload(const std::vector<std::string>& inputs,
+                                 const char* const system_prompt,
+                                 std::string& request_body) const override {
+        return Status::NotSupported("VoyageAI only support embedding 
function");
+    }
+
+    Status parse_response(const std::string& response_body,
+                          std::vector<std::string>& results) const override {
+        return Status::NotSupported("VoyageAI only support embedding 
function");
+    }
+
+    Status build_embedding_request(const std::vector<std::string>& inputs,
+                                   std::string& request_body) const override {
+        rapidjson::Document doc;
+        doc.SetObject();
+        auto& allocator = doc.GetAllocator();
+
+        /*{
+            "model": "xxx",
+            "input": [
+              "xxx",
+              "xxx",
+              ...
+            ],
+            "output_dimensions": 512
+        }*/
+        doc.AddMember("model", rapidjson::Value(_config.model_name.c_str(), 
allocator), allocator);
+        add_dimension_params(doc, allocator);
+
+        rapidjson::Value input(rapidjson::kArrayType);
+        for (const auto& msg : inputs) {
+            input.PushBack(rapidjson::Value(msg.c_str(), allocator), 
allocator);
+        }
+        doc.AddMember("input", input, allocator);
+
+        rapidjson::StringBuffer buffer;
+        rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
+        doc.Accept(writer);
+        request_body = buffer.GetString();
+
+        return Status::OK();
+    }
+
+    Status parse_embedding_response(const std::string& response_body,
+                                    std::vector<std::vector<float>>& results) 
const override {
+        rapidjson::Document doc;
+        doc.Parse(response_body.c_str());
+
+        if (doc.HasParseError() || !doc.IsObject()) {
+            return Status::InternalError("Failed to parse {} response: {}", 
_config.provider_type,
+                                         response_body);
+        }
+        if (!doc.HasMember("data") || !doc["data"].IsArray()) {
+            return Status::InternalError("Invalid {} response format: {}", 
_config.provider_type,
+                                         response_body);
+        }
+
+        /*{
+            "data":[
+              {
+                "object": "embedding",
+                "embedding": [...], <- only need this
+                "index": 0
+              },
+              {
+                "object": "embedding",
+                "embedding": [...],
+                "index": 1
+              }, ...
+            ],
+            "model"....
+        }*/
+        const auto& data = doc["data"];
+        results.reserve(data.Size());
+        for (rapidjson::SizeType i = 0; i < data.Size(); i++) {
+            if (!data[i].HasMember("embedding") || 
!data[i]["embedding"].IsArray()) {
+                return Status::InternalError("Invalid {} response format: {}",

Review Comment:
   the request was sent one row a time, so there will be no waste here for now. 
I will solve this before the batch supported.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to