This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git
The following commit(s) were added to refs/heads/main by this push:
new 551ff53b2 fix(python): support tuple dataclass fields and object
instances (#3468)
551ff53b2 is described below
commit 551ff53b2f0edb0c510da2b9966f2eef0ce41792
Author: Shawn Yang <[email protected]>
AuthorDate: Wed Mar 11 11:39:48 2026 +0800
fix(python): support tuple dataclass fields and object instances (#3468)
## Summary
- support typed `tuple[...]` dataclass fields in Python native-mode
field inference by routing them to `TupleSerializer`
- serialize instances without `__dict__` as zero-field objects so bare
`object()` round-trips cleanly
- add focused regressions for both issues, including empty-object ref
tracking behavior
## Testing
- cd python && ruff format --check pyfory/type_util.py pyfory/struct.py
pyfory/serializer.py pyfory/tests/test_struct.py
pyfory/tests/test_serializer.py
- cd python && ruff check pyfory/type_util.py pyfory/struct.py
pyfory/serializer.py pyfory/tests/test_struct.py
pyfory/tests/test_serializer.py
- cd python && ENABLE_FORY_CYTHON_SERIALIZATION=0 pytest -q .
- cd python && ENABLE_FORY_CYTHON_SERIALIZATION=1 pytest -q .
Closes #3466
Closes #3467
---
python/pyfory/collection.pxi | 78 ++++++++++++++++++--------------
python/pyfory/collection.py | 43 ++++++++++--------
python/pyfory/format/infer.py | 11 ++++-
python/pyfory/format/tests/test_infer.py | 22 ++++++++-
python/pyfory/meta/typedef.py | 21 +++++++--
python/pyfory/registry.py | 2 +
python/pyfory/serializer.py | 3 +-
python/pyfory/struct.py | 28 +++++++++++-
python/pyfory/tests/test_collection.py | 16 +++++++
python/pyfory/tests/test_serializer.py | 17 +++++++
python/pyfory/tests/test_struct.py | 59 +++++++++++++++++++++++-
python/pyfory/type_util.py | 23 ++++++++++
12 files changed, 261 insertions(+), 62 deletions(-)
diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi
index ecca3bff3..99e9dfb67 100644
--- a/python/pyfory/collection.pxi
+++ b/python/pyfory/collection.pxi
@@ -128,6 +128,7 @@ cdef class CollectionSerializer(Serializer):
cdef int8_t collect_flag
cdef TypeInfo elem_type_info = self.write_header(buffer, value,
&collect_flag)
cdef elem_type = elem_type_info.cls
+ cdef Serializer elem_serializer = self.elem_serializer if
(collect_flag & COLL_IS_DECL_ELEMENT_TYPE) != 0 and self.elem_serializer is not
None else elem_type_info.serializer
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef TypeResolver type_resolver = self.type_resolver
cdef PyObject **items = fory_sequence_get_items(value)
@@ -143,13 +144,13 @@ cdef class CollectionSerializer(Serializer):
if Fory_CanUsePrimitiveCollectionFastpath(type_id):
self._write_primitive_fastpath(buffer, value, type_id,
items, size)
elif (collect_flag & COLL_TRACKING_REF) == 0:
- self._write_same_type_no_ref(buffer, value, elem_type_info)
+ self._write_same_type_no_ref(buffer, value,
elem_serializer)
else:
- self._write_same_type_ref(buffer, value, elem_type_info)
+ self._write_same_type_ref(buffer, value, elem_serializer)
elif (collect_flag & COLL_TRACKING_REF) != 0:
- self._write_same_type_ref(buffer, value, elem_type_info)
+ self._write_same_type_ref(buffer, value, elem_serializer)
else:
- self._write_same_type_has_null(buffer, value, elem_type_info)
+ self._write_same_type_has_null(buffer, value, elem_serializer)
else:
# Check tracking_ref and has_null flags for different types writing
tracking_ref = (collect_flag & COLL_TRACKING_REF) != 0
@@ -253,7 +254,7 @@ cdef class CollectionSerializer(Serializer):
cdef inline _read_primitive_fastpath(self, Buffer buffer, int64_t len_,
object collection_, uint8_t type_id):
Fory_PyPrimitiveCollectionReadFromBuffer(collection_,
&buffer.c_buffer, len_, type_id)
- cpdef _write_same_type_no_ref(self, Buffer buffer, value, TypeInfo
typeinfo):
+ cpdef _write_same_type_no_ref(self, Buffer buffer, value, Serializer
serializer):
cdef PyObject **items = fory_sequence_get_items(value)
cdef Py_ssize_t i
cdef Py_ssize_t size
@@ -262,18 +263,18 @@ cdef class CollectionSerializer(Serializer):
size = Py_SIZE(value)
for i in range(size):
s = <object> items[i]
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
return
for s in value:
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
- cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object
collection_, TypeInfo typeinfo):
+ cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object
collection_, Serializer serializer):
cdef PyObject **items = fory_sequence_get_items(collection_)
cdef c_bool is_list = type(collection_) is list
self.fory.inc_depth()
if items != NULL:
for i in range(len_):
- obj = self.fory.read_no_ref(buffer,
serializer=typeinfo.serializer)
+ obj = self.fory.read_no_ref(buffer, serializer=serializer)
Py_INCREF(obj)
if is_list:
PyList_SET_ITEM(collection_, i, obj)
@@ -281,11 +282,11 @@ cdef class CollectionSerializer(Serializer):
PyTuple_SET_ITEM(collection_, i, obj)
else:
for i in range(len_):
- obj = self.fory.read_no_ref(buffer,
serializer=typeinfo.serializer)
+ obj = self.fory.read_no_ref(buffer, serializer=serializer)
self._add_element(collection_, i, obj)
self.fory.dec_depth()
- cpdef _write_same_type_has_null(self, Buffer buffer, value, TypeInfo
typeinfo):
+ cpdef _write_same_type_has_null(self, Buffer buffer, value, Serializer
serializer):
cdef PyObject **items = fory_sequence_get_items(value)
cdef PyObject *item
cdef Py_ssize_t i
@@ -300,16 +301,16 @@ cdef class CollectionSerializer(Serializer):
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
s = <object> item
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
return
for s in value:
if s is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
- cpdef _read_same_type_has_null(self, Buffer buffer, int64_t len_, object
collection_, TypeInfo typeinfo):
+ cpdef _read_same_type_has_null(self, Buffer buffer, int64_t len_, object
collection_, Serializer serializer):
cdef int8_t flag
cdef PyObject **items = fory_sequence_get_items(collection_)
cdef c_bool is_list = type(collection_) is list
@@ -320,7 +321,7 @@ cdef class CollectionSerializer(Serializer):
if flag == NULL_FLAG:
obj = None
else:
- obj = self.fory.read_no_ref(buffer,
serializer=typeinfo.serializer)
+ obj = self.fory.read_no_ref(buffer, serializer=serializer)
Py_INCREF(obj)
if is_list:
PyList_SET_ITEM(collection_, i, obj)
@@ -335,11 +336,11 @@ cdef class CollectionSerializer(Serializer):
self._add_element(
collection_,
i,
- self.fory.read_no_ref(buffer,
serializer=typeinfo.serializer),
+ self.fory.read_no_ref(buffer, serializer=serializer),
)
self.fory.dec_depth()
- cpdef _write_same_type_ref(self, Buffer buffer, value, TypeInfo typeinfo):
+ cpdef _write_same_type_ref(self, Buffer buffer, value, Serializer
serializer):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef PyObject **items = fory_sequence_get_items(value)
cdef Py_ssize_t i
@@ -350,13 +351,13 @@ cdef class CollectionSerializer(Serializer):
for i in range(size):
s = <object> items[i]
if not ref_resolver.write_ref_or_null(buffer, s):
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
return
for s in value:
if not ref_resolver.write_ref_or_null(buffer, s):
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
- cpdef _read_same_type_ref(self, Buffer buffer, int64_t len_, object
collection_, TypeInfo typeinfo):
+ cpdef _read_same_type_ref(self, Buffer buffer, int64_t len_, object
collection_, Serializer serializer):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef PyObject **items = fory_sequence_get_items(collection_)
cdef c_bool is_list = type(collection_) is list
@@ -367,7 +368,7 @@ cdef class CollectionSerializer(Serializer):
if ref_id < NOT_NULL_VALUE_FLAG:
obj = ref_resolver.get_read_object()
else:
- obj = typeinfo.serializer.read(buffer)
+ obj = serializer.read(buffer)
ref_resolver.set_read_object(ref_id, obj)
Py_INCREF(obj)
if is_list:
@@ -380,7 +381,7 @@ cdef class CollectionSerializer(Serializer):
if ref_id < NOT_NULL_VALUE_FLAG:
obj = ref_resolver.get_read_object()
else:
- obj = typeinfo.serializer.read(buffer)
+ obj = serializer.read(buffer)
ref_resolver.set_read_object(ref_id, obj)
self._add_element(collection_, i, obj)
self.fory.dec_depth()
@@ -406,24 +407,27 @@ cdef class ListSerializer(CollectionSerializer):
cdef c_bool tracking_ref
cdef c_bool has_null
cdef int8_t head_flag
+ cdef Serializer elem_serializer = self.elem_serializer
if (collect_flag & COLL_IS_SAME_TYPE) != 0:
if collect_flag & COLL_IS_DECL_ELEMENT_TYPE == 0:
typeinfo = self.type_resolver.read_type_info(buffer)
+ elem_serializer = typeinfo.serializer
else:
typeinfo = self.elem_type_info
+ elem_serializer = self.elem_serializer
if (collect_flag & COLL_HAS_NULL) == 0:
type_id = typeinfo.type_id
if Fory_CanUsePrimitiveCollectionFastpath(type_id):
self._read_primitive_fastpath(buffer, len_, list_, type_id)
return list_
elif (collect_flag & COLL_TRACKING_REF) == 0:
- self._read_same_type_no_ref(buffer, len_, list_, typeinfo)
+ self._read_same_type_no_ref(buffer, len_, list_,
elem_serializer)
else:
- self._read_same_type_ref(buffer, len_, list_, typeinfo)
+ self._read_same_type_ref(buffer, len_, list_,
elem_serializer)
elif (collect_flag & COLL_TRACKING_REF) != 0:
- self._read_same_type_ref(buffer, len_, list_, typeinfo)
+ self._read_same_type_ref(buffer, len_, list_, elem_serializer)
else:
- self._read_same_type_has_null(buffer, len_, list_, typeinfo)
+ self._read_same_type_has_null(buffer, len_, list_,
elem_serializer)
else:
self.fory.inc_depth()
# Check tracking_ref and has_null flags for different types
handling
@@ -508,24 +512,27 @@ cdef class TupleSerializer(CollectionSerializer):
cdef c_bool tracking_ref
cdef c_bool has_null
cdef int8_t head_flag
+ cdef Serializer elem_serializer = self.elem_serializer
if (collect_flag & COLL_IS_SAME_TYPE) != 0:
if collect_flag & COLL_IS_DECL_ELEMENT_TYPE == 0:
typeinfo = self.type_resolver.read_type_info(buffer)
+ elem_serializer = typeinfo.serializer
else:
typeinfo = self.elem_type_info
+ elem_serializer = self.elem_serializer
if (collect_flag & COLL_HAS_NULL) == 0:
type_id = typeinfo.type_id
if Fory_CanUsePrimitiveCollectionFastpath(type_id):
self._read_primitive_fastpath(buffer, len_, tuple_,
type_id)
return tuple_
elif (collect_flag & COLL_TRACKING_REF) == 0:
- self._read_same_type_no_ref(buffer, len_, tuple_, typeinfo)
+ self._read_same_type_no_ref(buffer, len_, tuple_,
elem_serializer)
else:
- self._read_same_type_ref(buffer, len_, tuple_, typeinfo)
+ self._read_same_type_ref(buffer, len_, tuple_,
elem_serializer)
elif (collect_flag & COLL_TRACKING_REF) != 0:
- self._read_same_type_ref(buffer, len_, tuple_, typeinfo)
+ self._read_same_type_ref(buffer, len_, tuple_, elem_serializer)
else:
- self._read_same_type_has_null(buffer, len_, tuple_, typeinfo)
+ self._read_same_type_has_null(buffer, len_, tuple_,
elem_serializer)
else:
self.fory.inc_depth()
# Check tracking_ref and has_null flags for different types
handling
@@ -593,24 +600,27 @@ cdef class SetSerializer(CollectionSerializer):
cdef c_bool tracking_ref
cdef c_bool has_null
cdef int8_t head_flag
+ cdef Serializer elem_serializer = self.elem_serializer
if (collect_flag & COLL_IS_SAME_TYPE) != 0:
if collect_flag & COLL_IS_DECL_ELEMENT_TYPE == 0:
typeinfo = self.type_resolver.read_type_info(buffer)
+ elem_serializer = typeinfo.serializer
else:
typeinfo = self.elem_type_info
+ elem_serializer = self.elem_serializer
if (collect_flag & COLL_HAS_NULL) == 0:
type_id = typeinfo.type_id
if Fory_CanUsePrimitiveCollectionFastpath(type_id):
self._read_primitive_fastpath(buffer, len_, instance,
type_id)
return instance
elif (collect_flag & COLL_TRACKING_REF) == 0:
- self._read_same_type_no_ref(buffer, len_, instance,
typeinfo)
+ self._read_same_type_no_ref(buffer, len_, instance,
elem_serializer)
else:
- self._read_same_type_ref(buffer, len_, instance, typeinfo)
+ self._read_same_type_ref(buffer, len_, instance,
elem_serializer)
elif (collect_flag & COLL_TRACKING_REF) != 0:
- self._read_same_type_ref(buffer, len_, instance, typeinfo)
+ self._read_same_type_ref(buffer, len_, instance,
elem_serializer)
else:
- self._read_same_type_has_null(buffer, len_, instance, typeinfo)
+ self._read_same_type_has_null(buffer, len_, instance,
elem_serializer)
else:
self.fory.inc_depth()
# Check tracking_ref and has_null flags for different types
handling
diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py
index b44fec026..35c0a7946 100644
--- a/python/pyfory/collection.py
+++ b/python/pyfory/collection.py
@@ -111,32 +111,35 @@ class CollectionSerializer(Serializer):
buffer.write_var_uint32(0)
return
collect_flag, typeinfo = self.write_header(buffer, value)
+ serializer = (
+ self.elem_serializer if (collect_flag & COLL_IS_DECL_ELEMENT_TYPE)
!= 0 and self.elem_serializer is not None else typeinfo.serializer
+ )
if (collect_flag & COLL_IS_SAME_TYPE) != 0:
if (collect_flag & COLL_TRACKING_REF) != 0:
- self._write_same_type_ref(buffer, value, typeinfo)
+ self._write_same_type_ref(buffer, value, serializer)
elif (collect_flag & COLL_HAS_NULL) == 0:
- self._write_same_type_no_ref(buffer, value, typeinfo)
+ self._write_same_type_no_ref(buffer, value, serializer)
else:
- self._write_same_type_has_null(buffer, value, typeinfo)
+ self._write_same_type_has_null(buffer, value, serializer)
else:
self._write_different_types(buffer, value, collect_flag)
- def _write_same_type_no_ref(self, buffer, value, typeinfo):
+ def _write_same_type_no_ref(self, buffer, value, serializer):
for s in value:
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
- def _write_same_type_has_null(self, buffer, value, typeinfo):
+ def _write_same_type_has_null(self, buffer, value, serializer):
for s in value:
if s is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
- def _write_same_type_ref(self, buffer, value, typeinfo):
+ def _write_same_type_ref(self, buffer, value, serializer):
for s in value:
if not self.ref_resolver.write_ref_or_null(buffer, s):
- typeinfo.serializer.write(buffer, s)
+ serializer.write(buffer, s)
def _write_different_types(self, buffer, value, collect_flag=0):
tracking_ref = (collect_flag & COLL_TRACKING_REF) != 0
@@ -177,14 +180,16 @@ class CollectionSerializer(Serializer):
if (collect_flag & COLL_IS_SAME_TYPE) != 0:
if collect_flag & COLL_IS_DECL_ELEMENT_TYPE == 0:
typeinfo = self.type_resolver.read_type_info(buffer)
+ serializer = typeinfo.serializer
else:
typeinfo = self.elem_type_info
+ serializer = self.elem_serializer
if (collect_flag & COLL_TRACKING_REF) != 0:
- self._read_same_type_ref(buffer, len_, collection_, typeinfo)
+ self._read_same_type_ref(buffer, len_, collection_, serializer)
elif (collect_flag & COLL_HAS_NULL) == 0:
- self._read_same_type_no_ref(buffer, len_, collection_,
typeinfo)
+ self._read_same_type_no_ref(buffer, len_, collection_,
serializer)
else:
- self._read_same_type_has_null(buffer, len_, collection_,
typeinfo)
+ self._read_same_type_has_null(buffer, len_, collection_,
serializer)
else:
self._read_different_types(buffer, len_, collection_, collect_flag)
return collection_
@@ -195,16 +200,16 @@ class CollectionSerializer(Serializer):
def _add_element(self, collection_, element):
raise NotImplementedError
- def _read_same_type_no_ref(self, buffer, len_, collection_, typeinfo):
+ def _read_same_type_no_ref(self, buffer, len_, collection_, serializer):
self.fory.inc_depth()
for _ in range(len_):
self._add_element(
collection_,
- self.fory.read_no_ref(buffer, serializer=typeinfo.serializer),
+ self.fory.read_no_ref(buffer, serializer=serializer),
)
self.fory.dec_depth()
- def _read_same_type_has_null(self, buffer, len_, collection_, typeinfo):
+ def _read_same_type_has_null(self, buffer, len_, collection_, serializer):
self.fory.inc_depth()
for _ in range(len_):
if buffer.read_int8() == NULL_FLAG:
@@ -212,18 +217,18 @@ class CollectionSerializer(Serializer):
else:
self._add_element(
collection_,
- self.fory.read_no_ref(buffer,
serializer=typeinfo.serializer),
+ self.fory.read_no_ref(buffer, serializer=serializer),
)
self.fory.dec_depth()
- def _read_same_type_ref(self, buffer, len_, collection_, typeinfo):
+ def _read_same_type_ref(self, buffer, len_, collection_, serializer):
self.fory.inc_depth()
for _ in range(len_):
ref_id = self.ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
obj = self.ref_resolver.get_read_object()
else:
- obj = typeinfo.serializer.read(buffer)
+ obj = serializer.read(buffer)
self.ref_resolver.set_read_object(ref_id, obj)
self._add_element(collection_, obj)
self.fory.dec_depth()
@@ -400,8 +405,6 @@ class MapSerializer(Serializer):
value_write_ref = self.value_tracking_ref
if value_write_ref:
buffer.write_int8(NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF)
- if not ref_resolver.write_ref_or_null(buffer,
key):
- value_serializer.write(buffer, key)
if not ref_resolver.write_ref_or_null(buffer,
value):
value_serializer.write(buffer, value)
else:
diff --git a/python/pyfory/format/infer.py b/python/pyfory/format/infer.py
index 99687b97e..09ed38332 100644
--- a/python/pyfory/format/infer.py
+++ b/python/pyfory/format/infer.py
@@ -19,7 +19,7 @@ import datetime
import typing
from typing import Optional
-from pyfory.type_util import TypeVisitor, infer_field
+from pyfory.type_util import TypeVisitor, get_homogeneous_tuple_elem_type,
infer_field
from pyfory.format._format import (
Schema,
DataType,
@@ -171,6 +171,13 @@ class ForyTypeVisitor(TypeVisitor):
elem_field = infer_field("item", elem_type, self,
types_path=types_path)
return field(field_name, list_(elem_field.type))
+ def visit_tuple(self, field_name, elem_types, types_path=None):
+ elem_type = get_homogeneous_tuple_elem_type(elem_types)
+ if elem_type is None:
+ raise TypeError(f"Row format supports only homogeneous tuple
annotations, got {elem_types}")
+ elem_field = infer_field("item", elem_type, self,
types_path=types_path)
+ return field(field_name, list_(elem_field.type))
+
def visit_dict(self, field_name, key_type, value_type, types_path=None):
# Infer type recursively for type such as Dict[str, Dict[str, str]]
key_field = infer_field("key", key_type, self, types_path=types_path)
@@ -185,6 +192,8 @@ class ForyTypeVisitor(TypeVisitor):
return field(field_name, struct(fields))
def visit_other(self, field_name, type_, types_path=None):
+ if isinstance(type_, type) and type_.__module__ != "builtins":
+ return self.visit_customized(field_name, type_,
types_path=types_path)
if type_ not in _supported_types_mapping:
raise TypeError(
f"Type {type_} not supported, currently only compositions of
{_supported_types_str} are supported. types_path is {types_path}"
diff --git a/python/pyfory/format/tests/test_infer.py
b/python/pyfory/format/tests/test_infer.py
index ea3a517c0..cfe9bb6c6 100644
--- a/python/pyfory/format/tests/test_infer.py
+++ b/python/pyfory/format/tests/test_infer.py
@@ -17,13 +17,14 @@
import datetime
import pyfory
+import pytest
from dataclasses import dataclass
from pyfory.format.infer import infer_schema, infer_field, ForyTypeVisitor
from pyfory.format import (
TypeId,
)
-from typing import List, Dict
+from typing import Dict, List, Tuple
@dataclass
@@ -53,8 +54,14 @@ def test_infer_field():
assert _infer_field("", str).type.id == TypeId.STRING
assert _infer_field("", bytes).type.id == TypeId.BINARY
assert _infer_field("", List[str]).type.id == TypeId.LIST
+ assert _infer_field("", Tuple[str, ...]).type.id == TypeId.LIST
+ assert _infer_field("", Tuple[int, int]).type.id == TypeId.LIST
assert _infer_field("", Dict[str, str]).type.id == TypeId.MAP
assert _infer_field("", List[Dict[str, str]]).type.id == TypeId.LIST
+ assert _infer_field("", List[Tuple[int, ...]]).type.id == TypeId.LIST
+
+ with pytest.raises(TypeError):
+ _infer_field("", Tuple[str, int])
# Custom class is treated as a struct
class X:
@@ -89,5 +96,18 @@ def test_type_id():
assert pyfory.format.infer.get_type_id(datetime.datetime) ==
TypeId.TIMESTAMP
+def test_infer_class_schema_with_tuple_fields():
+ @dataclass
+ class TupleFoo:
+ f1: Tuple[str, ...]
+ f2: List[Tuple[int, int]]
+ f3: Dict[str, Tuple[pyfory.int32, ...]]
+
+ schema = infer_schema(TupleFoo)
+ assert schema.field(0).type.id == TypeId.LIST
+ assert schema.field(1).type.id == TypeId.LIST
+ assert schema.field(2).type.id == TypeId.MAP
+
+
if __name__ == "__main__":
test_infer_class_schema()
diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py
index 69680eaf1..8ee660864 100644
--- a/python/pyfory/meta/typedef.py
+++ b/python/pyfory/meta/typedef.py
@@ -21,7 +21,7 @@ from typing import List
from pyfory.types import TypeId, is_polymorphic_type, is_union_type
from pyfory._fory import NO_USER_TYPE_ID
from pyfory.serialization import Buffer
-from pyfory.type_util import infer_field
+from pyfory.type_util import get_homogeneous_tuple_elem_type, infer_field
from pyfory.meta.metastring import Encoding
from pyfory.type_util import infer_field_types
@@ -356,12 +356,23 @@ class CollectionFieldType(FieldType):
self.element_type = element_type
def create_serializer(self, resolver, type_):
- from pyfory.serializer import ListSerializer, SetSerializer
+ from pyfory.serializer import ListSerializer, SetSerializer,
TupleSerializer
- elem_type = type_[1] if type_ and len(type_) >= 2 else None
+ declared_root_type = type_
+ elem_type = None
+ if isinstance(type_, list):
+ declared_root_type = type_[0]
+ if isinstance(declared_root_type, tuple):
+ if declared_root_type:
+ declared_root_type, *extra = declared_root_type
+ elem_type = extra[0] if extra else None
+ elif type_ and len(type_) >= 2:
+ elem_type = type_[1]
elem_serializer = self.element_type.create_serializer(resolver,
elem_type)
elem_override = getattr(self.element_type, "tracking_ref_override",
None)
if self.type_id == TypeId.LIST:
+ if declared_root_type in (tuple, typing.Tuple):
+ return TupleSerializer(resolver.fory, tuple, elem_serializer,
elem_override)
return ListSerializer(resolver.fory, list, elem_serializer,
elem_override)
elif self.type_id == TypeId.SET:
return SetSerializer(resolver.fory, set, elem_serializer,
elem_override)
@@ -569,6 +580,10 @@ def build_field_type_from_type_ids_with_ref(
args = typing.get_args(type_hint) if hasattr(typing,
"get_args") else getattr(type_hint, "__args__", ())
if args:
elem_hint, elem_ref_override = unwrap_ref(args[0])
+ elif origin in (tuple, typing.Tuple):
+ tuple_elem_hint = get_homogeneous_tuple_elem_type(type_hint)
+ if tuple_elem_hint is not None:
+ elem_hint, elem_ref_override = unwrap_ref(tuple_elem_hint)
elem_tracking_ref = is_tracking_ref
if elem_ref_override is not None:
elem_tracking_ref = elem_ref_override and is_tracking_ref
diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py
index dcae0f42c..c3d584c7c 100644
--- a/python/pyfory/registry.py
+++ b/python/pyfory/registry.py
@@ -637,6 +637,8 @@ class TypeResolver:
return self.get_type_info(cls).serializer
def get_type_info(self, cls, create=True):
+ if cls is tuple and self.fory.xlang:
+ return self.get_type_info(list, create=create)
type_info = self._types_info.get(cls)
if type_info is not None:
if type_info.serializer is None:
diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py
index e8ca76fa9..4d3089414 100644
--- a/python/pyfory/serializer.py
+++ b/python/pyfory/serializer.py
@@ -1286,7 +1286,8 @@ class ObjectSerializer(Serializer):
if self._slot_field_names is not None:
sorted_field_names = self._slot_field_names
else:
- sorted_field_names = sorted(value.__dict__.keys())
+ value_dict = getattr(value, "__dict__", None)
+ sorted_field_names = [] if value_dict is None else
sorted(value_dict.keys())
buffer.write_var_uint32(len(sorted_field_names))
for field_name in sorted_field_names:
diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py
index bc24dea93..3a842d8d7 100644
--- a/python/pyfory/struct.py
+++ b/python/pyfory/struct.py
@@ -57,6 +57,7 @@ from pyfory.types import (
from pyfory.type_util import (
TypeVisitor,
infer_field,
+ get_homogeneous_tuple_elem_type,
is_subclass,
get_type_hints,
unwrap_optional,
@@ -664,6 +665,17 @@ class StructFieldSerializerVisitor(TypeVisitor):
elem_serializer = infer_field("item", elem_type, self,
types_path=types_path)
return SetSerializer(self.fory, set, elem_serializer,
elem_ref_override)
+ def visit_tuple(self, field_name, elem_types, types_path=None):
+ from pyfory.serializer import TupleSerializer # Local import
+ from pyfory.type_util import unwrap_ref
+
+ elem_type = get_homogeneous_tuple_elem_type(elem_types)
+ if elem_type is not None:
+ elem_type, elem_ref_override = unwrap_ref(elem_type)
+ elem_serializer = infer_field("item", elem_type, self,
types_path=types_path)
+ return TupleSerializer(self.fory, tuple, elem_serializer,
elem_ref_override)
+ return TupleSerializer(self.fory, tuple)
+
def visit_dict(self, field_name, key_type, value_type, types_path=None):
from pyfory.serializer import MapSerializer # Local import
from pyfory.type_util import unwrap_ref
@@ -751,7 +763,7 @@ def group_fields(type_resolver, field_names, serializers,
nullable_map=None, fie
container = nullable_boxed_types if is_nullable else boxed_types
elif type_id == TypeId.SET:
container = set_types
- elif is_list_type(serializer.type_):
+ elif type_id == TypeId.LIST or is_list_type(serializer.type_):
container = collection_types
elif is_map_type(serializer.type_):
container = map_types
@@ -938,6 +950,13 @@ class StructTypeIdVisitor(TypeVisitor):
elem_ids = infer_field("item", elem_type, self, types_path=types_path)
return TypeId.SET, elem_ids
+ def visit_tuple(self, field_name, elem_types, types_path=None):
+ elem_type = get_homogeneous_tuple_elem_type(elem_types)
+ if elem_type is None:
+ return TypeId.LIST, [TypeId.UNKNOWN]
+ elem_ids = infer_field("item", elem_type, self, types_path=types_path)
+ return TypeId.LIST, elem_ids
+
def visit_dict(self, field_name, key_type, value_type, types_path=None):
# Infer type recursively for type such as Dict[str, Dict[str, str]]
key_ids = infer_field("key", key_type, self, types_path=types_path)
@@ -973,6 +992,13 @@ class StructTypeVisitor(TypeVisitor):
elem_types = infer_field("item", elem_type, self,
types_path=types_path)
return typing.Set, elem_types
+ def visit_tuple(self, field_name, elem_types, types_path=None):
+ elem_type = get_homogeneous_tuple_elem_type(elem_types)
+ if elem_type is None:
+ return tuple, None
+ elem_types_ = infer_field("item", elem_type, self,
types_path=types_path)
+ return tuple, elem_types_
+
def visit_dict(self, field_name, key_type, value_type, types_path=None):
# Infer type recursively for type such as Dict[str, Dict[str, str]]
key_types = infer_field("key", key_type, self, types_path=types_path)
diff --git a/python/pyfory/tests/test_collection.py
b/python/pyfory/tests/test_collection.py
index 93cb8ff2a..8a3406e72 100644
--- a/python/pyfory/tests/test_collection.py
+++ b/python/pyfory/tests/test_collection.py
@@ -140,6 +140,22 @@ class TestTupleWithNone:
assert result == data
+class TestTupleXlang:
+ @pytest.mark.parametrize("ref", [False, True])
+ def test_top_level_tuple_roundtrip_to_list(self, ref):
+ fory = pyfory.Fory(xlang=True, ref=ref, strict=False)
+ data = ("a", 1, ("nested", 2))
+ result = fory.loads(fory.dumps(data))
+ assert result == ["a", 1, ["nested", 2]]
+
+ @pytest.mark.parametrize("ref", [False, True])
+ def test_nested_dynamic_tuples_roundtrip_to_lists(self, ref):
+ fory = pyfory.Fory(xlang=True, ref=ref, strict=False)
+ data = [("a", 1), {"x": ("b", 2), "y": [("c", 3)]}]
+ result = fory.loads(fory.dumps(data))
+ assert result == [["a", 1], {"x": ["b", 2], "y": [["c", 3]]}]
+
+
class TestDictWithNone:
"""Test dict serialization with None keys/values."""
diff --git a/python/pyfory/tests/test_serializer.py
b/python/pyfory/tests/test_serializer.py
index 386c7c5e2..454fa15c2 100644
--- a/python/pyfory/tests/test_serializer.py
+++ b/python/pyfory/tests/test_serializer.py
@@ -722,6 +722,23 @@ def test_py_serialize_object(track_ref):
assert ser_de(fory, obj2) == obj2
[email protected]("track_ref", [False, True])
+def test_py_serialize_empty_object(track_ref):
+ fory = Fory(xlang=False, ref=track_ref, strict=False)
+ obj = object()
+ result = ser_de(fory, obj)
+ assert type(result) is object
+
+ repeated = [obj, obj]
+ repeated_result = ser_de(fory, repeated)
+ assert type(repeated_result[0]) is object
+ assert type(repeated_result[1]) is object
+ if track_ref:
+ assert repeated_result[0] is repeated_result[1]
+ else:
+ assert repeated_result[0] is not repeated_result[1]
+
+
def test_dumps_loads():
fory = Fory(xlang=False, ref=True)
obj = {"a": 1, "b": 2}
diff --git a/python/pyfory/tests/test_struct.py
b/python/pyfory/tests/test_struct.py
index abc8cd0ff..678fccd17 100644
--- a/python/pyfory/tests/test_struct.py
+++ b/python/pyfory/tests/test_struct.py
@@ -19,7 +19,7 @@ import dataclasses
from dataclasses import dataclass
import datetime
import enum
-from typing import Dict, Any, List, Set, Optional
+from typing import Dict, Any, List, Set, Optional, Tuple
import pytest
import typing
@@ -145,6 +145,25 @@ class BoolCoercionObject:
b: bool
+@dataclass(frozen=True)
+class TupleFieldObject:
+ bar: Tuple[str, int]
+
+
+@dataclass(frozen=True)
+class XlangTupleFieldObject:
+ bar: Tuple[str, int]
+
+
+@dataclass(frozen=True)
+class XlangNestedTupleObject:
+ tuple_field: Tuple[List[int], Dict[str, int]]
+ list_of_tuples: List[Tuple[str, int]]
+ map_of_tuples: Dict[str, Tuple[str, int]]
+ set_of_tuples: Set[Tuple[str, int]]
+ tuple_of_tuples: Tuple[Tuple[str, int], Tuple[str, int]]
+
+
def test_sort_fields():
@dataclass
class TestClass:
@@ -289,6 +308,44 @@ def test_data_class_serializer_xlang():
assert obj_deserialized_none == obj_with_none_complex
[email protected]("track_ref", [False, True])
+def test_dataclass_with_typed_tuple_field(track_ref):
+ fory = Fory(xlang=False, ref=track_ref, strict=False)
+ obj = TupleFieldObject(bar=("a", 1))
+ assert ser_de(fory, obj) == obj
+
+
[email protected]("track_ref", [False, True])
+def test_xlang_dataclass_tuple_field(track_ref):
+ fory = Fory(xlang=True, ref=track_ref, strict=False)
+ fory.register_type(XlangTupleFieldObject,
typename="example.XlangTupleFieldObject")
+ obj = XlangTupleFieldObject(bar=("a", 1))
+ result = ser_de(fory, obj)
+ assert result == obj
+ assert isinstance(result.bar, tuple)
+
+
[email protected]("track_ref", [False, True])
+def test_xlang_nested_tuple_container_fields(track_ref):
+ fory = Fory(xlang=True, ref=track_ref, strict=False)
+ fory.register_type(XlangNestedTupleObject,
typename="example.XlangNestedTupleObject")
+ obj = XlangNestedTupleObject(
+ tuple_field=([1, 2], {"a": 1, "b": 2}),
+ list_of_tuples=[("a", 1), ("b", 2)],
+ map_of_tuples={"left": ("c", 3), "right": ("d", 4)},
+ set_of_tuples={("e", 5), ("f", 6)},
+ tuple_of_tuples=(("g", 7), ("h", 8)),
+ )
+ result = ser_de(fory, obj)
+ assert result == obj
+ assert isinstance(result.tuple_field, tuple)
+ assert all(isinstance(value, tuple) for value in result.list_of_tuples)
+ assert all(isinstance(value, tuple) for value in
result.map_of_tuples.values())
+ assert all(isinstance(value, tuple) for value in result.set_of_tuples)
+ assert isinstance(result.tuple_of_tuples, tuple)
+ assert all(isinstance(value, tuple) for value in result.tuple_of_tuples)
+
+
def test_struct_evolving_override():
@pyfory.dataclass
class EvolvingStruct:
diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py
index 2b2451027..7b37e9b7e 100644
--- a/python/pyfory/type_util.py
+++ b/python/pyfory/type_util.py
@@ -181,6 +181,9 @@ class TypeVisitor(ABC):
def visit_dict(self, field_name, key_type, value_type, types_path=None):
pass
+ def visit_tuple(self, field_name, elem_types, types_path=None):
+ raise TypeError(f"Tuple type with elements {elem_types} is not
supported")
+
@abstractmethod
def visit_customized(self, field_name, type_, types_path=None):
pass
@@ -220,6 +223,24 @@ def unwrap_optional(type_, field_nullable=False):
return typing.Union[tuple(non_none_types)], True
+def get_homogeneous_tuple_elem_type(type_or_args):
+ if isinstance(type_or_args, tuple):
+ args = type_or_args
+ else:
+ origin = typing.get_origin(type_or_args) if hasattr(typing,
"get_origin") else getattr(type_or_args, "__origin__", None)
+ if origin not in (tuple, typing.Tuple):
+ return None
+ args = typing.get_args(type_or_args) if hasattr(typing, "get_args")
else getattr(type_or_args, "__args__", ())
+ if not args or args == ((),):
+ return None
+ if len(args) == 2 and args[1] is Ellipsis:
+ return args[0]
+ first = args[0]
+ if all(arg == first for arg in args[1:]):
+ return first
+ return None
+
+
def infer_field(field_name, type_, visitor: TypeVisitor, types_path=None):
types_path = list(types_path or [])
type_, _ = unwrap_ref(type_)
@@ -234,6 +255,8 @@ def infer_field(field_name, type_, visitor: TypeVisitor,
types_path=None):
elif origin is set or origin == typing.Set:
elem_type = args[0]
return visitor.visit_set(field_name, elem_type,
types_path=types_path)
+ elif origin is tuple or origin == typing.Tuple:
+ return visitor.visit_tuple(field_name, args, types_path=types_path)
elif origin is dict or origin == typing.Dict:
key_type, value_type = args
return visitor.visit_dict(field_name, key_type, value_type,
types_path=types_path)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]