This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 481aa6ed12 [Docs] Add operator fusion architecture documentation
(#19394)
481aa6ed12 is described below
commit 481aa6ed12e5b1a835ffadf4db43a4f5ebbd0b2c
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Apr 12 11:46:10 2026 -0400
[Docs] Add operator fusion architecture documentation (#19394)
as per title
---
docs/arch/fusion.rst | 391 +++++++++++++++++++++++++++++++++++++++++++++++++++
docs/arch/index.rst | 12 +-
2 files changed, 400 insertions(+), 3 deletions(-)
diff --git a/docs/arch/fusion.rst b/docs/arch/fusion.rst
new file mode 100644
index 0000000000..5b7e755f94
--- /dev/null
+++ b/docs/arch/fusion.rst
@@ -0,0 +1,391 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _fusion-arch:
+
+Operator Fusion
+===============
+
+Operator fusion is one of the most impactful optimizations in TVM. Instead of
launching one kernel
+per operator (e.g., conv2d, bias_add, relu), fusion merges multiple operators
into a single kernel,
+eliminating intermediate memory allocations and kernel launch overhead.
+
+TVM provides two complementary fusion mechanisms:
+
+- **Automatic fusion** (``FuseOps`` + ``FuseTIR``): groups operators based on
their computational
+ patterns using a post-dominator analysis algorithm.
+- **Pattern-based fusion** (``FuseOpsByPattern``): groups operators that match
user-defined
+ dataflow patterns, typically for offloading to external backends (cuBLAS,
CUTLASS, DNNL, etc.).
+
+Both produce the same output: Relax functions marked with ``Primitive=True``
that are later
+lowered to fused TIR kernels or dispatched to external libraries.
+
+Overview
+--------
+
+Fusion involves three passes:
+
+.. code-block:: text
+
+ IRModule (after LegalizeOps)
+ │
+ ▼ AnnotateTIROpPattern ← label each op (elementwise, reduce,
etc.)
+ IRModule (annotated)
+ │
+ ▼ FuseOps ← group ops into fused Relax functions
+ IRModule (with fused functions marked Primitive=True)
+ │
+ ▼ FuseTIR ← merge TIR PrimFuncs inside each group
+ IRModule (fused TIR kernels)
+
+In the compilation pipeline, these passes appear in the backend-specific
``legalize_passes``
+phase. For example, the CUDA pipeline
(``python/tvm/relax/backend/cuda/pipeline.py``) runs:
+
+.. code-block:: python
+
+ LegalizeOps() # lower Relax ops to call_tir
+ AnnotateTIROpPattern() # annotate pattern kinds
+ FoldConstant()
+ FuseOps() # group ops
+ FuseTIR() # merge TIR functions
+
+
+Operator Pattern Classification
+-------------------------------
+
+Before fusion, ``AnnotateTIROpPattern`` analyzes each TIR function in the
module and assigns
+an ``OpPatternKind``. The fusion algorithm uses these pattern kinds to decide
which operators
+can be fused together.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 10 70
+
+ * - Pattern Kind
+ - Value
+ - Description
+ * - ``kElemWise``
+ - 0
+ - Elementwise: one-to-one input/output mapping (e.g., ``add``, ``relu``,
``exp``).
+ * - ``kBroadcast``
+ - 1
+ - Broadcasting: output axes map to input axes in order, but some input
axes may be
+ broadcast (e.g., ``bias_add``). Note: ``transpose`` is **not**
broadcast because axes
+ are reordered.
+ * - ``kInjective``
+ - 2
+ - Injective: each output element depends on a single input element, but
the mapping may
+ be non-trivial (e.g., ``reshape``, ``concatenate``, ``transpose``).
+ * - ``kCommReduce``
+ - 3
+ - Communicative reduction: output elements aggregate over input elements
+ (e.g., ``sum``, ``max``, ``mean``).
+ * - ``kOutEWiseFusable``
+ - 4
+ - Complex operation whose output can accept elementwise followers, but
cannot chain
+ with another complex op (e.g., ``conv2d``, ``matmul``, ``dense``).
+ * - ``kTuple``
+ - 7
+ - Tuple node. Can fuse into subsequent injective ops but is treated
specially.
+ * - ``kOpaque``
+ - 8
+ - Opaque: cannot be fused (e.g., external function calls, operations with
side effects).
+
+These kinds form an ordering: lower values are "simpler" and more fusable. The
fusion algorithm
+uses ``CombinePattern(lhs, rhs) = max(lhs, rhs)`` when merging patterns along
a path.
+
+
+FuseOps: Automatic Fusion
+-------------------------
+
+``FuseOps`` (``src/relax/transform/fuse_ops.cc``) groups bindings in a
dataflow block into
+new Relax functions. It operates only within ``DataflowBlock``\ s — if your
module doesn't have
+any, run ``ConvertToDataflow`` first.
+
+Algorithm
+~~~~~~~~~
+
+The fusion algorithm addresses diamond-shaped dataflow branches, where a
single producer
+(e.g., conv2d) has multiple consumers that eventually reconverge:
+
+.. code-block:: text
+
+ conv2d
+ / | \
+ / | \
+ op op op
+ \ | /
+ \ | /
+ elemwise add
+
+At the point of ``conv2d``, we don't know if all future paths will merge. The
algorithm uses
+**post-dominator analysis** to resolve this:
+
+1. **Build forward graph**: construct an ``IndexedForwardGraph`` from the
dataflow block.
+ Each node has an ``OpPatternKind`` and a list of forward edges.
+
+2. **Build post-dominator tree**: compute the immediate post-dominator of each
node using
+ Least Common Ancestor (LCA) on the DAG. The post-dominator of a node is the
closest
+ downstream node where **all** future paths converge.
+
+3. **Fuse groups**: for each node in topological order, check if it can be
fused with its
+ immediate post-dominator:
+
+ - **CheckPath**: verify that all paths from the node to its post-dominator
satisfy the
+ fusion conditions (pattern compatibility, depth limits, argument limits).
+ - **CommitFuse**: mark all intermediate nodes as belonging to the same
group using a
+ Union-Find data structure.
+
+4. **Create grouped functions**: extract each group into a new
``relax.Function`` with the
+ attribute ``Primitive=True``. Replace the original bindings with a call to
the grouped
+ function.
+
+Fusion rules
+~~~~~~~~~~~~
+
+The key fusion decisions depend on the ``OpPatternKind`` of the source, the
path, and the
+post-dominator. The algorithm runs in three phases (via
``GraphPartitioner::RunFuse``) so that
+higher-complexity ops get a chance to fuse first:
+
+- **Phase 0**: ``kOutEWiseFusable`` ops (e.g., ``conv2d``) can fuse with their
elementwise
+ post-dominator if all intermediate ops are broadcast or simpler. This
enables patterns like
+ conv2d + bias_add + relu. Two ``kOutEWiseFusable`` ops cannot fuse together.
+- **Phase 1**: ``kInjective`` and ``kTuple`` ops can fuse only when all paths
to the
+ post-dominator are injective or simpler. This is deferred to phase 1 so that
+ ``kOutEWiseFusable`` groups are finalized first.
+- **Phase 2**: fuse injective ops into intermediate tuple nodes that have
already been absorbed
+ by subsequent injective groups.
+
+``kElemWise`` / ``kBroadcast`` ops are processed in **every** phase (not
restricted to one):
+they can fuse into a post-dominator that is injective or reduction. The sink
(final node) may
+also be a ``kOutEWiseFusable`` group that was formed in phase 0 — this is how
elementwise
+producers merge into an existing conv2d fusion group.
+
+Additional constraints:
+
+- **Reduction** (``kCommReduce``) ops never initiate fusion — they act as
sinks only. Elementwise
+ and broadcast producers can fuse *into* a reduction, but a reduction cannot
fuse forward.
+- **Opaque** ops are fusion barriers.
+- A group cannot exceed ``kMaxFusedOps`` (256) nodes or the maximum function
argument count.
+
+Example
+~~~~~~~
+
+Given two elementwise ops (``add``, ``exp``) and one injective op
(``squeeze``).
+The examples below are simplified pseudocode — real TVMScript would reference
TIR functions
+via ``cls.func_name``:
+
+.. code-block:: python
+
+ # Before FuseOps (simplified)
+ @R.function
+ def main(x: R.Tensor((10, 20), "float32")):
+ with R.dataflow():
+ lv0 = R.call_tir(add, (x, const_1), out_sinfo=R.Tensor((10, 20),
"float32"))
+ lv1 = R.call_tir(exp, (lv0,), out_sinfo=R.Tensor((10, 20),
"float32"))
+ gv = R.call_tir(squeeze, (lv1,), out_sinfo=R.Tensor((10, 20),
"float32"))
+ R.output(gv)
+ return gv
+
+After ``FuseOps``, all three are grouped into a single function:
+
+.. code-block:: python
+
+ # After FuseOps
+ @R.function(private=True)
+ def fused_add_exp_squeeze(x, p0):
+ R.func_attr({"Primitive": True})
+ with R.dataflow():
+ lv0 = R.call_tir(add, (x, p0), ...)
+ lv1 = R.call_tir(exp, (lv0,), ...)
+ gv = R.call_tir(squeeze, (lv1,), ...)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(x: R.Tensor((10, 20), "float32")):
+ with R.dataflow():
+ gv = fused_add_exp_squeeze(x, const_1)
+ R.output(gv)
+ return gv
+
+
+FuseTIR: Merging TIR Functions
+------------------------------
+
+``FuseTIR`` (``src/relax/transform/fuse_tir.cc``) takes the grouped Relax
functions produced by
+``FuseOps`` and merges their internal TIR ``PrimFunc``\ s into a single TIR
function.
+
+Before ``FuseTIR``, a fused group still contains multiple ``R.call_tir`` calls
to separate
+TIR functions. ``FuseTIR`` inlines and merges them:
+
+.. code-block:: text
+
+ Before FuseTIR:
+ fused_add_exp_squeeze:
+ call_tir(add, ...) → separate TIR PrimFunc
+ call_tir(exp, ...) → separate TIR PrimFunc
+ call_tir(squeeze, ...) → separate TIR PrimFunc
+
+ After FuseTIR:
+ fused_add_exp_squeeze: → single merged TIR PrimFunc
+
+The merged function eliminates intermediate buffers — the output of ``add`` is
directly consumed
+by ``exp`` without writing to and reading from global memory. This is the core
performance benefit
+of fusion.
+
+Internally, ``FuseTIR`` uses a ``SymbolicMatcher`` to align symbolic shape
variables across the
+TIR functions being merged, ensuring that dimensions are correctly mapped when
combining buffer
+accesses.
+
+
+FuseOpsByPattern: Pattern-Based Fusion
+--------------------------------------
+
+While ``FuseOps`` makes fusion decisions automatically based on operator
patterns,
+``FuseOpsByPattern`` lets you specify exactly which operator combinations to
fuse using
+the Relax :ref:`Dataflow Pattern Language (DPL) <relax-dpl>`.
+
+This is primarily used for **backend-specific dispatch**: identifying operator
subgraphs that
+should be offloaded to external libraries like cuBLAS, CUTLASS, cuDNN, or DNNL.
+
+FusionPattern
+~~~~~~~~~~~~~
+
+A ``FusionPattern`` (``python/tvm/relax/transform/transform.py``) defines what
to match:
+
+.. code-block:: python
+
+ from tvm.relax.dpl import wildcard, is_op
+ from tvm.relax.transform import FusionPattern
+
+ # Match: matmul(x, w) + bias
+ x = wildcard()
+ w = wildcard()
+ bias = wildcard()
+ matmul = is_op("relax.matmul")(x, w)
+ out = is_op("relax.add")(matmul, bias)
+
+ pattern = FusionPattern(
+ name="cutlass.matmul_bias",
+ pattern=out,
+ annotation_patterns={"matmul": matmul, "bias": bias},
+ check=my_check_function, # optional validation
+ )
+
+Fields:
+
+- ``name``: pattern identifier, typically prefixed with the backend name (e.g.,
+ ``"cutlass.matmul_bias"``).
+- ``pattern``: a DFPattern describing the subgraph to match. See the
+ :ref:`DPL deep dive <relax-dpl>` for the full pattern language.
+- ``annotation_patterns``: a mapping of names to sub-patterns within the main
pattern. These
+ are extracted during matching and made available to the ``check`` function
and
+ ``attrs_getter``.
+- ``check``: an optional ``Callable[[PatternCheckContext], bool]`` that
validates whether
+ a match should be accepted. Receives the matched expression, annotated
sub-expressions,
+ variable usages, and binding information.
+- ``attrs_getter``: an optional function that extracts attributes (e.g.,
transpose flags,
+ data types) from the matched expressions to annotate the grouped function.
+
+Applying patterns
+~~~~~~~~~~~~~~~~~
+
+.. code-block:: python
+
+ from tvm.relax.transform import FuseOpsByPattern
+
+ mod = FuseOpsByPattern(
+ patterns=[pattern1, pattern2, ...], # ordered by priority
+ bind_constants=True,
+ annotate_codegen=False,
+ )(mod)
+
+Key parameters:
+
+- ``patterns``: a list of ``FusionPattern`` objects, ordered by priority.
Higher-priority
+ patterns come first — if a subgraph matches multiple patterns, the first
match wins.
+- ``bind_constants``: if ``True``, constants used by the matched subgraph are
captured inside
+ the grouped function.
+- ``annotate_codegen``: if ``True``, wraps each composite function with an
outer function
+ annotated with ``"Codegen"`` and ``"global_symbol"`` attributes for external
backend dispatch.
+ The ``"Codegen"`` value is derived from the pattern name prefix (e.g.,
``"dnnl"`` from
+ ``"dnnl.conv2d_relu"``).
+
+PatternCheckContext
+~~~~~~~~~~~~~~~~~~~
+
+The ``check`` function receives a ``PatternCheckContext`` with:
+
+- ``matched_expr``: the root expression matched by the pattern.
+- ``annotated_expr``: a mapping from annotation pattern names to their matched
expressions.
+- ``matched_bindings``: variable-to-value bindings within the matched subgraph.
+- ``var_usages``: a mapping from variable definitions to all their uses in the
function.
+- ``value_to_bound_var``: reverse mapping from values to the variables they
are bound to.
+
+This context enables sophisticated validation logic, such as checking that an
intermediate
+result is not used outside the fused group, or verifying data type
compatibility.
+
+
+How Backends Use Fusion
+-----------------------
+
+The default backend pipelines (CUDA, ROCm, CPU, etc.) all include ``FuseOps``
+ ``FuseTIR``
+in their ``legalize_passes`` phase for automatic fusion. For example, the CUDA
pipeline
+(``python/tvm/relax/backend/cuda/pipeline.py``) runs::
+
+ LegalizeOps → AnnotateTIROpPattern → FoldConstant → FuseOps → FuseTIR →
DLight
+
+For external library dispatch (cuBLAS, CUTLASS, cuDNN, DNNL),
``FuseOpsByPattern`` is used
+separately. These are **not** included in the default pipeline — users add
them explicitly
+when building a custom compilation flow. The typical sequence is:
+
+1. **Pattern-based dispatch** (``FuseOpsByPattern``): identify subgraphs that
should be
+ offloaded to external libraries. For example, CUTLASS patterns match
+ matmul+bias+activation combinations
(``python/tvm/relax/backend/cuda/cutlass.py``).
+ Functions marked by patterns are annotated with ``Composite`` and
optionally ``Codegen``
+ attributes.
+
+2. **Automatic fusion** (``FuseOps`` + ``FuseTIR``): remaining operators that
were not
+ matched by backend patterns are fused automatically based on their pattern
kinds.
+
+
+Source Code Map
+---------------
+
+.. list-table::
+ :header-rows: 1
+ :widths: 50 50
+
+ * - Path
+ - Contents
+ * - ``src/relax/transform/fuse_ops.cc``
+ - FuseOps and FuseOpsByPattern implementation
+ * - ``src/relax/analysis/graph_partitioner.h``
+ - IndexedForwardGraph, DominatorTree, GraphPartitioner (Union-Find)
+ * - ``src/relax/transform/fuse_tir.cc``
+ - FuseTIR implementation, SymbolicMatcher
+ * - ``include/tvm/relax/op_attr_types.h``
+ - ``OpPatternKind`` enum definition
+ * - ``python/tvm/relax/transform/transform.py``
+ - Python API: FuseOps, FuseTIR, FuseOpsByPattern, FusionPattern
+ * - ``python/tvm/relax/dpl/``
+ - Dataflow Pattern Language (DFPattern, is_op, wildcard, etc.)
+ * - ``python/tvm/relax/backend/cuda/cutlass.py``
+ - Example: CUTLASS fusion patterns
+ * - ``python/tvm/relax/backend/cuda/cublas.py``
+ - Example: cuBLAS fusion patterns
diff --git a/docs/arch/index.rst b/docs/arch/index.rst
index ec3a3b1987..17b26a5ddc 100644
--- a/docs/arch/index.rst
+++ b/docs/arch/index.rst
@@ -105,9 +105,15 @@ Apache TVM enables cross-level optimization of end-to-end
models. As the IRModul
the IRModule by applying different transformations to these two types of
functions.
For example, ``relax.LegalizeOps`` pass mutates the IRModule by lowering relax
operators, adding corresponding TensorIR PrimFunc into the IRModule, and
replacing the relax operators
-with calls to the lowered TensorIR PrimFunc. Another example is operator
fusion pipeline in relax (including ``relax.FuseOps`` and ``relax.FuseTIR``),
which fuses multiple consecutive tensor operations
-into one. Different from the previous implementations, relax fusion pipeline
analyzes the pattern of TensorIR functions and detects the best fusion rules
automatically rather
-than human-defined operator fusion patterns.
+with calls to the lowered TensorIR PrimFunc. Another example is the operator
fusion pipeline
+(``relax.FuseOps`` + ``relax.FuseTIR``), which fuses multiple consecutive
tensor operations into a
+single kernel. See :ref:`fusion-arch` for a detailed explanation of the fusion
algorithm, operator
+pattern classification, and pattern-based fusion for external backends.
+
+.. toctree::
+ :maxdepth: 1
+
+ fusion
Target Translation
~~~~~~~~~~~~~~~~~~