This works nicely, and it is trivial to implement. Here's a complete example 
for reference: 
```
import torch
import topi
import tvm
from tvm import te
from tvm.contrib import dlpack

def _codegen_function(d1, d2, name):
    bsz = te.var('bsz') # bsz and d3 can be variables without impact on 
performance 
    d3 = te.var('d3')   # but d1 and d2 should be constants for 
`schedule_batch_matmul` to work
    A = te.placeholder((bsz, d1, d3), name='A', dtype='float32')
    B = te.placeholder((bsz, d2, d3), name='B', dtype='float32')
    R = topi.nn.batch_matmul(A, B)
    s = topi.cuda.batch_matmul.schedule_batch_matmul(R)
    return tvm.lower(s, [A, B, R], name=name)

if __name__ == "__main__":
  bsz = 12
  d11 = 2048
  d12 = 1024
  d2 = 512
  d3 = 64

  #  2 different versions of the same function
  bmm1 = _codegen_function(d11, d2, 'bmm1') 
  bmm2 = _codegen_function(d12, d2, 'bmm2')

  # build both functions into one module
  module = tvm.build([bmm1, bmm2], target='cuda', target_host='llvm')

  module.export_library('libbmm.so')  # save the module into a .so file
  module = tvm.runtime.load_module('libbmm.so')  # load it back
  # get each function then package it as a pytorch function
  bmm1_pytorch = dlpack.to_pytorch_func(module['bmm1'])
  bmm2_pytorch = dlpack.to_pytorch_func(module['bmm2'])

  A1 = torch.randn(bsz, d11, d3, device='cuda')
  A2 = torch.randn(bsz, d12, d3, device='cuda')
  B = torch.randn(bsz, d2, d3, device='cuda')
  R1 = B.new_empty(bsz, d11, d2)  # allocate memory for the result tensor
  R2 = B.new_empty(bsz, d12, d2)  # allocate memory for the result tensor

  bmm1_pytorch(A1, B, R1)
  print(R1.sum())

  bmm2_pytorch(A2, B, R2)
  print(R2.sum())
```





---
[Visit 
Topic](https://discuss.tvm.ai/t/optimizing-matrix-multiplication-for-gpu/4212/22)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/00b93bc294485982be0aa88ebf51db1bc8591a754ab14e3cf143c1e9d7710f26).

Reply via email to