gemini-code-assist[bot] commented on code in PR #18502:
URL: https://github.com/apache/tvm/pull/18502#discussion_r2558952016
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1324,87 +1048,31 @@ def create_convert_map(
"zero_.default": self._zeros_inplace,
"zeros.default": self._zeros,
"zeros_like.default": self._zeros_like,
- "grid_sampler_2d.default": self._grid_sampler_2d,
+ "_assert_tensor_metadata": lambda node: None,
+
# datatype
"to.dtype": self._to,
"to.dtype_layout": self._to,
"type_as.default": self._type_as,
# other
"getitem": self._getitem,
"item.default": self._item,
- "sym_size.int": self._sym_size_int,
- "_local_scalar_dense.default": self._item,
}
- def _process_derived_symbol(
- self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
- ) -> Tuple[str, Optional[tvm.tir.PrimExpr]]:
- """Process a sympy symbol to generate a descriptive name and TIR
expression."""
- import sympy
-
- if isinstance(symbol, sympy.Symbol):
- return str(symbol), None
-
- if not isinstance(symbol, (sympy.Add, sympy.Mul)):
- return str(symbol), None
-
- tir_expr = None
- for arg in symbol.args:
- if isinstance(arg, sympy.Integer):
- term = tvm.tir.IntImm("int64", int(arg))
- elif isinstance(arg, sympy.Symbol):
- term = torch_symbol_to_relax_var.setdefault(
- str(arg), tvm.tir.SizeVar(str(arg), "int64")
- )
- else:
- _, term = self._process_derived_symbol(arg,
torch_symbol_to_relax_var)
-
- if term is None:
- return str(symbol), None
-
- if tir_expr is None:
- tir_expr = term
- elif isinstance(symbol, sympy.Mul):
- tir_expr = tir_expr * term
- elif isinstance(symbol, sympy.Add):
- tir_expr = tir_expr + term
-
- if isinstance(tir_expr, tvm.tir.Add):
- for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b,
tir_expr.a)]:
- if isinstance(const, tvm.tir.IntImm) and isinstance(var,
tvm.tir.Var):
- return f"{var.name}___{const.value}", tir_expr
-
- if isinstance(tir_expr, tvm.tir.Mul):
- for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b,
tir_expr.a)]:
- if isinstance(const, tvm.tir.IntImm) and isinstance(var,
tvm.tir.Var):
- return f"{var.name}_{const.value}", tir_expr
-
- return str(symbol), tir_expr
-
def create_input_vars(
self, exported_program: torch.export.ExportedProgram
- ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str,
Tuple[int, int]]]:
+ ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
- range_constraints = {}
-
- if hasattr(exported_program, "range_constraints"):
- for symbol, value_range in
exported_program.range_constraints.items():
- if hasattr(value_range, "lower") and hasattr(value_range,
"upper"):
- try:
- lower = int(value_range.lower)
- upper = int(value_range.upper)
-
- symbol_name, _ = self._process_derived_symbol(
- symbol, torch_symbol_to_relax_var
- )
- range_constraints[symbol_name] = (lower, upper)
-
- except (OverflowError, AttributeError, TypeError):
- continue
-
+
+ extra_buffers={
+ "position_ids": {"shape":(1,128), "dtype":torch.int64},
+ "token_type_ids": {"shape":(1.128), "dtype":torch.int64},
Review Comment:

There is a typo in the shape for `token_type_ids`. It is written as
`(1.128)`, which is a syntax error and will cause a crash. It should be a tuple
`(1, 128)`.
```suggestion
"token_type_ids": {"shape":(1, 128), "dtype":torch.int64},
```
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1324,87 +1048,31 @@ def create_convert_map(
"zero_.default": self._zeros_inplace,
"zeros.default": self._zeros,
"zeros_like.default": self._zeros_like,
- "grid_sampler_2d.default": self._grid_sampler_2d,
+ "_assert_tensor_metadata": lambda node: None,
+
# datatype
"to.dtype": self._to,
"to.dtype_layout": self._to,
"type_as.default": self._type_as,
# other
"getitem": self._getitem,
"item.default": self._item,
- "sym_size.int": self._sym_size_int,
- "_local_scalar_dense.default": self._item,
}
- def _process_derived_symbol(
- self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
- ) -> Tuple[str, Optional[tvm.tir.PrimExpr]]:
- """Process a sympy symbol to generate a descriptive name and TIR
expression."""
- import sympy
-
- if isinstance(symbol, sympy.Symbol):
- return str(symbol), None
-
- if not isinstance(symbol, (sympy.Add, sympy.Mul)):
- return str(symbol), None
-
- tir_expr = None
- for arg in symbol.args:
- if isinstance(arg, sympy.Integer):
- term = tvm.tir.IntImm("int64", int(arg))
- elif isinstance(arg, sympy.Symbol):
- term = torch_symbol_to_relax_var.setdefault(
- str(arg), tvm.tir.SizeVar(str(arg), "int64")
- )
- else:
- _, term = self._process_derived_symbol(arg,
torch_symbol_to_relax_var)
-
- if term is None:
- return str(symbol), None
-
- if tir_expr is None:
- tir_expr = term
- elif isinstance(symbol, sympy.Mul):
- tir_expr = tir_expr * term
- elif isinstance(symbol, sympy.Add):
- tir_expr = tir_expr + term
-
- if isinstance(tir_expr, tvm.tir.Add):
- for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b,
tir_expr.a)]:
- if isinstance(const, tvm.tir.IntImm) and isinstance(var,
tvm.tir.Var):
- return f"{var.name}___{const.value}", tir_expr
-
- if isinstance(tir_expr, tvm.tir.Mul):
- for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b,
tir_expr.a)]:
- if isinstance(const, tvm.tir.IntImm) and isinstance(var,
tvm.tir.Var):
- return f"{var.name}_{const.value}", tir_expr
-
- return str(symbol), tir_expr
-
def create_input_vars(
self, exported_program: torch.export.ExportedProgram
- ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str,
Tuple[int, int]]]:
+ ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
- range_constraints = {}
-
- if hasattr(exported_program, "range_constraints"):
- for symbol, value_range in
exported_program.range_constraints.items():
- if hasattr(value_range, "lower") and hasattr(value_range,
"upper"):
- try:
- lower = int(value_range.lower)
- upper = int(value_range.upper)
-
- symbol_name, _ = self._process_derived_symbol(
- symbol, torch_symbol_to_relax_var
- )
- range_constraints[symbol_name] = (lower, upper)
-
- except (OverflowError, AttributeError, TypeError):
- continue
-
+
+ extra_buffers={
+ "position_ids": {"shape":(1,128), "dtype":torch.int64},
+ "token_type_ids": {"shape":(1.128), "dtype":torch.int64},
+ }
Review Comment:

