This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4d812a18de [Docs] Add docstrings for nn.Module classes and core APIs
in relax.frontend.nn (#19387)
4d812a18de is described below
commit 4d812a18deadd95e55d1b38104591b03fb62564e
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Apr 11 14:30:21 2026 -0400
[Docs] Add docstrings for nn.Module classes and core APIs in
relax.frontend.nn (#19387)
This pr adds some docstrings for important APIs
---
python/tvm/relax/frontend/nn/core.py | 51 +++++-
python/tvm/relax/frontend/nn/modules.py | 297 +++++++++++++++++++++++++++++---
2 files changed, 322 insertions(+), 26 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 40659e1623..f3886e94cb 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -123,7 +123,25 @@ class Tensor(_TensorOp):
@staticmethod
def from_struct_info(struct_info: rx.TensorStructInfo, name: str =
"tensor") -> "Tensor":
- """Construct a nn.Tensor from relax TensorStructInfo"""
+ """Construct a nn.Tensor from a Relax TensorStructInfo.
+
+ TensorStructInfo is the Relax type-level description of a tensor,
carrying its shape
+ and dtype without holding actual data. This factory creates an unbound
placeholder
+ ``nn.Tensor`` that can be used as a symbolic input when tracing an
``nn.Module``.
+
+ Parameters
+ ----------
+ struct_info : rx.TensorStructInfo
+ The struct info describing the tensor's shape and dtype.
+
+ name : str
+ Name hint for the underlying Relax variable.
+
+ Returns
+ -------
+ tensor : Tensor
+ A symbolic ``nn.Tensor`` backed by a ``relax.Var`` with the given
struct info.
+ """
return Tensor(
_expr=rx.Var(
name_hint=name,
@@ -492,7 +510,36 @@ class Module(SubroutineMixin):
out_format: str = "torch",
debug: bool = False,
) -> Any:
- """Just-in-time compilation of a nn.model to an executable"""
+ """Just-in-time compile an ``nn.Module`` into a callable executable.
+
+ The method exports the module to a Relax IRModule, applies the given
compilation
+ pipeline, builds a Relax VM executable, and wraps the result so it can
be called
+ directly (e.g. with PyTorch tensors when ``out_format="torch"``).
+
+ Parameters
+ ----------
+ spec : _spec.ModuleSpec
+ A specification mapping each module input to its shape and dtype.
+
+ device : Union[str, Device]
+ The device to compile and run on (e.g. ``"cpu"``, ``"cuda"``).
+
+ pipeline : Union[None, str, Pass]
+ The Relax compilation pipeline to apply. ``"default_build"`` uses
the standard
+ optimization pipeline; ``None`` skips pipeline passes.
+
+ out_format : str
+ Output wrapper format. ``"torch"`` returns a ``TorchModule`` whose
``forward``
+ accepts and returns PyTorch tensors.
+
+ debug : bool
+ If ``True``, enable effect-based debugging (e.g. printing) in the
compiled graph.
+
+ Returns
+ -------
+ module : Any
+ A callable wrapper (type depends on *out_format*) around the
compiled VM.
+ """
def _compile(spec, device, pipeline, debug):
# pylint: disable=import-outside-toplevel
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index cf6c827b9e..8753d37bac 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -95,8 +95,27 @@ class Identity(Module):
class Linear(Module):
- """
- Module for linear layer.
+ """Applies a linear transformation :math:`y = xW^T + b`.
+
+ Parameters
+ ----------
+ in_features : Union[int, str, tirx.PrimExpr]
+ Size of each input sample. Can be symbolic.
+
+ out_features : Union[int, str, tirx.PrimExpr]
+ Size of each output sample. Can be symbolic.
+
+ bias : bool
+ If ``True``, adds a learnable bias. Default: ``True``.
+
+ dtype : Optional[str]
+ Data type for weight (and bias when *out_dtype* is ``None``).
+ ``None`` uses the default dtype.
+
+ out_dtype : Optional[str]
+ If set, the matmul accumulates in this dtype and the bias is stored in
this dtype
+ instead of *dtype*. Useful for mixed-precision (e.g. ``float32``
accumulation with
+ ``float16`` weights).
"""
def __init__(
@@ -154,8 +173,36 @@ class Linear(Module):
class Conv1D(Module):
- """
- Module for conv1d layer.
+ """Applies a 1D convolution over an input signal.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of channels in the input.
+
+ out_channels : int
+ Number of channels produced by the convolution.
+
+ kernel_size : int
+ Size of the convolving kernel.
+
+ stride : int
+ Stride of the convolution. Default: 1.
+
+ padding : int
+ Zero-padding added to both sides of the input. Default: 0.
+
+ dilation : int
+ Spacing between kernel elements. Default: 1.
+
+ groups : int
+ Number of blocked connections from input to output channels. Default:
1.
+
+ bias : bool
+ If ``True``, adds a learnable bias. Default: ``True``.
+
+ dtype : Optional[str]
+ Data type for weight and bias. ``None`` uses the default dtype.
"""
def __init__(
@@ -212,8 +259,39 @@ class Conv1D(Module):
class Conv2D(Module):
- """
- Module for conv2d layer.
+ """Applies a 2D convolution over an input image.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of channels in the input image.
+
+ out_channels : int
+ Number of channels produced by the convolution.
+
+ kernel_size : Union[List[int], int]
+ Size of the convolving kernel. An int is expanded to a 2-element list.
+
+ stride : int
+ Stride of the convolution. Default: 1.
+
+ padding : int
+ Zero-padding added to both sides of the input. Default: 0.
+
+ dilation : int
+ Spacing between kernel elements. Default: 1.
+
+ groups : int
+ Number of blocked connections from input to output channels. Default:
1.
+
+ bias : bool
+ If ``True``, adds a learnable bias. Default: ``True``.
+
+ dtype : Optional[str]
+ Data type for weight and bias. ``None`` uses the default dtype.
+
+ data_layout : str
+ Layout of the input data, e.g. ``"NCHW"`` or ``"NHWC"``. Default:
``"NCHW"``.
"""
def __init__( # pylint: disable=too-many-arguments
@@ -286,8 +364,39 @@ class Conv2D(Module):
class Conv3D(Module):
- """
- Module for conv3d layer.
+ """Applies a 3D convolution over an input volume.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of channels in the input volume.
+
+ out_channels : int
+ Number of channels produced by the convolution.
+
+ kernel_size : Union[List[int], int]
+ Size of the convolving kernel. An int is expanded to a 3-element list.
+
+ stride : Union[List[int], int]
+ Stride of the convolution. Default: 1.
+
+ padding : Union[List[int], int]
+ Zero-padding added to each side of the input. Default: 0.
+
+ dilation : int
+ Spacing between kernel elements. Default: 1.
+
+ groups : int
+ Number of blocked connections from input to output channels. Default:
1.
+
+ bias : bool
+ If ``True``, adds a learnable bias. Default: ``True``.
+
+ dtype : Optional[str]
+ Data type for weight and bias. ``None`` uses the default dtype.
+
+ data_layout : str
+ Layout of the input data, e.g. ``"NCDHW"``. Default: ``"NCDHW"``.
"""
def __init__( # pylint: disable=too-many-arguments
@@ -360,8 +469,39 @@ class Conv3D(Module):
class ConvTranspose1D(Module):
- """
- Module for ConvTranspose1D layer.
+ """Applies a 1D transposed convolution (fractionally-strided convolution).
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of channels in the input.
+
+ out_channels : int
+ Number of channels produced by the transposed convolution.
+
+ kernel_size : int
+ Size of the convolving kernel.
+
+ stride : int
+ Stride of the convolution. Default: 1.
+
+ padding : int
+ Zero-padding added to both sides of the input. Default: 0.
+
+ output_padding : int
+ Additional size added to one side of the output shape. Default: 0.
+
+ dilation : int
+ Spacing between kernel elements. Default: 1.
+
+ groups : int
+ Number of blocked connections from input to output channels. Default:
1.
+
+ bias : bool
+ If ``True``, adds a learnable bias. Default: ``True``.
+
+ dtype : Optional[str]
+ Data type for weight and bias. ``None`` uses the default dtype.
"""
def __init__(
@@ -427,8 +567,22 @@ class ConvTranspose1D(Module):
class LayerNorm(Module):
- """
- Module for Layer Normalization
+ """Applies Layer Normalization over the last dimension.
+
+ Parameters
+ ----------
+ normalized_shape : int
+ Size of the last dimension to normalize over.
+
+ eps : Optional[float]
+ Value added to the denominator for numerical stability. Default:
``1e-5``.
+
+ elementwise_affine : bool
+ If ``True``, learnable affine parameters (weight and bias) are added.
+ Default: ``True``.
+
+ dtype : Optional[str]
+ Data type for the affine parameters. ``None`` uses the default dtype.
"""
def __init__(
@@ -473,8 +627,24 @@ class LayerNorm(Module):
class RMSNorm(Module):
- """
- Module for rms norm layer.
+ """Applies Root Mean Square Layer Normalization.
+
+ Parameters
+ ----------
+ hidden_size : int
+ Size of the weight parameter.
+
+ axes : Union[int, List[int]]
+ The axes over which to compute the RMS norm.
+
+ epsilon : float
+ Value added to the denominator for numerical stability. Default:
``1e-5``.
+
+ bias : bool
+ If ``True``, adds a learnable bias after normalization. Default:
``True``.
+
+ dtype : Optional[str]
+ Data type for the parameters. ``None`` uses the default dtype.
"""
def __init__(
@@ -515,8 +685,24 @@ class RMSNorm(Module):
class GroupNorm(Module):
- """
- Module for group norm layer.
+ """Applies Group Normalization.
+
+ Parameters
+ ----------
+ num_groups : int
+ Number of groups to separate the channels into.
+
+ num_channels : int
+ Number of channels in the input, must be divisible by *num_groups*.
+
+ eps : float
+ Value added to the denominator for numerical stability. Default:
``1e-5``.
+
+ affine : bool
+ If ``True``, learnable per-channel affine parameters are added.
Default: ``True``.
+
+ dtype : Optional[str]
+ Data type for the affine parameters. ``None`` uses the default dtype.
"""
def __init__(
@@ -563,8 +749,28 @@ class GroupNorm(Module):
class KVCache(Effect):
- """
- Effect to implement KVCache.
+ """Managed key-value cache for autoregressive decoding.
+
+ ``KVCache`` is a TVM-specific ``Effect`` that allocates and maintains a
runtime cache
+ for storing past key/value tensors in transformer models. Unlike regular
``Module``
+ parameters, effects are registered with the Relax VM and carry mutable
state across
+ calls (append, reset) without being passed as explicit function arguments.
+
+ The cache is pre-allocated with shape ``[init_seq_len, *unit_shape]`` and
grows via
+ the ``append`` method at runtime. Use ``init_seq_len`` to control the
initial
+ allocation size.
+
+ Parameters
+ ----------
+ init_seq_len : int
+ Initial sequence-length capacity of the cache allocation.
+
+ unit_shape : Sequence[int]
+ Shape of a single cache entry excluding the sequence dimension.
+ For multi-head attention this is typically ``[num_heads, head_dim]``.
+
+ dtype : Optional[str]
+ Data type of the cache tensor. ``None`` uses the default dtype.
"""
init_seq_len: int
@@ -708,8 +914,18 @@ class KVCache(Effect):
class Embedding(Module):
- """
- Module for embedding layer.
+ """A lookup table that retrieves embeddings by index.
+
+ Parameters
+ ----------
+ num : Union[int, str, tirx.PrimExpr]
+ Size of the embedding dictionary (vocabulary size). Can be symbolic.
+
+ dim : Union[int, str, tirx.PrimExpr]
+ Size of each embedding vector. Can be symbolic.
+
+ dtype : Optional[str]
+ Data type of the embedding weight. ``None`` uses the default dtype.
"""
def __init__(
@@ -749,8 +965,31 @@ class Embedding(Module):
class TimestepEmbedding(Module):
- """
- Module for HF TimestepEmbedding layer.
+ """MLP that projects timestep embeddings, following the HuggingFace
diffusers convention.
+
+ Consists of two linear layers with an activation in between, and an
optional
+ conditional projection and post-activation.
+
+ Parameters
+ ----------
+ in_channels : int
+ Dimensionality of the input timestep embedding.
+
+ time_embed_dim : int
+ Dimensionality of the intermediate (hidden) projection.
+
+ act_fn : str
+ Activation function name. Currently only ``"silu"`` is supported.
+
+ out_dim : Optional[int]
+ Dimensionality of the output. If ``None``, defaults to
*time_embed_dim*.
+
+ post_act_fn : Optional[str]
+ Optional post-activation applied after the second linear layer.
+
+ cond_proj_dim : Optional[int]
+ If set, adds a linear projection from a conditioning signal of this
+ dimensionality to *in_channels*, which is added to the input sample.
"""
def __init__(
@@ -816,8 +1055,18 @@ class TimestepEmbedding(Module):
class Timesteps(Module):
- """
- Module for HF timesteps layer.
+ """Sinusoidal positional embedding for diffusion timesteps (HuggingFace
convention).
+
+ Parameters
+ ----------
+ num_channels : int
+ Dimensionality of the embedding (number of sinusoidal channels).
+
+ flip_sin_to_cos : bool
+ If ``True``, swap sin and cos components. Default: ``False``.
+
+ downscale_freq_shift : float
+ Shift applied to the frequency denominator. Default: ``1``.
"""
def __init__(