LiSsHhUuAaIi opened a new issue, #18476:
URL: https://github.com/apache/tvm/issues/18476

   ### Description
   When converting a PyTorch model containing sparse matrix multiplication 
(`torch.sparse.mm`) and random number generation (`torch.randn`) operations to 
TVM Relax module via `torch.export`, an AssertionError occurs. TVM currently 
does not support the `_sparse_mm.default` and `randn.default` operations.
   
   ### Expected behavior
   The PyTorch model with sparse operations and random number generation should 
be successfully converted to TVM Relax module, enabling deployment of models 
that use sparse computations and stochastic components.
   
   ### Actual behavior
   An AssertionError occurs during `from_exported_program` conversion with the 
message `Unsupported function types ['_sparse_mm.default', 'randn.default']`, 
indicating that TVM's PyTorch frontend lacks support for these operations.
   ```
   AssertionError: Unsupported function types ['_sparse_mm.default', 
'randn.default']
   ```
   ### Environment
   
   - OS: Ubuntu 20.04.6 LTS
   - TVM version: 0.23.dev0
   - Python version: 3.11.14
   
   ### Steps to reproduce
   
   ```python
   import torch
   import torch.nn as nn
   import tvm
   from tvm import relax
   
   class MinimalSparseAndRandomModel(nn.Module):
       def __init__(self):
           super(MinimalSparseAndRandomModel, self).__init__()
   
       def forward(self, sparse_input):
           # Unsupported operations
           random_matrix = torch.randn(100, 50)  # randn.default
           result = torch.sparse.mm(sparse_input, random_matrix)  # 
_sparse_mm.default
           return result
   
   model = MinimalSparseAndRandomModel()
   model.eval()
   
   indices = torch.tensor([[0, 1, 2], [2, 0, 1]])
   values = torch.tensor([1.0, 2.0, 3.0])
   sparse_input = torch.sparse_coo_tensor(indices, values, size=(3, 100))
   
   # PyTorch execution works
   with torch.no_grad():
       output = model(sparse_input)
   
   # PyTorch export works  
   exported_program = torch.export.export(model, (sparse_input,))
   
   # TVM conversion fails
   from tvm.relax.frontend.torch import from_exported_program
   mod = from_exported_program(exported_program)  # AssertionError here
   ```
   
   ### Triage
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to