Thank you for your reply, this is the way I want.

According to your suggestion, I did an experiment and changed the constant 128 
to the variable OH,OW.
    
    @T.prim_func
    def compute_at_call_extern(a: T.handle, c: T.handle) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        OH = T.var("int32")
        OW = T.var("int32")
        A = T.match_buffer(a, (OH, OW), "float32")
        B = T.alloc_buffer((OH, OW), "float32")
        C = T.match_buffer(c, (OH, OW), "float32")
        for i in range(OH):
            with T.block("B"):
                vi = T.axis.spatial(OH, i)
                T.reads([A[0:OH,0:OW]])
                T.writes([B[vi, 0:OW]])
                T.evaluate(T.call_extern("test_cust_mul", 
                    T.tvm_access_ptr(T.type_annotation(dtype="float32"),A.data, 
0, 0, 1, dtype="handle"),
                    T.tvm_access_ptr(T.type_annotation(dtype="float32"),B.data, 
vi*128, 0, 1, dtype="handle"), dtype="handle"))

        for i, j in T.grid(OH, OW):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j])
                C[vi, vj] = B[vi, vj] + 1.0

    def test_compute_at2():
        sch = tir.Schedule(compute_at_call_extern, debug_mask="all")
        print(sch.mod.script())
        block_b = sch.get_block("B")
        loop_i, = sch.get_loops(block_b)
        block_c = sch.get_block("C")
        sch.reverse_compute_at(block_c, loop_i, preserve_unit_loops=True)  
        print(sch.mod.script())

The printed result is:

    @tvm.script.ir_module
    class Module:
        @T.prim_func
        def main(a: T.handle, c: T.handle) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            OH = T.var("int32")
            OW = T.var("int32")
            A = T.match_buffer(a, [OH, OW], dtype="float32")
            C = T.match_buffer(c, [OH, OW], dtype="float32")
            # body
            # with T.block("root")
            B = T.alloc_buffer([OH, OW], dtype="float32")
            for i in T.serial(0, OH):
                with T.block("B"):
                    vi = T.axis.spatial(OH, i)
                    T.reads([A[0 : OH, 0 : OW]])
                    T.writes([B[vi, 0 : OW]])
                    T.evaluate(T.call_extern("test_cust_mul", 
T.tvm_access_ptr(T.type_annotation(dtype="float32"), A.data, 0, 0, 1, 
dtype="handle"), T.tvm_access_ptr(T.type_annotation(dtype="float32"), B.data, 
vi * 128, 0, 1, dtype="handle"), dtype="handle"))
                for ax0, ax1 in T.grid(T.min(1, OH - i), OW):  
                    with T.block("C"):
                        vi = T.axis.spatial(OH, i + ax0)
                        vj = T.axis.spatial(OW, ax1)
                        T.reads([B[vi, vj]])
                        T.writes([C[vi, vj]])
                        C[vi, vj] = B[vi, vj] + T.float32(1)

I have a new problem,
    `for ax0, ax1 in T.grid(T.min(1, OH - i), OW):   `

Rather than
    `for j in T.serial(0: OW):`

How can I deal with this problem?

Thank you very much for your reply!





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/a-failed-example-of-using-compute-at-based-on-tvmscript/11489/4)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/4502f22a6c455d7d9f9c23ee7dbd4f9da5e73fcb292024fd889ca442dc51444e).

Reply via email to