This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 9b64b5cdd0 Support ValueWindowFunction for 
LEAD/LAG/FIRST_VALUE/LAST_VALUE (#12878)
9b64b5cdd0 is described below

commit 9b64b5cdd070aeade025b440f97b20a26ef32f83
Author: Xiang Fu <xiangfu.1...@gmail.com>
AuthorDate: Thu Apr 11 02:25:31 2024 +0800

    Support ValueWindowFunction for LEAD/LAG/FIRST_VALUE/LAST_VALUE (#12878)
---
 .../rules/PinotWindowExchangeNodeInsertRule.java   |  2 +-
 .../pinot/query/QueryEnvironmentTestBase.java      |  2 +
 .../runtime/operator/WindowAggregateOperator.java  | 36 +++++++++++++--
 .../runtime/operator/utils/AggregationUtils.java   |  2 +-
 .../operator/window/FirstValueWindowFunction.java  | 40 ++++++++++++++++
 .../operator/window/LagValueWindowFunction.java    | 48 +++++++++++++++++++
 .../operator/window/LastValueWindowFunction.java   | 40 ++++++++++++++++
 .../operator/window/LeadValueWindowFunction.java   | 48 +++++++++++++++++++
 .../operator/window/ValueWindowFunction.java       | 54 ++++++++++++++++++++++
 .../runtime/operator/window/WindowFunction.java    | 31 +++++++++++++
 10 files changed, 296 insertions(+), 7 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
index 317a406332..e9caf1216f 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotWindowExchangeNodeInsertRule.java
@@ -65,7 +65,7 @@ public class PinotWindowExchangeNodeInsertRule extends 
RelOptRule {
   // OTHER_FUNCTION supported are: BOOL_AND, BOOL_OR
   private static final Set<SqlKind> SUPPORTED_WINDOW_FUNCTION_KIND = 
ImmutableSet.of(SqlKind.SUM, SqlKind.SUM0,
       SqlKind.MIN, SqlKind.MAX, SqlKind.COUNT, SqlKind.ROW_NUMBER, 
SqlKind.RANK, SqlKind.DENSE_RANK,
-      SqlKind.OTHER_FUNCTION);
+      SqlKind.LAG, SqlKind.LEAD, SqlKind.FIRST_VALUE, SqlKind.LAST_VALUE, 
SqlKind.OTHER_FUNCTION);
 
   public PinotWindowExchangeNodeInsertRule(RelBuilderFactory factory) {
     super(operand(LogicalWindow.class, any()), factory, null);
diff --git 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index 9a97e75b88..6b3b8a3631 100644
--- 
a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++ 
b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -189,6 +189,8 @@ public class QueryEnvironmentTestBase {
                 + "RANK() OVER(ORDER BY count(*) DESC) AS rank FROM a GROUP BY 
a.col1) WHERE rank < 5"
         },
         new Object[]{"SELECT RANK() OVER(PARTITION BY a.col2 ORDER BY a.col1) 
FROM a"},
+        new Object[]{"SELECT a.col1, LEAD(a.col3) OVER (PARTITION BY a.col2 
ORDER BY a.col3) FROM a"},
+        new Object[]{"SELECT a.col1, LAG(a.col3) OVER (PARTITION BY a.col2 
ORDER BY a.col3) FROM a"},
         new Object[]{"SELECT DENSE_RANK() OVER(ORDER BY a.col1) FROM a"},
         new Object[]{"SELECT a.col1, SUM(a.col3) OVER (ORDER BY a.col2), 
MIN(a.col3) OVER (ORDER BY a.col2) FROM a"},
         new Object[]{
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
index 259abaea1b..d2e37598a0 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
@@ -44,6 +44,7 @@ import 
org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
 import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
+import org.apache.pinot.query.runtime.operator.window.ValueWindowFunction;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -175,6 +176,11 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
 
   private void validateAggregationCalls(String functionName,
       Map<String, Function<ColumnDataType, AggregationUtils.Merger>> mergers) {
+    if 
(ValueWindowFunction.VALUE_WINDOW_FUNCTION_MAP.containsKey(functionName)) {
+      Preconditions.checkState(_windowFrame.getWindowFrameType() == 
WindowNode.WindowFrameType.RANGE,
+          String.format("Only RANGE type frames are supported at present for 
VALUE function: %s", functionName));
+      return;
+    }
     if (!mergers.containsKey(functionName)) {
       throw new IllegalStateException("Unexpected aggregation function name: " 
+ functionName);
     }
@@ -219,13 +225,18 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
       for (Map.Entry<Key, List<Object[]>> e : _partitionRows.entrySet()) {
         Key partitionKey = e.getKey();
         List<Object[]> rowList = e.getValue();
-        for (Object[] existingRow : rowList) {
+        for (int rowId = 0; rowId < rowList.size(); rowId++) {
+          Object[] existingRow = rowList.get(rowId);
           Object[] row = new Object[existingRow.length + _aggCalls.size()];
           Key orderKey = (_isPartitionByOnly && 
CollectionUtils.isEmpty(_orderSetInfo.getOrderSet())) ? emptyOrderKey
               : AggregationUtils.extractRowKey(existingRow, 
_orderSetInfo.getOrderSet());
           System.arraycopy(existingRow, 0, row, 0, existingRow.length);
           for (int i = 0; i < _windowAccumulators.length; i++) {
-            row[i + existingRow.length] = 
_windowAccumulators[i].getRangeResultForKeys(partitionKey, orderKey);
+            if (_windowAccumulators[i]._valueWindowFunction == null) {
+              row[i + existingRow.length] = 
_windowAccumulators[i].getRangeResultForKeys(partitionKey, orderKey);
+            } else {
+              row[i + existingRow.length] = 
_windowAccumulators[i].getValueResultForKeys(orderKey, rowId, rowList);
+            }
           }
           // Convert the results from Accumulator to the desired type
           TypeUtils.convertRow(row, resultStoredTypes);
@@ -288,7 +299,9 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
               : AggregationUtils.extractRowKey(row, 
_orderSetInfo.getOrderSet());
           int aggCallsSize = _aggCalls.size();
           for (int i = 0; i < aggCallsSize; i++) {
-            _windowAccumulators[i].accumulateRangeResults(key, orderKey, row);
+            if (_windowAccumulators[i]._valueWindowFunction == null) {
+              _windowAccumulators[i].accumulateRangeResults(key, orderKey, 
row);
+            }
           }
         }
       } else {
@@ -430,11 +443,15 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
   private static class WindowAggregateAccumulator extends 
AggregationUtils.Accumulator {
     private static final Map<String, Function<ColumnDataType, 
AggregationUtils.Merger>> WIN_AGG_MERGERS =
         ImmutableMap.<String, Function<ColumnDataType, 
AggregationUtils.Merger>>builder()
-            .putAll(AggregationUtils.Accumulator.MERGERS).put("ROW_NUMBER", 
cdt -> new MergeRowNumber())
-            .put("RANK", cdt -> new MergeRank()).put("DENSE_RANK", cdt -> new 
MergeDenseRank()).build();
+            .putAll(AggregationUtils.Accumulator.MERGERS)
+            .put("ROW_NUMBER", cdt -> new MergeRowNumber())
+            .put("RANK", cdt -> new MergeRank())
+            .put("DENSE_RANK", cdt -> new MergeDenseRank())
+            .build();
 
     private final boolean _isPartitionByOnly;
     private final boolean _isRankingWindowFunction;
+    private final ValueWindowFunction _valueWindowFunction;
 
     // Fields needed only for RANGE frame type queries (ORDER BY)
     private final Map<Key, OrderKeyResult> _orderByResults = new HashMap<>();
@@ -445,6 +462,7 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
       super(aggCall, merger, functionName, inputSchema);
       _isPartitionByOnly = CollectionUtils.isEmpty(orderSetInfo.getOrderSet()) 
|| orderSetInfo.isPartitionByOnly();
       _isRankingWindowFunction = RANKING_FUNCTION_NAMES.contains(functionName);
+      _valueWindowFunction = 
ValueWindowFunction.construnctValueWindowFunction(functionName);
     }
 
     /**
@@ -514,6 +532,14 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
       return _orderByResults;
     }
 
+    public Object getValueResultForKeys(Key orderKey, int rowId, 
List<Object[]> partitionRows) {
+      Object[] row = _valueWindowFunction.processRow(rowId, partitionRows);
+      if (row == null) {
+        return null;
+      }
+      return _inputRef == -1 ? _literal : row[_inputRef];
+    }
+
     static class OrderKeyResult {
       final Map<Key, Object> _orderByResults;
       Key _previousOrderByKey;
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
index ed9b0acba2..049da05220 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
@@ -223,7 +223,7 @@ public class AggregationUtils {
         _literal = ((RexExpression.Literal) rexExpression).getValue();
         _dataType = rexExpression.getDataType();
       }
-      _merger = merger.get(functionName).apply(_dataType);
+      _merger = merger.containsKey(functionName) ? 
merger.get(functionName).apply(_dataType) : null;
     }
 
     public void accumulate(Key key, Object[] row) {
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java
new file mode 100644
index 0000000000..5d2ae75950
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window;
+
+import java.util.ArrayList;
+import java.util.List;
+
+
+public class FirstValueWindowFunction extends ValueWindowFunction {
+
+  @Override
+  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
+    return partitionedRows.get(0);
+  }
+
+  @Override
+  public List<Object[]> processRows(List<Object[]> rows) {
+    List<Object[]> result = new ArrayList<>();
+    for (int i = 0; i < rows.size(); i++) {
+      result.add(rows.get(0));
+    }
+    return result;
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java
new file mode 100644
index 0000000000..9bca8ec930
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window;
+
+import java.util.ArrayList;
+import java.util.List;
+
+
+public class LagValueWindowFunction extends ValueWindowFunction {
+
+  @Override
+  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
+    if (rowId == 0) {
+      return null;
+    } else {
+      return partitionedRows.get(rowId - 1);
+    }
+  }
+
+  @Override
+  public List<Object[]> processRows(List<Object[]> rows) {
+    List<Object[]> result = new ArrayList<>();
+    for (int i = 0; i < rows.size(); i++) {
+      if (i == 0) {
+        result.add(null);
+      } else {
+        result.add(rows.get(i - 1));
+      }
+    }
+    return result;
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
new file mode 100644
index 0000000000..cc7db910d2
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window;
+
+import java.util.ArrayList;
+import java.util.List;
+
+
+public class LastValueWindowFunction extends ValueWindowFunction {
+
+  @Override
+  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
+    return partitionedRows.get(partitionedRows.size() - 1);
+  }
+
+  @Override
+  public List<Object[]> processRows(List<Object[]> rows) {
+    List<Object[]> result = new ArrayList<>();
+    for (int i = 0; i < rows.size(); i++) {
+      result.add(rows.get(rows.size() - 1));
+    }
+    return result;
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
new file mode 100644
index 0000000000..bd8a50ea48
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window;
+
+import java.util.ArrayList;
+import java.util.List;
+
+
+public class LeadValueWindowFunction extends ValueWindowFunction {
+
+  @Override
+  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
+    if (rowId == partitionedRows.size() - 1) {
+      return null;
+    } else {
+      return partitionedRows.get(rowId + 1);
+    }
+  }
+
+  @Override
+  public List<Object[]> processRows(List<Object[]> rows) {
+    List<Object[]> result = new ArrayList<>();
+    for (int i = 0; i < rows.size(); i++) {
+      if (i == rows.size() - 1) {
+        result.add(null);
+      } else {
+        result.add(rows.get(i + 1));
+      }
+    }
+    return result;
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java
new file mode 100644
index 0000000000..c327bcf0ba
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window;
+
+import com.google.common.collect.ImmutableMap;
+import java.lang.reflect.InvocationTargetException;
+import java.util.List;
+import java.util.Map;
+
+
+public abstract class ValueWindowFunction implements WindowFunction {
+  public static final Map<String, Class<? extends ValueWindowFunction>> 
VALUE_WINDOW_FUNCTION_MAP =
+      ImmutableMap.<String, Class<? extends ValueWindowFunction>>builder()
+          .put("LEAD", LeadValueWindowFunction.class)
+          .put("LAG", LagValueWindowFunction.class)
+          .put("FIRST_VALUE", FirstValueWindowFunction.class)
+          .put("LAST_VALUE", LastValueWindowFunction.class)
+          .build();
+
+  /**
+   * @param rowId           Row id to process
+   * @param partitionedRows List of rows for reference
+   * @return Row with the window function applied
+   */
+  public abstract Object[] processRow(int rowId, List<Object[]> 
partitionedRows);
+
+  public static ValueWindowFunction construnctValueWindowFunction(String 
functionName) {
+    Class<? extends ValueWindowFunction> valueWindowFunctionClass = 
VALUE_WINDOW_FUNCTION_MAP.get(functionName);
+    if (valueWindowFunctionClass == null) {
+      return null;
+    }
+    try {
+      return valueWindowFunctionClass.getDeclaredConstructor().newInstance();
+    } catch (InstantiationException | IllegalAccessException | 
InvocationTargetException | NoSuchMethodException e) {
+      throw new RuntimeException("Failed to instantiate ValueWindowFunction 
for function: " + functionName, e);
+    }
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java
new file mode 100644
index 0000000000..56d893badf
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window;
+
+import java.util.List;
+
+
+public interface WindowFunction {
+
+  /**
+   * @param rows List of rows to process
+   * @return List of rows with the window function applied
+   */
+  List<Object[]> processRows(List<Object[]> rows);
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to