The `extra_buffers` dictionary is hardcoded to solve a specific Hugging Face
model issue. To make this solution more general and reusable for other models
that might have similar issues, consider allowing users to pass a dictionary of
extra buffers as an optional argument to `from_exported_program`.
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1533,14 +1205,18 @@ def from_exported_program(
if tensor_name == spec.target:
bind_name = spec.arg.name
break
- binding[bind_name] =
self._convert_pytorch_tensor_to_tvm(tensor_value)
+ try:
+ binding[bind_name] =
tvm.runtime.from_dlpack(tensor_value.detach())
+ except RuntimeError:
+ tensor_cpu = tensor_value.detach().cpu().contiguous()
+ binding[bind_name] = tvm.runtime.tensor(tensor_cpu.numpy())
mod = self.block_builder.get()
mod = relax.transform.BindParams("main", binding)(mod)
if keep_params_as_input:
parameters = dict(exported_program.named_parameters())
- params = [self._convert_pytorch_tensor_to_tvm(p) for p in
parameters.values()]
+ params = [tvm.runtime.from_dlpack(p.detach()) for p in
parameters.values()]
Review Comment:

The tensor conversion for `params` when `keep_params_as_input=True` only
uses `tvm.runtime.from_dlpack` without a fallback. This is inconsistent with
the conversion logic for bindings a few lines above, which includes a
`try-except` block to handle cases where DLPack conversion fails. To prevent
potential `RuntimeError`, it would be safer to use the same robust conversion
logic here.
```python
params = []
for p in parameters.values():
try:
params.append(tvm.runtime.from_dlpack(p.detach()))
except RuntimeError:
tensor_cpu = p.detach().cpu().contiguous()
params.append(tvm.runtime.tensor(tensor_cpu.numpy()))
```
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1418,23 +1086,32 @@ def create_input_vars(
break
else:
# PARAMETER or BUFFER
- torch_shape = exported_program.state_dict[spec.target].shape
- torch_dtype = exported_program.state_dict[spec.target].dtype
-
- relax_shape = []
- for s in torch_shape:
- if isinstance(s, torch.SymInt):
- sympy_node = s.node.expr if hasattr(s.node, "expr") else
s.node
- symbol_name, _ = self._process_derived_symbol(
- sympy_node, torch_symbol_to_relax_var
- )
-
- size_var = torch_symbol_to_relax_var.setdefault(
- symbol_name, tvm.tir.SizeVar(symbol_name, "int64")
- )
- relax_shape.append(size_var)
- else:
- relax_shape.append(s)
+ info= None
+ if spec.target in merged_state:
+ info=merged_state[spec.target]
+ elif spec.target.split(".")[-1] in merged_state:
+ info = merged_state[spec.target.split(".")[-1]]
+
+ if info is None:
+ raise KeyError(f"Missing target in state_dict or extra
buffers: {spec.target}")
+
+ # Handle both original and extra buffer
+ if hasattr(info,"shape") and hasattr(info,"dtype"):
+ torch_shape=info.shape
+ torch_dtype=info.dtype
+
+
+
+
+
+
Review Comment:

There are several consecutive empty lines here, and in other places in this
PR (e.g., lines 1124-1127, 1151-1152). Please remove them to improve code
readability.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]