This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b0327f37bfb0 [SPARK-53947][SQL] Count null in approx_top_k
b0327f37bfb0 is described below

commit b0327f37bfb03f80b120995f1d3fc344ba142831
Author: yhuang-db <[email protected]>
AuthorDate: Mon Oct 20 15:42:37 2025 -0700

    [SPARK-53947][SQL] Count null in approx_top_k
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add a nullCounter associated with the Frequent Item 
Sketch in `approx_top_k` aggregation, so that now the function will return null 
item and null count if NULL value is among the top_k frequent items.
    
    ### Why are the changes needed?
    
    NULL value could be meaningful in some use cases and users might want to 
include NULL in the approx_top_k output.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests on handling null values.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52655 from yhuang-db/approx_top_k_count_null.
    
    Authored-by: yhuang-db <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../aggregate/ApproxTopKAggregates.scala           | 174 +++++++++++++++++++--
 .../org/apache/spark/sql/ApproxTopKSuite.scala     |  53 ++++++-
 2 files changed, 210 insertions(+), 17 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
index 6c6b3b805048..1fca8ad86bc2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
@@ -79,7 +79,7 @@ case class ApproxTopK(
     maxItemsTracked: Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends TypedImperativeAggregate[ItemsSketch[Any]]
+  extends TypedImperativeAggregate[ApproxTopKAggregateBuffer[Any]]
   with ImplicitCastInputTypes
   with TernaryLike[Expression] {
 
@@ -137,25 +137,30 @@ case class ApproxTopK(
 
   override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType)
 
-  override def createAggregationBuffer(): ItemsSketch[Any] = {
+  override def createAggregationBuffer(): ApproxTopKAggregateBuffer[Any] = {
     val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
-    ApproxTopK.createAggregationBuffer(expr, maxMapSize)
+    val sketch = ApproxTopK.createItemsSketch(expr, maxMapSize)
+    new ApproxTopKAggregateBuffer[Any](sketch, 0L)
   }
 
-  override def update(buffer: ItemsSketch[Any], input: InternalRow): 
ItemsSketch[Any] =
-    ApproxTopK.updateSketchBuffer(expr, buffer, input)
+  override def update(buffer: ApproxTopKAggregateBuffer[Any], input: 
InternalRow):
+    ApproxTopKAggregateBuffer[Any] =
+    buffer.update(expr, input)
 
-  override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): 
ItemsSketch[Any] =
+  override def merge(
+      buffer: ApproxTopKAggregateBuffer[Any],
+      input: ApproxTopKAggregateBuffer[Any]):
+    ApproxTopKAggregateBuffer[Any] =
     buffer.merge(input)
 
-  override def eval(buffer: ItemsSketch[Any]): GenericArrayData =
-    ApproxTopK.genEvalResult(buffer, kVal, itemDataType)
+  override def eval(buffer: ApproxTopKAggregateBuffer[Any]): GenericArrayData =
+    buffer.eval(kVal, itemDataType)
 
-  override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
-    buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType))
+  override def serialize(buffer: ApproxTopKAggregateBuffer[Any]): Array[Byte] =
+    buffer.serialize(ApproxTopK.genSketchSerDe(itemDataType))
 
