This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 116c20ce0d2b [SPARK-51757][SQL] Fix LEAD/LAG Function Offset Exceeds 
Window Group Size
116c20ce0d2b is described below

commit 116c20ce0d2b3eda4aeb6d4092f22ce8c2a19b7a
Author: xin-aurora <56269194+xin-aur...@users.noreply.github.com>
AuthorDate: Sun Apr 27 20:20:36 2025 +0800

    [SPARK-51757][SQL] Fix LEAD/LAG Function Offset Exceeds Window Group Size
    
    ### What changes were proposed in this pull request?
    The current implementation of the `prepare` in 
`OffsetWindowFunctionFrameBase`:
    ```
      override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
        if (offset > rows.length) {
          fillDefaultValue(EmptyRow)
        } else {
        ...
      }
    ```
    The current implementation of the `write` in 
`FrameLessOffsetWindowFunctionFrame`:
    ```
       override def write(index: Int, current: InternalRow): Unit = {
        if (offset > rows.length) {
          // Already use default values in prepare.
        } else {
        ...
      }
    ```
    
    These implementations caused the `LEAD` and `LAG` functions to have 
`NullPointerException` when the default value is not Literal and the range of 
the default value exceeds the window group size.
    
    This pr introduced a boolean val `onlyLiteralNulls` and modified `prepare` 
and `write`.
    
    The `onlyLiteralNulls` indicated whether the default values are Literal 
values.
    
    In `prepare`, first check `onlyLiteralNulls`. If the default value is 
Literal, call `fillDefaultValue(EmptyRow)`.
    
    In `write`, if `onlyLiteralNulls ` is false, the default value must be 
non-literal, call `fillDefaultValue(current)`.
    
    ### Why are the changes needed?
    Fix `LEAD` and `LAG` cause NullPointerException in the window function 
(SPARK-51757)
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add test method in test("lead/lag with column reference as default when 
offset exceeds window group size") in 
org.apache.spark.sql.DataFrameWindowFramesSuite
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #50552 from xin-aurora/windowFuncFix.
    
    Lead-authored-by: xin-aurora <56269194+xin-aur...@users.noreply.github.com>
    Co-authored-by: Xin Zhang <xzhan...@ucr.edu>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/execution/window/WindowFunctionFrame.scala | 20 ++++++++++++--
 .../spark/sql/DataFrameWindowFramesSuite.scala     | 31 ++++++++++++++++++++++
 2 files changed, 49 insertions(+), 2 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 7fb6d3f36782..644603e4710f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -183,10 +183,21 @@ abstract class OffsetWindowFunctionFrameBase(
     }
   }
 
+  /** Indicates whether the default values are Literal. */
+  protected lazy val onlyLiterals = expressions.forall { e =>
+    e.default == null || e.default.foldable
+  }
+
   override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     resetStates(rows)
     if (absOffset > rows.length) {
-      fillDefaultValue(EmptyRow)
+      // Avoid evaluating non-literal defaults with EmptyRow,
+      // which causes NullPointerException.
+      // Check whether defaults are Literal.
+      if (onlyLiterals) {
+        fillDefaultValue(EmptyRow)
+      }
+      // Handle non-literal defaults in write().
     } else {
       if (ignoreNulls) {
         prepareForIgnoreNulls()
@@ -312,7 +323,12 @@ class FrameLessOffsetWindowFunctionFrame(
 
   override def write(index: Int, current: InternalRow): Unit = {
     if (absOffset > input.length) {
-      // Already use default values in prepare.
+      if (!onlyLiterals) {
+        // Handle non-literal defaults, e.g., column references
+        // Use default values since the offset row does not exist.
+        fillDefaultValue(current)
+      }
+      // Literal default values were already evaluated in prepare().
     } else {
       doWrite(current)
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
index 09a53edf9909..5eaf7a02a723 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
@@ -155,6 +155,37 @@ class DataFrameWindowFramesSuite extends QueryTest with 
SharedSparkSession {
         Row(2, default, default, default, default) :: Nil)
   }
 
+  test("lead/lag with column reference as default when offset exceeds window 
group size") {
+    val df = spark.range(0, 10, 1, 1).toDF("id")
+    val window = Window.partitionBy(expr("div(id, 2)")).orderBy($"id")
+
+    val result = df.select(
+      $"id",
+      lead($"id", 1, $"id").over(window).as("lead_1"),
+      lead($"id", 3, $"id").over(window).as("lead_3"),
+      lag($"id", 1, $"id").over(window).as("lag_1"),
+      lag($"id", 3, $"id").over(window).as("lag_3")
+    ).orderBy("id")
+
+    // check the output in one table
+    // col0: id, col1: lead_1 result, col2: lead_3 result,
+    // col3: lag_1 result, col4: lag_3 result
+    val expected = Seq(
+      Row(0, 1, 0, 0, 0),
+      Row(1, 1, 1, 0, 1),
+      Row(2, 3, 2, 2, 2),
+      Row(3, 3, 3, 2, 3),
+      Row(4, 5, 4, 4, 4),
+      Row(5, 5, 5, 4, 5),
+      Row(6, 7, 6, 6, 6),
+      Row(7, 7, 7, 6, 7),
+      Row(8, 9, 8, 8, 8),
+      Row(9, 9, 9, 8, 9)
+    )
+
+    checkAnswer(result, expected)
+  }
+
   test("rows/range between with empty data frame") {
     val df = Seq.empty[(String, Int)].toDF("key", "value")
     val window = Window.partitionBy($"key").orderBy($"value")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to