Thank you! Your explanation about TE tensors is very clear。

However,I think maybe those `T_reshape.op` may refer to different ones because 
I use several topi.reshape, and here's my code:

    def function():
        A = te.placeholder((1, 3, 5, 5), name="A", dtype="float32")
        kernel = te.placeholder((5, 5), name="kernel", dtype="float32")
        max_val = 1e4
        se_h, se_w = kernel.shape
        origin = [se_h // 2, se_w // 2]
        pad_e1 = [0, 0, origin[0], origin[1]]
        pad_e2 = [0, 0, se_h - origin[0] - 1, se_w - origin[1] - 1]
        border_value =  max_val
        output = topi.nn.pad(A, pad_e1, pad_e2, pad_value=border_value)
        print(output.shape)
        neighborhood = te.compute((5, 5), lambda i0, i1: 
te.if_then_else(kernel[i0, i1] == 0, -max_val, 0), name="neighborhood")
        B, C, H, W = A.shape
        Hpad, Wpad = output.shape[-2:]
        reshape_kernel = neight2channels(kernel)
        reshape1 = topi.reshape(output, [B*C, 1, Hpad, Wpad])
        conv1 = topi.nn.conv2d(reshape1, reshape_kernel, 1, 0, 1)
        out1 = topi.min(conv1, 1)
        reshape2 = topi.reshape(neighborhood, [1])
        out2 = topi.subtract(out1, reshape2)
        out = topi.reshape(out1, [B, C, H, W])
        return [A, kernel, out]


    def neight2channels(kernel):
        h, w = kernel.shape
        temp = te.compute((h*w, h*w), lambda i, j: te.if_then_else(i == j, 1, 
0), name="temp")
        reshape_kernel = topi.reshape(temp, [h*w, 1, h, w])
        return reshape_kernel

And here is the printed schedule:
   >      PadInput_i0, PadInput_i1, PadInput_i2, PadInput_i3 = 
tuple(PadInput.op.axis) + tuple(PadInput.op.reduce_axis)
>         T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = 
> tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
>         pad_temp_i0, pad_temp_i1, pad_temp_i2, pad_temp_i3 = 
> tuple(pad_temp.op.axis) + tuple(pad_temp.op.reduce_axis)
>         compute_i, compute_j = tuple(compute.op.axis) + 
> tuple(compute.op.reduce_axis)
>         T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = 
> tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
>         compute_nn, compute_ff, compute_yy, compute_xx, compute_rc, 
> compute_ry, compute_rx = tuple(compute.op.axis) + 
> tuple(compute.op.reduce_axis)
>         compute_red_ax0, compute_red_ax1, compute_red_ax2, compute_red_k1 = 
> tuple(compute_red.op.axis) + tuple(compute_red.op.reduce_axis)
>         T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3 = 
> tuple(T_reshape.op.axis) + tuple(T_reshape.op.reduce_axis)
>         compute_local, = s.cache_write([compute], "local")
>         compute_local_nn_c, compute_local_ff_c, compute_local_yy_c, 
> compute_local_xx_c, compute_local_rc, compute_local_ry, compute_local_rx = 
> tuple(compute_local.op.axis) + tuple(compute_local.op.reduce_axis)
>         compute_local_nn_c_o_i, compute_local_nn_c_i = 
> s[compute_local].split(compute_local_nn_c, factor=3)
>         compute_local_nn_c_o_o_i, compute_local_nn_c_o_i = 
> s[compute_local].split(compute_local_nn_c_o_i, factor=1)
>         compute_local_nn_c_o_o_o_i, compute_local_nn_c_o_o_i = 
> s[compute_local].split(compute_local_nn_c_o_o_i, factor=1)
>         compute_local_nn_c_o_o_o_o, compute_local_nn_c_o_o_o_i = 
> s[compute_local].split(compute_local_nn_c_o_o_o_i, factor=1)
>         compute_local_ff_c_o_i, compute_local_ff_c_i = 
> s[compute_local].split(compute_local_ff_c, factor=1)
>         compute_local_ff_c_o_o_i, compute_local_ff_c_o_i = 
> s[compute_local].split(compute_local_ff_c_o_i, factor=1)
>         compute_local_ff_c_o_o_o_i, compute_local_ff_c_o_o_i = 
> s[compute_local].split(compute_local_ff_c_o_o_i, factor=25)
>         compute_local_ff_c_o_o_o_o, compute_local_ff_c_o_o_o_i = 
> s[compute_local].split(compute_local_ff_c_o_o_o_i, factor=1)
>         compute_local_yy_c_o_i, compute_local_yy_c_i = 
> s[compute_local].split(compute_local_yy_c, factor=1)
>         compute_local_yy_c_o_o_i, compute_local_yy_c_o_i = 
> s[compute_local].split(compute_local_yy_c_o_i, factor=1)
>         compute_local_yy_c_o_o_o_i, compute_local_yy_c_o_o_i = 
> s[compute_local].split(compute_local_yy_c_o_o_i, factor=1)
>         compute_local_yy_c_o_o_o_o, compute_local_yy_c_o_o_o_i = 
> s[compute_local].split(compute_local_yy_c_o_o_o_i, factor=5)
>         compute_local_xx_c_o_i, compute_local_xx_c_i = 
> s[compute_local].split(compute_local_xx_c, factor=1)
>         compute_local_xx_c_o_o_i, compute_local_xx_c_o_i = 
> s[compute_local].split(compute_local_xx_c_o_i, factor=1)
>         compute_local_xx_c_o_o_o_i, compute_local_xx_c_o_o_i = 
> s[compute_local].split(compute_local_xx_c_o_o_i, factor=5)
>         compute_local_xx_c_o_o_o_o, compute_local_xx_c_o_o_o_i = 
> s[compute_local].split(compute_local_xx_c_o_o_o_i, factor=1)
>         compute_local_rc_o_i, compute_local_rc_i = 
> s[compute_local].split(compute_local_rc, factor=1)
>         compute_local_rc_o_o, compute_local_rc_o_i = 
> s[compute_local].split(compute_local_rc_o_i, factor=1)
>         compute_local_ry_o_i, compute_local_ry_i = 
> s[compute_local].split(compute_local_ry, factor=5)
>         compute_local_ry_o_o, compute_local_ry_o_i = 
> s[compute_local].split(compute_local_ry_o_i, factor=1)
>         compute_local_rx_o_i, compute_local_rx_i = 
> s[compute_local].split(compute_local_rx, factor=1)
>         compute_local_rx_o_o, compute_local_rx_o_i = 
> s[compute_local].split(compute_local_rx_o_i, factor=1)
>         s[compute_local].reorder(compute_local_nn_c_o_o_o_o, 
> compute_local_ff_c_o_o_o_o, compute_local_yy_c_o_o_o_o, 
> compute_local_xx_c_o_o_o_o, compute_local_nn_c_o_o_o_i, 
> compute_local_ff_c_o_o_o_i, compute_local_yy_c_o_o_o_i, 
> compute_local_xx_c_o_o_o_i, compute_local_nn_c_o_o_i, 
> compute_local_ff_c_o_o_i, compute_local_yy_c_o_o_i, compute_local_xx_c_o_o_i, 
> compute_local_rc_o_o, compute_local_ry_o_o, compute_local_rx_o_o, 
> compute_local_rc_o_i, compute_local_ry_o_i, compute_local_rx_o_i, 
> compute_local_nn_c_o_i, compute_local_ff_c_o_i, compute_local_yy_c_o_i, 
> compute_local_xx_c_o_i, compute_local_rc_i, compute_local_ry_i, 
> compute_local_rx_i, compute_local_nn_c_i, compute_local_ff_c_i, 
> compute_local_yy_c_i, compute_local_xx_c_i)
>         compute_nn_o_i, compute_nn_i = s[compute].split(compute_nn, factor=3)
>         compute_nn_o_o_i, compute_nn_o_i = s[compute].split(compute_nn_o_i, 
> factor=1)
>         compute_nn_o_o_o, compute_nn_o_o_i = 
> s[compute].split(compute_nn_o_o_i, factor=1)
>         compute_ff_o_i, compute_ff_i = s[compute].split(compute_ff, factor=1)
>         compute_ff_o_o_i, compute_ff_o_i = s[compute].split(compute_ff_o_i, 
> factor=25)
>         compute_ff_o_o_o, compute_ff_o_o_i = 
> s[compute].split(compute_ff_o_o_i, factor=1)
>         compute_yy_o_i, compute_yy_i = s[compute].split(compute_yy, factor=1)
>         compute_yy_o_o_i, compute_yy_o_i = s[compute].split(compute_yy_o_i, 
> factor=1)
>         compute_yy_o_o_o, compute_yy_o_o_i = 
> s[compute].split(compute_yy_o_o_i, factor=5)
>         compute_xx_o_i, compute_xx_i = s[compute].split(compute_xx, factor=1)
>         compute_xx_o_o_i, compute_xx_o_i = s[compute].split(compute_xx_o_i, 
> factor=5)
>         compute_xx_o_o_o, compute_xx_o_o_i = 
> s[compute].split(compute_xx_o_o_i, factor=1)
>         s[compute].reorder(compute_nn_o_o_o, compute_ff_o_o_o, 
> compute_yy_o_o_o, compute_xx_o_o_o, compute_nn_o_o_i, compute_ff_o_o_i, 
> compute_yy_o_o_i, compute_xx_o_o_i, compute_nn_o_i, compute_ff_o_i, 
> compute_yy_o_i, compute_xx_o_i, compute_nn_i, compute_ff_i, compute_yy_i, 
> compute_xx_i)
>         s[compute_local].compute_at(s[compute], compute_xx_o_i)
>         T_reshape_shared = s.cache_read(T_reshape, "shared", [compute_local])
>         T_reshape_shared_ax0, T_reshape_shared_ax1, T_reshape_shared_ax2, 
> T_reshape_shared_ax3 = tuple(T_reshape_shared.op.axis)
>         s[T_reshape_shared].compute_at(s[compute_local], compute_local_rx_o_o)
>         s[T_reshape].compute_inline()
>         s[compute].compute_inline()
>         pad_temp_shared = s.cache_read(pad_temp, "shared", [compute_local])
>         pad_temp_shared_ax0, pad_temp_shared_ax1, pad_temp_shared_ax2, 
> pad_temp_shared_ax3 = tuple(pad_temp_shared.op.axis)
>         s[pad_temp_shared].compute_at(s[compute_local], compute_local_rx_o_o)
>         s[pad_temp].compute_inline()
>         s[T_reshape].compute_inline()
>         s[PadInput].compute_inline()
>         T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused = 
> s[T_reshape].fuse(T_reshape_ax0, T_reshape_ax1, T_reshape_ax2, T_reshape_ax3)
>         T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_o, 
> T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_i = 
> s[T_reshape].split(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused, factor=32)
>         s[T_reshape].bind(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_o, 
> te.thread_axis("blockIdx.x"))
>         s[T_reshape].bind(T_reshape_ax0_ax1_fused_ax2_fused_ax3_fused_i, 
> te.thread_axis("threadIdx.x"))
>         compute_red_ax0_ax1_fused_ax2_fused = 
> s[compute_red].fuse(compute_red_ax0, compute_red_ax1, compute_red_ax2)
>         compute_red_ax0_ax1_fused_ax2_fused_o, 
> compute_red_ax0_ax1_fused_ax2_fused_i = 
> s[compute_red].split(compute_red_ax0_ax1_fused_ax2_fused, factor=64)
>         s[compute_red].bind(compute_red_ax0_ax1_fused_ax2_fused_o, 
> te.thread_axis("blockIdx.x"))
>         s[compute_red].bind(compute_red_ax0_ax1_fused_ax2_fused_i, 
> te.thread_axis("threadIdx.x"))
>         compute_nn_o_o_o_ff_o_o_o_fused_yy_o_o_o_fused_xx_o_o_o_fused = 
> s[compute].fuse(compute_nn_o_o_o, compute_ff_o_o_o, compute_yy_o_o_o, 
> compute_xx_o_o_o)
>         
> s[compute].bind(compute_nn_o_o_o_ff_o_o_o_fused_yy_o_o_o_fused_xx_o_o_o_fused,
>  te.thread_axis("blockIdx.x"))
>         compute_nn_o_o_i_ff_o_o_i_fused_yy_o_o_i_fused_xx_o_o_i_fused = 
> s[compute].fuse(compute_nn_o_o_i, compute_ff_o_o_i, compute_yy_o_o_i, 
> compute_xx_o_o_i)
>         
> s[compute].bind(compute_nn_o_o_i_ff_o_o_i_fused_yy_o_o_i_fused_xx_o_o_i_fused,
>  te.thread_axis("vthread"))
>         compute_nn_o_i_ff_o_i_fused_yy_o_i_fused_xx_o_i_fused = 
> s[compute].fuse(compute_nn_o_i, compute_ff_o_i, compute_yy_o_i, 
> compute_xx_o_i)
>         
> s[compute].bind(compute_nn_o_i_ff_o_i_fused_yy_o_i_fused_xx_o_i_fused, 
> te.thread_axis("threadIdx.x"))
>         T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused = 
> s[T_reshape_shared].fuse(T_reshape_shared_ax0, T_reshape_shared_ax1, 
> T_reshape_shared_ax2, T_reshape_shared_ax3)
>         T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, 
> T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = 
> s[T_reshape_shared].split(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused, 
> factor=1)
>         
> s[T_reshape_shared].vectorize(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
>         T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, 
> T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = 
> s[T_reshape_shared].split(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o,
>  factor=125)
>         
> s[T_reshape_shared].bind(T_reshape_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i,
>  te.thread_axis("threadIdx.x"))
>         pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused = 
> s[pad_temp_shared].fuse(pad_temp_shared_ax0, pad_temp_shared_ax1, 
> pad_temp_shared_ax2, pad_temp_shared_ax3)
>         pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, 
> pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i = 
> s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused, 
> factor=1)
>         
> s[pad_temp_shared].vectorize(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_i)
>         pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_o, 
> pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i = 
> s[pad_temp_shared].split(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o, 
> factor=125)
>         
> s[pad_temp_shared].bind(pad_temp_shared_ax0_ax1_fused_ax2_fused_ax3_fused_o_i,
>  te.thread_axis("threadIdx.x"))
>         s[compute_local].pragma(compute_local_nn_c_o_o_o_o, 
> "auto_unroll_max_step", 512)
>         s[compute_local].pragma(compute_local_nn_c_o_o_o_o, 
> "unroll_explicit", True)
>         s[compute_red].pragma(compute_red_ax0_ax1_fused_ax2_fused_o, 
> "auto_unroll_max_step", 64)
>         s[compute_red].pragma(compute_red_ax0_ax1_fused_ax2_fused_o, 
> "unroll_explicit", True)

So,since the fusion,how can I use the printed schedule again?





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/print-auto-schedule-python-schedule-with-topi-op/11363/3)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/1629551b4f84d2714664dc69d00871010d688c9a202cbe8913b245d137bbdc40).

Reply via email to