Thank you for your reply, this is the way I want.
According to your suggestion, I did an experiment and changed the constant 128 to the variable OH,OW. @T.prim_func def compute_at_call_extern(a: T.handle, c: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) OH = T.var("int32") OW = T.var("int32") A = T.match_buffer(a, (OH, OW), "float32") B = T.alloc_buffer((OH, OW), "float32") C = T.match_buffer(c, (OH, OW), "float32") for i in range(OH): with T.block("B"): vi = T.axis.spatial(OH, i) T.reads([A[0:OH,0:OW]]) T.writes([B[vi, 0:OW]]) T.evaluate(T.call_extern("test_cust_mul", T.tvm_access_ptr(T.type_annotation(dtype="float32"),A.data, 0, 0, 1, dtype="handle"), T.tvm_access_ptr(T.type_annotation(dtype="float32"),B.data, vi*128, 0, 1, dtype="handle"), dtype="handle")) for i, j in T.grid(OH, OW): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 def test_compute_at2(): sch = tir.Schedule(compute_at_call_extern, debug_mask="all") print(sch.mod.script()) block_b = sch.get_block("B") loop_i, = sch.get_loops(block_b) block_c = sch.get_block("C") sch.reverse_compute_at(block_c, loop_i, preserve_unit_loops=True) print(sch.mod.script()) The printed result is: @tvm.script.ir_module class Module: @T.prim_func def main(a: T.handle, c: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) OH = T.var("int32") OW = T.var("int32") A = T.match_buffer(a, [OH, OW], dtype="float32") C = T.match_buffer(c, [OH, OW], dtype="float32") # body # with T.block("root") B = T.alloc_buffer([OH, OW], dtype="float32") for i in T.serial(0, OH): with T.block("B"): vi = T.axis.spatial(OH, i) T.reads([A[0 : OH, 0 : OW]]) T.writes([B[vi, 0 : OW]]) T.evaluate(T.call_extern("test_cust_mul", T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 0, 1, dtype="handle"), T.tvm_access_ptr(T.type_annotation(dtype="float32"), B.data, vi * 128, 0, 1, dtype="handle"), dtype="handle")) for ax0, ax1 in T.grid(T.min(1, OH - i), OW): with T.block("C"): vi = T.axis.spatial(OH, i + ax0) vj = T.axis.spatial(OW, ax1) T.reads([B[vi, vj]]) T.writes([C[vi, vj]]) C[vi, vj] = B[vi, vj] + T.float32(1) I have a new problem, `for ax0, ax1 in T.grid(T.min(1, OH - i), OW): ` Rather than `for j in T.serial(0: OW):` How can I deal with this problem? Thank you very much for your reply! --- [Visit Topic](https://discuss.tvm.apache.org/t/a-failed-example-of-using-compute-at-based-on-tvmscript/11489/4) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/4502f22a6c455d7d9f9c23ee7dbd4f9da5e73fcb292024fd889ca442dc51444e).