This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push: new e3b0bfc scalar functions for array (#6105) e3b0bfc is described below commit e3b0bfcdbd796db91a606d591e492741799244b5 Author: SandishKumarHN <sanysand...@gmail.com> AuthorDate: Mon Nov 2 20:54:21 2020 -0800 scalar functions for array (#6105) Inbuilt scalar functions for array columns - array_reverse_int(multi_value_int_field) - array_reverse_string(multi_value_string_field) - array_sort_int(multi_value_int_field) - array_sort_string(multi_value_string_field) - array_index_of_int(multi_value_int_field, 2) - array_index_of_string(multi_value_string_field, 'foo') - array_contains_int(multi_value_int_field, 3) - array_contains_string(multi_value_string_field, 'bar') --- .../pinot/common/function/FunctionInvoker.java | 19 +- .../pinot/common/function/FunctionUtils.java | 32 +++ .../common/function/scalar/ArrayFunctions.java | 80 +++++++ .../apache/pinot/common/utils/PinotDataType.java | 32 ++- .../function/ScalarTransformFunctionWrapper.java | 122 +++++++++- .../core/data/function/InbuiltFunctionsTest.java | 67 ++++++ .../function/BaseTransformFunctionTest.java | 47 +++- .../ScalarTransformFunctionWrapperTest.java | 248 ++++++++++++++++++--- 8 files changed, 602 insertions(+), 45 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java index b185d26..aec453c 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java @@ -23,6 +23,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; +import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.common.utils.PinotDataType; @@ -108,7 +109,23 @@ public class FunctionInvoker { PinotDataType argumentType = FunctionUtils.getArgumentType(argumentClass); Preconditions.checkArgument(parameterType != null && argumentType != null, "Cannot convert value from class: %s to class: %s", argumentClass, parameterClass); - arguments[i] = parameterType.convert(argument, argumentType); + Object convertedArgument = parameterType.convert(argument, argumentType); + // For primitive array parameter, convert the argument from Object array to primitive array + switch (parameterType) { + case INTEGER_ARRAY: + convertedArgument = ArrayUtils.toPrimitive((Integer[]) convertedArgument); + break; + case LONG_ARRAY: + convertedArgument = ArrayUtils.toPrimitive((Long[]) convertedArgument); + break; + case FLOAT_ARRAY: + convertedArgument = ArrayUtils.toPrimitive((Float[]) convertedArgument); + break; + case DOUBLE_ARRAY: + convertedArgument = ArrayUtils.toPrimitive((Double[]) convertedArgument); + break; + } + arguments[i] = convertedArgument; } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java index 33da3cc..1b04ff7 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java @@ -42,6 +42,11 @@ public class FunctionUtils { put(Double.class, PinotDataType.DOUBLE); put(String.class, PinotDataType.STRING); put(byte[].class, PinotDataType.BYTES); + put(int[].class, PinotDataType.INTEGER_ARRAY); + put(long[].class, PinotDataType.LONG_ARRAY); + put(float[].class, PinotDataType.FLOAT_ARRAY); + put(double[].class, PinotDataType.DOUBLE_ARRAY); + put(String[].class, PinotDataType.STRING_ARRAY); }}; // Types allowed as the function argument (actual value passed into the function) for type conversion @@ -56,6 +61,15 @@ public class FunctionUtils { put(Double.class, PinotDataType.DOUBLE); put(String.class, PinotDataType.STRING); put(byte[].class, PinotDataType.BYTES); + put(int[].class, PinotDataType.INTEGER_ARRAY); + put(Integer[].class, PinotDataType.INTEGER_ARRAY); + put(long[].class, PinotDataType.LONG_ARRAY); + put(Long[].class, PinotDataType.LONG_ARRAY); + put(float[].class, PinotDataType.FLOAT_ARRAY); + put(Float[].class, PinotDataType.FLOAT_ARRAY); + put(double[].class, PinotDataType.DOUBLE_ARRAY); + put(Double[].class, PinotDataType.DOUBLE_ARRAY); + put(String[].class, PinotDataType.STRING_ARRAY); }}; private static final Map<Class<?>, DataType> DATA_TYPE_MAP = new HashMap<Class<?>, DataType>() {{ @@ -69,6 +83,15 @@ public class FunctionUtils { put(Double.class, DataType.DOUBLE); put(String.class, DataType.STRING); put(byte[].class, DataType.BYTES); + put(int[].class, DataType.INT); + put(Integer[].class, DataType.INT); + put(long[].class, DataType.LONG); + put(Long[].class, DataType.LONG); + put(float[].class, DataType.FLOAT); + put(Float[].class, DataType.FLOAT); + put(double[].class, DataType.DOUBLE); + put(Double[].class, DataType.DOUBLE); + put(String[].class, DataType.STRING); }}; private static final Map<Class<?>, ColumnDataType> COLUMN_DATA_TYPE_MAP = new HashMap<Class<?>, ColumnDataType>() {{ @@ -82,6 +105,15 @@ public class FunctionUtils { put(Double.class, ColumnDataType.DOUBLE); put(String.class, ColumnDataType.STRING); put(byte[].class, ColumnDataType.BYTES); + put(int[].class, ColumnDataType.INT_ARRAY); + put(Integer[].class, ColumnDataType.INT_ARRAY); + put(long[].class, ColumnDataType.LONG_ARRAY); + put(Long[].class, ColumnDataType.LONG_ARRAY); + put(float[].class, ColumnDataType.FLOAT_ARRAY); + put(Float[].class, ColumnDataType.FLOAT_ARRAY); + put(double[].class, ColumnDataType.DOUBLE_ARRAY); + put(Double[].class, ColumnDataType.DOUBLE_ARRAY); + put(String[].class, ColumnDataType.STRING_ARRAY); }}; /** diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java new file mode 100644 index 0000000..0976ab0 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar; + +import java.util.Arrays; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +/** + * Inbuilt array scalar functions. See {@link ArrayUtils} for details. + */ +public class ArrayFunctions { + private ArrayFunctions() { + } + + @ScalarFunction + public static int[] arrayReverseInt(int[] values) { + int[] clone = values.clone(); + ArrayUtils.reverse(clone); + return clone; + } + + @ScalarFunction + public static String[] arrayReverseString(String[] values) { + String[] clone = values.clone(); + ArrayUtils.reverse(clone); + return clone; + } + + @ScalarFunction + public static int[] arraySortInt(int[] values) { + int[] clone = values.clone(); + Arrays.sort(clone); + return clone; + } + + @ScalarFunction + public static String[] arraySortString(String[] values) { + String[] clone = values.clone(); + Arrays.sort(clone); + return clone; + } + + @ScalarFunction + public static int arrayIndexOfInt(int[] values, int valueToFind) { + return ArrayUtils.indexOf(values, valueToFind); + } + + @ScalarFunction + public static int arrayIndexOfString(String[] values, String valueToFind) { + return ArrayUtils.indexOf(values, valueToFind); + } + + @ScalarFunction + public static boolean arrayContainsInt(int[] values, int valueToFind) { + return ArrayUtils.contains(values, valueToFind); + } + + @ScalarFunction + public static boolean arrayContainsString(String[] values, String valueToFind) { + return ArrayUtils.contains(values, valueToFind); + } +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java index 97c017e..068241b 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.common.utils; +import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.utils.BytesUtils; @@ -505,7 +506,7 @@ public enum PinotDataType { if (isSingleValue()) { return new Integer[]{toInteger(value)}; } else { - Object[] valueArray = (Object[]) value; + Object[] valueArray = toObjectArray(value); int length = valueArray.length; Integer[] integerArray = new Integer[length]; PinotDataType singleValueType = getSingleValueType(); @@ -520,7 +521,7 @@ public enum PinotDataType { if (isSingleValue()) { return new Long[]{toLong(value)}; } else { - Object[] valueArray = (Object[]) value; + Object[] valueArray = toObjectArray(value); int length = valueArray.length; Long[] longArray = new Long[length]; PinotDataType singleValueType = getSingleValueType(); @@ -535,7 +536,7 @@ public enum PinotDataType { if (isSingleValue()) { return new Float[]{toFloat(value)}; } else { - Object[] valueArray = (Object[]) value; + Object[] valueArray = toObjectArray(value); int length = valueArray.length; Float[] floatArray = new Float[length]; PinotDataType singleValueType = getSingleValueType(); @@ -550,7 +551,7 @@ public enum PinotDataType { if (isSingleValue()) { return new Double[]{toDouble(value)}; } else { - Object[] valueArray = (Object[]) value; + Object[] valueArray = toObjectArray(value); int length = valueArray.length; Double[] doubleArray = new Double[length]; PinotDataType singleValueType = getSingleValueType(); @@ -565,7 +566,7 @@ public enum PinotDataType { if (isSingleValue()) { return new String[]{toString(value)}; } else { - Object[] valueArray = (Object[]) value; + Object[] valueArray = toObjectArray(value); int length = valueArray.length; String[] stringArray = new String[length]; PinotDataType singleValueType = getSingleValueType(); @@ -576,6 +577,27 @@ public enum PinotDataType { } } + private static Object[] toObjectArray(Object array) { + Class<?> componentType = array.getClass().getComponentType(); + if (componentType.isPrimitive()) { + if (componentType == int.class) { + return ArrayUtils.toObject((int[]) array); + } + if (componentType == long.class) { + return ArrayUtils.toObject((long[]) array); + } + if (componentType == float.class) { + return ArrayUtils.toObject((float[]) array); + } + if (componentType == double.class) { + return ArrayUtils.toObject((double[]) array); + } + throw new UnsupportedOperationException("Unsupported primitive array type: " + componentType); + } else { + return (Object[]) array; + } + } + public Object convert(Object value, PinotDataType sourceType) { throw new UnsupportedOperationException("Cannot convert value form " + sourceType + " to " + this); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java index 958b570..eca96f6 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java @@ -54,6 +54,12 @@ public class ScalarTransformFunctionWrapper extends BaseTransformFunction { private String[] _stringResults; private byte[][] _bytesResults; + private int[][] _intMVResults; + private long[][] _longMVResults; + private float[][] _floatMVResults; + private double[][] _doubleMVResults; + private String[][] _stringMVResults; + public ScalarTransformFunctionWrapper(FunctionInfo functionInfo) { _name = functionInfo.getMethod().getName(); _functionInvoker = new FunctionInvoker(functionInfo); @@ -95,12 +101,14 @@ public class ScalarTransformFunctionWrapper extends BaseTransformFunction { } _nonLiteralValues = new Object[_numNonLiteralArguments][]; - DataType resultDataType = FunctionUtils.getDataType(_functionInvoker.getResultClass()); + Class<?> resultClass = _functionInvoker.getResultClass(); + DataType resultDataType = FunctionUtils.getDataType(resultClass); // Handle unrecognized result class with STRING if (resultDataType == null) { resultDataType = DataType.STRING; } - _resultMetadata = new TransformResultMetadata(resultDataType, true, false); + boolean isSingleValue = !resultClass.isArray(); + _resultMetadata = new TransformResultMetadata(resultDataType, isSingleValue, false); } @Override @@ -222,6 +230,101 @@ public class ScalarTransformFunctionWrapper extends BaseTransformFunction { return _bytesResults; } + @Override + public int[][] transformToIntValuesMV(ProjectionBlock projectionBlock) { + if (_resultMetadata.getDataType() != DataType.INT) { + return super.transformToIntValuesMV(projectionBlock); + } + if (_intMVResults == null) { + _intMVResults = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL][]; + } + getNonLiteralValues(projectionBlock); + int length = projectionBlock.getNumDocs(); + for (int i = 0; i < length; i++) { + for (int j = 0; j < _numNonLiteralArguments; j++) { + _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i]; + } + _intMVResults[i] = (int[]) _functionInvoker.invoke(_arguments); + } + return _intMVResults; + } + + @Override + public long[][] transformToLongValuesMV(ProjectionBlock projectionBlock) { + if (_resultMetadata.getDataType() != DataType.LONG) { + return super.transformToLongValuesMV(projectionBlock); + } + if (_longMVResults == null) { + _longMVResults = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL][]; + } + getNonLiteralValues(projectionBlock); + int length = projectionBlock.getNumDocs(); + for (int i = 0; i < length; i++) { + for (int j = 0; j < _numNonLiteralArguments; j++) { + _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i]; + } + _longMVResults[i] = (long[]) _functionInvoker.invoke(_arguments); + } + return _longMVResults; + } + + @Override + public float[][] transformToFloatValuesMV(ProjectionBlock projectionBlock) { + if (_resultMetadata.getDataType() != DataType.FLOAT) { + return super.transformToFloatValuesMV(projectionBlock); + } + if (_floatMVResults == null) { + _floatMVResults = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL][]; + } + getNonLiteralValues(projectionBlock); + int length = projectionBlock.getNumDocs(); + for (int i = 0; i < length; i++) { + for (int j = 0; j < _numNonLiteralArguments; j++) { + _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i]; + } + _floatMVResults[i] = (float[]) _functionInvoker.invoke(_arguments); + } + return _floatMVResults; + } + + @Override + public double[][] transformToDoubleValuesMV(ProjectionBlock projectionBlock) { + if (_resultMetadata.getDataType() != DataType.DOUBLE) { + return super.transformToDoubleValuesMV(projectionBlock); + } + if (_doubleMVResults == null) { + _doubleMVResults = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL][]; + } + getNonLiteralValues(projectionBlock); + int length = projectionBlock.getNumDocs(); + for (int i = 0; i < length; i++) { + for (int j = 0; j < _numNonLiteralArguments; j++) { + _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i]; + } + _doubleMVResults[i] = (double[]) _functionInvoker.invoke(_arguments); + } + return _doubleMVResults; + } + + @Override + public String[][] transformToStringValuesMV(ProjectionBlock projectionBlock) { + if (_resultMetadata.getDataType() != DataType.STRING) { + return super.transformToStringValuesMV(projectionBlock); + } + if (_stringMVResults == null) { + _stringMVResults = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL][]; + } + getNonLiteralValues(projectionBlock); + int length = projectionBlock.getNumDocs(); + for (int i = 0; i < length; i++) { + for (int j = 0; j < _numNonLiteralArguments; j++) { + _arguments[_nonLiteralIndices[j]] = _nonLiteralValues[j][i]; + } + _stringMVResults[i] = (String[]) _functionInvoker.invoke(_arguments); + } + return _stringMVResults; + } + /** * Helper method to fetch values for the non-literal transform functions based on the parameter types. */ @@ -249,6 +352,21 @@ public class ScalarTransformFunctionWrapper extends BaseTransformFunction { case BYTES: _nonLiteralValues[i] = transformFunction.transformToBytesValuesSV(projectionBlock); break; + case INTEGER_ARRAY: + _nonLiteralValues[i] = transformFunction.transformToIntValuesMV(projectionBlock); + break; + case LONG_ARRAY: + _nonLiteralValues[i] = transformFunction.transformToLongValuesMV(projectionBlock); + break; + case FLOAT_ARRAY: + _nonLiteralValues[i] = transformFunction.transformToFloatValuesMV(projectionBlock); + break; + case DOUBLE_ARRAY: + _nonLiteralValues[i] = transformFunction.transformToDoubleValuesMV(projectionBlock); + break; + case STRING_ARRAY: + _nonLiteralValues[i] = transformFunction.transformToStringValuesMV(projectionBlock); + break; default: throw new IllegalStateException(); } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionsTest.java index b1bd0ec..54809a6 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionsTest.java @@ -374,4 +374,71 @@ public class InbuiltFunctionsTest { return inputs.toArray(new Object[0][]); } + + @Test(dataProvider = "arrayFunctionsDataProvider") + public void testArrayFunctions(String functionExpression, List<String> expectedArguments, GenericRow row, + Object expectedResult) { + testFunction(functionExpression, expectedArguments, row, expectedResult); + } + + @DataProvider(name = "arrayFunctionsDataProvider") + public Object[][] arrayFunctionsDataProvider() { + List<Object[]> inputs = new ArrayList<>(); + + GenericRow row = new GenericRow(); + row.putValue("intArray", new int[]{3, 2, 10, 6, 1, 12}); + row.putValue("integerArray", new Integer[]{3, 2, 10, 6, 1, 12}); + row.putValue("stringArray", new String[]{"3", "2", "10", "6", "1", "12"}); + + inputs.add(new Object[]{"array_reverse_int(intArray)", Collections.singletonList( + "intArray"), row, new int[]{12, 1, 6, 10, 2, 3}}); + inputs.add(new Object[]{"array_reverse_int(integerArray)", Collections.singletonList( + "integerArray"), row, new int[]{12, 1, 6, 10, 2, 3}}); + inputs.add(new Object[]{"array_reverse_int(stringArray)", Collections.singletonList( + "stringArray"), row, new int[]{12, 1, 6, 10, 2, 3}}); + + inputs.add(new Object[]{"array_reverse_string(intArray)", Collections.singletonList( + "intArray"), row, new String[]{"12", "1", "6", "10", "2", "3"}}); + inputs.add(new Object[]{"array_reverse_string(integerArray)", Collections.singletonList( + "integerArray"), row, new String[]{"12", "1", "6", "10", "2", "3"}}); + inputs.add(new Object[]{"array_reverse_string(stringArray)", Collections.singletonList( + "stringArray"), row, new String[]{"12", "1", "6", "10", "2", "3"}}); + + inputs.add(new Object[]{"array_sort_int(intArray)", Collections.singletonList( + "intArray"), row, new int[]{1, 2, 3, 6, 10, 12}}); + inputs.add(new Object[]{"array_sort_int(integerArray)", Collections.singletonList( + "integerArray"), row, new int[]{1, 2, 3, 6, 10, 12}}); + inputs.add(new Object[]{"array_sort_int(stringArray)", Collections.singletonList( + "stringArray"), row, new int[]{1, 2, 3, 6, 10, 12}}); + + inputs.add(new Object[]{"array_sort_string(intArray)", Collections.singletonList( + "intArray"), row, new String[]{"1", "10", "12", "2", "3", "6"}}); + inputs.add(new Object[]{"array_sort_string(integerArray)", Collections.singletonList( + "integerArray"), row, new String[]{"1", "10", "12", "2", "3", "6"}}); + inputs.add(new Object[]{"array_sort_string(stringArray)", Collections.singletonList( + "stringArray"), row, new String[]{"1", "10", "12", "2", "3", "6"}}); + + inputs.add(new Object[]{"array_index_of_int(intArray, 2)", Collections.singletonList("intArray"), row, 1}); + inputs.add(new Object[]{"array_index_of_int(integerArray, 2)", Collections.singletonList("integerArray"), row, 1}); + inputs.add(new Object[]{"array_index_of_int(stringArray, 2)", Collections.singletonList("stringArray"), row, 1}); + + inputs.add(new Object[]{"array_index_of_string(intArray, '2')", Collections.singletonList("intArray"), row, 1}); + inputs.add( + new Object[]{"array_index_of_string(integerArray, '2')", Collections.singletonList("integerArray"), row, 1}); + inputs + .add(new Object[]{"array_index_of_string(stringArray, '2')", Collections.singletonList("stringArray"), row, 1}); + + inputs.add(new Object[]{"array_contains_int(intArray, 2)", Collections.singletonList("intArray"), row, true}); + inputs + .add(new Object[]{"array_contains_int(integerArray, 2)", Collections.singletonList("integerArray"), row, true}); + inputs.add(new Object[]{"array_contains_int(stringArray, 2)", Collections.singletonList("stringArray"), row, true}); + + inputs.add(new Object[]{"array_contains_string(intArray, '2')", Collections.singletonList("intArray"), row, true}); + inputs.add( + new Object[]{"array_contains_string(integerArray, '2')", Collections.singletonList("integerArray"), row, true}); + inputs.add( + new Object[]{"array_contains_string(stringArray, '2')", Collections.singletonList("stringArray"), row, true}); + + return inputs.toArray(new Object[0][]); + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java index 8d67eb8..91eb467 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java @@ -73,6 +73,11 @@ public abstract class BaseTransformFunctionTest { protected static final String BYTES_SV_COLUMN = "bytesSV"; protected static final String STRING_ALPHANUM_SV_COLUMN = "stringAlphaNumSV"; protected static final String INT_MV_COLUMN = "intMV"; + protected static final String LONG_MV_COLUMN = "longMV"; + protected static final String FLOAT_MV_COLUMN = "floatMV"; + protected static final String DOUBLE_MV_COLUMN = "doubleMV"; + protected static final String STRING_MV_COLUMN = "stringMV"; + protected static final String STRING_ALPHANUM_MV_COLUMN = "stringAlphaNumMV"; protected static final String TIME_COLUMN = "time"; protected static final String JSON_COLUMN = "json"; protected final int[] _intSVValues = new int[NUM_ROWS]; @@ -83,6 +88,11 @@ public abstract class BaseTransformFunctionTest { protected final String[] _stringAlphaNumericSVValues = new String[NUM_ROWS]; protected final byte[][] _bytesSVValues = new byte[NUM_ROWS][]; protected final int[][] _intMVValues = new int[NUM_ROWS][]; + protected final long[][] _longMVValues = new long[NUM_ROWS][]; + protected final float[][] _floatMVValues = new float[NUM_ROWS][]; + protected final double[][] _doubleMVValues = new double[NUM_ROWS][]; + protected final String[][] _stringMVValues = new String[NUM_ROWS][]; + protected final String[][] _stringAlphaNumericMVValues = new String[NUM_ROWS][]; protected final long[] _timeValues = new long[NUM_ROWS]; protected final String[] _jsonValues = new String[NUM_ROWS]; @@ -107,8 +117,19 @@ public abstract class BaseTransformFunctionTest { int numValues = 1 + RANDOM.nextInt(MAX_NUM_MULTI_VALUES); _intMVValues[i] = new int[numValues]; + _longMVValues[i] = new long[numValues]; + _floatMVValues[i] = new float[numValues]; + _doubleMVValues[i] = new double[numValues]; + _stringMVValues[i] = new String[numValues]; + _stringAlphaNumericMVValues[i] = new String[numValues]; + for (int j = 0; j < numValues; j++) { _intMVValues[i][j] = 1 + RANDOM.nextInt(MAX_MULTI_VALUE); + _longMVValues[i][j] = 1 + RANDOM.nextLong(); + _floatMVValues[i][j] = 1 + RANDOM.nextFloat(); + _doubleMVValues[i][j] = 1 + RANDOM.nextDouble(); + _stringMVValues[i][j] = df.format(_intSVValues[i] * RANDOM.nextDouble()); + _stringAlphaNumericMVValues[i][j] = RandomStringUtils.randomAlphanumeric(26); } // Time in the past year @@ -126,6 +147,11 @@ public abstract class BaseTransformFunctionTest { map.put(STRING_ALPHANUM_SV_COLUMN, _stringAlphaNumericSVValues[i]); map.put(BYTES_SV_COLUMN, _bytesSVValues[i]); map.put(INT_MV_COLUMN, ArrayUtils.toObject(_intMVValues[i])); + map.put(LONG_MV_COLUMN, ArrayUtils.toObject(_longMVValues[i])); + map.put(FLOAT_MV_COLUMN, ArrayUtils.toObject(_floatMVValues[i])); + map.put(DOUBLE_MV_COLUMN, ArrayUtils.toObject(_doubleMVValues[i])); + map.put(STRING_MV_COLUMN, _stringMVValues[i]); + map.put(STRING_ALPHANUM_MV_COLUMN, _stringAlphaNumericMVValues[i]); map.put(TIME_COLUMN, _timeValues[i]); _jsonValues[i] = JsonUtils.objectToJsonNode(map).toString(); map.put(JSON_COLUMN, _jsonValues[i]); @@ -141,8 +167,13 @@ public abstract class BaseTransformFunctionTest { .addSingleValueDimension(STRING_SV_COLUMN, FieldSpec.DataType.STRING) .addSingleValueDimension(STRING_ALPHANUM_SV_COLUMN, FieldSpec.DataType.STRING) .addSingleValueDimension(BYTES_SV_COLUMN, FieldSpec.DataType.BYTES) - .addSingleValueDimension(JSON_COLUMN, FieldSpec.DataType.STRING) + .addSingleValueDimension(JSON_COLUMN, FieldSpec.DataType.STRING, Integer.MAX_VALUE, null) .addMultiValueDimension(INT_MV_COLUMN, FieldSpec.DataType.INT) + .addMultiValueDimension(LONG_MV_COLUMN, FieldSpec.DataType.LONG) + .addMultiValueDimension(FLOAT_MV_COLUMN, FieldSpec.DataType.FLOAT) + .addMultiValueDimension(DOUBLE_MV_COLUMN, FieldSpec.DataType.DOUBLE) + .addMultiValueDimension(STRING_MV_COLUMN, FieldSpec.DataType.STRING) + .addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN, FieldSpec.DataType.STRING) .addTime(new TimeGranularitySpec(FieldSpec.DataType.LONG, TimeUnit.MILLISECONDS, TIME_COLUMN), null).build(); TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName("test").setTimeColumnName(TIME_COLUMN).build(); @@ -209,6 +240,20 @@ public abstract class BaseTransformFunctionTest { } } + protected void testTransformFunctionMV(TransformFunction transformFunction, int[][] expectedValues) { + int[][] intMVValues = transformFunction.transformToIntValuesMV(_projectionBlock); + for (int i = 0; i < NUM_ROWS; i++) { + Assert.assertEquals(intMVValues[i], expectedValues[i]); + } + } + + protected void testTransformFunctionMV(TransformFunction transformFunction, String[][] expectedValues) { + String[][] stringMVValues = transformFunction.transformToStringValuesMV(_projectionBlock); + for (int i = 0; i < NUM_ROWS; i++) { + Assert.assertEquals(stringMVValues[i], expectedValues[i]); + } + } + @AfterClass public void tearDown() { FileUtils.deleteQuietly(new File(INDEX_DIR_PATH)); diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java index 9402b8b..61a521b 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java @@ -18,14 +18,20 @@ */ package org.apache.pinot.core.operator.transform.function; +import java.util.Arrays; import org.apache.commons.codec.digest.DigestUtils; +import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; -import org.apache.pinot.common.function.FunctionRegistry; import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils; -import org.testng.Assert; +import org.apache.pinot.core.util.ArrayCopyUtils; +import org.apache.pinot.spi.data.FieldSpec.DataType; import org.testng.annotations.Test; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTest { @@ -34,8 +40,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("lower(%s)", STRING_ALPHANUM_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "lower"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "lower"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = _stringAlphaNumericSVValues[i].toLowerCase(); @@ -48,8 +54,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("UPPER(%s)", STRING_ALPHANUM_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "upper"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "upper"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = _stringAlphaNumericSVValues[i].toUpperCase(); @@ -62,8 +68,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("rEvErSe(%s)", STRING_ALPHANUM_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "reverse"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "reverse"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = new StringBuilder(_stringAlphaNumericSVValues[i]).reverse().toString(); @@ -76,8 +82,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("sub_str(%s, 0, 2)", STRING_ALPHANUM_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "substr"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "substr"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = _stringAlphaNumericSVValues[i].substring(0, 2); @@ -87,8 +93,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes expression = QueryContextConverterUtils.getExpression(String.format("substr(%s, '2', '-1')", STRING_ALPHANUM_SV_COLUMN)); transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "substr"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "substr"); expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = _stringAlphaNumericSVValues[i].substring(2); @@ -101,8 +107,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes ExpressionContext expression = QueryContextConverterUtils .getExpression(String.format("concat(%s, %s, '-')", STRING_ALPHANUM_SV_COLUMN, STRING_ALPHANUM_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "concat"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "concat"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = _stringAlphaNumericSVValues[i] + "-" + _stringAlphaNumericSVValues[i]; @@ -115,8 +121,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("replace(%s, 'A', 'B')", STRING_ALPHANUM_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "replace"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "replace"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = _stringAlphaNumericSVValues[i].replaceAll("A", "B"); @@ -131,8 +137,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes ExpressionContext expression = QueryContextConverterUtils .getExpression(String.format("lpad(%s, %d, '%s')", STRING_ALPHANUM_SV_COLUMN, padLength, padString)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "lpad"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "lpad"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = StringUtils.leftPad(_stringAlphaNumericSVValues[i], padLength, padString); @@ -142,8 +148,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes expression = QueryContextConverterUtils .getExpression(String.format("rpad(%s, %d, '%s')", STRING_ALPHANUM_SV_COLUMN, padLength, padString)); transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "rpad"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "rpad"); expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = StringUtils.rightPad(_stringAlphaNumericSVValues[i], padLength, padString); @@ -156,22 +162,22 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("ltrim(lpad(%s, 50, ' '))", STRING_ALPHANUM_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "ltrim"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "ltrim"); testTransformFunction(transformFunction, _stringAlphaNumericSVValues); expression = QueryContextConverterUtils.getExpression(String.format("rtrim(rpad(%s, 50, ' '))", STRING_ALPHANUM_SV_COLUMN)); transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "rtrim"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "rtrim"); testTransformFunction(transformFunction, _stringAlphaNumericSVValues); expression = QueryContextConverterUtils .getExpression(String.format("trim(rpad(lpad(%s, 50, ' '), 100, ' '))", STRING_ALPHANUM_SV_COLUMN)); transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "trim"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "trim"); testTransformFunction(transformFunction, _stringAlphaNumericSVValues); } @@ -179,8 +185,8 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes public void testShaTransformFunction() { ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("sha(%s)", BYTES_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "sha"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "sha"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = DigestUtils.shaHex(_bytesSVValues[i]); @@ -190,10 +196,11 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes @Test public void testSha256TransformFunction() { - ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("sha256(%s)", BYTES_SV_COLUMN)); + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("sha256(%s)", BYTES_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "sha256"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "sha256"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = DigestUtils.sha256Hex(_bytesSVValues[i]); @@ -203,10 +210,11 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes @Test public void testSha512TransformFunction() { - ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("sha512(%s)", BYTES_SV_COLUMN)); + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("sha512(%s)", BYTES_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "sha512"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "sha512"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = DigestUtils.sha512Hex(_bytesSVValues[i]); @@ -218,12 +226,180 @@ public class ScalarTransformFunctionWrapperTest extends BaseTransformFunctionTes public void testMd5TransformFunction() { ExpressionContext expression = QueryContextConverterUtils.getExpression(String.format("md5(%s)", BYTES_SV_COLUMN)); TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); - Assert.assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); - Assert.assertEquals(transformFunction.getName(), "md5"); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "md5"); String[] expectedValues = new String[NUM_ROWS]; for (int i = 0; i < NUM_ROWS; i++) { expectedValues[i] = DigestUtils.md5Hex(_bytesSVValues[i]); } testTransformFunction(transformFunction, expectedValues); } + + @Test + public void testArrayReverseIntTransformFunction() { + { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_reverse_int(%s)", INT_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arrayReverseInt"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = _intMVValues[i].clone(); + ArrayUtils.reverse(expectedValues[i]); + } + testTransformFunctionMV(transformFunction, expectedValues); + } + { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_reverse_int(%s)", LONG_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arrayReverseInt"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = new int[_longMVValues[i].length]; + ArrayCopyUtils.copy(_longMVValues[i], expectedValues[i], _longMVValues[i].length); + ArrayUtils.reverse(expectedValues[i]); + } + testTransformFunctionMV(transformFunction, expectedValues); + } + } + + @Test + public void testArrayReverseStringTransformFunction() { + { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_reverse_string(%s)", STRING_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arrayReverseString"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + String[][] expectedValues = new String[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = _stringMVValues[i].clone(); + ArrayUtils.reverse(expectedValues[i]); + } + testTransformFunctionMV(transformFunction, expectedValues); + } + { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_reverse_string(%s)", INT_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arrayReverseString"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + String[][] expectedValues = new String[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = new String[_intMVValues[i].length]; + ArrayCopyUtils.copy(_intMVValues[i], expectedValues[i], _longMVValues[i].length); + ArrayUtils.reverse(expectedValues[i]); + } + testTransformFunctionMV(transformFunction, expectedValues); + } + } + + @Test + public void testArraySortIntTransformFunction() { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_sort_int(%s)", INT_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arraySortInt"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + int[][] expectedValues = new int[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = _intMVValues[i].clone(); + Arrays.sort(expectedValues[i]); + } + testTransformFunctionMV(transformFunction, expectedValues); + } + + @Test + public void testArraySortStringTransformFunction() { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_sort_string(%s)", STRING_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arraySortString"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); + assertFalse(transformFunction.getResultMetadata().isSingleValue()); + String[][] expectedValues = new String[NUM_ROWS][]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = _stringMVValues[i].clone(); + Arrays.sort(expectedValues[i]); + } + testTransformFunctionMV(transformFunction, expectedValues); + } + + @Test + public void testArrayIndexOfIntTransformFunction() { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_index_of_int(%s, 2)", INT_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arrayIndexOfInt"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertTrue(transformFunction.getResultMetadata().isSingleValue()); + int[] expectedValues = new int[NUM_ROWS]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = ArrayUtils.indexOf(_intMVValues[i], 2); + } + testTransformFunction(transformFunction, expectedValues); + } + + @Test + public void testArrayIndexOfStringTransformFunction() { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_index_of_string(%s, 'a')", INT_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arrayIndexOfString"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertTrue(transformFunction.getResultMetadata().isSingleValue()); + int[] expectedValues = new int[NUM_ROWS]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = ArrayUtils.indexOf(_intMVValues[i], 'a'); + } + testTransformFunction(transformFunction, expectedValues); + } + + @Test + public void testArrayContainsIntTransformFunction() { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_contains_int(%s, 2)", INT_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arrayContainsInt"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); + assertTrue(transformFunction.getResultMetadata().isSingleValue()); + String[] expectedValues = new String[NUM_ROWS]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = Boolean.toString(ArrayUtils.contains(_intMVValues[i], 2)); + } + testTransformFunction(transformFunction, expectedValues); + } + + @Test + public void testArrayContainsStringTransformFunction() { + ExpressionContext expression = + QueryContextConverterUtils.getExpression(String.format("array_contains_string(%s, 'a')", INT_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "arrayContainsString"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.STRING); + assertTrue(transformFunction.getResultMetadata().isSingleValue()); + String[] expectedValues = new String[NUM_ROWS]; + for (int i = 0; i < NUM_ROWS; i++) { + expectedValues[i] = Boolean.toString(ArrayUtils.contains(_intMVValues[i], 'a')); + } + testTransformFunction(transformFunction, expectedValues); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org