gemini-code-assist[bot] commented on code in PR #18502:
URL: https://github.com/apache/tvm/pull/18502#discussion_r2558952016


##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1324,87 +1048,31 @@ def create_convert_map(
             "zero_.default": self._zeros_inplace,
             "zeros.default": self._zeros,
             "zeros_like.default": self._zeros_like,
-            "grid_sampler_2d.default": self._grid_sampler_2d,
+            "_assert_tensor_metadata": lambda node: None,
+
             # datatype
             "to.dtype": self._to,
             "to.dtype_layout": self._to,
             "type_as.default": self._type_as,
             # other
             "getitem": self._getitem,
             "item.default": self._item,
-            "sym_size.int": self._sym_size_int,
-            "_local_scalar_dense.default": self._item,
         }
 
-    def _process_derived_symbol(
-        self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
-    ) -> Tuple[str, Optional[tvm.tir.PrimExpr]]:
-        """Process a sympy symbol to generate a descriptive name and TIR 
expression."""
-        import sympy
-
-        if isinstance(symbol, sympy.Symbol):
-            return str(symbol), None
-
-        if not isinstance(symbol, (sympy.Add, sympy.Mul)):
-            return str(symbol), None
-
-        tir_expr = None
-        for arg in symbol.args:
-            if isinstance(arg, sympy.Integer):
-                term = tvm.tir.IntImm("int64", int(arg))
-            elif isinstance(arg, sympy.Symbol):
-                term = torch_symbol_to_relax_var.setdefault(
-                    str(arg), tvm.tir.SizeVar(str(arg), "int64")
-                )
-            else:
-                _, term = self._process_derived_symbol(arg, 
torch_symbol_to_relax_var)
-
-            if term is None:
-                return str(symbol), None
-
-            if tir_expr is None:
-                tir_expr = term
-            elif isinstance(symbol, sympy.Mul):
-                tir_expr = tir_expr * term
-            elif isinstance(symbol, sympy.Add):
-                tir_expr = tir_expr + term
-
-        if isinstance(tir_expr, tvm.tir.Add):
-            for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, 
tir_expr.a)]:
-                if isinstance(const, tvm.tir.IntImm) and isinstance(var, 
tvm.tir.Var):
-                    return f"{var.name}___{const.value}", tir_expr
-
-        if isinstance(tir_expr, tvm.tir.Mul):
-            for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, 
tir_expr.a)]:
-                if isinstance(const, tvm.tir.IntImm) and isinstance(var, 
tvm.tir.Var):
-                    return f"{var.name}_{const.value}", tir_expr
-
-        return str(symbol), tir_expr
-
     def create_input_vars(
         self, exported_program: torch.export.ExportedProgram
-    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, 
Tuple[int, int]]]:
+    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
         """Create relax input vars."""
         parameters_buffers_constants = OrderedDict()
         user_inputs = OrderedDict()
         torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
-        range_constraints = {}
-
-        if hasattr(exported_program, "range_constraints"):
-            for symbol, value_range in 
exported_program.range_constraints.items():
-                if hasattr(value_range, "lower") and hasattr(value_range, 
"upper"):
-                    try:
-                        lower = int(value_range.lower)
-                        upper = int(value_range.upper)
-
-                        symbol_name, _ = self._process_derived_symbol(
-                            symbol, torch_symbol_to_relax_var
-                        )
-                        range_constraints[symbol_name] = (lower, upper)
-
-                    except (OverflowError, AttributeError, TypeError):
-                        continue
-
+        
+        extra_buffers={
+            "position_ids": {"shape":(1,128), "dtype":torch.int64},
+            "token_type_ids": {"shape":(1.128), "dtype":torch.int64},

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   There is a typo in the shape for `token_type_ids`. It is written as 
`(1.128)`, which is a syntax error and will cause a crash. It should be a tuple 
`(1, 128)`.
   
   ```suggestion
               "token_type_ids": {"shape":(1, 128), "dtype":torch.int64},
   ```



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1324,87 +1048,31 @@ def create_convert_map(
             "zero_.default": self._zeros_inplace,
             "zeros.default": self._zeros,
             "zeros_like.default": self._zeros_like,
-            "grid_sampler_2d.default": self._grid_sampler_2d,
+            "_assert_tensor_metadata": lambda node: None,
+
             # datatype
             "to.dtype": self._to,
             "to.dtype_layout": self._to,
             "type_as.default": self._type_as,
             # other
             "getitem": self._getitem,
             "item.default": self._item,
-            "sym_size.int": self._sym_size_int,
-            "_local_scalar_dense.default": self._item,
         }
 
-    def _process_derived_symbol(
-        self, symbol, torch_symbol_to_relax_var: Dict[str, tvm.tir.Var]
-    ) -> Tuple[str, Optional[tvm.tir.PrimExpr]]:
-        """Process a sympy symbol to generate a descriptive name and TIR 
expression."""
-        import sympy
-
-        if isinstance(symbol, sympy.Symbol):
-            return str(symbol), None
-
-        if not isinstance(symbol, (sympy.Add, sympy.Mul)):
-            return str(symbol), None
-
-        tir_expr = None
-        for arg in symbol.args:
-            if isinstance(arg, sympy.Integer):
-                term = tvm.tir.IntImm("int64", int(arg))
-            elif isinstance(arg, sympy.Symbol):
-                term = torch_symbol_to_relax_var.setdefault(
-                    str(arg), tvm.tir.SizeVar(str(arg), "int64")
-                )
-            else:
-                _, term = self._process_derived_symbol(arg, 
torch_symbol_to_relax_var)
-
-            if term is None:
-                return str(symbol), None
-
-            if tir_expr is None:
-                tir_expr = term
-            elif isinstance(symbol, sympy.Mul):
-                tir_expr = tir_expr * term
-            elif isinstance(symbol, sympy.Add):
-                tir_expr = tir_expr + term
-
-        if isinstance(tir_expr, tvm.tir.Add):
-            for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, 
tir_expr.a)]:
-                if isinstance(const, tvm.tir.IntImm) and isinstance(var, 
tvm.tir.Var):
-                    return f"{var.name}___{const.value}", tir_expr
-
-        if isinstance(tir_expr, tvm.tir.Mul):
-            for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, 
tir_expr.a)]:
-                if isinstance(const, tvm.tir.IntImm) and isinstance(var, 
tvm.tir.Var):
-                    return f"{var.name}_{const.value}", tir_expr
-
-        return str(symbol), tir_expr
-
     def create_input_vars(
         self, exported_program: torch.export.ExportedProgram
-    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, 
Tuple[int, int]]]:
+    ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
         """Create relax input vars."""
         parameters_buffers_constants = OrderedDict()
         user_inputs = OrderedDict()
         torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
