This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 70bfd4185c Adding batch api support for WindowFunction (#12993)
70bfd4185c is described below

commit 70bfd4185ca18eb64658c5a35813e3578f7c843d
Author: Xiang Fu <xiangfu.1...@gmail.com>
AuthorDate: Tue May 7 03:44:29 2024 -0700

    Adding batch api support for WindowFunction (#12993)
---
 .../runtime/operator/WindowAggregateOperator.java  | 376 ++++-----------------
 .../runtime/operator/utils/AggregationUtils.java   |  22 +-
 .../operator/window/ValueWindowFunction.java       |  54 ---
 .../runtime/operator/window/WindowFunction.java    |  38 ++-
 .../operator/window/WindowFunctionFactory.java     |  60 ++++
 .../window/aggregate/AggregateWindowFunction.java  | 124 +++++++
 .../DenseRankWindowFunction.java}                  |  35 +-
 .../operator/window/range/RangeWindowFunction.java |  67 ++++
 .../RankWindowFunction.java}                       |  33 +-
 .../RowNumberWindowFunction.java}                  |  21 +-
 .../{ => value}/FirstValueWindowFunction.java      |  18 +-
 .../window/{ => value}/LagValueWindowFunction.java |  22 +-
 .../{ => value}/LastValueWindowFunction.java       |  18 +-
 .../{ => value}/LeadValueWindowFunction.java       |  22 +-
 .../operator/window/value/ValueWindowFunction.java |  47 +++
 .../operator/WindowAggregateOperatorTest.java      |  48 +--
 16 files changed, 496 insertions(+), 509 deletions(-)

diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
index c797607660..e2b989b76d 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
@@ -21,7 +21,6 @@ package org.apache.pinot.query.runtime.operator;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -29,7 +28,6 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.function.Function;
 import java.util.stream.Collectors;
 import javax.annotation.Nullable;
 import org.apache.calcite.rel.RelFieldCollation;
@@ -45,7 +43,8 @@ import 
org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
 import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
-import org.apache.pinot.query.runtime.operator.window.ValueWindowFunction;
+import org.apache.pinot.query.runtime.operator.window.WindowFunction;
+import org.apache.pinot.query.runtime.operator.window.WindowFunctionFactory;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -85,9 +84,9 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(WindowAggregateOperator.class);
 
   // List of window functions which can only be applied as ROWS window frame 
type
-  private static final Set<String> ROWS_ONLY_FUNCTION_NAMES = 
ImmutableSet.of("ROW_NUMBER");
+  public static final Set<String> ROWS_ONLY_FUNCTION_NAMES = 
ImmutableSet.of("ROW_NUMBER");
   // List of ranking window functions whose output depends on the ordering of 
input rows and not on the actual values
-  private static final Set<String> RANKING_FUNCTION_NAMES = 
ImmutableSet.of("RANK", "DENSE_RANK");
+  public static final Set<String> RANKING_FUNCTION_NAMES = 
ImmutableSet.of("RANK", "DENSE_RANK");
 
   private final MultiStageOperator _inputOperator;
   private final List<RexExpression> _groupSet;
@@ -96,7 +95,7 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
   private final List<RexExpression.FunctionCall> _aggCalls;
   private final List<RexExpression> _constants;
   private final DataSchema _resultSchema;
-  private final WindowAggregateAccumulator[] _windowAccumulators;
+  private final WindowFunction[] _windowFunctions;
   private final Map<Key, List<Object[]>> _partitionRows;
   private final boolean _isPartitionByOnly;
 
@@ -106,22 +105,12 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
   private TransferableBlock _eosBlock = null;
   private final StatMap<StatKey> _statMap = new StatMap<>(StatKey.class);
 
-  public WindowAggregateOperator(OpChainExecutionContext context, 
MultiStageOperator inputOperator,
-      List<RexExpression> groupSet, List<RexExpression> orderSet, 
List<RelFieldCollation.Direction> orderSetDirection,
-      List<RelFieldCollation.NullDirection> orderSetNullDirection, 
List<RexExpression> aggCalls, int lowerBound,
-      int upperBound, WindowNode.WindowFrameType windowFrameType, 
List<RexExpression> constants,
-      DataSchema resultSchema, DataSchema inputSchema) {
-    this(context, inputOperator, groupSet, orderSet, orderSetDirection, 
orderSetNullDirection, aggCalls, lowerBound,
-        upperBound, windowFrameType, constants, resultSchema, inputSchema, 
WindowAggregateAccumulator.WIN_AGG_MERGERS);
-  }
-
   @VisibleForTesting
   public WindowAggregateOperator(OpChainExecutionContext context, 
MultiStageOperator inputOperator,
       List<RexExpression> groupSet, List<RexExpression> orderSet, 
List<RelFieldCollation.Direction> orderSetDirection,
       List<RelFieldCollation.NullDirection> orderSetNullDirection, 
List<RexExpression> aggCalls, int lowerBound,
       int upperBound, WindowNode.WindowFrameType windowFrameType, 
List<RexExpression> constants,
-      DataSchema resultSchema, DataSchema inputSchema,
-      Map<String, Function<ColumnDataType, AggregationUtils.Merger>> mergers) {
+      DataSchema resultSchema, DataSchema inputSchema) {
     super(context);
 
     _inputOperator = inputOperator;
@@ -140,13 +129,13 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
     _constants = constants;
     _resultSchema = resultSchema;
 
-    _windowAccumulators = new WindowAggregateAccumulator[_aggCalls.size()];
+    _windowFunctions = new WindowFunction[_aggCalls.size()];
     int aggCallsSize = _aggCalls.size();
     for (int i = 0; i < aggCallsSize; i++) {
       RexExpression.FunctionCall agg = _aggCalls.get(i);
       String functionName = agg.getFunctionName();
-      validateAggregationCalls(functionName, mergers);
-      _windowAccumulators[i] = new WindowAggregateAccumulator(agg, mergers, 
functionName, inputSchema, _orderSetInfo);
+      validateAggregationCalls(functionName);
+      _windowFunctions[i] = 
WindowFunctionFactory.construnctWindowFunction(agg, inputSchema, _orderSetInfo);
     }
 
     _partitionRows = new HashMap<>();
@@ -187,25 +176,10 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
     if (_hasReturnedWindowAggregateBlock) {
       return _eosBlock;
     }
-    TransferableBlock finalBlock = consumeInputBlocks();
-    if (finalBlock.isErrorBlock()) {
-      return finalBlock;
-    }
-    _eosBlock = updateEosBlock(finalBlock, _statMap);
-    return produceWindowAggregatedBlock();
+    return computeBlocks();
   }
 
