This is an automated email from the ASF dual-hosted git repository.
alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go-pixiu.git
The following commit(s) were added to refs/heads/develop by this push:
new d83df84e fix(ai-gateway): Add support for content encoding in
tokenizer filter (#706)
d83df84e is described below
commit d83df84ec32628938916063c096773c6df5a2cb9
Author: Xuetao Li <[email protected]>
AuthorDate: Tue Jul 29 20:09:34 2025 +0800
fix(ai-gateway): Add support for content encoding in tokenizer filter (#706)
* feat: Add support for content encoding in stream processing and update
tests
* feat: Enhance content encoding support for unary responses
* go mod tidy
* fix copilot
* fix comment
---
pkg/common/constant/http.go | 4 +
pkg/filter/llm/tokenizer/tokenizer.go | 98 ++++++++++----
pkg/filter/llm/tokenizer/tokenizer_test.go | 206 +++++++++++++++++++++++------
3 files changed, 239 insertions(+), 69 deletions(-)
diff --git a/pkg/common/constant/http.go b/pkg/common/constant/http.go
index b76dfeb0..767d49c0 100644
--- a/pkg/common/constant/http.go
+++ b/pkg/common/constant/http.go
@@ -23,6 +23,7 @@ const (
HeaderKeyConnection = "Connection"
HeaderKeyTransferEncoding = "Transfer-Encoding"
HeaderKeyContentLength = "Content-Length"
+ HeaderKeyContentEncoding = "Content-Encoding"
HeaderKeyAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HeaderKeyAccessControlAllowHeaders = "Access-Control-Allow-Headers"
@@ -41,6 +42,9 @@ const (
HeaderValueChunked = "chunked"
HeaderValueTextPrefix = "text/"
+ HeaderValueGzip = "gzip"
+ HeaderValueDeflate = "deflate"
+
HeaderValueKeepAlive = "keep-alive"
HeaderValueNoCache = "no-cache"
diff --git a/pkg/filter/llm/tokenizer/tokenizer.go
b/pkg/filter/llm/tokenizer/tokenizer.go
index b9a9a122..487dd5d6 100644
--- a/pkg/filter/llm/tokenizer/tokenizer.go
+++ b/pkg/filter/llm/tokenizer/tokenizer.go
@@ -18,7 +18,9 @@
package tokenizer
import (
- "bufio"
+ "bytes"
+ "compress/flate"
+ "compress/gzip"
"encoding/json"
"fmt"
"io"
@@ -30,7 +32,7 @@ import (
"github.com/apache/dubbo-go-pixiu/pkg/client"
"github.com/apache/dubbo-go-pixiu/pkg/common/constant"
"github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter"
- "github.com/apache/dubbo-go-pixiu/pkg/context/http"
+ contexthttp "github.com/apache/dubbo-go-pixiu/pkg/context/http"
"github.com/apache/dubbo-go-pixiu/pkg/logger"
)
@@ -56,7 +58,7 @@ type (
Filter struct {
cfg *Config
}
- // Config describe the config of FilterFactory
+ // Config describes the config of FilterFactory
Config struct {
}
)
@@ -77,7 +79,7 @@ func (factory *FilterFactory) Apply() error {
return nil
}
-func (factory *FilterFactory) PrepareFilterChain(ctx *http.HttpContext, chain
filter.FilterChain) error {
+func (factory *FilterFactory) PrepareFilterChain(ctx *contexthttp.HttpContext,
chain filter.FilterChain) error {
f := &Filter{
cfg: factory.cfg,
}
@@ -85,14 +87,16 @@ func (factory *FilterFactory) PrepareFilterChain(ctx
*http.HttpContext, chain fi
return nil
}
-func (f *Filter) Encode(hc *http.HttpContext) filter.FilterStatus {
+func (f *Filter) Encode(hc *contexthttp.HttpContext) filter.FilterStatus {
+ encoding := hc.Writer.Header().Get(constant.HeaderKeyContentEncoding)
+
switch res := hc.TargetResp.(type) {
case *client.StreamResponse:
pr, pw := io.Pipe()
res.Stream = newTeeReadCloser(res.Stream, pw)
- go f.processStreamResponse(pr)
+ go f.processStreamResponse(pr, encoding)
case *client.UnaryResponse:
- f.processUsageData(res.Data)
+ f.processUsageData(res.Data, encoding) // Unary response is not
a stream
default:
logger.Warnf(LoggerFmt+"Response type not suitable for token
calc: %T", res)
}
@@ -100,34 +104,46 @@ func (f *Filter) Encode(hc *http.HttpContext)
filter.FilterStatus {
return filter.Continue
}
-func (f *Filter) processStreamResponse(stream io.Reader) {
- scanner := bufio.NewScanner(stream)
- currentLine := make([]byte, 0, 1024)
- // read the stream by line
- // and process the data lines
- // the data line is prefixed with "data:"
- // the data line is a json string
- // the for loop is to read the streamline by line and concat the
separate "data:" lines
- for scanner.Scan() {
- line := scanner.Text()
- line = strings.TrimSpace(line)
- if strings.HasPrefix(line, "data:") {
- f.processUsageData(currentLine)
- currentLine = make([]byte, 0, 1024)
- line = strings.TrimPrefix(line, "data:")
+func (f *Filter) processStreamResponse(body io.Reader, encoding string) {
+ // For streams, we decompress the entire stream first, then process its
content.
+ // The content itself (with "data:" prefixes) is passed to
processUsageData.
+ decompressedData, ok := decompress(body, encoding)
+ if !ok {
+ return
+ }
+
+ decompressedDataTrim := strings.TrimPrefix(string(decompressedData),
"data:")
+
+ // Now process the fully decompressed stream data
+ f.processUsageData([]byte(decompressedDataTrim), "")
+}
+
+func (f *Filter) processUsageData(data []byte, encoding string) {
+ processedData := data
+ // Decompress data if an encoding is specified (primarily for unary
responses)
+ if encoding != "" {
+ bodyReader := bytes.NewReader(data)
+ if decompressedData, ok := decompress(bodyReader, encoding); ok
{
+ processedData = decompressedData
}
- currentLine = append(currentLine, line...)
}
- f.processUsageData(currentLine)
- if err := scanner.Err(); err != nil && err != io.EOF {
- logger.Errorf(LoggerFmt+"Error reading stream: %v", err)
+
+ if len(processedData) == 0 {
+ return
}
+
+ f.parseAndLogUsage(processedData)
}
-func (f *Filter) processUsageData(data []byte) {
+// parseAndLogUsage is a helper to parse the final JSON data and log it.
+func (f *Filter) parseAndLogUsage(data []byte) {
+ if len(data) == 0 {
+ return
+ }
var dataCont map[string]any
err := json.Unmarshal(data, &dataCont)
if err != nil {
+ // Suppress unmarshal errors for potentially incomplete stream
chunks
return
}
@@ -157,6 +173,34 @@ func (f *Filter) logUsage(usage map[string]any) {
}
}
+// getDecompressedReader returns an io.ReadCloser that decompresses the body
based on the encoding.
+func getDecompressedReader(body io.Reader, encoding string) (io.ReadCloser,
error) {
+ switch encoding {
+ case constant.HeaderValueGzip:
+ return gzip.NewReader(body)
+ case constant.HeaderValueDeflate:
+ return flate.NewReader(body), nil
+ default:
+ return io.NopCloser(body), nil
+ }
+}
+
+func decompress(body io.Reader, encoding string) ([]byte, bool) {
+ decompressedReader, err := getDecompressedReader(body, encoding)
+ if err != nil {
+ logger.Errorf(LoggerFmt+"%v", err)
+ return nil, false
+ }
+ defer decompressedReader.Close()
+
+ decompressedData, err := io.ReadAll(decompressedReader)
+ if err != nil {
+ logger.Errorf(LoggerFmt+"Error reading decompressed stream:
%v", err)
+ return nil, false
+ }
+ return decompressedData, true
+}
+
type teeReadCloser struct {
reader io.Reader
closer io.Closer
diff --git a/pkg/filter/llm/tokenizer/tokenizer_test.go
b/pkg/filter/llm/tokenizer/tokenizer_test.go
index 7881099f..73200000 100644
--- a/pkg/filter/llm/tokenizer/tokenizer_test.go
+++ b/pkg/filter/llm/tokenizer/tokenizer_test.go
@@ -4,6 +4,7 @@
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
+
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
@@ -19,6 +20,8 @@ package tokenizer
import (
"bytes"
+ "compress/flate"
+ "compress/gzip"
"io"
"net/http"
"strings"
@@ -32,55 +35,174 @@ import (
import (
"github.com/apache/dubbo-go-pixiu/pkg/client"
+ "github.com/apache/dubbo-go-pixiu/pkg/common/constant"
"github.com/apache/dubbo-go-pixiu/pkg/context/mock"
)
-func TestUnaryResponse(t *testing.T) {
- filter := &Filter{}
-
- request, err := http.NewRequest("POST",
"http://www.dubbogopixiu.com/mock/test?name=tc",
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
- assert.NoError(t, err)
- c := mock.GetMockHTTPContext(request)
- c.TargetResp = &client.UnaryResponse{
- Data: []byte(`{
- "usage": {
- "prompt_tokens": 7,
- "completion_tokens": 32,
- "total_tokens": 39,
- "prompt_tokens_details": {
- "cached_tokens": 0
+// TestUnaryResponseWithEncodings is a table-driven test for unary
(non-streaming) responses.
+// It covers multiple content encodings like gzip and deflate.
+func TestUnaryResponseWithEncodings(t *testing.T) {
+ // This is the payload we expect to process after decompression.
+ const payload = `{
+ "usage": {
+ "prompt_tokens": 7
+ }
+ }`
+
+ // Helper function to compress data with gzip for our test case.
+ compressGzipBytes := func(data string) []byte {
+ var buf bytes.Buffer
+ writer := gzip.NewWriter(&buf)
+ _, err := writer.Write([]byte(data))
+ assert.NoError(t, err)
+ err = writer.Close()
+ assert.NoError(t, err)
+ return buf.Bytes()
+ }
+
+ // Helper function to compress data with flate/deflate for our test
case.
+ compressFlateBytes := func(data string) []byte {
+ var buf bytes.Buffer
+ writer, err := flate.NewWriter(&buf, -1)
+ assert.NoError(t, err)
+ _, err = writer.Write([]byte(data))
+ assert.NoError(t, err)
+ err = writer.Close()
+ assert.NoError(t, err)
+ return buf.Bytes()
+ }
+
+ // Define all test cases in a table.
+ testCases := []struct {
+ name string
+ encoding string
+ getData func(string) []byte
+ }{
+ {
+ name: "No Encoding",
+ encoding: "",
+ getData: func(s string) []byte {
+ return []byte(s)
},
- "prompt_cache_hit_tokens": 0,
- "prompt_cache_miss_tokens": 7
- }
- }`)}
- filter.Encode(c)
+ },
+ {
+ name: "Gzip Encoding",
+ encoding: "gzip",
+ getData: compressGzipBytes,
+ },
+ {
+ name: "Flate Encoding",
+ encoding: "deflate",
+ getData: compressFlateBytes,
+ },
+ }
+
+ // Run the tests for each case.
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ filter := &Filter{}
+
+ request, err := http.NewRequest("POST",
"http://www.dubbogopixiu.com/mock/test?name=tc",
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
+ assert.NoError(t, err)
+ c := mock.GetMockHTTPContext(request)
+
+ // Prepare the (potentially) compressed data
+ compressedData := tc.getData(payload)
+
+ c.TargetResp = &client.UnaryResponse{
+ Data: compressedData,
+ }
+ c.AddHeader(constant.HeaderKeyContentEncoding,
tc.encoding)
+
+ // Call the filter's Encode method
+ filter.Encode(c)
+ })
+ }
}
-func TestStreamResponse(t *testing.T) {
- filter := &Filter{}
-
- request, err := http.NewRequest("POST",
"http://www.dubbogopixiu.com/mock/test?name=tc",
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
- assert.NoError(t, err)
- c := mock.GetMockHTTPContext(request)
- s := io.NopCloser(strings.NewReader(`data: {
- "usage": {
- "prompt_tokens": 7,
- "completion_tokens": 32,
- "total_tokens": 39,
- "prompt_tokens_details": {
- "cached_tokens": 0
+// TestStreamResponseWithEncodings is a table-driven test for streaming
responses.
+// It replaces the old TestStreamResponse.
+func TestStreamResponseWithEncodings(t *testing.T) {
+ // This is the payload we expect to process after decompression.
+ const payload = `data: {
+ "usage": {
+ "prompt_tokens": 7
+ }
+ }`
+
+ // Helper function to compress data with gzip for our test case.
+ compressGzip := func(data string) io.Reader {
+ var buf bytes.Buffer
+ writer := gzip.NewWriter(&buf)
+ _, err := writer.Write([]byte(data))
+ assert.NoError(t, err)
+ err = writer.Close() // IMPORTANT: Close flushes the writer.
+ assert.NoError(t, err)
+ return &buf
+ }
+
+ compressFlate := func(data string) io.Reader {
+ var buf bytes.Buffer
+ writer, _ := flate.NewWriter(&buf, -1)
+ _, err := writer.Write([]byte(data))
+ assert.NoError(t, err)
+ err = writer.Close()
+ assert.NoError(t, err)
+ return &buf
+ }
+
+ // Define all test cases in a table.
+ testCases := []struct {
+ name string
+ encoding string
+ getStream func(string) io.Reader
+ }{
+ {
+ name: "No Encoding",
+ encoding: "",
+ getStream: func(s string) io.Reader {
+ return strings.NewReader(s)
},
- "prompt_cache_hit_tokens": 0,
- "prompt_cache_miss_tokens": 7
- }
+ },
+ {
+ name: "Gzip Encoding",
+ encoding: "gzip",
+ getStream: compressGzip,
+ },
+ {
+ name: "Flate Encoding",
+ encoding: "deflate",
+ getStream: compressFlate,
+ },
}
-`))
- c.TargetResp = &client.StreamResponse{Stream: s}
- filter.Encode(c)
- buf := make([]byte, 1024)
- c.TargetResp.(*client.StreamResponse).Stream.Read(buf)
- time.Sleep(3 * time.Millisecond)
- c.TargetResp.(*client.StreamResponse).Stream.Close()
+ // Run the tests for each case.
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ filter := &Filter{}
+
+ req, err := http.NewRequest("POST",
"http://www.dubbogopixiu.com/mock/test?name=tc",
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
+ assert.NoError(t, err)
+ ctx := mock.GetMockHTTPContext(req)
+
+ // Prepare the compressed stream and the response header
+ compressedStream := tc.getStream(payload)
+
+ // Set up the mock response
+ ctx.TargetResp = &client.StreamResponse{
+ Stream: io.NopCloser(compressedStream),
+ }
+
+ ctx.AddHeader(constant.HeaderKeyContentEncoding,
tc.encoding)
+
+ // Call the filter's Encode method
+ filter.Encode(ctx)
+
+ // Give the goroutine a moment to process the data
+ buf := make([]byte, 1024)
+ ctx.TargetResp.(*client.StreamResponse).Stream.Read(buf)
+ time.Sleep(5 * time.Millisecond)
+ ctx.TargetResp.(*client.StreamResponse).Stream.Close()
+ })
+ }
}