This is an automated email from the ASF dual-hosted git repository. xiangfu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new e809b4e81e Adding vector scalar functions (#11222) e809b4e81e is described below commit e809b4e81e8c8571f9b407a28931fc2b8b20ba3c Author: Xiang Fu <xiangfu.1...@gmail.com> AuthorDate: Mon Jul 31 14:35:32 2023 -0700 Adding vector scalar functions (#11222) --- .../common/function/TransformFunctionType.java | 16 ++ .../common/function/scalar/VectorFunctions.java | 154 ++++++++++++++ .../function/TransformFunctionFactory.java | 14 ++ .../function/VectorTransformFunctions.java | 229 +++++++++++++++++++++ .../core/data/function/VectorFunctionsTest.java | 113 ++++++++++ .../function/BaseTransformFunctionTest.java | 29 ++- .../function/VectorTransformFunctionTest.java | 73 +++++++ .../integration/tests/VectorIntegrationTest.java | 193 +++++++++++++++++ 8 files changed, 818 insertions(+), 3 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java index 58cb4da3f2..471f6b128a 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java @@ -205,6 +205,22 @@ public enum TransformFunctionType { OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC), ordinal -> ordinal > 1 && ordinal < 4)), + // Vector functions + // TODO: Once VECTOR type is defined, we should update here. + COSINE_DISTANCE("cosineDistance", ReturnTypes.explicit(SqlTypeName.DOUBLE), + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC), + ordinal -> ordinal > 1 && ordinal < 4), "cosine_distance"), + INNER_PRODUCT("innerProduct", ReturnTypes.explicit(SqlTypeName.DOUBLE), + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "inner_product"), + L1_DISTANCE("l1Distance", ReturnTypes.explicit(SqlTypeName.DOUBLE), + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "l1_distance"), + L2_DISTANCE("l2Distance", ReturnTypes.explicit(SqlTypeName.DOUBLE), + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "l2_distance"), + VECTOR_DIMS("vectorDims", ReturnTypes.explicit(SqlTypeName.INTEGER), + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_dims"), + VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE), + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"), + // Trigonometry SIN("sin"), COS("cos"), diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/VectorFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/VectorFunctions.java new file mode 100644 index 0000000000..3c1ab75478 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/VectorFunctions.java @@ -0,0 +1,154 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.scalar; + +import com.google.common.base.Preconditions; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +/** + * Inbuilt Vector Transformation Functions + * The functions can be used as UDFs in Query when added in the FunctionRegistry. + * @ScalarFunction annotation is used with each method for the registration + * + * Example usage: + */ +public class VectorFunctions { + private VectorFunctions() { + } + + /** + * Returns the cosine distance between two vectors + * @param vector1 vector1 + * @param vector2 vector2 + * @return cosine distance + */ + @ScalarFunction(names = {"cosinedistance", "cosine_distance"}) + public static double cosineDistance(float[] vector1, float[] vector2) { + return cosineDistance(vector1, vector2, Double.NaN); + } + + /** + * Returns the cosine distance between two vectors, with a default value if the norm of either vector is 0. + * @param vector1 vector1 + * @param vector2 vector2 + * @param defaultValue default value when either vector has a norm of 0 + * @return cosine distance + */ + @ScalarFunction(names = {"cosinedistance", "cosine_distance"}) + public static double cosineDistance(float[] vector1, float[] vector2, double defaultValue) { + validateVectors(vector1, vector2); + double dotProduct = 0.0; + double norm1 = 0.0; + double norm2 = 0.0; + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + norm1 += Math.pow(vector1[i], 2); + norm2 += Math.pow(vector2[i], 2); + } + if (norm1 == 0 || norm2 == 0) { + return defaultValue; + } + return 1 - (dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2))); + } + + /** + * Returns the inner product between two vectors + * @param vector1 vector1 + * @param vector2 vector2 + * @return inner product + */ + @ScalarFunction(names = {"innerproduct", "inner_product"}) + public static double innerProduct(float[] vector1, float[] vector2) { + validateVectors(vector1, vector2); + double dotProduct = 0.0; + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + } + return dotProduct; + } + + /** + * Returns the L2 distance between two vectors + * @param vector1 vector1 + * @param vector2 vector2 + * @return L2 distance + */ + @ScalarFunction(names = {"l2distance", "l2_distance"}) + public static double l2Distance(float[] vector1, float[] vector2) { + validateVectors(vector1, vector2); + double distance = 0.0; + for (int i = 0; i < vector1.length; i++) { + distance += Math.pow(vector1[i] - vector2[i], 2); + } + return Math.sqrt(distance); + } + + /** + * Returns the L1 distance between two vectors + * @param vector1 vector1 + * @param vector2 vector2 + * @return L1 distance + */ + @ScalarFunction(names = {"l1distance", "l1_distance"}) + public static double l1Distance(float[] vector1, float[] vector2) { + validateVectors(vector1, vector2); + double distance = 0.0; + for (int i = 0; i < vector1.length; i++) { + distance += Math.abs(vector1[i] - vector2[i]); + } + return distance; + } + + /** + * Returns the number of dimensions in a vector + * @param vector input vector + * @return number of dimensions + */ + @ScalarFunction(names = {"vectordims", "vector_dims"}) + public static int vectorDims(float[] vector) { + validateVector(vector); + return vector.length; + } + + /** + * Returns the norm of a vector + * @param vector input vector + * @return norm + */ + @ScalarFunction(names = {"vectornorm", "vector_norm"}) + public static double vectorNorm(float[] vector) { + validateVector(vector); + double norm = 0.0; + for (int i = 0; i < vector.length; i++) { + norm += Math.pow(vector[i], 2); + } + return Math.sqrt(norm); + } + + public static void validateVectors(float[] vector1, float[] vector2) { + Preconditions.checkArgument(vector1 != null && vector2 != null, "Null vector passed"); + Preconditions.checkArgument(vector1.length == vector2.length, "Vector lengths do not match"); + } + + public static void validateVector(float[] vector) { + Preconditions.checkArgument(vector != null, "Null vector passed"); + Preconditions.checkArgument(vector.length > 0, "Empty vector passed"); + } +} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java index 3f706bf566..4e3ff24119 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java @@ -70,6 +70,12 @@ import org.apache.pinot.core.operator.transform.function.TrigonometricTransformF import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.SinhTransformFunction; import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.TanTransformFunction; import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.TanhTransformFunction; +import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.CosineDistanceTransformFunction; +import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.InnerProductTransformFunction; +import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.L1DistanceTransformFunction; +import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.L2DistanceTransformFunction; +import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDimsTransformFunction; +import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorNormTransformFunction; import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils; import org.apache.pinot.segment.spi.datasource.DataSource; @@ -217,6 +223,14 @@ public class TransformFunctionFactory { typeToImplementation.put(TransformFunctionType.DEGREES, DegreesTransformFunction.class); typeToImplementation.put(TransformFunctionType.RADIANS, RadiansTransformFunction.class); + // Vector functions + typeToImplementation.put(TransformFunctionType.COSINE_DISTANCE, CosineDistanceTransformFunction.class); + typeToImplementation.put(TransformFunctionType.INNER_PRODUCT, InnerProductTransformFunction.class); + typeToImplementation.put(TransformFunctionType.L1_DISTANCE, L1DistanceTransformFunction.class); + typeToImplementation.put(TransformFunctionType.L2_DISTANCE, L2DistanceTransformFunction.class); + typeToImplementation.put(TransformFunctionType.VECTOR_DIMS, VectorDimsTransformFunction.class); + typeToImplementation.put(TransformFunctionType.VECTOR_NORM, VectorNormTransformFunction.class); + Map<String, Class<? extends TransformFunction>> registry = new HashMap<>(typeToImplementation.size()); for (Map.Entry<TransformFunctionType, Class<? extends TransformFunction>> entry : typeToImplementation.entrySet()) { for (String alias : entry.getKey().getAlternativeNames()) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctions.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctions.java new file mode 100644 index 0000000000..d5d7508b0c --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctions.java @@ -0,0 +1,229 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.transform.function; + +import com.google.common.base.Preconditions; +import java.util.List; +import java.util.Map; +import org.apache.pinot.common.function.scalar.VectorFunctions; +import org.apache.pinot.core.operator.ColumnContext; +import org.apache.pinot.core.operator.blocks.ValueBlock; +import org.apache.pinot.core.operator.transform.TransformResultMetadata; + + +public class VectorTransformFunctions { + public static class CosineDistanceTransformFunction extends VectorDistanceTransformFunction { + public static final String FUNCTION_NAME = "cosineDistance"; + private Double _defaultValue = null; + + @Override + protected void checkArgumentSize(List<TransformFunction> arguments) { + // Check that there are 2 or 3 arguments + if (arguments.size() < 2 || arguments.size() > 3) { + throw new IllegalArgumentException("2 or 3 arguments are required for CosineDistance function"); + } + } + + @Override + public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) { + super.init(arguments, columnContextMap); + if (arguments.size() == 3) { + _defaultValue = ((LiteralTransformFunction) arguments.get(2)).getDoubleLiteral(); + } + } + + @Override + public String getName() { + return FUNCTION_NAME; + } + + @Override + protected double computeDistance(float[] vector1, float[] vector2) { + if (_defaultValue != null) { + return VectorFunctions.cosineDistance(vector1, vector2, _defaultValue); + } + return VectorFunctions.cosineDistance(vector1, vector2); + } + } + + public static class InnerProductTransformFunction extends VectorDistanceTransformFunction { + public static final String FUNCTION_NAME = "innerProduct"; + + @Override + public String getName() { + return FUNCTION_NAME; + } + + @Override + protected double computeDistance(float[] vector1, float[] vector2) { + return VectorFunctions.innerProduct(vector1, vector2); + } + } + + public static class L1DistanceTransformFunction extends VectorDistanceTransformFunction { + public static final String FUNCTION_NAME = "l1Distance"; + + @Override + public String getName() { + return FUNCTION_NAME; + } + + @Override + protected double computeDistance(float[] vector1, float[] vector2) { + return VectorFunctions.l1Distance(vector1, vector2); + } + } + + public static class L2DistanceTransformFunction extends VectorDistanceTransformFunction { + public static final String FUNCTION_NAME = "l2Distance"; + + @Override + public String getName() { + return FUNCTION_NAME; + } + + @Override + protected double computeDistance(float[] vector1, float[] vector2) { + return VectorFunctions.l2Distance(vector1, vector2); + } + } + + public static abstract class VectorDistanceTransformFunction extends BaseTransformFunction { + + protected TransformFunction _leftTransformFunction; + protected TransformFunction _rightTransformFunction; + + @Override + public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) { + super.init(arguments, columnContextMap); + checkArgumentSize(arguments); + _leftTransformFunction = arguments.get(0); + _rightTransformFunction = arguments.get(1); + Preconditions.checkArgument( + !_leftTransformFunction.getResultMetadata().isSingleValue() + && !_rightTransformFunction.getResultMetadata().isSingleValue(), + "Argument must be multi-valued float vector for vector distance transform function: %s", getName()); + } + + protected void checkArgumentSize(List<TransformFunction> arguments) { + // Check that there are 2 arguments + if (arguments.size() != 2) { + throw new IllegalArgumentException("Exactly 2 arguments are required for Vector transform function"); + } + } + + @Override + public TransformResultMetadata getResultMetadata() { + return DOUBLE_SV_NO_DICTIONARY_METADATA; + } + + @Override + public double[] transformToDoubleValuesSV(ValueBlock valueBlock) { + int length = valueBlock.getNumDocs(); + initDoubleValuesSV(length); + float[][] leftValues = _leftTransformFunction.transformToFloatValuesMV(valueBlock); + float[][] rightValues = _rightTransformFunction.transformToFloatValuesMV(valueBlock); + for (int i = 0; i < length; i++) { + _doubleValuesSV[i] = computeDistance(leftValues[i], rightValues[i]); + } + return _doubleValuesSV; + } + + protected abstract double computeDistance(float[] vector1, float[] vector2); + } + + public static class VectorDimsTransformFunction extends BaseTransformFunction { + public static final String FUNCTION_NAME = "vectorDims"; + + private TransformFunction _transformFunction; + + @Override + public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) { + super.init(arguments, columnContextMap); + // Check that there is exact 1 argument + if (arguments.size() != 1) { + throw new IllegalArgumentException("Exactly 1 argument is required for Vector transform function"); + } + _transformFunction = arguments.get(0); + Preconditions.checkArgument(!_transformFunction.getResultMetadata().isSingleValue(), + "Argument must be multi-valued float vector for vector distance transform function: %s", getName()); + } + + @Override + public String getName() { + return FUNCTION_NAME; + } + + @Override + public TransformResultMetadata getResultMetadata() { + return INT_SV_NO_DICTIONARY_METADATA; + } + + @Override + public int[] transformToIntValuesSV(ValueBlock valueBlock) { + int length = valueBlock.getNumDocs(); + initIntValuesSV(length); + float[][] values = _transformFunction.transformToFloatValuesMV(valueBlock); + for (int i = 0; i < length; i++) { + _intValuesSV[i] = VectorFunctions.vectorDims(values[i]); + } + return _intValuesSV; + } + } + + public static class VectorNormTransformFunction extends BaseTransformFunction { + public static final String FUNCTION_NAME = "vectorNorm"; + + private TransformFunction _transformFunction; + + @Override + public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) { + super.init(arguments, columnContextMap); + // Check that there is exact 1 argument + if (arguments.size() != 1) { + throw new IllegalArgumentException("Exactly 1 argument is required for Vector transform function"); + } + + _transformFunction = arguments.get(0); + Preconditions.checkArgument(!_transformFunction.getResultMetadata().isSingleValue(), + "Argument must be multi-valued float vector for vector distance transform function: %s", getName()); + } + + @Override + public String getName() { + return FUNCTION_NAME; + } + + @Override + public TransformResultMetadata getResultMetadata() { + return DOUBLE_SV_NO_DICTIONARY_METADATA; + } + + @Override + public double[] transformToDoubleValuesSV(ValueBlock valueBlock) { + int length = valueBlock.getNumDocs(); + initDoubleValuesSV(length); + float[][] values = _transformFunction.transformToFloatValuesMV(valueBlock); + for (int i = 0; i < length; i++) { + _doubleValuesSV[i] = VectorFunctions.vectorNorm(values[i]); + } + return _doubleValuesSV; + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java new file mode 100644 index 0000000000..972c33ee43 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/VectorFunctionsTest.java @@ -0,0 +1,113 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.data.function; + +import com.google.common.collect.Lists; +import java.util.ArrayList; +import java.util.List; +import org.apache.pinot.segment.local.function.InbuiltFunctionEvaluator; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + + +/** + * Tests the vector scalar functions + */ +public class VectorFunctionsTest { + + private void testFunction(String functionExpression, List<String> expectedArguments, GenericRow row, + Object expectedResult) { + InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(functionExpression); + Assert.assertEquals(evaluator.getArguments(), expectedArguments); + Assert.assertEquals(evaluator.evaluate(row), expectedResult); + } + + @Test(dataProvider = "vectorFunctionsDataProvider") + public void testVectorFunctions(String functionExpression, List<String> expectedArguments, GenericRow row, + Object expectedResult) { + testFunction(functionExpression, expectedArguments, row, expectedResult); + } + + @DataProvider(name = "vectorFunctionsDataProvider") + public Object[][] vectorFunctionsDataProvider() { + List<Object[]> inputs = new ArrayList<>(); + + GenericRow row = new GenericRow(); + row.putValue("vector1", new float[]{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + row.putValue("vector2", new float[]{0.6f, 0.7f, 0.8f, 0.9f, 1.0f}); + inputs.add(new Object[]{ + "cosineDistance(vector1, vector2)", Lists.newArrayList("vector1", "vector2"), row, 0.03504950750101454 + }); + inputs.add(new Object[]{ + "innerProduct(vector1, vector2)", Lists.newArrayList("vector1", "vector2"), row, 1.2999999970197678 + }); + inputs.add(new Object[]{ + "l2Distance(vector1, vector2)", Lists.newArrayList("vector1", "vector2"), row, 1.1180339754218913 + }); + inputs.add(new Object[]{ + "l1Distance(vector1, vector2)", Lists.newArrayList("vector1", "vector2"), row, 2.4999999701976776 + }); + inputs.add(new Object[]{"vectorDims(vector1)", Lists.newArrayList("vector1"), row, 5}); + inputs.add(new Object[]{"vectorDims(vector2)", Lists.newArrayList("vector2"), row, 5}); + inputs.add(new Object[]{"vectorNorm(vector1)", Lists.newArrayList("vector1"), row, 0.741619857751291}); + inputs.add(new Object[]{"vectorNorm(vector2)", Lists.newArrayList("vector2"), row, 1.8165902091773676}); + return inputs.toArray(new Object[0][]); + } + + @Test(dataProvider = "vectorFunctionsZeroDataProvider") + public void testVectorFunctionsWithZeroVector(String functionExpression, List<String> expectedArguments, + GenericRow row, + Object expectedResult) { + testFunction(functionExpression, expectedArguments, row, expectedResult); + } + + @DataProvider(name = "vectorFunctionsZeroDataProvider") + public Object[][] vectorFunctionsZeroDataProvider() { + List<Object[]> inputs = new ArrayList<>(); + + GenericRow row = new GenericRow(); + row.putValue("vector1", new float[]{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + row.putValue("vector2", new float[]{0f, 0f, 0f, 0f, 0f}); + inputs.add(new Object[]{ + "cosineDistance(vector1, vector2)", Lists.newArrayList("vector1", "vector2"), row, Double.NaN + }); + inputs.add(new Object[]{ + "cosineDistance(vector1, vector2, 0.0)", Lists.newArrayList("vector1", "vector2"), row, 0.0 + }); + inputs.add(new Object[]{ + "cosineDistance(vector1, vector2, 1.0)", Lists.newArrayList("vector1", "vector2"), row, 1.0 + }); + inputs.add(new Object[]{ + "innerProduct(vector1, vector2)", Lists.newArrayList("vector1", "vector2"), row, 0.0 + }); + inputs.add(new Object[]{ + "l2Distance(vector1, vector2)", Lists.newArrayList("vector1", "vector2"), row, 0.741619857751291 + }); + inputs.add(new Object[]{ + "l1Distance(vector1, vector2)", Lists.newArrayList("vector1", "vector2"), row, 1.5000000223517418 + }); + inputs.add(new Object[]{"vectorDims(vector1)", Lists.newArrayList("vector1"), row, 5}); + inputs.add(new Object[]{"vectorDims(vector2)", Lists.newArrayList("vector2"), row, 5}); + inputs.add(new Object[]{"vectorNorm(vector1)", Lists.newArrayList("vector1"), row, 0.741619857751291}); + inputs.add(new Object[]{"vectorNorm(vector2)", Lists.newArrayList("vector2"), row, 0.0}); + return inputs.toArray(new Object[0][]); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java index d36bbd3250..5026896048 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/BaseTransformFunctionTest.java @@ -34,6 +34,7 @@ import java.util.concurrent.TimeUnit; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.RandomStringUtils; +import org.apache.commons.lang3.RandomUtils; import org.apache.pinot.core.operator.DocIdSetOperator; import org.apache.pinot.core.operator.ProjectionOperator; import org.apache.pinot.core.operator.blocks.ProjectionBlock; @@ -71,6 +72,7 @@ public abstract class BaseTransformFunctionTest { protected static final int NUM_ROWS = 1000; protected static final int MAX_NUM_MULTI_VALUES = 5; protected static final int MAX_MULTI_VALUE = 10; + protected static final int VECTOR_DIM_SIZE = 512; protected static final String INT_SV_COLUMN = "intSV"; // INT_SV_NULL_COLUMN's even row equals to INT_SV_COLUMN. odd row is null. protected static final String INT_SV_NULL_COLUMN = "intSVNull"; @@ -82,6 +84,11 @@ public abstract class BaseTransformFunctionTest { protected static final String STRING_SV_NULL_COLUMN = "stringSVNull"; protected static final String BYTES_SV_COLUMN = "bytesSV"; + + protected static final String VECTOR_1_COLUMN = "vector1"; + protected static final String VECTOR_2_COLUMN = "vector2"; + protected static final String ZERO_VECTOR_COLUMN = "zeroVector"; + protected static final String STRING_ALPHANUM_SV_COLUMN = "stringAlphaNumSV"; protected static final String STRING_ALPHANUM_NULL_SV_COLUMN = "stringAlphaNumSVNull"; @@ -118,6 +125,8 @@ public abstract class BaseTransformFunctionTest { protected final String[][] _stringLongFormatMVValues = new String[NUM_ROWS][]; protected final long[] _timeValues = new long[NUM_ROWS]; protected final String[] _jsonValues = new String[NUM_ROWS]; + protected final float[][] _vector1Values = new float[NUM_ROWS][]; + protected final float[][] _vector2Values = new float[NUM_ROWS][]; protected Map<String, DataSource> _dataSourceMap; protected ProjectionBlock _projectionBlock; @@ -147,6 +156,8 @@ public abstract class BaseTransformFunctionTest { _stringMVValues[i] = new String[numValues]; _stringAlphaNumericMVValues[i] = new String[numValues]; _stringLongFormatMVValues[i] = new String[numValues]; + _vector1Values[i] = new float[VECTOR_DIM_SIZE]; + _vector2Values[i] = new float[VECTOR_DIM_SIZE]; for (int j = 0; j < numValues; j++) { _intMVValues[i][j] = 1 + RANDOM.nextInt(MAX_MULTI_VALUE); @@ -158,6 +169,11 @@ public abstract class BaseTransformFunctionTest { _stringLongFormatMVValues[i][j] = df.format(_intSVValues[i] * RANDOM.nextLong()); } + for (int j = 0; j < VECTOR_DIM_SIZE; j++) { + _vector1Values[i][j] = Math.abs(RandomUtils.nextFloat(0.0f, 1.0f)); + _vector2Values[i][j] = Math.abs(RandomUtils.nextFloat(0.0f, 1.0f)); + } + // Time in the past year _timeValues[i] = currentTimeMs - RANDOM.nextInt(365 * 24 * 3600) * 1000L; } @@ -188,6 +204,7 @@ public abstract class BaseTransformFunctionTest { map.put(STRING_ALPHANUM_NULL_SV_COLUMN, _stringAlphaNumericSVValues[i]); } map.put(BYTES_SV_COLUMN, _bytesSVValues[i]); + map.put(INT_MV_COLUMN, ArrayUtils.toObject(_intMVValues[i])); if (isNullRow(i)) { map.put(INT_MV_NULL_COLUMN, null); @@ -196,6 +213,9 @@ public abstract class BaseTransformFunctionTest { } map.put(LONG_MV_COLUMN, ArrayUtils.toObject(_longMVValues[i])); map.put(FLOAT_MV_COLUMN, ArrayUtils.toObject(_floatMVValues[i])); + map.put(VECTOR_1_COLUMN, ArrayUtils.toObject(_vector1Values[i])); + map.put(VECTOR_2_COLUMN, ArrayUtils.toObject(_vector2Values[i])); + map.put(ZERO_VECTOR_COLUMN, ArrayUtils.toObject(new float[VECTOR_DIM_SIZE])); map.put(DOUBLE_MV_COLUMN, ArrayUtils.toObject(_doubleMVValues[i])); map.put(STRING_MV_COLUMN, _stringMVValues[i]); map.put(STRING_ALPHANUM_MV_COLUMN, _stringAlphaNumericMVValues[i]); @@ -235,6 +255,9 @@ public abstract class BaseTransformFunctionTest { .addMultiValueDimension(STRING_MV_COLUMN, FieldSpec.DataType.STRING) .addMultiValueDimension(STRING_ALPHANUM_MV_COLUMN, FieldSpec.DataType.STRING) .addMultiValueDimension(STRING_LONG_MV_COLUMN, FieldSpec.DataType.STRING) + .addMultiValueDimension(VECTOR_1_COLUMN, FieldSpec.DataType.FLOAT) + .addMultiValueDimension(VECTOR_2_COLUMN, FieldSpec.DataType.FLOAT) + .addMultiValueDimension(ZERO_VECTOR_COLUMN, FieldSpec.DataType.FLOAT) .addDateTime(TIMESTAMP_COLUMN, FieldSpec.DataType.TIMESTAMP, "1:MILLISECONDS:EPOCH", "1:MILLISECONDS") .addDateTime(TIMESTAMP_COLUMN_NULL, FieldSpec.DataType.TIMESTAMP, "1:MILLISECONDS:EPOCH", "1:MILLISECONDS") .addTime(new TimeGranularitySpec(FieldSpec.DataType.LONG, TimeUnit.MILLISECONDS, TIME_COLUMN), null).build(); @@ -401,9 +424,9 @@ public abstract class BaseTransformFunctionTest { protected void testTransformFunctionWithNull(TransformFunction transformFunction, double[] expectedValues, RoaringBitmap expectedNull) { int[] intValues = transformFunction.transformToIntValuesSV(_projectionBlock); - long[]longValues = transformFunction.transformToLongValuesSV(_projectionBlock); + long[] longValues = transformFunction.transformToLongValuesSV(_projectionBlock); float[] floatValues = transformFunction.transformToFloatValuesSV(_projectionBlock); - double[]doubleValues = transformFunction.transformToDoubleValuesSV(_projectionBlock); + double[] doubleValues = transformFunction.transformToDoubleValuesSV(_projectionBlock); BigDecimal[] bigDecimalValues = null; try { // 1- Some transform functions cannot work with BigDecimal (e.g. exp, ln, and sqrt). @@ -473,7 +496,7 @@ public abstract class BaseTransformFunctionTest { long[] longValues = transformFunction.transformToLongValuesSV(_projectionBlock); float[] floatValues = transformFunction.transformToFloatValuesSV(_projectionBlock); double[] doubleValues = transformFunction.transformToDoubleValuesSV(_projectionBlock); - BigDecimal[]bigDecimalValues = + BigDecimal[] bigDecimalValues = transformFunction.transformToBigDecimalValuesSV(_projectionBlock); for (int i = 0; i < NUM_ROWS; i++) { diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java new file mode 100644 index 0000000000..fd79aeabc1 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/VectorTransformFunctionTest.java @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.transform.function; + +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.RequestContextUtils; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + + +public class VectorTransformFunctionTest extends BaseTransformFunctionTest { + + @Test(dataProvider = "testVectorTransformFunctionDataProvider") + public void testVectorTransformFunction(String expressionStr, double lowerBound, double upperBound) { + ExpressionContext expression = RequestContextUtils.getExpression(expressionStr); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + double[] doubleValuesSV = transformFunction.transformToDoubleValuesSV(_projectionBlock); + for (int i = 0; i < NUM_ROWS; i++) { + assertTrue(doubleValuesSV[i] >= lowerBound, doubleValuesSV[i] + " < " + lowerBound); + assertTrue(doubleValuesSV[i] <= upperBound, doubleValuesSV[i] + " > " + upperBound); + } + } + + @Test + public void testVectorDimsTransformFunction() { + ExpressionContext expression = RequestContextUtils.getExpression("vectorDims(vector1)"); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + int[] intValuesSV = transformFunction.transformToIntValuesSV(_projectionBlock); + for (int i = 0; i < NUM_ROWS; i++) { + assertEquals(intValuesSV[i], VECTOR_DIM_SIZE); + } + + expression = RequestContextUtils.getExpression("vectorDims(vector2)"); + transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + intValuesSV = transformFunction.transformToIntValuesSV(_projectionBlock); + for (int i = 0; i < NUM_ROWS; i++) { + assertEquals(intValuesSV[i], VECTOR_DIM_SIZE); + } + } + + @DataProvider(name = "testVectorTransformFunctionDataProvider") + public Object[][] testVectorTransformFunctionDataProvider() { + return new Object[][]{ + new Object[]{"cosineDistance(vector1, vector2)", 0.1, 0.4}, + new Object[]{"cosineDistance(vector1, vector2, 0)", 0.1, 0.4}, + new Object[]{"cosineDistance(vector1, zeroVector, 0)", 0.0, 0.0}, + new Object[]{"innerProduct(vector1, vector2)", 100, 160}, + new Object[]{"l1Distance(vector1, vector2)", 150, 200}, + new Object[]{"l2Distance(vector1, vector2)", 8, 11}, + new Object[]{"vectorNorm(vector1)", 10, 16}, + new Object[]{"vectorNorm(vector2)", 10, 16} + }; + } +} diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java new file mode 100644 index 0000000000..d44462f745 --- /dev/null +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/VectorIntegrationTest.java @@ -0,0 +1,193 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.integration.tests; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.common.collect.ImmutableList; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import org.apache.avro.Schema.Field; +import org.apache.avro.Schema.Type; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.RandomUtils; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.FieldSpec; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.apache.pinot.util.TestUtils; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + + +public class VectorIntegrationTest extends BaseClusterIntegrationTest { + private static final String VECTOR_1 = "vector1"; + private static final String VECTOR_2 = "vector2"; + private static final String ZERO_VECTOR = "zeroVector"; + private static final int VECTOR_DIM_SIZE = 512; + + @BeforeClass + public void setup() + throws Exception { + TestUtils.ensureDirectoriesExistAndEmpty(_tempDir, _segmentDir, _tarDir); + + // Start the Pinot cluster + startZk(); + startController(); + startBroker(); + startServer(); + + // create & upload schema AND table config + Schema schema = new Schema.SchemaBuilder().setSchemaName(DEFAULT_SCHEMA_NAME) + .addMultiValueDimension(VECTOR_1, FieldSpec.DataType.FLOAT) + .addMultiValueDimension(VECTOR_2, FieldSpec.DataType.FLOAT) + .addMultiValueDimension(ZERO_VECTOR, FieldSpec.DataType.FLOAT) + .build(); + addSchema(schema); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(DEFAULT_TABLE_NAME).build(); + addTableConfig(tableConfig); + + // create & upload segments + File avroFile = createAvroFile(getCountStarResult()); + ClusterIntegrationTestUtils.buildSegmentFromAvro(avroFile, tableConfig, schema, 0, _segmentDir, _tarDir); + uploadSegments(DEFAULT_TABLE_NAME, _tarDir); + + waitForAllDocsLoaded(60_000); + } + + @Override + protected long getCountStarResult() { + return 1000; + } + + @Test(dataProvider = "useBothQueryEngines") + public void testQueries(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "cosineDistance(vector1, vector2), " + + "innerProduct(vector1, vector2), " + + "l1Distance(vector1, vector2), " + + "l2Distance(vector1, vector2), " + + "vectorDims(vector1), vectorDims(vector2), " + + "vectorNorm(vector1), vectorNorm(vector2), " + + "cosineDistance(vector1, zeroVector), " + + "cosineDistance(vector1, zeroVector, 0) " + + "FROM %s", DEFAULT_TABLE_NAME); + JsonNode jsonNode = postQuery(query); + for (int i = 0; i < getCountStarResult(); i++) { + double cosineDistance = jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble(); + assertTrue(cosineDistance > 0.1 && cosineDistance < 0.4); + double innerProduce = jsonNode.get("resultTable").get("rows").get(0).get(1).asDouble(); + assertTrue(innerProduce > 100 && innerProduce < 160); + double l1Distance = jsonNode.get("resultTable").get("rows").get(0).get(2).asDouble(); + assertTrue(l1Distance > 150 && l1Distance < 200); + double l2Distance = jsonNode.get("resultTable").get("rows").get(0).get(3).asDouble(); + assertTrue(l2Distance > 8 && l2Distance < 11); + int vectorDimsVector1 = jsonNode.get("resultTable").get("rows").get(0).get(4).asInt(); + assertEquals(vectorDimsVector1, VECTOR_DIM_SIZE); + int vectorDimsVector2 = jsonNode.get("resultTable").get("rows").get(0).get(5).asInt(); + assertEquals(vectorDimsVector2, VECTOR_DIM_SIZE); + double vectorNormVector1 = jsonNode.get("resultTable").get("rows").get(0).get(6).asInt(); + assertTrue(vectorNormVector1 > 10 && vectorNormVector1 < 16); + double vectorNormVector2 = jsonNode.get("resultTable").get("rows").get(0).get(7).asInt(); + assertTrue(vectorNormVector2 > 10 && vectorNormVector2 < 16); + cosineDistance = jsonNode.get("resultTable").get("rows").get(0).get(8).asDouble(); + assertEquals(cosineDistance, Double.NaN); + cosineDistance = jsonNode.get("resultTable").get("rows").get(0).get(9).asDouble(); + assertEquals(cosineDistance, 0.0); + } + } + + private File createAvroFile(long totalNumRecords) + throws IOException { + + // create avro schema + org.apache.avro.Schema avroSchema = org.apache.avro.Schema.createRecord("myRecord", null, null, false); + avroSchema.setFields(ImmutableList.of( + new Field(VECTOR_1, org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(Type.FLOAT)), null, + null), + new Field(VECTOR_2, org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(Type.FLOAT)), null, + null), + new Field(ZERO_VECTOR, org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(Type.FLOAT)), null, + null) + )); + + // create avro file + File avroFile = new File(_tempDir, "data.avro"); + try (DataFileWriter<GenericData.Record> fileWriter = new DataFileWriter<>(new GenericDatumWriter<>(avroSchema))) { + fileWriter.create(avroSchema, avroFile); + for (int i = 0; i < totalNumRecords; i++) { + // create avro record + GenericData.Record record = new GenericData.Record(avroSchema); + + Collection<Float> vector1 = createRandomVector(VECTOR_DIM_SIZE); + Collection<Float> vector2 = createRandomVector(VECTOR_DIM_SIZE); + Collection<Float> zeroVector = createZeroVector(VECTOR_DIM_SIZE); + record.put(VECTOR_1, vector1); + record.put(VECTOR_2, vector2); + record.put(ZERO_VECTOR, zeroVector); + + // add avro record to file + fileWriter.append(record); + } + } + return avroFile; + } + + private Collection<Float> createZeroVector(int vectorDimSize) { + List<Float> vector = new ArrayList<>(); + for (int i = 0; i < vectorDimSize; i++) { + vector.add(i, 0.0f); + } + return vector; + } + + private Collection<Float> createRandomVector(int vectorDimSize) { + List<Float> vector = new ArrayList<>(); + for (int i = 0; i < vectorDimSize; i++) { + vector.add(i, RandomUtils.nextFloat(0.0f, 1.0f)); + } + return vector; + } + + @AfterClass + public void tearDown() + throws IOException { + dropOfflineTable(DEFAULT_TABLE_NAME); + + stopServer(); + stopBroker(); + stopController(); + stopZk(); + + FileUtils.deleteDirectory(_tempDir); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org