Copilot commented on code in PR #8: URL: https://github.com/apache/tvm-ffi/pull/8#discussion_r2364372478
########## python/tvm_ffi/dataclasses/field.py: ########## @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Counterpart of `dataclasses.field`.""" + +from __future__ import annotations + +from dataclasses import MISSING, dataclass +from typing import Any, Callable + + +@dataclass(kw_only=True) +class Field: + """Produced by `tvm_ffi.dataclasses.field`. + + TODO(@junrushao): add more fields according to: https://docs.python.org/3/library/dataclasses.html#dataclasses.field + """ + + name: str | None = None + default_factory: Callable[[], Any] + + +def field(*, default: Any = MISSING, default_factory: Any = MISSING) -> Field: + """Declare a dataclass-like field for FFI-backed classes. + + Parameters + ---------- + default : Any + A literal default value. + default_factory : Callable[[], Any] + A factory callable that produces the default value. + + """ + if default is not MISSING and default_factory is not MISSING: + raise ValueError("Cannot specify both `default` and `default_factory`") + if default is not MISSING: + default_factory = lambda: default Review Comment: Lambda function should be replaced with a proper function definition or use functools.partial to avoid late binding issues with mutable defaults. ########## python/tvm_ffi/dataclasses/c_class.py: ########## @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A decorator that builds Python dataclasses from C++ via TVM FFI. + +The ``c_class`` decorator reflects fields and methods from the underlying +FFI type (identified by ``type_key``) and attaches them to the decorated +Python class. +""" + +from collections.abc import Callable +from dataclasses import InitVar +from typing import ClassVar, TypeVar, get_origin, get_type_hints + +from ..core import TypeField, TypeInfo +from . import _utils, field + +try: + from typing import dataclass_transform +except ImportError: + from typing_extensions import dataclass_transform + + +_InputClsType = TypeVar("_InputClsType") + + +@dataclass_transform(field_specifiers=(field.field, field.Field)) +def c_class( + type_key: str, init: bool = True +) -> Callable[[type[_InputClsType]], type[_InputClsType]]: + """Create a decorator that binds a Python dataclass to an FFI type from C++. + + Parameters + ---------- + type_key : str + Type key registered in the TVM FFI registry. + + init : bool, default True + If True and the class does not implement ``__init__``, generate an + ``__init__`` that forwards to the FFI constructor. + + Returns + ------- + Callable[[type], type] + A class decorator that returns the finalized proxy class. + + """ + + def decorator(super_type_cls: type[_InputClsType]) -> type[_InputClsType]: + nonlocal init + init = init and "__init__" not in super_type_cls.__dict__ + # Step 1. Retrieve `type_info` from registry + type_info: TypeInfo = _utils._lookup_type_info_from_type_key(type_key) + assert type_info.parent_type_info is None, f"Already registered type: {type_key}" + type_info.parent_type_info = _utils.get_parent_type_info(super_type_cls) + # Step 2. Reflect all the fields of the type + type_info.fields = _inspect_c_class_fields(super_type_cls, type_info) + for type_field in type_info.fields: + _utils.fill_dataclass_field(super_type_cls, type_field) + # Step 3. Create the proxy class with the fields as properties + fn_init = _utils.method_init(super_type_cls, type_info) if init else None + type_cls: type[_InputClsType] = _utils.type_info_to_cls( + type_info=type_info, + cls=super_type_cls, + methods={"__init__": fn_init}, + ) + type_info.type_cls = type_cls + return type_cls + + return decorator + + +def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeField]: + type_hints_resolved = get_type_hints(type_cls, include_extras=True) + type_hints_py = { + name: type_hints_resolved[name] + for name in getattr(type_cls, "__annotations__", {}).keys() + if get_origin(type_hints_resolved[name]) + not in [ # ignore non-field annotations + ClassVar, + InitVar, + ] + } + del type_hints_resolved + + type_fields_cxx: dict[str, TypeField] = {f.name: f for f in type_info.fields} + type_fields: list[TypeField] = [] + for field_name, _field_ty_py in type_hints_py.items(): + if field_name.startswith("__tvm_ffi"): # TVM's private fields - skip + continue + type_field: TypeField = type_fields_cxx.pop(field_name, None) + if type_field is None: + raise ValueError( + f"Extraneous field `{type_cls}.{field_name}`. Defined in Python but not in C" + ) + type_fields.append(type_field) + if type_fields_cxx: + extra_fields = ", ".join(f"`{f.name}`" for f in type_fields_cxx.values()) + raise ValueError( + f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C but not in Python" Review Comment: Error message should include 'C++' instead of just 'C' for clarity since this is about C++ classes. ```suggestion f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++ but not in Python" ``` ########## python/tvm_ffi/dataclasses/c_class.py: ########## @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A decorator that builds Python dataclasses from C++ via TVM FFI. + +The ``c_class`` decorator reflects fields and methods from the underlying +FFI type (identified by ``type_key``) and attaches them to the decorated +Python class. +""" + +from collections.abc import Callable +from dataclasses import InitVar +from typing import ClassVar, TypeVar, get_origin, get_type_hints + +from ..core import TypeField, TypeInfo +from . import _utils, field + +try: + from typing import dataclass_transform +except ImportError: + from typing_extensions import dataclass_transform + + +_InputClsType = TypeVar("_InputClsType") + + +@dataclass_transform(field_specifiers=(field.field, field.Field)) +def c_class( + type_key: str, init: bool = True +) -> Callable[[type[_InputClsType]], type[_InputClsType]]: + """Create a decorator that binds a Python dataclass to an FFI type from C++. + + Parameters + ---------- + type_key : str + Type key registered in the TVM FFI registry. + + init : bool, default True + If True and the class does not implement ``__init__``, generate an + ``__init__`` that forwards to the FFI constructor. + + Returns + ------- + Callable[[type], type] + A class decorator that returns the finalized proxy class. + + """ + + def decorator(super_type_cls: type[_InputClsType]) -> type[_InputClsType]: + nonlocal init + init = init and "__init__" not in super_type_cls.__dict__ + # Step 1. Retrieve `type_info` from registry + type_info: TypeInfo = _utils._lookup_type_info_from_type_key(type_key) + assert type_info.parent_type_info is None, f"Already registered type: {type_key}" + type_info.parent_type_info = _utils.get_parent_type_info(super_type_cls) + # Step 2. Reflect all the fields of the type + type_info.fields = _inspect_c_class_fields(super_type_cls, type_info) + for type_field in type_info.fields: + _utils.fill_dataclass_field(super_type_cls, type_field) + # Step 3. Create the proxy class with the fields as properties + fn_init = _utils.method_init(super_type_cls, type_info) if init else None + type_cls: type[_InputClsType] = _utils.type_info_to_cls( + type_info=type_info, + cls=super_type_cls, + methods={"__init__": fn_init}, + ) + type_info.type_cls = type_cls + return type_cls + + return decorator + + +def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeField]: + type_hints_resolved = get_type_hints(type_cls, include_extras=True) + type_hints_py = { + name: type_hints_resolved[name] + for name in getattr(type_cls, "__annotations__", {}).keys() + if get_origin(type_hints_resolved[name]) + not in [ # ignore non-field annotations + ClassVar, + InitVar, + ] + } + del type_hints_resolved + + type_fields_cxx: dict[str, TypeField] = {f.name: f for f in type_info.fields} + type_fields: list[TypeField] = [] + for field_name, _field_ty_py in type_hints_py.items(): + if field_name.startswith("__tvm_ffi"): # TVM's private fields - skip + continue + type_field: TypeField = type_fields_cxx.pop(field_name, None) + if type_field is None: + raise ValueError( + f"Extraneous field `{type_cls}.{field_name}`. Defined in Python but not in C" Review Comment: Error message should include 'C++' instead of just 'C' for clarity since this is about C++ classes. ```suggestion f"Extraneous field `{type_cls}.{field_name}`. Defined in Python but not in C++" ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
