This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a8cfe0c682c5 [SPARK-53960][SQL] Let 
approx_top_k_accumulate/combine/estimate handle NULLs
a8cfe0c682c5 is described below

commit a8cfe0c682c57d2c7346fed20a0c2e76a2c7ff99
Author: yhuang-db <[email protected]>
AuthorDate: Tue Oct 21 08:49:30 2025 -0700

    [SPARK-53960][SQL] Let approx_top_k_accumulate/combine/estimate handle NULLs
    
    ### What changes were proposed in this pull request?
    
    As a follow-up of https://github.com/apache/spark/pull/52655, add NULL 
handling in approx_top_k_accumulate/estimate/combine.
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit tests on null handling for accumulate, combine and estimate.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52673 from yhuang-db/accumulate_estimate_count_null.
    
    Authored-by: yhuang-db <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../expressions/ApproxTopKExpressions.scala        |  12 +-
 .../aggregate/ApproxTopKAggregates.scala           | 114 ++----
 .../org/apache/spark/sql/ApproxTopKSuite.scala     | 436 +++++++++++++++------
 3 files changed, 363 insertions(+), 199 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
index 3e2d12fc5b17..53c37f0a5491 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
@@ -17,14 +17,11 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.datasketches.frequencies.ItemsSketch
-import org.apache.datasketches.memory.Memory
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK
+import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, 
ApproxTopKAggregateBuffer}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.types._
 
@@ -105,9 +102,10 @@ case class ApproxTopKEstimate(state: Expression, k: 
Expression)
     val kVal = kEval.asInstanceOf[Int]
     ApproxTopK.checkK(kVal)
     ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal)
-    val itemsSketch = ItemsSketch.getInstance(
-      Memory.wrap(dataSketchBytes), ApproxTopK.genSketchSerDe(itemDataType))
-    ApproxTopK.genEvalResult(itemsSketch, kVal, itemDataType)
+    val approxTopKAggregateBuffer = ApproxTopKAggregateBuffer.deserialize(
+      dataSketchBytes,
+      ApproxTopK.genSketchSerDe(itemDataType))
+    approxTopKAggregateBuffer.eval(kVal, itemDataType)
   }
 
   override protected def withNewChildrenInternal(newState: Expression, newK: 
Expression)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
index 1fca8ad86bc2..7ae542f190d5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
@@ -260,54 +260,6 @@ object ApproxTopK {
     }
   }
 
