This is an automated email from the ASF dual-hosted git repository. richardstartin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 7cb1473b8f fix regression where case order was reversed, add regression test (#8748) 7cb1473b8f is described below commit 7cb1473b8fa8aa54b9d1ab5cef8a51e19a2acedf Author: Richard Startin <rich...@startree.ai> AuthorDate: Fri May 20 21:24:59 2022 +0200 fix regression where case order was reversed, add regression test (#8748) --- .../transform/function/CaseTransformFunction.java | 15 +++-- .../function/CaseTransformFunctionTest.java | 68 +++++++++++++++++++--- 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java index c99b0c4d9d..641107bc1b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunction.java @@ -22,6 +22,7 @@ import com.google.common.base.Preconditions; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.pinot.core.operator.blocks.ProjectionBlock; @@ -86,13 +87,13 @@ public class CaseTransformFunction extends BaseTransformFunction { for (int i = 0; i < numWhenStatements; i++) { _whenStatements.add(arguments.get(i)); } - // Add ELSE Statement first _elseThenStatements = new ArrayList<>(numWhenStatements + 1); - _elseThenStatements.add(arguments.get(numWhenStatements * 2)); - for (int i = numWhenStatements; i < numWhenStatements * 2; i++) { + for (int i = numWhenStatements; i < numWhenStatements * 2 + 1; i++) { _elseThenStatements.add(arguments.get(i)); } _selections = new boolean[_elseThenStatements.size()]; + Collections.reverse(_elseThenStatements); + Collections.reverse(_whenStatements); _resultMetadata = calculateResultMetadata(); } @@ -212,7 +213,8 @@ public class CaseTransformFunction extends BaseTransformFunction { /** * Evaluate the ProjectionBlock for the WHEN statements, returns an array with the - * index(1 to N) of matched WHEN clause, 0 means nothing matched, so go to ELSE. + * index(1 to N) of matched WHEN clause ordered by match priority, 0 means nothing + * matched, so go to ELSE. */ private int[] getSelectedArray(ProjectionBlock projectionBlock) { int numDocs = projectionBlock.getNumDocs(); @@ -228,9 +230,12 @@ public class CaseTransformFunction extends BaseTransformFunction { int[] conditions = whenStatement.transformToIntValuesSV(projectionBlock); for (int j = 0; j < numDocs & j < conditions.length; j++) { _selectedResults[j] = Math.max(conditions[j] * (i + 1), _selectedResults[j]); - _selections[_selectedResults[j]] = true; } } + // try to prune clauses now + for (int i = 0; i < numDocs; i++) { + _selections[_selectedResults[i]] = true; + } int numSelections = 0; for (boolean selection : _selections) { if (selection) { diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java index b7f1e0534e..a9e02d3a26 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CaseTransformFunctionTest.java @@ -21,13 +21,17 @@ package org.apache.pinot.core.operator.transform.function; import java.math.BigDecimal; import java.util.Arrays; import java.util.Random; +import java.util.stream.Stream; import org.apache.pinot.common.function.TransformFunctionType; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.request.context.RequestContextUtils; import org.apache.pinot.spi.data.FieldSpec.DataType; import org.testng.Assert; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import static org.testng.Assert.assertEquals; + public class CaseTransformFunctionTest extends BaseTransformFunctionTest { private static final int INDEX_TO_COMPARE = new Random(System.currentTimeMillis()).nextInt(NUM_ROWS); @@ -37,6 +41,54 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { TransformFunctionType.LESS_THAN_OR_EQUAL }; + @DataProvider + public Object[][] params() { + return Stream.of(INT_SV_COLUMN, LONG_SV_COLUMN, FLOAT_SV_COLUMN, DOUBLE_SV_COLUMN) + .flatMap(col -> Stream.of( + new int[] {3, 2, 1}, + new int[] {1, 2, 3}, + new int[] {Integer.MAX_VALUE / 2, Integer.MAX_VALUE / 4, 0}, + new int[] {0, Integer.MAX_VALUE / 4, Integer.MAX_VALUE / 2}, + new int[] {0, Integer.MIN_VALUE / 4, Integer.MIN_VALUE}, + new int[] {Integer.MIN_VALUE, 0, 1}, + new int[] {Integer.MAX_VALUE, Integer.MIN_VALUE, 1}, + new int[] {Integer.MAX_VALUE, Integer.MAX_VALUE - 1, Integer.MAX_VALUE - 2} + ).map(thresholds -> new Object[]{col, thresholds[0], thresholds[1], thresholds[2]})) + .toArray(Object[][]::new); + } + + @Test(dataProvider = "params") + public void testCasePriorityObserved(String column, int threshold1, int threshold2, int threshold3) { + String statement = String.format("CASE WHEN %s > %d THEN 3 WHEN %s > %d THEN 2 WHEN %s > %d THEN 1 ELSE -1 END", + column, threshold1, column, threshold2, column, threshold3); + ExpressionContext expression = RequestContextUtils.getExpression(statement); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + int[] expectedIntResults = new int[NUM_ROWS]; + for (int i = 0; i < expectedIntResults.length; i++) { + switch (column) { + case INT_SV_COLUMN: + expectedIntResults[i] = _intSVValues[i] > threshold1 ? 3 + : _intSVValues[i] > threshold2 ? 2 : _intSVValues[i] > threshold3 ? 1 : -1; + break; + case LONG_SV_COLUMN: + expectedIntResults[i] = _longSVValues[i] > threshold1 ? 3 + : _longSVValues[i] > threshold2 ? 2 : _longSVValues[i] > threshold3 ? 1 : -1; + break; + case FLOAT_SV_COLUMN: + expectedIntResults[i] = _floatSVValues[i] > threshold1 ? 3 + : _floatSVValues[i] > threshold2 ? 2 : _floatSVValues[i] > threshold3 ? 1 : -1; + break; + case DOUBLE_SV_COLUMN: + expectedIntResults[i] = _doubleSVValues[i] > threshold1 ? 3 + : _doubleSVValues[i] > threshold2 ? 2 : _doubleSVValues[i] > threshold3 ? 1 : -1; + break; + default: + } + } + int[] intValues = transformFunction.transformToIntValuesSV(_projectionBlock); + assertEquals(expectedIntResults, intValues); + } + @Test public void testCaseTransformFunctionWithIntResults() { int[] expectedIntResults = new int[NUM_ROWS]; @@ -149,8 +201,8 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { RequestContextUtils.getExpression(String.format("CASE WHEN %s THEN 100 ELSE 10 END", predicate)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CaseTransformFunction); - Assert.assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); - Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); testTransformFunction(transformFunction, expectedValues); } @@ -159,8 +211,8 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { RequestContextUtils.getExpression(String.format("CASE WHEN %s THEN 100.0 ELSE 10.0 END", predicate)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CaseTransformFunction); - Assert.assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); - Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.FLOAT); + assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.FLOAT); testTransformFunction(transformFunction, expectedValues); } @@ -170,8 +222,8 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { String.format("CASE WHEN %s THEN '100.99887766554433221' ELSE '10.1122334455667788909' END", predicate)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CaseTransformFunction); - Assert.assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); - Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.BIG_DECIMAL); + assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.BIG_DECIMAL); testTransformFunction(transformFunction, expectedValues); } @@ -180,8 +232,8 @@ public class CaseTransformFunctionTest extends BaseTransformFunctionTest { RequestContextUtils.getExpression(String.format("CASE WHEN %s THEN 'aaa' ELSE 'bbb' END", predicate)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CaseTransformFunction); - Assert.assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); - Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); + assertEquals(transformFunction.getName(), CaseTransformFunction.FUNCTION_NAME); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); testTransformFunction(transformFunction, expectedValues); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org