This is an automated email from the ASF dual-hosted git repository.

richardstartin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 7cb1473b8f fix regression where case order was reversed, add 
regression test (#8748)
7cb1473b8f is described below

commit 7cb1473b8fa8aa54b9d1ab5cef8a51e19a2acedf
Author: Richard Startin <rich...@startree.ai>
AuthorDate: Fri May 20 21:24:59 2022 +0200

    fix regression where case order was reversed, add regression test (#8748)
---
 .../transform/function/CaseTransformFunction.java  | 15 +++--
 .../function/CaseTransformFunctionTest.java        | 68 +++++++++++++++++++---
 2 files changed, 70 insertions(+), 13 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
index c99b0c4d9d..641107bc1b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java
@@ -22,6 +22,7 @@ import com.google.common.base.Preconditions;
 import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import org.apache.pinot.core.operator.blocks.ProjectionBlock;
@@ -86,13 +87,13 @@ public class CaseTransformFunction extends 
BaseTransformFunction {
     for (int i = 0; i < numWhenStatements; i++) {
       _whenStatements.add(arguments.get(i));
     }
-    // Add ELSE Statement first
     _elseThenStatements = new ArrayList<>(numWhenStatements + 1);
-    _elseThenStatements.add(arguments.get(numWhenStatements * 2));
-    for (int i = numWhenStatements; i < numWhenStatements * 2; i++) {
+    for (int i = numWhenStatements; i < numWhenStatements * 2 + 1; i++) {
       _elseThenStatements.add(arguments.get(i));
     }
     _selections = new boolean[_elseThenStatements.size()];
+    Collections.reverse(_elseThenStatements);
+    Collections.reverse(_whenStatements);
     _resultMetadata = calculateResultMetadata();
   }
 
@@ -212,7 +213,8 @@ public class CaseTransformFunction extends 
BaseTransformFunction {
 
   /**
    * Evaluate the ProjectionBlock for the WHEN statements, returns an array 
with the
-   * index(1 to N) of matched WHEN clause, 0 means nothing matched, so go to 
ELSE.
+   * index(1 to N) of matched WHEN clause ordered by match priority, 0 means 
nothing
+   * matched, so go to ELSE.
    */
   private int[] getSelectedArray(ProjectionBlock projectionBlock) {
     int numDocs = projectionBlock.getNumDocs();
@@ -228,9 +230,12 @@ public class CaseTransformFunction extends 
BaseTransformFunction {
       int[] conditions = whenStatement.transformToIntValuesSV(projectionBlock);
       for (int j = 0; j < numDocs & j < conditions.length; j++) {
         _selectedResults[j] = Math.max(conditions[j] * (i + 1), 
_selectedResults[j]);
-        _selections[_selectedResults[j]] = true;
       }
     }
+    // try to prune clauses now
+    for (int i = 0; i < numDocs; i++) {
+      _selections[_selectedResults[i]] = true;
+    }
     int numSelections = 0;
     for (boolean selection : _selections) {
       if (selection) {
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java
index b7f1e0534e..a9e02d3a26 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java
@@ -21,13 +21,17 @@ package org.apache.pinot.core.operator.transform.function;
 import java.math.BigDecimal;
 import java.util.Arrays;
 import java.util.Random;
+import java.util.stream.Stream;
 import org.apache.pinot.common.function.TransformFunctionType;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.RequestContextUtils;
 import org.apache.pinot.spi.data.FieldSpec.DataType;
 import org.testng.Assert;
+import org.testng.annotations.DataProvider;
 import org.testng.annotations.Test;
 
+import static org.testng.Assert.assertEquals;
+
 
 public class CaseTransformFunctionTest extends BaseTransformFunctionTest {
   private static final int INDEX_TO_COMPARE = new 
Random(System.currentTimeMillis()).nextInt(NUM_ROWS);
@@ -37,6 +41,54 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
       TransformFunctionType.LESS_THAN_OR_EQUAL
   };
 
+  @DataProvider
+  public Object[][] params() {
+    return Stream.of(INT_SV_COLUMN, LONG_SV_COLUMN, FLOAT_SV_COLUMN, 
DOUBLE_SV_COLUMN)
+        .flatMap(col -> Stream.of(
+            new int[] {3, 2, 1},
+            new int[] {1, 2, 3},
+            new int[] {Integer.MAX_VALUE / 2, Integer.MAX_VALUE / 4, 0},
+            new int[] {0, Integer.MAX_VALUE / 4, Integer.MAX_VALUE / 2},
+            new int[] {0, Integer.MIN_VALUE / 4, Integer.MIN_VALUE},
+            new int[] {Integer.MIN_VALUE, 0, 1},
+            new int[] {Integer.MAX_VALUE, Integer.MIN_VALUE, 1},
+            new int[] {Integer.MAX_VALUE, Integer.MAX_VALUE - 1, 
Integer.MAX_VALUE - 2}
+        ).map(thresholds -> new Object[]{col, thresholds[0], thresholds[1], 
thresholds[2]}))
+        .toArray(Object[][]::new);
+  }
+
+  @Test(dataProvider = "params")
+  public void testCasePriorityObserved(String column, int threshold1, int 
threshold2, int threshold3) {
+    String statement = String.format("CASE WHEN %s > %d THEN 3 WHEN %s > %d 
THEN 2 WHEN %s > %d THEN 1 ELSE -1 END",
+        column, threshold1, column, threshold2, column, threshold3);
+    ExpressionContext expression = 
RequestContextUtils.getExpression(statement);
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    int[] expectedIntResults = new int[NUM_ROWS];
+    for (int i = 0; i < expectedIntResults.length; i++) {
+      switch (column) {
+        case INT_SV_COLUMN:
+          expectedIntResults[i] = _intSVValues[i] > threshold1 ? 3
+              : _intSVValues[i] > threshold2 ? 2 : _intSVValues[i] > 
threshold3 ? 1 : -1;
+          break;
+        case LONG_SV_COLUMN:
+          expectedIntResults[i] = _longSVValues[i] > threshold1 ? 3
+              : _longSVValues[i] > threshold2 ? 2 : _longSVValues[i] > 
threshold3 ? 1 : -1;
+          break;
+        case FLOAT_SV_COLUMN:
+          expectedIntResults[i] = _floatSVValues[i] > threshold1 ? 3
+              : _floatSVValues[i] > threshold2 ? 2 : _floatSVValues[i] > 
threshold3 ? 1 : -1;
+          break;
+        case DOUBLE_SV_COLUMN:
+          expectedIntResults[i] = _doubleSVValues[i] > threshold1 ? 3
+              : _doubleSVValues[i] > threshold2 ? 2 : _doubleSVValues[i] > 
threshold3 ? 1 : -1;
+          break;
+        default:
+      }
+    }
+    int[] intValues = 
transformFunction.transformToIntValuesSV(_projectionBlock);
+    assertEquals(expectedIntResults, intValues);
+  }
+
   @Test
   public void testCaseTransformFunctionWithIntResults() {
     int[] expectedIntResults = new int[NUM_ROWS];
@@ -149,8 +201,8 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
         RequestContextUtils.getExpression(String.format("CASE WHEN %s THEN 100 
ELSE 10 END", predicate));
     TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CaseTransformFunction);
-    Assert.assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
-    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
+    assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.INT);
     testTransformFunction(transformFunction, expectedValues);
   }
 
@@ -159,8 +211,8 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
         RequestContextUtils.getExpression(String.format("CASE WHEN %s THEN 
100.0 ELSE 10.0 END", predicate));
     TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CaseTransformFunction);
-    Assert.assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
-    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.FLOAT);
+    assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.FLOAT);
     testTransformFunction(transformFunction, expectedValues);
   }
 
@@ -170,8 +222,8 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
         String.format("CASE WHEN %s THEN '100.99887766554433221' ELSE 
'10.1122334455667788909' END", predicate));
     TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CaseTransformFunction);
-    Assert.assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
-    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.BIG_DECIMAL);
+    assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.BIG_DECIMAL);
     testTransformFunction(transformFunction, expectedValues);
   }
 
@@ -180,8 +232,8 @@ public class CaseTransformFunctionTest extends 
BaseTransformFunctionTest {
         RequestContextUtils.getExpression(String.format("CASE WHEN %s THEN 
'aaa' ELSE 'bbb' END", predicate));
     TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CaseTransformFunction);
-    Assert.assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
-    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.STRING);
+    assertEquals(transformFunction.getName(), 
CaseTransformFunction.FUNCTION_NAME);
+    assertEquals(transformFunction.getResultMetadata().getDataType(), 
DataType.STRING);
     testTransformFunction(transformFunction, expectedValues);
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to