This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6775a1729dca [SPARK-55702][SQL] Support filter predicate in window 
aggregate functions
6775a1729dca is described below

commit 6775a1729dca25294bca9a33f987891ec58e031f
Author: Wenchen Fan <[email protected]>
AuthorDate: Fri Feb 27 09:01:09 2026 +0800

    [SPARK-55702][SQL] Support filter predicate in window aggregate functions
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for the `FILTER (WHERE ...)` clause on aggregate 
functions used within window expressions. Previously, Spark rejected this with 
an `AnalysisException` ("Window aggregate function with filter predicate is not 
supported yet.").
    
    The changes are:
    1. **Remove the analysis rejection** in `Analyzer.scala` that blocked 
`FILTER` in window aggregates, and extract filter expressions alongside 
aggregate function children.
    2. **Add filter support to `AggregateProcessor`** so that 
`AggregateExpression.filter` is honored during window frame evaluation:
       - For `DeclarativeAggregate`: update expressions are wrapped with 
`If(filter, updateExpr, bufferAttr)` to conditionally skip rows.
       - For `ImperativeAggregate`: the filter predicate is evaluated before 
calling `update()`.
    3. **Pass filter expressions from `WindowEvaluatorFactoryBase`** to 
`AggregateProcessor`.
    
    ### Why are the changes needed?
    
    The SQL standard allows `FILTER` on aggregate functions in window contexts. 
Other databases (PostgreSQL, etc.) support this. Spark already supports 
`FILTER` for regular (non-window) aggregates but rejected it in window contexts.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Window aggregate expressions with `FILTER` now execute instead of 
throwing an `AnalysisException`. For example:
    ```sql
    SELECT val, cate,
      sum(val) FILTER (WHERE val > 1) OVER (PARTITION BY cate ORDER BY val
        ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS 
running_sum_filtered
    FROM testData
    ```
    
    ### How was this patch tested?
    
    Added 4 SQL test cases in `window.sql` covering:
    - Running sum with filter
    - `first_value`/`last_value` with filter (verifying no interference with 
NULL handling)
    - Multiple aggregates with different filters in the same window
    - Entire partition frame with filter
    
    The existing test case (`count(val) FILTER (WHERE val > 1) OVER(...)`) now 
produces correct results instead of an error.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes.
    
    Made with [Cursor](https://cursor.com)
    
    Closes #54501 from cloud-fan/window-agg-filter.
    
    Authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   8 +-
 .../sql/catalyst/analysis/AnalysisErrorSuite.scala |  13 ---
 .../sql/execution/window/AggregateProcessor.scala  |  37 ++++++--
 .../window/WindowEvaluatorFactoryBase.scala        |   7 +-
 .../sql-tests/analyzer-results/window.sql.out      | 102 ++++++++++++++++++++-
 .../src/test/resources/sql-tests/inputs/window.sql |  34 ++++++-
 .../resources/sql-tests/results/window.sql.out     |  99 ++++++++++++++++++--
 7 files changed, 261 insertions(+), 39 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 04bad39a88ba..a1ec413933db 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3256,16 +3256,14 @@ class Analyzer(
             }
             wsc.copy(partitionSpec = newPartitionSpec, orderSpec = 
newOrderSpec)
 
-          case WindowExpression(ae: AggregateExpression, _) if 
ae.filter.isDefined =>
-            throw 
QueryCompilationErrors.windowAggregateFunctionWithFilterNotSupportedError()
-
           // Extract Windowed AggregateExpression
           case we @ WindowExpression(
-              ae @ AggregateExpression(function, _, _, _, _),
+              ae @ AggregateExpression(function, _, _, filter, _),
               spec: WindowSpecDefinition) =>
             val newChildren = function.children.map(extractExpr)
             val newFunction = 
function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
-            val newAgg = ae.copy(aggregateFunction = newFunction)
+            val newFilter = filter.map(extractExpr)
+            val newAgg = ae.copy(aggregateFunction = newFunction, filter = 
newFilter)
             seenWindowAggregates += newAgg
             WindowExpression(newAgg, spec)
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 94f650bc35c7..ee644fc62a1a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -208,19 +208,6 @@ class AnalysisErrorSuite extends AnalysisTest with 
DataTypeErrorsBase {
          | RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)"
          |""".stripMargin.replaceAll("\n", "")))
 
-  errorTest(
-    "window aggregate function with filter predicate",
-    testRelation2.select(
-      WindowExpression(
-        Count(UnresolvedAttribute("b"))
-          .toAggregateExpression(isDistinct = false, filter = 
Some(UnresolvedAttribute("b") > 1)),
-        WindowSpecDefinition(
-          UnresolvedAttribute("a") :: Nil,
-          SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
-          UnspecifiedFrame)).as("window")),
-    "window aggregate function with filter predicate is not supported" :: Nil
-  )
-
   test("distinct function") {
     assertAnalysisErrorCondition(
       CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
index 60ce6b068904..63934944c21a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
@@ -46,13 +46,17 @@ private[window] object AggregateProcessor {
       functions: Array[Expression],
       ordinal: Int,
       inputAttributes: Seq[Attribute],
-      newMutableProjection: (Seq[Expression], Seq[Attribute]) => 
MutableProjection)
+      newMutableProjection: (Seq[Expression], Seq[Attribute]) => 
MutableProjection,
+      filters: Array[Option[Expression]])
     : AggregateProcessor = {
+    assert(filters.length == functions.length,
+      s"filters length (${filters.length}) must match functions length 
(${functions.length})")
     val aggBufferAttributes = mutable.Buffer.empty[AttributeReference]
     val initialValues = mutable.Buffer.empty[Expression]
     val updateExpressions = mutable.Buffer.empty[Expression]
     val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp)
     val imperatives = mutable.Buffer.empty[ImperativeAggregate]
+    val imperativeFilterExprs = mutable.Buffer.empty[Option[Expression]]
 
     // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver 
side and then
     // serialized to executor side. These functions all reference a global 
singleton window
@@ -73,25 +77,34 @@ private[window] object AggregateProcessor {
     }
 
     // Add an AggregateFunction to the AggregateProcessor.
-    functions.foreach {
-      case agg: DeclarativeAggregate =>
+    functions.zip(filters).foreach {
+      case (agg: DeclarativeAggregate, filterOpt) =>
         aggBufferAttributes ++= agg.aggBufferAttributes
         initialValues ++= agg.initialValues
-        updateExpressions ++= agg.updateExpressions
+        filterOpt match {
+          case Some(filter) =>
+            updateExpressions ++= 
agg.updateExpressions.zip(agg.aggBufferAttributes).map {
+              case (updateExpr, attr) => If(filter, updateExpr, attr)
+            }
+          case None =>
+            updateExpressions ++= agg.updateExpressions
+        }
         evaluateExpressions += agg.evaluateExpression
-      case agg: ImperativeAggregate =>
+      case (agg: ImperativeAggregate, filterOpt) =>
         val offset = aggBufferAttributes.size
         val imperative = BindReferences.bindReference(agg
           .withNewInputAggBufferOffset(offset)
           .withNewMutableAggBufferOffset(offset),
           inputAttributes)
         imperatives += imperative
+        imperativeFilterExprs += filterOpt.map(f =>
+          BindReferences.bindReference(f, inputAttributes))
         aggBufferAttributes ++= imperative.aggBufferAttributes
         val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp)
         initialValues ++= noOps
         updateExpressions ++= noOps
         evaluateExpressions += imperative
-      case other =>
+      case (other, _) =>
         throw SparkException.internalError(s"Unsupported aggregate function: 
$other")
     }
 
@@ -108,6 +121,7 @@ private[window] object AggregateProcessor {
       updateProj,
       evalProj,
       imperatives.toArray,
+      imperativeFilterExprs.toArray,
       partitionSize.isDefined)
   }
 }
@@ -122,6 +136,7 @@ private[window] final class AggregateProcessor(
     private[this] val updateProjection: MutableProjection,
     private[this] val evaluateProjection: MutableProjection,
     private[this] val imperatives: Array[ImperativeAggregate],
+    private[this] val imperativeFilters: Array[Option[Expression]],
     private[this] val trackPartitionSize: Boolean) {
 
   private[this] val join = new JoinedRow
@@ -152,7 +167,15 @@ private[window] final class AggregateProcessor(
     updateProjection(join(buffer, input))
     var i = 0
     while (i < numImperatives) {
-      imperatives(i).update(buffer, input)
+      val shouldUpdate = imperativeFilters(i) match {
+        case Some(filter) =>
+          val result = filter.eval(input)
+          result != null && result.asInstanceOf[Boolean]
+        case None => true
+      }
+      if (shouldUpdate) {
+        imperatives(i).update(buffer, input)
+      }
       i += 1
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala
index c2dedda832e2..9930c4a8963a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactoryBase.scala
@@ -194,12 +194,17 @@ trait WindowEvaluatorFactoryBase {
         def processor = if 
(functions.exists(_.isInstanceOf[PythonFuncExpression])) {
           null
         } else {
+          val aggFilters = expressions.map {
+            case WindowExpression(ae: AggregateExpression, _) => ae.filter
+            case _ => None
+          }.toArray
           AggregateProcessor(
             functions,
             ordinal,
             childOutput,
             (expressions, schema) =>
-              MutableProjection.create(expressions, schema))
+              MutableProjection.create(expressions, schema),
+            aggFilters)
         }
 
         // Create the factory to produce WindowFunctionFrame.
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
index a441256d3bf0..76c0fb1919ce 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out
@@ -688,13 +688,105 @@ Project [cate#x, sum(val) OVER (PARTITION BY cate ORDER 
BY val ASC NULLS FIRST R
 
 -- !query
 SELECT val, cate,
-count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate)
+first_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_a,
+last_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_a
+FROM testData ORDER BY val_long, cate
+-- !query analysis
+Project [val#x, cate#x, first_a#x, last_a#x]
++- Sort [val_long#xL ASC NULLS FIRST, cate#x ASC NULLS FIRST], true
+   +- Project [val#x, cate#x, first_a#x, last_a#x, val_long#xL]
+      +- Project [val#x, cate#x, _w0#x, val_long#xL, first_a#x, last_a#x, 
first_a#x, last_a#x]
+         +- Window [first_value(val#x, false) FILTER (WHERE _w0#x) 
windowspecdefinition(val_long#xL ASC NULLS FIRST, 
specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS 
first_a#x, last_value(val#x, false) FILTER (WHERE _w0#x) 
windowspecdefinition(val_long#xL ASC NULLS FIRST, 
specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS 
last_a#x], [val_long#xL ASC NULLS FIRST]
+            +- Project [val#x, cate#x, (cate#x = a) AS _w0#x, val_long#xL]
+               +- SubqueryAlias testdata
+                  +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
+                     +- Project [cast(val#x as int) AS val#x, cast(val_long#xL 
as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, 
cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS 
val_timestamp#x, cast(cate#x as string) AS cate#x]
+                        +- Project [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x]
+                           +- SubqueryAlias testData
+                              +- LocalRelation [val#x, val_long#xL, 
val_double#x, val_date#x, val_timestamp#x, cate#x]
+
+
+-- !query
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_a,
+sum(val) FILTER (WHERE cate = 'b') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_b,
+count(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt_gt1
+FROM testData ORDER BY val_long, cate
+-- !query analysis
+Project [val#x, cate#x, sum_a#xL, sum_b#xL, cnt_gt1#xL]
++- Sort [val_long#xL ASC NULLS FIRST, cate#x ASC NULLS FIRST], true
+   +- Project [val#x, cate#x, sum_a#xL, sum_b#xL, cnt_gt1#xL, val_long#xL]
+      +- Project [val#x, cate#x, _w0#x, val_long#xL, _w2#x, _w3#x, sum_a#xL, 
sum_b#xL, cnt_gt1#xL, sum_a#xL, sum_b#xL, cnt_gt1#xL]
+         +- Window [sum(val#x) FILTER (WHERE _w0#x) 
windowspecdefinition(val_long#xL ASC NULLS FIRST, 
specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS 
sum_a#xL, sum(val#x) FILTER (WHERE _w2#x) windowspecdefinition(val_long#xL ASC 
NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), 
currentrow$())) AS sum_b#xL, count(val#x) FILTER (WHERE _w3#x) 
windowspecdefinition(val_long#xL ASC NULLS FIRST, 
specifiedwindowframe(RowFrame, unboundedpreceding$(), curr [...]
+            +- Project [val#x, cate#x, (cate#x = a) AS _w0#x, val_long#xL, 
(cate#x = b) AS _w2#x, (val#x > 1) AS _w3#x]
+               +- SubqueryAlias testdata
+                  +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
+                     +- Project [cast(val#x as int) AS val#x, cast(val_long#xL 
as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, 
cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS 
val_timestamp#x, cast(cate#x as string) AS cate#x]
+                        +- Project [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x]
+                           +- SubqueryAlias testData
+                              +- LocalRelation [val#x, val_long#xL, 
val_double#x, val_date#x, val_timestamp#x, cate#x]
+
+
+-- !query
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(PARTITION BY cate) AS 
total_sum_filtered
 FROM testData ORDER BY cate, val
 -- !query analysis
-org.apache.spark.sql.AnalysisException
-{
-  "errorClass" : "_LEGACY_ERROR_TEMP_1030"
-}
+Sort [cate#x ASC NULLS FIRST, val#x ASC NULLS FIRST], true
++- Project [val#x, cate#x, total_sum_filtered#xL]
+   +- Project [val#x, cate#x, _w0#x, total_sum_filtered#xL, 
total_sum_filtered#xL]
+      +- Window [sum(val#x) FILTER (WHERE _w0#x) windowspecdefinition(cate#x, 
specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) 
AS total_sum_filtered#xL], [cate#x]
+         +- Project [val#x, cate#x, (cate#x = a) AS _w0#x]
+            +- SubqueryAlias testdata
+               +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
+                  +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as 
bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, 
cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS 
val_timestamp#x, cast(cate#x as string) AS cate#x]
+                     +- Project [val#x, val_long#xL, val_double#x, val_date#x, 
val_timestamp#x, cate#x]
+                        +- SubqueryAlias testData
+                           +- LocalRelation [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x]
+
+
+-- !query
+SELECT val, cate,
+sum(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
+  ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sliding_sum_filtered
+FROM testData ORDER BY val_long, cate
+-- !query analysis
+Project [val#x, cate#x, sliding_sum_filtered#xL]
++- Sort [val_long#xL ASC NULLS FIRST, cate#x ASC NULLS FIRST], true
+   +- Project [val#x, cate#x, sliding_sum_filtered#xL, val_long#xL]
+      +- Project [val#x, cate#x, _w0#x, val_long#xL, sliding_sum_filtered#xL, 
sliding_sum_filtered#xL]
+         +- Window [sum(val#x) FILTER (WHERE _w0#x) 
windowspecdefinition(val_long#xL ASC NULLS FIRST, 
specifiedwindowframe(RowFrame, -1, 1)) AS sliding_sum_filtered#xL], 
[val_long#xL ASC NULLS FIRST]
+            +- Project [val#x, cate#x, (val#x > 1) AS _w0#x, val_long#xL]
+               +- SubqueryAlias testdata
+                  +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
+                     +- Project [cast(val#x as int) AS val#x, cast(val_long#xL 
as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, 
cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS 
val_timestamp#x, cast(cate#x as string) AS cate#x]
+                        +- Project [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x]
+                           +- SubqueryAlias testData
+                              +- LocalRelation [val#x, val_long#xL, 
val_double#x, val_date#x, val_timestamp#x, cate#x]
+
+
+-- !query
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val
+  RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS range_sum_filtered
+FROM testData ORDER BY val, cate
+-- !query analysis
+Sort [val#x ASC NULLS FIRST, cate#x ASC NULLS FIRST], true
++- Project [val#x, cate#x, range_sum_filtered#xL]
+   +- Project [val#x, cate#x, _w0#x, range_sum_filtered#xL, 
range_sum_filtered#xL]
+      +- Window [sum(val#x) FILTER (WHERE _w0#x) windowspecdefinition(val#x 
ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), 
currentrow$())) AS range_sum_filtered#xL], [val#x ASC NULLS FIRST]
+         +- Project [val#x, cate#x, (cate#x = a) AS _w0#x]
+            +- SubqueryAlias testdata
+               +- View (`testData`, [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x])
+                  +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as 
bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, 
cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS 
val_timestamp#x, cast(cate#x as string) AS cate#x]
+                     +- Project [val#x, val_long#xL, val_double#x, val_date#x, 
val_timestamp#x, cate#x]
+                        +- SubqueryAlias testData
+                           +- LocalRelation [val#x, val_long#xL, val_double#x, 
val_date#x, val_timestamp#x, cate#x]
 
 
 -- !query
diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql 
b/sql/core/src/test/resources/sql-tests/inputs/window.sql
index bec79247f9a6..586fe88ac305 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/window.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql
@@ -182,11 +182,41 @@ FROM testData
 WHERE val is not null
 WINDOW w AS (PARTITION BY cate ORDER BY val);
 
--- with filter predicate
+-- window aggregate with filter predicate: first_value/last_value (imperative 
aggregate)
 SELECT val, cate,
-count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate)
+first_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_a,
+last_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_a
+FROM testData ORDER BY val_long, cate;
+
+-- window aggregate with filter predicate: multiple aggregates with different 
filters
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_a,
+sum(val) FILTER (WHERE cate = 'b') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_b,
+count(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt_gt1
+FROM testData ORDER BY val_long, cate;
+
+-- window aggregate with filter predicate: entire partition frame
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(PARTITION BY cate) AS 
total_sum_filtered
 FROM testData ORDER BY cate, val;
 
+-- window aggregate with filter predicate: sliding window (ROWS frame)
+SELECT val, cate,
+sum(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
+  ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sliding_sum_filtered
+FROM testData ORDER BY val_long, cate;
+
+-- window aggregate with filter predicate: RANGE frame
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val
+  RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS range_sum_filtered
+FROM testData ORDER BY val, cate;
+
 -- nth_value()/first_value()/any_value() over ()
 SELECT
     employee_name,
diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out 
b/sql/core/src/test/resources/sql-tests/results/window.sql.out
index ce88fb57f8aa..44c3b175868d 100644
--- a/sql/core/src/test/resources/sql-tests/results/window.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out
@@ -669,15 +669,102 @@ b        6
 
 -- !query
 SELECT val, cate,
-count(val) FILTER (WHERE val > 1) OVER(PARTITION BY cate)
+first_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_a,
+last_value(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_a
+FROM testData ORDER BY val_long, cate
+-- !query schema
+struct<val:int,cate:string,first_a:int,last_a:int>
+-- !query output
+NULL   NULL    1       NULL
+1      b       1       NULL
+3      NULL    1       1
+NULL   a       1       NULL
+1      a       1       1
+1      a       1       1
+2      b       1       1
+2      a       1       2
+3      b       1       2
+
+
+-- !query
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_a,
+sum(val) FILTER (WHERE cate = 'b') OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_b,
+count(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
+  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cnt_gt1
+FROM testData ORDER BY val_long, cate
+-- !query schema
+struct<val:int,cate:string,sum_a:bigint,sum_b:bigint,cnt_gt1:bigint>
+-- !query output
+NULL   NULL    NULL    1       0
+1      b       NULL    1       0
+3      NULL    1       1       1
+NULL   a       NULL    1       0
+1      a       1       1       0
+1      a       2       1       1
+2      b       2       3       2
+2      a       4       3       3
+3      b       4       6       4
+
+
+-- !query
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(PARTITION BY cate) AS 
total_sum_filtered
 FROM testData ORDER BY cate, val
 -- !query schema
-struct<>
+struct<val:int,cate:string,total_sum_filtered:bigint>
 -- !query output
-org.apache.spark.sql.AnalysisException
-{
-  "errorClass" : "_LEGACY_ERROR_TEMP_1030"
-}
+NULL   NULL    NULL
+3      NULL    NULL
+NULL   a       4
+1      a       4
+1      a       4
+2      a       4
+1      b       NULL
+2      b       NULL
+3      b       NULL
+
+
+-- !query
+SELECT val, cate,
+sum(val) FILTER (WHERE val > 1) OVER(ORDER BY val_long
+  ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sliding_sum_filtered
+FROM testData ORDER BY val_long, cate
+-- !query schema
+struct<val:int,cate:string,sliding_sum_filtered:bigint>
+-- !query output
+NULL   NULL    NULL
+1      b       NULL
+3      NULL    3
+NULL   a       NULL
+1      a       3
+1      a       5
+2      b       4
+2      a       7
+3      b       5
+
+
+-- !query
+SELECT val, cate,
+sum(val) FILTER (WHERE cate = 'a') OVER(ORDER BY val
+  RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS range_sum_filtered
+FROM testData ORDER BY val, cate
+-- !query schema
+struct<val:int,cate:string,range_sum_filtered:bigint>
+-- !query output
+NULL   NULL    NULL
+NULL   a       NULL
+1      a       2
+1      a       2
+1      b       2
+2      a       4
+2      b       4
+3      NULL    4
+3      b       4
 
 
 -- !query


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to