My machine has a V100 with cuda 10.2. The tutorial cannot generate code using tensor core primitives but ordinary cuda code. Can anyone help to solve this situation? Thanks! Here is the code I try to run:
import logging import sys import numpy as np import tvm from tvm import te import tvm.testing from tvm import autotvm from tvm.contrib import nvcc def matmul_nn(A, B, L, dtype="float16", layout="NN"): k = te.reduce_axis((0, L), name="k") if dtype == "float16": out_type = "float" elif dtype == "int8": out_type = "int" elif dtype == "int4" or dtype == "int1": out_type = "int" if layout == "NN": return te.compute( (N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k) ) if layout == "NT": return te.compute( (N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[k, j].astype(out_type), axis=k) ) if layout == "TN": return te.compute( (N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[j, k].astype(out_type), axis=k) ) if layout == "TT": return te.compute( (N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[j, k].astype(out_type), axis=k) ) def test_gemm(N, L, M, dtype, layout): if layout == "NN": shape_a = (N, L) shape_b = (L, M) elif layout == "NT": shape_a = (L, N) shape_b = (L, M) elif layout == "TN": shape_a = (N, L) shape_b = (M, L) elif layout == "TT": shape_a = (L, N) shape_b = (M, L) else: print("Unsupported layout:", layout) sys.exit(1) A = te.placeholder(shape_a, name="A", dtype=dtype) B = te.placeholder(shape_b, name="B", dtype=dtype) C = matmul_nn(A, B, L, dtype, layout) s = te.create_schedule(C.op) y, x = s[C].op.axis k = s[C].op.reduce_axis[0] # storage_align params factor = 16 offset = 8 if dtype == "int8": factor = 32 offset = 16 elif dtype == "int4": factor = 64 offset = 32 elif dtype == "int1": factor = 256 offset = 128 # create cache stages AA = s.cache_read(A, "shared", [C]) if layout == "NN" or layout == "TN": s[AA].storage_align(AA.op.axis[0], factor, offset) AL = s.cache_read(AA, "local", [C]) BB = s.cache_read(B, "shared", [C]) if layout == "TT" or layout == "NT": s[BB].storage_align(BB.op.axis[0], factor, offset) BL = s.cache_read(BB, "local", [C]) CL = s.cache_write(C, "local") bx = 4 by = 32 step_k = 16 v = 8 # thread tile TX = 8 TY = 1 if dtype == "int4" or dtype == "int1": TX = 2 # warp tile warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0 warp_tile_k = 16 # it must be 16 for fp16/int8 data type if dtype == "int4": warp_tile_m = 8 warp_tile_k = 32 elif dtype == "int1": warp_tile_m = 8 warp_tile_k = 128 # block tile tile_x = bx * TX tile_y = by * TY yo, ty = s[C].split(y, tile_y) ty, yi = s[C].split(ty, TY) # schedule for C stage xo, xi = s[C].split(x, tile_x) WX = min(warp_tile_m, tile_x) tz, xi = s[C].split(xi, WX) tx, xi = s[C].split(xi, TX) s[C].reorder(yo, xo, tz, ty, tx, yi, xi) s[C].bind(yo, te.thread_axis("blockIdx.y")) s[C].bind(xo, te.thread_axis("blockIdx.x")) s[C].bind(ty, te.thread_axis("threadIdx.y")) s[C].bind(tz, te.thread_axis("threadIdx.z")) s[C].bind(tx, te.thread_axis("threadIdx.x")) # schedule for CL stage ko, ki = s[CL].split(k, step_k * warp_tile_k) kl, ki = s[CL].split(ki, warp_tile_k) s[CL].compute_at(s[C], tx) yo, xo = CL.op.axis s[CL].reorder(ko, kl, ki, yo, xo) # schedule for AA stage s[AA].compute_at(s[CL], ko) xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx * v) tz, tx = s[AA].split(xi, factor=(WX // TX) * v) tx, vec = s[AA].split(tx, factor=v) fused = s[AA].fuse(s[AA].op.axis[0], xo) _, ty = s[AA].split(fused, factor=by) s[AA].bind(ty, te.thread_axis("threadIdx.y")) s[AA].bind(tz, te.thread_axis("threadIdx.z")) s[AA].bind(tx, te.thread_axis("threadIdx.x")) # vectorization is very important for float16/int8 inputs s[AA].vectorize(vec) # schedule for BB stage s[BB].compute_at(s[CL], ko) xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx * v) tz, tx = s[BB].split(xi, factor=(WX // TX) * v) tx, vec = s[BB].split(tx, factor=v) fused = s[BB].fuse(s[BB].op.axis[0], xo) _, ty = s[BB].split(fused, factor=by) s[BB].bind(ty, te.thread_axis("threadIdx.y")) s[BB].bind(tz, te.thread_axis("threadIdx.z")) s[BB].bind(tx, te.thread_axis("threadIdx.x")) s[BB].vectorize(vec) s[AL].compute_at(s[CL], kl) s[BL].compute_at(s[CL], kl) # set the 'tensor_core' pragma for tensorcore codegen s[CL].pragma(ko, "tensor_core") return s, [A, B, C] ctx = tvm.gpu() if not nvcc.have_tensorcore(ctx.compute_version): raise Exception("the gpu has no tensorcore, skipping...") M, N, L = 512, 32, 512 dtype = "float16" layout = "NN" if len(sys.argv) >= 4: M, N, L = int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]) if len(sys.argv) >= 5: dtype = sys.argv[4] if len(sys.argv) >= 6: layout = sys.argv[5] # check whether current gpu arch support support current dtype's wmma codegen cuda_compute_capability = tvm.runtime._ffi_api.GetDeviceAttr(2, 0, 4) major, minor = nvcc.parse_compute_version(cuda_compute_capability) if dtype == "int8": assert major == 7 and minor >= 2 elif dtype == "int4" or dtype == "int1": # int4/int1 only support layout TN assert major == 7 and minor == 5 and layout == "TN" def evaluate(M, N, L, dtype, layout): with tvm.target.Target("cuda"): s, arg_bufs = test_gemm(N, L, M, dtype, layout) print(tvm.lower(s, arg_bufs, simple_mode=True)) func = tvm.build(s, arg_bufs) dev_module = func.imported_modules[0] print(dev_module.get_source()) # check correctness if layout == "NN": shape_a = (N, L) shape_b = (L, M) elif layout == "NT": shape_a = (L, N) shape_b = (L, M) elif layout == "TN": shape_a = (N, L) shape_b = (M, L) elif layout == "TT": shape_a = (L, N) shape_b = (M, L) a_np = None b_np = None c_np = None c_np_type = None if dtype == "float16": c_np_type = np.float32 a_np = np.random.uniform(size=shape_a).astype(np.float16) b_np = np.random.uniform(size=shape_b).astype(np.float16) if layout == "NN": c_np = np.dot(a_np, b_np) elif layout == "NT": c_np = np.dot(a_np.T, b_np) elif layout == "TN": c_np = np.dot(a_np, b_np.T) elif layout == "TT": c_np = np.dot(a_np.T, b_np.T) elif dtype == "int8": c_np_type = np.int32 a_np = np.random.randint(low=-128, high=127, size=shape_a).astype(np.int8) b_np = np.random.randint(low=-128, high=127, size=shape_b).astype(np.int8) if layout == "NN": c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32)) elif layout == "NT": c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32)) elif layout == "TN": c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T) elif layout == "TT": c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T) elif dtype == "int4": c_np_type = np.int32 a_np_int = np.random.randint(low=-8, high=7, size=shape_a).astype(np.int32) b_np_int = np.random.randint(low=-8, high=7, size=shape_b).astype(np.int32) # "TN" c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T) a_np = np.zeros(shape=(N, int(L / 8)), dtype=np.int32) b_np = np.zeros(shape=(M, int(L / 8)), dtype=np.int32) # a_np --> col_major for i in range(N): for j in range(int(L / 8)): for k in range(8): a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 8 + k] & 0xF) << ((7 - k) * 4)) # b_np --> row_major for i in range(M): for j in range(int(L / 8)): for k in range(8): b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 8 + k] & 0xF) << ((7 - k) * 4)) elif dtype == "int1": c_np_type = np.int32 a_np_int = np.random.randint(low=0, high=1, size=shape_a).astype(np.int32) b_np_int = np.random.randint(low=0, high=1, size=shape_b).astype(np.int32) # "TN" c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T) a_np = np.zeros(shape=(N, int(L / 32)), dtype=np.int32) b_np = np.zeros(shape=(M, int(L / 32)), dtype=np.int32) for i in range(N): for j in range(int(L / 32)): for k in range(32): a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 32 + k] & 0xF) << (31 - k)) for i in range(M): for j in range(int(L / 32)): for k in range(32): b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 32 + k] & 0xF) << (31 - k)) c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx) a_tvm = tvm.nd.array(a_np, ctx=ctx) b_tvm = tvm.nd.array(b_np, ctx=ctx) func(a_tvm, b_tvm, c_tvm) tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-3) evaluator = func.time_evaluator(func.entry_name, ctx, number=100) print("Time cost of this operator: %f" % evaluator(a_tvm, b_tvm, c_tvm).mean) evaluate(M, N, L, dtype, layout) And the output is: #[version = "0.0.5"] primfn(A_1: handle, B_1: handle, compute_1: handle) -> () attr = {"global_symbol": "main", "tir.noalias": True} buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [32, 512], []), B: Buffer(B_2: Pointer(float16), float16, [512, 512], []), A: Buffer(A_2: Pointer(float16), float16, [32, 512], [])} buffer_map = {A_1: A, B_1: B, compute_1: compute} { attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 1; attr [compute.local: Pointer(float32)] "storage_scope" = "local"; allocate(compute.local, float32, [8]); attr [A.shared: Pointer(float16)] "storage_scope" = "shared"; allocate(A.shared, float16, [8448]); attr [B.shared: Pointer(float16)] "storage_scope" = "shared"; allocate(B.shared, float16, [8192]); attr [A.shared.local: Pointer(float16)] "storage_scope" = "local"; allocate(A.shared.local, float16, [16]); attr [B.shared.local: Pointer(float16)] "storage_scope" = "local"; allocate(B.shared.local, float16, [128]); attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 16; attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 2; attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 32; attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2 { for (j.c.init: int32, 0, 8) { compute.local[j.c.init] = 0f32 } attr [IterVar(k.outer: int32, (nullptr), "CommReduce", "")] "pragma_tensor_core" = 1; for (k.outer, 0, 2) { for (ax0.ax1.outer.fused.outer: int32, 0, 8) { attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 32; attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 2; attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2; A.shared[ramp((((((ax0.ax1.outer.fused.outer*1056) + (floordiv(threadIdx.y_1, 8)*264)) + (floormod(threadIdx.y_1, 8)*32)) + (threadIdx.z_1*16)) + (threadIdx.x_1*8)), 1, 8)] = (float16x8*)A_2[ramp(((((((ax0.ax1.outer.fused.outer*2048) + (floordiv(threadIdx.y_1, 8)*512)) + (k.outer*256)) + (floormod(threadIdx.y_1, 8)*32)) + (threadIdx.z_1*16)) + (threadIdx.x_1*8)), 1, 8)] } for (ax0.ax1.outer.fused.outer_1: int32, 0, 8) { attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 32; attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 2; attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2; B.shared[ramp(((((ax0.ax1.outer.fused.outer_1*1024) + (threadIdx.y_2*32)) + (threadIdx.z_2*16)) + (threadIdx.x_2*8)), 1, 8)] = (float16x8*)B_2[ramp(((((((k.outer*131072) + (ax0.ax1.outer.fused.outer_1*16384)) + (threadIdx.y_2*512)) + (blockIdx.x*32)) + (threadIdx.z_2*16)) + (threadIdx.x_2*8)), 1, 8)] } for (k.inner.outer: int32, 0, 16) { for (ax1: int32, 0, 16) { A.shared.local[ax1] = (float16*)A.shared[(((threadIdx.y*264) + (k.inner.outer*16)) + ax1)] } for (ax0: int32, 0, 16) { for (ax1_1: int32, 0, 8) { B.shared.local[((ax0*8) + ax1_1)] = (float16*)B.shared[(((((k.inner.outer*512) + (ax0*32)) + (threadIdx.z*16)) + (threadIdx.x*8)) + ax1_1)] } } for (k.inner.inner: int32, 0, 16) { for (j.c: int32, 0, 8) { compute.local[j.c] = ((float32*)compute.local[j.c] + (cast(float32, (float16*)A.shared.local[k.inner.inner])*cast(float32, (float16*)B.shared.local[((k.inner.inner*8) + j.c)]))) } } } } for (j.inner.inner.inner: int32, 0, 8) { compute_2[(((((threadIdx.y*512) + (blockIdx.x*32)) + (threadIdx.z*16)) + (threadIdx.x*8)) + j.inner.inner.inner)] = (float32*)compute.local[j.inner.inner.inner] } } } #[metadata] { "root": 1, "nodes": [ { "type_key": "" }, { "type_key": "Map", "keys": [ "IntImm" ], "data": [2] }, { "type_key": "Array", "data": [3] }, { "type_key": "IntImm", "attrs": { "dtype": "bool", "value": "1" } } ], "b64ndarrays": [], "attrs": {"tvm_version": "0.8.dev0"} } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) #include <cuda_fp16.h> __device__ half max(half a, half b) { return __hgt(__half(a), __half(b)) ? a : b; } __device__ half min(half a, half b) { return __hlt(__half(a), __half(b)) ? a : b; } #else typedef unsigned short uint16_t; typedef unsigned char uint8_t; typedef signed char int8_t; typedef int int32_t; typedef unsigned long long uint64_t; typedef unsigned int uint32_t; #define TVM_FORCE_INLINE inline __attribute__((always_inline)) #define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__ #define TVM_ALIGNED(x) __attribute__ ((aligned(x))) #define TVM_HALF_OPERATOR(RTYPE, OP) \ TVM_XINLINE RTYPE operator OP (half a, half b) { \ return RTYPE(float(a) OP float(b)); \ } \ template<typename T> \ TVM_XINLINE RTYPE operator OP (half a, T b) { \ return RTYPE(float(a) OP float(b)); \ } \ template<typename T> \ TVM_XINLINE RTYPE operator OP (T a, half b) { \ return RTYPE(float(a) OP float(b)); \ } #define TVM_HALF_ASSIGNOP(AOP, OP) \ template<typename T> \ TVM_XINLINE half operator AOP (const T& a) { \ return *this = half(float(*this) OP float(a)); \ } \ template<typename T> \ TVM_XINLINE half operator AOP (const volatile T& a) volatile { \ return *this = half(float(*this) OP float(a)); \ } class TVM_ALIGNED(2) half { public: uint16_t half_; static TVM_XINLINE half Binary(uint16_t value) { half res; res.half_ = value; return res; } TVM_XINLINE half() {} TVM_XINLINE half(const float& value) { constructor(value); } TVM_XINLINE explicit half(const double& value) { constructor(value); } TVM_XINLINE explicit half(const int8_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } TVM_XINLINE explicit half(const long long& value) { constructor(value); } TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } TVM_XINLINE operator float() const { \ return float(half2float(half_)); \ } \ TVM_XINLINE operator float() const volatile { \ return float(half2float(half_)); \ } TVM_HALF_ASSIGNOP(+=, +) TVM_HALF_ASSIGNOP(-=, -) TVM_HALF_ASSIGNOP(*=, *) TVM_HALF_ASSIGNOP(/=, /) TVM_XINLINE half operator+() { return *this; } TVM_XINLINE half operator-() { return half(-float(*this)); } TVM_XINLINE half operator=(const half& a) { half_ = a.half_; return a; } template<typename T> TVM_XINLINE half operator=(const T& a) { return *this = half(a); } TVM_XINLINE half operator=(const half& a) volatile { half_ = a.half_; return a; } template<typename T> TVM_XINLINE half operator=(const T& a) volatile { return *this = half(a); } private: union Bits { float f; int32_t si; uint32_t ui; }; static int const fp16FractionBits = 10; static int const fp32FractionBits = 23; static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 static int const shift = fp32FractionBits - fp16FractionBits; // == 13 static int const shiftSign = 16; static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) static int32_t const infN = 0x7F800000; // flt32 infinity static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 static int32_t const signN = 0x80000000; // flt32 sign bit static int32_t const infC = infN >> shift; static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 static int32_t const maxC = maxN >> shift; static int32_t const minC = minN >> shift; static int32_t const signC = signN >> shiftSign; // flt16 sign bit static int32_t const mulN = 0x52000000; // (1 << 23) / minN static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted static int32_t const norC = 0x00400; // min flt32 normal down shifted static int32_t const maxD = infC - maxC - 1; static int32_t const minD = minC - subC - 1; TVM_XINLINE uint16_t float2half(const float& value) const { Bits v; v.f = value; uint32_t sign = v.si & signN; // grab sign bit v.si ^= sign; // clear sign bit from v sign >>= shiftSign; // logical shift sign to fp16 position if (v.si <= maxZ) { // Handle eventual zeros here to ensure // vshift will not exceed 32 below. v.ui = 0; } else if (v.si < minN) { // Handle denorms uint32_t exp32 = v.ui >> fp32FractionBits; int32_t exp16 = exp32 - expAdjust; // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. // Smaller (so negative) exp16 values should result in greater right shifts. uint32_t vshift = 1 - exp16; uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); v.ui = significand >> vshift; v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; } else if (v.si <= maxN) { // Handle norms v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; v.ui -= expAdjust << fp32FractionBits; } else if (v.si <= infN) { v.si = infN; } else if (v.si < nanN) { v.si = nanN; } v.ui >>= shift; return sign | (v.ui & 0x7fff); } // Same as above routine, except for addition of volatile keyword TVM_XINLINE uint16_t float2half( const volatile float& value) const volatile { Bits v; v.f = value; uint32_t sign = v.si & signN; // grab sign bit v.si ^= sign; // clear sign bit from v sign >>= shiftSign; // logical shift sign to fp16 position if (v.si <= maxZ) { // Handle eventual zeros here to ensure // vshift will not exceed 32 below. v.ui = 0; } else if (v.si < minN) { // Handle denorms uint32_t exp32 = v.ui >> fp32FractionBits; int32_t exp16 = exp32 - expAdjust; // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. // Smaller (so negative) exp16 values should result in greater right shifts. uint32_t vshift = 1 - exp16; uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); v.ui = significand >> vshift; v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; } else if (v.si <= maxN) { // Handle norms v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; v.ui -= expAdjust << fp32FractionBits; } else if (v.si <= infN) { v.si = infN; } else if (v.si < nanN) { v.si = nanN; } v.ui >>= shift; return sign | (v.ui & 0x7fff); } TVM_XINLINE float half2float(const uint16_t& value) const { Bits v; v.ui = value; int32_t sign = v.si & signC; v.si ^= sign; sign <<= shiftSign; v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); Bits s; s.si = mulC; s.f *= v.si; int32_t mask = -(norC > v.si); v.si <<= shift; v.si ^= (s.si ^ v.si) & mask; v.si |= sign; return v.f; } TVM_XINLINE float half2float( const volatile uint16_t& value) const volatile { Bits v; v.ui = value; int32_t sign = v.si & signC; v.si ^= sign; sign <<= shiftSign; v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); Bits s; s.si = mulC; s.f *= v.si; int32_t mask = -(norC > v.si); v.si <<= shift; v.si ^= (s.si ^ v.si) & mask; v.si |= sign; return v.f; } template<typename T> TVM_XINLINE void constructor(const T& value) { half_ = float2half(float(value)); } }; TVM_HALF_OPERATOR(half, +) TVM_HALF_OPERATOR(half, -) TVM_HALF_OPERATOR(half, *) TVM_HALF_OPERATOR(half, /) TVM_HALF_OPERATOR(bool, >) TVM_HALF_OPERATOR(bool, <) TVM_HALF_OPERATOR(bool, >=) TVM_HALF_OPERATOR(bool, <=) TVM_XINLINE half __float2half_rn(const float a) { return half(a); } #endif // Pack two half values. static inline __device__ __host__ unsigned __pack_half2(const half x, const half y) { unsigned v0 = *((unsigned short *)&x); unsigned v1 = *((unsigned short *)&y); return (v1 << 16) | v0; } // fix undefined fp16 match function #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) static inline __device__ __host__ half hpow(half x, half y) { float tmp_x = __half2float(x); float tmp_y = __half2float(y); float result = powf(tmp_x, tmp_y); return __float2half(result); } static inline __device__ __host__ half htanh(half x) { float tmp_x = __half2float(x); float result = tanhf(tmp_x); return __float2half(result); } #endif extern "C" __global__ void default_function_kernel0(half* __restrict__ A, half* __restrict__ B, float* __restrict__ compute) { float compute_local[8]; __shared__ half A_shared[8448]; __shared__ half B_shared[8192]; half A_shared_local[16]; half B_shared_local[128]; for (int j_c_init = 0; j_c_init < 8; ++j_c_init) { compute_local[(j_c_init)] = 0.000000e+00f; } for (int k_outer = 0; k_outer < 2; ++k_outer) { __syncthreads(); for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 8; ++ax0_ax1_outer_fused_outer) { ((uint4*)(A_shared + ((((((ax0_ax1_outer_fused_outer * 1056) + ((((int)threadIdx.y) >> 3) * 264)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8)))))[0] = ((uint4*)(A + (((((((ax0_ax1_outer_fused_outer * 2048) + ((((int)threadIdx.y) >> 3) * 512)) + (k_outer * 256)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8)))))[0]; } for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 8; ++ax0_ax1_outer_fused_outer1) { ((uint4*)(B_shared + (((((ax0_ax1_outer_fused_outer1 * 1024) + (((int)threadIdx.y) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8)))))[0] = ((uint4*)(B + (((((((k_outer * 131072) + (ax0_ax1_outer_fused_outer1 * 16384)) + (((int)threadIdx.y) * 512)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8)))))[0]; } __syncthreads(); for (int k_inner_outer = 0; k_inner_outer < 16; ++k_inner_outer) { for (int ax1 = 0; ax1 < 16; ++ax1) { A_shared_local[(ax1)] = A_shared[((((((int)threadIdx.y) * 264) + (k_inner_outer * 16)) + ax1))]; } for (int ax0 = 0; ax0 < 16; ++ax0) { for (int ax11 = 0; ax11 < 8; ++ax11) { B_shared_local[(((ax0 * 8) + ax11))] = B_shared[((((((k_inner_outer * 512) + (ax0 * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8)) + ax11))]; } } for (int k_inner_inner = 0; k_inner_inner < 16; ++k_inner_inner) { for (int j_c = 0; j_c < 8; ++j_c) { compute_local[(j_c)] = (compute_local[(j_c)] + (((float)A_shared_local[(k_inner_inner)]) * ((float)B_shared_local[(((k_inner_inner * 8) + j_c))]))); } } } } for (int j_inner_inner_inner = 0; j_inner_inner_inner < 8; ++j_inner_inner_inner) { compute[((((((((int)threadIdx.y) * 512) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8)) + j_inner_inner_inner))] = compute_local[(j_inner_inner_inner)]; } } Time cost of this operator: 0.000041 --- [Visit Topic](https://discuss.tvm.apache.org/t/tutorial-how-to-optimize-matmul-with-auto-tensorcore-codegen-cannot-work-on-my-machien/8345/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/eb695657bdade2d4430ffa3d11dcd99d8bc39d09d00495506ed3330f4f633c7f).