This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new e949d958a8 [multistage] Support range queries in search (#11512)
e949d958a8 is described below

commit e949d958a8238a48ba147048fa72ada033ea889e
Author: Gonzalo Ortiz Jaureguizar <[email protected]>
AuthorDate: Wed Sep 6 16:23:02 2023 +0200

    [multistage] Support range queries in search (#11512)
    
    * Support range queries in search
    
    * Add a test that ensures we are testing the SEARCH case
---
 .../tests/MultiStageEngineIntegrationTest.java     | 17 +++++
 .../tests/OfflineClusterIntegrationTest.java       | 11 ++-
 .../query/planner/logical/RexExpressionUtils.java  | 83 ++++++++++++++++++++--
 3 files changed, 105 insertions(+), 6 deletions(-)

diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
index f6a2345a62..31583b1a0b 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java
@@ -27,6 +27,7 @@ import java.time.ZoneId;
 import java.time.format.DateTimeFormatter;
 import java.util.List;
 import java.util.concurrent.TimeUnit;
+import java.util.regex.Pattern;
 import org.apache.commons.io.FileUtils;
 import org.apache.pinot.spi.config.table.TableConfig;
 import org.apache.pinot.spi.data.Schema;
@@ -700,6 +701,22 @@ public class MultiStageEngineIntegrationTest extends 
BaseClusterIntegrationTestS
     assertNoError(jsonNode);
   }
 
+  @Test
+  public void testSearch()
+      throws Exception {
+    String sqlQuery = "SELECT CASE WHEN ArrDelay > 50 OR ArrDelay < 10 THEN 10 
ELSE 0 END "
+        + "FROM mytable LIMIT 1000";
+    JsonNode jsonNode = postQuery("Explain plan for " + sqlQuery);
+    JsonNode plan = jsonNode.get("resultTable").get("rows").get(0).get(1);
+
+    Pattern pattern = Pattern.compile("SEARCH\\(\\$7, Sarg\\[\\(-∞\\.\\.10\\), 
\\(50\\.\\.\\+∞\\)]\\)");
+    boolean matches = pattern.matcher(plan.asText()).find();
+    Assert.assertTrue(matches, "Plan doesn't contain the expected SEARCH");
+
+    jsonNode = postQuery(sqlQuery);
+    assertNoError(jsonNode);
+  }
+
   @AfterClass
   public void tearDown()
       throws Exception {
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
index 0debb9d09a..8c0d6602b7 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
@@ -83,6 +83,7 @@ import org.apache.pinot.util.TestUtils;
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
+import org.testng.annotations.BeforeMethod;
 import org.testng.annotations.Test;
 
 import static org.apache.pinot.common.function.scalar.StringFunctions.*;
@@ -222,6 +223,11 @@ public class OfflineClusterIntegrationTest extends 
BaseClusterIntegrationTestSet
     waitForAllDocsLoaded(600_000L);
   }
 
+  @BeforeMethod
+  public void resetMultiStage() {
+    setUseMultiStageQueryEngine(false);
+  }
+
   protected void startBrokers()
       throws Exception {
     startBrokers(getNumBrokers());
@@ -1956,9 +1962,10 @@ public class OfflineClusterIntegrationTest extends 
BaseClusterIntegrationTestSet
     }
   }
 
