I also wrote a a minimal example to reproduce the problem.

```
"""Test for NCHW[x]c convolution"""

import numpy as np
import tvm
from tvm import te
from tvm import autotvm
from tvm import topi
import tvm.testing
import tvm.topi.testing
from tvm.contrib.pickle_memoize import memoize
from tvm.topi.nn.util import get_pad_tuple
from tvm.topi.util import get_const_tuple
import os

import tvm
from tvm import te
from tvm import autotvm

from tvm.topi.cuda.injective import schedule_injective_from_existing
from tvm.topi.cuda.tensor_intrin import dp4a
from tvm.topi.nn.pad import pad
from tvm.topi.nn.util import get_pad_tuple3d
from tvm.topi.util import simplify, get_const_tuple, traverse_inline, tag

##########################################################################
#################### Operator and scheduler definition ###################
##########################################################################


def unpack_NCDHWc_to_ncdhw(packed_out, out_dtype):
    """Unpack conv3d_NCDHWc output from layout NCDHWc to NCDHW

    Parameters
    ----------
    packed_out : tvm.te.Tensor
        The output tensor of conv2d_NCHWc.

    out_dtype : str
        The output dtype.

    Returns
    -------
    unpacked_out : tvm.te.Tensor
        The unpacked output tensor in NCHW layout.
    """
    ######################################")

    n, oc_chunk, oz, oh, ow, oc_bn = get_const_tuple(packed_out.shape)

    idxmod = tvm.tir.indexmod
    idxdiv = tvm.tir.indexdiv

    oshape = (n, oc_chunk * oc_bn, oz, oh, ow)
    unpacked_out = te.compute(
        oshape,
        lambda n, c, z, h, w: packed_out[n, idxdiv(c, oc_bn), z, h, w, 
idxmod(c, oc_bn)].astype(
            out_dtype
        ),
        name="output_unpack",
        tag=tag.INJECTIVE + ",unpack_ncdhwc",
    )
    return unpacked_out


def conv3d_ncdhw_int8(data, kernel, strides, padding, dilation, 
out_dtype="int32"):
    """Compute conv3d internally using conv3d_ncdhwc layout for int8 dtype"""
    assert data.dtype in ("int8", "uint8")
    assert kernel.dtype in ("int8", "uint8")
    assert data.dtype == kernel.dtype
    packed_out = conv3d_NCDHWc_int8(data, kernel, strides, padding, dilation, 
"NCDHW", out_dtype)
    return unpack_NCDHWc_to_ncdhw(packed_out, out_dtype)


def schedule_conv3d_ncdhw_int8(outs):
    """Create schedule for tensors"""
    return schedule_conv3d_NCDHWc_int8(outs)


def conv3d_NCDHWc_int8(data, kernel, stride, padding, dilation, layout, 
out_dtype):
    """Convolution operator in NCDHW[x]c layout for int8."""
    cfg = autotvm.get_config()

    assert layout in ["NCDHW", "NCDHW4c"]

    ic_block_factor = 4
    oc_block_factor = 4

    pre_computed = len(kernel.shape) == 7
    if not pre_computed:
        batch, channels, depth, height, width = get_const_tuple(data.shape)
        assert (
            channels % ic_block_factor == 0
        ), "Number of input channels should be multiple of 
{}".format(ic_block_factor)
        packed_data = te.compute(
            (batch, channels // ic_block_factor, depth, height, width, 
ic_block_factor),
            lambda n, c, d, h, w, vc: data[n, c * ic_block_factor + vc, d, h, 
w],
            name="packed_data",
        )

        out_channels, in_channels, kernel_d, kernel_h, kernel_w = 
get_const_tuple(kernel.shape)
        assert out_channels % 4 == 0, "Number of output channels should be 
multiple of {}".format(
            oc_block_factor
        )
        packed_kernel = te.compute(
            (
                out_channels // oc_block_factor,
                in_channels // ic_block_factor,
                kernel_d,
                kernel_h,
                kernel_w,
                oc_block_factor,
                ic_block_factor,
            ),
            lambda oc_chunk, ic_chunk, kd, kh, kw, oc_block, ic_block: kernel[
                oc_chunk * oc_block_factor + oc_block,
                ic_chunk * ic_block_factor + ic_block,
                kd,
                kh,
                kw,
            ],
            name="packed_kernel",
        )

    else:
        packed_data = data
        packed_kernel = kernel

    batch, ic_chunk, in_depth, in_height, in_width, ic_block = 
get_const_tuple(packed_data.shape)
    oc_chunk, ic_chunk, kernel_d, kernel_h, kernel_w, oc_block, ic_block = 
get_const_tuple(
        packed_kernel.shape
    )
    assert isinstance(stride, int) or len(stride) == 3
    assert isinstance(dilation, int) or len(dilation) == 3

    if isinstance(stride, int):
        stride_d = stride_h = stride_w = stride
    else:
        stride_d, stride_h, stride_w = stride

    if isinstance(dilation, int):
        dilation_d = dilation_h = dilation_w = dilation
    else:
        dilation_d, dilation_h, dilation_w = dilation

    # # compute the output shape

    pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = 
get_pad_tuple3d(
        padding, (kernel_d, kernel_h, kernel_w)
    )
    # out_channel = num_filter
    out_depth = (in_depth - kernel_d + pad_front + pad_back) // stride_d + 1
    out_height = (in_height - kernel_h + pad_top + pad_down) // stride_h + 1
    out_width = (in_width - kernel_w + pad_left + pad_right) // stride_w + 1

    oshape = (batch, oc_chunk, out_depth, out_height, out_width, oc_block)
    # compute graph
    pad_before = [0, 0, pad_front, pad_top, pad_left, 0]
    pad_after = [0, 0, pad_back, pad_down, pad_right, 0]
    pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")

    icc = te.reduce_axis((0, ic_chunk), name="ic_chunk")
    icb = te.reduce_axis((0, ic_block), name="ic_block")
    rz = te.reduce_axis((0, kernel_d), name="rz")
    ry = te.reduce_axis((0, kernel_h), name="ry")
    rx = te.reduce_axis((0, kernel_w), name="rx")

    conv = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: te.sum(
            pad_data[
                nn,
                icc,
                zz * stride_d + rz * dilation_d,
                yy * stride_h + ry * dilation_h,
                xx * stride_w + rx * dilation_w,
                icb,
            ].astype("int32")
            * packed_kernel[oc_chunk, icc, rz, ry, rx, oc_block, 
icb].astype("int32"),
            axis=[icc, rz, ry, rx, icb],
        ),
    )

    output = te.compute(
        oshape,
        lambda nn, oc_chunk, zz, yy, xx, oc_block: conv[nn, oc_chunk, zz, yy, 
xx, oc_block].astype(
            out_dtype
        ),
        tag="conv3d_NCDHWc_int8",
    )

    # num flop
    num_flop = (
        batch
        * oc_chunk
        * oc_block
        * out_height
        * out_width
        * ic_chunk
        * ic_block
        * kernel_d
        * kernel_h
        * kernel_w
        * 2
    )
    cfg.add_flop(num_flop)

    return output


_dp4a = dp4a("shared", "shared", "local")


def schedule_conv3d_NCDHWc_int8(outs):
    """Schedule conv3d int8 NCDHWc template"""
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])

    def _callback(op):
        if op.tag == "conv3d_NCDHWc_int8":
            _schedule_conv3d_NCDHWc_int8(s, op.output(0), "NCDHW", 
"conv3d_NCDHWc_int8.cuda")

    traverse_inline(s, outs[0].op, _callback)
    return s


def _schedule_conv3d_NCDHWc_int8(s, output, layout, workload_name):

    cfg = autotvm.get_config()

    conv = output.op.input_tensors[0]
    packed_data, packed_kernel = conv.op.input_tensors

    if isinstance(packed_data.op, tvm.te.ComputeOp) and "pad" in 
packed_data.op.tag:
        pad_data = packed_data
        packed_data = pad_data.op.input_tensors[0]
    else:
        pad_data = packed_data

    if autotvm.GLOBAL_SCOPE.in_tuning:
        # skip this part during tuning to make recrods accurate
        # this part will be pre-computed during NNVM's pre-compute optimization 
pass
        s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
        s[packed_kernel].pragma(s[packed_kernel].op.axis[0], 
"debug_skip_region")
    else:
        if isinstance(packed_kernel.op, tvm.te.ComputeOp) and 
packed_kernel.name == "packed_kernel":
            # data and kernel are not pre-computed, schedule layout transform 
here
            schedule_injective_from_existing(s, packed_data)
            schedule_injective_from_existing(s, packed_kernel)
    if pad_data != packed_data:
        s[pad_data].compute_inline()

    AA = s.cache_read(pad_data, "shared", [conv])
    WW = s.cache_read(packed_kernel, "shared", [conv])

    s[conv].set_scope("local")

    # handle bias
    if output.op not in s.outputs:
        s[output].compute_inline()
        output = s.outputs[0].output(0)

    # tile and bind spatial axes
    if len(s[output].op.axis) == 6:
        n, f, d, y, x, c = s[output].op.axis
    else:
        # For task extraction of auto-tuning, the expected output is 4D.  Since 
auto-tuning tasks
        # are created from scratch, therefore the real auto-tuning will still 
happen on 5D output.
        n, f, d, y, x = s[output].op.axis

    cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
    cfg.define_split("tile_d", cfg.axis(d), num_outputs=4)
    cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
    cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

    kernel_scope, n = s[output].split(n, nparts=1)

    # bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
    bd, vd, td, di = cfg["tile_d"].apply(s, output, d)
    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

    s[output].reorder(bf, bd, by, bx, vf, vd, vy, vx, tf, td, ty, tx, fi, di, 
yi, xi)

    bf = s[output].fuse(n, bf)

    s[output].bind(bf, te.thread_axis("blockIdx.z"))
    s[output].bind(bd, te.thread_axis("blockIdx.y"))
    s[output].bind(s[output].fuse(by, bx), te.thread_axis("blockIdx.x"))
    s[output].bind(vf, te.thread_axis("vthread"))
    s[output].bind(vd, te.thread_axis("vthread"))
    s[output].bind(vy, te.thread_axis("vthread"))
    s[output].bind(vx, te.thread_axis("vthread"))

    cfg.define_knob("fuse_yx", [0, 1])  # fuse ty,tx or tn,tf
    if cfg["fuse_yx"].val:
        s[output].bind(tf, te.thread_axis("threadIdx.z"))
        s[output].bind(td, te.thread_axis("threadIdx.y"))
        tyx = s[output].fuse(ty, tx)
        s[output].bind(tyx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tyx)

        # number of threads
        n_tz = cfg["tile_f"].size[2]
        n_ty = cfg["tile_d"].size[2]
        n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
    else:
        s[output].bind(s[output].fuse(tf, td), te.thread_axis("threadIdx.z"))
        s[output].bind(ty, te.thread_axis("threadIdx.y"))
        s[output].bind(tx, te.thread_axis("threadIdx.x"))
        s[conv].compute_at(s[output], tx)

        # number of threads
        n_tz = cfg["tile_d"].size[2] * cfg["tile_f"].size[2]
        n_ty = cfg["tile_y"].size[2]
        n_tx = cfg["tile_x"].size[2]

    # tile reduction axes
    n, f, d, y, x, c = s[conv].op.axis
    rc, rd, ry, rx, rc_block = s[conv].op.reduce_axis

    cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=2)
    cfg.define_split("tile_rd", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=2)
    cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=2)
    rco, rci = cfg["tile_rc"].apply(s, conv, rc)
    rdo, rdi = cfg["tile_rd"].apply(s, conv, rd)
    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
    s[conv].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x, c, 
rc_block)

    cfg.define_reorder("reorder_inner", [rco, rdo, ryo, rxo], policy="all")
    cfg["reorder_inner"].apply(s, conv, [rco, rdo, ryo, rxo])
    cfg["reorder_inner"].apply(s, conv, [rci, rdi, ryi, rxi])

    _, rc_block = s[conv].split(rc_block, factor=4)
    s[conv].tensorize(rc_block, _dp4a)

    cache_loc = [rco, rdo, ryo, rxo][cfg["reorder_inner"].perm[-1]]
    s[AA].compute_at(s[conv], cache_loc)
    s[WW].compute_at(s[conv], cache_loc)

    # # cooperative fetching
    for load in [AA, WW]:

        c = s[load].op.axis[-1]
        c_outer, c = s[load].split(c, factor=4)
        s[load].vectorize(c)
        fused = s[load].op.axis[:-1] + [c_outer]
        fused = s[load].fuse(*fused)
        fused, tx = s[load].split(fused, factor=n_tx)
        fused, ty = s[load].split(fused, factor=n_ty)
        fused, tz = s[load].split(fused, factor=n_tz)
        s[load].bind(tz, te.thread_axis("threadIdx.z"))
        s[load].bind(ty, te.thread_axis("threadIdx.y"))
        s[load].bind(tx, te.thread_axis("threadIdx.x"))

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, "auto_unroll_max_step", 
cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)

    return s


##########################################################################
############################## Testing part ##############################
##########################################################################


@autotvm.template("tutorial/conv3d_int8")
def topi_conv(
    batch,
    in_channel,
    in_size,
    time_dim,
    num_filter,
    kernel,
    stride,
    padding,
    dilation=1,
    add_bias=False,
    add_relu=False,
    dtype="float32",
):

    A = te.placeholder(
        (batch, in_channel, time_dim, in_size, in_size),
        name="A",
        dtype="int8",
    )
    W = te.placeholder(
        (
            num_filter,
            in_channel,
            kernel,
            kernel,
            kernel,
        ),
        name="W",
        dtype="int8",
    )
    out = conv3d_ncdhw_int8(
        A,
        W,
        (stride, stride, stride),
        (padding, padding, padding),
        (dilation, dilation, dilation),
    )
    s = schedule_conv3d_NCDHWc_int8([out])

    # you can uncomment this line to see the generated code
    # print(tvm.lower(s, [A, W, out]))

    return s, [A, W, out]


(batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, 
dilation) = (
    1,
    128,
    56,
    18,
    128,
    3,
    1,
    1,
    1,
)
A = te.placeholder((batch, in_channel, time_dim, in_size, in_size), name="A")

W = te.placeholder(
    (
        num_filter,
        in_channel,
        kernel,
        kernel,
        kernel,
    ),
    name="W",
)
target = "cuda"


task = autotvm.task.create(
    "tutorial/conv3d_int8",
    (batch, in_channel, in_size, time_dim, num_filter, kernel, stride, padding, 
dilation),
    target=target,
)

measure_option = autotvm.measure_option(
    builder=autotvm.LocalBuilder(),
    runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=10, timeout=50),
)
tuner = autotvm.tuner.XGBTuner(task)

# uncomment if you want to tune the 3d convolution

# tuner.tune(
#     n_trial=10,
#     measure_option=measure_option,
#     callbacks=[
#         autotvm.callback.progress_bar(10, prefix="convolution"),
#         autotvm.callback.log_to_file("convolution.log"),
#     ],
# )

# A example of configuration that does not work:
# {"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", 
"tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": 
{"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 
8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], 
["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 
1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", 
[-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 
1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], 
"version": 0.2, "tvm_version": "0.8.dev1"}

with autotvm.apply_history_best("convolution.log"):
    with tvm.target.Target(target):
        s, arg_bufs = topi_conv(
            batch, in_channel, in_size, time_dim, num_filter, kernel, stride, 
padding, dilation
        )
        func = tvm.build(s, arg_bufs, target=target)
```

