This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new 4d3cd058f065 [SPARK-53991][SQL][FOLLOWUP] Enforce 
KLL_SKETCH_AGG_GET_RANK/QUANTILE arguments are foldable
4d3cd058f065 is described below

commit 4d3cd058f065a66919d346c398f81692b4021dcb
Author: Daniel Tenedorio <[email protected]>
AuthorDate: Mon Dec 15 13:51:10 2025 +0800

    [SPARK-53991][SQL][FOLLOWUP] Enforce KLL_SKETCH_AGG_GET_RANK/QUANTILE 
arguments are foldable
    
    ### What changes were proposed in this pull request?
    
    This PR adds a restriction that the rank/quantile arguments to 
`kll_sketch_get_quantile_*` and `kll_sketch_get_rank_*` functions must be 
foldable (compile-time constants).
    
    **Changes:**
    
    1. Added `checkInputDataTypes()` validation to `KllSketchGetQuantileBase` 
and `KllSketchGetRankBase` that returns a `NON_FOLDABLE_INPUT` error if the 
rank/quantile argument is not foldable.
    
    2. Fixed a bug in `nullSafeEval` where `right.eval()` was being called 
instead of using the already-evaluated `rightInput` parameter.
    
    3. Added negative test cases in `kllquantiles.sql` to verify the new 
foldability restriction for:
       - Non-foldable scalar rank argument to `kll_sketch_get_quantile_bigint`
       - Non-foldable array rank argument to `kll_sketch_get_quantile_bigint`
       - Non-foldable scalar quantile argument to `kll_sketch_get_rank_bigint`
       - Non-foldable array quantile argument to `kll_sketch_get_rank_bigint`
    
    ### Why are the changes needed?
    
    The foldability restriction is reasonable to have, since this is the 
intended usage of the functions, and is consistent with other existing 
functions in Spark.
    
    Additionally, fixing the `right.eval()` bug ensures we use the 
already-evaluated value passed to `nullSafeEval`, avoiding redundant evaluation.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Users will now receive an error if they pass a non-constant expression 
(e.g., a column reference) as the rank argument to `kll_sketch_get_quantile_*` 
or the quantile argument to `kll_sketch_get_rank_*`. For example:
    
    -- This will now fail with NON_FOLDABLE_INPUT error
    SELECT kll_sketch_get_quantile_bigint(sketch, col / 10.0) FROM table;
    
    -- This is allowed (constant expression)
    SELECT kll_sketch_get_quantile_bigint(sketch, 0.5) FROM table;
    SELECT kll_sketch_get_quantile_bigint(sketch, array(0.25, 0.5, 0.75)) FROM 
table;### How was this patch tested?
    
    Added negative test cases in `kllquantiles.sql` that exercise the new 
foldability restriction for both scalar and array inputs. The tests verify that 
appropriate error messages are returned for non-foldable arguments.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, with assistance from `claude-4.5-opus-high` with manual review and 
adjustment.
    
    Closes #53463 from dtenedor/kll-get-quantile-rank-check-foldable.
    
    Lead-authored-by: Daniel Tenedorio <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 392666f167ba47a9a1e31d9a39cb3fb03369caa5)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/expressions/kllExpressions.scala  |  36 ++++++-
 .../analyzer-results/kllquantiles.sql.out          | 112 +++++++++++++++++++
 .../resources/sql-tests/inputs/kllquantiles.sql    |  35 ++++++
 .../sql-tests/results/kllquantiles.sql.out         | 120 +++++++++++++++++++++
 .../apache/spark/sql/DataFrameAggregateSuite.scala |   7 +-
 5 files changed, 305 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala
index 18a9fc6e1f19..0556ef118e02 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.datasketches.kll.{KllDoublesSketch, KllFloatsSketch, 
KllLongsSketch}
 import org.apache.datasketches.memory.Memory
 
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, 
toSQLType}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -466,6 +468,21 @@ abstract class KllSketchGetQuantileBase
   /** The output data type for a single value (not array) */
   protected def outputDataType: DataType
 
