This works nicely, and it is trivial to implement. Here's a complete example for reference: ``` import torch import topi import tvm from tvm import te from tvm.contrib import dlpack
def _codegen_function(d1, d2, name): bsz = te.var('bsz') # bsz and d3 can be variables without impact on performance d3 = te.var('d3') # but d1 and d2 should be constants for `schedule_batch_matmul` to work A = te.placeholder((bsz, d1, d3), name='A', dtype='float32') B = te.placeholder((bsz, d2, d3), name='B', dtype='float32') R = topi.nn.batch_matmul(A, B) s = topi.cuda.batch_matmul.schedule_batch_matmul(R) return tvm.lower(s, [A, B, R], name=name) if __name__ == "__main__": bsz = 12 d11 = 2048 d12 = 1024 d2 = 512 d3 = 64 # 2 different versions of the same function bmm1 = _codegen_function(d11, d2, 'bmm1') bmm2 = _codegen_function(d12, d2, 'bmm2') # build both functions into one module module = tvm.build([bmm1, bmm2], target='cuda', target_host='llvm') module.export_library('libbmm.so') # save the module into a .so file module = tvm.runtime.load_module('libbmm.so') # load it back # get each function then package it as a pytorch function bmm1_pytorch = dlpack.to_pytorch_func(module['bmm1']) bmm2_pytorch = dlpack.to_pytorch_func(module['bmm2']) A1 = torch.randn(bsz, d11, d3, device='cuda') A2 = torch.randn(bsz, d12, d3, device='cuda') B = torch.randn(bsz, d2, d3, device='cuda') R1 = B.new_empty(bsz, d11, d2) # allocate memory for the result tensor R2 = B.new_empty(bsz, d12, d2) # allocate memory for the result tensor bmm1_pytorch(A1, B, R1) print(R1.sum()) bmm2_pytorch(A2, B, R2) print(R2.sum()) ``` --- [Visit Topic](https://discuss.tvm.ai/t/optimizing-matrix-multiplication-for-gpu/4212/22) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/00b93bc294485982be0aa88ebf51db1bc8591a754ab14e3cf143c1e9d7710f26).