-        range_constraints = {}
-
-        if hasattr(exported_program, "range_constraints"):
-            for symbol, value_range in 
exported_program.range_constraints.items():
-                if hasattr(value_range, "lower") and hasattr(value_range, 
"upper"):
-                    try:
-                        lower = int(value_range.lower)
-                        upper = int(value_range.upper)
-
-                        symbol_name, _ = self._process_derived_symbol(
-                            symbol, torch_symbol_to_relax_var
-                        )
-                        range_constraints[symbol_name] = (lower, upper)
-
-                    except (OverflowError, AttributeError, TypeError):
-                        continue
-
+        
+        extra_buffers={
+            "position_ids": {"shape":(1,128), "dtype":torch.int64},
+            "token_type_ids": {"shape":(1.128), "dtype":torch.int64},
+        }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `extra_buffers` dictionary is hardcoded to solve a specific Hugging Face 
model issue. To make this solution more general and reusable for other models 
that might have similar issues, consider allowing users to pass a dictionary of 
extra buffers as an optional argument to `from_exported_program`.



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1533,14 +1205,18 @@ def from_exported_program(
                 if tensor_name == spec.target:
                     bind_name = spec.arg.name
                     break
-            binding[bind_name] = 
self._convert_pytorch_tensor_to_tvm(tensor_value)
+            try:
+                binding[bind_name] = 
tvm.runtime.from_dlpack(tensor_value.detach())
+            except RuntimeError:
+                tensor_cpu = tensor_value.detach().cpu().contiguous()
+                binding[bind_name] = tvm.runtime.tensor(tensor_cpu.numpy())
 
         mod = self.block_builder.get()
         mod = relax.transform.BindParams("main", binding)(mod)
 
         if keep_params_as_input:
             parameters = dict(exported_program.named_parameters())
-            params = [self._convert_pytorch_tensor_to_tvm(p) for p in 
parameters.values()]
+            params = [tvm.runtime.from_dlpack(p.detach()) for p in 
parameters.values()]

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The tensor conversion for `params` when `keep_params_as_input=True` only 
uses `tvm.runtime.from_dlpack` without a fallback. This is inconsistent with 
the conversion logic for bindings a few lines above, which includes a 
`try-except` block to handle cases where DLPack conversion fails. To prevent 
potential `RuntimeError`, it would be safer to use the same robust conversion 
logic here.
   
   ```python
               params = []
               for p in parameters.values():
                   try:
                       params.append(tvm.runtime.from_dlpack(p.detach()))
                   except RuntimeError:
                       tensor_cpu = p.detach().cpu().contiguous()
                       params.append(tvm.runtime.tensor(tensor_cpu.numpy()))
   ```



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1418,23 +1086,32 @@ def create_input_vars(
                         break
             else:
                 # PARAMETER or BUFFER
-                torch_shape = exported_program.state_dict[spec.target].shape
-                torch_dtype = exported_program.state_dict[spec.target].dtype
-
-            relax_shape = []
-            for s in torch_shape:
-                if isinstance(s, torch.SymInt):
-                    sympy_node = s.node.expr if hasattr(s.node, "expr") else 
s.node
-                    symbol_name, _ = self._process_derived_symbol(
-                        sympy_node, torch_symbol_to_relax_var
-                    )
-
-                    size_var = torch_symbol_to_relax_var.setdefault(
-                        symbol_name, tvm.tir.SizeVar(symbol_name, "int64")
-                    )
-                    relax_shape.append(size_var)
-                else:
-                    relax_shape.append(s)
+                info= None
+                if spec.target in merged_state:
+                    info=merged_state[spec.target]
+                elif spec.target.split(".")[-1] in merged_state:
+                    info = merged_state[spec.target.split(".")[-1]]
+                    
+                if info is None:
+                    raise KeyError(f"Missing target in state_dict or extra 
buffers: {spec.target}")    
+                
+                # Handle both original and extra buffer
+                if hasattr(info,"shape") and hasattr(info,"dtype"):
+                    torch_shape=info.shape
+                    torch_dtype=info.dtype
+                    
+                
+                
+                
+                
+

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   There are several consecutive empty lines here, and in other places in this 
PR (e.g., lines 1124-1127, 1151-1152). Please remove them to improve code 
readability.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to