My machine has a V100 with cuda 10.2. The tutorial cannot generate code using 
tensor core primitives but ordinary cuda code. Can anyone help to solve this 
situation? Thanks!
Here is the code I try to run:

    import logging
    import sys

    import numpy as np
    import tvm
    from tvm import te
    import tvm.testing

    from tvm import autotvm
    from tvm.contrib import nvcc


    def matmul_nn(A, B, L, dtype="float16", layout="NN"):
        k = te.reduce_axis((0, L), name="k")
        if dtype == "float16":
            out_type = "float"
        elif dtype == "int8":
            out_type = "int"
        elif dtype == "int4" or dtype == "int1":
            out_type = "int"
        if layout == "NN":
            return te.compute(
                (N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[k, 
j].astype(out_type), axis=k)
            )
        if layout == "NT":
            return te.compute(
                (N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[k, 
j].astype(out_type), axis=k)
            )
        if layout == "TN":
            return te.compute(
                (N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[j, 
k].astype(out_type), axis=k)
            )
        if layout == "TT":
            return te.compute(
                (N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[j, 
k].astype(out_type), axis=k)
            )

    def test_gemm(N, L, M, dtype, layout):
        if layout == "NN":
            shape_a = (N, L)
            shape_b = (L, M)
        elif layout == "NT":
            shape_a = (L, N)
            shape_b = (L, M)
        elif layout == "TN":
            shape_a = (N, L)
            shape_b = (M, L)
        elif layout == "TT":
            shape_a = (L, N)
            shape_b = (M, L)
        else:
            print("Unsupported layout:", layout)
            sys.exit(1)
        A = te.placeholder(shape_a, name="A", dtype=dtype)
        B = te.placeholder(shape_b, name="B", dtype=dtype)
        C = matmul_nn(A, B, L, dtype, layout)

        s = te.create_schedule(C.op)
        y, x = s[C].op.axis
        k = s[C].op.reduce_axis[0]

        # storage_align params
        factor = 16
        offset = 8
        if dtype == "int8":
            factor = 32
            offset = 16
        elif dtype == "int4":
            factor = 64
            offset = 32
        elif dtype == "int1":
            factor = 256
            offset = 128

        # create cache stages
        AA = s.cache_read(A, "shared", [C])
        if layout == "NN" or layout == "TN":
            s[AA].storage_align(AA.op.axis[0], factor, offset)
        AL = s.cache_read(AA, "local", [C])
        BB = s.cache_read(B, "shared", [C])
        if layout == "TT" or layout == "NT":
            s[BB].storage_align(BB.op.axis[0], factor, offset)
        BL = s.cache_read(BB, "local", [C])
        CL = s.cache_write(C, "local")

        bx = 4
        by = 32
        step_k = 16
        v = 8

        # thread tile
        TX = 8
        TY = 1
        if dtype == "int4" or dtype == "int1":
            TX = 2
        # warp tile
        warp_tile_m = 16  # it could also be 8 or 32 on CUDA version >= 10.0
        warp_tile_k = 16  # it must be 16 for fp16/int8 data type
        if dtype == "int4":
            warp_tile_m = 8
            warp_tile_k = 32
        elif dtype == "int1":
            warp_tile_m = 8
            warp_tile_k = 128
        # block tile
        tile_x = bx * TX
        tile_y = by * TY

        yo, ty = s[C].split(y, tile_y)
        ty, yi = s[C].split(ty, TY)

        # schedule for C stage
        xo, xi = s[C].split(x, tile_x)
        WX = min(warp_tile_m, tile_x)
        tz, xi = s[C].split(xi, WX)
        tx, xi = s[C].split(xi, TX)
        s[C].reorder(yo, xo, tz, ty, tx, yi, xi)
        s[C].bind(yo, te.thread_axis("blockIdx.y"))
        s[C].bind(xo, te.thread_axis("blockIdx.x"))
        s[C].bind(ty, te.thread_axis("threadIdx.y"))
        s[C].bind(tz, te.thread_axis("threadIdx.z"))
        s[C].bind(tx, te.thread_axis("threadIdx.x"))

        # schedule for CL stage
        ko, ki = s[CL].split(k, step_k * warp_tile_k)
        kl, ki = s[CL].split(ki, warp_tile_k)
        s[CL].compute_at(s[C], tx)
        yo, xo = CL.op.axis
        s[CL].reorder(ko, kl, ki, yo, xo)

        # schedule for AA stage
        s[AA].compute_at(s[CL], ko)
        xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx * v)
        tz, tx = s[AA].split(xi, factor=(WX // TX) * v)
        tx, vec = s[AA].split(tx, factor=v)
        fused = s[AA].fuse(s[AA].op.axis[0], xo)
        _, ty = s[AA].split(fused, factor=by)
        s[AA].bind(ty, te.thread_axis("threadIdx.y"))
        s[AA].bind(tz, te.thread_axis("threadIdx.z"))
        s[AA].bind(tx, te.thread_axis("threadIdx.x"))
        # vectorization is very important for float16/int8 inputs
        s[AA].vectorize(vec)

        # schedule for BB stage
        s[BB].compute_at(s[CL], ko)
        xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx * v)
        tz, tx = s[BB].split(xi, factor=(WX // TX) * v)
        tx, vec = s[BB].split(tx, factor=v)
        fused = s[BB].fuse(s[BB].op.axis[0], xo)
        _, ty = s[BB].split(fused, factor=by)
        s[BB].bind(ty, te.thread_axis("threadIdx.y"))
        s[BB].bind(tz, te.thread_axis("threadIdx.z"))
        s[BB].bind(tx, te.thread_axis("threadIdx.x"))
        s[BB].vectorize(vec)

        s[AL].compute_at(s[CL], kl)
        s[BL].compute_at(s[CL], kl)

        # set the 'tensor_core' pragma for tensorcore codegen
        s[CL].pragma(ko, "tensor_core")

        return s, [A, B, C]


    ctx = tvm.gpu()
    if not nvcc.have_tensorcore(ctx.compute_version):
        raise Exception("the gpu has no tensorcore, skipping...")

    M, N, L = 512, 32, 512
    dtype = "float16"
    layout = "NN"
    if len(sys.argv) >= 4:
        M, N, L = int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3])
    if len(sys.argv) >= 5:
        dtype = sys.argv[4]
    if len(sys.argv) >= 6:
        layout = sys.argv[5]

    # check whether current gpu arch support support current dtype's wmma 
codegen
    cuda_compute_capability = tvm.runtime._ffi_api.GetDeviceAttr(2, 0, 4)
    major, minor = nvcc.parse_compute_version(cuda_compute_capability)
    if dtype == "int8":
        assert major == 7 and minor >= 2
    elif dtype == "int4" or dtype == "int1":
        # int4/int1 only support layout TN
        assert major == 7 and minor == 5 and layout == "TN"


    def evaluate(M, N, L, dtype, layout):
        with tvm.target.Target("cuda"):
            s, arg_bufs = test_gemm(N, L, M, dtype, layout)
            print(tvm.lower(s, arg_bufs, simple_mode=True))
            func = tvm.build(s, arg_bufs)
        dev_module = func.imported_modules[0]
        print(dev_module.get_source())

        # check correctness
        if layout == "NN":
            shape_a = (N, L)
            shape_b = (L, M)
        elif layout == "NT":
            shape_a = (L, N)
            shape_b = (L, M)
        elif layout == "TN":
            shape_a = (N, L)
            shape_b = (M, L)
        elif layout == "TT":
            shape_a = (L, N)
            shape_b = (M, L)

        a_np = None
        b_np = None
        c_np = None
        c_np_type = None
        if dtype == "float16":
            c_np_type = np.float32
            a_np = np.random.uniform(size=shape_a).astype(np.float16)
            b_np = np.random.uniform(size=shape_b).astype(np.float16)
            if layout == "NN":
                c_np = np.dot(a_np, b_np)
            elif layout == "NT":
                c_np = np.dot(a_np.T, b_np)
            elif layout == "TN":
                c_np = np.dot(a_np, b_np.T)
            elif layout == "TT":
                c_np = np.dot(a_np.T, b_np.T)
        elif dtype == "int8":
            c_np_type = np.int32
            a_np = np.random.randint(low=-128, high=127, 
size=shape_a).astype(np.int8)
            b_np = np.random.randint(low=-128, high=127, 
size=shape_b).astype(np.int8)
            if layout == "NN":
                c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32))
            elif layout == "NT":
                c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32))
            elif layout == "TN":
                c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
            elif layout == "TT":
                c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
        elif dtype == "int4":
            c_np_type = np.int32
            a_np_int = np.random.randint(low=-8, high=7, 
size=shape_a).astype(np.int32)
            b_np_int = np.random.randint(low=-8, high=7, 
size=shape_b).astype(np.int32)
            # "TN"
            c_np = np.dot(a_np_int.astype(np.int32), 
b_np_int.astype(np.int32).T)
            a_np = np.zeros(shape=(N, int(L / 8)), dtype=np.int32)
            b_np = np.zeros(shape=(M, int(L / 8)), dtype=np.int32)
            # a_np --> col_major
            for i in range(N):
                for j in range(int(L / 8)):
                    for k in range(8):
                        a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 8 + k] & 
