gemini-code-assist[bot] commented on code in PR #18524:
URL: https://github.com/apache/tvm/pull/18524#discussion_r2572785994


##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1477,10 +1477,48 @@ def _pixel_shuffle(self, node: fx.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, 
upscale_factor))
 
     def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
-        transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 
3])
-        query = transpose_S_H(self.env[node.args[0]])
-        key = transpose_S_H(self.env[node.args[1]])
-        value = transpose_S_H(self.env[node.args[2]])
+        query_tensor = self.env[node.args[0]]
+        key_tensor = self.env[node.args[1]]
+        value_tensor = self.env[node.args[2]]
+        
+        # Check the dimensionality of the input tensors
+        query_ndim = len(query_tensor.struct_info.shape)
+        
+        # TVM's nn.attention requires 4D inputs (batch, seq_len, num_heads, 
head_dim)
+        # For 2D inputs (seq_len, head_dim), we need to reshape to 4D first
+        if query_ndim == 2:
+            # 2D input: (seq_len, head_dim) -> expand to (1, seq_len, 1, 
head_dim)
+            # Add batch dimension at axis 0
+            query_3d = 
self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0))
+            key_3d = self.block_builder.emit(relax.op.expand_dims(key_tensor, 
axis=0))
+            value_3d = 
self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0))
+            # Add num_heads dimension at axis 2
+            query = self.block_builder.emit(relax.op.expand_dims(query_3d, 
axis=2))
+            key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=2))
+            value = self.block_builder.emit(relax.op.expand_dims(value_3d, 
axis=2))
+            
+            # No permutation needed for 2D inputs after expanding to 4D
+            # After attention, squeeze back to 2D: (1, seq_len, 1, head_dim) 
-> (seq_len, head_dim)
+            def transpose_and_reshape_back(tensor):
+                # Squeeze num_heads dimension (axis 2)
+                tensor_3d = self.block_builder.emit(relax.op.squeeze(tensor, 
axis=[2]))
+                # Squeeze batch dimension (axis 0)
+                return self.block_builder.emit(relax.op.squeeze(tensor_3d, 
axis=[0]))

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   There appear to be a couple of issues with the dimension handling for 2D 
inputs.
   
   1.  **Dimension Expansion**: The `relax.op.nn.attention` operator expects 
inputs in `(batch, num_heads, seq_len, head_dim)` format. The current code 
expands a 2D tensor `(seq_len, head_dim)` to `(1, seq_len, 1, head_dim)`, but 
it should be `(1, 1, seq_len, head_dim)`. This can be fixed by expanding on 
`axis=1` instead of `axis=2` for the `num_heads` dimension.
   
   2.  **Squeezing Dimensions**: In `transpose_and_reshape_back`, the 
`num_heads` dimension of the attention output is at `axis=1`, not `axis=2`. 
Additionally, the two `squeeze` operations can be combined into a single call 
for conciseness.
   
   Here is a suggested fix for both points.
   
   ```suggestion
               # 2D input: (seq_len, head_dim) -> expand to (1, 1, seq_len, 
head_dim)
               # Add batch dimension at axis 0
               query_3d = 
self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0))
               key_3d = 
self.block_builder.emit(relax.op.expand_dims(key_tensor, axis=0))
               value_3d = 
self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0))
               # Add num_heads dimension at axis 1
               query = self.block_builder.emit(relax.op.expand_dims(query_3d, 
axis=1))
               key = self.block_builder.emit(relax.op.expand_dims(key_3d, 
axis=1))
               value = self.block_builder.emit(relax.op.expand_dims(value_3d, 
axis=1))
               
               # After attention, squeeze back to 2D: (1, 1, seq_len, head_dim) 
-> (seq_len, head_dim)
               def transpose_and_reshape_back(tensor):
                   # Squeeze batch and num_heads dimensions
                   return self.block_builder.emit(relax.op.squeeze(tensor, 
axis=[0, 1]))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to