This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d1f5583e0e [Docs] Add tutorial for importing models from PyTorch, 
ONNX, and TFLite (#19354)
d1f5583e0e is described below

commit d1f5583e0efeb2bd02907e84e30034de46a934af
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Apr 6 14:52:17 2026 -0400

    [Docs] Add tutorial for importing models from PyTorch, ONNX, and TFLite 
(#19354)
    
    This pr adds a tutorial for user to have a quick understanding of how to
    import models from our supporting frontends.
    
    
    Besides, this pr also adds `absl::InitializeLog` and `trackable_obj`
    warnings to the CI docs ignore list — these are emitted by TensorFlow's
    C++ runtime during import and cannot be suppressed from Python.
---
 docs/how_to/tutorials/import_model.py | 407 ++++++++++++++++++++++++++++++++++
 docs/index.rst                        |   1 +
 tests/scripts/task_python_docs.sh     |   3 +
 3 files changed, 411 insertions(+)

diff --git a/docs/how_to/tutorials/import_model.py 
b/docs/how_to/tutorials/import_model.py
new file mode 100644
index 0000000000..888235a859
--- /dev/null
+++ b/docs/how_to/tutorials/import_model.py
@@ -0,0 +1,407 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# ruff: noqa: E402, E501
+
+"""
+.. _import_model:
+
+Importing Models from ML Frameworks
+====================================
+Apache TVM supports importing models from popular ML frameworks including 
PyTorch, ONNX,
+and TensorFlow Lite. This tutorial walks through each import path with a 
minimal working
+example and explains the key parameters. The PyTorch section additionally 
demonstrates
+how to handle unsupported operators via a custom converter map.
+
+For end-to-end optimization and deployment after importing, see 
:ref:`optimize_model`.
+
+.. note::
+
+    The ONNX section requires the ``onnx`` package. The TFLite section requires
+    ``tensorflow`` and ``tflite``. Sections whose dependencies are missing are 
skipped
+    automatically.
+
+.. contents:: Table of Contents
+    :local:
+    :depth: 2
+"""
+
+######################################################################
+# Importing from PyTorch (Recommended)
+# -------------------------------------
+# TVM's PyTorch frontend is the most feature-complete. The recommended entry 
point is
+# :py:func:`~tvm.relax.frontend.torch.from_exported_program`, which works with 
PyTorch's
+# ``torch.export`` API.
+#
+# We start by defining a small CNN model for demonstration. No pretrained 
weights are
+# needed — we only care about the graph structure.
+
+import numpy as np
+import torch
+from torch import nn
+from torch.export import export
+
+import tvm
+from tvm import relax
+from tvm.relax.frontend.torch import from_exported_program
+
+
+class SimpleCNN(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
+        self.bn = nn.BatchNorm2d(16)
+        self.pool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(16, 10)
+
+    def forward(self, x):
+        x = torch.relu(self.bn(self.conv(x)))
+        x = self.pool(x).flatten(1)
+        x = self.fc(x)
+        return x
+
+
+torch_model = SimpleCNN().eval()
+example_args = (torch.randn(1, 3, 32, 32),)
+
+######################################################################
+# Basic import
+# ~~~~~~~~~~~~
+# The standard workflow is: ``torch.export.export()`` → 
``from_exported_program()`` →
+# ``detach_params()``.
+
+with torch.no_grad():
+    exported_program = export(torch_model, example_args)
+    mod = from_exported_program(
+        exported_program,
+        keep_params_as_input=True,
+        unwrap_unit_return_tuple=True,
+    )
+
+mod, params = relax.frontend.detach_params(mod)
+mod.show()
+
+######################################################################
+# Key parameters
+# ~~~~~~~~~~~~~~
+# ``from_exported_program`` accepts several parameters that control how the 
model is
+# translated:
+#
+# - **keep_params_as_input** (bool, default ``False``): When ``True``, model 
weights become
+#   function parameters, separated via ``relax.frontend.detach_params()``. 
When ``False``,
+#   weights are embedded as constants inside the IRModule. Use ``True`` when 
you want to
+#   manage weights independently (e.g., for weight sharing or quantization).
+#
+# - **unwrap_unit_return_tuple** (bool, default ``False``): PyTorch ``export`` 
always wraps
+#   the return value in a tuple. Set ``True`` to unwrap single-element return 
tuples for a
+#   cleaner Relax function signature.
+#
+# - **run_ep_decomposition** (bool, default ``True``): Runs PyTorch's built-in 
operator
+#   decomposition before translation. This breaks high-level ops (e.g., 
``batch_norm``) into
+#   lower-level primitives, which generally improves TVM's coverage and 
optimization
+#   opportunities. Set ``False`` if you want to preserve the original op 
granularity.
+
+######################################################################
+# Handling unsupported operators with ``custom_convert_map``
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# When TVM encounters a PyTorch operator it does not recognize, it raises an 
error
+# indicating the unsupported operator name. You can extend the frontend by 
providing a
+# **custom converter map** — a dictionary mapping operator names to your own 
conversion
+# functions.
+#
+# A custom converter function receives two arguments:
+#
+# - **node** (``torch.fx.Node``): The FX graph node being converted, carrying 
operator
+#   info and references to input nodes.
+# - **importer** (``ExportedProgramImporter``): The importer instance, giving 
access to:
+#
+#   - ``importer.env``: Dict mapping FX nodes to their converted Relax 
expressions.
+#   - ``importer.block_builder``: The Relax ``BlockBuilder`` for emitting 
operations.
+#   - ``importer.retrieve_args(node)``: Helper to look up converted args.
+#
+# The function must return a ``relax.Var`` — the Relax expression for this 
node's output.
+# Here is an example that maps an operator to ``relax.op.sigmoid``:
+
+from tvm.relax.frontend.torch.exported_program_translator import 
ExportedProgramImporter
+
+
+def convert_sigmoid(node: torch.fx.Node, importer: ExportedProgramImporter) -> 
relax.Var:
+    """Custom converter: map an op to relax.op.sigmoid."""
+    args = importer.retrieve_args(node)
+    return importer.block_builder.emit(relax.op.sigmoid(args[0]))
+
+
+######################################################################
+# To use the custom converter, pass it via the ``custom_convert_map`` 
parameter. The key
+# is the ATen operator name in ``"op_name.variant"`` format (e.g., 
``"sigmoid.default"``):
+#
+# .. code-block:: python
+#
+#    mod = from_exported_program(
+#        exported_program,
+#        custom_convert_map={"sigmoid.default": convert_sigmoid},
+#    )
+#
+# .. note::
+#
+#    To find the correct operator name, check the error message TVM raises 
when encountering
+#    the unsupported op — it includes the exact ATen name. You can also 
inspect the exported
+#    program's graph via ``print(exported_program.graph_module.graph)`` to see 
all operator
+#    names.
+
+######################################################################
+# Alternative PyTorch import methods
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Besides ``from_exported_program``, TVM also provides:
+#
+# - :py:func:`~tvm.relax.frontend.torch.from_fx`: Works with 
``torch.fx.GraphModule``
+#   from ``torch.fx.symbolic_trace()``. Requires explicit ``input_info`` 
(shapes and dtypes).
+#   Use this when ``torch.export`` fails on certain Python control flow 
patterns.
+#
+# - :py:func:`~tvm.relax.frontend.torch.relax_dynamo`: A ``torch.compile`` 
backend that
+#   compiles and executes the model through TVM in one step. Useful for 
integrating TVM
+#   into an existing PyTorch training or inference loop.
+#
+# - :py:func:`~tvm.relax.frontend.torch.dynamo_capture_subgraphs`: Captures 
subgraphs from
+#   a PyTorch model into an IRModule via ``torch.compile``. Each subgraph 
becomes a separate
+#   function in the IRModule.
+#
+# For most use cases, ``from_exported_program`` is the recommended path.
+
+######################################################################
+# Verifying the imported model
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# After importing, it is good practice to verify that TVM produces the same 
output as the
+# original framework. We compile with the minimal ``"zero"`` pipeline (no 
tuning) and
+# compare. The same approach applies to models imported via the ONNX and 
TFLite frontends
+# shown below.
+
+mod_compiled = relax.get_pipeline("zero")(mod)
+exec_module = tvm.compile(mod_compiled, target="llvm")
+dev = tvm.cpu()
+vm = relax.VirtualMachine(exec_module, dev)
+
+# Run inference
+input_data = np.random.rand(1, 3, 32, 32).astype("float32")
+tvm_input = tvm.runtime.tensor(input_data, dev)
+tvm_params = [tvm.runtime.tensor(p, dev) for p in params["main"]]
+tvm_out = vm["main"](tvm_input, *tvm_params).numpy()
+
+# Compare with PyTorch
+with torch.no_grad():
+    pt_out = torch_model(torch.from_numpy(input_data)).numpy()
+
+np.testing.assert_allclose(tvm_out, pt_out, rtol=1e-5, atol=1e-5)
+print("PyTorch vs TVM outputs match!")
+
+######################################################################
+# Importing from ONNX
+# --------------------
+# TVM can import ONNX models via 
:py:func:`~tvm.relax.frontend.onnx.from_onnx`. The
+# function accepts an ``onnx.ModelProto`` object, so you need to load the 
model with
+# ``onnx.load()`` first.
+#
+# Here we export the same CNN model to ONNX format and then import it into TVM.
+
+try:
+    import onnx
+    import onnxscript  # noqa: F401  # required by torch.onnx.export
+
+    HAS_ONNX = True
+except ImportError:
+    onnx = None  # type: ignore[assignment]
+    HAS_ONNX = False
+
+if HAS_ONNX:
+    from tvm.relax.frontend.onnx import from_onnx
+
+    # Export the PyTorch model to ONNX
+    dummy_input = torch.randn(1, 3, 32, 32)
+    onnx_path = "simple_cnn.onnx"
+    torch.onnx.export(torch_model, dummy_input, onnx_path, 
input_names=["input"])
+
+    # Load and import into TVM
+    onnx_model = onnx.load(onnx_path)
+    mod_onnx = from_onnx(onnx_model, keep_params_in_input=True)
+    mod_onnx, params_onnx = relax.frontend.detach_params(mod_onnx)
+    mod_onnx.show()
+
+######################################################################
+# If you already have an ``.onnx`` file on disk, the workflow is even simpler:
+#
+# .. code-block:: python
+#
+#    import onnx
+#    from tvm.relax.frontend.onnx import from_onnx
+#
+#    onnx_model = onnx.load("my_model.onnx")
+#    mod = from_onnx(onnx_model)
+#
+
+######################################################################
+# Key parameters
+# ~~~~~~~~~~~~~~
+# - **shape_dict** (dict, optional): Maps input names to shapes. Auto-inferred 
from the
+#   model if not provided. Useful when the ONNX model has dynamic dimensions 
that you
+#   want to fix to concrete sizes:
+#
+#   .. code-block:: python
+#
+#      mod = from_onnx(onnx_model, shape_dict={"input": [1, 3, 224, 224]})
+#
+# - **dtype_dict** (str or dict, default ``"float32"``): Input dtypes. A 
single string
+#   applies to all inputs, or use a dict to set per-input dtypes:
+#
+#   .. code-block:: python
+#
+#      mod = from_onnx(onnx_model, dtype_dict={"input": "float16"})
+#
+# - **keep_params_in_input** (bool, default ``False``): Same semantics as 
PyTorch — whether
+#   model weights are function parameters or embedded constants.
+#
+# - **opset** (int, optional): Override the opset version auto-detected from 
the model.
+#   Each ONNX op may have different semantics across opset versions; TVM's 
converter
+#   selects the appropriate implementation automatically. You rarely need to 
set this
+#   unless the model metadata is incorrect.
+
+######################################################################
+# Importing from TensorFlow Lite
+# -------------------------------
+# TVM can import TFLite flat buffer models via
+# :py:func:`~tvm.relax.frontend.tflite.from_tflite`. The function expects a 
TFLite
+# ``Model`` object parsed from flat buffer bytes via ``GetRootAsModel``.
+#
+# .. note::
+#
+#    The ``tflite`` Python package has changed its module layout across 
versions.
+#    Older versions use ``tflite.Model.Model.GetRootAsModel``, while newer 
versions use
+#    ``tflite.Model.GetRootAsModel``. The code below handles both.
+#
+# Below we create a minimal TFLite model from TensorFlow and import it.
+
+try:
+    import tensorflow as tf
+    import tflite
+    import tflite.Model
+
+    HAS_TFLITE = True
+except ImportError:
+    HAS_TFLITE = False
+
+if HAS_TFLITE:
+    from tvm.relax.frontend.tflite import from_tflite
+
+    # Define a simple TF module and convert to TFLite.
+    # We use plain TF ops (not keras layers) to avoid variable-handling ops
+    # that some TFLite converter versions do not support cleanly.
+    class TFModule(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(1, 784), dtype=tf.float32),
+                tf.TensorSpec(shape=(784, 10), dtype=tf.float32),
+            ]
+        )
+        def forward(self, x, weight):
+            return tf.matmul(x, weight) + 0.1
+
+    tf_module = TFModule()
+    converter = tf.lite.TFLiteConverter.from_concrete_functions(
+        [tf_module.forward.get_concrete_function()], tf_module
+    )
+    tflite_buf = converter.convert()
+
+    # Parse and import into TVM (API differs between tflite package versions)
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
+    mod_tflite = from_tflite(tflite_model)
+    mod_tflite.show()
+
+######################################################################
+# Loading from a ``.tflite`` file
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# If you already have a ``.tflite`` file on disk, load the raw bytes and parse 
them:
+#
+# .. code-block:: python
+#
+#    import tflite
+#    import tflite.Model
+#    from tvm.relax.frontend.tflite import from_tflite
+#
+#    with open("my_model.tflite", "rb") as f:
+#        tflite_buf = f.read()
+#
+#    if hasattr(tflite.Model, "Model"):
+#        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
+#    else:
+#        tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
+#    mod = from_tflite(tflite_model)
+
+######################################################################
+# Key parameters
+# ~~~~~~~~~~~~~~
+# - **shape_dict** / **dtype_dict** (optional): Override input shapes and 
dtypes. If not
+#   provided, they are inferred from the TFLite model metadata.
+#
+# - **op_converter** (class, optional): A custom operator converter class. 
Subclass
+#   ``OperatorConverter`` and override its ``convert_map`` dictionary to add 
or replace
+#   operator conversions. For example, to add a hypothetical ``CUSTOM_RELU`` 
op:
+#
+#   .. code-block:: python
+#
+#      from tvm.relax.frontend.tflite.tflite_frontend import OperatorConverter
+#
+#      class MyConverter(OperatorConverter):
+#          def __init__(self, model, subgraph, exp_tab, ctx):
+#              super().__init__(model, subgraph, exp_tab, ctx)
+#              self.convert_map["CUSTOM_RELU"] = self._convert_custom_relu
+#
+#          def _convert_custom_relu(self, op):
+#              # implement your conversion logic here
+#              ...
+#
+#      mod = from_tflite(tflite_model, op_converter=MyConverter)
+
+######################################################################
+# Summary
+# -------
+#
+# 
+---------------------+----------------------------+-------------------------------+-----------------------------+
+# | Aspect              | PyTorch                    | ONNX                    
      | TFLite                      |
