gemini-code-assist[bot] commented on code in PR #18540:
URL: https://github.com/apache/tvm/pull/18540#discussion_r2579675681
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -48,6 +48,18 @@ def _convert_pytorch_tensor_to_tvm(tensor_value:
torch.Tensor) -> tvm.runtime.Te
tvm.runtime.Tensor
The converted TVM tensor.
"""
+ # Fix for Issue #18407: Handle FakeTensor and lifted tensors (from
torch.export)
+ # Check if this is a FakeTensor or tensor subclass that doesn't
support .numpy()
+ try:
+ # Check if it's a FakeTensor
+ if hasattr(torch, '_subclasses') and hasattr(torch._subclasses,
'fake_tensor'):
+ if isinstance(tensor_value,
torch._subclasses.fake_tensor.FakeTensor):
+ # Create a real tensor with the same shape and dtype
+ real_tensor = torch.zeros(tensor_value.shape,
dtype=tensor_value.dtype)
+ return tvm.runtime.tensor(real_tensor.numpy())
+ except (AttributeError, ImportError):
+ pass
Review Comment:

This block can be simplified for better readability. The `hasattr` checks
are redundant within a `try...except AttributeError` block, and the nested `if`
statements can be flattened. This makes the intent of checking for an internal
PyTorch API clearer.
```suggestion
try:
# Check if it's a FakeTensor from torch._subclasses, which is an
internal API.
FakeTensor = torch._subclasses.fake_tensor.FakeTensor
if isinstance(tensor_value, FakeTensor):
# Create a real tensor with the same shape and dtype as a
placeholder.
real_tensor = torch.zeros(tensor_value.shape,
dtype=tensor_value.dtype)
return tvm.runtime.tensor(real_tensor.numpy())
except (AttributeError, ImportError):
# FakeTensor class might not exist in this torch version, or
other import issue.
pass
```
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -61,8 +73,17 @@ def _convert_pytorch_tensor_to_tvm(tensor_value:
torch.Tensor) -> tvm.runtime.Te
except (RuntimeError, BufferError):
# Fallback: convert to numpy and then to TVM tensor
# This handles cases where DLPack conversion fails
- tensor_cpu = tensor_detached.cpu().contiguous()
- return tvm.runtime.tensor(tensor_cpu.numpy())
+ try:
+ tensor_cpu = tensor_detached.cpu().contiguous()
+ return tvm.runtime.tensor(tensor_cpu.numpy())
+ except RuntimeError as e:
+ # Fix for Issue #18407: Handle tensor subclasses that don't
support .numpy()
+ # This can happen with lifted tensors from torch.export
+ if "tensor subclasses" in str(e) or "FakeTensor" in str(e):
+ # Create a dummy tensor with the same shape and dtype
+ dummy_tensor = torch.zeros(tensor_value.shape,
dtype=tensor_value.dtype)
+ return tvm.runtime.tensor(dummy_tensor.numpy())
+ raise
Review Comment:

Relying on string matching for exception messages is fragile and can break
with future PyTorch updates.
* Is there a more robust way to identify this specific `RuntimeError`, for
example by checking for a specific error code if available, or a more specific
exception type? If this is the only viable approach, it would be beneficial to
add a comment explaining the fragility.
* The check for `"FakeTensor" in str(e)` seems potentially redundant, as
`FakeTensor` instances should be caught by the check at the beginning of the
function (lines 53-61). If there's a scenario where a `FakeTensor` error can
still occur here, could you add a comment to clarify it?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]