gemini-code-assist[bot] commented on code in PR #18503:
URL: https://github.com/apache/tvm/pull/18503#discussion_r2559688825
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1417,10 +1422,21 @@ def create_input_vars(
torch_dtype = node.meta["tensor_meta"].dtype
break
else:
+ #Buffers
+ info=None
+ if spec.target in merged_state:
+ info=merged_state[spec.target]
+ elif spec.target.split(".")[-1] in merged_state:
+ info = merged_state[spec.target.split(".")[-1]]
+ if info is None:
+ raise KeyError(f"Missing target in state_dict or extra
buffers: {spec.target}")
+
+ # Handle both original and extra buffer
+ if hasattr(info,"shape") and hasattr(info,"dtype"):
+ torch_shape=info.shape
+ torch_dtype=info.dtype
Review Comment:

This block has a couple of critical issues:
1. The `merged_state` variable is used but not defined, which will cause a
`NameError`. You should define it before the loop by merging
`exported_program.state_dict` and `extra_buffers`. For example: `merged_state =
{**exported_program.state_dict, **extra_buffers}`.
2. The logic to get shape and dtype using `hasattr(info, "shape")` will not
work for the dictionaries in `extra_buffers`, as dictionaries do not have
`shape` and `dtype` attributes. This will lead to `torch_shape` and
`torch_dtype` not being defined, causing another `NameError`.
Here is a suggested implementation that fixes these issues. Please note that
you still need to define `merged_state` before this block.
```suggestion
info = merged_state.get(spec.target)
if info is None:
short_name = spec.target.split(".")[-1]
info = merged_state.get(short_name)
if info is None:
raise KeyError(f"Missing target in state_dict or extra
buffers: {spec.target}")
if isinstance(info, torch.Tensor):
torch_shape = info.shape
torch_dtype = info.dtype
elif isinstance(info, dict):
torch_shape = info["shape"]
torch_dtype = info["dtype"]
else:
raise TypeError(f"Unsupported type for buffer/parameter
info: {type(info)}")
```
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1389,6 +1389,11 @@ def create_input_vars(
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
range_constraints = {}
+ extra_buffers={
+ "position_ids": {"shape":(1,128), "dtype":torch.int64},
+ "token_type_ids": {"shape":(1,128), "dtype":torch.int64},
+ }
Review Comment:

This `extra_buffers` dictionary is hardcoded, which limits the utility of
this feature to a specific model configuration (BERT with sequence length 128).
To make this more general, `extra_buffers` should be an optional parameter to
`from_exported_program` and passed down to this function. This would allow
users to provide their own non-persistent buffers for different models.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]