This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2e6ee08eaf [BugFix] Align `tir.round` to ties-to-even across all 
backends (#19368)
2e6ee08eaf is described below

commit 2e6ee08eafc328b82a965e49d106d61828c1d623
Author: Soowon Jeong <[email protected]>
AuthorDate: Thu Apr 9 03:35:22 2026 +0900

    [BugFix] Align `tir.round` to ties-to-even across all backends (#19368)
    
    ## Problem
    
    `tir.round` constant-folds using `std::nearbyint` (IEEE 754
    ties-to-even), but all backends lower it to platform `round()` which
    uses ties-away-from-zero. This means compiled code can produce different
    results from constant-folded code for midpoint values:
    
    | Input | Constant-fold (ties-to-even) | Compiled (ties-away) |
    |-------|-----|------|
    | 0.5   | 0.0 | 1.0  |
    | 2.5   | 2.0 | 3.0  |
    | -0.5  | 0.0 | -1.0 |
    
    This was identified as a follow-up to #19367 — see [this
    comment](https://github.com/apache/tvm/pull/19367#issuecomment-4201800320).
    
    ## Fix
    
    Align all backends to use ties-to-even intrinsics, matching the
    constant-folding behavior:
    
    | Backend | Before | After |
    |---------|--------|-------|
    | LLVM/ROCm/Hexagon | `llvm::Intrinsic::round` |
    `llvm::Intrinsic::nearbyint` |
    | NVPTX | `__nv_round[f]` | `__nv_nearbyint[f]` |
    | CUDA | `round`/`roundf` | `nearbyint`/`nearbyintf` (f16/bf16 already
    used `hrint`) |
    | Metal/OpenCL | `round` | `rint` |
    | Vulkan/SPIR-V | `GLSLstd450Round` | `GLSLstd450RoundEven` |
    
    Also fixes OpenCL codegen where `tir.nearbyint` was incorrectly mapped
    to OpenCL `round()` instead of `rint()`.
    
    Updates `op.h` documentation to explicitly state ties-to-even semantics
    for both `round()` and `nearbyint()`.
    
    ## Testing
    
    ```
    python -m pytest tests/python/tirx-base/test_tir_intrin.py -xvs
    ```
    
    New `test_round_ties_to_even` verifies midpoint inputs `[0.5, 1.5, 2.5,
    3.5, -0.5, -1.5, -2.5, -3.5]` produce ties-to-even results on the LLVM
    backend. All 12 tests pass (10 passed, 2 skipped for CUDA).
    
    ---------
    
    Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
---
 include/tvm/tirx/op.h                      | 12 +++++++++---
 python/tvm/topi/testing/roi_pool_python.py | 10 ++++++----
 python/tvm/topi/vision/roi_pool.py         | 15 +++++++++++----
 src/target/llvm/intrin_rule_hexagon.cc     |  2 +-
 src/target/llvm/intrin_rule_llvm.cc        |  2 +-
 src/target/llvm/intrin_rule_nvptx.cc       | 10 +++++++++-
 src/target/llvm/intrin_rule_rocm.cc        |  2 +-
 src/target/source/codegen_opencl.cc        |  2 +-
 src/target/source/intrin_rule_cuda.cc      |  3 +++
 src/target/source/intrin_rule_metal.cc     | 11 ++++++++++-
 src/target/source/intrin_rule_opencl.cc    | 11 ++++++++++-
 src/target/spirv/intrin_rule_spirv.cc      |  6 ++++--
 tests/python/tirx-base/test_tir_intrin.py  | 21 +++++++++++++++++++++
 13 files changed, 87 insertions(+), 20 deletions(-)

diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h
index 66d9d932b3..c953f12e38 100644
--- a/include/tvm/tirx/op.h
+++ b/include/tvm/tirx/op.h
@@ -654,7 +654,11 @@ TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span());
 TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
 
 /*!
- * \brief Calculate round(x)
+ * \brief Round x to the nearest integer, ties to even.
+ *
+ * Uses IEEE 754 default rounding mode (ties-to-even / banker's rounding).
+ * Constant-folding and all backends consistently use std::nearbyint semantics.
+ *
  * \param x The input expression.
  * \param span The location of this operation in the source.
  * \return The result expression.
@@ -662,11 +666,13 @@ TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
 TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
 
 /*!
- * \brief Calculates std::nearbyint(x)
+ * \brief Round x to the nearest integer, ties to even.
+ *
+ * Equivalent to round(). Both use IEEE 754 default rounding mode 
(ties-to-even).
+ *
  * \param x The input expression.
  * \param span The location of this operation in the source.
  * \return The result expression.
- * This is a faster alternate to round.
  */
 TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());
 