-  override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] =
-    ItemsSketch.getInstance(Memory.wrap(storageFormat), 
ApproxTopK.genSketchSerDe(itemDataType))
+  override def deserialize(storageFormat: Array[Byte]): 
ApproxTopKAggregateBuffer[Any] =
+    ApproxTopKAggregateBuffer.deserialize(storageFormat, 
ApproxTopK.genSketchSerDe(itemDataType))
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -214,7 +219,7 @@ object ApproxTopK {
 
   def getResultDataType(itemDataType: DataType): DataType = {
     val resultEntryType = StructType(
-      StructField("item", itemDataType, nullable = false) ::
+      StructField("item", itemDataType, nullable = true) ::
         StructField("count", LongType, nullable = false) :: Nil)
     ArrayType(resultEntryType, containsNull = false)
   }
@@ -238,7 +243,7 @@ object ApproxTopK {
     math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt
   }
 
-  def createAggregationBuffer(itemExpression: Expression, maxMapSize: Int): 
ItemsSketch[Any] = {
+  def createItemsSketch(itemExpression: Expression, maxMapSize: Int): 
ItemsSketch[Any] = {
     itemExpression.dataType match {
       case _: BooleanType =>
         new ItemsSketch[Boolean](maxMapSize).asInstanceOf[ItemsSketch[Any]]
@@ -369,6 +374,145 @@ object ApproxTopK {
   }
 }
 
+/**
+ * In internal class used as the aggregation buffer for ApproxTopK.
+ *
+ * @param sketch    the ItemsSketch instance for counting not-null items
+ * @param nullCount the count of null items
+ */
+class ApproxTopKAggregateBuffer[T](val sketch: ItemsSketch[T], private var 
nullCount: Long) {
+  def update(itemExpression: Expression, input: InternalRow): 
ApproxTopKAggregateBuffer[T] = {
+    val v = itemExpression.eval(input)
+    if (v != null) {
+      itemExpression.dataType match {
+        case _: BooleanType =>
+          
sketch.asInstanceOf[ItemsSketch[Boolean]].update(v.asInstanceOf[Boolean])
+        case _: ByteType =>
+          sketch.asInstanceOf[ItemsSketch[Byte]].update(v.asInstanceOf[Byte])
+        case _: ShortType =>
+          sketch.asInstanceOf[ItemsSketch[Short]].update(v.asInstanceOf[Short])
+        case _: IntegerType =>
+          sketch.asInstanceOf[ItemsSketch[Int]].update(v.asInstanceOf[Int])
+        case _: LongType =>
+          sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
+        case _: FloatType =>
+          sketch.asInstanceOf[ItemsSketch[Float]].update(v.asInstanceOf[Float])
+        case _: DoubleType =>
+          
sketch.asInstanceOf[ItemsSketch[Double]].update(v.asInstanceOf[Double])
+        case _: DateType =>
+          sketch.asInstanceOf[ItemsSketch[Int]].update(v.asInstanceOf[Int])
+        case _: TimestampType =>
+          sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
+        case _: TimestampNTZType =>
+          sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
+        case st: StringType =>
+          val cKey = 
CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId)
+          sketch.asInstanceOf[ItemsSketch[String]].update(cKey.toString)
+        case _: DecimalType =>
+          
sketch.asInstanceOf[ItemsSketch[Decimal]].update(v.asInstanceOf[Decimal])
+      }
+    } else {
+      nullCount += 1
+    }
+    this
+  }
+
+  def merge(other: ApproxTopKAggregateBuffer[T]): ApproxTopKAggregateBuffer[T] 
= {
+    sketch.merge(other.sketch)
+    nullCount += other.nullCount
+    this
+  }
+
+  /**
+   * Serialize the buffer into bytes.
+   * The format is:
+   * [sketch bytes][null count (8 bytes Long)]
+   */
+  def serialize(serDe: ArrayOfItemsSerDe[T]): Array[Byte] = {
+    val sketchBytes = sketch.toByteArray(serDe)
+    val result = new Array[Byte](sketchBytes.length + java.lang.Long.BYTES)
+    val byteBuffer = java.nio.ByteBuffer.wrap(result)
+    byteBuffer.put(sketchBytes)
+    byteBuffer.putLong(nullCount)
+    result
+  }
+
+  /**
+   * Evaluate the buffer and return top K items (including null) with their 
estimated frequency.
+   * The result is sorted by frequency in descending order.
+   */
+  def eval(k: Int, itemDataType: DataType): GenericArrayData = {
+    // frequent items from sketch
+    val frequentItems = sketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES)
+    // total number of frequent items (including null, if any)
+    val itemsLength = frequentItems.length + (if (nullCount > 0) 1 else 0)
+    // actual number of items to return
+    val resultLength = math.min(itemsLength, k)
+    val result = new Array[AnyRef](resultLength)
+
+    // variable pointers for merging frequent items and nullCount into result
+    var fiIndex = 0 // pointer for frequentItems
+    var resultIndex = 0 // pointer for result
+    var isNullAdded = false // whether nullCount has been added to result
+
+    // helper function to get nullCount estimate: if nullCount has been added, 
return Long.MinValue
+    // so that it won't be added again; otherwise return nullCount
+    @inline def getNullEstimate: Long = if (!isNullAdded) nullCount else 
Long.MinValue
+
+    // looping until result is full or run out of frequent items
+    while (resultIndex < resultLength && fiIndex < frequentItems.length) {
+      val curFrequentItem = frequentItems(fiIndex)
+      val itemEstimate = curFrequentItem.getEstimate
+      val nullEstimate = getNullEstimate
+
+      val (item, estimate) = if (nullEstimate > itemEstimate) {
+        // insert (null, nullCount) into result
+        isNullAdded = true
+        (null, nullCount.toLong)
+      } else {
+        // insert frequent item into result
+        val item: Any = itemDataType match {
+          case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType |
+               _: LongType | _: FloatType | _: DoubleType | _: DecimalType |
+               _: DateType | _: TimestampType | _: TimestampNTZType =>
+            curFrequentItem.getItem
+          case _: StringType =>
+            UTF8String.fromString(curFrequentItem.getItem.asInstanceOf[String])
+        }
+        fiIndex += 1 // move to next frequent item
+        (item, itemEstimate)
+      }
+      result(resultIndex) = InternalRow(item, estimate)
+      resultIndex += 1 // move to next result position
+    }
+
+    // in case there is still space in result and nullCount > 0 has not been 
added
+    if (resultIndex < resultLength && nullCount > 0 && !isNullAdded) {
+      result(resultIndex) = InternalRow(null, nullCount.toLong)
+    }
+
+    new GenericArrayData(result)
+  }
+}
+
+object ApproxTopKAggregateBuffer {
+  /**
+   * Deserialize the buffer from bytes.
+   * The format is:
+   * [sketch bytes][null count (8 bytes)]
+   */
+  def deserialize(bytes: Array[Byte], serDe: ArrayOfItemsSerDe[Any]):
+  ApproxTopKAggregateBuffer[Any] = {
+    val byteBuffer = java.nio.ByteBuffer.wrap(bytes)
+    val sketchBytesLength = bytes.length - 8
+    val sketchBytes = new Array[Byte](sketchBytesLength)
+    byteBuffer.get(sketchBytes, 0, sketchBytesLength)
+    val nullCount = byteBuffer.getLong(sketchBytesLength)
+    val deserializedSketch = ItemsSketch.getInstance(Memory.wrap(sketchBytes), 
serDe)
+    new ApproxTopKAggregateBuffer[Any](deserializedSketch, nullCount)
+  }
+}
+
 /**
  * An aggregate function that accumulates items into a sketch, which can then 
be used
  * to combine with other sketches, via ApproxTopKCombine,
@@ -450,7 +594,7 @@ case class ApproxTopKAccumulate(
 
   override def createAggregationBuffer(): ItemsSketch[Any] = {
     val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
-    ApproxTopK.createAggregationBuffer(expr, maxMapSize)
+    ApproxTopK.createItemsSketch(expr, maxMapSize)
   }
 
   override def update(buffer: ItemsSketch[Any], input: InternalRow): 
ItemsSketch[Any] =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
index 702f361ace28..d9d16d1234b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
@@ -196,13 +196,62 @@ class ApproxTopKSuite extends QueryTest with 
SharedSparkSession {
     )
   }
 
-  test("SPARK-52515: does not count NULL values") {
+  test("SPARK-53947: count NULL values") {
     val res = sql(
       "SELECT approx_top_k(expr, 2)" +
-        "FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL AS tab(expr);")
+        "FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL, NULL AS 
tab(expr);")
+    checkAnswer(res, Row(Seq(Row(null, 4), Row("b", 3))))
+  }
+
+  test("SPARK-53947: null is not in top k") {
+    val res = sql(
+      "SELECT approx_top_k(expr, 2) FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL 
AS tab(expr)"
+    )
     checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2))))
   }
 
+  test("SPARK-53947: null is the last in top k") {
+    val res = sql(
+      "SELECT approx_top_k(expr, 3) FROM VALUES 0, 0, 1, 1, 1, NULL AS 
tab(expr)"
+    )
+    checkAnswer(res, Row(Seq(Row(1, 3), Row(0, 2), Row(null, 1))))
+  }
+
+  test("SPARK-53947: null + frequent items < k") {
+    val res = sql(
+      """SELECT approx_top_k(expr, 5)
+        |FROM VALUES cast(0.0 AS DECIMAL(4, 1)), cast(0.0 AS DECIMAL(4, 1)),
+        |cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS 
DECIMAL(4, 1)),
+        |NULL AS tab(expr)""".stripMargin)
+    checkAnswer(
+      res,
+      Row(Seq(Row(new java.math.BigDecimal("0.1"), 3),
+        Row(new java.math.BigDecimal("0.0"), 2),
+        Row(null, 1))))
+  }
+
+  test("SPARK-53947: work on typed column with only NULL values") {
+    val res = sql(
+      "SELECT approx_top_k(expr) FROM VALUES cast(NULL AS INT), cast(NULL AS 
INT) AS tab(expr)"
+    )
+    checkAnswer(res, Row(Seq(Row(null, 2))))
+  }
+
+  test("SPARK-53947: invalid item void columns") {
+    checkError(
+      exception = intercept[ExtendedAnalysisException] {
+        sql("SELECT approx_top_k(expr) FROM VALUES (NULL), (NULL), (NULL) AS 
tab(expr)")
+      },
+      condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+      parameters = Map(
+        "sqlExpr" -> "\"approx_top_k(expr, 5, 10000)\"",
+        "msg" -> "void columns are not supported",
+        "hint" -> ""
+      ),
+      queryContext = Array(ExpectedContext("approx_top_k(expr)", 7, 24))
+    )
+  }
+
   /////////////////////////////////
   // approx_top_k_accumulate and
   // approx_top_k_estimate tests


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to