This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new d062cfd  Add InIdSetTransformFunction (#5973)
d062cfd is described below

commit d062cfde5a7093a59d19240a1a678dea500279bf
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Tue Sep 8 18:15:05 2020 -0700

    Add InIdSetTransformFunction (#5973)
    
    Add `InIdSetTransformFunction` to support filtering with an `IdSet`.
    Example query:
    `SELECT COUNT(*) FROM mytable WHERE INIDSET(AirlineID, 
'AgAAAAABAAAAADowAAABAAAAAAADABAAAAAAAOpMg0+zUg==') = 1`
    (`AgAAAAABAAAAADowAAABAAAAAAADABAAAAAAAOpMg0+zUg==` is the base64 encoded 
IdSet)
---
 .../common/function/TransformFunctionType.java     |   1 +
 .../function/InIdSetTransformFunction.java         | 132 +++++++++++++++++++++
 .../function/TransformFunctionFactory.java         |   1 +
 .../org/apache/pinot/queries/IdSetQueriesTest.java |  45 +++++++
 .../tests/BaseClusterIntegrationTestSet.java       |  22 +++-
 5 files changed, 198 insertions(+), 3 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
index 76f6446..1347f1d 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java
@@ -55,6 +55,7 @@ public enum TransformFunctionType {
   ARRAYLENGTH("arrayLength"),
   VALUEIN("valueIn"),
   MAPVALUE("mapValue"),
+  INIDSET("inIdSet"),
   GROOVY("groovy"),
   // Special type for annotation based scalar functions
   SCALAR("scalar"),
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/InIdSetTransformFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/InIdSetTransformFunction.java
new file mode 100644
index 0000000..bff7f36
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/InIdSetTransformFunction.java
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.operator.transform.function;
+
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.function.TransformFunctionType;
+import org.apache.pinot.core.common.DataSource;
+import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.operator.transform.TransformResultMetadata;
+import org.apache.pinot.core.plan.DocIdSetPlanNode;
+import org.apache.pinot.core.query.utils.idset.IdSet;
+import org.apache.pinot.core.query.utils.idset.IdSets;
+import org.apache.pinot.core.segment.index.readers.Dictionary;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+
+
+/**
+ * The IN_ID_SET transform function takes 2 arguments:
+ * <ul>
+ *   <li>Expression: a single-value expression</li>
+ *   <li>Base64 encoded IdSet: a literal string</li>
+ * </ul>
+ * <p>For each docId, the function returns {@code 1} if the IdSet contains the 
value of the expression, {code 0} if not.
+ * <p>E.g. {@code SELECT COUNT(*) FROM myTable WHERE IN_ID_SET(col, '<base64 
encoded IdSet>') = 1)}
+ */
+public class InIdSetTransformFunction extends BaseTransformFunction {
+  private TransformFunction _transformFunction;
+  private IdSet _idSet;
+  private int[] _results;
+
+  @Override
+  public String getName() {
+    return TransformFunctionType.INIDSET.getName();
+  }
+
+  @Override
+  public void init(List<TransformFunction> arguments, Map<String, DataSource> 
dataSourceMap) {
+    Preconditions.checkArgument(arguments.size() == 2,
+        "2 arguments are required for IN_ID_SET transform function: 
expression, base64 encoded IdSet");
+    
Preconditions.checkArgument(arguments.get(0).getResultMetadata().isSingleValue(),
+        "First argument for IN_ID_SET transform function must be a 
single-value expression");
+    Preconditions.checkArgument(arguments.get(1) instanceof 
LiteralTransformFunction,
+        "Second argument for IN_ID_SET transform function must be a literal 
string of the base64 encoded IdSet");
+
+    _transformFunction = arguments.get(0);
+    try {
+      _idSet = IdSets.fromBase64String(((LiteralTransformFunction) 
arguments.get(1)).getLiteral());
+    } catch (IOException e) {
+      throw new IllegalArgumentException("Caught exception while deserializing 
IdSet", e);
+    }
+  }
+
+  @Override
+  public TransformResultMetadata getResultMetadata() {
+    return INT_SV_NO_DICTIONARY_METADATA;
+  }
+
+  @Override
+  public Dictionary getDictionary() {
+    return null;
+  }
+
+  @Override
+  public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
+    if (_results == null) {
+      _results = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
+    }
+
+    int length = projectionBlock.getNumDocs();
+    DataType dataType = _transformFunction.getResultMetadata().getDataType();
+    switch (dataType) {
+      case INT:
+        int[] intValues = 
_transformFunction.transformToIntValuesSV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          _results[i] = _idSet.contains(intValues[i]) ? 1 : 0;
+        }
+        break;
+      case LONG:
+        long[] longValues = 
_transformFunction.transformToLongValuesSV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          _results[i] = _idSet.contains(longValues[i]) ? 1 : 0;
+        }
+        break;
+      case FLOAT:
+        float[] floatValues = 
_transformFunction.transformToFloatValuesSV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          _results[i] = _idSet.contains(floatValues[i]) ? 1 : 0;
+        }
+        break;
+      case DOUBLE:
+        double[] doubleValues = 
_transformFunction.transformToDoubleValuesSV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          _results[i] = _idSet.contains(doubleValues[i]) ? 1 : 0;
+        }
+        break;
+      case STRING:
+        String[] stringValues = 
_transformFunction.transformToStringValuesSV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          _results[i] = _idSet.contains(stringValues[i]) ? 1 : 0;
+        }
+        break;
+      case BYTES:
+        byte[][] bytesValues = 
_transformFunction.transformToBytesValuesSV(projectionBlock);
+        for (int i = 0; i < length; i++) {
+          _results[i] = _idSet.contains(bytesValues[i]) ? 1 : 0;
+        }
+        break;
+      default:
+        throw new IllegalStateException();
+    }
+    return _results;
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index 2e89188..e70e5ed 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -90,6 +90,7 @@ public class TransformFunctionFactory {
           put(TransformFunctionType.ARRAYLENGTH.getName().toLowerCase(), 
ArrayLengthTransformFunction.class);
           put(TransformFunctionType.VALUEIN.getName().toLowerCase(), 
ValueInTransformFunction.class);
           put(TransformFunctionType.MAPVALUE.getName().toLowerCase(), 
MapValueTransformFunction.class);
+          put(TransformFunctionType.INIDSET.getName().toLowerCase(), 
InIdSetTransformFunction.class);
 
           put(TransformFunctionType.GROOVY.getName().toLowerCase(), 
GroovyTransformFunction.class);
           put(TransformFunctionType.CASE.getName().toLowerCase(), 
CaseTransformFunction.class);
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/IdSetQueriesTest.java 
b/pinot-core/src/test/java/org/apache/pinot/queries/IdSetQueriesTest.java
index 480b208..0ee4f0e 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/IdSetQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/IdSetQueriesTest.java
@@ -42,6 +42,7 @@ import 
org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
 import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
 import org.apache.pinot.core.query.utils.idset.BloomFilterIdSet;
 import org.apache.pinot.core.query.utils.idset.EmptyIdSet;
+import org.apache.pinot.core.query.utils.idset.IdSet;
 import org.apache.pinot.core.query.utils.idset.IdSets;
 import org.apache.pinot.core.query.utils.idset.Roaring64NavigableMapIdSet;
 import org.apache.pinot.core.query.utils.idset.RoaringBitmapIdSet;
@@ -392,6 +393,50 @@ public class IdSetQueriesTest extends BaseQueriesTest {
     }
   }
 
+  @Test
+  public void testInIdSet()
+      throws IOException {
+    // Create an IdSet with the values from the first half records
+    IdSet idSet = IdSets.create(DataType.INT);
+    for (int i = 0; i < NUM_RECORDS / 2; i++) {
+      idSet.add(_values[i]);
+    }
+    String serializedIdSet = idSet.toBase64String();
+
+    // Calculate the expected number of matching records
+    int expectedNumMatchingRecords = 0;
+    for (int value : _values) {
+      if (idSet.contains(value)) {
+        expectedNumMatchingRecords++;
+      }
+    }
+
+    {
+      String query = "SELECT COUNT(*) FROM testTable where INIDSET(intColumn, 
'" + serializedIdSet + "') = 1";
+      AggregationOperator aggregationOperator = getOperatorForPqlQuery(query);
+      IntermediateResultsBlock resultsBlock = aggregationOperator.nextBlock();
+      QueriesTestUtils
+          
.testInnerSegmentExecutionStatistics(aggregationOperator.getExecutionStatistics(),
 expectedNumMatchingRecords,
+              NUM_RECORDS, 0, NUM_RECORDS);
+      List<Object> aggregationResult = resultsBlock.getAggregationResult();
+      assertNotNull(aggregationResult);
+      assertEquals(aggregationResult.size(), 1);
+      assertEquals((long) aggregationResult.get(0), 
expectedNumMatchingRecords);
+    }
+
+    {
+      String query = "SELECT COUNT(*) FROM testTable where INIDSET(intColumn, 
'" + serializedIdSet + "') = 0";
+      AggregationOperator aggregationOperator = getOperatorForPqlQuery(query);
+      IntermediateResultsBlock resultsBlock = aggregationOperator.nextBlock();
+      
QueriesTestUtils.testInnerSegmentExecutionStatistics(aggregationOperator.getExecutionStatistics(),
+          NUM_RECORDS - expectedNumMatchingRecords, NUM_RECORDS, 0, 
NUM_RECORDS);
+      List<Object> aggregationResult = resultsBlock.getAggregationResult();
+      assertNotNull(aggregationResult);
+      assertEquals(aggregationResult.size(), 1);
+      assertEquals((long) aggregationResult.get(0), NUM_RECORDS - 
expectedNumMatchingRecords);
+    }
+  }
+
   @AfterClass
   public void tearDown()
       throws IOException {
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
index cd18431..a524047 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java
@@ -33,6 +33,8 @@ import org.apache.helix.model.InstanceConfig;
 import org.apache.pinot.client.ResultSet;
 import org.apache.pinot.client.ResultSetGroup;
 import org.apache.pinot.common.utils.CommonConstants;
+import org.apache.pinot.core.query.utils.idset.IdSet;
+import org.apache.pinot.core.query.utils.idset.IdSets;
 import org.apache.pinot.spi.data.DimensionFieldSpec;
 import org.apache.pinot.spi.data.FieldSpec;
 import org.apache.pinot.spi.data.MetricFieldSpec;
@@ -52,7 +54,6 @@ import static org.testng.Assert.assertTrue;
  * Shared set of common tests for cluster integration tests.
  * <p>To enable the test, override it and add @Test annotation.
  */
-@SuppressWarnings("unused")
 public abstract class BaseClusterIntegrationTestSet extends 
BaseClusterIntegrationTest {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(BaseClusterIntegrationTestSet.class);
   private static final Random RANDOM = new Random();
@@ -186,11 +187,11 @@ public abstract class BaseClusterIntegrationTestSet 
extends BaseClusterIntegrati
     testSqlQuery(query, Collections.singletonList(query));
     query =
         "SELECT DistanceGroup FROM mytable WHERE \"Month\" BETWEEN 1 AND 1 AND 
DivAirportSeqIDs IN (1078102, 1142303, 1530402, 1172102, 1291503) OR 
SecurityDelay IN (1, 0, 14, -9999) LIMIT 10";
-    h2queries = Arrays.asList(
+    h2queries = Collections.singletonList(
         "SELECT DistanceGroup FROM mytable WHERE Month BETWEEN 1 AND 1 AND 
(DivAirportSeqIDs__MV0 IN (1078102, 1142303, 1530402, 1172102, 1291503) OR 
DivAirportSeqIDs__MV1 IN (1078102, 1142303, 1530402, 1172102, 1291503) OR 
DivAirportSeqIDs__MV2 IN (1078102, 1142303, 1530402, 1172102, 1291503) OR 
DivAirportSeqIDs__MV3 IN (1078102, 1142303, 1530402, 1172102, 1291503) OR 
DivAirportSeqIDs__MV4 IN (1078102, 1142303, 1530402, 1172102, 1291503)) OR 
SecurityDelay IN (1, 0, 14, -9999) LIMIT 10000");
     testSqlQuery(query, h2queries);
     query = "SELECT MAX(Quarter), MAX(FlightNum) FROM mytable LIMIT 8";
-    h2queries = Arrays.asList("SELECT MAX(Quarter),MAX(FlightNum) FROM mytable 
LIMIT 10000");
+    h2queries = Collections.singletonList("SELECT MAX(Quarter),MAX(FlightNum) 
FROM mytable LIMIT 10000");
     testSqlQuery(query, h2queries);
     query = "SELECT COUNT(*) FROM mytable WHERE DaysSinceEpoch = 16312 AND 
Carrier = 'DL'";
     testSqlQuery(query, Collections.singletonList(query));
@@ -251,6 +252,21 @@ public abstract class BaseClusterIntegrationTestSet 
extends BaseClusterIntegrati
     query =
         "SELECT DaysSinceEpoch, MAX(ArrDelay) - MAX(AirTime) AS Diff FROM 
mytable GROUP BY DaysSinceEpoch HAVING (Diff >= 300 AND Diff < 500) OR Diff < 
-500 ORDER BY Diff DESC";
     testSqlQuery(query, Collections.singletonList(query));
+
+    // IN_ID_SET
+    IdSet idSet = IdSets.create(FieldSpec.DataType.LONG);
+    idSet.add(19690L);
+    idSet.add(20355L);
+    idSet.add(21171L);
+    // Also include a non-existing id
+    idSet.add(0L);
+    String serializedIdSet = idSet.toBase64String();
+    String inIdSetQuery = "SELECT COUNT(*) FROM mytable WHERE 
INIDSET(AirlineID, '" + serializedIdSet + "') = 1";
+    String inQuery = "SELECT COUNT(*) FROM mytable WHERE AirlineID IN (19690, 
20355, 21171, 0)";
+    testSqlQuery(inIdSetQuery, Collections.singletonList(inQuery));
+    String notInIdSetQuery = "SELECT COUNT(*) FROM mytable WHERE 
INIDSET(AirlineID, '" + serializedIdSet + "') = 0";
+    String notInQuery = "SELECT COUNT(*) FROM mytable WHERE AirlineID NOT IN 
(19690, 20355, 21171, 0)";
+    testSqlQuery(notInIdSetQuery, Collections.singletonList(notInQuery));
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to