diff --git a/python/tvm/topi/testing/roi_pool_python.py 
b/python/tvm/topi/testing/roi_pool_python.py
index 0f7120b466..583800e982 100644
--- a/python/tvm/topi/testing/roi_pool_python.py
+++ b/python/tvm/topi/testing/roi_pool_python.py
@@ -36,10 +36,12 @@ def roi_pool_nchw_python(a_np, rois_np, pooled_size, 
spatial_scale):
     for i in range(num_roi):
         roi = rois_np[i]
         batch_index = int(roi[0])
-        roi_start_w = round(roi[1] * spatial_scale)
-        roi_start_h = round(roi[2] * spatial_scale)
-        roi_end_w = round(roi[3] * spatial_scale)
-        roi_end_h = round(roi[4] * spatial_scale)
+        # Use ties-away-from-zero rounding to match ONNX runtime (std::round 
semantics).
+        # Python's built-in round() uses ties-to-even, so use floor(x + 0.5) 
explicitly.
+        roi_start_w = math.floor(roi[1] * spatial_scale + 0.5)
+        roi_start_h = math.floor(roi[2] * spatial_scale + 0.5)
+        roi_end_w = math.floor(roi[3] * spatial_scale + 0.5)
+        roi_end_h = math.floor(roi[4] * spatial_scale + 0.5)
         roi_h = max(roi_end_h - roi_start_h + 1, 1)
         roi_w = max(roi_end_w - roi_start_w + 1, 1)
 
diff --git a/python/tvm/topi/vision/roi_pool.py 
b/python/tvm/topi/vision/roi_pool.py
index 54a4aeba50..2e86066c5b 100644
--- a/python/tvm/topi/vision/roi_pool.py
+++ b/python/tvm/topi/vision/roi_pool.py
@@ -36,12 +36,19 @@ def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
 
     neg_inf = tvm.tirx.const(float("-inf"), data.dtype)
 
+    def _round_away(x):
+        # ONNX MaxRoiPool spec uses ties-away-from-zero rounding for coordinate
+        # mapping (matching std::round semantics in the reference 
implementation).
+        # Use floor(x + 0.5) to be explicit and independent of tir.round 
semantics.
+        half = tvm.tirx.const(0.5, roi_dtype)
+        return te.floor(x + half)
+
     def _bin_bounds(i, ph, pw):
         roi = rois[i]
-        roi_start_w = te.round(roi[1] * spatial_scale).astype("int32")
-        roi_start_h = te.round(roi[2] * spatial_scale).astype("int32")
-        roi_end_w = te.round(roi[3] * spatial_scale).astype("int32")
-        roi_end_h = te.round(roi[4] * spatial_scale).astype("int32")
+        roi_start_w = _round_away(roi[1] * spatial_scale).astype("int32")
+        roi_start_h = _round_away(roi[2] * spatial_scale).astype("int32")
+        roi_end_w = _round_away(roi[3] * spatial_scale).astype("int32")
+        roi_end_h = _round_away(roi[4] * spatial_scale).astype("int32")
 
         roi_h = te.max(roi_end_h - roi_start_h + 1, tvm.tirx.const(1, "int32"))
         roi_w = te.max(roi_end_w - roi_start_w + 1, tvm.tirx.const(1, "int32"))
diff --git a/src/target/llvm/intrin_rule_hexagon.cc 
b/src/target/llvm/intrin_rule_hexagon.cc
index 79e91c20a3..e330dba4e1 100644
--- a/src/target/llvm/intrin_rule_hexagon.cc
+++ b/src/target/llvm/intrin_rule_hexagon.cc
@@ -93,7 +93,7 @@ TVM_REGISTER_OP("tirx.fabs")
 
 TVM_REGISTER_OP("tirx.round")
     .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
