This is an automated email from the ASF dual-hosted git repository.

nehapawar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 6911172  GROOVY transform function UDF for queries (#5748)
6911172 is described below

commit 69111727f6f011a7d864698cde292df0cd2e22d9
Author: Neha Pawar <neha.pawa...@gmail.com>
AuthorDate: Fri Jul 24 19:38:33 2020 -0700

    GROOVY transform function UDF for queries (#5748)
---
 .../common/function/TransformFunctionType.java     |   2 +-
 .../data/function/GroovyFunctionEvaluator.java     |  19 +-
 .../function/GroovyTransformFunction.java          | 438 +++++++++++++++++++++
 .../function/TransformFunctionFactory.java         |   2 +
 .../function/GroovyTransformFunctionTest.java      | 292 ++++++++++++++
 5 files changed, 749 insertions(+), 4 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 583e678..65cf1df 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -55,7 +55,7 @@ public enum TransformFunctionType {
   ARRAYLENGTH("arrayLength"),
   VALUEIN("valueIn"),
   MAPVALUE("mapValue"),
-
+  GROOVY("groovy"),
   // Special type for annotation based scalar functions
   SCALAR("scalar"),
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/data/function/GroovyFunctionEvaluator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/data/function/GroovyFunctionEvaluator.java
index 64e7d9a..ee5422e 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/data/function/GroovyFunctionEvaluator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/data/function/GroovyFunctionEvaluator.java
@@ -55,18 +55,20 @@ public class GroovyFunctionEvaluator implements 
FunctionEvaluator {
   private static final String ARGUMENTS_SEPARATOR = ",";
 
   private final List<String> _arguments;
+  private final int _numArguments;
   private final Binding _binding;
   private final Script _script;
 
-  public GroovyFunctionEvaluator(String transformExpression) {
-    Matcher matcher = GROOVY_FUNCTION_PATTERN.matcher(transformExpression);
-    Preconditions.checkState(matcher.matches(), "Invalid transform expression: 
%s", transformExpression);
+  public GroovyFunctionEvaluator(String closure) {
+    Matcher matcher = GROOVY_FUNCTION_PATTERN.matcher(closure);
+    Preconditions.checkState(matcher.matches(), "Invalid transform expression: 
%s", closure);
     String arguments = matcher.group(ARGUMENTS_GROUP_NAME);
     if (arguments != null) {
       _arguments = 
Splitter.on(ARGUMENTS_SEPARATOR).trimResults().splitToList(arguments);
     } else {
       _arguments = Collections.emptyList();
     }
+    _numArguments = _arguments.size();
     _binding = new Binding();
     _script = new 
GroovyShell(_binding).parse(matcher.group(SCRIPT_GROUP_NAME));
   }
@@ -92,4 +94,15 @@ public class GroovyFunctionEvaluator implements 
FunctionEvaluator {
     }
     return _script.run();
   }
+
+  /**
+   * Evaluate the Groovy function with bindings provided as an array of Object
+   * The number of elements in the values must match the numArguments
+   */
+  public Object evaluate(Object[] values) {
+    for (int i = 0; i < _numArguments; i++) {
+      _binding.setVariable(_arguments.get(i), values[i]);
+    }
+    return _script.run();
+  }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunction.java
new file mode 100644
index 0000000..504d226
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunction.java
@@ -0,0 +1,438 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.google.common.base.Preconditions;
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.floats.FloatArrayList;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.commons.lang3.EnumUtils;
+import org.apache.pinot.core.common.DataSource;
+import org.apache.pinot.core.data.function.GroovyFunctionEvaluator;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.utils.JsonUtils;
+
+
+/**
+ * The GroovyTransformFunction executes groovy expressions
+ * 1st argument - json string containing returnType and isSingleValue e.g. 
'{"returnType":"LONG", "isSingleValue":false}'
+ * 2nd argument - groovy script (string) using arg0, arg1, arg2... as 
arguments e.g. 'arg0 + " " + arg1', 'arg0 + arg1.toList().max() + arg2' etc
+ * rest of the arguments - identifiers/functions to the groovy script
+ *
+ * Sample queries:
+ * SELECT GROOVY('{"returnType":"LONG", "isSingleValue":false}', 
'arg0.findIndexValues{it==1}', products) FROM myTable
+ * SELECT GROOVY('{"returnType":"INT", "isSingleValue":true}', 'arg0 * arg1 * 
10', arraylength(units), columnB ) FROM bob
+ */
+public class GroovyTransformFunction extends BaseTransformFunction {
+  public static final String FUNCTION_NAME = "groovy";
+
+  private static final String RETURN_TYPE_KEY = "returnType";
+  private static final String IS_SINGLE_VALUE_KEY = "isSingleValue";
+  private static final String ARGUMENT_PREFIX = "arg";
+  private static final String GROOVY_TEMPLATE_WITH_ARGS = "Groovy({%s}, %s)";
+  private static final String GROOVY_TEMPLATE_WITHOUT_ARGS = "Groovy({%s})";
+  private static final String GROOVY_ARG_DELIMITER = ",";
+
+  private int[] _intResultSV;
+  private long[] _longResultSV;
+  private double[] _doubleResultSV;
+  private float[] _floatResultSV;
+  private String[] _stringResultSV;
+  private int[][] _intResultMV;
+  private long[][] _longResultMV;
+  private double[][] _doubleResultMV;
+  private float[][] _floatResultMV;
+  private String[][] _stringResultMV;
+  private TransformResultMetadata _resultMetadata;
+
+  private GroovyFunctionEvaluator _groovyFunctionEvaluator;
+  private int _numGroovyArgs;
+  private TransformFunction[] _groovyArguments;
+  private boolean[] _isSourceSingleValue;
+  private FieldSpec.DataType[] _sourceDataType;
+  private BiFunction<TransformFunction, ProjectionBlock, Object>[] 
_transformToValuesFunctions;
+  private BiFunction<Object, Integer, Object>[] _fetchElementFunctions;
+  private Object[] _sourceArrays;
+  private Object[] _bindingValues;
+
+  @Override
+  public String getName() {
+    return FUNCTION_NAME;
+  }
+
+  @Override
+  public void init(List<TransformFunction> arguments, Map<String, DataSource> 
dataSourceMap) {
+    int numArgs = arguments.size();
+    if (numArgs < 2) {
+      throw new IllegalArgumentException("GROOVY transform function requires 
at least 2 arguments");
+    }
+
+    // 1st argument is a json string
+    TransformFunction returnValueMetadata = arguments.get(0);
+    Preconditions.checkState(returnValueMetadata instanceof 
LiteralTransformFunction,
+        "First argument of GROOVY transform function must be a literal, 
representing a json string");
+    String returnValueMetadataStr = ((LiteralTransformFunction) 
returnValueMetadata).getLiteral();
+    try {
+      JsonNode returnValueMetadataJson = 
JsonUtils.stringToJsonNode(returnValueMetadataStr);
+      
Preconditions.checkState(returnValueMetadataJson.hasNonNull(RETURN_TYPE_KEY),
+          "The json string in the first argument of GROOVY transform function 
must have non-null 'returnType'");
+      
Preconditions.checkState(returnValueMetadataJson.hasNonNull(IS_SINGLE_VALUE_KEY),
+          "The json string in the first argument of GROOVY transform function 
must have non-null 'isSingleValue'");
+      String returnTypeStr = 
returnValueMetadataJson.get(RETURN_TYPE_KEY).asText();
+      Preconditions.checkState(EnumUtils.isValidEnum(FieldSpec.DataType.class, 
returnTypeStr),
+          "The 'returnType' in the json string which is the first argument of 
GROOVY transform function must be a valid FieldSpec.DataType enum value");
+      _resultMetadata = new 
TransformResultMetadata(FieldSpec.DataType.valueOf(returnTypeStr),
+          returnValueMetadataJson.get(IS_SINGLE_VALUE_KEY).asBoolean(true), 
false);
+    } catch (IOException e) {
+      throw new IllegalStateException(
+          "Caught exception when converting json string '" + 
returnValueMetadataStr + "' to JsonNode", e);
+    }
+
+    // 2nd argument is groovy expression string
+    TransformFunction groovyTransformFunction = arguments.get(1);
+    Preconditions.checkState(groovyTransformFunction instanceof 
LiteralTransformFunction,
+        "Second argument of GROOVY transform function must be a literal 
string, representing the groovy expression");
+
+    // 3rd argument onwards, all are arguments to the groovy function
+    _numGroovyArgs = numArgs - 2;
+    if (_numGroovyArgs > 0) {
+      _groovyArguments = new TransformFunction[_numGroovyArgs];
+      _isSourceSingleValue = new boolean[_numGroovyArgs];
+      _sourceDataType = new FieldSpec.DataType[_numGroovyArgs];
+      int idx = 0;
+      for (int i = 2; i < numArgs; i++) {
+        TransformFunction argument = arguments.get(i);
+        Preconditions.checkState(!(argument instanceof 
LiteralTransformFunction),
+            "Third argument onwards, all arguments must be a column or other 
transform function");
+        _groovyArguments[idx] = argument;
+        TransformResultMetadata resultMetadata = argument.getResultMetadata();
+        _isSourceSingleValue[idx] = resultMetadata.isSingleValue();
+        _sourceDataType[idx++] = resultMetadata.getDataType();
+      }
+      // construct arguments string for GroovyFunctionEvaluator
+      String argumentsStr = IntStream.range(0, _numGroovyArgs).mapToObj(i -> 
ARGUMENT_PREFIX + i)
+          .collect(Collectors.joining(GROOVY_ARG_DELIMITER));
+      _groovyFunctionEvaluator = new GroovyFunctionEvaluator(String
+          .format(GROOVY_TEMPLATE_WITH_ARGS, ((LiteralTransformFunction) 
groovyTransformFunction).getLiteral(),
+              argumentsStr));
+
+      _transformToValuesFunctions = new BiFunction[_numGroovyArgs];
+      _fetchElementFunctions = new BiFunction[_numGroovyArgs];
+      initFunctions();
+    } else {
+      _groovyFunctionEvaluator = new GroovyFunctionEvaluator(String
+          .format(GROOVY_TEMPLATE_WITHOUT_ARGS, ((LiteralTransformFunction) 
groovyTransformFunction).getLiteral()));
+    }
+    _sourceArrays = new Object[_numGroovyArgs];
+    _bindingValues = new Object[_numGroovyArgs];
+  }
+
+  @Override
+  public TransformResultMetadata getResultMetadata() {
+    return _resultMetadata;
+  }
+
+  private void initFunctions() {
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      BiFunction<Object, Integer, Object> getElementFunction;
+      BiFunction<TransformFunction, ProjectionBlock, Object> 
transformToValuesFunction;
+      if (_isSourceSingleValue[i]) {
+        switch (_sourceDataType[i]) {
+          case INT:
+            transformToValuesFunction = 
TransformFunction::transformToIntValuesSV;
+            getElementFunction = (sourceArray, position) -> ((int[]) 
sourceArray)[position];
+            break;
+          case LONG:
+            transformToValuesFunction = 
TransformFunction::transformToLongValuesSV;
+            getElementFunction = (sourceArray, position) -> ((long[]) 
sourceArray)[position];
+            break;
+          case FLOAT:
+            transformToValuesFunction = 
TransformFunction::transformToFloatValuesSV;
+            getElementFunction = (sourceArray, position) -> ((float[]) 
sourceArray)[position];
+            break;
+          case DOUBLE:
+            transformToValuesFunction = 
TransformFunction::transformToDoubleValuesSV;
+            getElementFunction = (sourceArray, position) -> ((double[]) 
sourceArray)[position];
+            break;
+          case STRING:
+            transformToValuesFunction = 
TransformFunction::transformToStringValuesSV;
+            getElementFunction = (sourceArray, position) -> ((String[]) 
sourceArray)[position];
+            break;
+          default:
+            throw new IllegalStateException(
+                "Unsupported data type '" + _sourceDataType[i] + "' for GROOVY 
transform function");
+        }
+      } else {
+        switch (_sourceDataType[i]) {
+          case INT:
+            transformToValuesFunction = 
TransformFunction::transformToIntValuesMV;
+            getElementFunction = (sourceArray, position) -> ((int[][]) 
sourceArray)[position];
+            break;
+          case LONG:
+            transformToValuesFunction = 
TransformFunction::transformToLongValuesMV;
+            getElementFunction = (sourceArray, position) -> ((long[][]) 
sourceArray)[position];
+            break;
+          case FLOAT:
+            transformToValuesFunction = 
TransformFunction::transformToFloatValuesMV;
+            getElementFunction = (sourceArray, position) -> ((float[][]) 
sourceArray)[position];
+            break;
+          case DOUBLE:
+            transformToValuesFunction = 
TransformFunction::transformToDoubleValuesMV;
+            getElementFunction = (sourceArray, position) -> ((double[][]) 
sourceArray)[position];
+            break;
+          case STRING:
+            transformToValuesFunction = 
TransformFunction::transformToStringValuesMV;
+            getElementFunction = (sourceArray, position) -> ((String[][]) 
sourceArray)[position];
+            break;
+          default:
+            throw new IllegalStateException(
+                "Unsupported data type '" + _sourceDataType[i] + "' for GROOVY 
transform function");
+        }
+      }
+      _transformToValuesFunctions[i] = transformToValuesFunction;
+      _fetchElementFunctions[i] = getElementFunction;
+    }
+  }
+
+  @Override
+  public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
+    if (_intResultSV == null) {
+      _intResultSV = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      _intResultSV[i] = (int) 
_groovyFunctionEvaluator.evaluate(_bindingValues);
+    }
+    return _intResultSV;
+  }
+
+  @Override
+  public int[][] transformToIntValuesMV(ProjectionBlock projectionBlock) {
+    if (_intResultMV == null) {
+      _intResultMV = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+      if (result instanceof List) {
+        _intResultMV[i] = new IntArrayList((List<Integer>) 
result).toIntArray();
+      } else if (result instanceof int[]) {
+        _intResultMV[i] = (int[]) result;
+      } else {
+        throw new IllegalStateException("Unexpected result type '" + 
result.getClass() + "' for GROOVY function");
+      }
+    }
+    return _intResultMV;
+  }
+
+  @Override
+  public double[] transformToDoubleValuesSV(ProjectionBlock projectionBlock) {
+    if (_doubleResultSV == null) {
+      _doubleResultSV = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      _doubleResultSV[i] = (double) 
_groovyFunctionEvaluator.evaluate(_bindingValues);
+    }
+    return _doubleResultSV;
+  }
+
+  @Override
+  public double[][] transformToDoubleValuesMV(ProjectionBlock projectionBlock) 
{
+    if (_doubleResultMV == null) {
+      _doubleResultMV = new double[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+      if (result instanceof List) {
+        _doubleResultMV[i] = new DoubleArrayList((List<Double>) 
result).toDoubleArray();
+      } else if (result instanceof double[]) {
+        _doubleResultMV[i] = (double[]) result;
+      } else {
+        throw new IllegalStateException("Unexpected result type '" + 
result.getClass() + "' for GROOVY function");
+      }
+    }
+    return _doubleResultMV;
+  }
+
+  @Override
+  public long[] transformToLongValuesSV(ProjectionBlock projectionBlock) {
+    if (_longResultSV == null) {
+      _longResultSV = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      _longResultSV[i] = (long) 
_groovyFunctionEvaluator.evaluate(_bindingValues);
+    }
+    return _longResultSV;
+  }
+
+  @Override
+  public long[][] transformToLongValuesMV(ProjectionBlock projectionBlock) {
+    if (_longResultMV == null) {
+      _longResultMV = new long[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+      if (result instanceof List) {
+        _longResultMV[i] = new LongArrayList((List<Long>) 
result).toLongArray();
+      } else if (result instanceof long[]) {
+        _longResultMV[i] = (long[]) result;
+      } else {
+        throw new IllegalStateException("Unexpected result type '" + 
result.getClass() + "' for GROOVY function");
+      }
+    }
+    return _longResultMV;
+  }
+
+  @Override
+  public float[] transformToFloatValuesSV(ProjectionBlock projectionBlock) {
+    if (_floatResultSV == null) {
+      _floatResultSV = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      _floatResultSV[i] = (float) 
_groovyFunctionEvaluator.evaluate(_bindingValues);
+    }
+    return _floatResultSV;
+  }
+
+  @Override
+  public float[][] transformToFloatValuesMV(ProjectionBlock projectionBlock) {
+    if (_floatResultMV == null) {
+      _floatResultMV = new float[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+      if (result instanceof List) {
+        _floatResultMV[i] = new FloatArrayList((List<Float>) 
result).toFloatArray();
+      } else if (result instanceof float[]) {
+        _floatResultMV[i] = (float[]) result;
+      } else {
+        throw new IllegalStateException("Unexpected result type '" + 
result.getClass() + "' for GROOVY function");
+      }
+    }
+    return _floatResultMV;
+  }
+
+  @Override
+  public String[] transformToStringValuesSV(ProjectionBlock projectionBlock) {
+    if (_stringResultSV == null) {
+      _stringResultSV = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      _stringResultSV[i] = (String) 
_groovyFunctionEvaluator.evaluate(_bindingValues);
+    }
+    return _stringResultSV;
+  }
+
+  @Override
+  public String[][] transformToStringValuesMV(ProjectionBlock projectionBlock) 
{
+    if (_stringResultMV == null) {
+      _stringResultMV = new String[DocIdSetPlanNode.MAX_DOC_PER_CALL][];
+    }
+    for (int i = 0; i < _numGroovyArgs; i++) {
+      _sourceArrays[i] = 
_transformToValuesFunctions[i].apply(_groovyArguments[i], projectionBlock);
+    }
+    int length = projectionBlock.getNumDocs();
+    for (int i = 0; i < length; i++) {
+      for (int j = 0; j < _numGroovyArgs; j++) {
+        _bindingValues[j] = _fetchElementFunctions[j].apply(_sourceArrays[j], 
i);
+      }
+      Object result = _groovyFunctionEvaluator.evaluate(_bindingValues);
+      if (result instanceof List) {
+        _stringResultMV[i] = ((List<String>) result).toArray(new String[0]);
+      } else if (result instanceof String[]) {
+        _stringResultMV[i] = (String[]) result;
+      } else {
+        throw new IllegalStateException("Unexpected result type '" + 
result.getClass() + "' for GROOVY function");
+      }
+    }
+    return _stringResultMV;
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 670d8a1..a4f9548 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -90,6 +90,8 @@ public class TransformFunctionFactory {
           put(TransformFunctionType.ARRAYLENGTH.getName().toLowerCase(), 
ArrayLengthTransformFunction.class);
           put(TransformFunctionType.VALUEIN.getName().toLowerCase(), 
ValueInTransformFunction.class);
           put(TransformFunctionType.MAPVALUE.getName().toLowerCase(), 
MapValueTransformFunction.class);
+
+          put(TransformFunctionType.GROOVY.getName().toLowerCase(), 
GroovyTransformFunction.class);
           put(TransformFunctionType.CASE.getName().toLowerCase(), 
CaseTransformFunction.class);
 
           put(TransformFunctionType.EQUALS.getName().toLowerCase(), 
EqualsTransformFunction.class);
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunctionTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunctionTest.java
new file mode 100644
index 0000000..187fb61
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GroovyTransformFunctionTest.java
@@ -0,0 +1,292 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.google.common.base.Joiner;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.IntSummaryStatistics;
+import java.util.List;
+import java.util.stream.IntStream;
+import org.apache.pinot.core.query.exception.BadQueryRequestException;
+import org.apache.pinot.core.query.request.context.ExpressionContext;
+import 
org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.testng.Assert;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+
+/**
+ * Tests the GROOVY transform function
+ */
+public class GroovyTransformFunctionTest extends BaseTransformFunctionTest {
+
+  @DataProvider(name = "groovyFunctionDataProvider")
+  public Object[][] groovyFunctionDataProvider() {
+
+    String groovyTransformFunction;
+    List<Object[]> inputs = new ArrayList<>();
+
+    // max in array (returns SV INT)
+    groovyTransformFunction = String
+        .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', "
+            + "'arg0.toList().max()', "
+            + "%s)", INT_MV_COLUMN);
+    int[] expectedResult1 = new int[NUM_ROWS];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      expectedResult1[i] = Arrays.stream(_intMVValues[i]).max().getAsInt();
+    }
+    inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.INT, 
true, expectedResult1});
+
+    // simple addition (returns SV LONG)
+    groovyTransformFunction = String
+        .format("groovy('{\"returnType\":\"LONG\", \"isSingleValue\":true}', "
+                + "'arg0 + arg1', "
+                + "%s, %s)", INT_SV_COLUMN, LONG_SV_COLUMN);
+    long[] expectedResult2 = new long[NUM_ROWS];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      expectedResult2[i] = _intSVValues[i] + _longSVValues[i];
+    }
+    inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.LONG, 
true, expectedResult2});
+
+    // minimum of 2 numbers (returns SV DOUBLE)
+    groovyTransformFunction = String
+        .format("groovy('{\"returnType\":\"DOUBLE\", \"isSingleValue\":true}', 
"
+                + "'Math.min(arg0, arg1)', "
+                + "%s, %s)", DOUBLE_SV_COLUMN, INT_SV_COLUMN);
+    double[] expectedResult3 = new double[NUM_ROWS];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      expectedResult3[i] = Math.min(_intSVValues[i], _doubleSVValues[i]);
+    }
+    inputs.add(new Object[]{groovyTransformFunction, 
FieldSpec.DataType.DOUBLE, true, expectedResult3});
+
+    // (returns SV FLOAT)
+    groovyTransformFunction = String.format(
+        "groovy('{\"returnType\":\"FLOAT\", \"isSingleValue\":true}', "
+            + "'def result; switch(arg0.length()) { case 10: result = 1.1; 
break; case 20: result = 1.2; break; default: result = 1.3;}; return 
result.floatValue()', "
+            + "%s)", STRING_ALPHANUM_SV_COLUMN);
+    float[] expectedResult4 = new float[NUM_ROWS];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      expectedResult4[i] =
+          _stringAlphaNumericSVValues.length == 10 ? 1.1f : 
(_stringAlphaNumericSVValues.length == 20 ? 1.2f : 1.3f);
+    }
+    inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.FLOAT, 
true, expectedResult4});
+
+    // string operations (returns SV STRING)
+    groovyTransformFunction = String.format(
+        "groovy('{\"returnType\":\"STRING\", \"isSingleValue\":true}', "
+            + "'[arg0, arg1, arg2].join(\"_\")', "
+            + "%s, %s, %s)", FLOAT_SV_COLUMN, STRING_SV_COLUMN, 
DOUBLE_SV_COLUMN);
+    String[] expectedResult5 = new String[NUM_ROWS];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      expectedResult5[i] = Joiner.on("_").join(_floatSVValues[i], 
_stringSVValues[i], _doubleSVValues[i]);
+    }
+    inputs.add(new Object[]{groovyTransformFunction, 
FieldSpec.DataType.STRING, true, expectedResult5});
+
+    // find all in array that match (returns MV INT)
+    groovyTransformFunction = String
+        .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":false}', "
+                + "'arg0.findAll{it < 5}', "
+                + "%s)", INT_MV_COLUMN);
+    int[][] expectedResult6 = new int[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      expectedResult6[i] = Arrays.stream(_intMVValues[i]).filter(e -> e < 
5).toArray();
+    }
+    inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.INT, 
false, expectedResult6});
+
+    // (returns MV LONG)
+    groovyTransformFunction = String
+        .format("groovy('{\"returnType\":\"LONG\", \"isSingleValue\":false}', "
+            + "'arg0.findIndexValues{it == 5}', "
+                + "%s)", INT_MV_COLUMN);
+    long[][] expectedResult7 = new long[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      int[] intMVValue = _intMVValues[i];
+      expectedResult7[i] =
+          IntStream.range(0, intMVValue.length).filter(e -> intMVValue[e] == 
5).mapToLong(e -> (long) e).toArray();
+    }
+    inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.LONG, 
false, expectedResult7});
+
+    // no-args function (returns MV STRING)
+    groovyTransformFunction = "groovy('{\"returnType\":\"STRING\", 
\"isSingleValue\":false}', '[\"foo\", \"bar\"]')";
+    String[][] expectedResult8 = new String[NUM_ROWS][];
+    Arrays.fill(expectedResult8, new String[]{"foo", "bar"});
+    inputs.add(new Object[]{groovyTransformFunction, 
FieldSpec.DataType.STRING, false, expectedResult8});
+
+    // nested groovy functions
+    String groovy1 = String
+        .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 
'arg0.toList().max()', %s)", INT_MV_COLUMN);
+    String groovy2 = String
+        .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 
'arg0.toList().min()', %s)", INT_MV_COLUMN);
+    groovyTransformFunction = String
+        .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":false}', 
'[arg0, arg1, arg2.sum()]', %s, %s, %s)",
+            groovy1, groovy2, INT_MV_COLUMN);
+    int[][] expectedResult9 = new int[NUM_ROWS][];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      IntSummaryStatistics stats = 
Arrays.stream(_intMVValues[i]).summaryStatistics();
+      expectedResult9[i] = new int[]{stats.getMax(), stats.getMin(), (int) 
stats.getSum()};
+    }
+    inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.INT, 
false, expectedResult9});
+
+    // nested with other functions
+    groovyTransformFunction = String
+        .format("groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 
'arg0 + arg1', %s, arraylength(%s))",
+            INT_SV_COLUMN, INT_MV_COLUMN);
+    int[] expectedResult10 = new int[NUM_ROWS];
+    for (int i = 0; i < NUM_ROWS; i++) {
+      expectedResult10[i] = _intSVValues[i] + _intMVValues[i].length;
+    }
+    inputs.add(new Object[]{groovyTransformFunction, FieldSpec.DataType.INT, 
true, expectedResult10});
+
+    return inputs.toArray(new Object[0][]);
+  }
+
+  @Test(dataProvider = "groovyFunctionDataProvider")
+  public void testGroovyTransformFunctions(String expressionStr, 
FieldSpec.DataType resultType,
+      boolean isResultSingleValue, Object expectedResult) {
+    ExpressionContext expression = 
QueryContextConverterUtils.getExpression(expressionStr);
+    TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap);
+    Assert.assertTrue(transformFunction instanceof GroovyTransformFunction);
+    Assert.assertEquals(transformFunction.getName(), 
GroovyTransformFunction.FUNCTION_NAME);
+    Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), 
resultType);
+    Assert.assertEquals(transformFunction.getResultMetadata().isSingleValue(), 
isResultSingleValue);
+    Assert.assertFalse(transformFunction.getResultMetadata().hasDictionary());
+
+    if (isResultSingleValue) {
+      switch (resultType) {
+
+        case INT:
+          int[] intResults = 
transformFunction.transformToIntValuesSV(_projectionBlock);
+          int[] expectedInts = (int[]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(intResults[i], expectedInts[i]);
+          }
+          break;
+        case LONG:
+          long[] longResults = 
transformFunction.transformToLongValuesSV(_projectionBlock);
+          long[] expectedLongs = (long[]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(longResults[i], expectedLongs[i]);
+          }
+          break;
+        case FLOAT:
+          float[] floatResults = 
transformFunction.transformToFloatValuesSV(_projectionBlock);
+          float[] expectedFloats = (float[]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(floatResults[i], expectedFloats[i]);
+          }
+          break;
+        case DOUBLE:
+          double[] doubleResults = 
transformFunction.transformToDoubleValuesSV(_projectionBlock);
+          double[] expectedDoubles = (double[]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(doubleResults[i], expectedDoubles[i]);
+          }
+          break;
+        case STRING:
+          String[] stringResults = 
transformFunction.transformToStringValuesSV(_projectionBlock);
+          String[] expectedStrings = (String[]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(stringResults[i], expectedStrings[i]);
+          }
+          break;
+      }
+    } else {
+      switch (resultType) {
+
+        case INT:
+          int[][] intResults = 
transformFunction.transformToIntValuesMV(_projectionBlock);
+          int[][] expectedInts = (int[][]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(intResults[i], expectedInts[i]);
+          }
+          break;
+        case LONG:
+          long[][] longResults = 
transformFunction.transformToLongValuesMV(_projectionBlock);
+          long[][] expectedLongs = (long[][]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(longResults[i], expectedLongs[i]);
+          }
+          break;
+        case FLOAT:
+          float[][] floatResults = 
transformFunction.transformToFloatValuesMV(_projectionBlock);
+          float[][] expectedFloats = (float[][]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(floatResults[i], expectedFloats[i]);
+          }
+          break;
+        case DOUBLE:
+          double[][] doubleResults = 
transformFunction.transformToDoubleValuesMV(_projectionBlock);
+          double[][] expectedDoubles = (double[][]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(doubleResults[i], expectedDoubles[i]);
+          }
+          break;
+        case STRING:
+          String[][] stringResults = 
transformFunction.transformToStringValuesMV(_projectionBlock);
+          String[][] expectedStrings = (String[][]) expectedResult;
+          for (int i = 0; i < NUM_ROWS; i++) {
+            Assert.assertEquals(stringResults[i], expectedStrings[i]);
+          }
+          break;
+      }
+    }
+  }
+
+  @Test(dataProvider = "testIllegalArguments", expectedExceptions = 
{BadQueryRequestException.class})
+  public void testIllegalArguments(String expressionStr) {
+    ExpressionContext expression = 
QueryContextConverterUtils.getExpression(expressionStr);
+    TransformFunctionFactory.get(expression, _dataSourceMap);
+  }
+
+  @DataProvider(name = "testIllegalArguments")
+  public Object[][] testIllegalArguments() {
+    List<Object[]> inputs = new ArrayList<>();
+    // incorrect number of arguments
+    inputs.add(new Object[]{String.format("groovy(%s)", STRING_SV_COLUMN)});
+    // first argument must be literal
+    inputs.add(new Object[]{String.format("groovy(%s, %s)", DOUBLE_SV_COLUMN, 
STRING_SV_COLUMN)});
+    // second argument must be a literal
+    inputs.add(new Object[]{String.format("groovy('arg0 + 10', %s)", 
STRING_SV_COLUMN)});
+    // first argument must be a valid json
+    inputs.add(new Object[]{String.format("groovy(']]', 'arg0 + 10', %s)", 
STRING_SV_COLUMN)});
+    // first argument json must contain non-null key returnType
+    inputs.add(new Object[]{String.format("groovy('{\"isSingleValue\":true}', 
'arg0 + 10', %s)", INT_SV_COLUMN)});
+    inputs.add(new Object[]{String.format("groovy('{\"returnType\":null, 
\"isSingleValue\":true}', 'arg0 + 10', %s)",
+        INT_SV_COLUMN)});
+    // first argument json must contain non-null key isSingleValue
+    inputs.add(new Object[]{String.format("groovy('{\"returnType\":\"INT\"}', 
'arg0 + 10', %s)", INT_SV_COLUMN)});
+    inputs.add(new Object[]{String.format("groovy('{\"returnType\":\"INT\", 
\"isSingleValue\":null}', 'arg0 + 10', %s)",
+        INT_SV_COLUMN)});
+    // return type must be valid DataType enum
+    inputs.add(new Object[]{String.format("groovy('{\"returnType\":\"foo\", 
\"isSingleValue\":true}', 'arg0 + 10', %s)",
+        INT_SV_COLUMN)});
+    // arguments must be columns/transform functions
+    inputs.add(new Object[]{"groovy('{\"returnType\":\"INT\", 
\"isSingleValue\":true}', 'arg0 + 10', 'foo')"});
+    inputs.add(new Object[]{String.format(
+        "groovy('{\"returnType\":\"INT\", \"isSingleValue\":true}', 'arg0 + 
arg1 + 10', 'arraylength(colB)', %s)",
+        INT_SV_COLUMN)});
+    // invalid groovy expression
+    inputs.add(new Object[]{"groovy('{\"returnType\":\"INT\"}', '+-+')"});
+    inputs.add(new Object[]{String.format("groovy('{\"returnType\":\"INT\"}', 
'+-+arg0 arg1', %s, %s)", INT_SV_COLUMN,
+        DOUBLE_SV_COLUMN)});
+    return inputs.toArray(new Object[0][]);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to