0xF) << ((7 - k) * 4))

            # b_np --> row_major
            for i in range(M):
                for j in range(int(L / 8)):
                    for k in range(8):
                        b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 8 + k] & 
0xF) << ((7 - k) * 4))
        elif dtype == "int1":
            c_np_type = np.int32
            a_np_int = np.random.randint(low=0, high=1, 
size=shape_a).astype(np.int32)
            b_np_int = np.random.randint(low=0, high=1, 
size=shape_b).astype(np.int32)
            # "TN"
            c_np = np.dot(a_np_int.astype(np.int32), 
b_np_int.astype(np.int32).T)
            a_np = np.zeros(shape=(N, int(L / 32)), dtype=np.int32)
            b_np = np.zeros(shape=(M, int(L / 32)), dtype=np.int32)
            for i in range(N):
                for j in range(int(L / 32)):
                    for k in range(32):
                        a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 32 + k] & 
0xF) << (31 - k))

            for i in range(M):
                for j in range(int(L / 32)):
                    for k in range(32):
                        b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 32 + k] & 
0xF) << (31 - k))

        c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
        a_tvm = tvm.nd.array(a_np, ctx=ctx)
        b_tvm = tvm.nd.array(b_np, ctx=ctx)
        func(a_tvm, b_tvm, c_tvm)

        tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-3)

        evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
        print("Time cost of this operator: %f" % evaluator(a_tvm, b_tvm, 
c_tvm).mean)


    evaluate(M, N, L, dtype, layout)