-                               
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+                               
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
 
 TVM_REGISTER_OP("tirx.ctpop")
     .set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
diff --git a/src/target/llvm/intrin_rule_llvm.cc 
b/src/target/llvm/intrin_rule_llvm.cc
index 468f0fb7b5..3244deab87 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -90,7 +90,7 @@ TVM_REGISTER_OP("tirx.fabs")
 
 TVM_REGISTER_OP("tirx.round")
     .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
-                               
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+                               
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
 
 TVM_REGISTER_OP("tirx.nearbyint")
     .set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
diff --git a/src/target/llvm/intrin_rule_nvptx.cc 
b/src/target/llvm/intrin_rule_nvptx.cc
index 4560205a60..0707a9a787 100644
--- a/src/target/llvm/intrin_rule_nvptx.cc
+++ b/src/target/llvm/intrin_rule_nvptx.cc
@@ -66,7 +66,15 @@ TVM_REGISTER_OP("tirx.ceil")
     .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", 
DispatchPureExternLibDevice);
 
 TVM_REGISTER_OP("tirx.round")
-    .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", 
DispatchPureExternLibDevice);
+    .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", [](const PrimExpr& e) 
-> PrimExpr {
+      // Redirect to nearbyint (ties-to-even) to match constant-folding 
semantics.
+      using namespace tirx;
+      const CallNode* call = e.as<CallNode>();
+      TVM_FFI_ICHECK(call != nullptr);
+      auto nearbyint_op = Op::Get("tirx.nearbyint");
+      auto new_call = Call(call->dtype, nearbyint_op, call->args);
+      return DispatchPureExternLibDevice(new_call);
+    });
 
 TVM_REGISTER_OP("tirx.nearbyint")
     .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", 
DispatchPureExternLibDevice);
diff --git a/src/target/llvm/intrin_rule_rocm.cc 
b/src/target/llvm/intrin_rule_rocm.cc
index 6d72c77783..4d542c1299 100644
--- a/src/target/llvm/intrin_rule_rocm.cc
+++ b/src/target/llvm/intrin_rule_rocm.cc
@@ -132,7 +132,7 @@ TVM_REGISTER_OP("tirx.ceil")
 
 TVM_REGISTER_OP("tirx.round")
     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
-                               
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+                               
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
 
 TVM_REGISTER_OP("tirx.nearbyint")
     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
diff --git a/src/target/source/codegen_opencl.cc 
b/src/target/source/codegen_opencl.cc
index 5d9135ef22..b2f78c2dbd 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -526,7 +526,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, 
std::ostream& os) {
       this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), 
