Sorry, I just started contacting tvm. I noticed that in order to speed up conv2d, the necessary transforms for the input and weight will be required:
def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): KH, KW, IC, OC = get_const_tuple(kernel.shape) K = KH * KW * IC N = OC kernel_flat = te.compute( (K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], "weight_flatten" ) pad_K = 0 pad_N = 0 if N % tile_rows != 0: pad_N = tile_rows - (N % tile_rows) if K % tile_cols != 0: pad_K = tile_cols - (K % tile_cols) N_padded = N + pad_N K_padded = K + pad_K if pad_K != 0 or pad_N != 0: kernel_flat = pad( kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N), name="weight_padding" ) return te.compute( (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols), lambda x, y, z, w: kernel_flat[w + tile_cols * y, z + tile_rows * x], name="weight_block_reshape", ) You can see that this is a te.compute, and you can still see this calculation expression at the TIR level,below is the printed TIR: // attr [A_padded] storage_scope = "global" allocate A_padded[uint8 * 1605632] // attr [weight_block_reshape] storage_scope = "global" allocate weight_block_reshape[uint8 * 1024] // attr [C] storage_scope = "global" allocate C[int32 * 1605632] // attr [iter_var(pipeline, , pipeline)] pipeline_exec_scope = 1 for (i1, 0, 50176) { for (i2, 0, 32) { A_padded[((i1*32) + i2)] = tir.if_then_else((i2 < 12), placeholder[((((floordiv(i1, 224)*675) + (floordiv(i2, 6)*675)) + (floormod(i1, 224)*3)) + floormod(i2, 6))], (uint8)0) } } for (x.y.fused, 0, 8) { for (z, 0, 32) { weight_block_reshape[((x.y.fused*128) + (z*4))] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[((x.y.fused*12) + z)], (uint8)0) weight_block_reshape[(((x.y.fused*128) + (z*4)) + 1)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 3)], (uint8)0) weight_block_reshape[(((x.y.fused*128) + (z*4)) + 2)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 6)], (uint8)0) weight_block_reshape[(((x.y.fused*128) + (z*4)) + 3)] = tir.if_then_else(((x.y.fused < 3) && (z < 3)), placeholder[(((x.y.fused*12) + z) + 9)], (uint8)0) } } This is just a transformation of the data, why not optimize it? In other words, LLVM will be optimized in the future? --- [Visit Topic](https://discuss.tvm.apache.org/t/questions-about-conv2d-weight-transform/10835/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/e9ccaf71a2427a86efddf42ba6fd727cb32d6eefb7ce7a21f10186dd7be467e9).