This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 865ace905f [Docs] Add DLight and MetaSchedule deep-dive instructions 
(#19356)
865ace905f is described below

commit 865ace905f3bcdc96d87d7cd901f19df59020e47
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Apr 8 00:57:33 2026 -0400

    [Docs] Add DLight and MetaSchedule deep-dive instructions (#19356)
    
    This pr adds a instructions covering MetaSchedule and Flight usage in
    deep dive
---
 docs/deep_dive/tensor_ir/index.rst                 |   2 +
 .../tensor_ir/tutorials/dlight_gpu_scheduling.py   | 316 +++++++++++++++++++++
 .../deep_dive/tensor_ir/tutorials/meta_schedule.py | 307 ++++++++++++++++++++
 3 files changed, 625 insertions(+)

diff --git a/docs/deep_dive/tensor_ir/index.rst 
b/docs/deep_dive/tensor_ir/index.rst
index 95a6a3a402..2f8bd07c1b 100644
--- a/docs/deep_dive/tensor_ir/index.rst
+++ b/docs/deep_dive/tensor_ir/index.rst
@@ -39,3 +39,5 @@ In TVMScript, both modules are accessed via
     learning
     tutorials/tir_creation
     tutorials/tir_transformation
+    tutorials/dlight_gpu_scheduling
+    tutorials/meta_schedule
diff --git a/docs/deep_dive/tensor_ir/tutorials/dlight_gpu_scheduling.py 
b/docs/deep_dive/tensor_ir/tutorials/dlight_gpu_scheduling.py
new file mode 100644
index 0000000000..9c5fe1ff4c
--- /dev/null
+++ b/docs/deep_dive/tensor_ir/tutorials/dlight_gpu_scheduling.py
@@ -0,0 +1,316 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# ruff: noqa: E402, E501
+
+"""
+.. _dlight_gpu_scheduling:
+
+DLight: Rule-Based GPU Scheduling
+==================================
+TIR functions produced by Relax legalization need GPU-specific scheduling — 
thread binding,
+loop tiling, shared memory usage — before they can run efficiently on a GPU. 
There are two
+main approaches in TVM:
+
+- **MetaSchedule**: explores a search space to find the best schedule. High 
quality, but
+  compilation takes minutes to hours.
+- **DLight**: applies pre-defined scheduling rules deterministically. No 
tuning required,
+  compilation completes in seconds. Performance is excellent for well-known 
patterns
+  (e.g., GEMM, GEMV in LLM workloads) and fair for the rest.
+
+This tutorial covers how DLight works, what rules are available, how to 
diagnose scheduling
+quality, and how to write custom rules.
+
+.. contents:: Table of Contents
+    :local:
+    :depth: 1
+"""
+
+######################################################################
+# Prepare a Model
+# ---------------
+# We build a small model with ``nn.Module`` that is rich enough to trigger 
multiple DLight
+# rules: ``Linear`` layers produce GEMM (matrix multiplication) kernels, 
``LayerNorm``
+# produces a general-reduction kernel, and ``ReLU`` is a simple elementwise op.
+
+import tvm
+from tvm import relax, tirx
+from tvm.relax.frontend import nn
+from tvm.s_tir import dlight as dl
+
+
+class DemoModel(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.fc1 = nn.Linear(768, 768)
+        self.relu = nn.ReLU()
+        self.norm = nn.LayerNorm(768)
+        self.fc2 = nn.Linear(768, 256)
+
+    def forward(self, x):
+        x = self.norm(self.relu(self.fc1(x)))
+        return self.fc2(x)
+
+
+mod, params = DemoModel().export_tvm({"forward": {"x": nn.spec.Tensor((1, 
768), "float32")}})
+
+######################################################################
+# Legalize Relax operators into TIR functions so that DLight has concrete 
kernels to schedule.
+
+device = tvm.cuda(0)
+target = tvm.target.Target.from_device(device)
+with target:
+    mod = relax.get_pipeline("zero")(mod)
+
+######################################################################
+# At this point every TIR function in ``mod`` is **unscheduled** — it has no 
thread bindings
+# and would not run efficiently on a GPU. Let's see what functions we have:
+for gv, func in mod.functions_items():
+    if isinstance(func, tirx.PrimFunc):
+        print(f"  {gv.name_hint}")
+
+######################################################################
+# Basic Usage: ApplyDefaultSchedule
+# ---------------------------------
+# ``ApplyDefaultSchedule`` is an ``IRModule`` pass. It iterates over every TIR 
function in the
+# module and tries the given rules **in order**. For each function the first 
rule whose
+# ``apply()`` returns a non-``None`` schedule wins; subsequent rules are 
skipped.
+# After scheduling, the function is marked with ``tirx.is_scheduled`` so it 
won't be
+# scheduled again by a later ``ApplyDefaultSchedule`` call.
+
+######################################################################
+# Here we use a common subset of rules. The full catalog (including 
``LowBatchGEMV``,
+# ``Transpose``, ``RMSNorm``) is listed in the next section.
+
+with target:
+    scheduled_mod = dl.ApplyDefaultSchedule(
+        dl.gpu.Matmul(),  # GEMM: dense matrix multiplication
+        dl.gpu.GEMV(),  # matrix-vector products
+        dl.gpu.Reduction(),  # simple reductions (sum, max, ...)
+        dl.gpu.GeneralReduction(),  # compound reductions (softmax, layer 
norm, ...)
+        dl.gpu.Fallback(),  # catch-all for anything unmatched above
+    )(mod)
+
+scheduled_mod.show()
+
+######################################################################
+# Compared with the unscheduled IR, you can now see thread bindings
+# (``blockIdx.x``, ``threadIdx.x``, ...) and loop transformations in each TIR 
function.
+
+######################################################################
+# Rule Catalog
+# ------------
+# DLight ships a set of GPU scheduling rules. Each rule is a subclass of
+# ``ScheduleRule`` and implements an ``apply(func, target, tunable)`` method 
that returns
+# a ``Schedule`` if the rule matches, or ``None`` to pass.
+#
+# The built-in GPU rules, roughly from most specific to most general:
+#
+# .. list-table::
+#    :header-rows: 1
+#    :widths: 20 40 40
+#
+#    * - Rule
+#      - Pattern
+#      - Typical operators
+#    * - ``Matmul``
+#      - GEMM index pattern ``C[S,I,J] += A[S,I,K] * B[S,J,K]``
+#      - ``nn.Linear``, batched matmul
+#    * - ``GEMV``
+#      - Matrix-vector multiply (one dimension is 1)
+#      - single-batch decode in attention
+#    * - ``LowBatchGEMV``
+#      - Low-batch GEMM scheduled with a GEMV strategy
+#      - small-batch decode
+#    * - ``Reduction``
+#      - Simple accumulation ``X[...] += Y[...]``
+#      - sum, max, argmax
+#    * - ``GeneralReduction``
+#      - Spatial dims followed by reduction dims (``S* R*``)
+#      - softmax, layer norm, RMS norm
+#    * - ``Transpose``
+#      - Read/write indices are permutations of each other
+#      - 2-D transpose
+#    * - ``RMSNorm``
+#      - Contains an ``rsqrt`` operation
+#      - RMS normalization
+#    * - ``Fallback``
+#      - Any function (always matches)
+#      - generic catch-all
+#
+# **Rule order matters.** ``ApplyDefaultSchedule`` stops at the first match, 
so:
+#
+# - Put **specialized** rules first (``Matmul``, ``GEMV``) — they have strict 
matching
+#   conditions but produce high-quality schedules.
+# - Put **general** rules later (``GeneralReduction``, ``Fallback``) — they 
match broadly
+#   but with less optimal schedules.
+# - If you put ``Fallback`` first, it would "steal" every function and no 
specialized
+#   rule would ever run.
+
+######################################################################
+# Diagnosing Schedule Quality
+# ---------------------------
+# A common question is: *which rule scheduled which function?* 
``ApplyDefaultSchedule``
+# does not log this directly, but you can figure it out by applying rules one 
at a time.
+#
+# **Step 1**: Apply each rule individually and record which functions it 
claims.
+
+from collections import OrderedDict
+
+rules = OrderedDict(
+    [
+        ("Matmul", dl.gpu.Matmul()),
+        ("GEMV", dl.gpu.GEMV()),
+        ("LowBatchGEMV", dl.gpu.LowBatchGEMV()),
+        ("Reduction", dl.gpu.Reduction()),
+        ("GeneralReduction", dl.gpu.GeneralReduction()),
+        ("Transpose", dl.gpu.Transpose()),
+        ("RMSNorm", dl.gpu.RMSNorm()),
+    ]
+)
+
+rule_assignment = {}
+for rule_name, rule in rules.items():
+    with target:
+        test_mod = dl.ApplyDefaultSchedule(rule)(mod)
+    for gv, func in test_mod.functions_items():
+        if isinstance(func, tirx.PrimFunc) and gv.name_hint not in 
rule_assignment:
+            if "tirx.is_scheduled" in func.attrs and 
func.attrs["tirx.is_scheduled"] == 1:
+                rule_assignment[gv.name_hint] = rule_name
+
+######################################################################
+# **Step 2**: Functions not claimed by any specialized rule will fall through 
to ``Fallback``.
+
+all_tir_funcs = [
+    gv.name_hint for gv, func in mod.functions_items() if isinstance(func, 
tirx.PrimFunc)
+]
+fallback_funcs = [name for name in all_tir_funcs if name not in 
rule_assignment]
+
+print("Rule assignments:")
+for name, rule_name in sorted(rule_assignment.items()):
+    print(f"  {name:40s} -> {rule_name}")
+if fallback_funcs:
+    print("Handled by Fallback (may have suboptimal performance):")
+    for name in sorted(fallback_funcs):
+        print(f"  {name}")
+
+######################################################################
+# If an important kernel lands in the Fallback bucket, you have three options:
+#
+# 1. Write a **custom DLight rule** for it (see below).
+# 2. Use **MetaSchedule** to auto-tune that specific function.
+# 3. Manually schedule it with the ``tvm.s_tir.Schedule`` API.
+
+######################################################################
+# DLight vs MetaSchedule
+# ----------------------
+# The two systems are complementary, not competing:
+#
+# .. list-table::
+#    :header-rows: 1
+#    :widths: 20 40 40
+#
+#    * -
+#      - DLight
+#      - MetaSchedule
+#    * - Mechanism
+#      - Deterministic rule matching
+#      - Search-space exploration
+#    * - Compile time
+#      - Seconds
+#      - Minutes to hours
+#    * - Performance
+#      - Excellent on known patterns, fair otherwise
+#      - Near-optimal with sufficient search budget
+#    * - Best for
+#      - Default path, rapid iteration, CI
+#      - Hot-spot tuning in production
+#
+# A practical workflow:
+#
+# 1. Run ``ApplyDefaultSchedule`` with the full rule set to cover all 
functions.
+# 2. Profile the compiled model to identify hot-spot kernels.
+# 3. Use ``MetaScheduleTuneTIR`` to auto-tune only those kernels.
+#
+# Note that ``MetaScheduleTuneTIR`` does **not** automatically skip functions 
already
+# scheduled by DLight — it processes every ``PrimFunc`` in the module. In 
practice this
+# is harmless (tuning an already-scheduled function simply re-explores its 
space), but if
+# you want to avoid the extra search cost, filter the module or use 
``MetaScheduleTuneIRMod``
+# with ``op_names`` to target specific functions.
+
+######################################################################
+# Writing a Custom Rule
+# ---------------------
+# You can extend DLight by writing your own ``ScheduleRule``. The simplest way 
is
+# ``ScheduleRule.from_callable``, which wraps a plain function into a rule 
**instance**.
+
+from tvm import s_tir
+from tvm.s_tir.dlight.analysis import normalize_prim_func
+from tvm.s_tir.dlight.base.schedule_rule import ScheduleRule
+
+
[email protected]_callable("MyTileAndBind")
+def my_tile_and_bind(func: tirx.PrimFunc, target: tvm.target.Target, tunable: 
bool):
+    """A minimal rule: for single-block injective functions, tile and bind to 
GPU threads."""
+    if not isinstance(func, tirx.PrimFunc):
+        return None
+    sch = s_tir.Schedule(func)
+    # Use normalize_prim_func to get block info with correct spatial/reduction 
classification.
+    # This is the same analysis used by built-in DLight rules.
+    block_infos = normalize_prim_func(sch)
+    if block_infos is None or len(block_infos) != 1:
+        return None  # only handle single-block functions
+    info = block_infos[0]
+    if not info.is_injective():
+        return None  # skip reductions — dom_kind() uses iter_type, not loop 
kind
+    loops = sch.get_loops(info.block_rv)
+    if len(loops) == 0:
+        return None
+    fused = sch.fuse(*loops)
+    bx, tx = sch.split(fused, factors=[None, 256])
+    sch.bind(bx, "blockIdx.x")
+    sch.bind(tx, "threadIdx.x")
+    return sch
+
+
+######################################################################
+# Insert the custom rule into the rule chain. Note that ``from_callable`` 
returns an
+# **instance**, so pass it directly — do not call ``my_tile_and_bind()`` again.
+
+with target:
+    custom_mod = dl.ApplyDefaultSchedule(
+        dl.gpu.Matmul(),
+        dl.gpu.GeneralReduction(),
+        my_tile_and_bind,  # our custom rule, tried before Fallback
+        dl.gpu.Fallback(),
+    )(mod)
+
+custom_mod.show()
+
+######################################################################
+# To build a production-quality rule, subclass ``ScheduleRule`` directly and 
implement
+# ``apply()`` with full analysis logic (see ``tvm.s_tir.dlight.gpu.Matmul`` 
for an example).
+
+######################################################################
+# Summary
+# -------
+# - **DLight** provides fast, deterministic GPU scheduling via rule matching.
+# - Rules are tried in order; the first match wins. Put specialized rules 
before general ones.
+# - Use the **single-rule probing** technique to diagnose which rule handles 
each function.
+# - Combine DLight with MetaSchedule: DLight for baseline coverage, 
MetaSchedule for hot-spot tuning.
+# - Extend DLight by writing custom ``ScheduleRule`` implementations.
+#
+# For DLight's role in the broader optimization pipeline, see 
:ref:`customize_opt`.
diff --git a/docs/deep_dive/tensor_ir/tutorials/meta_schedule.py 
b/docs/deep_dive/tensor_ir/tutorials/meta_schedule.py
new file mode 100644
index 0000000000..a263397bbe
--- /dev/null
+++ b/docs/deep_dive/tensor_ir/tutorials/meta_schedule.py
@@ -0,0 +1,307 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# ruff: noqa: E402
+
+"""
+.. _meta_schedule_deep_dive:
+
+MetaSchedule: Search-Based Auto-Tuning
+=======================================
+MetaSchedule is TVM's search-based auto-tuning framework, located in
+``python/tvm/s_tir/meta_schedule/``. It explores different TIR schedules
+(loop tiling, vectorization, thread binding, etc.) and measures them on real
+hardware to find the fastest implementation for each operator.
+
+While **DLight** (see :ref:`dlight_gpu_scheduling`) provides rule-based 
scheduling with zero
+search time, MetaSchedule trades compilation time for better performance by 
searching over
+the space of possible schedules.
+
+.. contents:: Table of Contents
+    :local:
+    :depth: 1
+"""
+
+######################################################################
+# Architecture Overview
+# ---------------------
+# A MetaSchedule tuning session involves the following components:
+#
+# - **ExtractedTask**: A unique TIR workload extracted from a Relax IRModule,
+#   with a ``task_name`` and ``weight`` (call frequency in the graph).
+# - **TuneContext**: Container holding all resources for a single tuning task
+#   (module, target, space generator, search strategy, etc.).
+# - **SpaceGenerator** (default: ``PostOrderApply``): Generates the design 
space
+#   of possible schedules by applying ``ScheduleRule`` instances to each block.
+# - **SearchStrategy** (default: ``EvolutionarySearch``): Explores the design
+#   space using an evolutionary algorithm guided by a cost model.
+# - **CostModel** (default: ``XGBModel``): Predicts schedule performance using
+#   XGBoost, reducing the number of actual hardware measurements needed.
+#   Alternatives include ``MLPModel`` (neural network) and ``RandomModel``
+#   (baseline).
+# - **Builder** / **Runner**: Compile and execute candidates on real hardware 
to
+#   obtain measured run times.
+# - **Database** (default: ``JSONDatabase``): Persistently stores tuning 
records
+#   (schedule traces + measured run times) for later retrieval.
+# - **TaskScheduler** (default: ``GradientBasedScheduler``): Allocates tuning
+#   budget across multiple tasks based on their weights and estimated 
improvement
+#   potential.
+#
+# The tuning loop works as follows:
+#
+# 1. The **TaskScheduler** picks a task to tune.
+# 2. The **SpaceGenerator** produces candidate schedules from the design space.
+# 3. The **SearchStrategy** selects candidates (guided by the **CostModel**),
+#    sends them to the **Builder** and **Runner** for measurement.
+# 4. Measured results are committed to the **Database** and used to update the
+#    **CostModel** for the next iteration.
+# 5. Repeat until the trial budget is exhausted.
+
+######################################################################
+# Prepare a Model
+# ---------------
+# We reuse a simple model to demonstrate MetaSchedule APIs.
+
+import os
+import tempfile
+
+import tvm
+from tvm import relax
+from tvm.relax.frontend import nn
+
+
+class DemoModel(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.fc1 = nn.Linear(784, 256)
+        self.relu = nn.ReLU()
+        self.fc2 = nn.Linear(256, 10, bias=False)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = self.fc2(x)
+        return x
+
+
+input_shape = (1, 784)
+mod, params = DemoModel().export_tvm({"forward": {"x": 
nn.spec.Tensor(input_shape, "float32")}})
+
+device = tvm.cuda(0)
+target = tvm.target.Target.from_device(device)
+
+######################################################################
+# User-Facing Entry Points
+# ------------------------
+# MetaSchedule provides several levels of API, from high-level transforms to
+# low-level tuning functions.
+#
+# Transform-Based API (Recommended)
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# These are Relax passes that can be composed into a ``Sequential`` pipeline:
+#
+# - **MetaScheduleTuneIRMod**: Tunes an entire IRModule. Supports ``op_names``
+#   for selective operator tuning.
+# - **MetaScheduleTuneTIR**: Tunes all TIR functions individually (no
+#   ``op_names`` filtering).
+# - **MetaScheduleApplyDatabase**: Applies the best schedules from the tuning
+#   database. Only replaces functions that have records; the rest are left
+#   unchanged.
+#
+# Here is a typical tune-and-apply pipeline:
+#
+# .. note::
+#
+#    To save CI time and avoid flakiness, we skip the tuning process in CI.
+
+if os.getenv("CI", "") != "true":
+    with target, tempfile.TemporaryDirectory() as tmp_dir:
+        tuned_mod = tvm.ir.transform.Sequential(
+            [
+                relax.get_pipeline("zero"),
+                relax.transform.MetaScheduleTuneTIR(
+                    work_dir=tmp_dir,
+                    max_trials_global=300,
+                ),
+                relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
+            ]
+        )(mod)
+
+    tuned_mod.show()
+
+######################################################################
+# Inspecting Tunable Tasks
+# ------------------------
+# Before tuning, use ``extract_tasks`` to see what MetaSchedule will tune:
+
+from tvm.s_tir.meta_schedule.relax_integration import extract_tasks
+
+with target:
+    legalized_mod = relax.get_pipeline("zero")(mod)
+
+tasks = extract_tasks(legalized_mod, target)
+for i, task in enumerate(tasks):
+    print(f"Task {i}: {task.task_name}  (weight={task.weight})")
+
+######################################################################
+# Each ``ExtractedTask`` has:
+#
+# - ``task_name``: Derived from the PrimFunc name (e.g., 
``"fused_matmul_add_relu"``).
+# - ``weight``: How many ``call_tir`` sites invoke this workload. The task
+#   scheduler uses weights to allocate more budget to frequently-called 
operators.
+# - ``dispatched``: List of candidate TIR modules for this workload.
+
+######################################################################
+# Selective Operator Tuning
+# -------------------------
+# ``MetaScheduleTuneIRMod`` accepts an ``op_names`` parameter to tune only
+# operators whose task name contains any of the given strings:
+#
+# .. code-block:: python
+#
+#     with target:
+#         mod = tvm.ir.transform.Sequential([
+#             relax.transform.MetaScheduleTuneIRMod(
+#                 params={},
+#                 work_dir="./tuning_logs",
+#                 max_trials_global=300,
+#                 op_names=["matmul"],  # Only tune matmul-related operators
+#             ),
+#             
relax.transform.MetaScheduleApplyDatabase(work_dir="./tuning_logs"),
+#         ])(mod)
+#
+# Operators without tuning records are left unscheduled -- you can apply 
DLight or
+# other rule-based schedules to cover them afterward.
+#
+# .. note::
+#
+#    ``MetaScheduleTuneTIR`` does not support ``op_names`` filtering. Use
+#    ``MetaScheduleTuneIRMod`` when you need selective tuning.
+
+######################################################################
+# Database
+# --------
+# When using a fixed ``work_dir``, tuning results are persisted in two
+# newline-delimited JSON files:
+#
+# - ``database_workload.json``: One line per unique workload (structural hash +
+#   serialized IRModule).
+# - ``database_tuning_record.json``: One line per tuning record (workload 
index +
+#   schedule trace + measured run times).
+#
+# Records are appended incrementally as tuning progresses.
+#
+# Resumption Semantics
+# ~~~~~~~~~~~~~~~~~~~~
+# When you re-run tuning with the same ``work_dir``, existing records are 
loaded
+# and used as warm-start seeds for the evolutionary search. The tuner does
+# **not** skip already-seen workloads entirely -- it starts from a better 
initial
+# population, so re-runs are faster than starting from scratch but still 
consume
+# trials.
+#
+# Once tuning is done, subsequent compilations only need
+# ``MetaScheduleApplyDatabase``:
+#
+# .. code-block:: python
+#
+#     with target:
+#         mod = relax.transform.MetaScheduleApplyDatabase(
+#             work_dir="./tuning_logs"
+#         )(mod)
+#
+# Database Implementations
+# ~~~~~~~~~~~~~~~~~~~~~~~~
+# MetaSchedule ships several database backends:
+#
+# - **JSONDatabase**: Persistent file-based storage (default). Created
+#   automatically when you pass ``work_dir``.
+# - **MemoryDatabase**: In-memory, non-persistent. Useful for testing.
+# - **UnionDatabase**: Queries all sub-databases and returns the globally best
+#   record.
+# - **OrderedUnionDatabase**: Queries sub-databases in order; returns from the
+#   first one that has a match.
+# - **ScheduleFnDatabase**: Wraps a user-provided scheduling function.
+
+######################################################################
+# Cross-Model Database Reuse
+# --------------------------
+# MetaSchedule identifies workloads by their structural hash. If two models
+# contain operators with the same shape, dtype, and computation, they share the
+# same hash and can reuse tuning records.
+#
+# module_equality Options
+# ~~~~~~~~~~~~~~~~~~~~~~~
+# - ``"structural"`` (default): Exact structural match. Safe but strict.
+# - ``"anchor-block"``: Match based on the dominant compute block, ignoring
+#   surrounding context. More permissive -- enables sharing across fused 
operators
+#   that have the same core computation but different fusion boundaries.
+#
+# ``OrderedUnionDatabase`` enables a layered lookup strategy: check a local
+# database first, then fall back to a shared team database:
+#
+# .. code-block:: python
+#
+#     from tvm.s_tir.meta_schedule.database import JSONDatabase, 
OrderedUnionDatabase
+#
+#     local_db = JSONDatabase(work_dir="./my_tuning_logs")
+#     shared_db = JSONDatabase(work_dir="/shared/tuning_db")
+#     combined_db = OrderedUnionDatabase(local_db, shared_db)
+#
+#     with target, combined_db:
+#         mod = relax.transform.MetaScheduleApplyDatabase()(mod)
+
+######################################################################
+# Key Parameters Reference
+# ------------------------
+#
+# .. list-table::
+#    :header-rows: 1
+#    :widths: 25 75
+#
+#    * - Parameter
+#      - Description
+#    * - ``max_trials_global``
+#      - Total trial budget shared across all tasks. Set proportional to the
+#        number of tasks (e.g., 200-500 trials per task for good results).
+#    * - ``max_trials_per_task``
+#      - Per-task trial cap. Defaults to ``max_trials_global`` if not set.
+#    * - ``op_names``
+#      - List of strings to filter tasks by name (substring match).
+#        ``MetaScheduleTuneIRMod`` only.
+#    * - ``work_dir``
+#      - Directory for database files and logs. Use a fixed path to enable
+#        persistence and resumption.
+#    * - ``cost_model``
+#      - ``"xgb"`` (XGBoost, default), ``"mlp"`` (neural network), or
+#        ``"random"`` (baseline). Only available via ``tune_relax``.
+#    * - ``runner``
+#      - ``"local"`` (default) or an ``RPCRunner`` instance for remote devices.
+#        Only available via ``tune_relax``.
+#    * - ``module_equality``
+#      - ``"structural"`` (default) or ``"anchor-block"`` for more permissive
+#        cross-model matching. Only available via ``tune_relax``.
+
+######################################################################
+# Summary
+# -------
+# - **MetaSchedule** finds high-quality TIR schedules by searching over the
+#   design space and measuring on real hardware.
+# - Use ``MetaScheduleTuneTIR`` for full-module tuning, or
+#   ``MetaScheduleTuneIRMod`` with ``op_names`` for selective tuning.
+# - Tuning records persist in ``work_dir`` and can be reused across runs and
+#   models with the same operator shapes.
+# - Combine with DLight: use DLight for fast baseline coverage, then 
MetaSchedule
+#   for hot-spot tuning (see :ref:`dlight_gpu_scheduling`).

Reply via email to