This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9c8e5a6037 [TIR] Handle Bind in LowerDeviceKernelLaunch (#18912)
9c8e5a6037 is described below
commit 9c8e5a60376ff16a7a88dc841271befd9f32bf96
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Mar 27 09:30:37 2026 -0400
[TIR] Handle Bind in LowerDeviceKernelLaunch (#18912)
DeviceInfoCollector did not track Bind statements, so when CSE (or
any other pass) inserted a Bind before a thread_extent AttrStmt, the
collected extent referenced a locally-bound variable instead of
function parameters. LowerDeviceKernelLaunch then produced dangling
references in the host function.
Fix: record Bind definitions in DeviceInfoCollector and inline them
when extracting thread_extent values and dynamic shared memory sizes.
---
src/tirx/transform/lower_device_kernel_launch.cc | 22 ++++++++-
.../test_tir_transform_device_kernel_launch.py | 52 ++++++++++++++++++++++
2 files changed, 73 insertions(+), 1 deletion(-)
diff --git a/src/tirx/transform/lower_device_kernel_launch.cc
b/src/tirx/transform/lower_device_kernel_launch.cc
index 3ff4cf17c5..fea8d458b9 100644
--- a/src/tirx/transform/lower_device_kernel_launch.cc
+++ b/src/tirx/transform/lower_device_kernel_launch.cc
@@ -104,6 +104,17 @@ class DeviceInfoCollector : public StmtVisitor {
return extent.value();
}
+ void VisitStmt_(const BindNode* op) final {
+ // Track Bind definitions so that thread_extent values and
+ // dyn_shmem_size expressions that reference locally-bound
+ // variables (e.g. CSE variables) can be inlined back to
+ // expressions over function parameters. Substitute earlier
+ // bindings into the value to handle chains (cse_v2 = f(cse_v1)).
+ PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) :
op->value;
+ bind_map_.Set(op->var, value);
+ StmtVisitor::VisitStmt_(op);
+ }
+
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
@@ -113,7 +124,10 @@ class DeviceInfoCollector : public StmtVisitor {
if (!defined_thread.count(iv.get())) {
defined_thread.insert(iv.get());
info_.launch_params.push_back(iv->thread_tag);
- thread_extent.Set(iv->thread_tag, op->value);
+ // Inline any locally-bound variables (e.g. from CSE) so
+ // that the extent is expressible in terms of function params.
+ PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) :
op->value;
+ thread_extent.Set(iv->thread_tag, value);
}
}
@@ -133,6 +147,10 @@ class DeviceInfoCollector : public StmtVisitor {
}
dyn_size *= op->buffer->dtype.bytes();
+ // Inline any locally-bound variables (e.g. from CSE).
+ if (bind_map_.size()) {
+ dyn_size = Substitute(dyn_size, bind_map_);
+ }
dyn_shmem_size = dyn_size;
}
StmtVisitor::VisitStmt_(op);
@@ -146,6 +164,8 @@ class DeviceInfoCollector : public StmtVisitor {
ffi::Map<ffi::String, PrimExpr> thread_extent;
// The amount of dynamic shared memory used
ffi::Optional<PrimExpr> dyn_shmem_size{std::nullopt};
+ // Accumulated Bind definitions for inlining into extent/size expressions.
+ ffi::Map<Var, PrimExpr> bind_map_;
};
class ReturnRemover : public StmtExprMutator {
diff --git
a/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py
b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py
index 6d77d7e871..3dab487ab5 100644
--- a/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py
+++ b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py
@@ -223,5 +223,57 @@ def test_same_device_different_target():
tvm.ir.assert_structural_equal(After, Expected)
+def test_bind_before_thread_extent():
+ """DeviceInfoCollector inlines Bind-defined variables in thread extents.
+
+ When CSE (or another pass) inserts Bind statements before
+ thread_extent AttrStmts, the extent value may reference a
+ locally-bound variable instead of function parameters.
+ LowerDeviceKernelLaunch must inline these bindings so that the
+ launch argument is expressible in terms of the caller's arguments.
+ """
+
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer(16, "float32"), n: T.int32):
+ T.func_attr({"target": T.target("llvm")})
+ Before.kernel(A.data, n)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32"), n: T.int32):
+ T.func_attr({"target": T.target("cuda"), "global_symbol":
"kernel"})
+ A = T.decl_buffer(16, dtype="float32", data=A_data)
+ v: T.int32 = n + 1
+ i = T.launch_thread("threadIdx.x", v)
+ A[i] = 0.0
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def main(A: T.Buffer(16, "float32"), n: T.int32):
+ T.func_attr({"target": T.target("llvm")})
+ T.call_packed("kernel", A.data, n, n + 1)
+
+ @T.prim_func
+ def kernel(A_data: T.handle("float32"), n: T.int32):
+ T.func_attr(
+ {
+ "target": T.target("cuda"),
+ "calling_conv": 2,
+ "tirx.kernel_launch_params": ["threadIdx.x"],
+ "global_symbol": "kernel",
+ "tirx.is_global_func": True,
+ }
+ )
+ A = T.decl_buffer(16, dtype="float32", data=A_data)
+ v: T.int32 = n + 1
+ i = T.launch_thread("threadIdx.x", v)
+ A[i] = 0.0
+
+ After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()