+# 
+=====================+============================+===============================+=============================+
+# | Entry function      | ``from_exported_program``  | ``from_onnx``           
      | ``from_tflite``             |
+# 
+---------------------+----------------------------+-------------------------------+-----------------------------+
+# | Input               | ``ExportedProgram``        | ``onnx.ModelProto``     
      | TFLite ``Model`` object     |
+# 
+---------------------+----------------------------+-------------------------------+-----------------------------+
+# | Custom extension    | ``custom_convert_map``     | —                       
      | ``op_converter`` class      |
+# 
+---------------------+----------------------------+-------------------------------+-----------------------------+
+#
+# **Which to use?** Pick the frontend that matches your model format:
+#
+# - Have a PyTorch model? Use ``from_exported_program`` — it has the broadest 
operator coverage.
+# - Have an ``.onnx`` file? Use ``from_onnx``.
+# - Have a ``.tflite`` file? Use ``from_tflite``.
+#
+# The verification workflow (compile → run → compare) demonstrated in the 
PyTorch section
+# above applies equally to ONNX and TFLite imports.
+#
+# For the full list of supported operators, see the converter map in each 
frontend's source:
+# PyTorch uses ``create_convert_map()`` in ``exported_program_translator.py``, 
ONNX uses
+# ``_get_convert_map()`` in ``onnx_frontend.py``, and TFLite uses 
``convert_map`` in
+# ``OperatorConverter`` in ``tflite_frontend.py``.
+#
+# After importing, refer to :ref:`optimize_model` for optimization and 
deployment.
diff --git a/docs/index.rst b/docs/index.rst
index 2b5ef64646..9bebfe1772 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -41,6 +41,7 @@ driving its costs down.
    :maxdepth: 1
    :caption: How To
 
+   how_to/tutorials/import_model
    how_to/tutorials/e2e_opt_model
    how_to/tutorials/customize_opt
    how_to/tutorials/optimize_llm
diff --git a/tests/scripts/task_python_docs.sh 
b/tests/scripts/task_python_docs.sh
index ec765659c2..7c3122cd92 100755
--- a/tests/scripts/task_python_docs.sh
+++ b/tests/scripts/task_python_docs.sh
@@ -91,6 +91,9 @@ IGNORED_WARNINGS=(
     # Warning is thrown during TFLite quantization for micro_train tutorial
     'absl:For model inputs containing unsupported operations which cannot be 
quantized, the `inference_input_type` attribute will default to the original 
type.'
     'absl:Found untraced functions such as _jit_compiled_convolution_op'
+    # TF C++ runtime prints this before absl logging is initialized
+    'absl::InitializeLog'
+    'absl:Please consider providing the trackable_obj argument'
     'You are using pip version'
     # Tutorial READMEs can be ignored, but other docs should be included
     "tutorials/README.rst: WARNING: document isn't included in any toctree"

Reply via email to