This is an automated email from the ASF dual-hosted git repository.

yashmayya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 56684e980f2 Handle NaN and +/- Infinity in ROUND_DECIMAL (#16993)
56684e980f2 is described below

commit 56684e980f2d9fd55d2b75ed2f3f7cc48dc179b1
Author: Yash Mayya <[email protected]>
AuthorDate: Fri Oct 10 09:09:13 2025 -0700

    Handle NaN and +/- Infinity in ROUND_DECIMAL (#16993)
---
 .../function/scalar/ArithmeticFunctions.java       | 10 +++++++++
 .../function/RoundDecimalTransformFunction.java    | 15 +++++++++++---
 .../data/function/ArithmeticFunctionsTest.java     | 24 ++++++++++++++++++++++
 .../RoundDecimalTransformFunctionTest.java         | 20 ++++++++++++++++++
 4 files changed, 66 insertions(+), 3 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
index d27a3fa6ccc..6412c9c22f2 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
@@ -169,6 +169,10 @@ public class ArithmeticFunctions {
   // when multiplying by Math.pow(10, scale) for rounding
   @ScalarFunction
   public static double roundDecimal(double a, int scale) {
+    if (Double.isNaN(a) || Double.isInfinite(a)) {
+      // Follow standard PostgreSQL behavior where NaN and +/- Inf are 
returned as is
+      return a;
+    }
     return BigDecimal.valueOf(a).setScale(scale, 
RoundingMode.HALF_UP).doubleValue();
   }
 
@@ -176,6 +180,12 @@ public class ArithmeticFunctions {
   // but it is not possible because of existing DateTimeFunction with same 
name.
   @ScalarFunction
   public static double roundDecimal(double a) {
+    if (Double.isNaN(a) || Double.isInfinite(a)) {
+      // Math.round has special handling for NaN and +/- Inf:
+      // NaN -> 0, -Inf -> Long.MIN_VALUE, +Inf -> Long.MAX_VALUE
+      // Follow standard PostgreSQL behavior where NaN and +/- Inf are 
returned as is
+      return a;
+    }
     return Math.round(a);
   }
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunction.java
index 55bc42d5ba0..56ae627ab6c 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunction.java
@@ -88,9 +88,14 @@ public class RoundDecimalTransformFunction extends 
BaseTransformFunction {
     int length = valueBlock.getNumDocs();
     initDoubleValuesSV(length);
     double[] leftValues = 
_leftTransformFunction.transformToDoubleValuesSV(valueBlock);
+    // Follow standard PostgreSQL behavior where NaN and +/- Inf are returned 
as is
     if (_fixedScale) {
       for (int i = 0; i < length; i++) {
         double value = leftValues[i];
+        if (Double.isNaN(value) || Double.isInfinite(value)) {
+          _doubleValuesSV[i] = value;
+          continue;
+        }
         try {
           _doubleValuesSV[i] = BigDecimal.valueOf(value).setScale(_scale, 
RoundingMode.HALF_UP).doubleValue();
         } catch (Exception e) {
@@ -101,6 +106,10 @@ public class RoundDecimalTransformFunction extends 
BaseTransformFunction {
       int[] rightValues = 
_rightTransformFunction.transformToIntValuesSV(valueBlock);
       for (int i = 0; i < length; i++) {
         double value = leftValues[i];
+        if (Double.isNaN(value) || Double.isInfinite(value)) {
+          _doubleValuesSV[i] = value;
+          continue;
+        }
         int scale = rightValues[i];
         try {
           _doubleValuesSV[i] = BigDecimal.valueOf(value).setScale(scale, 
RoundingMode.HALF_UP).doubleValue();
@@ -111,11 +120,11 @@ public class RoundDecimalTransformFunction extends 
BaseTransformFunction {
     } else {
       for (int i = 0; i < length; i++) {
         double value = leftValues[i];
-        if (value == Double.NEGATIVE_INFINITY || value == 
Double.POSITIVE_INFINITY || Double.isNaN(value)) {
+        if (Double.isNaN(value) || Double.isInfinite(value)) {
           _doubleValuesSV[i] = value;
-        } else {
-          _doubleValuesSV[i] = Math.round(value);
+          continue;
         }
+        _doubleValuesSV[i] = Math.round(value);
       }
     }
     return _doubleValuesSV;
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/ArithmeticFunctionsTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/ArithmeticFunctionsTest.java
index 7d8d022d20d..226e9d56323 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/ArithmeticFunctionsTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/ArithmeticFunctionsTest.java
@@ -327,6 +327,30 @@ public class ArithmeticFunctionsTest {
       inputs.add(new Object[]{"roundDecimal(a, 2)", Lists.newArrayList("a"), 
row, 9.46});
       inputs.add(new Object[]{"roundDecimal(a, 3)", Lists.newArrayList("a"), 
row, 9.46});
     }
+    {
+      GenericRow row = new GenericRow();
+      row.putValue("a", Double.NEGATIVE_INFINITY);
+      inputs.add(new Object[]{"roundDecimal(a)", Lists.newArrayList("a"), row, 
Double.NEGATIVE_INFINITY});
+      inputs.add(new Object[]{"roundDecimal(a, 1)", Lists.newArrayList("a"), 
row, Double.NEGATIVE_INFINITY});
+      inputs.add(new Object[]{"roundDecimal(a, 2)", Lists.newArrayList("a"), 
row, Double.NEGATIVE_INFINITY});
+      inputs.add(new Object[]{"roundDecimal(a, 3)", Lists.newArrayList("a"), 
row, Double.NEGATIVE_INFINITY});
+    }
+    {
+      GenericRow row = new GenericRow();
+      row.putValue("a", Double.POSITIVE_INFINITY);
+      inputs.add(new Object[]{"roundDecimal(a)", Lists.newArrayList("a"), row, 
Double.POSITIVE_INFINITY});
+      inputs.add(new Object[]{"roundDecimal(a, 1)", Lists.newArrayList("a"), 
row, Double.POSITIVE_INFINITY});
+      inputs.add(new Object[]{"roundDecimal(a, 2)", Lists.newArrayList("a"), 
row, Double.POSITIVE_INFINITY});
+      inputs.add(new Object[]{"roundDecimal(a, 3)", Lists.newArrayList("a"), 
row, Double.POSITIVE_INFINITY});
+    }
+    {
+      GenericRow row = new GenericRow();
+      row.putValue("a", Double.NaN);
+      inputs.add(new Object[]{"roundDecimal(a)", Lists.newArrayList("a"), row, 
Double.NaN});
+      inputs.add(new Object[]{"roundDecimal(a, 1)", Lists.newArrayList("a"), 
row, Double.NaN});
+      inputs.add(new Object[]{"roundDecimal(a, 2)", Lists.newArrayList("a"), 
row, Double.NaN});
+      inputs.add(new Object[]{"roundDecimal(a, 3)", Lists.newArrayList("a"), 
row, Double.NaN});
+    }
     // test truncate
     {
       GenericRow row = new GenericRow();
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunctionTest.java
index 3a2ab38e2d3..577a1c4812a 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunctionTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunctionTest.java
@@ -98,4 +98,24 @@ public class RoundDecimalTransformFunctionTest extends 
BaseTransformFunctionTest
     }
     testTransformFunctionWithNull(transformFunction, expectedValues, 
roaringBitmap);
   }
+
+  @Test
+  public void testRoundDecimalNaNAndInfinity() {
+    ExpressionContext expression =
+        RequestContextUtils.getExpression(String.format("round_decimal(%s / 
0)", INT_SV_COLUMN));
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    Assert.assertTrue(transformFunction instanceof 
RoundDecimalTransformFunction);
+    Assert.assertEquals(transformFunction.getName(), 
TransformFunctionType.ROUND_DECIMAL.getName());
+    double[] expectedValues = new double[NUM_ROWS];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      if (_intSVValues[i] < 0) {
+        expectedValues[i] = Double.NEGATIVE_INFINITY;
+      } else if (_intSVValues[i] == 0) {
+        expectedValues[i] = Double.NaN;
+      } else {
+        expectedValues[i] = Double.POSITIVE_INFINITY;
+      }
+    }
+    testTransformFunction(transformFunction, expectedValues);
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to