And the output is:

    #[version = "0.0.5"]
    primfn(A_1: handle, B_1: handle, compute_1: handle) -> ()
      attr = {"global_symbol": "main", "tir.noalias": True}
      buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [32, 
512], []),
                B: Buffer(B_2: Pointer(float16), float16, [512, 512], []),
                A: Buffer(A_2: Pointer(float16), float16, [32, 512], [])}
      buffer_map = {A_1: A, B_1: B, compute_1: compute} {
      attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] 
"thread_extent" = 1;
      attr [compute.local: Pointer(float32)] "storage_scope" = "local";
      allocate(compute.local, float32, [8]);
      attr [A.shared: Pointer(float16)] "storage_scope" = "shared";
      allocate(A.shared, float16, [8448]);
      attr [B.shared: Pointer(float16)] "storage_scope" = "shared";
      allocate(B.shared, float16, [8192]);
      attr [A.shared.local: Pointer(float16)] "storage_scope" = "local";
      allocate(A.shared.local, float16, [16]);
      attr [B.shared.local: Pointer(float16)] "storage_scope" = "local";
      allocate(B.shared.local, float16, [128]);
      attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] 
"thread_extent" = 16;
      attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", 
"threadIdx.z")] "thread_extent" = 2;
      attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", 
"threadIdx.y")] "thread_extent" = 32;
      attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 2 {
        for (j.c.init: int32, 0, 8) {
          compute.local[j.c.init] = 0f32
        }
        attr [IterVar(k.outer: int32, (nullptr), "CommReduce", "")] 
"pragma_tensor_core" = 1;
        for (k.outer, 0, 2) {
          for (ax0.ax1.outer.fused.outer: int32, 0, 8) {
            attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", 
"threadIdx.y")] "thread_extent" = 32;
            attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", 
"threadIdx.z")] "thread_extent" = 2;
            attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 2;
            A.shared[ramp((((((ax0.ax1.outer.fused.outer*1056) + 
(floordiv(threadIdx.y_1, 8)*264)) + (floormod(threadIdx.y_1, 8)*32)) + 
(threadIdx.z_1*16)) + (threadIdx.x_1*8)), 1, 8)] = 
(float16x8*)A_2[ramp(((((((ax0.ax1.outer.fused.outer*2048) + 
(floordiv(threadIdx.y_1, 8)*512)) + (k.outer*256)) + (floormod(threadIdx.y_1, 
8)*32)) + (threadIdx.z_1*16)) + (threadIdx.x_1*8)), 1, 8)]
          }
          for (ax0.ax1.outer.fused.outer_1: int32, 0, 8) {
            attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", 
"threadIdx.y")] "thread_extent" = 32;
            attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", 