+  // The rank argument must be foldable (compile-time constant).
+  // This enables Photon to efficiently handle array outputs with a known 
constant size.
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (!right.foldable) {
+      TypeCheckResult.DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("rank"),
+          "inputType" -> toSQLType(right.dataType),
+          "inputExpr" -> toSQLExpr(right)))
+    } else {
+      super.checkInputDataTypes()
+    }
+  }
+
   override def nullIntolerant: Boolean = true
   override def inputTypes: Seq[AbstractDataType] =
     Seq(
@@ -485,7 +502,7 @@ abstract class KllSketchGetQuantileBase
     val buffer = leftInput.asInstanceOf[Array[Byte]]
     val memory = Memory.wrap(buffer)
 
-    right.eval() match {
+    rightInput match {
       case null => null
       case num: Double =>
         // Single value case
@@ -617,6 +634,21 @@ abstract class KllSketchGetRankBase
    */
   protected def kllSketchGetRank(memory: Memory, quantile: Any): Double
 
+  // The quantile argument must be foldable (compile-time constant).
+  // This enables Photon to efficiently handle array outputs with a known 
constant size.
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (!right.foldable) {
+      TypeCheckResult.DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("quantile"),
+          "inputType" -> toSQLType(right.dataType),
+          "inputExpr" -> toSQLExpr(right)))
+    } else {
+      super.checkInputDataTypes()
+    }
+  }
+
   override def nullIntolerant: Boolean = true
   override def inputTypes: Seq[AbstractDataType] = {
     Seq(
@@ -636,7 +668,7 @@ abstract class KllSketchGetRankBase
     val buffer: Array[Byte] = leftInput.asInstanceOf[Array[Byte]]
     val memory: Memory = Memory.wrap(buffer)
 
-    right.eval() match {
+    rightInput match {
       case null => null
       case value if !value.isInstanceOf[ArrayData] =>
         // Single value case
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/kllquantiles.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/kllquantiles.sql.out
index 64fc8998c9e4..dc22199985f0 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/kllquantiles.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/kllquantiles.sql.out
@@ -1294,6 +1294,118 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 }
 
 
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, CAST(col1 AS DOUBLE) / 10.0) AS 
non_foldable_scalar_rank
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"(CAST(col1 AS DOUBLE) / 10.0)\"",
+    "inputName" : "`rank`",
+    "inputType" : "\"DOUBLE\"",
+    "sqlExpr" : "\"kll_sketch_get_quantile_bigint(agg, (CAST(col1 AS DOUBLE) / 
10.0))\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 71,
+    "fragment" : "kll_sketch_get_quantile_bigint(agg, CAST(col1 AS DOUBLE) / 
10.0)"
+  } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, array(0.25, CAST(col1 AS DOUBLE) / 
10.0, 0.75)) AS non_foldable_array_rank
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"array(0.25, (CAST(col1 AS DOUBLE) / 10.0), 0.75)\"",
+    "inputName" : "`rank`",
+    "inputType" : "\"ARRAY<DOUBLE>\"",
+    "sqlExpr" : "\"kll_sketch_get_quantile_bigint(agg, array(0.25, (CAST(col1 
AS DOUBLE) / 10.0), 0.75))\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 90,
+    "fragment" : "kll_sketch_get_quantile_bigint(agg, array(0.25, CAST(col1 AS 
DOUBLE) / 10.0, 0.75))"
+  } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_rank_bigint(agg, col1) AS non_foldable_scalar_quantile
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"col1\"",
+    "inputName" : "`quantile`",
+    "inputType" : "\"BIGINT\"",
+    "sqlExpr" : "\"kll_sketch_get_rank_bigint(agg, col1)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 44,
+    "fragment" : "kll_sketch_get_rank_bigint(agg, col1)"
+  } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_rank_bigint(agg, array(1L, col1, 5L)) AS 
non_foldable_array_quantile
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"array(1, col1, 5)\"",
+    "inputName" : "`quantile`",
+    "inputType" : "\"ARRAY<BIGINT>\"",
+    "sqlExpr" : "\"kll_sketch_get_rank_bigint(agg, array(1, col1, 5))\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 59,
+    "fragment" : "kll_sketch_get_rank_bigint(agg, array(1L, col1, 5L))"
+  } ]
+}
+
+
 -- !query
 DROP TABLE IF EXISTS t_int_1_5_through_7_11
 -- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/inputs/kllquantiles.sql 
b/sql/core/src/test/resources/sql-tests/inputs/kllquantiles.sql
index d0d7fb1f9c12..69d472ac78a6 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/kllquantiles.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/kllquantiles.sql
@@ -464,6 +464,41 @@ FROM (
     FROM t_double_1_5_through_7_11
 );
 
+-- Negative tests for non-foldable (non-constant) rank/quantile arguments
+-- These tests verify that get_quantile and get_rank functions require 
compile-time constant arguments
+
+-- Non-foldable scalar rank argument to get_quantile (column reference)
+SELECT kll_sketch_get_quantile_bigint(agg, CAST(col1 AS DOUBLE) / 10.0) AS 
non_foldable_scalar_rank
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+);
+
+-- Non-foldable array rank argument to get_quantile (array containing column 
reference)
+SELECT kll_sketch_get_quantile_bigint(agg, array(0.25, CAST(col1 AS DOUBLE) / 
10.0, 0.75)) AS non_foldable_array_rank
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+);
+
+-- Non-foldable scalar quantile argument to get_rank (column reference)
+SELECT kll_sketch_get_rank_bigint(agg, col1) AS non_foldable_scalar_quantile
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+);
+
+-- Non-foldable array quantile argument to get_rank (array containing column 
reference)
+SELECT kll_sketch_get_rank_bigint(agg, array(1L, col1, 5L)) AS 
non_foldable_array_quantile
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+);
+
 -- Clean up
 DROP TABLE IF EXISTS t_int_1_5_through_7_11;
 DROP TABLE IF EXISTS t_long_1_5_through_7_11;
