This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new fd99530a4 feat(go): add configurable size guardrails (#3475)
fd99530a4 is described below
commit fd99530a42b71813cdb7264a0b32aaedadd392de
Author: Ayush Kumar <[email protected]>
AuthorDate: Thu Mar 19 08:35:17 2026 +0530
feat(go): add configurable size guardrails (#3475)
## Why?
go deserialization didn't have any configured guardrails for untrusted
paylaods which leads to high memory pressure while allocation and Out of
memory attacks.
## What does this PR do?
Added two configurable guardrails `MaxCollectionSize` and
`MaxBinarySize`, and implement size guardrails across the fory codegen.
1. Configuration: Added `MaxCollectionSize` and `MaxBinarySize` to
`Config` struct with corresponding options `WithMaxCollectionSize` and
`WithMaxBinarySize`.
2. Added `ReadCollectionLength` and `ReadBinaryLength` to `ByteBuffer`
and `ReadContext`.
3. These methods enforce configured limits and return specialized
errors: `ErrKindMaxCollectionSizeExceeded` and
`ErrKindMaxBinarySizeExceeded`.
4. Updated the Go code generator (`codegen/decoder.go`) to use these
guarded length methods in generated serializers.
5. Removed the generic `ReadLength` method to ensure all length-reading
paths in the codebase are subject to guardrails.
## Related issues
Closes #3419
## AI Contribution Checklist
- [x] Substantial AI assistance was used in this PR: `yes`
- [x] If `yes`, I included a completed [AI Contribution
Checklist](https://github.com/apache/fory/blob/main/AI_POLICY.md#9-contributor-checklist-for-ai-assisted-prs)
in this PR description and the required `AI Usage Disclosure`.
- [x] If `yes`, I included the standardized `AI Usage Disclosure` block
below.
- [x] If `yes`, I can explain and defend all important changes without
AI help.
- [x] If `yes`, I reviewed AI-assisted code changes line by line before
submission.
- [x] If `yes`, I ran adequate human verification and recorded evidence
(checks run locally or in CI, pass/fail summary, and confirmation I
reviewed results).
- [x] If `yes`, I added/updated tests and specs where required.
- [x] If `yes`, I validated protocol/performance impacts with evidence
when applicable.
- [x] If `yes`, I verified licensing and provenance compliance.
```text
AI Usage Disclosure
I used AI to find and replace the multiple iterations of `ReadLength` by
the specific `ReadCollectionSize` / `ReadBinarySize` across the go runtime.
Also I used it to fix some errors during running tests.
I can still explain all of my work, as everything is tested by me.
```
## Does this PR introduce any user-facing change?
- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
---
go/fory/array.go | 3 +-
go/fory/array_primitive.go | 26 ++++++------
go/fory/codegen/decoder.go | 20 ++++-----
go/fory/errors.go | 24 +++++++++++
go/fory/fory.go | 34 ++++++++++++---
go/fory/limit_test.go | 103 +++++++++++++++++++++++++++++++++++++++++++++
go/fory/map.go | 2 +-
go/fory/map_primitive.go | 70 ++++++++++++++++++------------
go/fory/reader.go | 80 +++++++++++++++++++++++------------
go/fory/set.go | 2 +-
go/fory/skip.go | 14 +++---
go/fory/slice.go | 2 +-
go/fory/slice_dyn.go | 2 +-
go/fory/slice_primitive.go | 10 ++---
14 files changed, 290 insertions(+), 102 deletions(-)
diff --git a/go/fory/array.go b/go/fory/array.go
index 9b8b9c17d..6cb1751f3 100644
--- a/go/fory/array.go
+++ b/go/fory/array.go
@@ -365,8 +365,7 @@ func (s byteArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeType
func (s byteArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
- err := ctx.Err()
- length := buf.ReadLength(err)
+ length := ctx.ReadCollectionLength()
if ctx.HasError() {
return
}
diff --git a/go/fory/array_primitive.go b/go/fory/array_primitive.go
index 27813060b..06c76dc78 100644
--- a/go/fory/array_primitive.go
+++ b/go/fory/array_primitive.go
@@ -66,7 +66,7 @@ func (s boolArraySerializer) Write(ctx *WriteContext, refMode
RefMode, writeType
func (s boolArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
- length := buf.ReadLength(err)
+ length := ctx.ReadBinaryLength()
if ctx.HasError() {
return
}
@@ -131,7 +131,7 @@ func (s int8ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeType
func (s int8ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
- length := buf.ReadLength(err)
+ length := ctx.ReadBinaryLength()
if ctx.HasError() {
return
}
@@ -197,7 +197,7 @@ func (s int16ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeTyp
func (s int16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
- size := buf.ReadLength(err)
+ size := ctx.ReadBinaryLength()
length := size / 2
if ctx.HasError() {
return
@@ -269,7 +269,7 @@ func (s int32ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeTyp
func (s int32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
- size := buf.ReadLength(err)
+ size := ctx.ReadBinaryLength()
length := size / 4
if ctx.HasError() {
return
@@ -341,7 +341,7 @@ func (s int64ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeTyp
func (s int64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
- size := buf.ReadLength(err)
+ size := ctx.ReadBinaryLength()
length := size / 8
if ctx.HasError() {
return
@@ -413,7 +413,7 @@ func (s float32ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeT
func (s float32ArraySerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
- size := buf.ReadLength(err)
+ size := ctx.ReadBinaryLength()
length := size / 4
if ctx.HasError() {
return
@@ -485,7 +485,7 @@ func (s float64ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeT
func (s float64ArraySerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
- size := buf.ReadLength(err)
+ size := ctx.ReadBinaryLength()
length := size / 8
if ctx.HasError() {
return
@@ -556,7 +556,7 @@ func (s uint8ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeTyp
func (s uint8ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
- length := buf.ReadLength(err)
+ length := ctx.ReadBinaryLength()
if ctx.HasError() {
return
}
@@ -623,7 +623,7 @@ func (s uint16ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeTy
func (s uint16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value)
{
buf := ctx.Buffer()
err := ctx.Err()
- size := buf.ReadLength(err)
+ size := ctx.ReadBinaryLength()
length := size / 2
if ctx.HasError() {
return
@@ -694,7 +694,7 @@ func (s uint32ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeTy
func (s uint32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value)
{
buf := ctx.Buffer()
err := ctx.Err()
- size := buf.ReadLength(err)
+ size := ctx.ReadBinaryLength()
length := size / 4
if ctx.HasError() {
return
@@ -764,7 +764,7 @@ func (s uint64ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeTy
func (s uint64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value)
{
buf := ctx.Buffer()
err := ctx.Err()
- size := buf.ReadLength(err)
+ size := ctx.ReadBinaryLength()
length := size / 8
if ctx.HasError() {
return
@@ -838,7 +838,7 @@ func (s float16ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, writeT
func (s float16ArraySerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
- size := buf.ReadLength(ctxErr)
+ size := ctx.ReadBinaryLength()
length := size / 2
if ctx.HasError() {
return
@@ -912,7 +912,7 @@ func (s bfloat16ArraySerializer) Write(ctx *WriteContext,
refMode RefMode, write
func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
- size := buf.ReadLength(ctxErr)
+ size := ctx.ReadBinaryLength()
length := size / 2
if ctx.HasError() {
return
diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go
index d3713a243..8e580b053 100644
--- a/go/fory/codegen/decoder.go
+++ b/go/fory/codegen/decoder.go
@@ -168,7 +168,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field
*FieldInfo) error {
fmt.Fprintf(buf, "\t\tisXlang :=
ctx.TypeResolver().IsXlang()\n")
fmt.Fprintf(buf, "\t\tif isXlang {\n")
fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not
nullable, read directly without null flag\n")
- fmt.Fprintf(buf, "\t\t\tsliceLen :=
int(buf.ReadVarUint32(err))\n")
+ fmt.Fprintf(buf, "\t\t\tsliceLen :=
ctx.ReadCollectionLength()\n")
fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, 0)\n",
fieldAccess)
fmt.Fprintf(buf, "\t\t\t} else {\n")
@@ -187,7 +187,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field
*FieldInfo) error {
fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") //
NullFlag
fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess)
fmt.Fprintf(buf, "\t\t\t} else {\n")
- fmt.Fprintf(buf, "\t\t\t\tsliceLen :=
int(buf.ReadVarUint32(err))\n")
+ fmt.Fprintf(buf, "\t\t\t\tsliceLen :=
ctx.ReadCollectionLength()\n")
fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, 0)\n",
fieldAccess)
fmt.Fprintf(buf, "\t\t\t\t} else {\n")
@@ -517,7 +517,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType
*types.Slice, fieldAcc
fmt.Fprintf(buf, "\t\tisXlang := ctx.TypeResolver().IsXlang()\n")
fmt.Fprintf(buf, "\t\tif isXlang {\n")
fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read
directly without null flag\n")
- fmt.Fprintf(buf, "\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n")
+ fmt.Fprintf(buf, "\t\t\tsliceLen := ctx.ReadCollectionLength()\n")
fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t%s = make(%s, 0)\n", fieldAccess,
sliceType.String())
fmt.Fprintf(buf, "\t\t\t} else {\n")
@@ -532,7 +532,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType
*types.Slice, fieldAcc
fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") // NullFlag
fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess)
fmt.Fprintf(buf, "\t\t\t} else {\n")
- fmt.Fprintf(buf, "\t\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n")
+ fmt.Fprintf(buf, "\t\t\t\tsliceLen := ctx.ReadCollectionLength()\n")
fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s, 0)\n", fieldAccess,
sliceType.String())
fmt.Fprintf(buf, "\t\t\t\t} else {\n")
@@ -555,7 +555,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer,
sliceType *types.Slice, fi
unwrappedElem := types.Unalias(elemType)
if iface, ok := unwrappedElem.(*types.Interface); ok && iface.Empty() {
fmt.Fprintf(buf, "%s// Dynamic slice []any handling - no null
flag\n", indent)
- fmt.Fprintf(buf, "%ssliceLen := int(buf.ReadVarUint32(err))\n",
indent)
+ fmt.Fprintf(buf, "%ssliceLen := ctx.ReadCollectionLength()\n",
indent)
fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent)
fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent,
fieldAccess)
fmt.Fprintf(buf, "%s} else {\n", indent)
@@ -573,7 +573,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer,
sliceType *types.Slice, fi
}
elemIsReferencable := isReferencableType(elemType)
- fmt.Fprintf(buf, "%ssliceLen := int(buf.ReadVarUint32(err))\n", indent)
+ fmt.Fprintf(buf, "%ssliceLen := ctx.ReadCollectionLength()\n", indent)
fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent)
fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess,
sliceType.String())
fmt.Fprintf(buf, "%s} else {\n", indent)
@@ -703,7 +703,7 @@ func writePrimitiveSliceReadCall(buf *bytes.Buffer, basic
*types.Basic, fieldAcc
case types.Int8:
fmt.Fprintf(buf, "%s%s = fory.ReadInt8Slice(buf, err)\n",
indent, fieldAccess)
case types.Uint8:
- fmt.Fprintf(buf, "%ssizeBytes := buf.ReadLength(err)\n", indent)
+ fmt.Fprintf(buf, "%ssizeBytes := ctx.ReadBinaryLength()\n",
indent)
fmt.Fprintf(buf, "%s%s = make([]uint8, sizeBytes)\n", indent,
fieldAccess)
fmt.Fprintf(buf, "%sif sizeBytes > 0 {\n", indent)
fmt.Fprintf(buf, "%s\traw := buf.ReadBinary(sizeBytes, err)\n",
indent)
@@ -925,7 +925,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType
*types.Map, fieldAccess st
fmt.Fprintf(buf, "\t\tisXlang := ctx.TypeResolver().IsXlang()\n")
fmt.Fprintf(buf, "\t\tif isXlang {\n")
fmt.Fprintf(buf, "\t\t\t// xlang mode: maps are not nullable, read
directly without null flag\n")
- fmt.Fprintf(buf, "\t\t\tmapLen := int(buf.ReadVarUint32(err))\n")
+ fmt.Fprintf(buf, "\t\t\tmapLen := ctx.ReadCollectionLength()\n")
fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t%s = make(%s)\n", fieldAccess,
mapType.String())
fmt.Fprintf(buf, "\t\t\t} else {\n")
@@ -940,7 +940,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType
*types.Map, fieldAccess st
fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") // NullFlag
fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess)
fmt.Fprintf(buf, "\t\t\t} else {\n")
- fmt.Fprintf(buf, "\t\t\t\tmapLen := int(buf.ReadVarUint32(err))\n")
+ fmt.Fprintf(buf, "\t\t\t\tmapLen := ctx.ReadCollectionLength()\n")
fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s)\n", fieldAccess,
mapType.String())
fmt.Fprintf(buf, "\t\t\t\t} else {\n")
@@ -972,7 +972,7 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType
*types.Map, fieldAcc
}
indent := "\t\t\t"
- fmt.Fprintf(buf, "%smapLen := int(buf.ReadVarUint32(err))\n", indent)
+ fmt.Fprintf(buf, "%smapLen := ctx.ReadCollectionLength()\n", indent)
fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent)
fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess,
mapType.String())
fmt.Fprintf(buf, "%s} else {\n", indent)
diff --git a/go/fory/errors.go b/go/fory/errors.go
index 25f7dd087..6dc092bf2 100644
--- a/go/fory/errors.go
+++ b/go/fory/errors.go
@@ -52,6 +52,10 @@ const (
ErrKindInvalidTag
// ErrKindInvalidUTF16String indicates malformed UTF-16 string data
ErrKindInvalidUTF16String
+ // ErrKindMaxCollectionSizeExceeded indicates max collection size
exceeded
+ ErrKindMaxCollectionSizeExceeded
+ // ErrKindMaxBinarySizeExceeded indicates max binary size exceeded
+ ErrKindMaxBinarySizeExceeded
)
// Error is a lightweight error type optimized for hot path performance.
@@ -296,6 +300,26 @@ func InvalidUTF16StringError(byteCount int) Error {
})
}
+// MaxCollectionSizeExceededError creates a max collection size exceeded error
+//
+//go:noinline
+func MaxCollectionSizeExceededError(size, limit int) Error {
+ return panicIfEnabled(Error{
+ kind: ErrKindMaxCollectionSizeExceeded,
+ message: fmt.Sprintf("max collection size exceeded: size=%d,
limit=%d", size, limit),
+ })
+}
+
+// MaxBinarySizeExceededError creates a max binary size exceeded error
+//
+//go:noinline
+func MaxBinarySizeExceededError(size, limit int) Error {
+ return panicIfEnabled(Error{
+ kind: ErrKindMaxBinarySizeExceeded,
+ message: fmt.Sprintf("max binary size exceeded: size=%d,
limit=%d", size, limit),
+ })
+}
+
// WrapError wraps a standard error into a fory Error
//
//go:noinline
diff --git a/go/fory/fory.go b/go/fory/fory.go
index 342b0acc3..57da20c57 100644
--- a/go/fory/fory.go
+++ b/go/fory/fory.go
@@ -50,18 +50,22 @@ const (
// Config holds configuration options for Fory instances
type Config struct {
- TrackRef bool
- MaxDepth int
- IsXlang bool
- Compatible bool // Schema evolution compatibility mode
+ TrackRef bool
+ MaxDepth int
+ IsXlang bool
+ Compatible bool // Schema evolution compatibility mode
+ MaxCollectionSize int
+ MaxBinarySize int
}
// defaultConfig returns the default configuration
func defaultConfig() Config {
return Config{
- TrackRef: false, // Match Java's default: reference tracking
disabled
- MaxDepth: 20,
- IsXlang: false,
+ TrackRef: false, // Match Java's default: reference
tracking disabled
+ MaxDepth: 20,
+ IsXlang: false,
+ MaxCollectionSize: 1_000_000,
+ MaxBinarySize: 64 * 1024 * 1024,
}
}
@@ -101,6 +105,20 @@ func WithCompatible(enabled bool) Option {
}
}
+// WithMaxCollectionSize sets the maximum collection size limit
+func WithMaxCollectionSize(size int) Option {
+ return func(f *Fory) {
+ f.config.MaxCollectionSize = size
+ }
+}
+
+// WithMaxBinarySize sets the maximum binary size limit
+func WithMaxBinarySize(size int) Option {
+ return func(f *Fory) {
+ f.config.MaxBinarySize = size
+ }
+}
+
// ============================================================================
// Fory - Main serialization instance
// ============================================================================
@@ -152,6 +170,8 @@ func New(opts ...Option) *Fory {
f.writeCtx.xlang = f.config.IsXlang
f.readCtx = NewReadContext(f.config.TrackRef)
+ f.readCtx.maxCollectionSize = f.config.MaxCollectionSize
+ f.readCtx.maxBinarySize = f.config.MaxBinarySize
f.readCtx.typeResolver = f.typeResolver
f.readCtx.refResolver = f.refResolver
f.readCtx.compatible = f.config.Compatible
diff --git a/go/fory/limit_test.go b/go/fory/limit_test.go
new file mode 100644
index 000000000..64d32f7a3
--- /dev/null
+++ b/go/fory/limit_test.go
@@ -0,0 +1,103 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package fory
+
+import (
+ "github.com/stretchr/testify/require"
+ "testing"
+)
+
+func TestMaxCollectionSizeGuardrail(t *testing.T) {
+ // 1. Test slice exceeding limit
+ t.Run("Slice exceeds MaxCollectionSize", func(t *testing.T) {
+ config := WithMaxCollectionSize(2)
+ f := NewFory(config)
+
+ slice := []string{"a", "b", "c"}
+ fBase := NewFory()
+ bytes, _ := fBase.Serialize(slice)
+
+ var decoded []string
+ err := f.Deserialize(bytes, &decoded)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "max collection size exceeded:
size=3, limit=2")
+ })
+
+ // 2. Test map exceeding limit
+ t.Run("Map exceeds MaxCollectionSize", func(t *testing.T) {
+ config := WithMaxCollectionSize(2)
+ f := NewFory(config)
+
+ m := map[int32]int32{1: 1, 2: 2, 3: 3}
+ fBase := NewFory()
+ bytes, _ := fBase.Serialize(m)
+
+ var decoded map[int32]int32
+ err := f.Deserialize(bytes, &decoded)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "max collection size exceeded:
size=3, limit=2")
+ })
+
+ // 3. Test string is not affected by MaxCollectionSize
+ t.Run("String unaffected by MaxCollectionSize", func(t *testing.T) {
+ config := WithMaxCollectionSize(2)
+ f := NewFory(config)
+
+ str := "hello world" // length 11
+ bytes, err := f.Serialize(str)
+ require.NoError(t, err)
+
+ var decoded string
+ err = f.Deserialize(bytes, &decoded)
+ require.NoError(t, err)
+ require.Equal(t, str, decoded)
+ })
+}
+
+func TestMaxBinarySizeGuardrail(t *testing.T) {
+ // 1. Test binary (byte slice) exceeding limit
+ t.Run("Byte slice exceeds MaxBinarySize", func(t *testing.T) {
+ config := WithMaxBinarySize(5)
+ f := NewFory(config)
+
+ // We can serialize a byte slice using standard serializer,
then decode with the f instance
+ slice := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ fBase := NewFory()
+ bytes, _ := fBase.Serialize(slice)
+
+ var decoded []byte
+ err := f.Deserialize(bytes, &decoded)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "max binary size exceeded:
size=10, limit=5")
+ })
+
+ // 2. Test string is not affected by MaxBinarySize
+ t.Run("String unaffected by MaxBinarySize", func(t *testing.T) {
+ config := WithMaxBinarySize(2)
+ f := NewFory(config)
+
+ str := "hello world" // length 11
+ bytes, err := f.Serialize(str)
+ require.NoError(t, err)
+
+ var decoded string
+ err = f.Deserialize(bytes, &decoded)
+ require.NoError(t, err)
+ require.Equal(t, str, decoded)
+ })
+}
diff --git a/go/fory/map.go b/go/fory/map.go
index f2489601f..ace2c5597 100644
--- a/go/fory/map.go
+++ b/go/fory/map.go
@@ -305,7 +305,7 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
}
refResolver.Reference(value)
- size := int(buf.ReadVarUint32(ctxErr))
+ size := ctx.ReadCollectionLength()
if size == 0 || ctx.HasError() {
return
}
diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go
index 21a4bd7b5..16e40f5ae 100644
--- a/go/fory/map_primitive.go
+++ b/go/fory/map_primitive.go
@@ -69,8 +69,10 @@ func writeMapStringString(buf *ByteBuffer, m
map[string]string, hasGenerics bool
}
// readMapStringString reads map[string]string using chunk protocol
-func readMapStringString(buf *ByteBuffer, err *Error) map[string]string {
- size := int(buf.ReadVarUint32(err))
+func readMapStringString(ctx *ReadContext) map[string]string {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[string]string, size)
if size == 0 {
return result
@@ -172,8 +174,10 @@ func writeMapStringInt64(buf *ByteBuffer, m
map[string]int64, hasGenerics bool)
}
// readMapStringInt64 reads map[string]int64 using chunk protocol
-func readMapStringInt64(buf *ByteBuffer, err *Error) map[string]int64 {
- size := int(buf.ReadVarUint32(err))
+func readMapStringInt64(ctx *ReadContext) map[string]int64 {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[string]int64, size)
if size == 0 {
return result
@@ -246,8 +250,10 @@ func writeMapStringInt32(buf *ByteBuffer, m
map[string]int32, hasGenerics bool)
}
// readMapStringInt32 reads map[string]int32 using chunk protocol
-func readMapStringInt32(buf *ByteBuffer, err *Error) map[string]int32 {
- size := int(buf.ReadVarUint32(err))
+func readMapStringInt32(ctx *ReadContext) map[string]int32 {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[string]int32, size)
if size == 0 {
return result
@@ -320,8 +326,10 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int,
hasGenerics bool) {
}
// readMapStringInt reads map[string]int using chunk protocol
-func readMapStringInt(buf *ByteBuffer, err *Error) map[string]int {
- size := int(buf.ReadVarUint32(err))
+func readMapStringInt(ctx *ReadContext) map[string]int {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[string]int, size)
if size == 0 {
return result
@@ -394,8 +402,10 @@ func writeMapStringFloat64(buf *ByteBuffer, m
map[string]float64, hasGenerics bo
}
// readMapStringFloat64 reads map[string]float64 using chunk protocol
-func readMapStringFloat64(buf *ByteBuffer, err *Error) map[string]float64 {
- size := int(buf.ReadVarUint32(err))
+func readMapStringFloat64(ctx *ReadContext) map[string]float64 {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[string]float64, size)
if size == 0 {
return result
@@ -468,8 +478,10 @@ func writeMapStringBool(buf *ByteBuffer, m
map[string]bool, hasGenerics bool) {
}
// readMapStringBool reads map[string]bool using chunk protocol
-func readMapStringBool(buf *ByteBuffer, err *Error) map[string]bool {
- size := int(buf.ReadVarUint32(err))
+func readMapStringBool(ctx *ReadContext) map[string]bool {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[string]bool, size)
if size == 0 {
return result
@@ -547,8 +559,10 @@ func writeMapInt32Int32(buf *ByteBuffer, m
map[int32]int32, hasGenerics bool) {
}
// readMapInt32Int32 reads map[int32]int32 using chunk protocol
-func readMapInt32Int32(buf *ByteBuffer, err *Error) map[int32]int32 {
- size := int(buf.ReadVarUint32(err))
+func readMapInt32Int32(ctx *ReadContext) map[int32]int32 {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[int32]int32, size)
if size == 0 {
return result
@@ -621,8 +635,10 @@ func writeMapInt64Int64(buf *ByteBuffer, m
map[int64]int64, hasGenerics bool) {
}
// readMapInt64Int64 reads map[int64]int64 using chunk protocol
-func readMapInt64Int64(buf *ByteBuffer, err *Error) map[int64]int64 {
- size := int(buf.ReadVarUint32(err))
+func readMapInt64Int64(ctx *ReadContext) map[int64]int64 {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[int64]int64, size)
if size == 0 {
return result
@@ -695,8 +711,10 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int,
hasGenerics bool) {
}
// readMapIntInt reads map[int]int using chunk protocol
-func readMapIntInt(buf *ByteBuffer, err *Error) map[int]int {
- size := int(buf.ReadVarUint32(err))
+func readMapIntInt(ctx *ReadContext) map[int]int {
+ err := ctx.Err()
+ buf := ctx.Buffer()
+ size := ctx.ReadCollectionLength()
result := make(map[int]int, size)
if size == 0 {
return result
@@ -752,7 +770,7 @@ func (s stringStringMapSerializer) ReadData(ctx
*ReadContext, value reflect.Valu
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
- result := readMapStringString(ctx.buffer, ctx.Err())
+ result := readMapStringString(ctx)
value.Set(reflect.ValueOf(result))
}
@@ -787,7 +805,7 @@ func (s stringInt64MapSerializer) ReadData(ctx
*ReadContext, value reflect.Value
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
- result := readMapStringInt64(ctx.buffer, ctx.Err())
+ result := readMapStringInt64(ctx)
value.Set(reflect.ValueOf(result))
}
@@ -822,7 +840,7 @@ func (s stringIntMapSerializer) ReadData(ctx *ReadContext,
value reflect.Value)
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
- result := readMapStringInt(ctx.buffer, ctx.Err())
+ result := readMapStringInt(ctx)
value.Set(reflect.ValueOf(result))
}
@@ -857,7 +875,7 @@ func (s stringFloat64MapSerializer) ReadData(ctx
*ReadContext, value reflect.Val
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
- result := readMapStringFloat64(ctx.buffer, ctx.Err())
+ result := readMapStringFloat64(ctx)
value.Set(reflect.ValueOf(result))
}
@@ -892,7 +910,7 @@ func (s stringBoolMapSerializer) ReadData(ctx *ReadContext,
value reflect.Value)
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
- result := readMapStringBool(ctx.buffer, ctx.Err())
+ result := readMapStringBool(ctx)
value.Set(reflect.ValueOf(result))
}
@@ -927,7 +945,7 @@ func (s int32Int32MapSerializer) ReadData(ctx *ReadContext,
value reflect.Value)
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
- result := readMapInt32Int32(ctx.buffer, ctx.Err())
+ result := readMapInt32Int32(ctx)
value.Set(reflect.ValueOf(result))
}
@@ -962,7 +980,7 @@ func (s int64Int64MapSerializer) ReadData(ctx *ReadContext,
value reflect.Value)
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
- result := readMapInt64Int64(ctx.buffer, ctx.Err())
+ result := readMapInt64Int64(ctx)
value.Set(reflect.ValueOf(result))
}
@@ -997,7 +1015,7 @@ func (s intIntMapSerializer) ReadData(ctx *ReadContext,
value reflect.Value) {
value.Set(reflect.MakeMap(value.Type()))
}
ctx.RefResolver().Reference(value)
- result := readMapIntInt(ctx.buffer, ctx.Err())
+ result := readMapIntInt(ctx)
value.Set(reflect.ValueOf(result))
}
diff --git a/go/fory/reader.go b/go/fory/reader.go
index a0a37d92f..9c8b049ad 100644
--- a/go/fory/reader.go
+++ b/go/fory/reader.go
@@ -29,20 +29,22 @@ import (
// ReadContext holds all state needed during deserialization.
type ReadContext struct {
- buffer *ByteBuffer
- refReader *RefReader
- trackRef bool // Cached flag to avoid indirection
- xlang bool // Cross-language serialization mode
- compatible bool // Schema evolution compatibility mode
- typeResolver *TypeResolver // For complex type deserialization
- refResolver *RefResolver // For reference tracking (legacy)
- outOfBandBuffers []*ByteBuffer // Out-of-band buffers for
deserialization
- outOfBandIndex int // Current index into out-of-band buffers
- depth int // Current nesting depth for cycle
detection
- maxDepth int // Maximum allowed nesting depth
- err Error // Accumulated error state for deferred
checking
- lastTypePtr uintptr
- lastTypeInfo *TypeInfo
+ buffer *ByteBuffer
+ refReader *RefReader
+ trackRef bool // Cached flag to avoid indirection
+ xlang bool // Cross-language serialization mode
+ compatible bool // Schema evolution compatibility mode
+ typeResolver *TypeResolver // For complex type deserialization
+ refResolver *RefResolver // For reference tracking (legacy)
+ outOfBandBuffers []*ByteBuffer // Out-of-band buffers for
deserialization
+ outOfBandIndex int // Current index into out-of-band
buffers
+ depth int // Current nesting depth for cycle
detection
+ maxDepth int // Maximum allowed nesting depth
+ err Error // Accumulated error state for deferred
checking
+ lastTypePtr uintptr
+ lastTypeInfo *TypeInfo
+ maxCollectionSize int // Size guardrail for collection reads
+ maxBinarySize int // Size guardrail for binary reads
}
// IsXlang returns whether cross-language serialization mode is enabled
@@ -237,10 +239,32 @@ func (c *ReadContext) ReadAndValidateTypeId(expected
TypeId) {
}
}
-// ReadLength reads a length value as varint (non-negative values)
-func (c *ReadContext) ReadLength() int {
+// ReadCollectionLength reads a length value for collections with size
guardrails
+func (c *ReadContext) ReadCollectionLength() int {
err := c.Err()
- return int(c.buffer.ReadVarUint32(err))
+ length := c.buffer.ReadLength(err)
+ if c.err.HasError() {
+ return 0
+ }
+ if length > c.maxCollectionSize {
+ c.SetError(MaxCollectionSizeExceededError(length,
c.maxCollectionSize))
+ return 0
+ }
+ return length
+}
+
+// ReadBinaryLength reads a length value for binary data with size guardrails
+func (c *ReadContext) ReadBinaryLength() int {
+ err := c.Err()
+ length := c.buffer.ReadLength(err)
+ if c.err.HasError() {
+ return 0
+ }
+ if length > c.maxBinarySize {
+ c.SetError(MaxBinarySizeExceededError(length, c.maxBinarySize))
+ return 0
+ }
+ return length
}
// ============================================================================
@@ -434,7 +458,7 @@ func (c *ReadContext) ReadByteSlice(refMode RefMode,
readType bool) []byte {
if readType {
_ = c.buffer.ReadUint8(err)
}
- size := c.buffer.ReadLength(err)
+ size := c.ReadBinaryLength()
return c.buffer.ReadBinary(size, err)
}
@@ -463,7 +487,7 @@ func (c *ReadContext) ReadStringStringMap(refMode RefMode,
readType bool) map[st
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapStringString(c.buffer, err)
+ return readMapStringString(c)
}
// ReadStringInt64Map reads map[string]int64 with optional ref/type info
@@ -477,7 +501,7 @@ func (c *ReadContext) ReadStringInt64Map(refMode RefMode,
readType bool) map[str
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapStringInt64(c.buffer, err)
+ return readMapStringInt64(c)
}
// ReadStringInt32Map reads map[string]int32 with optional ref/type info
@@ -491,7 +515,7 @@ func (c *ReadContext) ReadStringInt32Map(refMode RefMode,
readType bool) map[str
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapStringInt32(c.buffer, err)
+ return readMapStringInt32(c)
}
// ReadStringIntMap reads map[string]int with optional ref/type info
@@ -505,7 +529,7 @@ func (c *ReadContext) ReadStringIntMap(refMode RefMode,
readType bool) map[strin
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapStringInt(c.buffer, err)
+ return readMapStringInt(c)
}
// ReadStringFloat64Map reads map[string]float64 with optional ref/type info
@@ -519,7 +543,7 @@ func (c *ReadContext) ReadStringFloat64Map(refMode RefMode,
readType bool) map[s
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapStringFloat64(c.buffer, err)
+ return readMapStringFloat64(c)
}
// ReadStringBoolMap reads map[string]bool with optional ref/type info
@@ -533,7 +557,7 @@ func (c *ReadContext) ReadStringBoolMap(refMode RefMode,
readType bool) map[stri
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapStringBool(c.buffer, err)
+ return readMapStringBool(c)
}
// ReadInt32Int32Map reads map[int32]int32 with optional ref/type info
@@ -547,7 +571,7 @@ func (c *ReadContext) ReadInt32Int32Map(refMode RefMode,
readType bool) map[int3
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapInt32Int32(c.buffer, err)
+ return readMapInt32Int32(c)
}
// ReadInt64Int64Map reads map[int64]int64 with optional ref/type info
@@ -561,7 +585,7 @@ func (c *ReadContext) ReadInt64Int64Map(refMode RefMode,
readType bool) map[int6
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapInt64Int64(c.buffer, err)
+ return readMapInt64Int64(c)
}
// ReadIntIntMap reads map[int]int with optional ref/type info
@@ -575,7 +599,7 @@ func (c *ReadContext) ReadIntIntMap(refMode RefMode,
readType bool) map[int]int
if readType {
_ = c.buffer.ReadUint8(err)
}
- return readMapIntInt(c.buffer, err)
+ return readMapIntInt(c)
}
// ReadBufferObject reads a buffer object
@@ -583,7 +607,7 @@ func (c *ReadContext) ReadBufferObject() *ByteBuffer {
err := c.Err()
isInBand := c.buffer.ReadBool(err)
if isInBand {
- size := c.buffer.ReadLength(err)
+ size := c.ReadBinaryLength()
buf := c.buffer.Slice(c.buffer.readerIndex, size)
c.buffer.readerIndex += size
return buf
diff --git a/go/fory/set.go b/go/fory/set.go
index 2105b3e9d..83a7b3171 100644
--- a/go/fory/set.go
+++ b/go/fory/set.go
@@ -295,7 +295,7 @@ func (s setSerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
err := ctx.Err()
type_ := value.Type()
// ReadData collection length from buffer
- length := int(buf.ReadVarUint32(err))
+ length := ctx.ReadCollectionLength()
if length == 0 {
// Initialize empty set if length is 0
value.Set(reflect.MakeMap(type_))
diff --git a/go/fory/skip.go b/go/fory/skip.go
index 34005ad74..64660dfc3 100644
--- a/go/fory/skip.go
+++ b/go/fory/skip.go
@@ -213,7 +213,7 @@ func readTypeInfoForSkip(ctx *ReadContext, fieldTypeId
TypeId) *TypeInfo {
// Uses context error state for deferred error checking.
func skipCollection(ctx *ReadContext, fieldDef FieldDef) {
err := ctx.Err()
- length := ctx.buffer.ReadVarUint32(err)
+ length := uint32(ctx.ReadCollectionLength())
if ctx.HasError() || length == 0 {
return
}
@@ -283,7 +283,7 @@ func skipCollection(ctx *ReadContext, fieldDef FieldDef) {
// Uses context error state for deferred error checking.
func skipMap(ctx *ReadContext, fieldDef FieldDef) {
bufErr := ctx.Err()
- length := ctx.buffer.ReadVarUint32(bufErr)
+ length := uint32(ctx.ReadCollectionLength())
if ctx.HasError() || length == 0 {
return
}
@@ -601,31 +601,31 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef,
readRefFlag bool, isField bo
_ = ctx.buffer.ReadBinary(int(size), err)
}
case BINARY:
- length := ctx.buffer.ReadVarUint32(err)
+ length := uint32(ctx.ReadBinaryLength())
if ctx.HasError() {
return
}
_ = ctx.buffer.ReadBinary(int(length), err)
case BOOL_ARRAY, INT8_ARRAY, UINT8_ARRAY:
- length := ctx.buffer.ReadLength(err)
+ length := ctx.ReadBinaryLength()
if ctx.HasError() {
return
}
_ = ctx.buffer.ReadBinary(length, err)
case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY, BFLOAT16_ARRAY:
- length := ctx.buffer.ReadLength(err)
+ length := ctx.ReadBinaryLength()
if ctx.HasError() {
return
}
_ = ctx.buffer.ReadBinary(length*2, err)
case INT32_ARRAY, UINT32_ARRAY, FLOAT32_ARRAY:
- length := ctx.buffer.ReadLength(err)
+ length := ctx.ReadBinaryLength()
if ctx.HasError() {
return
}
_ = ctx.buffer.ReadBinary(length*4, err)
case INT64_ARRAY, UINT64_ARRAY, FLOAT64_ARRAY:
- length := ctx.buffer.ReadLength(err)
+ length := ctx.ReadBinaryLength()
if ctx.HasError() {
return
}
diff --git a/go/fory/slice.go b/go/fory/slice.go
index bd3a9aa7e..90b6ff14e 100644
--- a/go/fory/slice.go
+++ b/go/fory/slice.go
@@ -264,7 +264,7 @@ func (s *sliceSerializer) ReadWithTypeInfo(ctx
*ReadContext, refMode RefMode, ty
func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
- length := int(buf.ReadVarUint32(ctxErr))
+ length := ctx.ReadCollectionLength()
isArrayType := value.Type().Kind() == reflect.Array
if length == 0 {
diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go
index 3393d4b22..90a7b8cad 100644
--- a/go/fory/slice_dyn.go
+++ b/go/fory/slice_dyn.go
@@ -261,7 +261,7 @@ func (s sliceDynSerializer) Read(ctx *ReadContext, refMode
RefMode, readType boo
func (s sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
- length := int(buf.ReadVarUint32(ctxErr))
+ length := ctx.ReadCollectionLength()
sliceType := value.Type()
value.Set(reflect.MakeSlice(sliceType, length, length))
if length == 0 {
diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go
index c89390898..dddc80069 100644
--- a/go/fory/slice_primitive.go
+++ b/go/fory/slice_primitive.go
@@ -74,7 +74,7 @@ func (s byteSliceSerializer) ReadWithTypeInfo(ctx
*ReadContext, refMode RefMode,
func (s byteSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
- length := buf.ReadLength(ctxErr)
+ length := ctx.ReadBinaryLength()
ptr := (*[]byte)(value.Addr().UnsafePointer())
if length == 0 {
*ptr = make([]byte, 0)
@@ -642,7 +642,7 @@ func (s stringSliceSerializer) ReadWithTypeInfo(ctx
*ReadContext, refMode RefMod
func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value)
{
buf := ctx.Buffer()
ctxErr := ctx.Err()
- length := int(buf.ReadVarUint32(ctxErr))
+ length := ctx.ReadCollectionLength()
ptr := (*[]string)(value.Addr().UnsafePointer())
if length == 0 {
*ptr = make([]string, 0)
@@ -1071,7 +1071,7 @@ func (s float16SliceSerializer) ReadWithTypeInfo(ctx
*ReadContext, refMode RefMo
func (s float16SliceSerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
- size := buf.ReadLength(ctxErr)
+ size := ctx.ReadBinaryLength()
length := size / 2
if ctx.HasError() {
return
@@ -1253,7 +1253,7 @@ func WriteStringSlice(buf *ByteBuffer, value []string,
hasGenerics bool) {
// ReadStringSlice reads []string from buffer using LIST protocol
func ReadStringSlice(buf *ByteBuffer, err *Error) []string {
- length := int(buf.ReadVarUint32(err))
+ length := buf.ReadLength(err)
if length == 0 {
return make([]string, 0)
}
@@ -1328,7 +1328,7 @@ func (s bfloat16SliceSerializer) ReadWithTypeInfo(ctx
*ReadContext, refMode RefM
func (s bfloat16SliceSerializer) ReadData(ctx *ReadContext, value
reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
- size := buf.ReadLength(ctxErr)
+ size := ctx.ReadBinaryLength()
length := size / 2
if ctx.HasError() {
return
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]