This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d1f5583e0e [Docs] Add tutorial for importing models from PyTorch,
ONNX, and TFLite (#19354)
d1f5583e0e is described below
commit d1f5583e0efeb2bd02907e84e30034de46a934af
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Apr 6 14:52:17 2026 -0400
[Docs] Add tutorial for importing models from PyTorch, ONNX, and TFLite
(#19354)
This pr adds a tutorial for user to have a quick understanding of how to
import models from our supporting frontends.
Besides, this pr also adds `absl::InitializeLog` and `trackable_obj`
warnings to the CI docs ignore list — these are emitted by TensorFlow's
C++ runtime during import and cannot be suppressed from Python.
---
docs/how_to/tutorials/import_model.py | 407 ++++++++++++++++++++++++++++++++++
docs/index.rst | 1 +
tests/scripts/task_python_docs.sh | 3 +
3 files changed, 411 insertions(+)
diff --git a/docs/how_to/tutorials/import_model.py
b/docs/how_to/tutorials/import_model.py
new file mode 100644
index 0000000000..888235a859
--- /dev/null
+++ b/docs/how_to/tutorials/import_model.py
@@ -0,0 +1,407 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# ruff: noqa: E402, E501
+
+"""
+.. _import_model:
+
+Importing Models from ML Frameworks
+====================================
+Apache TVM supports importing models from popular ML frameworks including
PyTorch, ONNX,
+and TensorFlow Lite. This tutorial walks through each import path with a
minimal working
+example and explains the key parameters. The PyTorch section additionally
demonstrates
+how to handle unsupported operators via a custom converter map.
+
+For end-to-end optimization and deployment after importing, see
:ref:`optimize_model`.
+
+.. note::
+
+ The ONNX section requires the ``onnx`` package. The TFLite section requires
+ ``tensorflow`` and ``tflite``. Sections whose dependencies are missing are
skipped
+ automatically.
+
+.. contents:: Table of Contents
+ :local:
+ :depth: 2
+"""
+
+######################################################################
+# Importing from PyTorch (Recommended)
+# -------------------------------------
+# TVM's PyTorch frontend is the most feature-complete. The recommended entry
point is
+# :py:func:`~tvm.relax.frontend.torch.from_exported_program`, which works with
PyTorch's
+# ``torch.export`` API.
+#
+# We start by defining a small CNN model for demonstration. No pretrained
weights are
+# needed — we only care about the graph structure.
+
+import numpy as np
+import torch
+from torch import nn
+from torch.export import export
+
+import tvm
+from tvm import relax
+from tvm.relax.frontend.torch import from_exported_program
+
+
+class SimpleCNN(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
+ self.bn = nn.BatchNorm2d(16)
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(16, 10)
+
+ def forward(self, x):
+ x = torch.relu(self.bn(self.conv(x)))
+ x = self.pool(x).flatten(1)
+ x = self.fc(x)
+ return x
+
+
+torch_model = SimpleCNN().eval()
+example_args = (torch.randn(1, 3, 32, 32),)
+
+######################################################################
+# Basic import
+# ~~~~~~~~~~~~
+# The standard workflow is: ``torch.export.export()`` →
``from_exported_program()`` →
+# ``detach_params()``.
+
+with torch.no_grad():
+ exported_program = export(torch_model, example_args)
+ mod = from_exported_program(
+ exported_program,
+ keep_params_as_input=True,
+ unwrap_unit_return_tuple=True,
+ )
+
+mod, params = relax.frontend.detach_params(mod)
+mod.show()
+
+######################################################################
+# Key parameters
+# ~~~~~~~~~~~~~~
+# ``from_exported_program`` accepts several parameters that control how the
model is
+# translated:
+#
+# - **keep_params_as_input** (bool, default ``False``): When ``True``, model
weights become
+# function parameters, separated via ``relax.frontend.detach_params()``.
When ``False``,
+# weights are embedded as constants inside the IRModule. Use ``True`` when
you want to
+# manage weights independently (e.g., for weight sharing or quantization).
+#
+# - **unwrap_unit_return_tuple** (bool, default ``False``): PyTorch ``export``
always wraps
+# the return value in a tuple. Set ``True`` to unwrap single-element return
tuples for a
+# cleaner Relax function signature.
+#
+# - **run_ep_decomposition** (bool, default ``True``): Runs PyTorch's built-in
operator
+# decomposition before translation. This breaks high-level ops (e.g.,
``batch_norm``) into
+# lower-level primitives, which generally improves TVM's coverage and
optimization
+# opportunities. Set ``False`` if you want to preserve the original op
granularity.
+
+######################################################################
+# Handling unsupported operators with ``custom_convert_map``
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# When TVM encounters a PyTorch operator it does not recognize, it raises an
error
+# indicating the unsupported operator name. You can extend the frontend by
providing a
+# **custom converter map** — a dictionary mapping operator names to your own
conversion
+# functions.
+#
+# A custom converter function receives two arguments:
+#
+# - **node** (``torch.fx.Node``): The FX graph node being converted, carrying
operator
+# info and references to input nodes.
+# - **importer** (``ExportedProgramImporter``): The importer instance, giving
access to:
+#
+# - ``importer.env``: Dict mapping FX nodes to their converted Relax
expressions.
+# - ``importer.block_builder``: The Relax ``BlockBuilder`` for emitting
operations.
+# - ``importer.retrieve_args(node)``: Helper to look up converted args.
+#
+# The function must return a ``relax.Var`` — the Relax expression for this
node's output.
+# Here is an example that maps an operator to ``relax.op.sigmoid``:
+
+from tvm.relax.frontend.torch.exported_program_translator import
ExportedProgramImporter
+
+
+def convert_sigmoid(node: torch.fx.Node, importer: ExportedProgramImporter) ->
relax.Var:
+ """Custom converter: map an op to relax.op.sigmoid."""
+ args = importer.retrieve_args(node)
+ return importer.block_builder.emit(relax.op.sigmoid(args[0]))
+
+
+######################################################################
+# To use the custom converter, pass it via the ``custom_convert_map``
parameter. The key
+# is the ATen operator name in ``"op_name.variant"`` format (e.g.,
``"sigmoid.default"``):
+#
+# .. code-block:: python
+#
+# mod = from_exported_program(
+# exported_program,
+# custom_convert_map={"sigmoid.default": convert_sigmoid},
+# )
+#
+# .. note::
+#
+# To find the correct operator name, check the error message TVM raises
when encountering
+# the unsupported op — it includes the exact ATen name. You can also
inspect the exported
+# program's graph via ``print(exported_program.graph_module.graph)`` to see
all operator
+# names.
+
+######################################################################
+# Alternative PyTorch import methods
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Besides ``from_exported_program``, TVM also provides:
+#
+# - :py:func:`~tvm.relax.frontend.torch.from_fx`: Works with
``torch.fx.GraphModule``
+# from ``torch.fx.symbolic_trace()``. Requires explicit ``input_info``
(shapes and dtypes).
+# Use this when ``torch.export`` fails on certain Python control flow
patterns.
+#
+# - :py:func:`~tvm.relax.frontend.torch.relax_dynamo`: A ``torch.compile``
backend that
+# compiles and executes the model through TVM in one step. Useful for
integrating TVM
+# into an existing PyTorch training or inference loop.
+#
+# - :py:func:`~tvm.relax.frontend.torch.dynamo_capture_subgraphs`: Captures
subgraphs from
+# a PyTorch model into an IRModule via ``torch.compile``. Each subgraph
becomes a separate
+# function in the IRModule.
+#
+# For most use cases, ``from_exported_program`` is the recommended path.
+
+######################################################################
+# Verifying the imported model
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# After importing, it is good practice to verify that TVM produces the same
output as the
+# original framework. We compile with the minimal ``"zero"`` pipeline (no
tuning) and
+# compare. The same approach applies to models imported via the ONNX and
TFLite frontends
+# shown below.
+
+mod_compiled = relax.get_pipeline("zero")(mod)
+exec_module = tvm.compile(mod_compiled, target="llvm")
+dev = tvm.cpu()
+vm = relax.VirtualMachine(exec_module, dev)
+
+# Run inference
+input_data = np.random.rand(1, 3, 32, 32).astype("float32")
+tvm_input = tvm.runtime.tensor(input_data, dev)
+tvm_params = [tvm.runtime.tensor(p, dev) for p in params["main"]]
+tvm_out = vm["main"](tvm_input, *tvm_params).numpy()
+
+# Compare with PyTorch
+with torch.no_grad():
+ pt_out = torch_model(torch.from_numpy(input_data)).numpy()
+
+np.testing.assert_allclose(tvm_out, pt_out, rtol=1e-5, atol=1e-5)
+print("PyTorch vs TVM outputs match!")
+
+######################################################################
+# Importing from ONNX
+# --------------------
+# TVM can import ONNX models via
:py:func:`~tvm.relax.frontend.onnx.from_onnx`. The
+# function accepts an ``onnx.ModelProto`` object, so you need to load the
model with
+# ``onnx.load()`` first.
+#
+# Here we export the same CNN model to ONNX format and then import it into TVM.
+
+try:
+ import onnx
+ import onnxscript # noqa: F401 # required by torch.onnx.export
+
+ HAS_ONNX = True
+except ImportError:
+ onnx = None # type: ignore[assignment]
+ HAS_ONNX = False
+
+if HAS_ONNX:
+ from tvm.relax.frontend.onnx import from_onnx
+
+ # Export the PyTorch model to ONNX
+ dummy_input = torch.randn(1, 3, 32, 32)
+ onnx_path = "simple_cnn.onnx"
+ torch.onnx.export(torch_model, dummy_input, onnx_path,
input_names=["input"])
+
+ # Load and import into TVM
+ onnx_model = onnx.load(onnx_path)
+ mod_onnx = from_onnx(onnx_model, keep_params_in_input=True)
+ mod_onnx, params_onnx = relax.frontend.detach_params(mod_onnx)
+ mod_onnx.show()
+
+######################################################################
+# If you already have an ``.onnx`` file on disk, the workflow is even simpler:
+#
+# .. code-block:: python
+#
+# import onnx
+# from tvm.relax.frontend.onnx import from_onnx
+#
+# onnx_model = onnx.load("my_model.onnx")
+# mod = from_onnx(onnx_model)
+#
+
+######################################################################
+# Key parameters
+# ~~~~~~~~~~~~~~
+# - **shape_dict** (dict, optional): Maps input names to shapes. Auto-inferred
from the
+# model if not provided. Useful when the ONNX model has dynamic dimensions
that you
+# want to fix to concrete sizes:
+#
+# .. code-block:: python
+#
+# mod = from_onnx(onnx_model, shape_dict={"input": [1, 3, 224, 224]})
+#
+# - **dtype_dict** (str or dict, default ``"float32"``): Input dtypes. A
single string
+# applies to all inputs, or use a dict to set per-input dtypes:
+#
+# .. code-block:: python
+#
+# mod = from_onnx(onnx_model, dtype_dict={"input": "float16"})
+#
+# - **keep_params_in_input** (bool, default ``False``): Same semantics as
PyTorch — whether
+# model weights are function parameters or embedded constants.
+#
+# - **opset** (int, optional): Override the opset version auto-detected from
the model.
+# Each ONNX op may have different semantics across opset versions; TVM's
converter
+# selects the appropriate implementation automatically. You rarely need to
set this
+# unless the model metadata is incorrect.
+
+######################################################################
+# Importing from TensorFlow Lite
+# -------------------------------
+# TVM can import TFLite flat buffer models via
+# :py:func:`~tvm.relax.frontend.tflite.from_tflite`. The function expects a
TFLite
+# ``Model`` object parsed from flat buffer bytes via ``GetRootAsModel``.
+#
+# .. note::
+#
+# The ``tflite`` Python package has changed its module layout across
versions.
+# Older versions use ``tflite.Model.Model.GetRootAsModel``, while newer
versions use
+# ``tflite.Model.GetRootAsModel``. The code below handles both.
+#
+# Below we create a minimal TFLite model from TensorFlow and import it.
+
+try:
+ import tensorflow as tf
+ import tflite
+ import tflite.Model
+
+ HAS_TFLITE = True
+except ImportError:
+ HAS_TFLITE = False
+
+if HAS_TFLITE:
+ from tvm.relax.frontend.tflite import from_tflite
+
+ # Define a simple TF module and convert to TFLite.
+ # We use plain TF ops (not keras layers) to avoid variable-handling ops
+ # that some TFLite converter versions do not support cleanly.
+ class TFModule(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(1, 784), dtype=tf.float32),
+ tf.TensorSpec(shape=(784, 10), dtype=tf.float32),
+ ]
+ )
+ def forward(self, x, weight):
+ return tf.matmul(x, weight) + 0.1
+
+ tf_module = TFModule()
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [tf_module.forward.get_concrete_function()], tf_module
+ )
+ tflite_buf = converter.convert()
+
+ # Parse and import into TVM (API differs between tflite package versions)
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
+ mod_tflite = from_tflite(tflite_model)
+ mod_tflite.show()
+
+######################################################################
+# Loading from a ``.tflite`` file
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# If you already have a ``.tflite`` file on disk, load the raw bytes and parse
them:
+#
+# .. code-block:: python
+#
+# import tflite
+# import tflite.Model
+# from tvm.relax.frontend.tflite import from_tflite
+#
+# with open("my_model.tflite", "rb") as f:
+# tflite_buf = f.read()
+#
+# if hasattr(tflite.Model, "Model"):
+# tflite_model = tflite.Model.Model.GetRootAsModel(tflite_buf, 0)
+# else:
+# tflite_model = tflite.Model.GetRootAsModel(tflite_buf, 0)
+# mod = from_tflite(tflite_model)
+
+######################################################################
+# Key parameters
+# ~~~~~~~~~~~~~~
+# - **shape_dict** / **dtype_dict** (optional): Override input shapes and
dtypes. If not
+# provided, they are inferred from the TFLite model metadata.
+#
+# - **op_converter** (class, optional): A custom operator converter class.
Subclass
+# ``OperatorConverter`` and override its ``convert_map`` dictionary to add
or replace
+# operator conversions. For example, to add a hypothetical ``CUSTOM_RELU``
op:
+#
+# .. code-block:: python
+#
+# from tvm.relax.frontend.tflite.tflite_frontend import OperatorConverter
+#
+# class MyConverter(OperatorConverter):
+# def __init__(self, model, subgraph, exp_tab, ctx):
+# super().__init__(model, subgraph, exp_tab, ctx)
+# self.convert_map["CUSTOM_RELU"] = self._convert_custom_relu
+#
+# def _convert_custom_relu(self, op):
+# # implement your conversion logic here
+# ...
+#
+# mod = from_tflite(tflite_model, op_converter=MyConverter)
+
+######################################################################
+# Summary
+# -------
+#
+#
+---------------------+----------------------------+-------------------------------+-----------------------------+
+# | Aspect | PyTorch | ONNX
| TFLite |
+#
+=====================+============================+===============================+=============================+
+# | Entry function | ``from_exported_program`` | ``from_onnx``
| ``from_tflite`` |
+#
+---------------------+----------------------------+-------------------------------+-----------------------------+
+# | Input | ``ExportedProgram`` | ``onnx.ModelProto``
| TFLite ``Model`` object |
+#
+---------------------+----------------------------+-------------------------------+-----------------------------+
+# | Custom extension | ``custom_convert_map`` | —
| ``op_converter`` class |
+#
+---------------------+----------------------------+-------------------------------+-----------------------------+
+#
+# **Which to use?** Pick the frontend that matches your model format:
+#
+# - Have a PyTorch model? Use ``from_exported_program`` — it has the broadest
operator coverage.
+# - Have an ``.onnx`` file? Use ``from_onnx``.
+# - Have a ``.tflite`` file? Use ``from_tflite``.
+#
+# The verification workflow (compile → run → compare) demonstrated in the
PyTorch section
+# above applies equally to ONNX and TFLite imports.
+#
+# For the full list of supported operators, see the converter map in each
frontend's source:
+# PyTorch uses ``create_convert_map()`` in ``exported_program_translator.py``,
ONNX uses
+# ``_get_convert_map()`` in ``onnx_frontend.py``, and TFLite uses
``convert_map`` in
+# ``OperatorConverter`` in ``tflite_frontend.py``.
+#
+# After importing, refer to :ref:`optimize_model` for optimization and
deployment.
diff --git a/docs/index.rst b/docs/index.rst
index 2b5ef64646..9bebfe1772 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -41,6 +41,7 @@ driving its costs down.
:maxdepth: 1
:caption: How To
+ how_to/tutorials/import_model
how_to/tutorials/e2e_opt_model
how_to/tutorials/customize_opt
how_to/tutorials/optimize_llm
diff --git a/tests/scripts/task_python_docs.sh
b/tests/scripts/task_python_docs.sh
index ec765659c2..7c3122cd92 100755
--- a/tests/scripts/task_python_docs.sh
+++ b/tests/scripts/task_python_docs.sh
@@ -91,6 +91,9 @@ IGNORED_WARNINGS=(
# Warning is thrown during TFLite quantization for micro_train tutorial
'absl:For model inputs containing unsupported operations which cannot be
quantized, the `inference_input_type` attribute will default to the original
type.'
'absl:Found untraced functions such as _jit_compiled_convolution_op'
+ # TF C++ runtime prints this before absl logging is initialized
+ 'absl::InitializeLog'
+ 'absl:Please consider providing the trackable_obj argument'
'You are using pip version'
# Tutorial READMEs can be ignored, but other docs should be included
"tutorials/README.rst: WARNING: document isn't included in any toctree"