This is an automated email from the ASF dual-hosted git repository. richardstartin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 87fb007f7d add scalar function for cast so it can be calculated at compile time (#8535) 87fb007f7d is described below commit 87fb007f7ddbeec5412b526f94ecd3bbd8fe2e08 Author: Richard Startin <rich...@startree.ai> AuthorDate: Sat Apr 16 00:53:10 2022 +0100 add scalar function for cast so it can be calculated at compile time (#8535) --- config/checkstyle.xml | 1 + .../scalar/DataTypeConversionFunctions.java | 34 ++++ .../scalar/DataTypeConversionFunctionsTest.java | 63 ++++++++ .../pinot/sql/parsers/CalciteSqlCompilerTest.java | 30 +--- .../function/CastTransformFunctionTest.java | 23 ++- .../org/apache/pinot/queries/CastQueriesTest.java | 171 +++++++++++++++++++++ 6 files changed, 295 insertions(+), 27 deletions(-) diff --git a/config/checkstyle.xml b/config/checkstyle.xml index 8c7242ca8b..86e163f50c 100644 --- a/config/checkstyle.xml +++ b/config/checkstyle.xml @@ -136,6 +136,7 @@ org.apache.pinot.controller.recommender.rules.io.params.RecommenderConstants.RulesToExecute.*, org.apache.pinot.controller.recommender.rules.utils.PredicateParseResult.*, org.apache.pinot.client.utils.Constants.*, + org.apache.pinot.common.utils.PinotDataType.*, org.apache.pinot.segment.local.startree.StarTreeBuilderUtils.*, org.apache.pinot.segment.local.startree.v2.store.StarTreeIndexMapUtils.*, org.apache.pinot.segment.local.utils.GeometryType.*, diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java index 47c9ce91f5..789cb35cb7 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java @@ -18,12 +18,19 @@ */ package org.apache.pinot.common.function.scalar; +import com.google.common.base.Preconditions; import java.math.BigDecimal; import java.util.Base64; +import org.apache.pinot.common.utils.PinotDataType; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.BigDecimalUtils; import org.apache.pinot.spi.utils.BytesUtils; +import static org.apache.pinot.common.utils.PinotDataType.DOUBLE; +import static org.apache.pinot.common.utils.PinotDataType.INTEGER; +import static org.apache.pinot.common.utils.PinotDataType.LONG; +import static org.apache.pinot.common.utils.PinotDataType.STRING; + /** * Contains function to convert a datatype to another datatype. @@ -32,6 +39,33 @@ public class DataTypeConversionFunctions { private DataTypeConversionFunctions() { } + @ScalarFunction + public static Object cast(Object value, String targetTypeLiteral) { + try { + Class<?> clazz = value.getClass(); + Preconditions.checkArgument(!clazz.isArray() | clazz == byte[].class, "%s must not be an array type", clazz); + PinotDataType sourceType = PinotDataType.getSingleValueType(clazz); + String transformed = targetTypeLiteral.toUpperCase(); + PinotDataType targetDataType; + if ("INT".equals(transformed)) { + targetDataType = INTEGER; + } else if ("VARCHAR".equals(transformed)) { + targetDataType = STRING; + } else { + targetDataType = PinotDataType.valueOf(transformed); + } + if (sourceType == STRING && (targetDataType == INTEGER || targetDataType == LONG)) { + if (String.valueOf(value).contains(".")) { + // convert integers via double to avoid parse errors + return targetDataType.convert(DOUBLE.convert(value, sourceType), DOUBLE); + } + } + return targetDataType.convert(value, sourceType); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Unknown data type: " + targetTypeLiteral); + } + } + /** * Converts big decimal string representation to bytes. * Only scale of upto 2 bytes is supported by the function diff --git a/pinot-common/src/test/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctionsTest.java b/pinot-common/src/test/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctionsTest.java new file mode 100644 index 0000000000..981e39e2ad --- /dev/null +++ b/pinot-common/src/test/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctionsTest.java @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar; + +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + + +public class DataTypeConversionFunctionsTest { + + @DataProvider(name = "testCases") + public static Object[][] testCases() { + return new Object[][]{ + {"a", "string", "a"}, + {"10", "int", 10}, + {"10", "long", 10L}, + {"10", "float", 10F}, + {"10", "double", 10D}, + {"10.0", "int", 10}, + {"10.0", "long", 10L}, + {"10.0", "float", 10F}, + {"10.0", "double", 10D}, + {10, "string", "10"}, + {10L, "string", "10"}, + {10F, "string", "10.0"}, + {10D, "string", "10.0"}, + {"a", "string", "a"}, + {10, "int", 10}, + {10L, "long", 10L}, + {10F, "float", 10F}, + {10D, "double", 10D}, + {10L, "int", 10}, + {10, "long", 10L}, + {10D, "float", 10F}, + {10F, "double", 10D}, + {"abc1", "bytes", new byte[]{(byte) 0xab, (byte) 0xc1}}, + {new byte[]{(byte) 0xab, (byte) 0xc1}, "string", "abc1"} + }; + } + + @Test(dataProvider = "testCases") + public void test(Object value, String type, Object expected) { + assertEquals(DataTypeConversionFunctions.cast(value, type), expected); + } +} diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java index 83cde06edc..90f0d2604b 100644 --- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java @@ -1611,41 +1611,19 @@ public class CalciteSqlCompilerTest { public void testCastTransformation() { PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery("select CAST(25.65 AS int) from myTable"); Assert.assertEquals(pinotQuery.getSelectListSize(), 1); - Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast"); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(), 25.65); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(), - "INTEGER"); + Assert.assertEquals(pinotQuery.getSelectList().get(0).getLiteral().getLongValue(), 25); pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST('20170825' AS LONG) from myTable"); Assert.assertEquals(pinotQuery.getSelectListSize(), 1); - Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast"); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getStringValue(), - "20170825"); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(), "LONG"); + Assert.assertEquals(pinotQuery.getSelectList().get(0).getLiteral().getLongValue(), 20170825); pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(20170825.0 AS Float) from myTable"); Assert.assertEquals(pinotQuery.getSelectListSize(), 1); - Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast"); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(), - 20170825.0); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(), - "FLOAT"); + Assert.assertEquals((float) pinotQuery.getSelectList().get(0).getLiteral().getDoubleValue(), 20170825.0F); pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(20170825.0 AS dOuble) from myTable"); Assert.assertEquals(pinotQuery.getSelectListSize(), 1); - Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast"); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(), - 20170825.0); - Assert.assertEquals( - pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(), - "DOUBLE"); + Assert.assertEquals((float) pinotQuery.getSelectList().get(0).getLiteral().getDoubleValue(), 20170825.0F); pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(column1 AS STRING) from myTable"); Assert.assertEquals(pinotQuery.getSelectListSize(), 1); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java index 6f9c985896..19556a3d9e 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java @@ -23,6 +23,9 @@ import org.apache.pinot.common.request.context.RequestContextUtils; import org.testng.Assert; import org.testng.annotations.Test; +import static org.apache.pinot.common.function.scalar.DataTypeConversionFunctions.cast; +import static org.testng.Assert.assertEquals; + public class CastTransformFunctionTest extends BaseTransformFunctionTest { @@ -32,22 +35,28 @@ public class CastTransformFunctionTest extends BaseTransformFunctionTest { RequestContextUtils.getExpressionFromSQL(String.format("CAST(%s AS string)", INT_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CastTransformFunction); - Assert.assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME); + assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME); String[] expectedValues = new String[NUM_ROWS]; + String[] scalarStringValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = Integer.toString(_intSVValues[i]); + scalarStringValues[i] = (String) cast(_intSVValues[i], "string"); } testTransformFunction(transformFunction, expectedValues); + assertEquals(expectedValues, scalarStringValues); expression = RequestContextUtils.getExpressionFromSQL(String.format("CAST(CAST(%s as INT) as FLOAT)", FLOAT_SV_COLUMN)); transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CastTransformFunction); float[] expectedFloatValues = new float[NUM_ROWS]; + float[] scalarFloatValues = new float[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedFloatValues[i] = (int) _floatSVValues[i]; + scalarFloatValues[i] = (float) cast(cast(_floatSVValues[i], "int"), "float"); } testTransformFunction(transformFunction, expectedFloatValues); + assertEquals(expectedFloatValues, scalarFloatValues); expression = RequestContextUtils.getExpressionFromSQL( String.format("CAST(ADD(CAST(%s AS LONG), %s) AS STRING)", DOUBLE_SV_COLUMN, LONG_SV_COLUMN)); @@ -55,18 +64,26 @@ public class CastTransformFunctionTest extends BaseTransformFunctionTest { Assert.assertTrue(transformFunction instanceof CastTransformFunction); for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = Double.toString((double) (long) _doubleSVValues[i] + (double) _longSVValues[i]); + scalarStringValues[i] = (String) cast( + (double) (long) cast(_doubleSVValues[i], "long") + (double) _longSVValues[i], "string"); } testTransformFunction(transformFunction, expectedValues); + assertEquals(expectedValues, scalarStringValues); expression = RequestContextUtils.getExpressionFromSQL( String.format("caSt(cAst(casT(%s as inT) + %s aS sTring) As DouBle)", FLOAT_SV_COLUMN, INT_SV_COLUMN)); transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CastTransformFunction); double[] expectedDoubleValues = new double[NUM_ROWS]; + double[] scalarDoubleValues = new double[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedDoubleValues[i] = (double) (int) _floatSVValues[i] + (double) _intSVValues[i]; + scalarDoubleValues[i] = + (double) cast(cast((double) (int) cast(_floatSVValues[i], "int") + (double) _intSVValues[i], "string"), + "double"); } testTransformFunction(transformFunction, expectedDoubleValues); + assertEquals(expectedDoubleValues, scalarDoubleValues); expression = RequestContextUtils.getExpressionFromSQL(String .format("CAST(CAST(%s AS INT) - CAST(%s AS FLOAT) / CAST(%s AS DOUBLE) AS LONG)", DOUBLE_SV_COLUMN, @@ -74,10 +91,14 @@ public class CastTransformFunctionTest extends BaseTransformFunctionTest { transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); Assert.assertTrue(transformFunction instanceof CastTransformFunction); long[] expectedLongValues = new long[NUM_ROWS]; + long[] longScalarValues = new long[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedLongValues[i] = (long) ((double) (int) _doubleSVValues[i] - (double) (float) _longSVValues[i] / (double) _intSVValues[i]); + longScalarValues[i] = (long) cast((double) (int) cast(_doubleSVValues[i], "int") + - (double) (float) cast(_longSVValues[i], "float") / (double) cast(_intSVValues[i], "double"), "long"); } testTransformFunction(transformFunction, expectedLongValues); + assertEquals(expectedLongValues, longScalarValues); } } diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/CastQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/CastQueriesTest.java new file mode 100644 index 0000000000..47f6b0d4c2 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/queries/CastQueriesTest.java @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.queries; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import org.apache.commons.io.FileUtils; +import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.operator.query.AggregationGroupByOperator; +import org.apache.pinot.core.operator.query.AggregationOperator; +import org.apache.pinot.core.operator.query.SelectionOnlyOperator; +import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; +import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; +import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader; +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl; +import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader; +import org.apache.pinot.segment.spi.ImmutableSegment; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.utils.ReadMode; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + + +public class CastQueriesTest extends BaseQueriesTest { + + private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "CastQueriesTest"); + private static final String RAW_TABLE_NAME = "testTable"; + private static final String SEGMENT_NAME = "testSegment"; + + private static final int NUM_RECORDS = 1000; + private static final int BUCKET_SIZE = 8; + private static final String CLASSIFICATION_COLUMN = "class"; + private static final String X_COL = "x"; + private static final String Y_COL = "y"; + + private static final Schema SCHEMA = new Schema.SchemaBuilder() + .addSingleValueDimension(X_COL, FieldSpec.DataType.DOUBLE) + .addSingleValueDimension(Y_COL, FieldSpec.DataType.DOUBLE) + .addSingleValueDimension(CLASSIFICATION_COLUMN, FieldSpec.DataType.STRING) + .build(); + + private static final TableConfig TABLE_CONFIG = + new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); + + private IndexSegment _indexSegment; + private List<IndexSegment> _indexSegments; + + @Override + protected String getFilter() { + return ""; + } + + @Override + protected IndexSegment getIndexSegment() { + return _indexSegment; + } + + @Override + protected List<IndexSegment> getIndexSegments() { + return _indexSegments; + } + + @BeforeClass + public void setUp() + throws Exception { + FileUtils.deleteQuietly(INDEX_DIR); + + List<GenericRow> records = new ArrayList<>(NUM_RECORDS); + for (int i = 0; i < NUM_RECORDS; i++) { + GenericRow record = new GenericRow(); + record.putValue(X_COL, 0.5); + record.putValue(Y_COL, 0.25); + record.putValue(CLASSIFICATION_COLUMN, "" + (i % BUCKET_SIZE)); + records.add(record); + } + + SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA); + segmentGeneratorConfig.setTableName(RAW_TABLE_NAME); + segmentGeneratorConfig.setSegmentName(SEGMENT_NAME); + segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath()); + + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records)); + driver.build(); + + ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap); + _indexSegment = immutableSegment; + _indexSegments = Arrays.asList(immutableSegment, immutableSegment); + } + + @Test + public void testCastSum() { + String query = "select cast(sum(" + X_COL + ") as int), " + + "cast(sum(" + Y_COL + ") as int) " + + "from " + RAW_TABLE_NAME; + Operator<?> operator = getOperatorForSqlQuery(query); + assertTrue(operator instanceof AggregationOperator); + List<Object> aggregationResult = ((AggregationOperator) operator).nextBlock().getAggregationResult(); + assertNotNull(aggregationResult); + assertEquals(aggregationResult.size(), 2); + assertEquals(((Number) aggregationResult.get(0)).intValue(), NUM_RECORDS / 2); + assertEquals(((Number) aggregationResult.get(1)).intValue(), NUM_RECORDS / 4); + } + + @Test + public void testCastSumGroupBy() { + String query = "select cast(sum(" + X_COL + ") as int), " + + "cast(sum(" + Y_COL + ") as int) " + + "from " + RAW_TABLE_NAME + " " + + "group by " + CLASSIFICATION_COLUMN; + Operator<?> operator = getOperatorForSqlQuery(query); + assertTrue(operator instanceof AggregationGroupByOperator); + AggregationGroupByResult result = ((AggregationGroupByOperator) operator).nextBlock().getAggregationGroupByResult(); + assertNotNull(result); + Iterator<GroupKeyGenerator.GroupKey> it = result.getGroupKeyIterator(); + while (it.hasNext()) { + GroupKeyGenerator.GroupKey groupKey = it.next(); + Object aggregate = result.getResultForGroupId(0, groupKey._groupId); + assertEquals(((Number) aggregate).intValue(), NUM_RECORDS / (2 * BUCKET_SIZE)); + aggregate = result.getResultForGroupId(1, groupKey._groupId); + assertEquals(((Number) aggregate).intValue(), NUM_RECORDS / (4 * BUCKET_SIZE)); + } + } + + @Test + public void testCastFilterAndProject() { + String query = "select cast(" + CLASSIFICATION_COLUMN + " as int)" + + " from " + RAW_TABLE_NAME + + " where " + CLASSIFICATION_COLUMN + " = cast(0 as string) limit " + NUM_RECORDS; + Operator<?> operator = getOperatorForSqlQuery(query); + assertTrue(operator instanceof SelectionOnlyOperator); + Collection<Object[]> result = ((SelectionOnlyOperator) operator).nextBlock().getSelectionResult(); + assertNotNull(result); + assertEquals(result.size(), NUM_RECORDS / BUCKET_SIZE); + for (Object[] row : result) { + assertEquals(row.length, 1); + assertEquals(row[0], 0); + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org