This is an automated email from the ASF dual-hosted git repository. xiangfu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 48490933de Support parsing ARRAY literal in multistage query engine (#11268) 48490933de is described below commit 48490933ded6f20d94139eedbb57971fe0cbcb61 Author: Xiang Fu <xiangfu.1...@gmail.com> AuthorDate: Mon Aug 7 13:10:35 2023 -0700 Support parsing ARRAY literal in multistage query engine (#11268) --- .../common/function/TransformFunctionType.java | 2 + .../org/apache/pinot/common/utils/DataSchema.java | 54 +++- .../function/ArrayLiteralTransformFunction.java | 291 +++++++++++++++++++++ .../function/TransformFunctionFactory.java | 8 + .../core/data/function/VectorFunctionsTest.java | 19 ++ .../ArrayLiteralTransformFunctionTest.java | 167 ++++++++++++ .../function/VectorTransformFunctionTest.java | 12 +- .../integration/tests/VectorIntegrationTest.java | 86 +++++- .../apache/calcite/sql/fun/PinotOperatorTable.java | 8 +- .../planner/logical/RelToPlanNodeConverter.java | 21 +- .../local/function/InbuiltFunctionEvaluator.java | 39 ++- 11 files changed, 677 insertions(+), 30 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java index 471f6b128a..f741ff223e 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java @@ -221,6 +221,8 @@ public enum TransformFunctionType { VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE), OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"), + ARRAY_VALUE_CONSTRUCTOR("arrayValueConstructor"), + // Trigonometry SIN("sin"), COS("cos"), diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java index 354ba8cd3c..282a3d7416 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java @@ -24,6 +24,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyOrder; import com.google.common.collect.Ordering; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.floats.FloatArrayList; +import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.longs.LongArrayList; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; @@ -358,11 +360,11 @@ public class DataSchema { case BYTES: return ((ByteArray) value).getBytes(); case INT_ARRAY: - return (int[]) value; + return toIntArray(value); case LONG_ARRAY: return toLongArray(value); case FLOAT_ARRAY: - return (float[]) value; + return toFloatArray(value); case DOUBLE_ARRAY: return toDoubleArray(value); case STRING_ARRAY: @@ -475,6 +477,38 @@ public class DataSchema { } } + private static float[] toFloatArray(Object value) { + if (value instanceof float[]) { + return (float[]) value; + } else if (value instanceof FloatArrayList) { + return ((FloatArrayList) value).elements(); + } else if (value instanceof int[]) { + int[] intValues = (int[]) value; + int length = intValues.length; + float[] floatValues = new float[length]; + for (int i = 0; i < length; i++) { + floatValues[i] = intValues[i]; + } + return floatValues; + } else if (value instanceof long[]) { + long[] longValues = (long[]) value; + int length = longValues.length; + float[] floatValues = new float[length]; + for (int i = 0; i < length; i++) { + floatValues[i] = longValues[i]; + } + return floatValues; + } else { + double[] doubleValues = (double[]) value; + int length = doubleValues.length; + float[] floatValues = new float[length]; + for (int i = 0; i < length; i++) { + floatValues[i] = (float) doubleValues[i]; + } + return floatValues; + } + } + private static long[] toLongArray(Object value) { if (value instanceof long[]) { return (long[]) value; @@ -491,6 +525,22 @@ public class DataSchema { } } + private static int[] toIntArray(Object value) { + if (value instanceof int[]) { + return (int[]) value; + } else if (value instanceof IntArrayList) { + return ((IntArrayList) value).elements(); + } else { + long[] longValues = (long[]) value; + int length = longValues.length; + int[] intValues = new int[length]; + for (int i = 0; i < length; i++) { + intValues[i] = (int) longValues[i]; + } + return intValues; + } + } + private static boolean[] toBooleanArray(Object value) { int[] ints = (int[]) value; boolean[] booleans = new boolean[ints.length]; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java new file mode 100644 index 0000000000..6208ee1966 --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java @@ -0,0 +1,291 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.transform.function; + +import com.google.common.base.Preconditions; +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.operator.ColumnContext; +import org.apache.pinot.core.operator.blocks.ValueBlock; +import org.apache.pinot.core.operator.transform.TransformResultMetadata; +import org.apache.pinot.segment.spi.index.reader.Dictionary; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.roaringbitmap.RoaringBitmap; + + +/** + * The <code>LiteralTransformFunction</code> class is a special transform function which is a wrapper on top of a + * LITERAL. The data type is inferred from the literal string. + */ +public class ArrayLiteralTransformFunction implements TransformFunction { + public static final String FUNCTION_NAME = "arrayValueConstructor"; + + private final DataType _dataType; + + private final int[] _intArrayLiteral; + private final long[] _longArrayLiteral; + private final float[] _floatArrayLiteral; + private final double[] _doubleArrayLiteral; + private final String[] _stringArrayLiteral; + + // literals may be shared but values are intentionally not volatile as assignment races are benign + private int[][] _intArrayResult; + private long[][] _longArrayResult; + private float[][] _floatArrayResult; + private double[][] _doubleArrayResult; + private String[][] _stringArrayResult; + + public ArrayLiteralTransformFunction(List<ExpressionContext> literalContexts) { + Preconditions.checkNotNull(literalContexts); + if (literalContexts.isEmpty()) { + _dataType = DataType.UNKNOWN; + _intArrayLiteral = new int[0]; + _longArrayLiteral = new long[0]; + _floatArrayLiteral = new float[0]; + _doubleArrayLiteral = new double[0]; + _stringArrayLiteral = new String[0]; + return; + } + for (ExpressionContext literalContext : literalContexts) { + Preconditions.checkState(literalContext.getType() == ExpressionContext.Type.LITERAL, + "ArrayLiteralTransformFunction only takes literals as arguments, found: %s", literalContext); + } + _dataType = literalContexts.get(0).getLiteral().getType(); + switch (_dataType) { + case INT: + _intArrayLiteral = new int[literalContexts.size()]; + for (int i = 0; i < _intArrayLiteral.length; i++) { + _intArrayLiteral[i] = literalContexts.get(i).getLiteral().getIntValue(); + } + _longArrayLiteral = null; + _floatArrayLiteral = null; + _doubleArrayLiteral = null; + _stringArrayLiteral = null; + break; + case LONG: + _longArrayLiteral = new long[literalContexts.size()]; + for (int i = 0; i < _longArrayLiteral.length; i++) { + _longArrayLiteral[i] = Long.parseLong(literalContexts.get(i).getLiteral().getStringValue()); + } + _intArrayLiteral = null; + _floatArrayLiteral = null; + _doubleArrayLiteral = null; + _stringArrayLiteral = null; + break; + case FLOAT: + _floatArrayLiteral = new float[literalContexts.size()]; + for (int i = 0; i < _floatArrayLiteral.length; i++) { + _floatArrayLiteral[i] = Float.parseFloat(literalContexts.get(i).getLiteral().getStringValue()); + } + _intArrayLiteral = null; + _longArrayLiteral = null; + _doubleArrayLiteral = null; + _stringArrayLiteral = null; + break; + case DOUBLE: + _doubleArrayLiteral = new double[literalContexts.size()]; + for (int i = 0; i < _doubleArrayLiteral.length; i++) { + _doubleArrayLiteral[i] = Double.parseDouble(literalContexts.get(i).getLiteral().getStringValue()); + } + _intArrayLiteral = null; + _longArrayLiteral = null; + _floatArrayLiteral = null; + _stringArrayLiteral = null; + break; + case STRING: + _stringArrayLiteral = new String[literalContexts.size()]; + for (int i = 0; i < _stringArrayLiteral.length; i++) { + _stringArrayLiteral[i] = literalContexts.get(i).getLiteral().getStringValue(); + } + _intArrayLiteral = null; + _longArrayLiteral = null; + _floatArrayLiteral = null; + _doubleArrayLiteral = null; + break; + default: + throw new IllegalStateException( + "Illegal data type for ArrayLiteralTransformFunction: " + _dataType + ", literal contexts: " + + Arrays.toString(literalContexts.toArray())); + } + } + + public int[] getIntArrayLiteral() { + return _intArrayLiteral; + } + + public long[] getLongArrayLiteral() { + return _longArrayLiteral; + } + + public float[] getFloatArrayLiteral() { + return _floatArrayLiteral; + } + + public double[] getDoubleArrayLiteral() { + return _doubleArrayLiteral; + } + + public String[] getStringArrayLiteral() { + return _stringArrayLiteral; + } + + @Override + public String getName() { + return FUNCTION_NAME; + } + + @Override + public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) { + } + + @Override + public TransformResultMetadata getResultMetadata() { + return new TransformResultMetadata(_dataType, false, false); + } + + @Override + public Dictionary getDictionary() { + return null; + } + + @Override + public int[] transformToDictIdsSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public int[][] transformToDictIdsMV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public int[] transformToIntValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public long[] transformToLongValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public float[] transformToFloatValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public double[] transformToDoubleValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal[] transformToBigDecimalValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public String[] transformToStringValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[][] transformToBytesValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public int[][] transformToIntValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + int[][] intArrayResult = _intArrayResult; + if (intArrayResult == null || intArrayResult.length < numDocs) { + intArrayResult = new int[numDocs][]; + Arrays.fill(intArrayResult, _intArrayLiteral); + _intArrayResult = intArrayResult; + } + return intArrayResult; + } + + @Override + public long[][] transformToLongValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + long[][] longArrayResult = _longArrayResult; + if (longArrayResult == null || longArrayResult.length < numDocs) { + longArrayResult = new long[numDocs][]; + Arrays.fill(longArrayResult, _longArrayLiteral); + _longArrayResult = longArrayResult; + } + return longArrayResult; + } + + @Override + public float[][] transformToFloatValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + float[][] floatArrayResult = _floatArrayResult; + if (floatArrayResult == null || floatArrayResult.length < numDocs) { + floatArrayResult = new float[numDocs][]; + Arrays.fill(floatArrayResult, _floatArrayLiteral); + _floatArrayResult = floatArrayResult; + } + return floatArrayResult; + } + + @Override + public double[][] transformToDoubleValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + double[][] doubleArrayResult = _doubleArrayResult; + if (doubleArrayResult == null || doubleArrayResult.length < numDocs) { + doubleArrayResult = new double[numDocs][]; + Arrays.fill(doubleArrayResult, _doubleArrayLiteral); + _doubleArrayResult = doubleArrayResult; + } + return doubleArrayResult; + } + + @Override + public String[][] transformToStringValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + String[][] stringArrayResult = _stringArrayResult; + if (stringArrayResult == null || stringArrayResult.length < numDocs) { + stringArrayResult = new String[numDocs][]; + Arrays.fill(stringArrayResult, _stringArrayLiteral); + _stringArrayResult = stringArrayResult; + } + return stringArrayResult; + } + + @Override + public byte[][][] transformToBytesValuesMV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public RoaringBitmap getNullBitmap(ValueBlock valueBlock) { + // Treat all unknown type values as null regardless of the value. + if (_dataType != DataType.UNKNOWN) { + return null; + } + int length = valueBlock.getNumDocs(); + RoaringBitmap bitmap = new RoaringBitmap(); + bitmap.add(0L, length); + return bitmap; + } +} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java index 4e3ff24119..be2b54d128 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java @@ -153,6 +153,7 @@ public class TransformFunctionFactory { typeToImplementation.put(TransformFunctionType.ARRAYMAX, ArrayMaxTransformFunction.class); typeToImplementation.put(TransformFunctionType.ARRAYMIN, ArrayMinTransformFunction.class); typeToImplementation.put(TransformFunctionType.ARRAYSUM, ArraySumTransformFunction.class); + typeToImplementation.put(TransformFunctionType.ARRAY_VALUE_CONSTRUCTOR, ArrayLiteralTransformFunction.class); typeToImplementation.put(TransformFunctionType.GROOVY, GroovyTransformFunction.class); typeToImplementation.put(TransformFunctionType.CASE, CaseTransformFunction.class); @@ -281,6 +282,13 @@ public class TransformFunctionFactory { List<ExpressionContext> arguments = function.getArguments(); int numArguments = arguments.size(); + // Check if the function is ArrayLiteraltransform function + if (functionName.equalsIgnoreCase(ArrayLiteralTransformFunction.FUNCTION_NAME)) { + return queryContext.getOrComputeSharedValue(ArrayLiteralTransformFunction.class, + expression.getFunction().getArguments(), + ArrayLiteralTransformFunction::new); + } + TransformFunction transformFunction; Class<? extends TransformFunction> transformFunctionClass = TRANSFORM_FUNCTION_MAP.get(functionName); if (transformFunctionClass != null) { diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java index 972c33ee43..6600b5c10f 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java @@ -108,6 +108,25 @@ public class VectorFunctionsTest { inputs.add(new Object[]{"vectorDims(vector2)", Lists.newArrayList("vector2"), row, 5}); inputs.add(new Object[]{"vectorNorm(vector1)", Lists.newArrayList("vector1"), row, 0.741619857751291}); inputs.add(new Object[]{"vectorNorm(vector2)", Lists.newArrayList("vector2"), row, 0.0}); + + inputs.add(new Object[]{ + "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", Lists.newArrayList("vector1"), row, Double.NaN + }); + inputs.add(new Object[]{ + "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0], 0.0)", Lists.newArrayList("vector1"), row, 0.0 + }); + inputs.add(new Object[]{ + "cosineDistance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0], 1.0)", Lists.newArrayList("vector1"), row, 1.0 + }); + inputs.add(new Object[]{ + "innerProduct(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", Lists.newArrayList("vector1"), row, 0.0 + }); + inputs.add(new Object[]{ + "l2Distance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", Lists.newArrayList("vector1"), row, 0.741619857751291 + }); + inputs.add(new Object[]{ + "l1Distance(vector1, ARRAY[0.0,0.0,0.0,0.0,0.0])", Lists.newArrayList("vector1"), row, 1.5000000223517418 + }); return inputs.toArray(new Object[0][]); } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java new file mode 100644 index 0000000000..005b8c3eeb --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunctionTest.java @@ -0,0 +1,167 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.transform.function; + +import java.util.ArrayList; +import java.util.List; +import org.apache.pinot.common.request.Literal; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.operator.blocks.ProjectionBlock; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.testng.Assert; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static org.mockito.Mockito.when; + + +public class ArrayLiteralTransformFunctionTest { + private static final int NUM_DOCS = 100; + private AutoCloseable _mocks; + + @Mock + private ProjectionBlock _projectionBlock; + + @BeforeMethod + public void setUp() { + _mocks = MockitoAnnotations.openMocks(this); + when(_projectionBlock.getNumDocs()).thenReturn(NUM_DOCS); + } + + @AfterMethod + public void tearDown() + throws Exception { + _mocks.close(); + } + + @Test + public void testIntArrayLiteralTransformFunction() { + List<ExpressionContext> arrayExpressions = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.INT, i)); + } + + ArrayLiteralTransformFunction intArray = new ArrayLiteralTransformFunction(arrayExpressions); + Assert.assertEquals(intArray.getResultMetadata().getDataType(), DataType.INT); + Assert.assertEquals(intArray.getIntArrayLiteral(), new int[]{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + } + + @Test + public void testLongArrayLiteralTransformFunction() { + List<ExpressionContext> arrayExpressions = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.LONG, (long) i)); + } + + ArrayLiteralTransformFunction longArray = new ArrayLiteralTransformFunction(arrayExpressions); + Assert.assertEquals(longArray.getResultMetadata().getDataType(), DataType.LONG); + Assert.assertEquals(longArray.getLongArrayLiteral(), new long[]{ + 0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L + }); + } + + @Test + public void testFloatArrayLiteralTransformFunction() { + List<ExpressionContext> arrayExpressions = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.FLOAT, (double) i)); + } + + ArrayLiteralTransformFunction floatArray = new ArrayLiteralTransformFunction(arrayExpressions); + Assert.assertEquals(floatArray.getResultMetadata().getDataType(), DataType.FLOAT); + Assert.assertEquals(floatArray.getFloatArrayLiteral(), new float[]{ + 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f + }); + } + + @Test + public void testDoubleArrayLiteralTransformFunction() { + List<ExpressionContext> arrayExpressions = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.DOUBLE, (double) i)); + } + + ArrayLiteralTransformFunction doubleArray = new ArrayLiteralTransformFunction(arrayExpressions); + Assert.assertEquals(doubleArray.getResultMetadata().getDataType(), DataType.DOUBLE); + Assert.assertEquals(doubleArray.getDoubleArrayLiteral(), new double[]{ + 0d, 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d + }); + } + + @Test + public void testStringArrayLiteralTransformFunction() { + List<ExpressionContext> arrayExpressions = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + arrayExpressions.add( + ExpressionContext.forLiteralContext(new Literal(Literal._Fields.STRING_VALUE, String.valueOf(i)))); + } + + ArrayLiteralTransformFunction stringArray = new ArrayLiteralTransformFunction(arrayExpressions); + Assert.assertEquals(stringArray.getResultMetadata().getDataType(), DataType.STRING); + Assert.assertEquals(stringArray.getStringArrayLiteral(), new String[]{ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" + }); + } + + @Test + public void testEmptyArrayTransform() { + List<ExpressionContext> arrayExpressions = new ArrayList<>(); + ArrayLiteralTransformFunction emptyLiteral = new ArrayLiteralTransformFunction(arrayExpressions); + Assert.assertEquals(emptyLiteral.getIntArrayLiteral(), new int[0]); + Assert.assertEquals(emptyLiteral.getLongArrayLiteral(), new long[0]); + Assert.assertEquals(emptyLiteral.getFloatArrayLiteral(), new float[0]); + Assert.assertEquals(emptyLiteral.getDoubleArrayLiteral(), new double[0]); + Assert.assertEquals(emptyLiteral.getStringArrayLiteral(), new String[0]); + + int[][] ints = emptyLiteral.transformToIntValuesMV(_projectionBlock); + Assert.assertEquals(ints.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(ints[i].length, 0); + } + + long[][] longs = emptyLiteral.transformToLongValuesMV(_projectionBlock); + Assert.assertEquals(longs.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(longs[i].length, 0); + } + + float[][] floats = emptyLiteral.transformToFloatValuesMV(_projectionBlock); + Assert.assertEquals(floats.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(floats[i].length, 0); + } + + double[][] doubles = emptyLiteral.transformToDoubleValuesMV(_projectionBlock); + Assert.assertEquals(doubles.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(doubles[i].length, 0); + } + + String[][] strings = emptyLiteral.transformToStringValuesMV(_projectionBlock); + Assert.assertEquals(strings.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(strings[i].length, 0); + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java index 8aed6e4698..23b3213f3e 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java @@ -59,6 +59,9 @@ public class VectorTransformFunctionTest extends BaseTransformFunctionTest { @DataProvider(name = "testVectorTransformFunctionDataProvider") public Object[][] testVectorTransformFunctionDataProvider() { + String zeroVectorLiteral = "ARRAY[0.0" + + ",0.0".repeat(VECTOR_DIM_SIZE - 1) + + "]"; return new Object[][]{ new Object[]{"cosineDistance(vector1, vector2)", 0.1, 0.4}, new Object[]{"cosineDistance(vector1, vector2, 0)", 0.1, 0.4}, @@ -67,7 +70,14 @@ public class VectorTransformFunctionTest extends BaseTransformFunctionTest { new Object[]{"l1Distance(vector1, vector2)", 140, 210}, new Object[]{"l2Distance(vector1, vector2)", 8, 11}, new Object[]{"vectorNorm(vector1)", 10, 16}, - new Object[]{"vectorNorm(vector2)", 10, 16} + new Object[]{"vectorNorm(vector2)", 10, 16}, + + new Object[]{String.format("cosineDistance(vector1, %s, 0)", zeroVectorLiteral), 0.0, 0.0}, + new Object[]{String.format("innerProduct(vector1, %s)", zeroVectorLiteral), 0.0, 0.0}, + new Object[]{String.format("l1Distance(vector1, %s)", zeroVectorLiteral), 0, VECTOR_DIM_SIZE}, + new Object[]{String.format("l2Distance(vector1, %s)", zeroVectorLiteral), 0, VECTOR_DIM_SIZE}, + new Object[]{String.format("vectorDims(%s)", zeroVectorLiteral), VECTOR_DIM_SIZE, VECTOR_DIM_SIZE}, + new Object[]{String.format("vectorNorm(%s)", zeroVectorLiteral), 0.0, 0.0}, }; } } diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java index 48efe20490..dbfcd5a347 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java @@ -100,32 +100,96 @@ public class VectorIntegrationTest extends BaseClusterIntegrationTest { + "vectorNorm(vector1), vectorNorm(vector2), " + "cosineDistance(vector1, zeroVector), " + "cosineDistance(vector1, zeroVector, 0) " - + "FROM %s", DEFAULT_TABLE_NAME); + + "FROM %s LIMIT %d", DEFAULT_TABLE_NAME, getCountStarResult()); JsonNode jsonNode = postQuery(query); for (int i = 0; i < getCountStarResult(); i++) { - double cosineDistance = jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble(); + double cosineDistance = jsonNode.get("resultTable").get("rows").get(i).get(0).asDouble(); assertTrue(cosineDistance > 0.1 && cosineDistance < 0.4); - double innerProduce = jsonNode.get("resultTable").get("rows").get(0).get(1).asDouble(); + double innerProduce = jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble(); assertTrue(innerProduce > 100 && innerProduce < 160); - double l1Distance = jsonNode.get("resultTable").get("rows").get(0).get(2).asDouble(); + double l1Distance = jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble(); assertTrue(l1Distance > 140 && l1Distance < 210); - double l2Distance = jsonNode.get("resultTable").get("rows").get(0).get(3).asDouble(); + double l2Distance = jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble(); assertTrue(l2Distance > 8 && l2Distance < 11); - int vectorDimsVector1 = jsonNode.get("resultTable").get("rows").get(0).get(4).asInt(); + int vectorDimsVector1 = jsonNode.get("resultTable").get("rows").get(i).get(4).asInt(); assertEquals(vectorDimsVector1, VECTOR_DIM_SIZE); - int vectorDimsVector2 = jsonNode.get("resultTable").get("rows").get(0).get(5).asInt(); + int vectorDimsVector2 = jsonNode.get("resultTable").get("rows").get(i).get(5).asInt(); assertEquals(vectorDimsVector2, VECTOR_DIM_SIZE); - double vectorNormVector1 = jsonNode.get("resultTable").get("rows").get(0).get(6).asInt(); + double vectorNormVector1 = jsonNode.get("resultTable").get("rows").get(i).get(6).asInt(); assertTrue(vectorNormVector1 > 10 && vectorNormVector1 < 16); - double vectorNormVector2 = jsonNode.get("resultTable").get("rows").get(0).get(7).asInt(); + double vectorNormVector2 = jsonNode.get("resultTable").get("rows").get(i).get(7).asInt(); assertTrue(vectorNormVector2 > 10 && vectorNormVector2 < 16); - cosineDistance = jsonNode.get("resultTable").get("rows").get(0).get(8).asDouble(); + cosineDistance = jsonNode.get("resultTable").get("rows").get(i).get(8).asDouble(); assertEquals(cosineDistance, Double.NaN); - cosineDistance = jsonNode.get("resultTable").get("rows").get(0).get(9).asDouble(); + cosineDistance = jsonNode.get("resultTable").get("rows").get(i).get(9).asDouble(); assertEquals(cosineDistance, 0.0); } } + @Test(dataProvider = "useBothQueryEngines") + public void testQueriesWithLiterals(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String zeroVectorStringLiteral = "ARRAY[0.0" + + ", 0.0".repeat(VECTOR_DIM_SIZE - 1) + + "]"; + String oneVectorStringLiteral = "ARRAY[1.0" + + ", 1.0".repeat(VECTOR_DIM_SIZE - 1) + + "]"; + String query = + String.format("SELECT " + + "cosineDistance(vector1, %s), " + + "innerProduct(vector1, %s), " + + "l1Distance(vector1, %s), " + + "l2Distance(vector1, %s), " + + "vectorDims(%s), " + + "vectorNorm(%s) " + + "FROM %s LIMIT %d", + zeroVectorStringLiteral, zeroVectorStringLiteral, zeroVectorStringLiteral, zeroVectorStringLiteral, + zeroVectorStringLiteral, zeroVectorStringLiteral, DEFAULT_TABLE_NAME, getCountStarResult()); + JsonNode jsonNode = postQuery(query); + for (int i = 0; i < getCountStarResult(); i++) { + double cosineDistance = jsonNode.get("resultTable").get("rows").get(i).get(0).asDouble(); + assertEquals(cosineDistance, Double.NaN); + double innerProduce = jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble(); + assertEquals(innerProduce, 0.0); + double l1Distance = jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble(); + assertTrue(l1Distance > 100 && l1Distance < 300); + double l2Distance = jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble(); + assertTrue(l2Distance > 10 && l2Distance < 16); + int vectorDimsVector = jsonNode.get("resultTable").get("rows").get(i).get(4).asInt(); + assertEquals(vectorDimsVector, VECTOR_DIM_SIZE); + double vectorNormVector = jsonNode.get("resultTable").get("rows").get(i).get(5).asInt(); + assertEquals(vectorNormVector, 0.0); + } + + query = + String.format("SELECT " + + "cosineDistance(%s, %s), " + + "cosineDistance(%s, %s, 0.0), " + + "innerProduct(%s, %s), " + + "l1Distance(%s, %s), " + + "l2Distance(%s, %s)" + + "FROM %s LIMIT 1", + zeroVectorStringLiteral, oneVectorStringLiteral, + zeroVectorStringLiteral, oneVectorStringLiteral, + zeroVectorStringLiteral, oneVectorStringLiteral, + zeroVectorStringLiteral, oneVectorStringLiteral, + zeroVectorStringLiteral, oneVectorStringLiteral, + DEFAULT_TABLE_NAME); + jsonNode = postQuery(query); + double cosineDistance = jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble(); + assertEquals(cosineDistance, Double.NaN); + cosineDistance = jsonNode.get("resultTable").get("rows").get(0).get(1).asDouble(); + assertEquals(cosineDistance, 0.0); + double innerProduce = jsonNode.get("resultTable").get("rows").get(0).get(2).asDouble(); + assertEquals(innerProduce, 0.0); + double l1Distance = jsonNode.get("resultTable").get("rows").get(0).get(3).asDouble(); + assertEquals(l1Distance, 512.0); + double l2Distance = jsonNode.get("resultTable").get("rows").get(0).get(4).asDouble(); + assertEquals(l2Distance, 22.627416997969522); + } + private File createAvroFile(long totalNumRecords) throws IOException { diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java index 2ee178419f..1a63a6eb07 100644 --- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java +++ b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java @@ -49,8 +49,6 @@ public class PinotOperatorTable extends SqlStdOperatorTable { private static @MonotonicNonNull PinotOperatorTable _instance; - public static final SqlFunction COALESCE = new PinotSqlCoalesceFunction(); - // TODO: clean up lazy init by using Suppliers.memorized(this::computeInstance) and make getter wrapped around // supplier instance. this should replace all lazy init static objects in the codebase public static synchronized PinotOperatorTable instance() { @@ -75,6 +73,12 @@ public class PinotOperatorTable extends SqlStdOperatorTable { * which are multistage enabled. */ public final void initNoDuplicate() { + // Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. + register(new PinotSqlCoalesceFunction()); + // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor + register(ARRAY_VALUE_CONSTRUCTOR); + + // TODO: reflection based registration is not ideal, we should use a static list of operators and register them // Use reflection to register the expressions stored in public fields. for (Field field : getClass().getFields()) { try { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java index a4e6be355a..b0b7545677 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java @@ -219,7 +219,7 @@ public final class RelToPlanNodeConverter { case BIGINT: return isArray ? DataSchema.ColumnDataType.LONG_ARRAY : DataSchema.ColumnDataType.LONG; case DECIMAL: - return resolveDecimal(relDataType); + return resolveDecimal(relDataType, isArray); case FLOAT: case REAL: return isArray ? DataSchema.ColumnDataType.FLOAT_ARRAY : DataSchema.ColumnDataType.FLOAT; @@ -259,31 +259,32 @@ public final class RelToPlanNodeConverter { } /** - * Calcite uses DEMICAL type to infer data type hoisting and infer arithmetic result types. down casting this - * back to the proper primitive type for Pinot. + * Calcite uses DEMICAL type to infer data type hoisting and infer arithmetic result types. down casting this back to + * the proper primitive type for Pinot. * * @param relDataType the DECIMAL rel data type. + * @param isArray * @return proper {@link DataSchema.ColumnDataType}. * @see {@link org.apache.calcite.rel.type.RelDataTypeFactoryImpl#decimalOf}. */ - private static DataSchema.ColumnDataType resolveDecimal(RelDataType relDataType) { + private static DataSchema.ColumnDataType resolveDecimal(RelDataType relDataType, boolean isArray) { int precision = relDataType.getPrecision(); int scale = relDataType.getScale(); if (scale == 0) { if (precision <= 10) { - return DataSchema.ColumnDataType.INT; + return isArray ? DataSchema.ColumnDataType.INT_ARRAY : DataSchema.ColumnDataType.INT; } else if (precision <= 38) { - return DataSchema.ColumnDataType.LONG; + return isArray ? DataSchema.ColumnDataType.LONG_ARRAY : DataSchema.ColumnDataType.LONG; } else { - return DataSchema.ColumnDataType.BIG_DECIMAL; + return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY : DataSchema.ColumnDataType.BIG_DECIMAL; } } else { if (precision <= 14) { - return DataSchema.ColumnDataType.FLOAT; + return isArray ? DataSchema.ColumnDataType.FLOAT_ARRAY : DataSchema.ColumnDataType.FLOAT; } else if (precision <= 30) { - return DataSchema.ColumnDataType.DOUBLE; + return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY : DataSchema.ColumnDataType.DOUBLE; } else { - return DataSchema.ColumnDataType.BIG_DECIMAL; + return isArray ? DataSchema.ColumnDataType.DOUBLE_ARRAY : DataSchema.ColumnDataType.BIG_DECIMAL; } } } diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java index df896d2c00..823dd23b88 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java @@ -20,6 +20,7 @@ package org.apache.pinot.segment.local.function; import com.google.common.base.Preconditions; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.commons.lang3.StringUtils; import org.apache.pinot.common.function.FunctionInfo; @@ -78,6 +79,13 @@ public class InbuiltFunctionEvaluator implements FunctionEvaluator { case "not": Preconditions.checkState(numArguments == 1, "NOT function expects 1 argument, got: %s", numArguments); return new NotExecutionNode(childNodes[0]); + case "arrayvalueconstructor": + Object[] values = new Object[numArguments]; + int i = 0; + for (ExpressionContext literal : arguments) { + values[i++] = literal.getLiteral().getValue(); + } + return new ArrayConstantExecutionNode(values); default: FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName, numArguments); if (functionInfo == null) { @@ -145,7 +153,7 @@ public class InbuiltFunctionEvaluator implements FunctionEvaluator { @Override public Object execute(GenericRow row) { - for (ExecutableNode executableNode :_argumentNodes) { + for (ExecutableNode executableNode : _argumentNodes) { Boolean res = (Boolean) executableNode.execute(row); if (res) { return true; @@ -156,7 +164,7 @@ public class InbuiltFunctionEvaluator implements FunctionEvaluator { @Override public Object execute(Object[] values) { - for (ExecutableNode executableNode :_argumentNodes) { + for (ExecutableNode executableNode : _argumentNodes) { Boolean res = (Boolean) executableNode.execute(values); if (res) { return true; @@ -175,7 +183,7 @@ public class InbuiltFunctionEvaluator implements FunctionEvaluator { @Override public Object execute(GenericRow row) { - for (ExecutableNode executableNode :_argumentNodes) { + for (ExecutableNode executableNode : _argumentNodes) { Boolean res = (Boolean) executableNode.execute(row); if (!res) { return false; @@ -186,7 +194,7 @@ public class InbuiltFunctionEvaluator implements FunctionEvaluator { @Override public Object execute(Object[] values) { - for (ExecutableNode executableNode :_argumentNodes) { + for (ExecutableNode executableNode : _argumentNodes) { Boolean res = (Boolean) executableNode.execute(values); if (!res) { return false; @@ -284,6 +292,29 @@ public class InbuiltFunctionEvaluator implements FunctionEvaluator { } } + private static class ArrayConstantExecutionNode implements ExecutableNode { + final Object[] _value; + + ArrayConstantExecutionNode(Object[] value) { + _value = value; + } + + @Override + public Object[] execute(GenericRow row) { + return _value; + } + + @Override + public Object[] execute(Object[] values) { + return _value; + } + + @Override + public String toString() { + return String.format("'%s'", Arrays.toString(_value)); + } + } + private static class ColumnExecutionNode implements ExecutableNode { final String _column; final int _id; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org