This is an automated email from the ASF dual-hosted git repository.

alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go-pixiu.git


The following commit(s) were added to refs/heads/develop by this push:
     new d83df84e fix(ai-gateway): Add support for content encoding in 
tokenizer filter (#706)
d83df84e is described below

commit d83df84ec32628938916063c096773c6df5a2cb9
Author: Xuetao Li <[email protected]>
AuthorDate: Tue Jul 29 20:09:34 2025 +0800

    fix(ai-gateway): Add support for content encoding in tokenizer filter (#706)
    
    * feat: Add support for content encoding in stream processing and update 
tests
    
    * feat: Enhance content encoding support for unary responses
    
    * go mod tidy
    
    * fix copilot
    
    * fix comment
---
 pkg/common/constant/http.go                |   4 +
 pkg/filter/llm/tokenizer/tokenizer.go      |  98 ++++++++++----
 pkg/filter/llm/tokenizer/tokenizer_test.go | 206 +++++++++++++++++++++++------
 3 files changed, 239 insertions(+), 69 deletions(-)

diff --git a/pkg/common/constant/http.go b/pkg/common/constant/http.go
index b76dfeb0..767d49c0 100644
--- a/pkg/common/constant/http.go
+++ b/pkg/common/constant/http.go
@@ -23,6 +23,7 @@ const (
        HeaderKeyConnection       = "Connection"
        HeaderKeyTransferEncoding = "Transfer-Encoding"
        HeaderKeyContentLength    = "Content-Length"
+       HeaderKeyContentEncoding  = "Content-Encoding"
 
        HeaderKeyAccessControlAllowOrigin      = "Access-Control-Allow-Origin"
        HeaderKeyAccessControlAllowHeaders     = "Access-Control-Allow-Headers"
@@ -41,6 +42,9 @@ const (
        HeaderValueChunked                = "chunked"
        HeaderValueTextPrefix             = "text/"
 
+       HeaderValueGzip    = "gzip"
+       HeaderValueDeflate = "deflate"
+
        HeaderValueKeepAlive = "keep-alive"
        HeaderValueNoCache   = "no-cache"
 
diff --git a/pkg/filter/llm/tokenizer/tokenizer.go 
b/pkg/filter/llm/tokenizer/tokenizer.go
index b9a9a122..487dd5d6 100644
--- a/pkg/filter/llm/tokenizer/tokenizer.go
+++ b/pkg/filter/llm/tokenizer/tokenizer.go
@@ -18,7 +18,9 @@
 package tokenizer
 
 import (
-       "bufio"
+       "bytes"
+       "compress/flate"
+       "compress/gzip"
        "encoding/json"
        "fmt"
        "io"
@@ -30,7 +32,7 @@ import (
        "github.com/apache/dubbo-go-pixiu/pkg/client"
        "github.com/apache/dubbo-go-pixiu/pkg/common/constant"
        "github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter"
-       "github.com/apache/dubbo-go-pixiu/pkg/context/http"
+       contexthttp "github.com/apache/dubbo-go-pixiu/pkg/context/http"
        "github.com/apache/dubbo-go-pixiu/pkg/logger"
 )
 
@@ -56,7 +58,7 @@ type (
        Filter struct {
                cfg *Config
        }
-       // Config describe the config of FilterFactory
+       // Config describes the config of FilterFactory
        Config struct {
        }
 )
@@ -77,7 +79,7 @@ func (factory *FilterFactory) Apply() error {
        return nil
 }
 
-func (factory *FilterFactory) PrepareFilterChain(ctx *http.HttpContext, chain 
filter.FilterChain) error {
+func (factory *FilterFactory) PrepareFilterChain(ctx *contexthttp.HttpContext, 
chain filter.FilterChain) error {
        f := &Filter{
                cfg: factory.cfg,
        }
@@ -85,14 +87,16 @@ func (factory *FilterFactory) PrepareFilterChain(ctx 
*http.HttpContext, chain fi
        return nil
 }
 
-func (f *Filter) Encode(hc *http.HttpContext) filter.FilterStatus {
+func (f *Filter) Encode(hc *contexthttp.HttpContext) filter.FilterStatus {
+       encoding := hc.Writer.Header().Get(constant.HeaderKeyContentEncoding)
+
        switch res := hc.TargetResp.(type) {
        case *client.StreamResponse:
                pr, pw := io.Pipe()
                res.Stream = newTeeReadCloser(res.Stream, pw)
-               go f.processStreamResponse(pr)
+               go f.processStreamResponse(pr, encoding)
        case *client.UnaryResponse:
-               f.processUsageData(res.Data)
+               f.processUsageData(res.Data, encoding) // Unary response is not 
a stream
        default:
                logger.Warnf(LoggerFmt+"Response type not suitable for token 
calc: %T", res)
        }
@@ -100,34 +104,46 @@ func (f *Filter) Encode(hc *http.HttpContext) 
filter.FilterStatus {
        return filter.Continue
 }
 
-func (f *Filter) processStreamResponse(stream io.Reader) {
-       scanner := bufio.NewScanner(stream)
-       currentLine := make([]byte, 0, 1024)
-       // read the stream by line
-       // and process the data lines
-       // the data line is prefixed with "data:"
-       // the data line is a json string
-       // the for loop is to read the streamline by line and concat the 
separate "data:" lines
-       for scanner.Scan() {
-               line := scanner.Text()
-               line = strings.TrimSpace(line)
-               if strings.HasPrefix(line, "data:") {
-                       f.processUsageData(currentLine)
-                       currentLine = make([]byte, 0, 1024)
-                       line = strings.TrimPrefix(line, "data:")
+func (f *Filter) processStreamResponse(body io.Reader, encoding string) {
+       // For streams, we decompress the entire stream first, then process its 
content.
+       // The content itself (with "data:" prefixes) is passed to 
processUsageData.
+       decompressedData, ok := decompress(body, encoding)
+       if !ok {
+               return
+       }
+
+       decompressedDataTrim := strings.TrimPrefix(string(decompressedData), 
"data:")
+
+       // Now process the fully decompressed stream data
+       f.processUsageData([]byte(decompressedDataTrim), "")
+}
+
+func (f *Filter) processUsageData(data []byte, encoding string) {
+       processedData := data
+       // Decompress data if an encoding is specified (primarily for unary 
responses)
+       if encoding != "" {
+               bodyReader := bytes.NewReader(data)
+               if decompressedData, ok := decompress(bodyReader, encoding); ok 
{
+                       processedData = decompressedData
                }
-               currentLine = append(currentLine, line...)
        }
-       f.processUsageData(currentLine)
-       if err := scanner.Err(); err != nil && err != io.EOF {
-               logger.Errorf(LoggerFmt+"Error reading stream: %v", err)
+
+       if len(processedData) == 0 {
+               return
        }
+
+       f.parseAndLogUsage(processedData)
 }
 
-func (f *Filter) processUsageData(data []byte) {
+// parseAndLogUsage is a helper to parse the final JSON data and log it.
+func (f *Filter) parseAndLogUsage(data []byte) {
+       if len(data) == 0 {
+               return
+       }
        var dataCont map[string]any
        err := json.Unmarshal(data, &dataCont)
        if err != nil {
+               // Suppress unmarshal errors for potentially incomplete stream 
chunks
                return
        }
 
@@ -157,6 +173,34 @@ func (f *Filter) logUsage(usage map[string]any) {
        }
 }
 
+// getDecompressedReader returns an io.ReadCloser that decompresses the body 
based on the encoding.
+func getDecompressedReader(body io.Reader, encoding string) (io.ReadCloser, 
error) {
+       switch encoding {
+       case constant.HeaderValueGzip:
+               return gzip.NewReader(body)
+       case constant.HeaderValueDeflate:
+               return flate.NewReader(body), nil
+       default:
+               return io.NopCloser(body), nil
+       }
+}
+
+func decompress(body io.Reader, encoding string) ([]byte, bool) {
+       decompressedReader, err := getDecompressedReader(body, encoding)
+       if err != nil {
+               logger.Errorf(LoggerFmt+"%v", err)
+               return nil, false
+       }
+       defer decompressedReader.Close()
+
+       decompressedData, err := io.ReadAll(decompressedReader)
+       if err != nil {
+               logger.Errorf(LoggerFmt+"Error reading decompressed stream: 
%v", err)
+               return nil, false
+       }
+       return decompressedData, true
+}
+
 type teeReadCloser struct {
        reader   io.Reader
        closer   io.Closer
diff --git a/pkg/filter/llm/tokenizer/tokenizer_test.go 
b/pkg/filter/llm/tokenizer/tokenizer_test.go
index 7881099f..73200000 100644
--- a/pkg/filter/llm/tokenizer/tokenizer_test.go
+++ b/pkg/filter/llm/tokenizer/tokenizer_test.go
@@ -4,6 +4,7 @@
  * this work for additional information regarding copyright ownership.
  * The ASF licenses this file to You under the Apache License, Version 2.0
  * (the "License"); you may not use this file except in compliance with
+
  * the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
@@ -19,6 +20,8 @@ package tokenizer
 
 import (
        "bytes"
+       "compress/flate"
+       "compress/gzip"
        "io"
        "net/http"
        "strings"
@@ -32,55 +35,174 @@ import (
 
 import (
        "github.com/apache/dubbo-go-pixiu/pkg/client"
+       "github.com/apache/dubbo-go-pixiu/pkg/common/constant"
        "github.com/apache/dubbo-go-pixiu/pkg/context/mock"
 )
 
-func TestUnaryResponse(t *testing.T) {
-       filter := &Filter{}
-
-       request, err := http.NewRequest("POST", 
"http://www.dubbogopixiu.com/mock/test?name=tc";, 
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
-       assert.NoError(t, err)
-       c := mock.GetMockHTTPContext(request)
-       c.TargetResp = &client.UnaryResponse{
-               Data: []byte(`{
-               "usage": {
-                       "prompt_tokens": 7,
-                       "completion_tokens": 32,
-                       "total_tokens": 39,
-                       "prompt_tokens_details": {
-                               "cached_tokens": 0
+// TestUnaryResponseWithEncodings is a table-driven test for unary 
(non-streaming) responses.
+// It covers multiple content encodings like gzip and deflate.
+func TestUnaryResponseWithEncodings(t *testing.T) {
+       // This is the payload we expect to process after decompression.
+       const payload = `{
+       "usage": {
+          "prompt_tokens": 7
+       }
+    }`
+
+       // Helper function to compress data with gzip for our test case.
+       compressGzipBytes := func(data string) []byte {
+               var buf bytes.Buffer
+               writer := gzip.NewWriter(&buf)
+               _, err := writer.Write([]byte(data))
+               assert.NoError(t, err)
+               err = writer.Close()
+               assert.NoError(t, err)
+               return buf.Bytes()
+       }
+
+       // Helper function to compress data with flate/deflate for our test 
case.
+       compressFlateBytes := func(data string) []byte {
+               var buf bytes.Buffer
+               writer, err := flate.NewWriter(&buf, -1)
+               assert.NoError(t, err)
+               _, err = writer.Write([]byte(data))
+               assert.NoError(t, err)
+               err = writer.Close()
+               assert.NoError(t, err)
+               return buf.Bytes()
+       }
+
+       // Define all test cases in a table.
+       testCases := []struct {
+               name     string
+               encoding string
+               getData  func(string) []byte
+       }{
+               {
+                       name:     "No Encoding",
+                       encoding: "",
+                       getData: func(s string) []byte {
+                               return []byte(s)
                        },
-                       "prompt_cache_hit_tokens": 0,
-                       "prompt_cache_miss_tokens": 7
-               }
-       }`)}
-       filter.Encode(c)
+               },
+               {
+                       name:     "Gzip Encoding",
+                       encoding: "gzip",
+                       getData:  compressGzipBytes,
+               },
+               {
+                       name:     "Flate Encoding",
+                       encoding: "deflate",
+                       getData:  compressFlateBytes,
+               },
+       }
+
+       // Run the tests for each case.
+       for _, tc := range testCases {
+               t.Run(tc.name, func(t *testing.T) {
+                       filter := &Filter{}
+
+                       request, err := http.NewRequest("POST", 
"http://www.dubbogopixiu.com/mock/test?name=tc";, 
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
+                       assert.NoError(t, err)
+                       c := mock.GetMockHTTPContext(request)
+
+                       // Prepare the (potentially) compressed data
+                       compressedData := tc.getData(payload)
+
+                       c.TargetResp = &client.UnaryResponse{
+                               Data: compressedData,
+                       }
+                       c.AddHeader(constant.HeaderKeyContentEncoding, 
tc.encoding)
+
+                       // Call the filter's Encode method
+                       filter.Encode(c)
+               })
+       }
 }
 
-func TestStreamResponse(t *testing.T) {
-       filter := &Filter{}
-
-       request, err := http.NewRequest("POST", 
"http://www.dubbogopixiu.com/mock/test?name=tc";, 
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
-       assert.NoError(t, err)
-       c := mock.GetMockHTTPContext(request)
-       s := io.NopCloser(strings.NewReader(`data: {
-               "usage": {
-                       "prompt_tokens": 7,
-                       "completion_tokens": 32,
-                       "total_tokens": 39,
-                       "prompt_tokens_details": {
-                               "cached_tokens": 0
+// TestStreamResponseWithEncodings is a table-driven test for streaming 
responses.
+// It replaces the old TestStreamResponse.
+func TestStreamResponseWithEncodings(t *testing.T) {
+       // This is the payload we expect to process after decompression.
+       const payload = `data: {
+       "usage": {
+          "prompt_tokens": 7
+       }
+    }`
+
+       // Helper function to compress data with gzip for our test case.
+       compressGzip := func(data string) io.Reader {
+               var buf bytes.Buffer
+               writer := gzip.NewWriter(&buf)
+               _, err := writer.Write([]byte(data))
+               assert.NoError(t, err)
+               err = writer.Close() // IMPORTANT: Close flushes the writer.
+               assert.NoError(t, err)
+               return &buf
+       }
+
+       compressFlate := func(data string) io.Reader {
+               var buf bytes.Buffer
+               writer, _ := flate.NewWriter(&buf, -1)
+               _, err := writer.Write([]byte(data))
+               assert.NoError(t, err)
+               err = writer.Close()
+               assert.NoError(t, err)
+               return &buf
+       }
+
+       // Define all test cases in a table.
+       testCases := []struct {
+               name      string
+               encoding  string
+               getStream func(string) io.Reader
+       }{
+               {
+                       name:     "No Encoding",
+                       encoding: "",
+                       getStream: func(s string) io.Reader {
+                               return strings.NewReader(s)
                        },
-                       "prompt_cache_hit_tokens": 0,
-                       "prompt_cache_miss_tokens": 7
-               }
+               },
+               {
+                       name:      "Gzip Encoding",
+                       encoding:  "gzip",
+                       getStream: compressGzip,
+               },
+               {
+                       name:      "Flate Encoding",
+                       encoding:  "deflate",
+                       getStream: compressFlate,
+               },
        }
 
-`))
-       c.TargetResp = &client.StreamResponse{Stream: s}
-       filter.Encode(c)
-       buf := make([]byte, 1024)
-       c.TargetResp.(*client.StreamResponse).Stream.Read(buf)
-       time.Sleep(3 * time.Millisecond)
-       c.TargetResp.(*client.StreamResponse).Stream.Close()
+       // Run the tests for each case.
+       for _, tc := range testCases {
+               t.Run(tc.name, func(t *testing.T) {
+                       filter := &Filter{}
+
+                       req, err := http.NewRequest("POST", 
"http://www.dubbogopixiu.com/mock/test?name=tc";, 
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
+                       assert.NoError(t, err)
+                       ctx := mock.GetMockHTTPContext(req)
+
+                       // Prepare the compressed stream and the response header
+                       compressedStream := tc.getStream(payload)
+
+                       // Set up the mock response
+                       ctx.TargetResp = &client.StreamResponse{
+                               Stream: io.NopCloser(compressedStream),
+                       }
+
+                       ctx.AddHeader(constant.HeaderKeyContentEncoding, 
tc.encoding)
+
+                       // Call the filter's Encode method
+                       filter.Encode(ctx)
+
+                       // Give the goroutine a moment to process the data
+                       buf := make([]byte, 1024)
+                       ctx.TargetResp.(*client.StreamResponse).Stream.Read(buf)
+                       time.Sleep(5 * time.Millisecond)
+                       ctx.TargetResp.(*client.StreamResponse).Stream.Close()
+               })
+       }
 }

Reply via email to