This is an automated email from the ASF dual-hosted git repository.
yashmayya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 56684e980f2 Handle NaN and +/- Infinity in ROUND_DECIMAL (#16993)
56684e980f2 is described below
commit 56684e980f2d9fd55d2b75ed2f3f7cc48dc179b1
Author: Yash Mayya <[email protected]>
AuthorDate: Fri Oct 10 09:09:13 2025 -0700
Handle NaN and +/- Infinity in ROUND_DECIMAL (#16993)
---
.../function/scalar/ArithmeticFunctions.java | 10 +++++++++
.../function/RoundDecimalTransformFunction.java | 15 +++++++++++---
.../data/function/ArithmeticFunctionsTest.java | 24 ++++++++++++++++++++++
.../RoundDecimalTransformFunctionTest.java | 20 ++++++++++++++++++
4 files changed, 66 insertions(+), 3 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
index d27a3fa6ccc..6412c9c22f2 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
@@ -169,6 +169,10 @@ public class ArithmeticFunctions {
// when multiplying by Math.pow(10, scale) for rounding
@ScalarFunction
public static double roundDecimal(double a, int scale) {
+ if (Double.isNaN(a) || Double.isInfinite(a)) {
+ // Follow standard PostgreSQL behavior where NaN and +/- Inf are
returned as is
+ return a;
+ }
return BigDecimal.valueOf(a).setScale(scale,
RoundingMode.HALF_UP).doubleValue();
}
@@ -176,6 +180,12 @@ public class ArithmeticFunctions {
// but it is not possible because of existing DateTimeFunction with same
name.
@ScalarFunction
public static double roundDecimal(double a) {
+ if (Double.isNaN(a) || Double.isInfinite(a)) {
+ // Math.round has special handling for NaN and +/- Inf:
+ // NaN -> 0, -Inf -> Long.MIN_VALUE, +Inf -> Long.MAX_VALUE
+ // Follow standard PostgreSQL behavior where NaN and +/- Inf are
returned as is
+ return a;
+ }
return Math.round(a);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunction.java
index 55bc42d5ba0..56ae627ab6c 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunction.java
@@ -88,9 +88,14 @@ public class RoundDecimalTransformFunction extends
BaseTransformFunction {
int length = valueBlock.getNumDocs();
initDoubleValuesSV(length);
double[] leftValues =
_leftTransformFunction.transformToDoubleValuesSV(valueBlock);
+ // Follow standard PostgreSQL behavior where NaN and +/- Inf are returned
as is
if (_fixedScale) {
for (int i = 0; i < length; i++) {
double value = leftValues[i];
+ if (Double.isNaN(value) || Double.isInfinite(value)) {
+ _doubleValuesSV[i] = value;
+ continue;
+ }
try {
_doubleValuesSV[i] = BigDecimal.valueOf(value).setScale(_scale,
RoundingMode.HALF_UP).doubleValue();
} catch (Exception e) {
@@ -101,6 +106,10 @@ public class RoundDecimalTransformFunction extends
BaseTransformFunction {
int[] rightValues =
_rightTransformFunction.transformToIntValuesSV(valueBlock);
for (int i = 0; i < length; i++) {
double value = leftValues[i];
+ if (Double.isNaN(value) || Double.isInfinite(value)) {
+ _doubleValuesSV[i] = value;
+ continue;
+ }
int scale = rightValues[i];
try {
_doubleValuesSV[i] = BigDecimal.valueOf(value).setScale(scale,
RoundingMode.HALF_UP).doubleValue();
@@ -111,11 +120,11 @@ public class RoundDecimalTransformFunction extends
BaseTransformFunction {
} else {
for (int i = 0; i < length; i++) {
double value = leftValues[i];
- if (value == Double.NEGATIVE_INFINITY || value ==
Double.POSITIVE_INFINITY || Double.isNaN(value)) {
+ if (Double.isNaN(value) || Double.isInfinite(value)) {
_doubleValuesSV[i] = value;
- } else {
- _doubleValuesSV[i] = Math.round(value);
+ continue;
}
+ _doubleValuesSV[i] = Math.round(value);
}
}
return _doubleValuesSV;
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/ArithmeticFunctionsTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/ArithmeticFunctionsTest.java
index 7d8d022d20d..226e9d56323 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/data/function/ArithmeticFunctionsTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/data/function/ArithmeticFunctionsTest.java
@@ -327,6 +327,30 @@ public class ArithmeticFunctionsTest {
inputs.add(new Object[]{"roundDecimal(a, 2)", Lists.newArrayList("a"),
row, 9.46});
inputs.add(new Object[]{"roundDecimal(a, 3)", Lists.newArrayList("a"),
row, 9.46});
}
+ {
+ GenericRow row = new GenericRow();
+ row.putValue("a", Double.NEGATIVE_INFINITY);
+ inputs.add(new Object[]{"roundDecimal(a)", Lists.newArrayList("a"), row,
Double.NEGATIVE_INFINITY});
+ inputs.add(new Object[]{"roundDecimal(a, 1)", Lists.newArrayList("a"),
row, Double.NEGATIVE_INFINITY});
+ inputs.add(new Object[]{"roundDecimal(a, 2)", Lists.newArrayList("a"),
row, Double.NEGATIVE_INFINITY});
+ inputs.add(new Object[]{"roundDecimal(a, 3)", Lists.newArrayList("a"),
row, Double.NEGATIVE_INFINITY});
+ }
+ {
+ GenericRow row = new GenericRow();
+ row.putValue("a", Double.POSITIVE_INFINITY);
+ inputs.add(new Object[]{"roundDecimal(a)", Lists.newArrayList("a"), row,
Double.POSITIVE_INFINITY});
+ inputs.add(new Object[]{"roundDecimal(a, 1)", Lists.newArrayList("a"),
row, Double.POSITIVE_INFINITY});
+ inputs.add(new Object[]{"roundDecimal(a, 2)", Lists.newArrayList("a"),
row, Double.POSITIVE_INFINITY});
+ inputs.add(new Object[]{"roundDecimal(a, 3)", Lists.newArrayList("a"),
row, Double.POSITIVE_INFINITY});
+ }
+ {
+ GenericRow row = new GenericRow();
+ row.putValue("a", Double.NaN);
+ inputs.add(new Object[]{"roundDecimal(a)", Lists.newArrayList("a"), row,
Double.NaN});
+ inputs.add(new Object[]{"roundDecimal(a, 1)", Lists.newArrayList("a"),
row, Double.NaN});
+ inputs.add(new Object[]{"roundDecimal(a, 2)", Lists.newArrayList("a"),
row, Double.NaN});
+ inputs.add(new Object[]{"roundDecimal(a, 3)", Lists.newArrayList("a"),
row, Double.NaN});
+ }
// test truncate
{
GenericRow row = new GenericRow();
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunctionTest.java
index 3a2ab38e2d3..577a1c4812a 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunctionTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/RoundDecimalTransformFunctionTest.java
@@ -98,4 +98,24 @@ public class RoundDecimalTransformFunctionTest extends
BaseTransformFunctionTest
}
testTransformFunctionWithNull(transformFunction, expectedValues,
roaringBitmap);
}
+
+ @Test
+ public void testRoundDecimalNaNAndInfinity() {
+ ExpressionContext expression =
+ RequestContextUtils.getExpression(String.format("round_decimal(%s /
0)", INT_SV_COLUMN));
+ TransformFunction transformFunction =
TransformFunctionFactory.get(expression, _dataSourceMap);
+ Assert.assertTrue(transformFunction instanceof
RoundDecimalTransformFunction);
+ Assert.assertEquals(transformFunction.getName(),
TransformFunctionType.ROUND_DECIMAL.getName());
+ double[] expectedValues = new double[NUM_ROWS];
+ for (int i = 0; i < NUM_ROWS; i++) {
+ if (_intSVValues[i] < 0) {
+ expectedValues[i] = Double.NEGATIVE_INFINITY;
+ } else if (_intSVValues[i] == 0) {
+ expectedValues[i] = Double.NaN;
+ } else {
+ expectedValues[i] = Double.POSITIVE_INFINITY;
+ }
+ }
+ testTransformFunction(transformFunction, expectedValues);
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]