gemini-code-assist[bot] commented on code in PR #18499:
URL: https://github.com/apache/tvm/pull/18499#discussion_r2558522075


##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -34,6 +34,112 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
     from torch import fx
 
+    @staticmethod
+    def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> 
tvm.runtime.Tensor:
+        """Convert a PyTorch tensor to TVM tensor, handling sparse tensors.
+
+        Parameters
+        ----------
+        tensor_value : torch.Tensor
+            The PyTorch tensor to convert.
+
+        Returns
+        -------
+        tvm.runtime.Tensor
+            The converted TVM tensor.
+        """
+        # PyTorch sparse tensors (layout != torch.strided) must be converted 
to dense.
+        if tensor_value.layout != torch.strided:
+            tensor_to_convert = tensor_value.to_dense()
+        else:
+            tensor_to_convert = tensor_value
+        tensor_detached = tensor_to_convert.detach()
+
+        # Try DLPack conversion first (faster)
+        try:
+            return tvm.runtime.from_dlpack(tensor_detached)
+        except (RuntimeError, BufferError):
+            # Fallback: convert to numpy and then to TVM tensor
+            # This handles cases where DLPack conversion fails
+            tensor_cpu = tensor_detached.cpu().contiguous()
+            return tvm.runtime.tensor(tensor_cpu.numpy())
+
+    def _sparse_mm(self, node: fx.Node) -> relax.Var:
+        """Handle sparse matrix multiplication by converting sparse tensor to 
dense."""
+        args = self.retrieve_args(node)
+        sparse_input = args[0]
+        dense_input = args[1]
+
+        # Convert sparse tensor to dense if needed
+        # Note: sparse_input should already be converted to dense in 
_convert_pytorch_tensor_to_tvm
+        # Use regular matrix multiplication
+        return self.block_builder.emit(
+            relax.op.linear_algebra.matmul(sparse_input, dense_input, 
out_dtype="float32")
+        )
+
+    def _sparse_addmm(self, node: fx.Node) -> relax.Var:
+        """Handle sparse addmm (beta * input + alpha * sparse_mm(mat1, 
mat2))."""
+        args = self.retrieve_args(node)
+        input_tensor = args[0]  # beta * input
+        sparse_mat1 = args[1]  # sparse matrix
+        dense_mat2 = args[2]  # dense matrix
+        alpha = node.kwargs.get("alpha", 1.0) if hasattr(node, "kwargs") else 
1.0
+        beta = node.kwargs.get("beta", 1.0) if hasattr(node, "kwargs") else 1.0

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `hasattr(node, "kwargs")` check is redundant. `torch.fx.Node` objects 
always have a `kwargs` attribute, which is a dictionary. You can simplify this 
by directly using `node.kwargs.get()`.
   
   ```python
   alpha = node.kwargs.get("alpha", 1.0)
   beta = node.kwargs.get("beta", 1.0)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to