diff --git a/sql/core/src/test/resources/sql-tests/results/kllquantiles.sql.out 
b/sql/core/src/test/resources/sql-tests/results/kllquantiles.sql.out
index 3618c851939e..6f60f30e5681 100644
--- a/sql/core/src/test/resources/sql-tests/results/kllquantiles.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/kllquantiles.sql.out
@@ -1395,6 +1395,126 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 }
 
 
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, CAST(col1 AS DOUBLE) / 10.0) AS 
non_foldable_scalar_rank
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"(CAST(col1 AS DOUBLE) / 10.0)\"",
+    "inputName" : "`rank`",
+    "inputType" : "\"DOUBLE\"",
+    "sqlExpr" : "\"kll_sketch_get_quantile_bigint(agg, (CAST(col1 AS DOUBLE) / 
10.0))\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 71,
+    "fragment" : "kll_sketch_get_quantile_bigint(agg, CAST(col1 AS DOUBLE) / 
10.0)"
+  } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, array(0.25, CAST(col1 AS DOUBLE) / 
10.0, 0.75)) AS non_foldable_array_rank
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"array(0.25, (CAST(col1 AS DOUBLE) / 10.0), 0.75)\"",
+    "inputName" : "`rank`",
+    "inputType" : "\"ARRAY<DOUBLE>\"",
+    "sqlExpr" : "\"kll_sketch_get_quantile_bigint(agg, array(0.25, (CAST(col1 
AS DOUBLE) / 10.0), 0.75))\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 90,
+    "fragment" : "kll_sketch_get_quantile_bigint(agg, array(0.25, CAST(col1 AS 
DOUBLE) / 10.0, 0.75))"
+  } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_rank_bigint(agg, col1) AS non_foldable_scalar_quantile
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"col1\"",
+    "inputName" : "`quantile`",
+    "inputType" : "\"BIGINT\"",
+    "sqlExpr" : "\"kll_sketch_get_rank_bigint(agg, col1)\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 44,
+    "fragment" : "kll_sketch_get_rank_bigint(agg, col1)"
+  } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_rank_bigint(agg, array(1L, col1, 5L)) AS 
non_foldable_array_quantile
+FROM (
+    SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+    FROM t_long_1_5_through_7_11
+    GROUP BY col1
+)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+  "sqlState" : "42K09",
+  "messageParameters" : {
+    "inputExpr" : "\"array(1, col1, 5)\"",
+    "inputName" : "`quantile`",
+    "inputType" : "\"ARRAY<BIGINT>\"",
+    "sqlExpr" : "\"kll_sketch_get_rank_bigint(agg, array(1, col1, 5))\""
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 59,
+    "fragment" : "kll_sketch_get_rank_bigint(agg, array(1L, col1, 5L))"
+  } ]
+}
+
+
 -- !query
 DROP TABLE IF EXISTS t_int_1_5_through_7_11
 -- !query schema
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index da2fbceae97e..0dfd37ebeae0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import java.time.{Duration, LocalDateTime, LocalTime, Period}
+import java.util.Locale
 
 import scala.util.Random
 
@@ -3348,7 +3349,7 @@ class DataFrameAggregateSuite extends QueryTest
     val result = 
sketchDf.select(kll_sketch_to_string_bigint($"sketch")).collect()(0)(0)
     assert(result != null)
     assert(result.asInstanceOf[String].length > 0)
-    assert(result.asInstanceOf[String].contains("Kll"))
+    
assert(result.asInstanceOf[String].toLowerCase(Locale.ROOT).contains("kll"))
   }
 
   test("kll_sketch_get_n functions") {
@@ -3402,7 +3403,7 @@ class DataFrameAggregateSuite extends QueryTest
 
     // Test to_string
     val str = 
sketchDf.select(kll_sketch_to_string_float($"sketch")).collect()(0)(0)
-    assert(str.asInstanceOf[String].contains("Kll"))
+    assert(str.asInstanceOf[String].toLowerCase(Locale.ROOT).contains("kll"))
 
     // Test get_n
     val n = sketchDf.select(kll_sketch_get_n_float($"sketch")).collect()(0)(0)
@@ -3433,7 +3434,7 @@ class DataFrameAggregateSuite extends QueryTest
 
     // Test to_string
     val str = 
sketchDf.select(kll_sketch_to_string_double($"sketch")).collect()(0)(0)
-    assert(str.asInstanceOf[String].contains("Kll"))
+    assert(str.asInstanceOf[String].toLowerCase(Locale.ROOT).contains("kll"))
 
     // Test get_n
     val n = sketchDf.select(kll_sketch_get_n_double($"sketch")).collect()(0)(0)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to