The tuner of tvm will explore several configurations, and pick the best one.

For reproducibility purpose, I also provide you a example of configuration that 
makes tvm crash:

```
{"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", 
"tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": 
{"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 
8, 2]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], 
["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 
1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", 
[-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 
1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], 
"version": 0.2, "tvm_version": "0.8.dev1"}
```

You can write this configuration manually in the log file, and you will get the 
following result

```
[12:36:52] /usr/tvm/src/tir/transforms/loop_partition.cc:548: Cannot prove: 
((((((floordiv(((threadIdx.z*2) + 1), 4) + 1) - floordiv(threadIdx.z, 2)) - 1) 
- (29 - (blockIdx.z*4))) + 1) >= 0), when generating the post doubt loop
Traceback (most recent call last):
  File "test_3dconv_optimization.py", line 491, in <module>
    func = tvm.build(s, arg_bufs, target=target)
  File "/usr/tvm/python/tvm/driver/build_module.py", line 414, in build
    mod_host, mdev = _build_for_device(input_mod, tar, target_host)
  File "/usr/tvm/python/tvm/driver/build_module.py", line 256, in 
_build_for_device
    mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
  File "/usr/tvm/python/tvm/ir/transform.py", line 127, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 322, in 
tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in 
tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in 
tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (6) /usr/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f3e6c5faf51]
  [bt] (5) /usr/tvm/build/libtvm.so(+0x644907) [0x7f3e6ba2f907]
  [bt] (4) 
/usr/tvm/build/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule,
 tvm::transform::PassContext const&) const+0x40e) [0x7f3e6ba2ed1e]
  [bt] (3) 
/usr/tvm/build/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule,
 tvm::transform::PassContext const&) const+0x1e2) [0x7f3e6ba2ce52]
  [bt] (2) /usr/tvm/build/libtvm.so(+0x8d347c) [0x7f3e6bcbe47c]
  [bt] (1) 
/usr/tvm/build/libtvm.so(tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&, 
int)+0x2d19) [0x7f3e6bcbb7a9]
  [bt] (0) 
/usr/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x61) 
[0x7f3e6b925f91]
  File "/usr/tvm/src/tir/transforms/make_packed_api.cc", line 210
TVMError: Not all Vars are passed in api_args:  'threadIdx.z'  is not bound to 
any variables
```

As mentioned previously, you make this configuration valid by changing the 
value of tile_f into

```
{"input": ["cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", 
"tutorial/conv3d_int8", [1, 128, 56, 18, 128, 3, 1, 1, 1], {}], "config": 
{"index": 77070610321, "code_hash": null, "entity": [["tile_f", "sp", [-1, 1, 
8, 1]], ["tile_d", "sp", [-1, 1, 1, 2]], ["tile_y", "sp", [-1, 1, 7, 2]], 
["tile_x", "sp", [-1, 2, 1, 1]], ["fuse_yx", "ot", 0], ["tile_rc", "sp", [-1, 
1]], ["tile_rd", "sp", [-1, 1]], ["tile_ry", "sp", [-1, 1]], ["tile_rx", "sp", 
[-1, 1]], ["reorder_inner", "re", [1, 2, 0, 3]], ["auto_unroll_max_step", "ot", 
1500]]}, "result": [[0.0027175069], 0, 11.701743602752686, 1603898087.1376908], 
"version": 0.2, "tvm_version": "0.8.dev1"}
```





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/quantization-and-3d-convolution/8338/7) 
to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/16aa416a178237edf8a0b5a9ad2603dbb0b66a25908318a57e465a96d4501984).

Reply via email to