I also wrote a a minimal example to reproduce the problem.
``` """Test for NCHW[x]c convolution""" import numpy as np import tvm from tvm import te from tvm import autotvm from tvm import topi import tvm.testing import tvm.topi.testing from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.util import get_pad_tuple from tvm.topi.util import get_const_tuple import os import tvm from tvm import te from tvm import autotvm from tvm.topi.cuda.injective import schedule_injective_from_existing from tvm.topi.cuda.tensor_intrin import dp4a from tvm.topi.nn.pad import pad from tvm.topi.nn.util import get_pad_tuple3d from tvm.topi.util import simplify, get_const_tuple, traverse_inline, tag ########################################################################## #################### Operator and scheduler definition ################### ########################################################################## def unpack_NCDHWc_to_ncdhw(packed_out, out_dtype): """Unpack conv3d_NCDHWc output from layout NCDHWc to NCDHW Parameters ---------- packed_out : tvm.te.Tensor The output tensor of conv2d_NCHWc. out_dtype : str The output dtype. Returns ------- unpacked_out : tvm.te.Tensor The unpacked output tensor in NCHW layout. """ ######################################") n, oc_chunk, oz, oh, ow, oc_bn = get_const_tuple(packed_out.shape) idxmod = tvm.tir.indexmod idxdiv = tvm.tir.indexdiv oshape = (n, oc_chunk * oc_bn, oz, oh, ow) unpacked_out = te.compute( oshape, lambda n, c, z, h, w: packed_out[n, idxdiv(c, oc_bn), z, h, w, idxmod(c, oc_bn)].astype( out_dtype ), name="output_unpack", tag=tag.INJECTIVE + ",unpack_ncdhwc", ) return unpacked_out def conv3d_ncdhw_int8(data, kernel, strides, padding, dilation, out_dtype="int32"): """Compute conv3d internally using conv3d_ncdhwc layout for int8 dtype""" assert data.dtype in ("int8", "uint8") assert kernel.dtype in ("int8", "uint8") assert data.dtype == kernel.dtype packed_out = conv3d_NCDHWc_int8(data, kernel, strides, padding, dilation, "NCDHW", out_dtype) return unpack_NCDHWc_to_ncdhw(packed_out, out_dtype) def schedule_conv3d_ncdhw_int8(outs): """Create schedule for tensors""" return schedule_conv3d_NCDHWc_int8(outs) def conv3d_NCDHWc_int8(data, kernel, stride, padding, dilation, layout, out_dtype): """Convolution operator in NCDHW[x]c layout for int8.""" cfg = autotvm.get_config() assert layout in ["NCDHW", "NCDHW4c"] ic_block_factor = 4 oc_block_factor = 4 pre_computed = len(kernel.shape) == 7 if not pre_computed: batch, channels, depth, height, width = get_const_tuple(data.shape) assert ( channels % ic_block_factor == 0 ), "Number of input channels should be multiple of {}".format(ic_block_factor) packed_data = te.compute( (batch, channels // ic_block_factor, depth, height, width, ic_block_factor), lambda n, c, d, h, w, vc: data[n, c * ic_block_factor + vc, d, h, w], name="packed_data", ) out_channels, in_channels, kernel_d, kernel_h, kernel_w = get_const_tuple(kernel.shape) assert out_channels % 4 == 0, "Number of output channels should be multiple of {}".format( oc_block_factor ) packed_kernel = te.compute( ( out_channels // oc_block_factor, in_channels // ic_block_factor, kernel_d, kernel_h, kernel_w, oc_block_factor, ic_block_factor, ), lambda oc_chunk, ic_chunk, kd, kh, kw, oc_block, ic_block: kernel[ oc_chunk * oc_block_factor + oc_block, ic_chunk * ic_block_factor + ic_block, kd, kh, kw, ], name="packed_kernel", ) else: packed_data = data packed_kernel = kernel batch, ic_chunk, in_depth, in_height, in_width, ic_block = get_const_tuple(packed_data.shape) oc_chunk, ic_chunk, kernel_d, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple( packed_kernel.shape ) assert isinstance(stride, int) or len(stride) == 3 assert isinstance(dilation, int) or len(dilation) == 3 if isinstance(stride, int): stride_d = stride_h = stride_w = stride else: stride_d, stride_h, stride_w = stride if isinstance(dilation, int): dilation_d = dilation_h = dilation_w = dilation else: dilation_d, dilation_h, dilation_w = dilation # # compute the output shape pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d( padding, (kernel_d, kernel_h, kernel_w) ) # out_channel = num_filter out_depth = (in_depth - kernel_d + pad_front + pad_back) // stride_d + 1 out_height = (in_height - kernel_h + pad_top + pad_down) // stride_h + 1 out_width = (in_width - kernel_w + pad_left + pad_right) // stride_w + 1 oshape = (batch, oc_chunk, out_depth, out_height, out_width, oc_block) # compute graph pad_before = [0, 0, pad_front, pad_top, pad_left, 0] pad_after = [0, 0, pad_back, pad_down, pad_right, 0] pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") icc = te.reduce_axis((0, ic_chunk), name="ic_chunk") icb = te.reduce_axis((0, ic_block), name="ic_block") rz = te.reduce_axis((0, kernel_d), name="rz") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") conv = te.compute( oshape, lambda nn, oc_chunk, zz, yy, xx, oc_block: te.sum( pad_data[ nn, icc, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, icb, ].astype("int32") * packed_kernel[oc_chunk, icc, rz, ry, rx, oc_block, icb].astype("int32"), axis=[icc, rz, ry, rx, icb], ), ) output = te.compute( oshape, lambda nn, oc_chunk, zz, yy, xx, oc_block: conv[nn, oc_chunk, zz, yy, xx, oc_block].astype( out_dtype ), tag="conv3d_NCDHWc_int8", ) # num flop num_flop = ( batch * oc_chunk * oc_block * out_height * out_width * ic_chunk * ic_block * kernel_d * kernel_h * kernel_w * 2 ) cfg.add_flop(num_flop) return output _dp4a = dp4a("shared", "shared", "local") def schedule_conv3d_NCDHWc_int8(outs): """Schedule conv3d int8 NCDHWc template""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) def _callback(op): if op.tag == "conv3d_NCDHWc_int8": _schedule_conv3d_NCDHWc_int8(s, op.output(0), "NCDHW", "conv3d_NCDHWc_int8.cuda") traverse_inline(s, outs[0].op, _callback) return s def _schedule_conv3d_NCDHWc_int8(s, output, layout, workload_name): cfg = autotvm.get_config() conv = output.op.input_tensors[0] packed_data, packed_kernel = conv.op.input_tensors if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in packed_data.op.tag: pad_data = packed_data packed_data = pad_data.op.input_tensors[0] else: pad_data = packed_data if autotvm.GLOBAL_SCOPE.in_tuning: # skip this part during tuning to make recrods accurate # this part will be pre-computed during NNVM's pre-compute optimization pass s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region") s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region") else: if isinstance(packed_kernel.op, tvm.te.ComputeOp) and packed_kernel.name == "packed_kernel": # data and kernel are not pre-computed, schedule layout transform here schedule_injective_from_existing(s, packed_data) schedule_injective_from_existing(s, packed_kernel) if pad_data != packed_data: s[pad_data].compute_inline() AA = s.cache_read(pad_data, "shared", [conv]) WW = s.cache_read(packed_kernel, "shared", [conv]) s[conv].set_scope("local") # handle bias if output.op not in s.outputs: s[output].compute_inline() output = s.outputs[0].output(0) # tile and bind spatial axes if len(s[output].op.axis) == 6: n, f, d, y, x, c = s[output].op.axis else: # For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks # are created from scratch, therefore the real auto-tuning will still happen on 5D output. n, f, d, y, x = s[output].op.axis cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_d", cfg.axis(d), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) kernel_scope, n = s[output].split(n, nparts=1) # bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) bd, vd, td, di = cfg["tile_d"].apply(s, output, d) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) s[output].reorder(bf, bd, by, bx, vf, vd, vy, vx, tf, td, ty, tx, fi, di, yi, xi) bf = s[output].fuse(n, bf) s[output].bind(bf, te.thread_axis("blockIdx.z")) s[output].bind(bd, te.thread_axis("blockIdx.y")) s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x")) s[output].bind(vf, te.thread_axis("vthread")) s[output].bind(vd, te.thread_axis("vthread")) s[output].bind(vy, te.thread_axis("vthread")) s[output].bind(vx, te.thread_axis("vthread")) cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf if cfg["fuse_yx"].val: s[output].bind(tf, te.thread_axis("threadIdx.z")) s[output].bind(td, te.thread_axis("threadIdx.y")) tyx = s[output].fuse(ty, tx) s[output].bind(tyx, te.thread_axis("threadIdx.x")) s[conv].compute_at(s[output], tyx) # number of threads n_tz = cfg["tile_f"].size[2] n_ty = cfg["tile_d"].size[2] n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] else: s[output].bind(s[output].fuse(tf, td), te.thread_axis("threadIdx.z")) s[output].bind(ty, te.thread_axis("threadIdx.y")) s[output].bind(tx, te.thread_axis("threadIdx.x")) s[conv].compute_at(s[output], tx) # number of threads n_tz = cfg["tile_d"].size[2] * cfg["tile_f"].size[2] n_ty = cfg["tile_y"].size[2] n_tx = cfg["tile_x"].size[2] # tile reduction axes n, f, d, y, x, c = s[conv].op.axis rc, rd, ry, rx, rc_block = s[conv].op.reduce_axis cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2) cfg.define_split("tile_rd", cfg.axis(ry), num_outputs=2) cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2) cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2) rco, rci = cfg["tile_rc"].apply(s, conv, rc) rdo, rdi = cfg["tile_rd"].apply(s, conv, rd) ryo, ryi = cfg["tile_ry"].apply(s, conv, ry) rxo, rxi = cfg["tile_rx"].apply(s, conv, rx) s[conv].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x, c, rc_block) cfg.define_reorder("reorder_inner", [rco, rdo, ryo, rxo], policy="all") cfg["reorder_inner"].apply(s, conv, [rco, rdo, ryo, rxo]) cfg["reorder_inner"].apply(s, conv, [rci, rdi, ryi, rxi]) _, rc_block = s[conv].split(rc_block, factor=4) s[conv].tensorize(rc_block, _dp4a) cache_loc = [rco, rdo, ryo, rxo][cfg["reorder_inner"].perm[-1]] s[AA].compute_at(s[conv], cache_loc) s[WW].compute_at(s[conv], cache_loc) # # cooperative fetching for load in [AA, WW]: c = s[load].op.axis[-1] c_outer, c = s[load].split(c, factor=4) s[load].vectorize(c) fused = s[load].op.axis[:-1] + [c_outer] fused = s[load].fuse(*fused) fused, tx = s[load].split(fused, factor=n_tx) fused, ty = s[load].split(fused, factor=n_ty) fused, tz = s[load].split(fused, factor=n_tz) s[load].bind(tz, te.thread_axis("threadIdx.z")) s[load].bind(ty, te.thread_axis("threadIdx.y")) s[load].bind(tx, te.thread_axis("threadIdx.x")) # unroll cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[output].pragma(kernel_scope, "unroll_explicit", False) return s ########################################################################## ############################## Testing part ############################## ########################################################################## @autotvm.template("tutorial/conv3d_int8") def topi_conv( batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, dtype="float32", ): A = te.placeholder( (batch, in_channel, time_dim, in_size, in_size), name="A", dtype="int8", ) W = te.placeholder( ( num_filter, in_channel, kernel, kernel, kernel, ), name="W", dtype="int8", ) out = conv3d_ncdhw_int8( A, W, (stride, stride, stride), (padding, padding, padding), (dilation, dilation, dilation), ) s = schedule_conv3d_NCDHWc_int8([out]) # you can uncomment this line to see the generated code # print(tvm.lower(s, [A, W, out])) return s, [A, W, out] (batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation) = ( 1, 128, 56, 18, 128, 3, 1, 1, 1, ) A = te.placeholder((batch, in_channel, time_dim, in_size, in_size), name="A") W = te.placeholder( ( num_filter, in_channel, kernel, kernel, kernel, ), name="W", ) target = "cuda" task = autotvm.task.create( "tutorial/conv3d_int8", (batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation), target=target, ) measure_option = autotvm.measure_option( builder=autotvm.LocalBuilder(), runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=10, timeout=50), ) tuner = autotvm.tuner.XGBTuner(task) # uncomment if you want to tune the 3d convolution # tuner.tune( # n_trial=10, # measure_option=measure_option, # callbacks=[ # autotvm.callback.progress_bar(10, prefix="convolution"), # autotvm.callback.log_to_file("convolution.log"), # ], # ) # A example of configuration that does not work: # {"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", "tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev1"} with autotvm.apply_history_best("convolution.log"): with tvm.target.Target(target): s, arg_bufs = topi_conv( batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, dilation ) func = tvm.build(s, arg_bufs, target=target) ``` The tuner of tvm will explore several configurations, and pick the best one. For reproducibility purpose, I also provide you a example of configuration that makes tvm crash: ``` {"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", "tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev1"} ``` You can write this configuration manually in the log file, and you will get the following result ``` [12:36:52] /usr/tvm/src/tir/transforms/loop_partition.cc:548: Cannot prove: ((((((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)) - 1) - (29 - (blockIdx.z*4))) + 1) >= 0), when generating the post doubt loop Traceback (most recent call last): File "test_3dconv_optimization.py", line 491, in <module> func = tvm.build(s, arg_bufs, target=target) File "/usr/tvm/python/tvm/driver/build_module.py", line 414, in build mod_host, mdev = _build_for_device(input_mod, tar, target_host) File "/usr/tvm/python/tvm/driver/build_module.py", line 256, in _build_for_device mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed) File "/usr/tvm/python/tvm/ir/transform.py", line 127, in __call__ return _ffi_transform_api.RunPass(self, mod) File "tvm/_ffi/_cython/./packed_func.pxi", line 322, in tvm._ffi._cy3.core.PackedFuncBase.__call__ File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3 File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL tvm._ffi.base.TVMError: Traceback (most recent call last): [bt] (6) /usr/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f3e6c5faf51] [bt] (5) /usr/tvm/build/libtvm.so(+0x644907) [0x7f3e6ba2f907] [bt] (4) /usr/tvm/build/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x40e) [0x7f3e6ba2ed1e] [bt] (3) /usr/tvm/build/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1e2) [0x7f3e6ba2ce52] [bt] (2) /usr/tvm/build/libtvm.so(+0x8d347c) [0x7f3e6bcbe47c] [bt] (1) /usr/tvm/build/libtvm.so(tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&, int)+0x2d19) [0x7f3e6bcbb7a9] [bt] (0) /usr/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x61) [0x7f3e6b925f91] File "/usr/tvm/src/tir/transforms/make_packed_api.cc", line 210 TVMError: Not all Vars are passed in api_args: 'threadIdx.z' is not bound to any variables ``` As mentioned previously, you make this configuration valid by changing the value of tile_f into ``` {"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", "tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": {"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 8, 1]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], ["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", [-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], "version": 0.2, "tvm_version": "0.8.dev1"} ``` --- [Visit Topic](https://discuss.tvm.apache.org/t/quantization-and-3d-convolution/8338/7) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/16aa416a178237edf8a0b5a9ad2603dbb0b66a25908318a57e465a96d4501984).