"atomic_add_float_emu", op->args,
                             true, os);
     } else if (func->value == "nearbyint") {
-      this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "round", 
op->args, true, os);
+      this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "rint", 
op->args, true, os);
     } else {
       if (func->value == "atomic_add") {
         enable_atomics_ = true;
diff --git a/src/target/source/intrin_rule_cuda.cc 
b/src/target/source/intrin_rule_cuda.cc
index bcd158432b..d38db9fe83 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -37,8 +37,11 @@ struct CUDAMath {
     if (t.is_float()) {
       switch (t.bits()) {
         case 64:
+          // Use nearbyint (ties-to-even) for round to match constant-folding 
semantics.
+          if (name == "round") return "nearbyint";
           return name;
         case 32:
+          if (name == "round") return "nearbyintf";
           return name + 'f';
         case 16: {
           if (name == "fabs") {
diff --git a/src/target/source/intrin_rule_metal.cc 
b/src/target/source/intrin_rule_metal.cc
index d61bf1256f..cea19519ca 100644
--- a/src/target/source/intrin_rule_metal.cc
+++ b/src/target/source/intrin_rule_metal.cc
@@ -68,7 +68,16 @@ TVM_REGISTER_OP("tirx.fabs")
     .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
 
 TVM_REGISTER_OP("tirx.round")
-    .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
+    .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", [](const PrimExpr& e) 
-> PrimExpr {
+      // Metal's rint() uses ties-to-even, matching constant-folding semantics.
+      const tirx::CallNode* call = e.as<tirx::CallNode>();
+      TVM_FFI_ICHECK(call != nullptr);
+      ffi::Array<PrimExpr> new_args = {tirx::StringImm("rint")};
+      for (auto arg : call->args) {
+        new_args.push_back(arg);
+      }
+      return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), 
new_args);
+    });
 
 TVM_REGISTER_OP("tirx.nearbyint")
     .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
diff --git a/src/target/source/intrin_rule_opencl.cc 
b/src/target/source/intrin_rule_opencl.cc
index 85084b1a16..ba1873bde6 100644
--- a/src/target/source/intrin_rule_opencl.cc
+++ b/src/target/source/intrin_rule_opencl.cc
@@ -47,7 +47,16 @@ TVM_REGISTER_OP("tirx.fabs")
     .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
 
 TVM_REGISTER_OP("tirx.round")
-    .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
+    .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", [](const PrimExpr& e) 
-> PrimExpr {
+      // OpenCL's rint() uses ties-to-even, matching constant-folding 
semantics.
+      const tirx::CallNode* call = e.as<tirx::CallNode>();
+      TVM_FFI_ICHECK(call != nullptr);
+      ffi::Array<PrimExpr> new_args = {tirx::StringImm("rint")};
+      for (auto arg : call->args) {
+        new_args.push_back(arg);
+      }
+      return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), 
new_args);
+    });
 
 TVM_REGISTER_OP("tirx.nearbyint")
     .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", 
DispatchPureExtern<Direct>);
diff --git a/src/target/spirv/intrin_rule_spirv.cc 
b/src/target/spirv/intrin_rule_spirv.cc
index cde1e0165f..4b1ffc4b6d 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -68,10 +68,12 @@ TVM_REGISTER_OP("tirx.ceil")
     .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Ceil>);
 
 TVM_REGISTER_OP("tirx.round")
-    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Round>);
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
+                               DispatchGLSLPureIntrin<GLSLstd450RoundEven>);
 
 TVM_REGISTER_OP("tirx.nearbyint")
-    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Round>);
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
+                               DispatchGLSLPureIntrin<GLSLstd450RoundEven>);
 
 TVM_REGISTER_OP("tirx.trunc")
     .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Trunc>);
diff --git a/tests/python/tirx-base/test_tir_intrin.py 
b/tests/python/tirx-base/test_tir_intrin.py
index 0dd06dee93..30676715b8 100644
--- a/tests/python/tirx-base/test_tir_intrin.py
+++ b/tests/python/tirx-base/test_tir_intrin.py
@@ -56,6 +56,27 @@ def test_nearbyint():
     tvm.testing.assert_allclose(a_rounded.numpy(), np.rint(a.numpy()))
 
 
+def test_round_ties_to_even():
+    """Test that tir.round uses ties-to-even (banker's rounding) semantics."""
+    m = te.var("m")
+    A = te.placeholder((m,), name="A")
+    A_rounded = te.compute((m,), lambda *i: tvm.tirx.round(A(*i)), name="A")
+
+    mod = te.create_prim_func([A, A_rounded])
+    sch = tvm.s_tir.Schedule(mod)
+    func = tvm.compile(sch.mod, target="llvm")
+
+    dev = tvm.cpu(0)
+    # Midpoint values where ties-to-even and ties-away differ
+    test_values = np.array([0.5, 1.5, 2.5, 3.5, -0.5, -1.5, -2.5, -3.5], 
dtype="float32")
+    expected = np.array([0.0, 2.0, 2.0, 4.0, 0.0, -2.0, -2.0, -4.0], 
dtype="float32")
+
+    a = tvm.runtime.tensor(test_values, dev)
+    a_rounded = tvm.runtime.tensor(np.zeros(len(test_values), 
dtype="float32"), dev)
+    func(a, a_rounded)
+    tvm.testing.assert_allclose(a_rounded.numpy(), expected)
+
+
 def test_round_intrinsics_on_int():
     i = tvm.tirx.Var("i", "int32")
     for op in [tvm.tirx.round, tvm.tirx.trunc, tvm.tirx.ceil, tvm.tirx.floor, 
tvm.tirx.nearbyint]:

Reply via email to