gemini-code-assist[bot] commented on code in PR #18499:
URL: https://github.com/apache/tvm/pull/18499#discussion_r2558522075
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -34,6 +34,112 @@ class ExportedProgramImporter(BaseFXGraphImporter):
from torch import fx
+ @staticmethod
+ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) ->
tvm.runtime.Tensor:
+ """Convert a PyTorch tensor to TVM tensor, handling sparse tensors.
+
+ Parameters
+ ----------
+ tensor_value : torch.Tensor
+ The PyTorch tensor to convert.
+
+ Returns
+ -------
+ tvm.runtime.Tensor
+ The converted TVM tensor.
+ """
+ # PyTorch sparse tensors (layout != torch.strided) must be converted
to dense.
+ if tensor_value.layout != torch.strided:
+ tensor_to_convert = tensor_value.to_dense()
+ else:
+ tensor_to_convert = tensor_value
+ tensor_detached = tensor_to_convert.detach()
+
+ # Try DLPack conversion first (faster)
+ try:
+ return tvm.runtime.from_dlpack(tensor_detached)
+ except (RuntimeError, BufferError):
+ # Fallback: convert to numpy and then to TVM tensor
+ # This handles cases where DLPack conversion fails
+ tensor_cpu = tensor_detached.cpu().contiguous()
+ return tvm.runtime.tensor(tensor_cpu.numpy())
+
+ def _sparse_mm(self, node: fx.Node) -> relax.Var:
+ """Handle sparse matrix multiplication by converting sparse tensor to
dense."""
+ args = self.retrieve_args(node)
+ sparse_input = args[0]
+ dense_input = args[1]
+
+ # Convert sparse tensor to dense if needed
+ # Note: sparse_input should already be converted to dense in
_convert_pytorch_tensor_to_tvm
+ # Use regular matrix multiplication
+ return self.block_builder.emit(
+ relax.op.linear_algebra.matmul(sparse_input, dense_input,
out_dtype="float32")
+ )
+
+ def _sparse_addmm(self, node: fx.Node) -> relax.Var:
+ """Handle sparse addmm (beta * input + alpha * sparse_mm(mat1,
mat2))."""
+ args = self.retrieve_args(node)
+ input_tensor = args[0] # beta * input
+ sparse_mat1 = args[1] # sparse matrix
+ dense_mat2 = args[2] # dense matrix
+ alpha = node.kwargs.get("alpha", 1.0) if hasattr(node, "kwargs") else
1.0
+ beta = node.kwargs.get("beta", 1.0) if hasattr(node, "kwargs") else 1.0
Review Comment:

The `hasattr(node, "kwargs")` check is redundant. `torch.fx.Node` objects
always have a `kwargs` attribute, which is a dictionary. You can simplify this
by directly using `node.kwargs.get()`.
```python
alpha = node.kwargs.get("alpha", 1.0)
beta = node.kwargs.get("beta", 1.0)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]