junrushao commented on code in PR #36:
URL: https://github.com/apache/tvm-ffi/pull/36#discussion_r2403509386
##########
python/tvm_ffi/cython/type_info.pxi:
##########
@@ -58,6 +59,82 @@ cdef class FieldSetter:
raise_existing_error()
raise move_from_last_error().py_error()
+_TYPE_SCHEMA_ORIGIN_CONVERTER = {
+ # A few Python-native types
+ "Variant": "Union",
+ "Optional": "Optional",
+ "Tuple": "tuple",
+ "ffi.Function": "Callable",
+ "ffi.Array": "list",
+ "ffi.Map": "dict",
+ "ffi.OpaquePyObject": "Any",
+ "ffi.Object": "Object",
+ "ffi.Tensor": "Tensor",
+ "DLTensor*": "Tensor",
+ # ctype types
+ "void*": "ctypes.c_void_p",
+ # bytes
+ "TVMFFIByteArray*": "bytes",
+ "ffi.SmallBytes": "bytes",
+ "ffi.Bytes": "bytes",
+ # strings
+ "std::string": "str",
+ "const char*": "str",
+ "ffi.SmallStr": "str",
+ "ffi.String": "str",
+}
+
+
[email protected](repr=False, frozen=True)
+class TypeSchema:
+ """Type schema for a TVM FFI type."""
+ origin: str
+ args: tuple[TypeSchema, ...] = ()
+
+ def __post_init__(self):
+ origin = self.origin
+ args = self.args
+ if origin == "Union":
+ assert len(args) >= 2, "Union must have at least two arguments"
+ elif origin == "Optional":
+ assert len(args) == 1, "Optional must have exactly one argument"
+ elif origin == "list":
+ assert len(args) == 1, "list must have exactly one argument"
+ elif origin == "dict":
+ assert len(args) == 2, "dict must have exactly two arguments"
+ elif origin == "tuple":
+ pass # tuple can have arbitrary number of arguments
+
+ def __repr__(self) -> str:
+ if self.origin == "Union":
+ return " | ".join(repr(a) for a in self.args)
+ elif self.origin == "Optional":
+ return repr(self.args[0]) + " | None"
+ elif self.origin == "Callable":
+ if not self.args:
+ return "Callable[..., Any]"
+ else:
+ arg_ret = self.args[0]
+ arg_args = self.args[1:]
+ return f"Callable[[{', '.join(repr(a) for a in arg_args)}],
{repr(arg_ret)}]"
+ elif not self.args:
+ return self.origin
+ else:
+ return f"{self.origin}[{', '.join(repr(a) for a in self.args)}]"
+
+ @staticmethod
+ def from_json_obj(obj: dict[str, Any]) -> "TypeSchema":
+ assert isinstance(obj, dict) and "type" in obj, obj
+ origin = obj["type"]
+ origin = _TYPE_SCHEMA_ORIGIN_CONVERTER.get(origin, origin)
+ args = obj.get("args", ())
+ args = tuple(TypeSchema.from_json_obj(a) for a in args)
+ return TypeSchema(origin, args)
+
+ @staticmethod
+ def from_json_str(s) -> "TypeSchema":
+ return TypeSchema.from_json_obj(json.loads(s))
+
Review Comment:
it's probably fine
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]