-  def updateSketchBuffer(
-      itemExpression: Expression,
-      buffer: ItemsSketch[Any],
-      input: InternalRow): ItemsSketch[Any] = {
-    val v = itemExpression.eval(input)
-    if (v != null) {
-      itemExpression.dataType match {
-        case _: BooleanType => buffer.update(v.asInstanceOf[Boolean])
-        case _: ByteType => buffer.update(v.asInstanceOf[Byte])
-        case _: ShortType => buffer.update(v.asInstanceOf[Short])
-        case _: IntegerType => buffer.update(v.asInstanceOf[Int])
-        case _: LongType => buffer.update(v.asInstanceOf[Long])
-        case _: FloatType => buffer.update(v.asInstanceOf[Float])
-        case _: DoubleType => buffer.update(v.asInstanceOf[Double])
-        case _: DateType => buffer.update(v.asInstanceOf[Int])
-        case _: TimestampType => buffer.update(v.asInstanceOf[Long])
-        case _: TimestampNTZType => buffer.update(v.asInstanceOf[Long])
-        case st: StringType =>
-          val cKey = 
CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId)
-          buffer.update(cKey.toString)
-        case _: DecimalType => buffer.update(v.asInstanceOf[Decimal])
-      }
-    }
-    buffer
-  }
-
-  def genEvalResult(
-      itemsSketch: ItemsSketch[Any],
-      k: Int,
-      itemDataType: DataType): GenericArrayData = {
-    val items = itemsSketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES)
-    val resultLength = math.min(items.length, k)
-    val result = new Array[AnyRef](resultLength)
-    for (i <- 0 until resultLength) {
-      val row = items(i)
-      itemDataType match {
-        case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType |
-             _: LongType | _: FloatType | _: DoubleType | _: DecimalType |
-             _: DateType | _: TimestampType | _: TimestampNTZType =>
-          result(i) = InternalRow.apply(row.getItem, row.getEstimate)
-        case _: StringType =>
-          val item = UTF8String.fromString(row.getItem.asInstanceOf[String])
-          result(i) = InternalRow.apply(item, row.getEstimate)
-      }
-    }
-    new GenericArrayData(result)
-  }
-
   def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = {
     dataType match {
       case _: BooleanType => new 
ArrayOfBooleansSerDe().asInstanceOf[ArrayOfItemsSerDe[Any]]
@@ -333,7 +285,7 @@ object ApproxTopK {
 
   def dataTypeToDDL(dataType: DataType): String = dataType match {
     case _: StringType =>
-      // Hide collation information in DDL format
+      // Hide collation information in DDL format, otherwise 
CollationExpressionWalkerSuite fails
       s"item string not null"
     case other =>
       StructField("item", other, nullable = false).toDDL
@@ -552,7 +504,7 @@ case class ApproxTopKAccumulate(
     maxItemsTracked: Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends TypedImperativeAggregate[ItemsSketch[Any]]
+  extends TypedImperativeAggregate[ApproxTopKAggregateBuffer[Any]]
   with ImplicitCastInputTypes
   with BinaryLike[Expression] {
 
@@ -592,18 +544,23 @@ case class ApproxTopKAccumulate(
 
   override def dataType: DataType = 
ApproxTopK.getSketchStateDataType(itemDataType)
 
-  override def createAggregationBuffer(): ItemsSketch[Any] = {
+  override def createAggregationBuffer(): ApproxTopKAggregateBuffer[Any] = {
     val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
-    ApproxTopK.createItemsSketch(expr, maxMapSize)
+    val sketch = ApproxTopK.createItemsSketch(expr, maxMapSize)
+    new ApproxTopKAggregateBuffer[Any](sketch, 0L)
   }
 
-  override def update(buffer: ItemsSketch[Any], input: InternalRow): 
ItemsSketch[Any] =
-    ApproxTopK.updateSketchBuffer(expr, buffer, input)
+  override def update(buffer: ApproxTopKAggregateBuffer[Any], input: 
InternalRow):
+    ApproxTopKAggregateBuffer[Any] =
+    buffer.update(expr, input)
 
-  override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): 
ItemsSketch[Any] =
+  override def merge(
+      buffer: ApproxTopKAggregateBuffer[Any],
+      input: ApproxTopKAggregateBuffer[Any]):
+    ApproxTopKAggregateBuffer[Any] =
     buffer.merge(input)
 
-  override def eval(buffer: ItemsSketch[Any]): Any = {
+  override def eval(buffer: ApproxTopKAggregateBuffer[Any]): Any = {
     val sketchBytes = serialize(buffer)
     val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
     InternalRow.apply(
@@ -613,11 +570,11 @@ case class ApproxTopKAccumulate(
       UTF8String.fromString(itemDataTypeDDL))
   }
 
-  override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
-    buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType))
+  override def serialize(buffer: ApproxTopKAggregateBuffer[Any]): Array[Byte] =
+    buffer.serialize(ApproxTopK.genSketchSerDe(itemDataType))
 
-  override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] =
-    ItemsSketch.getInstance(Memory.wrap(storageFormat), 
ApproxTopK.genSketchSerDe(itemDataType))
+  override def deserialize(storageFormat: Array[Byte]): 
ApproxTopKAggregateBuffer[Any] =
+    ApproxTopKAggregateBuffer.deserialize(storageFormat, 
ApproxTopK.genSketchSerDe(itemDataType))
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -644,10 +601,10 @@ case class ApproxTopKAccumulate(
  * @param maxItemsTracked the maximum number of items tracked in the sketch
  */
 class CombineInternal[T](
-    sketch: ItemsSketch[T],
+    sketchWithNullCount: ApproxTopKAggregateBuffer[T],
     var itemDataType: DataType,
     var maxItemsTracked: Int) {
-  def getSketch: ItemsSketch[T] = sketch
+  def getSketchWithNullCount: ApproxTopKAggregateBuffer[T] = 
sketchWithNullCount
 
   def getItemDataType: DataType = itemDataType
 
@@ -689,6 +646,9 @@ class CombineInternal[T](
     }
   }
 
+  def updateSketchWithNullCount(otherSketchWithNullCount: 
ApproxTopKAggregateBuffer[T]): Unit =
+    sketchWithNullCount.merge(otherSketchWithNullCount)
+
   /**
    * Serialize the CombineInternal instance to a byte array.
    * Serialization format:
@@ -698,18 +658,18 @@ class CombineInternal[T](
    *     sketchBytes
    */
   def serialize(): Array[Byte] = {
-    val sketchBytes = sketch.toByteArray(
+    val sketchWithNullCountBytes = sketchWithNullCount.serialize(
       
ApproxTopK.genSketchSerDe(itemDataType).asInstanceOf[ArrayOfItemsSerDe[T]])
     val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
     val ddlBytes: Array[Byte] = 
itemDataTypeDDL.getBytes(StandardCharsets.UTF_8)
     val byteArray = new Array[Byte](
-      sketchBytes.length + Integer.BYTES + Integer.BYTES + ddlBytes.length)
+      sketchWithNullCountBytes.length + Integer.BYTES + Integer.BYTES + 
ddlBytes.length)
 
     val byteBuffer = ByteBuffer.wrap(byteArray)
     byteBuffer.putInt(maxItemsTracked)
     byteBuffer.putInt(ddlBytes.length)
     byteBuffer.put(ddlBytes)
-    byteBuffer.put(sketchBytes)
+    byteBuffer.put(sketchWithNullCountBytes)
     byteArray
   }
 }
@@ -736,9 +696,9 @@ object CombineInternal {
     // read sketchBytes
     val sketchBytes = new Array[Byte](buffer.length - Integer.BYTES - 
Integer.BYTES - ddlLength)
     byteBuffer.get(sketchBytes)
-    val sketch = ItemsSketch.getInstance(
-      Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType))
-    new CombineInternal[Any](sketch, itemDataType, maxItemsTracked)
+    val sketchWithNullCount = ApproxTopKAggregateBuffer.deserialize(
+      sketchBytes, ApproxTopK.genSketchSerDe(itemDataType))
+    new CombineInternal[Any](sketchWithNullCount, itemDataType, 
maxItemsTracked)
   }
 }
 
@@ -833,7 +793,7 @@ case class ApproxTopKCombine(
     if (combineSizeSpecified) {
       val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
       new CombineInternal[Any](
-        new ItemsSketch[Any](maxMapSize),
+        new ApproxTopKAggregateBuffer[Any](new ItemsSketch[Any](maxMapSize), 
0L),
         null,
         maxItemsTrackedVal)
     } else {
@@ -842,7 +802,7 @@ case class ApproxTopKCombine(
       // The actual maxItemsTracked will be checked during the updates.
       val maxMapSize = 
ApproxTopK.calMaxMapSize(ApproxTopK.MAX_ITEMS_TRACKED_LIMIT)
       new CombineInternal[Any](
-        new ItemsSketch[Any](maxMapSize),
+        new ApproxTopKAggregateBuffer[Any](new ItemsSketch[Any](maxMapSize), 
0L),
         null,
         ApproxTopK.VOID_MAX_ITEMS_TRACKED)
     }
@@ -863,9 +823,9 @@ case class ApproxTopKCombine(
     // update itemDataType (throw error if not match)
     buffer.updateItemDataType(inputItemDataType)
     // update sketch
-    val inputSketch = ItemsSketch.getInstance(
-      Memory.wrap(inputSketchBytes), 
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
-    buffer.getSketch.merge(inputSketch)
+    val inputSketchWithNullCount = ApproxTopKAggregateBuffer.deserialize(
+      inputSketchBytes, ApproxTopK.genSketchSerDe(inputItemDataType))
+    buffer.updateSketchWithNullCount(inputSketchWithNullCount)
     buffer
   }
 
@@ -876,14 +836,14 @@ case class ApproxTopKCombine(
     buffer.updateMaxItemsTracked(combineSizeSpecified, 
input.getMaxItemsTracked)
     // update itemDataType (throw error if not match)
     buffer.updateItemDataType(input.getItemDataType)
-    // update sketch
-    buffer.getSketch.merge(input.getSketch)
+    // update sketchWithNullCount
+    buffer.getSketchWithNullCount.merge(input.getSketchWithNullCount)
     buffer
   }
 
   override def eval(buffer: CombineInternal[Any]): Any = {
-    val sketchBytes =
-      
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+    val sketchBytes = buffer.getSketchWithNullCount
+      .serialize(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
     val maxItemsTracked = buffer.getMaxItemsTracked
     val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(buffer.getItemDataType)
     InternalRow.apply(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
index d9d16d1234b7..982c9ff90da7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
@@ -381,6 +381,77 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
     )
   }
 
+  test("SPARK-53960: accumulate and estimate count NULL values") {
+    val res = sql(
+      """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2)
+        |FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL, NULL AS 
tab(expr)""".stripMargin)
+    checkAnswer(res, Row(Seq(Row(null, 4), Row("b", 3))))
+  }
+
+  test("SPARK-53960: accumulate and estimate null is not in top k") {
+    val res = sql(
+      """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2)
+        |FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL AS tab(expr)""".stripMargin)
+    checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2))))
+  }
+
+  test("SPARK-53960: accumulate and estimate null is the last in top k") {
+    val res = sql(
+      """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 3)
+        |FROM VALUES 0, 0, 1, 1, 1, NULL AS tab(expr)""".stripMargin)
+    checkAnswer(res, Row(Seq(Row(1, 3), Row(0, 2), Row(null, 1))))
+  }
+
+  test("SPARK-53960: accumulate and estimate null + frequent items < k") {
+    val res = sql(
+      """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 5)
+        |FROM VALUES cast(0.0 AS DECIMAL(4, 1)), cast(0.0 AS DECIMAL(4, 1)),
+        |cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS 
DECIMAL(4, 1)),
+        |NULL AS tab(expr)""".stripMargin)
+    checkAnswer(
+      res,
+      Row(Seq(Row(new java.math.BigDecimal("0.1"), 3),
+        Row(new java.math.BigDecimal("0.0"), 2),
+        Row(null, 1))))
+  }
+
+  test("SPARK-53960: accumulate and estimate work on typed column with only 
NULL values") {
+    val res = sql(
+      """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr))
+        |FROM VALUES cast(NULL AS INT), cast(NULL AS INT) AS 
tab(expr)""".stripMargin)
+    checkAnswer(res, Row(Seq(Row(null, 2))))
+  }
+
+  test("SPARK-53960: accumulate a column of all nulls with type - success") {
+    withView("accumulation") {
+      val res = sql(
+        """SELECT approx_top_k_accumulate(expr) AS acc
+          |FROM VALUES cast(NULL AS INT), cast(NULL AS INT) AS 
tab(expr)""".stripMargin)
+
+      assert(res.collect().length == 1)
+      res.createOrReplaceTempView("accumulation")
+      val est = sql("SELECT approx_top_k_estimate(acc) FROM accumulation;")
+      checkAnswer(est, Row(Seq(Row(null, 2))))
+
+    }
+  }
+
+  test("SPARK-53960: accumulate a column of all nulls without type - fail") {
+    checkError(
+      exception = intercept[ExtendedAnalysisException] {
+        sql("""SELECT approx_top_k_accumulate(expr)
+            |FROM VALUES (NULL), (NULL), (NULL), (NULL) AS 
tab(expr)""".stripMargin)
+      },
+      condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+      parameters = Map(
+        "sqlExpr" -> "\"approx_top_k_accumulate(expr, 10000)\"",
+        "msg" -> "void columns are not supported",
+        "hint" -> ""
+      ),
+      queryContext = Array(ExpectedContext("approx_top_k_accumulate(expr)", 7, 
35))
+    )
+  }
+
   /////////////////////////////////
   // approx_top_k_combine
   /////////////////////////////////
@@ -445,75 +516,87 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
   // positive tests for approx_top_k_combine on every types
   gridTest("SPARK-52798: same type, same size, specified combine size - 
success")(itemsWithTopK) {
     case (input, expected) =>
-      sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS 
tab(expr);")
-        .createOrReplaceTempView("accumulation1")
-      sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS 
tab(expr);")
-        .createOrReplaceTempView("accumulation2")
-      sql("SELECT approx_top_k_combine(acc, 30) as com " +
-        "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM 
accumulation2);")
-        .createOrReplaceTempView("combined")
-      val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
-      // expected should be doubled because we combine two identical sketches
-      val expectedDoubled = expected.map {
-        case Row(value: Any, count: Int) => Row(value, count * 2)
+      withView("accumulation1", "accumulation2", "combines") {
+        sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input 
AS tab(expr);")
+          .createOrReplaceTempView("accumulation1")
+        sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input 
AS tab(expr);")
+          .createOrReplaceTempView("accumulation2")
+        sql("SELECT approx_top_k_combine(acc, 30) as com " +
+          "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM 
accumulation2);")
+          .createOrReplaceTempView("combined")
+        val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+        // expected should be doubled because we combine two identical sketches
+        val expectedDoubled = expected.map {
+          case Row(value: Any, count: Int) => Row(value, count * 2)
+        }
+        checkAnswer(est, Row(expectedDoubled))
       }
-      checkAnswer(est, Row(expectedDoubled))
   }
 
   test("SPARK-52798: same type, same size, specified combine size - success") {
-    setupMixedSizeAccumulations(10, 10)
+    withView("accumulation1", "accumulation2", "unioned", "combined") {
+      setupMixedSizeAccumulations(10, 10)
 
-    sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
-      .createOrReplaceTempView("combined")
+      sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+        .createOrReplaceTempView("combined")
 
-    val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
-    checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+      val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+      checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+    }
   }
 
   test("SPARK-52798: same type, same size, unspecified combine size - 
success") {
-    setupMixedSizeAccumulations(10, 10)
+    withView("accumulation1", "accumulation2", "unioned", "combined") {
+      setupMixedSizeAccumulations(10, 10)
 
-    sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
-      .createOrReplaceTempView("combined")
+      sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
+        .createOrReplaceTempView("combined")
 
-    val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
-    checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+      val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+      checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+    }
   }
 
   test("SPARK-52798: same type, different size, specified combine size - 
success") {
-    setupMixedSizeAccumulations(10, 20)
+    withView("accumulation1", "accumulation2", "unioned", "combined") {
+      setupMixedSizeAccumulations(10, 20)
 
-    sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
-      .createOrReplaceTempView("combination")
+      sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+        .createOrReplaceTempView("combined")
 
-    val est = sql("SELECT approx_top_k_estimate(com) FROM combination;")
-    checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+      val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+      checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), 
Row(4, 2))))
+    }
   }
 
   test("SPARK-52798: same type, different size, unspecified combine size - 
fail") {
-    setupMixedSizeAccumulations(10, 20)
+    withView("accumulation1", "accumulation2", "unioned") {
+      setupMixedSizeAccumulations(10, 20)
 
-    val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
-
-    checkError(
-      exception = intercept[SparkRuntimeException] {
-        comb.collect()
-      },
-      condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
-      parameters = Map("size1" -> "10", "size2" -> "20")
-    )
-  }
+      val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
 
-  gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10, 
20))) {
-    case (size1, size2) =>
-      setupMixedSizeAccumulations(size1, size2)
       checkError(
         exception = intercept[SparkRuntimeException] {
-          sql("SELECT approx_top_k_combine(acc, 0) as com FROM 
unioned").collect()
+          comb.collect()
         },
-        condition = "APPROX_TOP_K_NON_POSITIVE_ARG",
-        parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0")
+        condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+        parameters = Map("size1" -> "10", "size2" -> "20")
       )
+    }
+  }
+
+  gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10, 
20))) {
+    case (size1, size2) =>
+      withView("accumulation1", "accumulation2", "unioned") {
+        setupMixedSizeAccumulations(size1, size2)
+        checkError(
+          exception = intercept[SparkRuntimeException] {
+            sql("SELECT approx_top_k_combine(acc, 0) as com FROM 
unioned").collect()
+          },
+          condition = "APPROX_TOP_K_NON_POSITIVE_ARG",
+          parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0")
+        )
+      }
   }
 
   test("SPARK-52798: among different number or datetime types - fail at 
combine") {
@@ -523,13 +606,15 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
           val (type1, _, seq1) = mixedTypeSeq(i)
           val (type2, _, seq2) = mixedTypeSeq(j)
           setupMixedTypeAccumulation(seq1, seq2)
-          checkError(
-            exception = intercept[SparkRuntimeException] {
-              sql("SELECT approx_top_k_combine(acc, 30) as com FROM 
unioned;").collect()
-            },
-            condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
-            parameters = Map("type1" -> toSQLType(type1), "type2" -> 
toSQLType(type2))
-          )
+          withView("accumulation1", "accumulation2", "unioned") {
+            checkError(
+              exception = intercept[SparkRuntimeException] {
+                sql("SELECT approx_top_k_combine(acc, 30) as com FROM 
unioned;").collect()
+              },
+              condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+              parameters = Map("type1" -> toSQLType(type1), "type2" -> 
toSQLType(type2))
+            )
+          }
         }
       }
     }
@@ -547,7 +632,9 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
     case ((_, type1, seq1), (_, type2, seq2)) =>
       checkError(
         exception = intercept[ExtendedAnalysisException] {
-          setupMixedTypeAccumulation(seq1, seq2)
+          withView("accumulation1", "accumulation2", "unioned") {
+            setupMixedTypeAccumulation(seq1, seq2)
+          }
         },
         condition = "INCOMPATIBLE_COLUMN_TYPE",
         parameters = Map(
@@ -568,14 +655,17 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
 
   gridTest("SPARK-52798: number vs string - fail at 
combine")(mixedNumberTypes) {
     case (type1, _, seq1) =>
-      setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", 
"'c'", "'d'", "'d'"))
-      checkError(
-        exception = intercept[SparkRuntimeException] {
-          sql("SELECT approx_top_k_combine(acc, 30) as com FROM 
unioned;").collect()
-        },
-        condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
-        parameters = Map("type1" -> toSQLType(type1), "type2" -> 
toSQLType(StringType))
-      )
+      withView("accumulation1", "accumulation2", "unioned") {
+        setupMixedTypeAccumulation(
+          seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'"))
+        checkError(
+          exception = intercept[SparkRuntimeException] {
+            sql("SELECT approx_top_k_combine(acc, 30) as com FROM 
unioned;").collect()
+          },
+          condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+          parameters = Map("type1" -> toSQLType(type1), "type2" -> 
toSQLType(StringType))
+        )
+      }
   }
 
   gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypes) 
{
@@ -583,7 +673,9 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
       val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
       checkError(
         exception = intercept[ExtendedAnalysisException] {
-          setupMixedTypeAccumulation(seq1, seq2)
+          withView("accumulation1", "accumulation2", "unioned") {
+            setupMixedTypeAccumulation(seq1, seq2)
+          }
         },
         condition = "INCOMPATIBLE_COLUMN_TYPE",
         parameters = Map(
@@ -604,14 +696,17 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
 
   gridTest("SPARK-52798: datetime vs string - fail at 
combine")(mixedDateTimeTypes) {
     case (type1, _, seq1) =>
-      setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", 
"'c'", "'d'", "'d'"))
-      checkError(
-        exception = intercept[SparkRuntimeException] {
-          sql("SELECT approx_top_k_combine(acc, 30) as com FROM 
unioned;").collect()
-        },
-        condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
-        parameters = Map("type1" -> toSQLType(type1), "type2" -> 
toSQLType(StringType))
-      )
+      withView("accumulation1", "accumulation2", "unioned") {
+        setupMixedTypeAccumulation(
+          seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'"))
+        checkError(
+          exception = intercept[SparkRuntimeException] {
+            sql("SELECT approx_top_k_combine(acc, 30) as com FROM 
unioned;").collect()
+          },
+          condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+          parameters = Map("type1" -> toSQLType(type1), "type2" -> 
toSQLType(StringType))
+        )
+      }
   }
 
   gridTest("SPARK-52798: datetime vs boolean - fail at 
UNION")(mixedDateTimeTypes) {
@@ -619,7 +714,9 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
       val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
       checkError(
         exception = intercept[ExtendedAnalysisException] {
-          setupMixedTypeAccumulation(seq1, seq2)
+          withView("accumulation1", "accumulation2", "unioned") {
+            setupMixedTypeAccumulation(seq1, seq2)
+          }
         },
         condition = "INCOMPATIBLE_COLUMN_TYPE",
         parameters = Map(
@@ -641,65 +738,174 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
   test("SPARK-52798: string vs boolean - fail at combine") {
     val seq1 = Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")
     val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
-    setupMixedTypeAccumulation(seq1, seq2)
-    checkError(
-      exception = intercept[SparkRuntimeException] {
-        sql("SELECT approx_top_k_combine(acc, 30) as com FROM 
unioned;").collect()
-      },
-      condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
-      parameters = Map("type1" -> toSQLType(StringType), "type2" -> 
toSQLType(BooleanType))
-    )
+    withView("accumulation1", "accumulation2", "unioned") {
+      setupMixedTypeAccumulation(seq1, seq2)
+      checkError(
+        exception = intercept[SparkRuntimeException] {
+          sql("SELECT approx_top_k_combine(acc, 30) as com FROM 
unioned;").collect()
+        },
+        condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+        parameters = Map("type1" -> toSQLType(StringType), "type2" -> 
toSQLType(BooleanType))
+      )
+    }
   }
 
   test("SPARK-52798: combine more than 2 sketches with specified size") {
-    sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
-      "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
-      .createOrReplaceTempView("accumulation1")
+    withView("accumulation1", "accumulation2", "accumulation3", "unioned", 
"combined") {
+      sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+        "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
+        .createOrReplaceTempView("accumulation1")
 
-    sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
-      "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
-      .createOrReplaceTempView("accumulation2")
+      sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+        "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
+        .createOrReplaceTempView("accumulation2")
 
-    sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
-      "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
-      .createOrReplaceTempView("accumulation3")
+      sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
+        "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
+        .createOrReplaceTempView("accumulation3")
 
-    sql("SELECT acc from accumulation1 UNION ALL " +
-      "SELECT acc FROM accumulation2 UNION ALL " +
-      "SELECT acc FROM accumulation3")
-      .createOrReplaceTempView("unioned")
+      sql("SELECT acc from accumulation1 UNION ALL " +
+        "SELECT acc FROM accumulation2 UNION ALL " +
+        "SELECT acc FROM accumulation3")
+        .createOrReplaceTempView("unioned")
 
-    sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
-      .createOrReplaceTempView("combined")
+      sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+        .createOrReplaceTempView("combined")
 
-    val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
-    checkAnswer(est, Row(Seq(Row(2, 6), Row(3, 5), Row(1, 4), Row(0, 3), 
Row(4, 2))))
+      val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+      checkAnswer(est, Row(Seq(Row(2, 6), Row(3, 5), Row(1, 4), Row(0, 3), 
Row(4, 2))))
+    }
   }
 
   test("SPARK-52798: combine more than 2 sketches without specified size") {
-    sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
-      "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
-      .createOrReplaceTempView("accumulation1")
+    withView("accumulation1", "accumulation2", "accumulation3", "unioned") {
+      sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+        "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
+        .createOrReplaceTempView("accumulation1")
 
-    sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
-      "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
-      .createOrReplaceTempView("accumulation2")
+      sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+        "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
+        .createOrReplaceTempView("accumulation2")
 
-    sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
-      "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
-      .createOrReplaceTempView("accumulation3")
+      sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
+        "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
+        .createOrReplaceTempView("accumulation3")
 
-    sql("SELECT acc from accumulation1 UNION ALL " +
-      "SELECT acc FROM accumulation2 UNION ALL " +
-      "SELECT acc FROM accumulation3")
-      .createOrReplaceTempView("unioned")
+      sql("SELECT acc from accumulation1 UNION ALL " +
+        "SELECT acc FROM accumulation2 UNION ALL " +
+        "SELECT acc FROM accumulation3")
+        .createOrReplaceTempView("unioned")
 
-    checkError(
-      exception = intercept[SparkRuntimeException] {
-        sql("SELECT approx_top_k_combine(acc) as com FROM unioned").collect()
-      },
-      condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
-      parameters = Map("size1" -> "10", "size2" -> "20")
-    )
+      checkError(
+        exception = intercept[SparkRuntimeException] {
+          sql("SELECT approx_top_k_combine(acc) as com FROM unioned").collect()
+        },
+        condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+        parameters = Map("size1" -> "10", "size2" -> "20")
+      )
+    }
+  }
+
+  test("SPARK-53960: combine and estimate count NULL values") {
+    withView("accumulation1", "accumulation2", "unioned", "combined") {
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES 'a', 'a', 'b', NULL, NULL AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation1")
+
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES 'b', 'b', NULL, NULL AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation2")
+
+      sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM 
accumulation2")
+        .createOrReplaceTempView("unioned")
+
+      sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned")
+        .createOrReplaceTempView("combined")
+
+      val est = sql("SELECT approx_top_k_estimate(com, 2) FROM combined")
+      checkAnswer(est, Row(Seq(Row(null, 4), Row("b", 3))))
+    }
+  }
+
+  test("SPARK-53960: combine with a sketch of all nulls") {
+    withView("accumulation1", "accumulation2", "unioned", "combined") {
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES cast(NULL AS INT), cast(NULL AS INT), cast(NULL AS INT)
+          |AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation1")
+
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES 1, 1, 2, 2 AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation2")
+
+      sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM 
accumulation2")
+        .createOrReplaceTempView("unioned")
+
+      sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned")
+        .createOrReplaceTempView("combined")
+
+      val est = sql("SELECT approx_top_k_estimate(com) FROM combined")
+      checkAnswer(est, Row(Seq(Row(null, 3), Row(2, 2), Row(1, 2))))
+    }
+  }
+
+  test("SPARK-53960: combine sketches with nulls from more than 2 sketches") {
+    withView("accumulation1", "accumulation2", "accumulation3", "unioned", 
"combined") {
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES 0, 0, 0, 1, 1, NULL AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation1")
+
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES NULL, 1, 1, 2, 2, NULL AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation2")
+
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES 2, 3, 3, NULL AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation3")
+
+      sql(
+        """SELECT acc from accumulation1 UNION ALL
+          |SELECT acc FROM accumulation2 UNION ALL
+          |SELECT acc FROM accumulation3""".stripMargin)
+        .createOrReplaceTempView("unioned")
+
+      sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+        .createOrReplaceTempView("combined")
+
+      val est = sql("SELECT approx_top_k_estimate(com, 2) FROM combined")
+      checkAnswer(est, Row(Seq(Row(1, 4), Row(null, 4))))
+    }
+  }
+
+  test("SPARK-53960: combine 2 sketches with all nulls") {
+    withView("accumulation1", "accumulation2", "unioned", "combined") {
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES cast(NULL AS INT), cast(NULL AS INT), cast(NULL AS INT)
+          |AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation1")
+
+      sql(
+        """SELECT approx_top_k_accumulate(expr, 10) as acc
+          |FROM VALUES cast(NULL AS INT), cast(NULL AS INT)
+          |AS tab(expr)""".stripMargin)
+        .createOrReplaceTempView("accumulation2")
+
+      sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM 
accumulation2")
+        .createOrReplaceTempView("unioned")
+
+      sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned")
+        .createOrReplaceTempView("combined")
+
+      val est = sql("SELECT approx_top_k_estimate(com) FROM combined")
+      checkAnswer(est, Row(Seq(Row(null, 5))))
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to