junrushao commented on code in PR #101:
URL: https://github.com/apache/tvm-ffi/pull/101#discussion_r2424506051


##########
python/tvm_ffi/stub/stubgen.py:
##########
@@ -0,0 +1,517 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""TVM-FFI Stub Generator (``tvm-ffi-stubgen``).
+
+Overview
+--------
+This module powers the ``tvm-ffi-stubgen`` command line tool which generates
+in-place type stubs for Python modules that integrate with the TVM FFI
+ecosystem. It scans ``.py``/``.pyi`` files for special comment markers and
+fills the enclosed blocks with precise, static type annotations derived from
+runtime metadata exposed by TVM FFI.
+
+Why you might use this
+----------------------
+- You author a Python module that binds to C++/C via TVM FFI and want
+  high-quality type hints for functions, objects, and methods.
+- You maintain a downstream extension that registers global functions or
+  FFI object types and want your Python API surface to be type-checker
+  friendly without manually writing stubs.
+
+How it works (in one sentence)
+------------------------------
+``tvm-ffi-stubgen`` replaces the content between special ``tvm-ffi-stubgen`` 
markers
+with generated code guarded by ``if TYPE_CHECKING: ...`` so that the runtime
+behavior is unchanged while static analyzers get rich types.
+
+Stub block markers
+------------------
+Insert one of the following begin/end markers in your source, then run
+``tvm-ffi-stubgen``. Indentation on the ``begin`` line is preserved; generated
+content is additionally indented by ``--indent`` spaces (default: 4).
+
+1) Global function stubs
+
+    Mark all global functions whose names start with a registry prefix
+    (e.g. ``ffi`` or ``my_ffi_extension``):
+
+    .. code-block:: python
+
+        from typing import TYPE_CHECKING
+
+        # tvm-ffi-stubgen(begin): global/ffi
+        if TYPE_CHECKING:
+            # fmt: off
+            # (generated by tvm-ffi-stubgen)
+            # fmt: on
+        # tvm-ffi-stubgen(end)
+
+    ``tvm-ffi-stubgen`` expands this with function signatures discovered via 
the
+    TVM FFI global function registry.
+
+2) Object type stubs
+
+    Mark fields and methods for a registered FFI object type using its
+    ``type_key`` (the key passed to ``@register_object``):
+
+    .. code-block:: python
+
+        @register_object("testing.SchemaAllTypes")
+        class _SchemaAllTypes:
+            # tvm-ffi-stubgen(begin): object/testing.SchemaAllTypes
+            # tvm-ffi-stubgen(ty_map): testing.SchemaAllTypes -> 
_SchemaAllTypes
+            if TYPE_CHECKING:
+                # fmt: off
+                # (generated by tvm-ffi-stubgen)
+                # fmt: on
+            # tvm-ffi-stubgen(end)
+
+    ``tvm-ffi-stubgen`` expands this with annotated attributes and method stub
+    signatures. The special C FFI initializer ``__ffi_init__`` is exposed as
+    ``__c_ffi_init__`` to avoid interfering with your Python ``__init__``.
+
+3) Skip whole file
+
+    If a source file should never be modified by the stub generator, add the
+    following directive anywhere in the file:
+
+    .. code-block:: python
+
+        # tvm-ffi-stubgen(skip-file)
+
+    When present, ``tvm-ffi-stubgen`` skips processing this file entirely. This
+    is useful for files that are generated by other tooling or vendored.
+
+Optional type mapping lines
+---------------------------
+Inside a stub block you may add mapping hints to rename fully-qualified type
+names to simpler aliases in the generated output:
+
+.. code-block:: python
+
+    # tvm-ffi-stubgen(ty_map): A.B.C -> C
+    # tvm-ffi-stubgen(ty_map): list -> Sequence
+    # tvm-ffi-stubgen(ty_map): dict -> Mapping
+
+By default, ``list`` is shown as ``Sequence`` and ``dict`` as ``Mapping``.
+If you use names such as ``Sequence``/``Mapping``, ensure they are available
+to type checkers in your module, for example:
+
+.. code-block:: python
+
+    from typing import TYPE_CHECKING
+    if TYPE_CHECKING:
+        from collections.abc import Mapping, Sequence
+
+Runtime requirements
+--------------------
+- Python must be able to import ``tvm_ffi``.
+- The process needs access to the TVM runtime and any extension libraries that
+  provide the global functions or object types you want to stub. Use the
+  ``--dlls`` option to preload shared libraries when necessary.
+
+What files are modified
+-----------------------
+Only files with extensions ``.py`` and ``.pyi`` are scanned. Files are updated
+in place. A colored unified diff is printed for each change.
+
+CLI quick start
+---------------
+
+.. code-block:: bash
+
+    # Generate stubs for a single file
+    tvm-ffi-stubgen python/tvm_ffi/_ffi_api.py
+
+    # Recursively scan directories for tvm-ffi-stubgen blocks
+    tvm-ffi-stubgen python/tvm_ffi examples/packaging/python/my_ffi_extension
+
+    # Preload TVM runtime and your extension library before generation
+    tvm-ffi-stubgen \
+      --dlls build/libtvm_runtime.dylib build/libmy_ext.dylib \
+      python/tvm_ffi/_ffi_api.py
+
+Exit status
+-----------
+Returns 0 on success and 1 if any file fails to process.
+
+"""
+
+from __future__ import annotations
+
+import argparse
+import ctypes
+import dataclasses
+import difflib
+import logging
+import sys
+from io import StringIO
+from pathlib import Path
+from typing import Callable
+
+from tvm_ffi.core import TypeSchema, 
_lookup_or_register_type_info_from_type_key
+from tvm_ffi.registry import get_global_func_metadata, list_global_func_names
+
+DEFAULT_SOURCE_EXTS = {".py", ".pyi"}
+STUB_BEGIN = "# tvm-ffi-stubgen(begin):"
+STUB_END = "# tvm-ffi-stubgen(end)"
+STUB_TY_MAP = "# tvm-ffi-stubgen(ty_map):"
+STUB_SKIP_FILE = "# tvm-ffi-stubgen(skip-file)"
+
+TERM_RESET = "\033[0m"
+TERM_BOLD = "\033[1m"
+TERM_RED = "\033[31m"
+TERM_GREEN = "\033[32m"
+TERM_YELLOW = "\033[33m"
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
[email protected]
+class Options:
+    """Command line options for stub generation."""
+
+    dlls: list[str] = dataclasses.field(default_factory=list)
+    indent: int = 4
+    files: list[str] = dataclasses.field(default_factory=list)
+    header_lines: list[str] = dataclasses.field(default_factory=list)
+    suppress_print: bool = False
+
+
[email protected]
+class StubConfig:
+    """Configuration of a stub block."""
+
+    name: str
+    indent: int
+    lineno: int
+    ty_map: dict[str, str] = dataclasses.field(
+        default_factory=lambda: dict(
+            {
+                "list": "Sequence",
+                "dict": "Mapping",
+            }
+        )
+    )
+
+
+def _as_func_signature(
+    schema: TypeSchema,
+    func_name: str,
+    ty_map: Callable[[str], str],
+) -> str:
+    buf = StringIO()
+    buf.write(f"def {func_name}(")
+    if schema.origin != "Callable":
+        raise ValueError(f"Expected Callable type schema, but got: {schema}")
+    if not schema.args:
+        buf.write("*args: Any) -> Any:")
+        return buf.getvalue()
+    arg_ret = schema.args[0]
+    arg_args = schema.args[1:]
+    for i, arg in enumerate(arg_args):
+        buf.write(f"_{i}: ")
+        buf.write(arg.repr(ty_map))
+        buf.write(", ")
+    if arg_args:
+        buf.write("/")
+    buf.write(") -> ")
+    buf.write(arg_ret.repr(ty_map))
+    buf.write(":")
+    return buf.getvalue()
+
+
+def _filter_files(paths: list[Path]) -> list[Path]:
+    results: list[Path] = []
+    for p in paths:
+        if not p.exists():
+            raise FileNotFoundError(f"Path does not exist: {p}")
+        if p.is_dir():
+            for f in p.rglob("*"):
+                if f.is_file() and f.suffix.lower() in DEFAULT_SOURCE_EXTS:
+                    results.append(f.resolve())
+            continue
+        f = p.resolve()
+        if f.is_file() and f.suffix.lower() in DEFAULT_SOURCE_EXTS:
+            results.append(f)
+    # Deterministic order
+    return sorted(set(results))
+
+
+def _make_type_map(name_map: dict[str, str]) -> Callable[[str], str]:
+    def map_type(name: str) -> str:
+        if (ret := name_map.get(name)) is not None:
+            return ret
+        return name.rsplit(".", 1)[-1]
+
+    return map_type
+
+
+def _generate_global(
+    stub: StubConfig,
+    global_func_tab: dict[str, list[str]],
+    opt: Options,
+) -> list[str]:
+    assert stub.name.startswith("global/")
+    prefix = stub.name[len("global/") :].strip()
+    ty_map = _make_type_map(stub.ty_map)
+    indent = " " * (stub.indent + opt.indent)
+    results: list[str] = [
+        " " * stub.indent + "if TYPE_CHECKING:",
+        f"{indent}# fmt: off",
+    ]
+    for name in global_func_tab.get(prefix, []):
+        schema_str = 
get_global_func_metadata(f"{prefix}.{name}")["type_schema"]
+        schema = TypeSchema.from_json_str(schema_str)
+        sig = _as_func_signature(schema, name, ty_map=ty_map)
+        func = f"{indent}{sig} ..."
+        results.append(func)
+    if len(results) > 2:
+        results.append(f"{indent}# fmt: on")
+    else:
+        results = []
+    return results
+
+
+def _show_diff(old: list[str], new: list[str]) -> None:
+    for line in difflib.unified_diff(old, new, lineterm=""):
+        # Skip placeholder headers when fromfile/tofile are unspecified
+        if line.startswith("---") or line.startswith("+++"):
+            continue
+        if line.startswith("-") and not line.startswith("---"):
+            print(f"{TERM_RED}{line}{TERM_RESET}")  # Red for removals
+        elif line.startswith("+") and not line.startswith("+++"):
+            print(f"{TERM_GREEN}{line}{TERM_RESET}")  # Green for additions
+        elif line.startswith("?"):
+            print(f"{TERM_YELLOW}{line}{TERM_RESET}")  # Yellow for hints
+        else:
+            print(line)
+
+
+def _generate_object(
+    stub: StubConfig,
+    opt: Options,
+) -> list[str]:
+    assert stub.name.startswith("object/")
+    type_key = stub.name[len("object/") :].strip()
+    ty_map = _make_type_map(stub.ty_map)
+    indent = " " * (stub.indent + opt.indent)
+    results: list[str] = [
+        " " * stub.indent + "if TYPE_CHECKING:",
+        f"{indent}# fmt: off",
+    ]
+
+    type_info = _lookup_or_register_type_info_from_type_key(type_key)
+    for field in type_info.fields:
+        schema = TypeSchema.from_json_str(field.metadata["type_schema"])
+        schema_str = schema.repr(ty_map=ty_map)
+        results.append(f"{indent}{field.name}: {schema_str}")
+    for method in type_info.methods:
+        name = method.name
+        if name == "__ffi_init__":
+            name = "__c_ffi_init__"
+        schema = TypeSchema.from_json_str(method.metadata["type_schema"])
+        schema_str = _as_func_signature(schema, name, ty_map=ty_map)
+        if method.is_static:
+            results.append(f"{indent}@staticmethod")
+        results.append(f"{indent}{schema_str} ...")
+    if len(results) > 2:
+        results.append(f"{indent}# fmt: on")
+    else:
+        results = []
+    return results
+
+
+def _has_skip_file_marker(lines: list[str]) -> bool:
+    for raw in lines:
+        if raw.strip().startswith(STUB_SKIP_FILE):
+            return True
+    return False
+
+
+def _main(file: Path, opt: Options) -> None:  # noqa: PLR0912, PLR0915
+    assert file.is_file(), f"Expected a file, but got: {file}"
+
+    lines_now = file.read_text(encoding="utf-8").splitlines()
+
+    # New directive: skip processing this file entirely if present.
+    if _has_skip_file_marker(lines_now):
+        if not opt.suppress_print:
+            print(f"{TERM_YELLOW}[Skipped]  {file}{TERM_RESET}")
+        return
+
+    # Build global function table only if we are going to process blocks.
+    global_func_tab: dict[str, list[str]] = {}
+    for name in list_global_func_names():
+        prefix, suffix = name.rsplit(".", 1)
+        global_func_tab.setdefault(prefix, []).append(suffix)
+    # Ensure stable ordering for deterministic output.
+    for k in list(global_func_tab.keys()):
+        global_func_tab[k].sort()
+    lines_new: list[str] = []
+    stub: StubConfig | None = None
+    skipped: bool = True
+    for lineno, line in enumerate(lines_now, 1):
+        clean_line = line.strip()
+        if clean_line.startswith(STUB_BEGIN):
+            if stub is not None:
+                raise ValueError(f"Nested stub not permitted, but found at 
{file}:{lineno}")
+            stub = StubConfig(
+                name=clean_line[len(STUB_BEGIN) :].strip(),
+                indent=len(line) - len(clean_line),
+                lineno=lineno,
+            )
+            skipped = False
+            lines_new.append(line)
+        elif clean_line.startswith(STUB_END):
+            if stub is None:
+                raise ValueError(f"Unmatched stub end found at 
{file}:{lineno}")
+            if stub.name.startswith("global/"):
+                lines_new.extend(_generate_global(stub, global_func_tab, opt))
+            elif stub.name.startswith("object/"):
+                lines_new.extend(_generate_object(stub, opt))
+            else:
+                raise ValueError(f"Unknown stub type `{stub.name}` at 
{file}:{stub.lineno}")
+            stub = None
+            lines_new.append(line)
+        elif clean_line.startswith(STUB_TY_MAP):
+            if stub is None:
+                raise ValueError(f"Stub ty_map outside stub block at 
{file}:{lineno}")
+            ty_map = clean_line[len(STUB_TY_MAP) :].strip()
+            try:
+                lhs, rhs = ty_map.split("->")
+            except Exception as e:
+                raise ValueError(
+                    f"Invalid ty_map format at {file}:{lineno}. Example: `A.B 
-> C`"
+                ) from e

Review Comment:
   Good point! Fixed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to