This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a8cfe0c682c5 [SPARK-53960][SQL] Let
approx_top_k_accumulate/combine/estimate handle NULLs
a8cfe0c682c5 is described below
commit a8cfe0c682c57d2c7346fed20a0c2e76a2c7ff99
Author: yhuang-db <[email protected]>
AuthorDate: Tue Oct 21 08:49:30 2025 -0700
[SPARK-53960][SQL] Let approx_top_k_accumulate/combine/estimate handle NULLs
### What changes were proposed in this pull request?
As a follow-up of https://github.com/apache/spark/pull/52655, add NULL
handling in approx_top_k_accumulate/estimate/combine.
### Why are the changes needed?
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit tests on null handling for accumulate, combine and estimate.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52673 from yhuang-db/accumulate_estimate_count_null.
Authored-by: yhuang-db <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../expressions/ApproxTopKExpressions.scala | 12 +-
.../aggregate/ApproxTopKAggregates.scala | 114 ++----
.../org/apache/spark/sql/ApproxTopKSuite.scala | 436 +++++++++++++++------
3 files changed, 363 insertions(+), 199 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
index 3e2d12fc5b17..53c37f0a5491 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
@@ -17,14 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.datasketches.frequencies.ItemsSketch
-import org.apache.datasketches.memory.Memory
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK
+import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK,
ApproxTopKAggregateBuffer}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
@@ -105,9 +102,10 @@ case class ApproxTopKEstimate(state: Expression, k:
Expression)
val kVal = kEval.asInstanceOf[Int]
ApproxTopK.checkK(kVal)
ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal)
- val itemsSketch = ItemsSketch.getInstance(
- Memory.wrap(dataSketchBytes), ApproxTopK.genSketchSerDe(itemDataType))
- ApproxTopK.genEvalResult(itemsSketch, kVal, itemDataType)
+ val approxTopKAggregateBuffer = ApproxTopKAggregateBuffer.deserialize(
+ dataSketchBytes,
+ ApproxTopK.genSketchSerDe(itemDataType))
+ approxTopKAggregateBuffer.eval(kVal, itemDataType)
}
override protected def withNewChildrenInternal(newState: Expression, newK:
Expression)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
index 1fca8ad86bc2..7ae542f190d5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
@@ -260,54 +260,6 @@ object ApproxTopK {
}
}
- def updateSketchBuffer(
- itemExpression: Expression,
- buffer: ItemsSketch[Any],
- input: InternalRow): ItemsSketch[Any] = {
- val v = itemExpression.eval(input)
- if (v != null) {
- itemExpression.dataType match {
- case _: BooleanType => buffer.update(v.asInstanceOf[Boolean])
- case _: ByteType => buffer.update(v.asInstanceOf[Byte])
- case _: ShortType => buffer.update(v.asInstanceOf[Short])
- case _: IntegerType => buffer.update(v.asInstanceOf[Int])
- case _: LongType => buffer.update(v.asInstanceOf[Long])
- case _: FloatType => buffer.update(v.asInstanceOf[Float])
- case _: DoubleType => buffer.update(v.asInstanceOf[Double])
- case _: DateType => buffer.update(v.asInstanceOf[Int])
- case _: TimestampType => buffer.update(v.asInstanceOf[Long])
- case _: TimestampNTZType => buffer.update(v.asInstanceOf[Long])
- case st: StringType =>
- val cKey =
CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId)
- buffer.update(cKey.toString)
- case _: DecimalType => buffer.update(v.asInstanceOf[Decimal])
- }
- }
- buffer
- }
-
- def genEvalResult(
- itemsSketch: ItemsSketch[Any],
- k: Int,
- itemDataType: DataType): GenericArrayData = {
- val items = itemsSketch.getFrequentItems(ErrorType.NO_FALSE_POSITIVES)
- val resultLength = math.min(items.length, k)
- val result = new Array[AnyRef](resultLength)
- for (i <- 0 until resultLength) {
- val row = items(i)
- itemDataType match {
- case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType |
- _: LongType | _: FloatType | _: DoubleType | _: DecimalType |
- _: DateType | _: TimestampType | _: TimestampNTZType =>
- result(i) = InternalRow.apply(row.getItem, row.getEstimate)
- case _: StringType =>
- val item = UTF8String.fromString(row.getItem.asInstanceOf[String])
- result(i) = InternalRow.apply(item, row.getEstimate)
- }
- }
- new GenericArrayData(result)
- }
-
def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = {
dataType match {
case _: BooleanType => new
ArrayOfBooleansSerDe().asInstanceOf[ArrayOfItemsSerDe[Any]]
@@ -333,7 +285,7 @@ object ApproxTopK {
def dataTypeToDDL(dataType: DataType): String = dataType match {
case _: StringType =>
- // Hide collation information in DDL format
+ // Hide collation information in DDL format, otherwise
CollationExpressionWalkerSuite fails
s"item string not null"
case other =>
StructField("item", other, nullable = false).toDDL
@@ -552,7 +504,7 @@ case class ApproxTopKAccumulate(
maxItemsTracked: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends TypedImperativeAggregate[ItemsSketch[Any]]
+ extends TypedImperativeAggregate[ApproxTopKAggregateBuffer[Any]]
with ImplicitCastInputTypes
with BinaryLike[Expression] {
@@ -592,18 +544,23 @@ case class ApproxTopKAccumulate(
override def dataType: DataType =
ApproxTopK.getSketchStateDataType(itemDataType)
- override def createAggregationBuffer(): ItemsSketch[Any] = {
+ override def createAggregationBuffer(): ApproxTopKAggregateBuffer[Any] = {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
- ApproxTopK.createItemsSketch(expr, maxMapSize)
+ val sketch = ApproxTopK.createItemsSketch(expr, maxMapSize)
+ new ApproxTopKAggregateBuffer[Any](sketch, 0L)
}
- override def update(buffer: ItemsSketch[Any], input: InternalRow):
ItemsSketch[Any] =
- ApproxTopK.updateSketchBuffer(expr, buffer, input)
+ override def update(buffer: ApproxTopKAggregateBuffer[Any], input:
InternalRow):
+ ApproxTopKAggregateBuffer[Any] =
+ buffer.update(expr, input)
- override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]):
ItemsSketch[Any] =
+ override def merge(
+ buffer: ApproxTopKAggregateBuffer[Any],
+ input: ApproxTopKAggregateBuffer[Any]):
+ ApproxTopKAggregateBuffer[Any] =
buffer.merge(input)
- override def eval(buffer: ItemsSketch[Any]): Any = {
+ override def eval(buffer: ApproxTopKAggregateBuffer[Any]): Any = {
val sketchBytes = serialize(buffer)
val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
InternalRow.apply(
@@ -613,11 +570,11 @@ case class ApproxTopKAccumulate(
UTF8String.fromString(itemDataTypeDDL))
}
- override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
- buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType))
+ override def serialize(buffer: ApproxTopKAggregateBuffer[Any]): Array[Byte] =
+ buffer.serialize(ApproxTopK.genSketchSerDe(itemDataType))
- override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] =
- ItemsSketch.getInstance(Memory.wrap(storageFormat),
ApproxTopK.genSketchSerDe(itemDataType))
+ override def deserialize(storageFormat: Array[Byte]):
ApproxTopKAggregateBuffer[Any] =
+ ApproxTopKAggregateBuffer.deserialize(storageFormat,
ApproxTopK.genSketchSerDe(itemDataType))
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -644,10 +601,10 @@ case class ApproxTopKAccumulate(
* @param maxItemsTracked the maximum number of items tracked in the sketch
*/
class CombineInternal[T](
- sketch: ItemsSketch[T],
+ sketchWithNullCount: ApproxTopKAggregateBuffer[T],
var itemDataType: DataType,
var maxItemsTracked: Int) {
- def getSketch: ItemsSketch[T] = sketch
+ def getSketchWithNullCount: ApproxTopKAggregateBuffer[T] =
sketchWithNullCount
def getItemDataType: DataType = itemDataType
@@ -689,6 +646,9 @@ class CombineInternal[T](
}
}
+ def updateSketchWithNullCount(otherSketchWithNullCount:
ApproxTopKAggregateBuffer[T]): Unit =
+ sketchWithNullCount.merge(otherSketchWithNullCount)
+
/**
* Serialize the CombineInternal instance to a byte array.
* Serialization format:
@@ -698,18 +658,18 @@ class CombineInternal[T](
* sketchBytes
*/
def serialize(): Array[Byte] = {
- val sketchBytes = sketch.toByteArray(
+ val sketchWithNullCountBytes = sketchWithNullCount.serialize(
ApproxTopK.genSketchSerDe(itemDataType).asInstanceOf[ArrayOfItemsSerDe[T]])
val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
val ddlBytes: Array[Byte] =
itemDataTypeDDL.getBytes(StandardCharsets.UTF_8)
val byteArray = new Array[Byte](
- sketchBytes.length + Integer.BYTES + Integer.BYTES + ddlBytes.length)
+ sketchWithNullCountBytes.length + Integer.BYTES + Integer.BYTES +
ddlBytes.length)
val byteBuffer = ByteBuffer.wrap(byteArray)
byteBuffer.putInt(maxItemsTracked)
byteBuffer.putInt(ddlBytes.length)
byteBuffer.put(ddlBytes)
- byteBuffer.put(sketchBytes)
+ byteBuffer.put(sketchWithNullCountBytes)
byteArray
}
}
@@ -736,9 +696,9 @@ object CombineInternal {
// read sketchBytes
val sketchBytes = new Array[Byte](buffer.length - Integer.BYTES -
Integer.BYTES - ddlLength)
byteBuffer.get(sketchBytes)
- val sketch = ItemsSketch.getInstance(
- Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType))
- new CombineInternal[Any](sketch, itemDataType, maxItemsTracked)
+ val sketchWithNullCount = ApproxTopKAggregateBuffer.deserialize(
+ sketchBytes, ApproxTopK.genSketchSerDe(itemDataType))
+ new CombineInternal[Any](sketchWithNullCount, itemDataType,
maxItemsTracked)
}
}
@@ -833,7 +793,7 @@ case class ApproxTopKCombine(
if (combineSizeSpecified) {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
new CombineInternal[Any](
- new ItemsSketch[Any](maxMapSize),
+ new ApproxTopKAggregateBuffer[Any](new ItemsSketch[Any](maxMapSize),
0L),
null,
maxItemsTrackedVal)
} else {
@@ -842,7 +802,7 @@ case class ApproxTopKCombine(
// The actual maxItemsTracked will be checked during the updates.
val maxMapSize =
ApproxTopK.calMaxMapSize(ApproxTopK.MAX_ITEMS_TRACKED_LIMIT)
new CombineInternal[Any](
- new ItemsSketch[Any](maxMapSize),
+ new ApproxTopKAggregateBuffer[Any](new ItemsSketch[Any](maxMapSize),
0L),
null,
ApproxTopK.VOID_MAX_ITEMS_TRACKED)
}
@@ -863,9 +823,9 @@ case class ApproxTopKCombine(
// update itemDataType (throw error if not match)
buffer.updateItemDataType(inputItemDataType)
// update sketch
- val inputSketch = ItemsSketch.getInstance(
- Memory.wrap(inputSketchBytes),
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
- buffer.getSketch.merge(inputSketch)
+ val inputSketchWithNullCount = ApproxTopKAggregateBuffer.deserialize(
+ inputSketchBytes, ApproxTopK.genSketchSerDe(inputItemDataType))
+ buffer.updateSketchWithNullCount(inputSketchWithNullCount)
buffer
}
@@ -876,14 +836,14 @@ case class ApproxTopKCombine(
buffer.updateMaxItemsTracked(combineSizeSpecified,
input.getMaxItemsTracked)
// update itemDataType (throw error if not match)
buffer.updateItemDataType(input.getItemDataType)
- // update sketch
- buffer.getSketch.merge(input.getSketch)
+ // update sketchWithNullCount
+ buffer.getSketchWithNullCount.merge(input.getSketchWithNullCount)
buffer
}
override def eval(buffer: CombineInternal[Any]): Any = {
- val sketchBytes =
-
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ val sketchBytes = buffer.getSketchWithNullCount
+ .serialize(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
val maxItemsTracked = buffer.getMaxItemsTracked
val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(buffer.getItemDataType)
InternalRow.apply(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
index d9d16d1234b7..982c9ff90da7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
@@ -381,6 +381,77 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
)
}
+ test("SPARK-53960: accumulate and estimate count NULL values") {
+ val res = sql(
+ """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2)
+ |FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL, NULL, NULL, NULL AS
tab(expr)""".stripMargin)
+ checkAnswer(res, Row(Seq(Row(null, 4), Row("b", 3))))
+ }
+
+ test("SPARK-53960: accumulate and estimate null is not in top k") {
+ val res = sql(
+ """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2)
+ |FROM VALUES 'a', 'a', 'b', 'b', 'b', NULL AS tab(expr)""".stripMargin)
+ checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2))))
+ }
+
+ test("SPARK-53960: accumulate and estimate null is the last in top k") {
+ val res = sql(
+ """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 3)
+ |FROM VALUES 0, 0, 1, 1, 1, NULL AS tab(expr)""".stripMargin)
+ checkAnswer(res, Row(Seq(Row(1, 3), Row(0, 2), Row(null, 1))))
+ }
+
+ test("SPARK-53960: accumulate and estimate null + frequent items < k") {
+ val res = sql(
+ """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 5)
+ |FROM VALUES cast(0.0 AS DECIMAL(4, 1)), cast(0.0 AS DECIMAL(4, 1)),
+ |cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS DECIMAL(4, 1)), cast(0.1 AS
DECIMAL(4, 1)),
+ |NULL AS tab(expr)""".stripMargin)
+ checkAnswer(
+ res,
+ Row(Seq(Row(new java.math.BigDecimal("0.1"), 3),
+ Row(new java.math.BigDecimal("0.0"), 2),
+ Row(null, 1))))
+ }
+
+ test("SPARK-53960: accumulate and estimate work on typed column with only
NULL values") {
+ val res = sql(
+ """SELECT approx_top_k_estimate(approx_top_k_accumulate(expr))
+ |FROM VALUES cast(NULL AS INT), cast(NULL AS INT) AS
tab(expr)""".stripMargin)
+ checkAnswer(res, Row(Seq(Row(null, 2))))
+ }
+
+ test("SPARK-53960: accumulate a column of all nulls with type - success") {
+ withView("accumulation") {
+ val res = sql(
+ """SELECT approx_top_k_accumulate(expr) AS acc
+ |FROM VALUES cast(NULL AS INT), cast(NULL AS INT) AS
tab(expr)""".stripMargin)
+
+ assert(res.collect().length == 1)
+ res.createOrReplaceTempView("accumulation")
+ val est = sql("SELECT approx_top_k_estimate(acc) FROM accumulation;")
+ checkAnswer(est, Row(Seq(Row(null, 2))))
+
+ }
+ }
+
+ test("SPARK-53960: accumulate a column of all nulls without type - fail") {
+ checkError(
+ exception = intercept[ExtendedAnalysisException] {
+ sql("""SELECT approx_top_k_accumulate(expr)
+ |FROM VALUES (NULL), (NULL), (NULL), (NULL) AS
tab(expr)""".stripMargin)
+ },
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ parameters = Map(
+ "sqlExpr" -> "\"approx_top_k_accumulate(expr, 10000)\"",
+ "msg" -> "void columns are not supported",
+ "hint" -> ""
+ ),
+ queryContext = Array(ExpectedContext("approx_top_k_accumulate(expr)", 7,
35))
+ )
+ }
+
/////////////////////////////////
// approx_top_k_combine
/////////////////////////////////
@@ -445,75 +516,87 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
// positive tests for approx_top_k_combine on every types
gridTest("SPARK-52798: same type, same size, specified combine size -
success")(itemsWithTopK) {
case (input, expected) =>
- sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS
tab(expr);")
- .createOrReplaceTempView("accumulation1")
- sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS
tab(expr);")
- .createOrReplaceTempView("accumulation2")
- sql("SELECT approx_top_k_combine(acc, 30) as com " +
- "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2);")
- .createOrReplaceTempView("combined")
- val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
- // expected should be doubled because we combine two identical sketches
- val expectedDoubled = expected.map {
- case Row(value: Any, count: Int) => Row(value, count * 2)
+ withView("accumulation1", "accumulation2", "combines") {
+ sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input
AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+ sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input
AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+ sql("SELECT approx_top_k_combine(acc, 30) as com " +
+ "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2);")
+ .createOrReplaceTempView("combined")
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ // expected should be doubled because we combine two identical sketches
+ val expectedDoubled = expected.map {
+ case Row(value: Any, count: Int) => Row(value, count * 2)
+ }
+ checkAnswer(est, Row(expectedDoubled))
}
- checkAnswer(est, Row(expectedDoubled))
}
test("SPARK-52798: same type, same size, specified combine size - success") {
- setupMixedSizeAccumulations(10, 10)
+ withView("accumulation1", "accumulation2", "unioned", "combined") {
+ setupMixedSizeAccumulations(10, 10)
- sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
- .createOrReplaceTempView("combined")
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combined")
- val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
- checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
}
test("SPARK-52798: same type, same size, unspecified combine size -
success") {
- setupMixedSizeAccumulations(10, 10)
+ withView("accumulation1", "accumulation2", "unioned", "combined") {
+ setupMixedSizeAccumulations(10, 10)
- sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
- .createOrReplaceTempView("combined")
+ sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
+ .createOrReplaceTempView("combined")
- val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
- checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
}
test("SPARK-52798: same type, different size, specified combine size -
success") {
- setupMixedSizeAccumulations(10, 20)
+ withView("accumulation1", "accumulation2", "unioned", "combined") {
+ setupMixedSizeAccumulations(10, 20)
- sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
- .createOrReplaceTempView("combination")
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combined")
- val est = sql("SELECT approx_top_k_estimate(com) FROM combination;")
- checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
}
test("SPARK-52798: same type, different size, unspecified combine size -
fail") {
- setupMixedSizeAccumulations(10, 20)
+ withView("accumulation1", "accumulation2", "unioned") {
+ setupMixedSizeAccumulations(10, 20)
- val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
-
- checkError(
- exception = intercept[SparkRuntimeException] {
- comb.collect()
- },
- condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
- parameters = Map("size1" -> "10", "size2" -> "20")
- )
- }
+ val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
- gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10,
20))) {
- case (size1, size2) =>
- setupMixedSizeAccumulations(size1, size2)
checkError(
exception = intercept[SparkRuntimeException] {
- sql("SELECT approx_top_k_combine(acc, 0) as com FROM
unioned").collect()
+ comb.collect()
},
- condition = "APPROX_TOP_K_NON_POSITIVE_ARG",
- parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0")
+ condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+ parameters = Map("size1" -> "10", "size2" -> "20")
)
+ }
+ }
+
+ gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10,
20))) {
+ case (size1, size2) =>
+ withView("accumulation1", "accumulation2", "unioned") {
+ setupMixedSizeAccumulations(size1, size2)
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 0) as com FROM
unioned").collect()
+ },
+ condition = "APPROX_TOP_K_NON_POSITIVE_ARG",
+ parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0")
+ )
+ }
}
test("SPARK-52798: among different number or datetime types - fail at
combine") {
@@ -523,13 +606,15 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
val (type1, _, seq1) = mixedTypeSeq(i)
val (type2, _, seq2) = mixedTypeSeq(j)
setupMixedTypeAccumulation(seq1, seq2)
- checkError(
- exception = intercept[SparkRuntimeException] {
- sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
- },
- condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
- parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(type2))
- )
+ withView("accumulation1", "accumulation2", "unioned") {
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(type2))
+ )
+ }
}
}
}
@@ -547,7 +632,9 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
case ((_, type1, seq1), (_, type2, seq2)) =>
checkError(
exception = intercept[ExtendedAnalysisException] {
- setupMixedTypeAccumulation(seq1, seq2)
+ withView("accumulation1", "accumulation2", "unioned") {
+ setupMixedTypeAccumulation(seq1, seq2)
+ }
},
condition = "INCOMPATIBLE_COLUMN_TYPE",
parameters = Map(
@@ -568,14 +655,17 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
gridTest("SPARK-52798: number vs string - fail at
combine")(mixedNumberTypes) {
case (type1, _, seq1) =>
- setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'",
"'c'", "'d'", "'d'"))
- checkError(
- exception = intercept[SparkRuntimeException] {
- sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
- },
- condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
- parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(StringType))
- )
+ withView("accumulation1", "accumulation2", "unioned") {
+ setupMixedTypeAccumulation(
+ seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'"))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(StringType))
+ )
+ }
}
gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypes)
{
@@ -583,7 +673,9 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
checkError(
exception = intercept[ExtendedAnalysisException] {
- setupMixedTypeAccumulation(seq1, seq2)
+ withView("accumulation1", "accumulation2", "unioned") {
+ setupMixedTypeAccumulation(seq1, seq2)
+ }
},
condition = "INCOMPATIBLE_COLUMN_TYPE",
parameters = Map(
@@ -604,14 +696,17 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
gridTest("SPARK-52798: datetime vs string - fail at
combine")(mixedDateTimeTypes) {
case (type1, _, seq1) =>
- setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'",
"'c'", "'d'", "'d'"))
- checkError(
- exception = intercept[SparkRuntimeException] {
- sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
- },
- condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
- parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(StringType))
- )
+ withView("accumulation1", "accumulation2", "unioned") {
+ setupMixedTypeAccumulation(
+ seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'"))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(StringType))
+ )
+ }
}
gridTest("SPARK-52798: datetime vs boolean - fail at
UNION")(mixedDateTimeTypes) {
@@ -619,7 +714,9 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
checkError(
exception = intercept[ExtendedAnalysisException] {
- setupMixedTypeAccumulation(seq1, seq2)
+ withView("accumulation1", "accumulation2", "unioned") {
+ setupMixedTypeAccumulation(seq1, seq2)
+ }
},
condition = "INCOMPATIBLE_COLUMN_TYPE",
parameters = Map(
@@ -641,65 +738,174 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
test("SPARK-52798: string vs boolean - fail at combine") {
val seq1 = Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")
val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
- setupMixedTypeAccumulation(seq1, seq2)
- checkError(
- exception = intercept[SparkRuntimeException] {
- sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
- },
- condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
- parameters = Map("type1" -> toSQLType(StringType), "type2" ->
toSQLType(BooleanType))
- )
+ withView("accumulation1", "accumulation2", "unioned") {
+ setupMixedTypeAccumulation(seq1, seq2)
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(StringType), "type2" ->
toSQLType(BooleanType))
+ )
+ }
}
test("SPARK-52798: combine more than 2 sketches with specified size") {
- sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
- "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
- .createOrReplaceTempView("accumulation1")
+ withView("accumulation1", "accumulation2", "accumulation3", "unioned",
"combined") {
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
- sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
- "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
- .createOrReplaceTempView("accumulation2")
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
- sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
- "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
- .createOrReplaceTempView("accumulation3")
+ sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
+ "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
+ .createOrReplaceTempView("accumulation3")
- sql("SELECT acc from accumulation1 UNION ALL " +
- "SELECT acc FROM accumulation2 UNION ALL " +
- "SELECT acc FROM accumulation3")
- .createOrReplaceTempView("unioned")
+ sql("SELECT acc from accumulation1 UNION ALL " +
+ "SELECT acc FROM accumulation2 UNION ALL " +
+ "SELECT acc FROM accumulation3")
+ .createOrReplaceTempView("unioned")
- sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
- .createOrReplaceTempView("combined")
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combined")
- val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
- checkAnswer(est, Row(Seq(Row(2, 6), Row(3, 5), Row(1, 4), Row(0, 3),
Row(4, 2))))
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 6), Row(3, 5), Row(1, 4), Row(0, 3),
Row(4, 2))))
+ }
}
test("SPARK-52798: combine more than 2 sketches without specified size") {
- sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
- "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
- .createOrReplaceTempView("accumulation1")
+ withView("accumulation1", "accumulation2", "accumulation3", "unioned") {
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
- sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
- "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
- .createOrReplaceTempView("accumulation2")
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
- sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
- "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
- .createOrReplaceTempView("accumulation3")
+ sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
+ "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
+ .createOrReplaceTempView("accumulation3")
- sql("SELECT acc from accumulation1 UNION ALL " +
- "SELECT acc FROM accumulation2 UNION ALL " +
- "SELECT acc FROM accumulation3")
- .createOrReplaceTempView("unioned")
+ sql("SELECT acc from accumulation1 UNION ALL " +
+ "SELECT acc FROM accumulation2 UNION ALL " +
+ "SELECT acc FROM accumulation3")
+ .createOrReplaceTempView("unioned")
- checkError(
- exception = intercept[SparkRuntimeException] {
- sql("SELECT approx_top_k_combine(acc) as com FROM unioned").collect()
- },
- condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
- parameters = Map("size1" -> "10", "size2" -> "20")
- )
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc) as com FROM unioned").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+ parameters = Map("size1" -> "10", "size2" -> "20")
+ )
+ }
+ }
+
+ test("SPARK-53960: combine and estimate count NULL values") {
+ withView("accumulation1", "accumulation2", "unioned", "combined") {
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES 'a', 'a', 'b', NULL, NULL AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation1")
+
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES 'b', 'b', NULL, NULL AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation2")
+
+ sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2")
+ .createOrReplaceTempView("unioned")
+
+ sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com, 2) FROM combined")
+ checkAnswer(est, Row(Seq(Row(null, 4), Row("b", 3))))
+ }
+ }
+
+ test("SPARK-53960: combine with a sketch of all nulls") {
+ withView("accumulation1", "accumulation2", "unioned", "combined") {
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES cast(NULL AS INT), cast(NULL AS INT), cast(NULL AS INT)
+ |AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation1")
+
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES 1, 1, 2, 2 AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation2")
+
+ sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2")
+ .createOrReplaceTempView("unioned")
+
+ sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined")
+ checkAnswer(est, Row(Seq(Row(null, 3), Row(2, 2), Row(1, 2))))
+ }
+ }
+
+ test("SPARK-53960: combine sketches with nulls from more than 2 sketches") {
+ withView("accumulation1", "accumulation2", "accumulation3", "unioned",
"combined") {
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES 0, 0, 0, 1, 1, NULL AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation1")
+
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES NULL, 1, 1, 2, 2, NULL AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation2")
+
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES 2, 3, 3, NULL AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation3")
+
+ sql(
+ """SELECT acc from accumulation1 UNION ALL
+ |SELECT acc FROM accumulation2 UNION ALL
+ |SELECT acc FROM accumulation3""".stripMargin)
+ .createOrReplaceTempView("unioned")
+
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com, 2) FROM combined")
+ checkAnswer(est, Row(Seq(Row(1, 4), Row(null, 4))))
+ }
+ }
+
+ test("SPARK-53960: combine 2 sketches with all nulls") {
+ withView("accumulation1", "accumulation2", "unioned", "combined") {
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES cast(NULL AS INT), cast(NULL AS INT), cast(NULL AS INT)
+ |AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation1")
+
+ sql(
+ """SELECT approx_top_k_accumulate(expr, 10) as acc
+ |FROM VALUES cast(NULL AS INT), cast(NULL AS INT)
+ |AS tab(expr)""".stripMargin)
+ .createOrReplaceTempView("accumulation2")
+
+ sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2")
+ .createOrReplaceTempView("unioned")
+
+ sql("SELECT approx_top_k_combine(acc, 20) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined")
+ checkAnswer(est, Row(Seq(Row(null, 5))))
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]