gemini-code-assist[bot] commented on code in PR #18516:
URL: https://github.com/apache/tvm/pull/18516#discussion_r2567909529


##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -402,109 +462,96 @@ def _lstm(self, node: fx.Node) -> relax.Var:
         # c_t = f_t * c_{t-1} + i_t * g_t
         # h_t = o_t * tanh(c_t)
         dtype = input_tensor.struct_info.dtype
-        if params and len(params) >= 4:
-            weight_ih = params[0]  # (4 * hidden_size, input_size)
-            weight_hh = params[1]  # (4 * hidden_size, hidden_size)
-            bias_ih = params[2] if has_biases else None  # (4 * hidden_size,)
-            bias_hh = params[3] if has_biases else None  # (4 * hidden_size,)
+        params_per_direction = 4 if has_biases else 2
+
+        weight_ih_fwd = params[0] if params else None
+        weight_hh_fwd = params[1] if params and len(params) > 1 else None
+        bias_ih_fwd = params[2] if params and has_biases and len(params) > 2 
else None
+        bias_hh_fwd = params[3] if params and has_biases and len(params) > 3 
else None
+
+        if bidirectional and params and len(params) >= params_per_direction * 
2:
+            weight_ih_bwd = params[params_per_direction]
+            weight_hh_bwd = params[params_per_direction + 1]
+            bias_ih_bwd = params[params_per_direction + 2] if has_biases else 
None
+            bias_hh_bwd = params[params_per_direction + 3] if has_biases else 
None
         else:
-            # Fallback: create zero weights
-            weight_ih = self.block_builder.emit(
-                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), 
dtype)
-            )
-            weight_hh = self.block_builder.emit(
-                relax.op.zeros(relax.ShapeExpr((4 * hidden_size, 
hidden_size)), dtype)
-            )
-            bias_ih = None
-            bias_hh = None
-        # Initialize hidden and cell states
+            weight_ih_bwd = None
+            weight_hh_bwd = None
+            bias_ih_bwd = None
+            bias_hh_bwd = None

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   This change removes the fallback logic for creating zero-initialized weights 
when LSTM parameters are not provided. The new implementation assigns `None` to 
weight variables, which will cause a crash inside `_lstm_cell_unroll` when 
`relax.op.permute_dims` is called on a `None` value. This appears to be a 
regression from the previous behavior.
   
   Please consider restoring the fallback logic to create zero weights for both 
forward and backward directions if they are not available in `params`.



##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -350,46 +350,106 @@ def _upsample_bicubic2d(self, node: fx.node) -> 
relax.Var:
             align_corners=align_corners,
         )
 
+    def _lstm_cell_unroll(
+        self,
+        input_reshaped,
+        weight_ih,
+        weight_hh,
+        bias_ih,
+        bias_hh,
+        h_prev,
+        c_prev,
+        seq_len,
+        hidden_size,
+        reverse=False,
+    ):
+        """Unroll LSTM cells for a single direction."""
+        weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, 
axes=[1, 0]))
+        weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, 
axes=[1, 0]))
+        outputs = []
+        time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)
+
+        for t in time_steps:
+            x_t = self.block_builder.emit(
+                relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, 
mode="clip")
+            )
+            ih_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
+            hh_gates = 
self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
+
+            gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
+            if bias_ih is not None:
+                gates = self.block_builder.emit(relax.op.add(gates, bias_ih))
+            if bias_hh is not None:
+                gates = self.block_builder.emit(relax.op.add(gates, bias_hh))
+
+            i_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[0], 
end=[hidden_size])
+            )
+            f_gate = self.block_builder.emit(
+                relax.op.strided_slice(gates, axes=[1], begin=[hidden_size], 
end=[2 * hidden_size])
+            )
+            g_gate = self.block_builder.emit(
+                relax.op.strided_slice(
+                    gates, axes=[1], begin=[2 * hidden_size], end=[3 * 
hidden_size]
+                )
+            )
+            o_gate = self.block_builder.emit(
+                relax.op.strided_slice(
+                    gates, axes=[1], begin=[3 * hidden_size], end=[4 * 
hidden_size]
+                )
+            )
+
+            i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
+            f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
+            g_t = self.block_builder.emit(relax.op.tanh(g_gate))
+            o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
+
+            c_t = self.block_builder.emit(
+                relax.op.add(relax.op.multiply(f_t, c_prev), 
relax.op.multiply(i_t, g_t))
+            )
+            h_t = self.block_builder.emit(relax.op.multiply(o_t, 
relax.op.tanh(c_t)))
+
+            outputs.append(h_t)
+            h_prev = h_t
+            c_prev = c_t
+
+        if reverse:
+            outputs = outputs[::-1]
+
+        output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
+        return output
+
     def _lstm(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         input_tensor = args[0]
         hx = args[1] if len(args) > 1 else None
         params = args[2] if len(args) > 2 else None
         has_biases = args[3] if len(args) > 3 else True
         num_layers = args[4] if len(args) > 4 else 1
-        _dropout = args[5] if len(args) > 5 else 0.0  # Not used in inference
-        _train = args[6] if len(args) > 6 else False  # Not used in inference
         bidirectional = args[7] if len(args) > 7 else False
         batch_first = args[8] if len(args) > 8 else False
-        if bidirectional:
-            raise NotImplementedError("Bidirectional LSTM is not yet 
supported")
+
         if num_layers > 1:
             raise NotImplementedError("Multi-layer LSTM is not yet supported")
+
         input_shape = self.shape_of(input_tensor)
         if batch_first:
-            # Input shape: (batch, seq_len, input_size)
             batch_size, seq_len, input_size = input_shape
         else:
-            # Input shape: (seq_len, batch, input_size)
             seq_len, batch_size, input_size = input_shape
 
-        if isinstance(seq_len, tvm.tir.IntImm):
-            seq_len = seq_len.value
-        if isinstance(batch_size, tvm.tir.IntImm):
-            batch_size = batch_size.value
-        if isinstance(input_size, tvm.tir.IntImm):
-            input_size = input_size.value
+        seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else 
seq_len
+        batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) 
else batch_size
+        input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) 
else input_size
         # Extract hidden size from the LSTM parameters
         # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
         # weight_ih shape: (4 * hidden_size, input_size)
         # weight_hh shape: (4 * hidden_size, hidden_size)
         if params and len(params) >= 2:
-            weight_ih = params[0]
-            weight_hh = params[1]
             # Extract hidden size from weight dimensions
             # weight_ih has shape (4 * hidden_size, input_size)
-            weight_ih_shape = self.shape_of(weight_ih)
-            hidden_size = weight_ih_shape[0] // 4  # 4 gates: input, forget, 
cell, output
+            weight_ih_shape = self.shape_of(params[0])
+            hidden_size = weight_ih_shape[0] // 4
         else:
             # Fallback to a default hidden size
             hidden_size = 16

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The code falls back to a default `hidden_size` of 16 when it cannot be 
inferred from the model parameters. This could lead to unexpected behavior or 
errors if the actual model has a different hidden size. It would be beneficial 
to add a warning to notify the user about this fallback, so they are aware of 
the potential discrepancy.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to