-  private void validateAggregationCalls(String functionName,
-      Map<String, Function<ColumnDataType, AggregationUtils.Merger>> mergers) {
-    if 
(ValueWindowFunction.VALUE_WINDOW_FUNCTION_MAP.containsKey(functionName)) {
-      Preconditions.checkState(_windowFrame.getWindowFrameType() == 
WindowNode.WindowFrameType.RANGE,
-          String.format("Only RANGE type frames are supported at present for 
VALUE function: %s", functionName));
-      return;
-    }
-    if (!mergers.containsKey(functionName)) {
-      throw new IllegalStateException("Unexpected aggregation function name: " 
+ functionName);
-    }
-
+  private void validateAggregationCalls(String functionName) {
     if (ROWS_ONLY_FUNCTION_NAMES.contains(functionName)) {
       Preconditions.checkState(
           _windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.ROWS 
&& _windowFrame.isUpperBoundCurrentRow(),
@@ -236,60 +210,54 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
     return partitionByInputRefIndexes.equals(orderByInputRefIndexes);
   }
 
-  private TransferableBlock produceWindowAggregatedBlock() {
-    Key emptyOrderKey = AggregationUtils.extractEmptyKey();
+  /**
+   * @return the final block, which must be either an end of stream or an 
error.
+   */
+  private TransferableBlock computeBlocks() {
+    TransferableBlock block = _inputOperator.nextBlock();
+    while (!TransferableBlockUtils.isEndOfStream(block)) {
+      List<Object[]> container = block.getContainer();
+      for (Object[] row : container) {
+        _numRows++;
+        // TODO: Revisit null direction handling for all query types
+        Key key = AggregationUtils.extractRowKey(row, _groupSet);
+        _partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row);
+      }
+      block = _inputOperator.nextBlock();
+    }
+    // Early termination if the block is an error block
+    if (block.isErrorBlock()) {
+      return block;
+    }
+    _eosBlock = updateEosBlock(block, _statMap);
+
     ColumnDataType[] resultStoredTypes = 
_resultSchema.getStoredColumnDataTypes();
     List<Object[]> rows = new ArrayList<>(_numRows);
-    if (_windowFrame.getWindowFrameType() == WindowNode.WindowFrameType.RANGE) 
{
-      // All aggregation window functions only support RANGE type today 
(SUM/AVG/MIN/MAX/COUNT/BOOL_AND/BOOL_OR)
-      // RANK and DENSE_RANK ranking window functions also only support RANGE 
type today
-      for (Map.Entry<Key, List<Object[]>> e : _partitionRows.entrySet()) {
-        Key partitionKey = e.getKey();
-        List<Object[]> rowList = e.getValue();
-        for (int rowId = 0; rowId < rowList.size(); rowId++) {
-          Object[] existingRow = rowList.get(rowId);
-          Object[] row = new Object[existingRow.length + _aggCalls.size()];
-          Key orderKey = (_isPartitionByOnly && 
CollectionUtils.isEmpty(_orderSetInfo.getOrderSet())) ? emptyOrderKey
-              : AggregationUtils.extractRowKey(existingRow, 
_orderSetInfo.getOrderSet());
-          System.arraycopy(existingRow, 0, row, 0, existingRow.length);
-          for (int i = 0; i < _windowAccumulators.length; i++) {
-            if (_windowAccumulators[i]._valueWindowFunction == null) {
-              row[i + existingRow.length] = 
_windowAccumulators[i].getRangeResultForKeys(partitionKey, orderKey);
-            } else {
-              row[i + existingRow.length] = 
_windowAccumulators[i].getValueResultForKeys(orderKey, rowId, rowList);
-            }
-          }
-          // Convert the results from Accumulator to the desired type
-          TypeUtils.convertRow(row, resultStoredTypes);
-          rows.add(row);
-        }
+    for (Map.Entry<Key, List<Object[]>> e : _partitionRows.entrySet()) {
+      List<Object[]> rowList = e.getValue();
+
+      // Each window function will return a list of results for each row in 
the input set
+      List<List<Object>> windowFunctionResults = new ArrayList<>();
+      for (WindowFunction windowFunction : _windowFunctions) {
+        List<Object> processRows = windowFunction.processRows(rowList);
+        Preconditions.checkState(processRows.size() == rowList.size(),
+            "Number of rows in the result set must match the number of rows in 
the input set");
+        windowFunctionResults.add(processRows);
       }
-    } else {
-      // Only ROW_NUMBER() window function is supported as ROWS type today
-      Key previousPartitionKey = null;
-      Object[] previousRowValues = new Object[_windowAccumulators.length];
-      for (int i = 0; i < _windowAccumulators.length; i++) {
-        previousRowValues[i] = null;
-      }
-      for (Map.Entry<Key, List<Object[]>> e : _partitionRows.entrySet()) {
-        Key partitionKey = e.getKey();
-        List<Object[]> rowList = e.getValue();
-        for (Object[] existingRow : rowList) {
-          Object[] row = new Object[existingRow.length + _aggCalls.size()];
-          System.arraycopy(existingRow, 0, row, 0, existingRow.length);
-          for (int i = 0; i < _windowAccumulators.length; i++) {
-            row[i + existingRow.length] =
-                
_windowAccumulators[i].computeRowResultForCurrentRow(partitionKey, 
previousPartitionKey, row,
-                    previousRowValues[i]);
-            previousRowValues[i] = row[i + existingRow.length];
-          }
-          // Convert the results from Accumulator to the desired type
-          TypeUtils.convertRow(row, resultStoredTypes);
-          rows.add(row);
-          previousPartitionKey = partitionKey;
+
+      for (int rowId = 0; rowId < rowList.size(); rowId++) {
+        Object[] existingRow = rowList.get(rowId);
+        Object[] row = new Object[existingRow.length + _aggCalls.size()];
+        System.arraycopy(existingRow, 0, row, 0, existingRow.length);
+        for (int i = 0; i < _windowFunctions.length; i++) {
+          row[i + existingRow.length] = 
windowFunctionResults.get(i).get(rowId);
         }
+        // Convert the results from WindowFunction to the desired type
+        TypeUtils.convertRow(row, resultStoredTypes);
+        rows.add(row);
       }
     }
+
     _hasReturnedWindowAggregateBlock = true;
     if (rows.isEmpty()) {
       return _eosBlock;
@@ -298,60 +266,20 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
     }
   }
 
-  /**
-   * @return the final block, which must be either an end of stream or an 
error.
-   */
-  private TransferableBlock consumeInputBlocks() {
-    Key emptyOrderKey = AggregationUtils.extractEmptyKey();
-    TransferableBlock block = _inputOperator.nextBlock();
-    while (!TransferableBlockUtils.isEndOfStream(block)) {
-      List<Object[]> container = block.getContainer();
-      if (_windowFrame.getWindowFrameType() == 
WindowNode.WindowFrameType.RANGE) {
-        // Only need to accumulate the aggregate function values for RANGE 
type. ROW type can be calculated as
-        // we output the rows since the aggregation value depends on the 
neighboring rows.
-        for (Object[] row : container) {
-          _numRows++;
-          // TODO: Revisit null direction handling for all query types
-          Key key = AggregationUtils.extractRowKey(row, _groupSet);
-          _partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row);
-          // Only need to accumulate the aggregate function values for RANGE 
type. ROW type can be calculated as
-          // we output the rows since the aggregation value depends on the 
neighboring rows.
-          Key orderKey = (_isPartitionByOnly && 
CollectionUtils.isEmpty(_orderSetInfo.getOrderSet())) ? emptyOrderKey
-              : AggregationUtils.extractRowKey(row, 
_orderSetInfo.getOrderSet());
-          int aggCallsSize = _aggCalls.size();
-          for (int i = 0; i < aggCallsSize; i++) {
-            if (_windowAccumulators[i]._valueWindowFunction == null) {
-              _windowAccumulators[i].accumulateRangeResults(key, orderKey, 
row);
-            }
-          }
-        }
-      } else {
-        for (Object[] row : container) {
-          _numRows++;
-          // TODO: Revisit null direction handling for all query types
-          Key key = AggregationUtils.extractRowKey(row, _groupSet);
-          _partitionRows.computeIfAbsent(key, k -> new ArrayList<>()).add(row);
-        }
-      }
-      block = _inputOperator.nextBlock();
-    }
-    return block;
-  }
-
   /**
    * Contains all the ORDER BY key related information such as the keys, 
direction, and null direction
    */
-  private static class OrderSetInfo {
+  public static class OrderSetInfo {
     // List of order keys
-    final List<RexExpression> _orderSet;
+    public final List<RexExpression> _orderSet;
     // List of order direction for each key
-    final List<RelFieldCollation.Direction> _orderSetDirection;
+    public final List<RelFieldCollation.Direction> _orderSetDirection;
     // List of null direction for each key
-    final List<RelFieldCollation.NullDirection> _orderSetNullDirection;
+    public final List<RelFieldCollation.NullDirection> _orderSetNullDirection;
     // Set to 'true' if this is a partition by only query
-    final boolean _isPartitionByOnly;
+    public final boolean _isPartitionByOnly;
 
-    OrderSetInfo(List<RexExpression> orderSet, 
List<RelFieldCollation.Direction> orderSetDirection,
+    public OrderSetInfo(List<RexExpression> orderSet, 
List<RelFieldCollation.Direction> orderSetDirection,
         List<RelFieldCollation.NullDirection> orderSetNullDirection, boolean 
isPartitionByOnly) {
       _orderSet = orderSet;
       _orderSetDirection = orderSetDirection;
@@ -359,19 +287,19 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
       _isPartitionByOnly = isPartitionByOnly;
     }
 
-    List<RexExpression> getOrderSet() {
+    public List<RexExpression> getOrderSet() {
       return _orderSet;
     }
 
-    List<RelFieldCollation.Direction> getOrderSetDirection() {
+    public List<RelFieldCollation.Direction> getOrderSetDirection() {
       return _orderSetDirection;
     }
 
-    List<RelFieldCollation.NullDirection> getOrderSetNullDirection() {
+    public List<RelFieldCollation.NullDirection> getOrderSetNullDirection() {
       return _orderSetNullDirection;
     }
 
-    boolean isPartitionByOnly() {
+    public boolean isPartitionByOnly() {
       return _isPartitionByOnly;
     }
   }
@@ -419,184 +347,6 @@ public class WindowAggregateOperator extends 
MultiStageOperator {
     }
   }
 
-  private static class MergeRowNumber implements AggregationUtils.Merger {
-
-    @Override
-    public Long init(@Nullable Object value, ColumnDataType dataType) {
-      return 1L;
-    }
-
-    @Override
-    public Long merge(Object agg, @Nullable Object value) {
-      return (long) agg + 1;
-    }
-  }
-
-  private static class MergeRank implements AggregationUtils.Merger {
-
-    @Override
-    public Long init(Object other, ColumnDataType dataType) {
-      return 1L;
-    }
-
-    @Override
-    public Long merge(Object left, Object right) {
-      // RANK always increase by the number of duplicate entries seen for the 
given ORDER BY key.
-      return ((Number) left).longValue() + ((Number) right).longValue();
-    }
-  }
-
-  private static class MergeDenseRank implements AggregationUtils.Merger {
-
-    @Override
-    public Long init(Object other, ColumnDataType dataType) {
-      return 1L;
-    }
-
-    @Override
-    public Long merge(Object left, Object right) {
-      long rightValueInLong = ((Number) right).longValue();
-      // DENSE_RANK always increase the rank by 1, irrespective of the number 
of duplicate ORDER BY keys seen
-      return (rightValueInLong == 0L) ? ((Number) left).longValue() : 
((Number) left).longValue() + 1L;
-    }
-  }
-
-  private static class WindowAggregateAccumulator extends 
AggregationUtils.Accumulator {
-    private static final Map<String, Function<ColumnDataType, 
AggregationUtils.Merger>> WIN_AGG_MERGERS =
-        ImmutableMap.<String, Function<ColumnDataType, 
AggregationUtils.Merger>>builder()
-            .putAll(AggregationUtils.Accumulator.MERGERS)
-            .put("ROW_NUMBER", cdt -> new MergeRowNumber())
-            .put("RANK", cdt -> new MergeRank())
-            .put("DENSE_RANK", cdt -> new MergeDenseRank())
-            .build();
-
-    private final boolean _isPartitionByOnly;
-    private final boolean _isRankingWindowFunction;
-    private final ValueWindowFunction _valueWindowFunction;
-
-    // Fields needed only for RANGE frame type queries (ORDER BY)
-    private final Map<Key, OrderKeyResult> _orderByResults = new HashMap<>();
-
-    WindowAggregateAccumulator(RexExpression.FunctionCall aggCall,
-        Map<String, Function<ColumnDataType, AggregationUtils.Merger>> merger, 
String functionName,
-        DataSchema inputSchema, OrderSetInfo orderSetInfo) {
-      super(aggCall, merger, functionName, inputSchema);
-      _isPartitionByOnly = CollectionUtils.isEmpty(orderSetInfo.getOrderSet()) 
|| orderSetInfo.isPartitionByOnly();
-      _isRankingWindowFunction = RANKING_FUNCTION_NAMES.contains(functionName);
-      _valueWindowFunction = 
ValueWindowFunction.construnctValueWindowFunction(functionName);
-    }
-
-    /**
-     * For ROW type queries the aggregation function value depends on the 
order of the rows rather than on the actual
-     * keys. For such queries compute the current row value based on the 
previous row and previous partition key.
-     * This should only be called for ROW type queries.
-     */
-    public Object computeRowResultForCurrentRow(Key currentPartitionKey, Key 
previousPartitionKey, Object[] row,
-        Object previousRowOutputValue) {
-      Object value = _inputRef == -1 ? _literal : row[_inputRef];
-      if (previousPartitionKey == null || 
!currentPartitionKey.equals(previousPartitionKey)) {
-        return _merger.init(currentPartitionKey, _dataType);
-      } else {
-        return _merger.merge(previousRowOutputValue, value);
-      }
-    }
-
-    /**
-     * For RANGE type queries, accumulate the function values for each 
PARTITION BY key and ORDER BY key based on
-     * the current row. Should only be called for RANGE type queries where the 
aggregation values are tied to the
-     * RANGE key and not to the row ordering. This should only be called for 
RANGE type queries.
-     */
-    public void accumulateRangeResults(Key key, Key orderKey, Object[] row) {
-      // Ranking functions don't use the row value, thus cannot reuse the 
AggregationUtils accumulate function for them
-      if (_isPartitionByOnly && !_isRankingWindowFunction) {
-        accumulate(key, row);
-        return;
-      }
-
-      // TODO: fix that single agg result (original type) has different type 
from multiple agg results (double).
-      Key previousOrderKeyIfPresent =
-          _orderByResults.get(key) == null ? null : 
_orderByResults.get(key).getPreviousOrderByKey();
-      Object currentRes = previousOrderKeyIfPresent == null ? null
-          : 
_orderByResults.get(key).getOrderByResults().get(previousOrderKeyIfPresent);
-      Object value = _inputRef == -1 ? _literal : row[_inputRef];
-
-      // The ranking functions do not depend on the actual value of the data, 
but are calculated based on the
-      // position of the data ordered by the ORDER BY key. Thus they need to 
be handled differently and require setting
-      // whether the rank has changed or not and if changed then by how much.
-      _orderByResults.putIfAbsent(key, new OrderKeyResult());
-      if (currentRes == null) {
-        value = _isRankingWindowFunction ? 0 : value;
-        _orderByResults.get(key).addOrderByResult(orderKey, 
_merger.init(value, _dataType));
-      } else {
-        Object mergedResult;
-        if (orderKey.equals(previousOrderKeyIfPresent)) {
-          value = _isRankingWindowFunction ? 0 : value;
-          mergedResult = _merger.merge(currentRes, value);
-        } else {
-          Object previousValue = 
_orderByResults.get(key).getOrderByResults().get(previousOrderKeyIfPresent);
-          value = _isRankingWindowFunction ? 
_orderByResults.get(key).getCountOfDuplicateOrderByKeys() : value;
-          mergedResult = _merger.merge(previousValue, value);
-        }
-        _orderByResults.get(key).addOrderByResult(orderKey, mergedResult);
-      }
-    }
-
-    public Object getRangeResultForKeys(Key key, Key orderKey) {
-      if (_isPartitionByOnly && !_isRankingWindowFunction) {
-        return _results.get(key);
-      } else {
-        return _orderByResults.get(key).getOrderByResults().get(orderKey);
-      }
-    }
-
-    public Map<Key, OrderKeyResult> getRangeOrderByResults() {
-      return _orderByResults;
-    }
-
-    public Object getValueResultForKeys(Key orderKey, int rowId, 
List<Object[]> partitionRows) {
-      Object[] row = _valueWindowFunction.processRow(rowId, partitionRows);
-      if (row == null) {
-        return null;
-      }
-      return _inputRef == -1 ? _literal : row[_inputRef];
-    }
-
-    static class OrderKeyResult {
-      final Map<Key, Object> _orderByResults;
-      Key _previousOrderByKey;
-      // Store the counts of duplicate ORDER BY keys seen for this PARTITION 
BY key for calculating RANK/DENSE_RANK
-      long _countOfDuplicateOrderByKeys;
-
-      OrderKeyResult() {
-        _orderByResults = new HashMap<>();
-        _previousOrderByKey = null;
-        _countOfDuplicateOrderByKeys = 0;
-      }
-
-      public void addOrderByResult(Key orderByKey, Object value) {
-        // We expect to get the rows in order based on the ORDER BY key so it 
is safe to blindly assign the
-        // current key as the previous key
-        _orderByResults.put(orderByKey, value);
-        _countOfDuplicateOrderByKeys =
-            (_previousOrderByKey != null && 
_previousOrderByKey.equals(orderByKey)) ? _countOfDuplicateOrderByKeys + 1
-                : 1;
-        _previousOrderByKey = orderByKey;
-      }
-
-      public Map<Key, Object> getOrderByResults() {
-        return _orderByResults;
-      }
-
-      public Key getPreviousOrderByKey() {
-        return _previousOrderByKey;
-      }
-
-      public long getCountOfDuplicateOrderByKeys() {
-        return _countOfDuplicateOrderByKeys;
-      }
-    }
-  }
-
   public enum StatKey implements StatMap.Key {
     EXECUTION_TIME_MS(StatMap.Type.LONG) {
       @Override
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
index 049da05220..0ea7b5df87 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
@@ -194,23 +194,17 @@ public class AggregationUtils {
     protected final int _inputRef;
     protected final Object _literal;
     protected final Map<Key, Object> _results = new HashMap<>();
-    protected final Merger _merger;
     protected final ColumnDataType _dataType;
 
     public Map<Key, Object> getResults() {
       return _results;
     }
 
-    public Merger getMerger() {
-      return _merger;
-    }
-
     public ColumnDataType getDataType() {
       return _dataType;
     }
 
-    public Accumulator(RexExpression.FunctionCall aggCall,
-        Map<String, Function<ColumnDataType, AggregationUtils.Merger>> merger, 
String functionName,
+    public Accumulator(RexExpression.FunctionCall aggCall, String functionName,
         DataSchema inputSchema) {
       // agg function operand should either be a InputRef or a Literal
       RexExpression rexExpression = toAggregationFunctionOperand(aggCall);
@@ -223,20 +217,6 @@ public class AggregationUtils {
         _literal = ((RexExpression.Literal) rexExpression).getValue();
         _dataType = rexExpression.getDataType();
       }
-      _merger = merger.containsKey(functionName) ? 
merger.get(functionName).apply(_dataType) : null;
-    }
-
-    public void accumulate(Key key, Object[] row) {
-      // TODO: fix that single agg result (original type) has different type 
from multiple agg results (double).
-      Object currentRes = _results.get(key);
-      Object value = _inputRef == -1 ? _literal : row[_inputRef];
-
-      if (currentRes == null) {
-        _results.put(key, _merger.init(value, _dataType));
-      } else {
-        Object mergedResult = _merger.merge(currentRes, value);
-        _results.put(key, mergedResult);
-      }
     }
 
     private RexExpression 
toAggregationFunctionOperand(RexExpression.FunctionCall rexExpression) {
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java
deleted file mode 100644
index c327bcf0ba..0000000000
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/ValueWindowFunction.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.query.runtime.operator.window;
-
-import com.google.common.collect.ImmutableMap;
-import java.lang.reflect.InvocationTargetException;
-import java.util.List;
-import java.util.Map;
-
-
-public abstract class ValueWindowFunction implements WindowFunction {
-  public static final Map<String, Class<? extends ValueWindowFunction>> 
VALUE_WINDOW_FUNCTION_MAP =
-      ImmutableMap.<String, Class<? extends ValueWindowFunction>>builder()
-          .put("LEAD", LeadValueWindowFunction.class)
-          .put("LAG", LagValueWindowFunction.class)
-          .put("FIRST_VALUE", FirstValueWindowFunction.class)
-          .put("LAST_VALUE", LastValueWindowFunction.class)
-          .build();
-
-  /**
-   * @param rowId           Row id to process
-   * @param partitionedRows List of rows for reference
-   * @return Row with the window function applied
-   */
-  public abstract Object[] processRow(int rowId, List<Object[]> 
partitionedRows);
-
-  public static ValueWindowFunction construnctValueWindowFunction(String 
functionName) {
-    Class<? extends ValueWindowFunction> valueWindowFunctionClass = 
VALUE_WINDOW_FUNCTION_MAP.get(functionName);
-    if (valueWindowFunctionClass == null) {
-      return null;
-    }
-    try {
-      return valueWindowFunctionClass.getDeclaredConstructor().newInstance();
-    } catch (InstantiationException | IllegalAccessException | 
InvocationTargetException | NoSuchMethodException e) {
-      throw new RuntimeException("Failed to instantiate ValueWindowFunction 
for function: " + functionName, e);
-    }
-  }
-}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java
index 56d893badf..6221caeae7 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java
@@ -19,13 +19,47 @@
 package org.apache.pinot.query.runtime.operator.window;
 
 import java.util.List;
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
+import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
 
 
-public interface WindowFunction {
+/**
+ * This class provides the basic structure for window functions. It provides 
the batch row processing API:
+ * processRows(List<Object[]> rows) which processes a batch of rows at a time.
+ *
+ */
+public abstract class WindowFunction extends AggregationUtils.Accumulator {
+  protected final String _functionName;
+  protected final int[] _inputRefs;
+  protected final boolean _isPartitionByOnly;
+  protected final List<RexExpression> _orderSet;
+
+  public WindowFunction(RexExpression.FunctionCall aggCall, String 
functionName,
+      DataSchema inputSchema, WindowAggregateOperator.OrderSetInfo 
orderSetInfo) {
+    super(aggCall, functionName, inputSchema);
+    _isPartitionByOnly = CollectionUtils.isEmpty(orderSetInfo.getOrderSet()) 
|| orderSetInfo.isPartitionByOnly();
+    boolean isRankingWindowFunction = 
WindowAggregateOperator.RANKING_FUNCTION_NAMES.contains(functionName);
+    int[] inputRefs = new int[]{_inputRef};
+    if (isRankingWindowFunction) {
+      inputRefs = 
orderSetInfo._orderSet.stream().map(RexExpression.InputRef.class::cast)
+          .mapToInt(RexExpression.InputRef::getIndex).toArray();
+    }
+    _functionName = functionName;
+    _inputRefs = inputRefs;
+    _orderSet = orderSetInfo._orderSet;
+  }
 
   /**
+   * Batch processing API for Window functions.
+   * This method processes a batch of rows at a time.
+   * Each row generates one object as output.
+   * Note, the input and output list size should be the same.
+   *
    * @param rows List of rows to process
    * @return List of rows with the window function applied
    */
-  List<Object[]> processRows(List<Object[]> rows);
+  public abstract List<Object> processRows(List<Object[]> rows);
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java
new file mode 100644
index 0000000000..7f2806b757
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunctionFactory.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window;
+
+import com.google.common.collect.ImmutableMap;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Map;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
+import 
org.apache.pinot.query.runtime.operator.window.aggregate.AggregateWindowFunction;
+import 
org.apache.pinot.query.runtime.operator.window.range.RangeWindowFunction;
+import 
org.apache.pinot.query.runtime.operator.window.value.ValueWindowFunction;
+
+
+/**
+ * Factory class to construct WindowFunction instances.
+ */
+public class WindowFunctionFactory {
+  private WindowFunctionFactory() {
+  }
+
+  public static final Map<String, Class<? extends WindowFunction>> 
WINDOW_FUNCTION_MAP =
+      ImmutableMap.<String, Class<? extends WindowFunction>>builder()
+          .putAll(RangeWindowFunction.WINDOW_FUNCTION_MAP)
+          .putAll(ValueWindowFunction.WINDOW_FUNCTION_MAP)
+          .build();
+
+  public static WindowFunction 
construnctWindowFunction(RexExpression.FunctionCall aggCall, DataSchema 
inputSchema,
+      WindowAggregateOperator.OrderSetInfo orderSetInfo) {
+    String functionName = aggCall.getFunctionName();
+    Class<? extends WindowFunction> windowFunctionClass =
+        WINDOW_FUNCTION_MAP.getOrDefault(functionName, 
AggregateWindowFunction.class);
+    try {
+      Constructor<? extends WindowFunction> constructor =
+          windowFunctionClass.getConstructor(RexExpression.FunctionCall.class, 
String.class, DataSchema.class,
+              WindowAggregateOperator.OrderSetInfo.class);
+      return constructor.newInstance(aggCall, functionName, inputSchema, 
orderSetInfo);
+    } catch (InstantiationException | IllegalAccessException | 
InvocationTargetException | NoSuchMethodException e) {
+      throw new RuntimeException("Failed to instantiate WindowFunction for 
function name: " + functionName, e);
+    }
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java
new file mode 100644
index 0000000000..8dd5c791e4
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java
@@ -0,0 +1,124 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window.aggregate;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.data.table.Key;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
+import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
+import org.apache.pinot.query.runtime.operator.window.WindowFunction;
+
+
+public class AggregateWindowFunction extends WindowFunction {
+  private final AggregationUtils.Merger _merger;
+
+  public AggregateWindowFunction(RexExpression.FunctionCall aggCall, String 
functionName,
+      DataSchema inputSchema, WindowAggregateOperator.OrderSetInfo 
orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
+    _merger = 
AggregationUtils.Accumulator.MERGERS.get(_functionName).apply(_dataType);
+  }
+
+  @Override
+  public final List<Object> processRows(List<Object[]> rows) {
+    if (_isPartitionByOnly) {
+      return processPartitionOnlyRows(rows);
+    } else {
+      return processRowsInternal(rows);
+    }
+  }
+
+  protected List<Object> processPartitionOnlyRows(List<Object[]> rows) {
+    Object mergedResult = null;
+    for (Object[] row : rows) {
+      Object value = _inputRef == -1 ? _literal : row[_inputRef];
+      if (value == null) {
+        continue;
+      }
+      if (mergedResult == null) {
+        mergedResult = _merger.init(value, _dataType);
+      } else {
+        mergedResult = _merger.merge(mergedResult, value);
+      }
+    }
+    return Collections.nCopies(rows.size(), mergedResult);
+  }
+
+  protected List<Object> processRowsInternal(List<Object[]> rows) {
+    Key emptyOrderKey = AggregationUtils.extractEmptyKey();
+    OrderKeyResult orderByResult = new OrderKeyResult();
+    for (Object[] row : rows) {
+      // Only need to accumulate the aggregate function values for RANGE type. 
ROW type can be calculated as
+      // we output the rows since the aggregation value depends on the 
neighboring rows.
+      Key orderKey = (_isPartitionByOnly && 
CollectionUtils.isEmpty(_orderSet)) ? emptyOrderKey
+          : AggregationUtils.extractRowKey(row, _orderSet);
+
+      Key previousOrderKeyIfPresent = orderByResult.getPreviousOrderByKey();
+      Object currentRes = previousOrderKeyIfPresent == null ? null
+          : orderByResult.getOrderByResults().get(previousOrderKeyIfPresent);
+      Object value = _inputRef == -1 ? _literal : row[_inputRef];
+      if (currentRes == null) {
+        orderByResult.addOrderByResult(orderKey, _merger.init(value, 
_dataType));
+      } else {
+        orderByResult.addOrderByResult(orderKey, _merger.merge(currentRes, 
value));
+      }
+    }
+    List<Object> results = new ArrayList<>(rows.size());
+    for (Object[] row : rows) {
+      // Only need to accumulate the aggregate function values for RANGE type. 
ROW type can be calculated as
+      // we output the rows since the aggregation value depends on the 
neighboring rows.
+      Key orderKey = (_isPartitionByOnly && 
CollectionUtils.isEmpty(_orderSet)) ? emptyOrderKey
+          : AggregationUtils.extractRowKey(row, _orderSet);
+      Object value = orderByResult.getOrderByResults().get(orderKey);
+      results.add(value);
+    }
+    return results;
+  }
+
+  static class OrderKeyResult {
+    final Map<Key, Object> _orderByResults;
+    Key _previousOrderByKey;
+
+    OrderKeyResult() {
+      _orderByResults = new HashMap<>();
+      _previousOrderByKey = null;
+    }
+
+    public void addOrderByResult(Key orderByKey, Object value) {
+      // We expect to get the rows in order based on the ORDER BY key so it is 
safe to blindly assign the
+      // current key as the previous key
+      _orderByResults.put(orderByKey, value);
+      _previousOrderByKey = orderByKey;
+    }
+
+    public Map<Key, Object> getOrderByResults() {
+      return _orderByResults;
+    }
+
+    public Key getPreviousOrderByKey() {
+      return _previousOrderByKey;
+    }
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java
similarity index 51%
copy from 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
copy to 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java
index bd8a50ea48..00f23f851a 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/DenseRankWindowFunction.java
@@ -16,32 +16,37 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.runtime.operator.window;
+package org.apache.pinot.query.runtime.operator.window.range;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
 
 
-public class LeadValueWindowFunction extends ValueWindowFunction {
+public class DenseRankWindowFunction extends RangeWindowFunction {
 
-  @Override
-  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
-    if (rowId == partitionedRows.size() - 1) {
-      return null;
-    } else {
-      return partitionedRows.get(rowId + 1);
-    }
+  public DenseRankWindowFunction(RexExpression.FunctionCall aggCall, String 
functionName, DataSchema inputSchema,
+      WindowAggregateOperator.OrderSetInfo orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
   }
 
   @Override
-  public List<Object[]> processRows(List<Object[]> rows) {
-    List<Object[]> result = new ArrayList<>();
-    for (int i = 0; i < rows.size(); i++) {
-      if (i == rows.size() - 1) {
-        result.add(null);
+  public List<Object> processRows(List<Object[]> rows) {
+    List<Object> result = new ArrayList<>();
+    int rank = 1;
+    Object[] lastRow = null;
+    for (Object[] row : rows) {
+      if (lastRow == null) {
+        result.add(rank);
       } else {
-        result.add(rows.get(i + 1));
+        if (compareRows(row, lastRow) != 0) {
+          rank++;
+        }
+        result.add(rank);
       }
+      lastRow = row;
     }
     return result;
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java
new file mode 100644
index 0000000000..a4ac37318f
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RangeWindowFunction.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window.range;
+
+import com.google.common.collect.ImmutableMap;
+import java.util.Map;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
+import org.apache.pinot.query.runtime.operator.window.WindowFunction;
+
+
+public abstract class RangeWindowFunction extends WindowFunction {
+  public static final Map<String, Class<? extends WindowFunction>> 
WINDOW_FUNCTION_MAP =
+      ImmutableMap.<String, Class<? extends WindowFunction>>builder()
+          // Range window functions
+          .put("ROW_NUMBER", RowNumberWindowFunction.class)
+          .put("RANK", RankWindowFunction.class)
+          .put("DENSE_RANK", DenseRankWindowFunction.class)
+          .build();
+
+  public RangeWindowFunction(RexExpression.FunctionCall aggCall, String 
functionName,
+      DataSchema inputSchema, WindowAggregateOperator.OrderSetInfo 
orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
+  }
+
+  protected int compareRows(Object[] leftRow, Object[] rightRow) {
+    for (int inputRef : _inputRefs) {
+      if (inputRef < 0) {
+        continue;
+      }
+      Object leftValue = leftRow[inputRef];
+      Object rightValue = rightRow[inputRef];
+      if (leftValue == null) {
+        if (rightValue != null) {
+          return -1;
+        }
+      } else {
+        if (rightValue == null) {
+          return 1;
+        } else {
+          int result = ((Comparable) leftValue).compareTo(rightValue);
+          if (result != 0) {
+            return result;
+          }
+        }
+      }
+    }
+    return 0;
+  }
+}
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java
similarity index 52%
copy from 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
copy to 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java
index bd8a50ea48..8688f70216 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RankWindowFunction.java
@@ -16,32 +16,35 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.runtime.operator.window;
+package org.apache.pinot.query.runtime.operator.window.range;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
 
 
-public class LeadValueWindowFunction extends ValueWindowFunction {
+public class RankWindowFunction extends RangeWindowFunction {
 
-  @Override
-  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
-    if (rowId == partitionedRows.size() - 1) {
-      return null;
-    } else {
-      return partitionedRows.get(rowId + 1);
-    }
+  public RankWindowFunction(RexExpression.FunctionCall aggCall, String 
functionName, DataSchema inputSchema,
+      WindowAggregateOperator.OrderSetInfo orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
   }
 
   @Override
-  public List<Object[]> processRows(List<Object[]> rows) {
-    List<Object[]> result = new ArrayList<>();
+  public List<Object> processRows(List<Object[]> rows) {
+    int rank = 1;
+    List<Object> result = new ArrayList<>();
     for (int i = 0; i < rows.size(); i++) {
-      if (i == rows.size() - 1) {
-        result.add(null);
-      } else {
-        result.add(rows.get(i + 1));
+      if (i > 0) {
+        Object[] prevRow = rows.get(i - 1);
+        Object[] currentRow = rows.get(i);
+        if (compareRows(prevRow, currentRow) != 0) {
+          rank = i + 1;
+        }
       }
+      result.add(rank);
     }
     return result;
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java
similarity index 56%
copy from 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
copy to 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java
index cc7db910d2..dd75d17f6c 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/range/RowNumberWindowFunction.java
@@ -16,24 +16,27 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.runtime.operator.window;
+package org.apache.pinot.query.runtime.operator.window.range;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
 
 
-public class LastValueWindowFunction extends ValueWindowFunction {
+public class RowNumberWindowFunction extends RangeWindowFunction {
 
-  @Override
-  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
-    return partitionedRows.get(partitionedRows.size() - 1);
+  public RowNumberWindowFunction(RexExpression.FunctionCall aggCall, String 
functionName, DataSchema inputSchema,
+      WindowAggregateOperator.OrderSetInfo orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
   }
 
   @Override
-  public List<Object[]> processRows(List<Object[]> rows) {
-    List<Object[]> result = new ArrayList<>();
-    for (int i = 0; i < rows.size(); i++) {
-      result.add(rows.get(rows.size() - 1));
+  public List<Object> processRows(List<Object[]> rows) {
+    List<Object> result = new ArrayList<>();
+    for (long i = 1; i <= rows.size(); i++) {
+      result.add(i);
     }
     return result;
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java
similarity index 61%
rename from 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java
rename to 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java
index 5d2ae75950..6894a156d6 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/FirstValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java
@@ -16,24 +16,28 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.runtime.operator.window;
+package org.apache.pinot.query.runtime.operator.window.value;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
 
 
 public class FirstValueWindowFunction extends ValueWindowFunction {
 
-  @Override
-  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
-    return partitionedRows.get(0);
+  public FirstValueWindowFunction(RexExpression.FunctionCall aggCall,
+      String functionName, DataSchema inputSchema,
+      WindowAggregateOperator.OrderSetInfo orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
   }
 
   @Override
-  public List<Object[]> processRows(List<Object[]> rows) {
-    List<Object[]> result = new ArrayList<>();
+  public List<Object> processRows(List<Object[]> rows) {
+    List<Object> result = new ArrayList<>();
     for (int i = 0; i < rows.size(); i++) {
-      result.add(rows.get(0));
+      result.add(extractValueFromRow(rows.get(0)));
     }
     return result;
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
similarity index 62%
rename from 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java
rename to 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
index 9bca8ec930..7e093ed792 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LagValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LagValueWindowFunction.java
@@ -16,31 +16,31 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.runtime.operator.window;
+package org.apache.pinot.query.runtime.operator.window.value;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
 
 
 public class LagValueWindowFunction extends ValueWindowFunction {
 
-  @Override
-  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
-    if (rowId == 0) {
-      return null;
-    } else {
-      return partitionedRows.get(rowId - 1);
-    }
+  public LagValueWindowFunction(RexExpression.FunctionCall aggCall,
+      String functionName, DataSchema inputSchema,
+      WindowAggregateOperator.OrderSetInfo orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
   }
 
   @Override
-  public List<Object[]> processRows(List<Object[]> rows) {
-    List<Object[]> result = new ArrayList<>();
+  public List<Object> processRows(List<Object[]> rows) {
+    List<Object> result = new ArrayList<>();
     for (int i = 0; i < rows.size(); i++) {
       if (i == 0) {
         result.add(null);
       } else {
-        result.add(rows.get(i - 1));
+        result.add(extractValueFromRow(rows.get(i - 1)));
       }
     }
     return result;
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java
similarity index 61%
rename from 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
rename to 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java
index cc7db910d2..bccafccf8a 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LastValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java
@@ -16,24 +16,28 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.runtime.operator.window;
+package org.apache.pinot.query.runtime.operator.window.value;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
 
 
 public class LastValueWindowFunction extends ValueWindowFunction {
 
-  @Override
-  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
-    return partitionedRows.get(partitionedRows.size() - 1);
+  public LastValueWindowFunction(RexExpression.FunctionCall aggCall,
+      String functionName, DataSchema inputSchema,
+      WindowAggregateOperator.OrderSetInfo orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
   }
 
   @Override
-  public List<Object[]> processRows(List<Object[]> rows) {
-    List<Object[]> result = new ArrayList<>();
+  public List<Object> processRows(List<Object[]> rows) {
+    List<Object> result = new ArrayList<>();
     for (int i = 0; i < rows.size(); i++) {
-      result.add(rows.get(rows.size() - 1));
+      result.add(extractValueFromRow(rows.get(rows.size() - 1)));
     }
     return result;
   }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
similarity index 63%
rename from 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
rename to 
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
index bd8a50ea48..4cbd917274 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/LeadValueWindowFunction.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LeadValueWindowFunction.java
@@ -16,31 +16,31 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.query.runtime.operator.window;
+package org.apache.pinot.query.runtime.operator.window.value;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
 
 
 public class LeadValueWindowFunction extends ValueWindowFunction {
 
-  @Override
-  public Object[] processRow(int rowId, List<Object[]> partitionedRows) {
-    if (rowId == partitionedRows.size() - 1) {
-      return null;
-    } else {
-      return partitionedRows.get(rowId + 1);
-    }
+  public LeadValueWindowFunction(RexExpression.FunctionCall aggCall,
+      String functionName, DataSchema inputSchema,
+      WindowAggregateOperator.OrderSetInfo orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
   }
 
   @Override
-  public List<Object[]> processRows(List<Object[]> rows) {
-    List<Object[]> result = new ArrayList<>();
+  public List<Object> processRows(List<Object[]> rows) {
+    List<Object> result = new ArrayList<>();
     for (int i = 0; i < rows.size(); i++) {
       if (i == rows.size() - 1) {
         result.add(null);
       } else {
-        result.add(rows.get(i + 1));
+        result.add(extractValueFromRow(rows.get(i + 1)));
       }
     }
     return result;
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java
new file mode 100644
index 0000000000..7226a926d4
--- /dev/null
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/ValueWindowFunction.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.window.value;
+
+import com.google.common.collect.ImmutableMap;
+import java.util.Map;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.WindowAggregateOperator;
+import org.apache.pinot.query.runtime.operator.window.WindowFunction;
+
+
+public abstract class ValueWindowFunction extends WindowFunction {
+  public static final Map<String, Class<? extends WindowFunction>> 
WINDOW_FUNCTION_MAP =
+      ImmutableMap.<String, Class<? extends WindowFunction>>builder()
+          // Value window functions
+          .put("LEAD", LeadValueWindowFunction.class)
+          .put("LAG", LagValueWindowFunction.class)
+          .put("FIRST_VALUE", FirstValueWindowFunction.class)
+          .put("LAST_VALUE", LastValueWindowFunction.class)
+          .build();
+
+  public ValueWindowFunction(RexExpression.FunctionCall aggCall, String 
functionName,
+      DataSchema inputSchema, WindowAggregateOperator.OrderSetInfo 
orderSetInfo) {
+    super(aggCall, functionName, inputSchema, orderSetInfo);
+  }
+
+  protected Object extractValueFromRow(Object[] row) {
+    return _inputRef == -1 ? _literal : (row == null ? null : row[_inputRef]);
+  }
+}
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
index 2bfca7c149..61df71d9ad 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
@@ -19,7 +19,6 @@
 package org.apache.pinot.query.runtime.operator;
 
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -36,7 +35,6 @@ import org.apache.pinot.query.routing.VirtualServerAddress;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockTestUtils;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
-import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
 import org.mockito.Mock;
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
@@ -230,44 +228,6 @@ public class WindowAggregateOperatorTest {
     Assert.assertTrue(block2.isEndOfStreamBlock(), "Second block is EOS (done 
processing)");
   }
 
-  @Test
-  public void testShouldCallMergerWhenWindowAggregatingMultipleRows() {
-    // Given:
-    List<RexExpression> calls = ImmutableList.of(getSum(new 
RexExpression.InputRef(1)));
-    List<RexExpression> group = ImmutableList.of(new 
RexExpression.InputRef(0));
-
-    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new 
ColumnDataType[]{INT, INT});
-    Mockito.when(_input.nextBlock())
-        .thenReturn(OperatorTestUtil.block(inSchema, new Object[]{1, 1}, new 
Object[]{1, 2}))
-        .thenReturn(OperatorTestUtil.block(inSchema, new Object[]{1, 3}))
-        
.thenReturn(TransferableBlockTestUtils.getEndOfStreamTransferableBlock(0));
-
-    AggregationUtils.Merger merger = 
Mockito.mock(AggregationUtils.Merger.class);
-    Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d);
-    Mockito.when(merger.init(Mockito.any(), Mockito.any())).thenReturn(1d);
-    DataSchema outSchema = new DataSchema(new String[]{"group", "arg", "sum"}, 
new ColumnDataType[]{INT, INT, DOUBLE});
-    WindowAggregateOperator operator =
-        new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), 
_input, group, Collections.emptyList(),
-            Collections.emptyList(), Collections.emptyList(), calls, 
Integer.MIN_VALUE, Integer.MAX_VALUE,
-            WindowNode.WindowFrameType.RANGE, Collections.emptyList(), 
outSchema, inSchema,
-            ImmutableMap.of("SUM", cdt -> merger));
-
-    // When:
-    TransferableBlock resultBlock = operator.nextBlock(); // (output result)
-
-    // Then:
-    // should call merger twice, one from second row in first block and two 
from the first row
-    // in second block
-    Mockito.verify(merger, Mockito.times(1)).init(Mockito.any(), 
Mockito.any());
-    Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), 
Mockito.any());
-    Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 1, 
12d},
-        "Expected three columns (original two columns, agg literal value)");
-    Assert.assertEquals(resultBlock.getContainer().get(1), new Object[]{1, 2, 
12d},
-        "Expected three columns (original two columns, agg literal value)");
-    Assert.assertEquals(resultBlock.getContainer().get(2), new Object[]{1, 3, 
12d},
-        "Expected three columns (original two columns, agg literal value)");
-  }
-
   @Test
   public void testPartitionByWindowAggregateWithHashCollision() {
     MultiStageOperator upstreamOperator = 
OperatorTestUtil.getOperator(OperatorTestUtil.OP_1);
@@ -292,8 +252,8 @@ public class WindowAggregateOperatorTest {
     Assert.assertEquals(resultRows.get(2), expectedRows.get(2));
   }
 
-  @Test(expectedExceptions = IllegalStateException.class, 
expectedExceptionsMessageRegExp = ".*Unexpected aggregation "
-      + "function name: AVERAGE.*")
+  @Test(expectedExceptions = RuntimeException.class, 
expectedExceptionsMessageRegExp = ".*Failed to instantiate "
+      + "WindowFunction for function name: AVERAGE.*")
   public void testShouldThrowOnUnknownAggFunction() {
     // Given:
     List<RexExpression> calls = ImmutableList.of(
@@ -309,8 +269,8 @@ public class WindowAggregateOperatorTest {
             WindowNode.WindowFrameType.RANGE, Collections.emptyList(), 
outSchema, inSchema);
   }
 
-  @Test(expectedExceptions = IllegalStateException.class, 
expectedExceptionsMessageRegExp = ".*Unexpected aggregation "
-      + "function name: NTILE.*")
+  @Test(expectedExceptions = RuntimeException.class, 
expectedExceptionsMessageRegExp = ".*Failed to instantiate "
+      + "WindowFunction for function name: NTILE.*")
   public void testShouldThrowOnUnknownRankAggFunction() {
     // TODO: Remove this test when support is added for NTILE function
     // Given:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to