"threadIdx.z")] "thread_extent" = 2;
            attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 2;
            B.shared[ramp(((((ax0.ax1.outer.fused.outer_1*1024) + 
(threadIdx.y_2*32)) + (threadIdx.z_2*16)) + (threadIdx.x_2*8)), 1, 8)] = 
(float16x8*)B_2[ramp(((((((k.outer*131072) + 
(ax0.ax1.outer.fused.outer_1*16384)) + (threadIdx.y_2*512)) + (blockIdx.x*32)) 
+ (threadIdx.z_2*16)) + (threadIdx.x_2*8)), 1, 8)]
          }
          for (k.inner.outer: int32, 0, 16) {
            for (ax1: int32, 0, 16) {
              A.shared.local[ax1] = (float16*)A.shared[(((threadIdx.y*264) + 
(k.inner.outer*16)) + ax1)]
            }
            for (ax0: int32, 0, 16) {
              for (ax1_1: int32, 0, 8) {
                B.shared.local[((ax0*8) + ax1_1)] = 
(float16*)B.shared[(((((k.inner.outer*512) + (ax0*32)) + (threadIdx.z*16)) + 
(threadIdx.x*8)) + ax1_1)]
              }
            }
            for (k.inner.inner: int32, 0, 16) {
              for (j.c: int32, 0, 8) {
                compute.local[j.c] = ((float32*)compute.local[j.c] + 
(cast(float32, (float16*)A.shared.local[k.inner.inner])*cast(float32, 
(float16*)B.shared.local[((k.inner.inner*8) + j.c)])))
              }
            }
          }
        }
        for (j.inner.inner.inner: int32, 0, 8) {
          compute_2[(((((threadIdx.y*512) + (blockIdx.x*32)) + 
(threadIdx.z*16)) + (threadIdx.x*8)) + j.inner.inner.inner)] = 
(float32*)compute.local[j.inner.inner.inner]
        }
      }
    }

    #[metadata]
    {
      "root": 1, 
      "nodes": [
        {
          "type_key": ""
        }, 
        {
          "type_key": "Map", 
          "keys": [
            "IntImm"
          ], 
          "data": [2]
        }, 
        {
          "type_key": "Array", 
          "data": [3]
        }, 
        {
          "type_key": "IntImm", 
          "attrs": {
            "dtype": "bool", 
            "value": "1"
          }
        }
      ], 
      "b64ndarrays": [], 
      "attrs": {"tvm_version": "0.8.dev0"}
    }
    #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
    #include <cuda_fp16.h>
    __device__ half max(half a, half b)
    {
      return __hgt(__half(a), __half(b)) ? a : b;
    }
    __device__ half min(half a, half b)
    {
      return __hlt(__half(a), __half(b)) ? a : b;
    }
    #else

    typedef unsigned short uint16_t;
    typedef unsigned char uint8_t;
    typedef signed char int8_t;
    typedef int int32_t;
    typedef unsigned long long uint64_t;
    typedef unsigned int uint32_t;

    #define TVM_FORCE_INLINE inline __attribute__((always_inline))
    #define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
    #define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
    #define TVM_HALF_OPERATOR(RTYPE, OP)                              \
      TVM_XINLINE RTYPE operator OP (half a, half b) {                \
        return RTYPE(float(a) OP float(b));                           \
      }                                                               \
      template<typename T>                                            \
      TVM_XINLINE RTYPE operator OP (half a, T b) {                   \
        return RTYPE(float(a) OP float(b));                           \
      }                                                               \
      template<typename T>                                            \
      TVM_XINLINE RTYPE operator OP (T a, half b) {                   \
        return RTYPE(float(a) OP float(b));                           \
      }

    #define TVM_HALF_ASSIGNOP(AOP, OP)                                \
      template<typename T>                                            \
      TVM_XINLINE half operator AOP (const T& a) {                    \
        return *this = half(float(*this) OP float(a));                \
      }                                                               \
      template<typename T>                                            \
      TVM_XINLINE half operator AOP (const volatile T& a) volatile {  \
        return *this = half(float(*this) OP float(a));                \
      }

    class TVM_ALIGNED(2) half {
    public:
      uint16_t half_;

      static TVM_XINLINE half Binary(uint16_t value) {
        half res;
        res.half_ = value;
        return res;
      }

      TVM_XINLINE half() {}

      TVM_XINLINE half(const float& value) { constructor(value); }
      TVM_XINLINE explicit half(const double& value) { constructor(value); }
      TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
      TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
      TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
      TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
      TVM_XINLINE explicit half(const long long& value) { constructor(value); }
      TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }

      TVM_XINLINE operator float() const {                          \
        return float(half2float(half_));                            \
      }                                                             \
      TVM_XINLINE operator float() const volatile {                 \
        return float(half2float(half_));                            \
      }


      TVM_HALF_ASSIGNOP(+=, +)
      TVM_HALF_ASSIGNOP(-=, -)
      TVM_HALF_ASSIGNOP(*=, *)
      TVM_HALF_ASSIGNOP(/=, /)

      TVM_XINLINE half operator+() {
        return *this;
      }

      TVM_XINLINE half operator-() {
        return half(-float(*this));
      }

      TVM_XINLINE half operator=(const half& a) {
        half_ = a.half_;
        return a;
      }

      template<typename T>
      TVM_XINLINE half operator=(const T& a) {
        return *this = half(a);
      }

      TVM_XINLINE half operator=(const half& a) volatile {
        half_ = a.half_;
        return a;
      }

      template<typename T>
      TVM_XINLINE half operator=(const T& a) volatile {
        return *this = half(a);
      }

    private:
      union Bits {
        float f;
        int32_t si;
        uint32_t ui;
      };

      static int const fp16FractionBits = 10;
      static int const fp32FractionBits = 23;
      static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);   // 
