Thank you for your reply, this is the way I want.
According to your suggestion, I did an experiment and changed the constant 128
to the variable OH,OW.
@T.prim_func
def compute_at_call_extern(a: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
OH = T.var("int32")
OW = T.var("int32")
A = T.match_buffer(a, (OH, OW), "float32")
B = T.alloc_buffer((OH, OW), "float32")
C = T.match_buffer(c, (OH, OW), "float32")
for i in range(OH):
with T.block("B"):
vi = T.axis.spatial(OH, i)
T.reads([A[0:OH,0:OW]])
T.writes([B[vi, 0:OW]])
T.evaluate(T.call_extern("test_cust_mul",
T.tvm_access_ptr(T.type_annotation(dtype="float32"),A.data,
0, 0, 1, dtype="handle"),
T.tvm_access_ptr(T.type_annotation(dtype="float32"),B.data,
vi*128, 0, 1, dtype="handle"), dtype="handle"))
for i, j in T.grid(OH, OW):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
def test_compute_at2():
sch = tir.Schedule(compute_at_call_extern, debug_mask="all")
print(sch.mod.script())
block_b = sch.get_block("B")
loop_i, = sch.get_loops(block_b)
block_c = sch.get_block("C")
sch.reverse_compute_at(block_c, loop_i, preserve_unit_loops=True)
print(sch.mod.script())
The printed result is:
@tvm.script.ir_module
class Module:
@T.prim_func
def main(a: T.handle, c: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
OH = T.var("int32")
OW = T.var("int32")
A = T.match_buffer(a, [OH, OW], dtype="float32")
C = T.match_buffer(c, [OH, OW], dtype="float32")
# body
# with T.block("root")
B = T.alloc_buffer([OH, OW], dtype="float32")
for i in T.serial(0, OH):
with T.block("B"):
vi = T.axis.spatial(OH, i)
T.reads([A[0 : OH, 0 : OW]])
T.writes([B[vi, 0 : OW]])
T.evaluate(T.call_extern("test_cust_mul",
T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 0, 1,
dtype="handle"), T.tvm_access_ptr(T.type_annotation(dtype="float32"), B.data,
vi * 128, 0, 1, dtype="handle"), dtype="handle"))
for ax0, ax1 in T.grid(T.min(1, OH - i), OW):
with T.block("C"):
vi = T.axis.spatial(OH, i + ax0)
vj = T.axis.spatial(OW, ax1)
T.reads([B[vi, vj]])
T.writes([C[vi, vj]])
C[vi, vj] = B[vi, vj] + T.float32(1)
I have a new problem,
`for ax0, ax1 in T.grid(T.min(1, OH - i), OW): `
Rather than
`for j in T.serial(0: OW):`
How can I deal with this problem?
Thank you very much for your reply!
---
[Visit
Topic](https://discuss.tvm.apache.org/t/a-failed-example-of-using-compute-at-based-on-tvmscript/11489/4)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/4502f22a6c455d7d9f9c23ee7dbd4f9da5e73fcb292024fd889ca442dc51444e).