This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b0327f37bfb0 [SPARK-53947][SQL] Count null in approx_top_k
b0327f37bfb0 is described below
commit b0327f37bfb03f80b120995f1d3fc344ba142831
Author: yhuang-db <[email protected]>
AuthorDate: Mon Oct 20 15:42:37 2025 -0700
[SPARK-53947][SQL] Count null in approx_top_k
### What changes were proposed in this pull request?
This PR proposes to add a nullCounter associated with the Frequent Item
Sketch in `approx_top_k` aggregation, so that now the function will return null
item and null count if NULL value is among the top_k frequent items.
### Why are the changes needed?
NULL value could be meaningful in some use cases and users might want to
include NULL in the approx_top_k output.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit tests on handling null values.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52655 from yhuang-db/approx_top_k_count_null.
Authored-by: yhuang-db <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../aggregate/ApproxTopKAggregates.scala | 174 +++++++++++++++++++--
.../org/apache/spark/sql/ApproxTopKSuite.scala | 53 ++++++-
2 files changed, 210 insertions(+), 17 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
index 6c6b3b805048..1fca8ad86bc2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
@@ -79,7 +79,7 @@ case class ApproxTopK(
maxItemsTracked: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends TypedImperativeAggregate[ItemsSketch[Any]]
+ extends TypedImperativeAggregate[ApproxTopKAggregateBuffer[Any]]
with ImplicitCastInputTypes
with TernaryLike[Expression] {
@@ -137,25 +137,30 @@ case class ApproxTopK(
override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType)
- override def createAggregationBuffer(): ItemsSketch[Any] = {
+ override def createAggregationBuffer(): ApproxTopKAggregateBuffer[Any] = {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
- ApproxTopK.createAggregationBuffer(expr, maxMapSize)
+ val sketch = ApproxTopK.createItemsSketch(expr, maxMapSize)
+ new ApproxTopKAggregateBuffer[Any](sketch, 0L)
}
- override def update(buffer: ItemsSketch[Any], input: InternalRow):
ItemsSketch[Any] =
- ApproxTopK.updateSketchBuffer(expr, buffer, input)
+ override def update(buffer: ApproxTopKAggregateBuffer[Any], input:
InternalRow):
+ ApproxTopKAggregateBuffer[Any] =
+ buffer.update(expr, input)
- override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]):
ItemsSketch[Any] =
+ override def merge(
+ buffer: ApproxTopKAggregateBuffer[Any],
+ input: ApproxTopKAggregateBuffer[Any]):
+ ApproxTopKAggregateBuffer[Any] =
buffer.merge(input)
- override def eval(buffer: ItemsSketch[Any]): GenericArrayData =
- ApproxTopK.genEvalResult(buffer, kVal, itemDataType)
+ override def eval(buffer: ApproxTopKAggregateBuffer[Any]): GenericArrayData =
+ buffer.eval(kVal, itemDataType)
- override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
- buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType))
+ override def serialize(buffer: ApproxTopKAggregateBuffer[Any]): Array[Byte] =
+ buffer.serialize(ApproxTopK.genSketchSerDe(itemDataType))
- override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] =
- ItemsSketch.getInstance(Memory.wrap(storageFormat),
ApproxTopK.genSketchSerDe(itemDataType))
+ override def deserialize(storageFormat: Array[Byte]):
ApproxTopKAggregateBuffer[Any] =
+ ApproxTopKAggregateBuffer.deserialize(storageFormat,
ApproxTopK.genSketchSerDe(itemDataType))
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -214,7 +219,7 @@ object ApproxTopK {
def getResultDataType(itemDataType: DataType): DataType = {
val resultEntryType = StructType(
- StructField("item", itemDataType, nullable = false) ::
+ StructField("item", itemDataType, nullable = true) ::
StructField("count", LongType, nullable = false) :: Nil)
ArrayType(resultEntryType, containsNull = false)
}
@@ -238,7 +243,7 @@ object ApproxTopK {
math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt
}
- def createAggregationBuffer(itemExpression: Expression, maxMapSize: Int):
ItemsSketch[Any] = {
+ def createItemsSketch(itemExpression: Expression, maxMapSize: Int):
ItemsSketch[Any] = {
itemExpression.dataType match {
case _: BooleanType =>
new ItemsSketch[Boolean](maxMapSize).asInstanceOf[ItemsSketch[Any]]
@@ -369,6 +374,145 @@ object ApproxTopK {
}
}
+/**
+ * In internal class used as the aggregation buffer for ApproxTopK.
+ *
+ * @param sketch the ItemsSketch instance for counting not-null items
+ * @param nullCount the count of null items
+ */
+class ApproxTopKAggregateBuffer[T](val sketch: ItemsSketch[T], private var
nullCount: Long) {
+ def update(itemExpression: Expression, input: InternalRow):
ApproxTopKAggregateBuffer[T] = {
+ val v = itemExpression.eval(input)
+ if (v != null) {
+ itemExpression.dataType match {
+ case _: BooleanType =>
+
sketch.asInstanceOf[ItemsSketch[Boolean]].update(v.asInstanceOf[Boolean])
+ case _: ByteType =>
+ sketch.asInstanceOf[ItemsSketch[Byte]].update(v.asInstanceOf[Byte])
+ case _: ShortType =>
+ sketch.asInstanceOf[ItemsSketch[Short]].update(v.asInstanceOf[Short])
+ case _: IntegerType =>
+ sketch.asInstanceOf[ItemsSketch[Int]].update(v.asInstanceOf[Int])
+ case _: LongType =>
+ sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
+ case _: FloatType =>
+ sketch.asInstanceOf[ItemsSketch[Float]].update(v.asInstanceOf[Float])
+ case _: DoubleType =>
+
sketch.asInstanceOf[ItemsSketch[Double]].update(v.asInstanceOf[Double])
+ case _: DateType =>
+ sketch.asInstanceOf[ItemsSketch[Int]].update(v.asInstanceOf[Int])
+ case _: TimestampType =>
+ sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
+ case _: TimestampNTZType =>
+ sketch.asInstanceOf[ItemsSketch[Long]].update(v.asInstanceOf[Long])
+ case st: StringType =>
+ val cKey =
CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId)
+ sketch.asInstanceOf[ItemsSketch[String]].update(cKey.toString)
+ case _: DecimalType =>
+
sketch.asInstanceOf[ItemsSketch[Decimal]].update(v.asInstanceOf[Decimal])
+ }
+ } else {
+ nullCount += 1
+ }
+ this
+ }
+
+ def merge(other: ApproxTopKAggregateBuffer[T]): ApproxTopKAggregateBuffer[T]
= {
+ sketch.merge(other.sketch)
+ nullCount += other.nullCount
+ this
+ }
+
+ /**
+ * Serialize the buffer into bytes.
+ * The format is:
+ * [sketch bytes][null count (8 bytes Long)]
+ */
+ def serialize(serDe: ArrayOfItemsSerDe[T]): Array[Byte] = {
+ val sketchBytes = sketch.toByteArray(serDe)
+ val result = new Array[Byte](sketchBytes.length + java.lang.Long.BYTES)
+ val byteBuffer = java.nio.ByteBuffer.wrap(result)
+ byteBuffer.put(sketchBytes)
+ byteBuffer.putLong(nullCount)
+ result
+ }
+
+ /**
+ * Evaluate the buffer and return top K items (including null) with their
estimated frequency.
+ * The result is sorted by frequency in descending order.
+ */
+ def eval(k: Int, itemDataType: DataType): GenericArrayData = {
+ // frequent items from sketch
+ val frequentItems = sketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES)
+ // total number of frequent items (including null, if any)
+ val itemsLength = frequentItems.length + (if (nullCount > 0) 1 else 0)
+ // actual number of items to return
+ val resultLength = math.min(itemsLength, k)
+ val result = new Array[AnyRef](resultLength)
+
+ // variable pointers for merging frequent items and nullCount into result
+ var fiIndex = 0 // pointer for frequentItems
+ var resultIndex = 0 // pointer for result
+ var isNullAdded = false // whether nullCount has been added to result
+
+ // helper function to get nullCount estimate: if nullCount has been added,
return Long.MinValue
+ // so that it won't be added again; otherwise return nullCount
+ @inline def getNullEstimate: Long = if (!isNullAdded) nullCount else
Long.MinValue
+
+ // looping until result is full or run out of frequent items
+ while (resultIndex < resultLength && fiIndex < frequentItems.length) {
+ val curFrequentItem = frequentItems(fiIndex)
+ val itemEstimate = curFrequentItem.getEstimate
+ val nullEstimate = getNullEstimate
+
+ val (item, estimate) = if (nullEstimate > itemEstimate) {
+ // insert (null, nullCount) into result
+ isNullAdded = true
+ (null, nullCount.toLong)
+ } else {
+ // insert frequent item into result
+ val item: Any = itemDataType match {
+ case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType |
+ _: LongType | _: FloatType | _: DoubleType | _: DecimalType |
+ _: DateType | _: TimestampType | _: TimestampNTZType =>
+ curFrequentItem.getItem
+ case _: StringType =>
+ UTF8String.fromString(curFrequentItem.getItem.asInstanceOf[String])
+ }
+ fiIndex += 1 // move to next frequent item
+ (item, itemEstimate)
+ }
+ result(resultIndex) = InternalRow(item, estimate)
+ resultIndex += 1 // move to next result position
+ }
+
+ // in case there is still space in result and nullCount > 0 has not been
added
+ if (resultIndex < resultLength && nullCount > 0 && !isNullAdded) {
+ result(resultIndex) = InternalRow(null, nullCount.toLong)
+ }
+
+ new GenericArrayData(result)
+ }
+}
+
+object ApproxTopKAggregateBuffer {
+ /**
+ * Deserialize the buffer from bytes.
+ * The format is:
+ * [sketch bytes][null count (8 bytes)]
+ */
+ def deserialize(bytes: Array[Byte], serDe: ArrayOfItemsSerDe[Any]):
+ ApproxTopKAggregateBuffer[Any] = {
+ val byteBuffer = java.nio.ByteBuffer.wrap(bytes)
+ val sketchBytesLength = bytes.length - 8
+ val sketchBytes = new Array[Byte](sketchBytesLength)
+ byteBuffer.get(sketchBytes, 0, sketchBytesLength)
+ val nullCount = byteBuffer.getLong(sketchBytesLength)
+ val deserializedSketch = ItemsSketch.getInstance(Memory.wrap(sketchBytes),
serDe)
+ new ApproxTopKAggregateBuffer[Any](deserializedSketch, nullCount)
+ }
+}
+
/**
* An aggregate function that accumulates items into a sketch, which can then
be used
* to combine with other sketches, via ApproxTopKCombine,
@@ -450,7 +594,7 @@ case class ApproxTopKAccumulate(
override def createAggregationBuffer(): ItemsSketch[Any] = {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
- ApproxTopK.createAggregationBuffer(expr, maxMapSize)
+ ApproxTopK.createItemsSketch(expr, maxMapSize)
}
override def update(buffer: ItemsSketch[Any], input: InternalRow):
ItemsSketch[Any] =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
index 702f361ace28..d9d16d1234b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
@@ -196,13 +196,62 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
)
}
- test("SPARK-52515: does not count NULL values") {
+ test("SPARK-53947: count NULL values") {
val res = sql(
"SELECT approx_top_k(expr, 2)" +
- "FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL AS tab(expr);")
+ "FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL, NULL AS
tab(expr);")
+ checkAnswer(res, Row(Seq(Row(null, 4), Row("b", 3))))
+ }
+
+ test("SPARK-53947: null is not in top k") {
+ val res = sql(
+ "SELECT approx_top_k(expr, 2) FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL
AS tab(expr)"
+ )
checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2))))
}
+ test("SPARK-53947: null is the last in top k") {
+ val res = sql(
+ "SELECT approx_top_k(expr, 3) FROM VALUES 0, 0, 1, 1, 1, NULL AS
tab(expr)"
+ )
+ checkAnswer(res, Row(Seq(Row(1, 3), Row(0, 2), Row(null, 1))))
+ }
+
+ test("SPARK-53947: null + frequent items < k") {
+ val res = sql(
+ """SELECT approx_top_k(expr, 5)
+ |FROM VALUES cast(0.0 AS DECIMAL(4, 1)), cast(0.0 AS DECIMAL(4, 1)),
+ |cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS
DECIMAL(4, 1)),
+ |NULL AS tab(expr)""".stripMargin)
+ checkAnswer(
+ res,
+ Row(Seq(Row(new java.math.BigDecimal("0.1"), 3),
+ Row(new java.math.BigDecimal("0.0"), 2),
+ Row(null, 1))))
+ }
+
+ test("SPARK-53947: work on typed column with only NULL values") {
+ val res = sql(
+ "SELECT approx_top_k(expr) FROM VALUES cast(NULL AS INT), cast(NULL AS
INT) AS tab(expr)"
+ )
+ checkAnswer(res, Row(Seq(Row(null, 2))))
+ }
+
+ test("SPARK-53947: invalid item void columns") {
+ checkError(
+ exception = intercept[ExtendedAnalysisException] {
+ sql("SELECT approx_top_k(expr) FROM VALUES (NULL), (NULL), (NULL) AS
tab(expr)")
+ },
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ parameters = Map(
+ "sqlExpr" -> "\"approx_top_k(expr, 5, 10000)\"",
+ "msg" -> "void columns are not supported",
+ "hint" -> ""
+ ),
+ queryContext = Array(ExpectedContext("approx_top_k(expr)", 7, 24))
+ )
+ }
+
/////////////////////////////////
// approx_top_k_accumulate and
// approx_top_k_estimate tests
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]