== 0x7fffff
      static int32_t const fp32HiddenBit = 1 << fp32FractionBits;   // == 
0x800000
      static int const shift = fp32FractionBits - fp16FractionBits;   // == 13
      static int const shiftSign = 16;
      static int32_t const expAdjust = 127 - 15;   // exp32-127 = exp16-15, so 
exp16 = exp32 - (127-15)

      static int32_t const infN = 0x7F800000;   // flt32 infinity
      static int32_t const maxN = 0x477FFFFF;   // max flt32 that's a flt16 
normal after >> by shift
      static int32_t const minN = 0x38800000;   // min flt16 normal as a flt32
      static int32_t const maxZ = 0x33000000;   // max fp32 number that's still 
rounded to zero in fp16
      static int32_t const signN = 0x80000000;  // flt32 sign bit

      static int32_t const infC = infN >> shift;
      static int32_t const nanN = (infC + 1) << shift;   // minimum flt16 nan 
as a flt32
      static int32_t const maxC = maxN >> shift;
      static int32_t const minC = minN >> shift;
      static int32_t const signC = signN >> shiftSign;  // flt16 sign bit

      static int32_t const mulN = 0x52000000;  // (1 << 23) / minN
      static int32_t const mulC = 0x33800000;  // minN / (1 << (23 - shift))

      static int32_t const subC = 0x003FF;  // max flt32 subnormal down shifted
      static int32_t const norC = 0x00400;  // min flt32 normal down shifted

      static int32_t const maxD = infC - maxC - 1;
      static int32_t const minD = minC - subC - 1;

      TVM_XINLINE uint16_t float2half(const float& value) const {
        Bits v;
        v.f = value;
        uint32_t sign = v.si & signN;    // grab sign bit
        v.si ^= sign;                    // clear sign bit from v
        sign >>= shiftSign;              // logical shift sign to fp16 position

        if (v.si <= maxZ) {
          // Handle eventual zeros here to ensure
          // vshift will not exceed 32 below.
          v.ui = 0;
        } else if (v.si < minN) {
          // Handle denorms
          uint32_t exp32 = v.ui >> fp32FractionBits;
          int32_t exp16 = exp32 - expAdjust;
          // If exp16 == 0 (just into the denorm range), then significant 
should be shifted right 1.
          // Smaller (so negative) exp16 values should result in greater right 
shifts.
          uint32_t vshift = 1 - exp16;
          uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
          v.ui = significand >> vshift;
          v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 
0;
        } else if (v.si <= maxN) {
          // Handle norms
          v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
          v.ui -= expAdjust << fp32FractionBits;
        } else if (v.si <= infN) {
          v.si = infN;
        } else if (v.si < nanN) {
          v.si = nanN;
        }

        v.ui >>= shift;
        return sign | (v.ui & 0x7fff);
      }

      // Same as above routine, except for addition of volatile keyword
      TVM_XINLINE uint16_t float2half(
        const volatile float& value) const volatile {
        Bits v;
        v.f = value;
        uint32_t sign = v.si & signN;    // grab sign bit
        v.si ^= sign;                    // clear sign bit from v
        sign >>= shiftSign;              // logical shift sign to fp16 position

        if (v.si <= maxZ) {
          // Handle eventual zeros here to ensure
          // vshift will not exceed 32 below.
          v.ui = 0;
        } else if (v.si < minN) {
          // Handle denorms
          uint32_t exp32 = v.ui >> fp32FractionBits;
          int32_t exp16 = exp32 - expAdjust;
          // If exp16 == 0 (just into the denorm range), then significant 
should be shifted right 1.
          // Smaller (so negative) exp16 values should result in greater right 
shifts.
          uint32_t vshift = 1 - exp16;
          uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
          v.ui = significand >> vshift;
          v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 
0;
        } else if (v.si <= maxN) {
          // Handle norms
          v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
          v.ui -= expAdjust << fp32FractionBits;
        } else if (v.si <= infN) {
          v.si = infN;
        } else if (v.si < nanN) {
          v.si = nanN;
        }

        v.ui >>= shift;
        return sign | (v.ui & 0x7fff);
      }

      TVM_XINLINE float half2float(const uint16_t& value) const {
        Bits v;
        v.ui = value;
        int32_t sign = v.si & signC;
        v.si ^= sign;
        sign <<= shiftSign;
        v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
        v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
        Bits s;
        s.si = mulC;
        s.f *= v.si;
        int32_t mask = -(norC > v.si);
        v.si <<= shift;
        v.si ^= (s.si ^ v.si) & mask;
        v.si |= sign;
        return v.f;
      }

      TVM_XINLINE float half2float(
        const volatile uint16_t& value) const volatile {
        Bits v;
        v.ui = value;
        int32_t sign = v.si & signC;
        v.si ^= sign;
        sign <<= shiftSign;
        v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
        v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
        Bits s;
        s.si = mulC;
        s.f *= v.si;
        int32_t mask = -(norC > v.si);
        v.si <<= shift;
        v.si ^= (s.si ^ v.si) & mask;
        v.si |= sign;
        return v.f;
      }

      template<typename T>
      TVM_XINLINE void constructor(const T& value) {
        half_ = float2half(float(value));
      }
    };

    TVM_HALF_OPERATOR(half, +)
    TVM_HALF_OPERATOR(half, -)
    TVM_HALF_OPERATOR(half, *)
    TVM_HALF_OPERATOR(half, /)
    TVM_HALF_OPERATOR(bool, >)
    TVM_HALF_OPERATOR(bool, <)
    TVM_HALF_OPERATOR(bool, >=)
    TVM_HALF_OPERATOR(bool, <=)

    TVM_XINLINE half __float2half_rn(const float a) {
      return half(a);
    }
    #endif


    // Pack two half values.
    static inline __device__ __host__ unsigned
    __pack_half2(const half x, const half y) {
      unsigned v0 = *((unsigned short *)&x);
      unsigned v1 = *((unsigned short *)&y);
      return (v1 << 16) | v0;
    }

    // fix undefined fp16 match function
    #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
    static inline __device__ __host__ half hpow(half x, half y) {
      float tmp_x = __half2float(x);
      float tmp_y = __half2float(y);
      float result = powf(tmp_x, tmp_y);
      return __float2half(result);
    }

    static inline __device__ __host__ half htanh(half x) {
      float tmp_x = __half2float(x);
      float result = tanhf(tmp_x);
      return __float2half(result);
    }
    #endif
    extern "C" __global__ void default_function_kernel0(half* __restrict__ A, 
half* __restrict__ B, float* __restrict__ compute) {
      float compute_local[8];
      __shared__ half A_shared[8448];
      __shared__ half B_shared[8192];
      half A_shared_local[16];
      half B_shared_local[128];
      for (int j_c_init = 0; j_c_init < 8; ++j_c_init) {
        compute_local[(j_c_init)] = 0.000000e+00f;
      }
      for (int k_outer = 0; k_outer < 2; ++k_outer) {
        __syncthreads();
        for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 8; 
++ax0_ax1_outer_fused_outer) {
            ((uint4*)(A_shared + ((((((ax0_ax1_outer_fused_outer * 1056) + 
((((int)threadIdx.y) >> 3) * 264)) + ((((int)threadIdx.y) & 7) * 32)) + 
(((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8)))))[0] = ((uint4*)(A + 
(((((((ax0_ax1_outer_fused_outer * 2048) + ((((int)threadIdx.y) >> 3) * 512)) + 
(k_outer * 256)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 
16)) + (((int)threadIdx.x) * 8)))))[0];
        }
        for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 
8; ++ax0_ax1_outer_fused_outer1) {
            ((uint4*)(B_shared + (((((ax0_ax1_outer_fused_outer1 * 1024) + 
(((int)threadIdx.y) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 
8)))))[0] = ((uint4*)(B + (((((((k_outer * 131072) + 
(ax0_ax1_outer_fused_outer1 * 16384)) + (((int)threadIdx.y) * 512)) + 
(((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 
8)))))[0];
        }
        __syncthreads();
        for (int k_inner_outer = 0; k_inner_outer < 16; ++k_inner_outer) {
          for (int ax1 = 0; ax1 < 16; ++ax1) {
            A_shared_local[(ax1)] = A_shared[((((((int)threadIdx.y) * 264) + 
(k_inner_outer * 16)) + ax1))];
          }
          for (int ax0 = 0; ax0 < 16; ++ax0) {
            for (int ax11 = 0; ax11 < 8; ++ax11) {
              B_shared_local[(((ax0 * 8) + ax11))] = 
B_shared[((((((k_inner_outer * 512) + (ax0 * 32)) + (((int)threadIdx.z) * 16)) 
+ (((int)threadIdx.x) * 8)) + ax11))];
            }
          }
          for (int k_inner_inner = 0; k_inner_inner < 16; ++k_inner_inner) {
            for (int j_c = 0; j_c < 8; ++j_c) {
              compute_local[(j_c)] = (compute_local[(j_c)] + 
(((float)A_shared_local[(k_inner_inner)]) * 
((float)B_shared_local[(((k_inner_inner * 8) + j_c))])));
            }
          }
        }
      }
      for (int j_inner_inner_inner = 0; j_inner_inner_inner < 8; 
++j_inner_inner_inner) {
        compute[((((((((int)threadIdx.y) * 512) + (((int)blockIdx.x) * 32)) + 
(((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8)) + j_inner_inner_inner))] 
= compute_local[(j_inner_inner_inner)];
      }
    }


    Time cost of this operator: 0.000041





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/tutorial-how-to-optimize-matmul-with-auto-tensorcore-codegen-cannot-work-on-my-machien/8345/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/eb695657bdade2d4430ffa3d11dcd99d8bc39d09d00495506ed3330f4f633c7f).

Reply via email to