-  @Test
-  public void testCaseStatementWithLogicalTransformFunction()
+  @Test(dataProvider = "useBothQueryEngines")
+  public void testCaseStatementWithLogicalTransformFunction(boolean 
useMultiStageQueryEngine)
       throws Exception {
+    setUseMultiStageQueryEngine(useMultiStageQueryEngine);
     String sqlQuery = "SELECT ArrDelay" + ", CASE WHEN ArrDelay > 50 OR 
ArrDelay < 10 THEN 10 ELSE 0 END"
         + ", CASE WHEN ArrDelay < 50 AND ArrDelay >= 10 THEN 10 ELSE 0 END" + 
" FROM mytable LIMIT 1000";
     JsonNode response = postQuery(sqlQuery);
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
index 0cb443fecc..a85df889f5 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
@@ -19,10 +19,13 @@
 package org.apache.pinot.query.planner.logical;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.BoundType;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Range;
 import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.GregorianCalendar;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -36,8 +39,8 @@ import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.util.NlsString;
 import org.apache.calcite.util.Sarg;
-import org.apache.commons.lang3.NotImplementedException;
 import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.spi.utils.BooleanUtils;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 
@@ -141,8 +144,7 @@ public class RexExpressionUtils {
     return fromRexNode(operands.get(0));
   }
 
-  // TODO: Add support for range filter expressions (e.g. a > 0 and a < 30)
-  private static RexExpression.FunctionCall handleSearch(RexCall rexCall) {
+  private static RexExpression handleSearch(RexCall rexCall) {
     List<RexNode> operands = rexCall.getOperands();
     RexInputRef rexInputRef = (RexInputRef) operands.get(0);
     RexLiteral rexLiteral = (RexLiteral) operands.get(1);
@@ -155,10 +157,83 @@ public class RexExpressionUtils {
       return new RexExpression.FunctionCall(SqlKind.NOT_IN, dataType, 
SqlKind.NOT_IN.name(),
           toFunctionOperands(rexInputRef, 
sarg.rangeSet.complement().asRanges(), dataType));
     } else {
-      throw new NotImplementedException("Range is not implemented yet");
+      Set<Range<?>> ranges = sarg.rangeSet.asRanges();
+      return convertRangesToOr(dataType, rexInputRef, ranges);
     }
   }
 
+  private static RexExpression convertRangesToOr(ColumnDataType dataType, 
RexInputRef rexInputRef,
+      Set<Range<?>> ranges) {
+    RexExpression result;
+    Iterator<Range<?>> it = ranges.iterator();
+    if (!it.hasNext()) { // no disjunctions means false
+      return new RexExpression.Literal(ColumnDataType.BOOLEAN, 0);
+    }
+    RexExpression.InputRef rexInput = fromRexInputRef(rexInputRef);
+    result = convertRange(rexInput, dataType, it.next());
+    if (result instanceof RexExpression.Literal) {
+      Object value = ((RexExpression.Literal) result).getValue();
+      if (BooleanUtils.isTrueInternalValue(value)) { // one of the 
disjunctions is true => return true
+        return result;
+      }
+    }
+    while (it.hasNext()) {
+      Range<?> range = it.next();
+      RexExpression newExp = convertRange(rexInput, dataType, range);
+      if (newExp instanceof RexExpression.Literal) {
+        Object value = ((RexExpression.Literal) newExp).getValue();
+        if (BooleanUtils.isTrueInternalValue(value)) { // one of the 
disjunctions is true => return true
+          return newExp;
+        } else {
+          continue; // one of the disjunctions is false => ignore it
+        }
+      }
+      ImmutableList<RexExpression> operands = ImmutableList.of(result, newExp);
+      result = new RexExpression.FunctionCall(SqlKind.OR, 
ColumnDataType.BOOLEAN, SqlKind.OR.name(), operands);
+    }
+    return result;
+  }
+
+  private static RexExpression convertRange(RexExpression.InputRef rexInput, 
ColumnDataType dataType, Range<?> range) {
+    if (range.isEmpty()) {
+      return new RexExpression.Literal(ColumnDataType.BOOLEAN, 0);
+    }
+    if (!range.hasLowerBound()) {
+      if (!range.hasUpperBound()) {
+        return new RexExpression.Literal(ColumnDataType.BOOLEAN, 1);
+      }
+      return convertUpperBound(rexInput, dataType, range.upperBoundType(), 
range.upperEndpoint());
+    } else if (!range.hasUpperBound()) {
+      return convertLowerBound(rexInput, dataType, range.lowerBoundType(), 
range.lowerEndpoint());
+    } else {
+      RexExpression lowerConstraint =
+          convertLowerBound(rexInput, dataType, range.lowerBoundType(), 
range.lowerEndpoint());
+      RexExpression upperConstraint =
+          convertUpperBound(rexInput, dataType, range.upperBoundType(), 
range.upperEndpoint());
+      ImmutableList<RexExpression> operands = 
ImmutableList.of(lowerConstraint, upperConstraint);
+      return new RexExpression.FunctionCall(SqlKind.AND, 
ColumnDataType.BOOLEAN, SqlKind.AND.name(), operands);
+    }
+  }
+
+  private static RexExpression convertLowerBound(RexExpression.InputRef 
inputRef, ColumnDataType dataType,
+      BoundType boundType, Comparable<?> endpoint) {
+    SqlKind sqlKind = boundType == BoundType.OPEN ? SqlKind.GREATER_THAN : 
SqlKind.GREATER_THAN_OR_EQUAL;
+    RexExpression.Literal literal = new RexExpression.Literal(dataType, 
convertValue(dataType, endpoint));
+    ImmutableList<RexExpression> operands = ImmutableList.of(inputRef, 
literal);
+    return new RexExpression.FunctionCall(sqlKind, ColumnDataType.BOOLEAN, 
sqlKind.name(), operands);
+  }
+
+  private static RexExpression convertUpperBound(RexExpression.InputRef 
inputRef, ColumnDataType dataType,
+      BoundType boundType, Comparable<?> endpoint) {
+    SqlKind sqlKind = boundType == BoundType.OPEN ? SqlKind.LESS_THAN : 
SqlKind.LESS_THAN_OR_EQUAL;
+    RexExpression.Literal literal = new RexExpression.Literal(dataType, 
convertValue(dataType, endpoint));
+    ImmutableList<RexExpression> operands = ImmutableList.of(inputRef, 
literal);
+    return new RexExpression.FunctionCall(sqlKind, ColumnDataType.BOOLEAN, 
sqlKind.name(), operands);
+  }
+
+  /**
+   * Transforms a set of <b>point based</b> ranges into a list of expressions.
+   */
   private static List<RexExpression> toFunctionOperands(RexInputRef 
rexInputRef, Set<Range> ranges,
       ColumnDataType dataType) {
     List<RexExpression> result = new ArrayList<>(ranges.size() + 1);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to