gemini-code-assist[bot] commented on code in PR #18507:
URL: https://github.com/apache/tvm/pull/18507#discussion_r2565140092


##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -364,10 +364,19 @@ def _prelu(self, node: fx.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis))
 
     def _round(self, node: fx.Node) -> relax.Expr:
-        if node.kwargs.get("decimals", 0) != 0:
-            raise ValueError("specifying decimals for round is not supported 
yet")
         arg = self.env[node.args[0]]
-        return self.block_builder.emit(relax.op.round(arg))
+        decimals = node.kwargs.get("decimals", 0)
+
+        if decimals == 0:
+            return self.block_builder.emit(relax.op.round(arg))
+
+        # For decimals != 0, use: round(x * 10^decimals) / 10^decimals
+        dtype = arg.struct_info.dtype
+        scale = relax.const(10**decimals, dtype)
+        scaled = relax.op.multiply(arg, scale)
+        rounded = relax.op.round(scaled)
+        result = relax.op.divide(rounded, scale)
+        return self.block_builder.emit(result)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The current implementation for `_round` with `decimals != 0` has a couple of 
issues:
   
   1.  **Incorrect handling of integer inputs**:
       *   If the input is an integer tensor and `decimals` is negative, 
`10**decimals` will be a fraction. Casting this to an integer `dtype` will 
result in `0`, leading to a division-by-zero error.
       *   If the input is an integer tensor and `decimals` is positive, 
`torch.round` is a no-op and should return an integer tensor. The current 
implementation performs calculations that result in a float tensor.
   2.  **Incorrect graph construction**: The intermediate results of `multiply` 
and `round` are not emitted to the block builder. This creates a single nested 
`Call` node, which is inconsistent with the dataflow style used elsewhere and 
what the accompanying test expects.
   
   I've provided a suggestion to fix these issues by:
   - Handling integer inputs with positive `decimals` as a no-op.
   - Casting integer inputs to float for calculations when `decimals` is not 
positive, to prevent errors and match PyTorch's behavior.
   - Emitting each operation to build a proper dataflow graph.
   
   ```python
           # For decimals != 0, use: round(x * 10^decimals) / 10^decimals
           input_dtype = arg.struct_info.dtype
           is_integer_input = "int" in input_dtype
   
           if is_integer_input and decimals > 0:
               # torch.round is a no-op for integer tensors with positive 
decimals.
               return arg
   
           # We need to use floating point for the calculation to avoid 
truncation
           # and to match PyTorch's behavior of returning a float tensor for
           # integer inputs with non-positive decimals.
           if is_integer_input:
               compute_dtype = "float32"
               arg_for_calc = self.block_builder.emit(relax.op.astype(arg, 
compute_dtype))
           else:
               compute_dtype = input_dtype
               arg_for_calc = arg
   
           scale = relax.const(10**decimals, compute_dtype)
           scaled = self.block_builder.emit(relax.op.multiply(arg_for_calc, 
scale))
           rounded = self.block_builder.emit(relax.op.round(scaled))
           result = self.block_builder.emit(relax.op.divide(rounded, scale))
           return result
   ```



##########
tests/python/relax/test_frontend_from_fx.py:
##########
@@ -6273,5 +6273,80 @@ def forward(self, input):
     )
 
 
+def test_round():
+    input_info = [([3, 4], "float32")]
+
+    class Round(Module):
+        def __init__(self, decimals=0):
+            super().__init__()
+            self.decimals = decimals
+
+        def forward(self, x):
+            if self.decimals == 0:
+                return torch.round(x)
+            else:
+                return torch.round(x, decimals=self.decimals)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 4), dtype="float32"),
+        ) -> R.Tensor((3, 4), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.round(inp_0)
+                gv: R.Tensor((3, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 4), dtype="float32"),
+        ) -> R.Tensor((3, 4), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.multiply(inp_0, 
R.const(100.0, "float32"))
+                lv1: R.Tensor((3, 4), dtype="float32") = R.round(lv)
+                lv2: R.Tensor((3, 4), dtype="float32") = R.divide(lv1, 
R.const(100.0, "float32"))
+                gv: R.Tensor((3, 4), dtype="float32") = lv2
+                R.output(gv)
+            return gv
+
+    rounds = [
+        (0, Expected1),
+        (2, Expected2),
+    ]
+
+    for decimals, expected in rounds:
+        verify_model(Round(decimals), input_info, {}, expected)
+
+    # Test numerical accuracy with decimals
+    test_data = torch.tensor(
+        [
+            [1.2345, 2.3456, 3.4567, 4.5678],
+            [5.6789, 6.7890, 7.8901, 8.9012],
+            [9.1234, 10.2345, 11.3456, 12.4567],
+        ]
+    )
+
+    for decimals in [0, 1, 2, 3]:
+        torch_model = Round(decimals)
+        graph_model = fx.symbolic_trace(torch_model)
+        with torch.no_grad():
+            mod = from_fx(graph_model, input_info)
+
+        target = tvm.target.Target("llvm")
+        ex = relax.build(mod, target)
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+
+        torch_result = torch_model(test_data).numpy()
+        tvm_input = tvm.runtime.tensor(test_data.numpy())
+        tvm_result = vm["main"](tvm_input).numpy()
+
+        # Use relaxed tolerance due to floating-point precision in decimal 
operations
+        tvm.testing.assert_allclose(tvm_result, torch_result, rtol=1e-3, 
atol=1e-3)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The added tests for `round` only cover floating-point inputs and positive 
`decimals`. The implementation of `round` has different behavior for integer 
inputs and for negative `decimals`, which are not currently tested.
   
   Please extend the tests to cover these cases:
   -   Integer-type inputs (e.g., `int32`).
   -   Negative `decimals` for both float and integer inputs.
   
   For example, you could add test cases for `torch.round(torch.tensor([128]), 
decimals=-1)` which should be `130.0`, and `torch.round(torch.tensor([128]), 
decimals=1)` which should be `128`.
   
   This will ensure the implementation is robust and handles all supported 
scenarios correctly.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to