This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1c0bca91d270 [SPARK-52798][SQL] Add function approx_top_k_combine
1c0bca91d270 is described below
commit 1c0bca91d270940800085748cd501e5421b69bb1
Author: yhuang-db <[email protected]>
AuthorDate: Fri Oct 17 19:51:44 2025 -0700
[SPARK-52798][SQL] Add function approx_top_k_combine
### What changes were proposed in this pull request?
This PR adds a SQL function: `approx_top_k_accumulate`, an aggregation
function that merges multiple sketches into a single sketch.
**Syntax**
```sql
approx_top_k_combine(expr[, maxItemsTracked])
```
**Arguments**
- `expr`: An expression of sketch structs
- `maxItemsTracked`: An optional INTEGER literal. If maxItemsTracked is
specified, use this value for the newly generated combined sketch. If
maxItemsTracked is not specified, all input sketches must have the same
maxItemsTracked, and the output sketch would use the same value as well.
**Returns**
The return of this function is a STRUCT with four fields: sketch,
itemDataType, maxItemsTracked and typeCode. The return is exactly the same as
for approx_top_k_accumulate.
### Why are the changes needed?
They are useful sibling functions for approx_top_k queries.
### Does this PR introduce _any_ user-facing change?
Yes, this PR introduces a new user-facing SQL function. See user examples
as below.
```sql
SELECT approx_top_k_estimate(approx_top_k_combine(sketch, 10000), 5) FROM
(SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1)
AS tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM
VALUES (2), (3), (4), (4) AS tab(expr))
```
### How was this patch tested?
Unit tests for end-to-end SQL queries and invalid input for expressions.
### Was this patch authored or co-authored using generative AI tooling?
Closes #51505 from yhuang-db/SPARK-52798.
Authored-by: yhuang-db <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 12 +
.../sql/catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/ApproxTopKExpressions.scala | 31 +-
.../aggregate/ApproxTopKAggregates.scala | 346 ++++++++++++++++++++-
.../spark/sql/errors/QueryExecutionErrors.scala | 16 +
.../expressions/aggregate/ApproxTopKSuite.scala | 192 ++++++++++--
.../sql-functions/sql-expression-schema.md | 1 +
.../org/apache/spark/sql/ApproxTopKSuite.scala | 325 +++++++++++++++++++
8 files changed, 873 insertions(+), 51 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index b215438c89d1..448f0e926a6d 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -114,6 +114,18 @@
],
"sqlState" : "22004"
},
+ "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH" : {
+ "message" : [
+ "Combining approx_top_k sketches of different sizes is not allowed.
Found sketches of size <size1> and <size2>."
+ ],
+ "sqlState" : "42846"
+ },
+ "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH" : {
+ "message" : [
+ "Combining approx_top_k sketches of different types is not allowed.
Found sketches of type <type1> and <type2>."
+ ],
+ "sqlState" : "42846"
+ },
"ARITHMETIC_OVERFLOW" : {
"message" : [
"<message>.<alternative> If necessary set <config> to \"false\" to
bypass this error."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index b643d2a11864..0a596a8bd63e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -534,6 +534,7 @@ object FunctionRegistry {
expression[ThetaUnionAgg]("theta_union_agg"),
expression[ThetaIntersectionAgg]("theta_intersection_agg"),
expression[ApproxTopKAccumulate]("approx_top_k_accumulate"),
+ expression[ApproxTopKCombine]("approx_top_k_combine"),
// string functions
expression[Ascii]("ascii"),
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
index 3c9440764a9a..3e2d12fc5b17 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala
@@ -66,8 +66,8 @@ case class ApproxTopKEstimate(state: Expression, k:
Expression)
def this(child: Expression) = this(child, Literal(ApproxTopK.DEFAULT_K))
private lazy val itemDataType: DataType = {
- // itemDataType is the type of the second field of the output of
ACCUMULATE or COMBINE
- state.dataType.asInstanceOf[StructType](1).dataType
+ // itemDataType is the type of the third field of the output of ACCUMULATE
or COMBINE
+ state.dataType.asInstanceOf[StructType](2).dataType
}
override def left: Expression = state
@@ -76,35 +76,12 @@ case class ApproxTopKEstimate(state: Expression, k:
Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
- private def checkStateFieldAndType(state: Expression): TypeCheckResult = {
- val stateStructType = state.dataType.asInstanceOf[StructType]
- if (stateStructType.length != 3) {
- return TypeCheckFailure("State must be a struct with 3 fields. " +
- "Expected struct:
struct<sketch:binary,itemDataType:any,maxItemsTracked:int>. " +
- "Got: " + state.dataType.simpleString)
- }
-
- if (stateStructType.head.dataType != BinaryType) {
- TypeCheckFailure("State struct must have the first field to be binary. "
+
- "Got: " + stateStructType.head.dataType.simpleString)
- } else if (!ApproxTopK.isDataTypeSupported(itemDataType)) {
- TypeCheckFailure("State struct must have the second field to be a
supported data type. " +
- "Got: " + itemDataType.simpleString)
- } else if (stateStructType(2).dataType != IntegerType) {
- TypeCheckFailure("State struct must have the third field to be int. " +
- "Got: " + stateStructType(2).dataType.simpleString)
- } else {
- TypeCheckSuccess
- }
- }
-
-
override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else {
- val stateCheck = checkStateFieldAndType(state)
+ val stateCheck = ApproxTopK.checkStateFieldAndType(state)
if (stateCheck.isFailure) {
stateCheck
} else if (!k.foldable) {
@@ -124,7 +101,7 @@ case class ApproxTopKEstimate(state: Expression, k:
Expression)
val stateEval = left.eval(input)
val kEval = right.eval(input)
val dataSketchBytes = stateEval.asInstanceOf[InternalRow].getBinary(0)
- val maxItemsTrackedVal = stateEval.asInstanceOf[InternalRow].getInt(2)
+ val maxItemsTrackedVal = stateEval.asInstanceOf[InternalRow].getInt(1)
val kVal = kEval.asInstanceOf[Int]
ApproxTopK.checkK(kVal)
ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
index cefe0a14dee5..6c6b3b805048 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
+
import org.apache.datasketches.common._
import org.apache.datasketches.frequencies.{ErrorType, ItemsSketch}
import org.apache.datasketches.memory.Memory
@@ -176,7 +179,9 @@ object ApproxTopK {
val DEFAULT_K: Int = 5
val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000
- private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000
+ val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000
+ // A special value indicating no explicit maxItemsTracked input in function
approx_top_k_combine
+ val VOID_MAX_ITEMS_TRACKED = -1
def checkExpressionNotNull(expr: Expression, exprName: String): Unit = {
if (expr == null || expr.eval() == null) {
@@ -317,8 +322,51 @@ object ApproxTopK {
def getSketchStateDataType(itemDataType: DataType): StructType =
StructType(
StructField("sketch", BinaryType, nullable = false) ::
+ StructField("maxItemsTracked", IntegerType, nullable = false) ::
StructField("itemDataType", itemDataType) ::
- StructField("maxItemsTracked", IntegerType, nullable = false) :: Nil)
+ StructField("itemDataTypeDDL", StringType, nullable = false) :: Nil)
+
+ def dataTypeToDDL(dataType: DataType): String = dataType match {
+ case _: StringType =>
+ // Hide collation information in DDL format
+ s"item string not null"
+ case other =>
+ StructField("item", other, nullable = false).toDDL
+ }
+
+ def DDLToDataType(ddl: String): DataType = {
+ StructType.fromDDL(ddl).fields.head.dataType
+ }
+
+ def checkStateFieldAndType(state: Expression): TypeCheckResult = {
+ val stateStructType = state.dataType.asInstanceOf[StructType]
+ if (stateStructType.length != 4) {
+ return TypeCheckFailure("State must be a struct with 4 fields. " +
+ "Expected struct: " +
+
"struct<sketch:binary,maxItemsTracked:int,itemDataType:any,itemDataTypeDDL:string>.
" +
+ "Got: " + state.dataType.simpleString)
+ }
+
+ val fieldType1 = stateStructType.head.dataType
+ val fieldType2 = stateStructType(1).dataType
+ val fieldType3 = stateStructType(2).dataType
+ val fieldType4 = stateStructType(3).dataType
+ if (fieldType1 != BinaryType) {
+ TypeCheckFailure("State struct must have the first field to be binary. "
+
+ "Got: " + fieldType1.simpleString)
+ } else if (fieldType2 != IntegerType) {
+ TypeCheckFailure("State struct must have the second field to be int. " +
+ "Got: " + fieldType2.simpleString)
+ } else if (!ApproxTopK.isDataTypeSupported(fieldType3)) {
+ TypeCheckFailure("State struct must have the third field to be a
supported data type. " +
+ "Got: " + fieldType3.simpleString)
+ } else if (fieldType4 != StringType) {
+ TypeCheckFailure("State struct must have the fourth field to be string.
" +
+ "Got: " + fieldType4.simpleString)
+ } else {
+ TypeCheckSuccess
+ }
+ }
}
/**
@@ -327,8 +375,11 @@ object ApproxTopK {
* or to estimate the top K items, via ApproxTopKEstimate.
*
* The output of this function is a struct containing the sketch in binary
format,
+ * the maximum number of items tracked by the sketch,
* a null object indicating the type of items in the sketch,
- * and the maximum number of items tracked by the sketch.
+ * and a DDL string representing the data type of items in the sketch.
+ * The null object is used in approx_top_k_estimate,
+ * while the DDL is used in approx_top_k_combine.
*
* @param expr the child expression to accumulate items from
* @param maxItemsTracked the maximum number of items to track in the
sketch
@@ -410,7 +461,12 @@ case class ApproxTopKAccumulate(
override def eval(buffer: ItemsSketch[Any]): Any = {
val sketchBytes = serialize(buffer)
- InternalRow.apply(sketchBytes, null, maxItemsTrackedVal)
+ val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
+ InternalRow.apply(
+ sketchBytes,
+ maxItemsTrackedVal,
+ null,
+ UTF8String.fromString(itemDataTypeDDL))
}
override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
@@ -435,3 +491,285 @@ case class ApproxTopKAccumulate(
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
}
+
+/**
+ * In internal class used as the aggregation buffer for ApproxTopKCombine.
+ *
+ * @param sketch the ItemsSketch instance
+ * @param itemDataType the data type of items in the sketch
+ * @param maxItemsTracked the maximum number of items tracked in the sketch
+ */
+class CombineInternal[T](
+ sketch: ItemsSketch[T],
+ var itemDataType: DataType,
+ var maxItemsTracked: Int) {
+ def getSketch: ItemsSketch[T] = sketch
+
+ def getItemDataType: DataType = itemDataType
+
+ def getMaxItemsTracked: Int = maxItemsTracked
+
+ def updateMaxItemsTracked(combineSizeSpecified: Boolean, newMaxItemsTracked:
Int): Unit = {
+ if (!combineSizeSpecified) {
+ // check size
+ if (this.maxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
+ // If buffer's maxItemsTracked VOID_MAX_ITEMS_TRACKED, it means the
buffer is a placeholder
+ // sketch that has not beed updated by any input sketch yet.
+ // So we can set it to the input sketch's max items tracked.
+ this.maxItemsTracked = newMaxItemsTracked
+ } else {
+ if (this.maxItemsTracked != newMaxItemsTracked) {
+ // If buffer's maxItemsTracked is not VOID_MAX_ITEMS_TRACKED, it
means the buffer has been
+ // updated by some input sketch. So if buffer and input sketch have
different
+ // maxItemsTracked values, it means at least two of the input
sketches have different
+ // maxItemsTracked values. In this case, we should throw an error.
+ throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
+ this.maxItemsTracked, newMaxItemsTracked)
+ }
+ }
+ }
+ }
+
+ def updateItemDataType(inputItemDataType: DataType): Unit = {
+ // When the buffer's dataType hasn't been set, set it to the input
sketch's item data type
+ // When input sketch's item data type is null, buffer's item data type
will remain null
+ if (this.itemDataType == null) {
+ this.itemDataType = inputItemDataType
+ } else {
+ // When the buffer's dataType has been set, throw an error
+ // if the input sketch's item data type is not null the two data types
don't match
+ if (inputItemDataType != null && this.itemDataType != inputItemDataType)
{
+ throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
+ this.itemDataType, inputItemDataType)
+ }
+ }
+ }
+
+ /**
+ * Serialize the CombineInternal instance to a byte array.
+ * Serialization format:
+ * maxItemsTracked (4 bytes int) +
+ * itemDataTypeDDL length n in byte (4 bytes int) +
+ * itemDataTypeDDL (n bytes) +
+ * sketchBytes
+ */
+ def serialize(): Array[Byte] = {
+ val sketchBytes = sketch.toByteArray(
+
ApproxTopK.genSketchSerDe(itemDataType).asInstanceOf[ArrayOfItemsSerDe[T]])
+ val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(itemDataType)
+ val ddlBytes: Array[Byte] =
itemDataTypeDDL.getBytes(StandardCharsets.UTF_8)
+ val byteArray = new Array[Byte](
+ sketchBytes.length + Integer.BYTES + Integer.BYTES + ddlBytes.length)
+
+ val byteBuffer = ByteBuffer.wrap(byteArray)
+ byteBuffer.putInt(maxItemsTracked)
+ byteBuffer.putInt(ddlBytes.length)
+ byteBuffer.put(ddlBytes)
+ byteBuffer.put(sketchBytes)
+ byteArray
+ }
+}
+
+object CombineInternal {
+ /**
+ * Deserialize a byte array to a CombineInternal instance.
+ * Serialization format:
+ * maxItemsTracked (4 bytes int) +
+ * itemDataTypeDDL length n in byte (4 bytes int) +
+ * itemDataTypeDDL (n bytes) +
+ * sketchBytes
+ */
+ def deserialize(buffer: Array[Byte]): CombineInternal[Any] = {
+ val byteBuffer = ByteBuffer.wrap(buffer)
+ // read maxItemsTracked
+ val maxItemsTracked = byteBuffer.getInt
+ // read itemDataTypeDDL
+ val ddlLength = byteBuffer.getInt
+ val ddlBytes = new Array[Byte](ddlLength)
+ byteBuffer.get(ddlBytes)
+ val itemDataTypeDDL = new String(ddlBytes, StandardCharsets.UTF_8)
+ val itemDataType = ApproxTopK.DDLToDataType(itemDataTypeDDL)
+ // read sketchBytes
+ val sketchBytes = new Array[Byte](buffer.length - Integer.BYTES -
Integer.BYTES - ddlLength)
+ byteBuffer.get(sketchBytes)
+ val sketch = ItemsSketch.getInstance(
+ Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType))
+ new CombineInternal[Any](sketch, itemDataType, maxItemsTracked)
+ }
+}
+
+/**
+ * An aggregate function that combines multiple sketches into a single sketch.
+ *
+ * @param state the expression containing the sketches to
combine
+ * @param maxItemsTracked the maximum number of items to track in the
sketch
+ * @param mutableAggBufferOffset the offset for mutable aggregation buffer
+ * @param inputAggBufferOffset the offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single
sketch.
+ `maxItemsTracked` An optional positive INTEGER literal with upper limit
of 1000000. If maxItemsTracked is specified, it will be set for the combined
sketch. If maxItemsTracked is not specified, the input sketches must have the
same maxItemsTracked value, otherwise an error will be thrown. The output
sketch will use the same value from the input sketches.
+ """,
+ examples = """
+ Examples:
+ > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES
(2), (3), (4), (4) AS tab(expr));
+
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class ApproxTopKCombine(
+ state: Expression,
+ maxItemsTracked: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[CombineInternal[Any]]
+ with ImplicitCastInputTypes
+ with BinaryLike[Expression] {
+
+ def this(child: Expression, maxItemsTracked: Expression) = {
+ this(child, maxItemsTracked, 0, 0)
+ ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
+ ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
+ }
+
+ def this(child: Expression, maxItemsTracked: Int) = this(child,
Literal(maxItemsTracked))
+
+ // If maxItemsTracked is not specified, set it to VOID_MAX_ITEMS_TRACKED.
+ // This indicates that there is no explicit maxItemsTracked input from the
function call.
+ // Hence, function needs to check the input sketches' maxItemsTracked values
during merge.
+ def this(child: Expression) = this(child,
Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)
+
+ // The item data type extracted from the third field of the state struct.
+ // It is named "unchecked" because it may be inaccurate when input sketches
have different
+ // item data types. For example, if one sketch has int type null and another
has string type
+ // null, the union of the two sketches will have bigint type null.
+ // The accurate item data type will be tracked in the aggregation buffer
during update/merge.
+ // It is okay to use uncheckedItemDataType to create the output data type of
this function,
+ // because if the input sketches have different item data types, an error
will be thrown
+ // during update/merge. Otherwise, the uncheckedItemDataType is accurate.
+ private lazy val uncheckedItemDataType: DataType =
+ state.dataType.asInstanceOf[StructType](2).dataType
+ private lazy val maxItemsTrackedVal: Int =
maxItemsTracked.eval().asInstanceOf[Int]
+ private lazy val combineSizeSpecified: Boolean =
+ maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED
+
+ override def left: Expression = state
+
+ override def right: Expression = maxItemsTracked
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ val stateCheck = ApproxTopK.checkStateFieldAndType(state)
+ if (stateCheck.isFailure) {
+ stateCheck
+ } else if (!maxItemsTracked.foldable) {
+ TypeCheckFailure("Number of items tracked must be a constant literal")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+ }
+
+ override def dataType: DataType =
ApproxTopK.getSketchStateDataType(uncheckedItemDataType)
+
+ /**
+ * If maxItemsTracked is specified in function call, use it for the output
sketch.
+ * Otherwise, create a placeholder sketch with VOID_MAX_ITEMS_TRACKED. The
actual value will be
+ * decided during the first update.
+ */
+ override def createAggregationBuffer(): CombineInternal[Any] = {
+ if (combineSizeSpecified) {
+ val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
+ new CombineInternal[Any](
+ new ItemsSketch[Any](maxMapSize),
+ null,
+ maxItemsTrackedVal)
+ } else {
+ // If maxItemsTracked is not specified, create a sketch with the maximum
allowed size.
+ // No need to worry about memory waste, as the sketch always grows from
a small init size.
+ // The actual maxItemsTracked will be checked during the updates.
+ val maxMapSize =
ApproxTopK.calMaxMapSize(ApproxTopK.MAX_ITEMS_TRACKED_LIMIT)
+ new CombineInternal[Any](
+ new ItemsSketch[Any](maxMapSize),
+ null,
+ ApproxTopK.VOID_MAX_ITEMS_TRACKED)
+ }
+ }
+
+ /**
+ * Update the aggregation buffer with an input sketch. The input has the
same schema as the
+ * ApproxTopKAccumulate output, i.e., sketchBytes + maxItemsTracked + null +
DDL.
+ */
+ override def update(buffer: CombineInternal[Any], input: InternalRow):
CombineInternal[Any] = {
+ val inputState = state.eval(input).asInstanceOf[InternalRow]
+ val inputSketchBytes = inputState.getBinary(0)
+ val inputMaxItemsTracked = inputState.getInt(1)
+ val inputItemDataTypeDDL = inputState.getUTF8String(3).toString
+ val inputItemDataType = ApproxTopK.DDLToDataType(inputItemDataTypeDDL)
+ // update maxItemsTracked (throw error if not match)
+ buffer.updateMaxItemsTracked(combineSizeSpecified, inputMaxItemsTracked)
+ // update itemDataType (throw error if not match)
+ buffer.updateItemDataType(inputItemDataType)
+ // update sketch
+ val inputSketch = ItemsSketch.getInstance(
+ Memory.wrap(inputSketchBytes),
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ buffer.getSketch.merge(inputSketch)
+ buffer
+ }
+
+ override def merge(
+ buffer: CombineInternal[Any],
+ input: CombineInternal[Any]): CombineInternal[Any] = {
+ // update maxItemsTracked (throw error if not match)
+ buffer.updateMaxItemsTracked(combineSizeSpecified,
input.getMaxItemsTracked)
+ // update itemDataType (throw error if not match)
+ buffer.updateItemDataType(input.getItemDataType)
+ // update sketch
+ buffer.getSketch.merge(input.getSketch)
+ buffer
+ }
+
+ override def eval(buffer: CombineInternal[Any]): Any = {
+ val sketchBytes =
+
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
+ val maxItemsTracked = buffer.getMaxItemsTracked
+ val itemDataTypeDDL = ApproxTopK.dataTypeToDDL(buffer.getItemDataType)
+ InternalRow.apply(
+ sketchBytes,
+ maxItemsTracked,
+ null,
+ UTF8String.fromString(itemDataTypeDDL))
+ }
+
+ override def serialize(buffer: CombineInternal[Any]): Array[Byte] = {
+ buffer.serialize()
+ }
+
+ override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = {
+ CombineInternal.deserialize(buffer)
+ }
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression,
+ newRight: Expression): Expression =
+ copy(state = newLeft, maxItemsTracked = newRight)
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int):
ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def nullable: Boolean = false
+
+ override def prettyName: String =
+ getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_combine")
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index a6c5bbf91eb0..f62ffe2a8e60 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -2833,6 +2833,22 @@ private[sql] object QueryExecutionErrors extends
QueryErrorsBase with ExecutionE
"limit" -> toSQLValue(limit, IntegerType)))
}
+ def approxTopKSketchSizeNotMatch(size1: Int, size2: Int): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+ messageParameters = Map(
+ "size1" -> toSQLValue(size1, IntegerType),
+ "size2" -> toSQLValue(size2, IntegerType)))
+ }
+
+ def approxTopKSketchTypeNotMatch(type1: DataType, type2: DataType):
Throwable = {
+ new SparkRuntimeException(
+ errorClass = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ messageParameters = Map(
+ "type1" -> toSQLType(type1),
+ "type2" -> toSQLType(type2)))
+ }
+
def mergeCardinalityViolationError(): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "MERGE_CARDINALITY_VIOLATION",
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala
index 2b339003abd4..731c9259a8a3 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala
@@ -137,18 +137,14 @@ class ApproxTopKSuite extends SparkFunSuite {
// ApproxTopKEstimate tests
/////////////////////////////
- val stateStructType: StructType = StructType(Seq(
- StructField("sketch", BinaryType),
- StructField("itemDataType", IntegerType),
- StructField("maxItemsTracked", IntegerType)
- ))
-
test("SPARK-52588: invalid estimate if k are not foldable") {
val badEstimate = ApproxTopKEstimate(
state = BoundReference(0, StructType(Seq(
StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", IntegerType),
StructField("itemDataType", IntegerType),
- StructField("maxItemsTracked", IntegerType))), nullable = false),
+ StructField("itemDataTypeDDL", StringType)
+ )), nullable = false),
k = Sum(BoundReference(1, IntegerType, nullable = true))
)
assert(badEstimate.checkInputDataTypes().isFailure)
@@ -184,11 +180,11 @@ class ApproxTopKSuite extends SparkFunSuite {
)
}
- test("SPARK-52588: invalid estimate if state struct length is not 3") {
+ test("SPARK-52588: invalid estimate if state struct length is not 4") {
val invalidState = StructType(Seq(
StructField("sketch", BinaryType),
- StructField("itemDataType", IntegerType)
- // Missing "maxItemsTracked"
+ StructField("maxItemsTracked", IntegerType)
+ // Missing "itemDataType", "itemDataTypeDDL" fields
))
val badEstimate = ApproxTopKEstimate(
state = BoundReference(0, invalidState, nullable = false),
@@ -196,16 +192,18 @@ class ApproxTopKSuite extends SparkFunSuite {
)
assert(badEstimate.checkInputDataTypes().isFailure)
assert(badEstimate.checkInputDataTypes() ==
- TypeCheckFailure("State must be a struct with 3 fields. " +
- "Expected struct:
struct<sketch:binary,itemDataType:any,maxItemsTracked:int>. " +
- "Got: struct<sketch:binary,itemDataType:int>"))
+ TypeCheckFailure("State must be a struct with 4 fields. " +
+ "Expected struct: " +
+
"struct<sketch:binary,maxItemsTracked:int,itemDataType:any,itemDataTypeDDL:string>.
" +
+ "Got: struct<sketch:binary,maxItemsTracked:int>"))
}
test("SPARK-52588: invalid estimate if state struct's first field is not
binary") {
val invalidState = StructType(Seq(
StructField("notSketch", IntegerType), // Should be BinaryType
+ StructField("maxItemsTracked", IntegerType),
StructField("itemDataType", IntegerType),
- StructField("maxItemsTracked", IntegerType)
+ StructField("itemDataTypeDDL", StringType)
))
val badEstimate = ApproxTopKEstimate(
state = BoundReference(0, invalidState, nullable = false),
@@ -216,7 +214,23 @@ class ApproxTopKSuite extends SparkFunSuite {
TypeCheckFailure("State struct must have the first field to be binary.
Got: int"))
}
- gridTest("SPARK-52588: invalid estimate if state struct's second field is
not supported")(
+ test("SPARK-52588: invalid estimate if state struct's second field is not
int") {
+ val invalidState = StructType(Seq(
+ StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", LongType), // Should be IntegerType
+ StructField("itemDataType", IntegerType),
+ StructField("itemDataTypeDDL", StringType)
+ ))
+ val badEstimate = ApproxTopKEstimate(
+ state = BoundReference(0, invalidState, nullable = false),
+ k = Literal(5)
+ )
+ assert(badEstimate.checkInputDataTypes().isFailure)
+ assert(badEstimate.checkInputDataTypes() ==
+ TypeCheckFailure("State struct must have the second field to be int.
Got: bigint"))
+ }
+
+ gridTest("SPARK-52588: invalid estimate if state struct's third field is not
supported")(
Seq(
("array<int>", ArrayType(IntegerType)),
("map<string,int>", MapType(StringType, IntegerType)),
@@ -226,8 +240,9 @@ class ApproxTopKSuite extends SparkFunSuite {
val (typeName, dataType) = unSupportedType
val invalidState = StructType(Seq(
StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", IntegerType),
StructField("itemDataType", dataType),
- StructField("maxItemsTracked", IntegerType)
+ StructField("itemDataTypeDDL", StringType)
))
val badEstimate = ApproxTopKEstimate(
state = BoundReference(0, invalidState, nullable = false),
@@ -235,15 +250,16 @@ class ApproxTopKSuite extends SparkFunSuite {
)
assert(badEstimate.checkInputDataTypes().isFailure)
assert(badEstimate.checkInputDataTypes() ==
- TypeCheckFailure(s"State struct must have the second field to be a
supported data type. " +
+ TypeCheckFailure(s"State struct must have the third field to be a
supported data type. " +
s"Got: $typeName"))
}
- test("SPARK-52588: invalid estimate if state struct's third field is not
int") {
+ test("SPARK-52588: invalid estimate if state struct's fourth field is not
string") {
val invalidState = StructType(Seq(
StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", IntegerType),
StructField("itemDataType", IntegerType),
- StructField("maxItemsTracked", LongType) // Should be IntegerType
+ StructField("itemDataTypeDDL", BinaryType) // Should be StringType
))
val badEstimate = ApproxTopKEstimate(
state = BoundReference(0, invalidState, nullable = false),
@@ -251,6 +267,142 @@ class ApproxTopKSuite extends SparkFunSuite {
)
assert(badEstimate.checkInputDataTypes().isFailure)
assert(badEstimate.checkInputDataTypes() ==
- TypeCheckFailure("State struct must have the third field to be int. Got:
bigint"))
+ TypeCheckFailure("State struct must have the fourth field to be string.
Got: binary"))
+ }
+
+ /////////////////////////////
+ // ApproxTopKCombine tests
+ /////////////////////////////
+ test("SPARK-52798: invalid combine if maxItemsTracked is not foldable") {
+ val badCombine = ApproxTopKCombine(
+ state = BoundReference(0, StructType(Seq(
+ StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", IntegerType),
+ StructField("itemDataType", IntegerType),
+ StructField("itemDataTypeDDL", StringType)
+ )), nullable = false),
+ maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true))
+ )
+ assert(badCombine.checkInputDataTypes().isFailure)
+ assert(badCombine.checkInputDataTypes() ==
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(1),
+ "requiredType" -> "\"INT\"",
+ "inputSql" -> "\"sum(boundreference())\"",
+ "inputType" -> "\"BIGINT\""
+ )
+ )
+ )
+ }
+
+ test("SPARK-52798: invalid combine if state is not a struct") {
+ val badCombine = ApproxTopKCombine(
+ state = BoundReference(0, IntegerType, nullable = false),
+ maxItemsTracked = Literal(10)
+ )
+ assert(badCombine.checkInputDataTypes().isFailure)
+ assert(badCombine.checkInputDataTypes() ==
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> "\"STRUCT\"",
+ "inputSql" -> "\"boundreference()\"",
+ "inputType" -> "\"INT\""
+ )
+ )
+ )
+ }
+
+ test("SPARK-52798: invalid combine if state struct length is not 4") {
+ val invalidState = StructType(Seq(
+ StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", IntegerType)
+ // Missing "itemDataType", "itemDataTypeDDL" fields
+ ))
+ val badCombine = ApproxTopKCombine(
+ state = BoundReference(0, invalidState, nullable = false),
+ maxItemsTracked = Literal(10)
+ )
+ assert(badCombine.checkInputDataTypes().isFailure)
+ assert(badCombine.checkInputDataTypes() ==
+ TypeCheckFailure("State must be a struct with 4 fields. " +
+ "Expected struct: " +
+
"struct<sketch:binary,maxItemsTracked:int,itemDataType:any,itemDataTypeDDL:string>.
" +
+ "Got: struct<sketch:binary,maxItemsTracked:int>"))
+ }
+
+ test("SPARK-52798: invalid combine if state struct's first field is not
binary") {
+ val invalidState = StructType(Seq(
+ StructField("sketch", IntegerType), // Should be BinaryType
+ StructField("maxItemsTracked", IntegerType),
+ StructField("itemDataType", IntegerType),
+ StructField("itemDataTypeDDL", StringType)
+ ))
+ val badCombine = ApproxTopKCombine(
+ state = BoundReference(0, invalidState, nullable = false),
+ maxItemsTracked = Literal(10)
+ )
+ assert(badCombine.checkInputDataTypes().isFailure)
+ assert(badCombine.checkInputDataTypes() ==
+ TypeCheckFailure("State struct must have the first field to be binary.
Got: int"))
+ }
+
+ test("SPARK-52798: invalid combine if state struct's second field is not
int") {
+ val invalidState = StructType(Seq(
+ StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", LongType), // Should be IntegerType
+ StructField("itemDataType", IntegerType),
+ StructField("itemDataTypeDDL", StringType)
+ ))
+ val badCombine = ApproxTopKCombine(
+ state = BoundReference(0, invalidState, nullable = false),
+ maxItemsTracked = Literal(10)
+ )
+ assert(badCombine.checkInputDataTypes().isFailure)
+ assert(badCombine.checkInputDataTypes() ==
+ TypeCheckFailure("State struct must have the second field to be int.
Got: bigint"))
+ }
+
+ gridTest("SPARK-52798: invalid combine if state struct's third field is not
supported")(
+ Seq(
+ ("array<int>", ArrayType(IntegerType)),
+ ("map<string,int>", MapType(StringType, IntegerType)),
+ ("struct<a:int>", StructType(Seq(StructField("a", IntegerType)))),
+ ("binary", BinaryType)
+ )) { unSupportedType =>
+ val (typeName, dataType) = unSupportedType
+ val invalidState = StructType(Seq(
+ StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", IntegerType),
+ StructField("itemDataType", dataType),
+ StructField("itemDataTypeDDL", StringType)
+ ))
+ val badCombine = ApproxTopKCombine(
+ state = BoundReference(0, invalidState, nullable = false),
+ maxItemsTracked = Literal(10)
+ )
+ assert(badCombine.checkInputDataTypes().isFailure)
+ assert(badCombine.checkInputDataTypes() ==
+ TypeCheckFailure(s"State struct must have the third field to be a
supported data type. " +
+ s"Got: $typeName"))
+ }
+
+ test("SPARK-52798: invalid combine if state struct's fourth field is not
string") {
+ val invalidState = StructType(Seq(
+ StructField("sketch", BinaryType),
+ StructField("maxItemsTracked", IntegerType),
+ StructField("itemDataType", IntegerType),
+ StructField("itemDataTypeDDL", BinaryType) // Should be StringType
+ ))
+ val badCombine = ApproxTopKCombine(
+ state = BoundReference(0, invalidState, nullable = false),
+ maxItemsTracked = Literal(10)
+ )
+ assert(badCombine.checkInputDataTypes().isFailure)
+ assert(badCombine.checkInputDataTypes() ==
+ TypeCheckFailure("State struct must have the fourth field to be string.
Got: binary"))
}
}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index a34455f24e01..f192a020f576 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -412,6 +412,7 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue | any_value |
SELECT any_value(col) FROM VALUES (10), (5), (20) AS tab(col) |
struct<any_value(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK |
approx_top_k | SELECT approx_top_k(expr) FROM VALUES (0), (0), (1), (1), (2),
(3), (4), (4) AS tab(expr) | struct<approx_top_k(expr, 5,
10000):array<struct<item:int,count:bigint>>> |
| org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopKAccumulate |
approx_top_k_accumulate | SELECT
approx_top_k_estimate(approx_top_k_accumulate(expr)) FROM VALUES (0), (0), (1),
(1), (2), (3), (4), (4) AS tab(expr) |
struct<approx_top_k_estimate(approx_top_k_accumulate(expr, 10000),
5):array<struct<item:int,count:bigint>>> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopKCombine |
approx_top_k_combine | SELECT
approx_top_k_estimate(approx_top_k_combine(sketch, 10000), 5) FROM (SELECT
approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS
tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES
(2), (3), (4), (4) AS tab(expr)) |
struct<approx_top_k_estimate(approx_top_k_combine(sketch, 10000),
5):array<struct<item:int,count:bigint>>> |
| org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile |
approx_percentile | SELECT approx_percentile(col, array(0.5, 0.4, 0.1), 100)
FROM VALUES (0), (1), (2), (10) AS tab(col) | struct<approx_percentile(col,
array(0.5, 0.4, 0.1), 100):array<int>> |
| org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile |
percentile_approx | SELECT percentile_approx(col, array(0.5, 0.4, 0.1), 100)
FROM VALUES (0), (1), (2), (10) AS tab(col) | struct<percentile_approx(col,
array(0.5, 0.4, 0.1), 100):array<int>> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Average | avg | SELECT
avg(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<avg(col):double> |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
index 8219fce9b217..702f361ace28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala
@@ -21,7 +21,10 @@ import java.sql.{Date, Timestamp}
import java.time.LocalDateTime
import org.apache.spark.{SparkArithmeticException, SparkRuntimeException}
+import org.apache.spark.sql.catalyst.ExtendedAnalysisException
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType,
DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType,
StringType, TimestampNTZType, TimestampType}
class ApproxTopKSuite extends QueryTest with SharedSparkSession {
@@ -328,4 +331,326 @@ class ApproxTopKSuite extends QueryTest with
SharedSparkSession {
parameters = Map("maxItemsTracked" -> "5", "k" -> "10")
)
}
+
+ /////////////////////////////////
+ // approx_top_k_combine
+ /////////////////////////////////
+
+ def setupMixedSizeAccumulations(size1: Int, size2: Int): Unit = {
+ sql(s"SELECT approx_top_k_accumulate(expr, $size1) as acc " +
+ "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+
+ sql(s"SELECT approx_top_k_accumulate(expr, $size2) as acc " +
+ "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+
+ sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2")
+ .createOrReplaceTempView("unioned")
+ }
+
+ def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = {
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+
+ sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2")
+ .createOrReplaceTempView("unioned")
+ }
+
+ val mixedNumberTypes: Seq[(DataType, String, Seq[Any])] = Seq(
+ (IntegerType, "INT",
+ Seq(0, 0, 0, 1, 1, 2, 2, 3)),
+ (ByteType, "TINYINT",
+ Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")),
+ (ShortType, "SMALLINT",
+ Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")),
+ (LongType, "BIGINT",
+ Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")),
+ (FloatType, "FLOAT",
+ Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")),
+ (DoubleType, "DOUBLE",
+ Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")),
+ (DecimalType(4, 2), "DECIMAL(4,2)",
+ Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", "cast(1 AS
DECIMAL(4, 2))")),
+ (DecimalType(10, 2), "DECIMAL(10,2)",
+ Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", "cast(1 AS
DECIMAL(10, 2))")),
+ (DecimalType(20, 3), "DECIMAL(20,3)",
+ Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS
DECIMAL(20, 3))"))
+ )
+
+ val mixedDateTimeTypes: Seq[(DataType, String, Seq[String])] = Seq(
+ (DateType, "DATE",
+ Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")),
+ (TimestampType, "TIMESTAMP",
+ Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")),
+ (TimestampNTZType, "TIMESTAMP_NTZ",
+ Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01
00:00:00'")
+ )
+ )
+
+ // positive tests for approx_top_k_combine on every types
+ gridTest("SPARK-52798: same type, same size, specified combine size -
success")(itemsWithTopK) {
+ case (input, expected) =>
+ sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS
tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+ sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS
tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+ sql("SELECT approx_top_k_combine(acc, 30) as com " +
+ "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2);")
+ .createOrReplaceTempView("combined")
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ // expected should be doubled because we combine two identical sketches
+ val expectedDoubled = expected.map {
+ case Row(value: Any, count: Int) => Row(value, count * 2)
+ }
+ checkAnswer(est, Row(expectedDoubled))
+ }
+
+ test("SPARK-52798: same type, same size, specified combine size - success") {
+ setupMixedSizeAccumulations(10, 10)
+
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
+
+ test("SPARK-52798: same type, same size, unspecified combine size -
success") {
+ setupMixedSizeAccumulations(10, 10)
+
+ sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
+
+ test("SPARK-52798: same type, different size, specified combine size -
success") {
+ setupMixedSizeAccumulations(10, 20)
+
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combination")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combination;")
+ checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3),
Row(4, 2))))
+ }
+
+ test("SPARK-52798: same type, different size, unspecified combine size -
fail") {
+ setupMixedSizeAccumulations(10, 20)
+
+ val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned")
+
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ comb.collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+ parameters = Map("size1" -> "10", "size2" -> "20")
+ )
+ }
+
+ gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10,
20))) {
+ case (size1, size2) =>
+ setupMixedSizeAccumulations(size1, size2)
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 0) as com FROM
unioned").collect()
+ },
+ condition = "APPROX_TOP_K_NON_POSITIVE_ARG",
+ parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0")
+ )
+ }
+
+ test("SPARK-52798: among different number or datetime types - fail at
combine") {
+ def checkMixedTypeError(mixedTypeSeq: Seq[(DataType, String, Seq[Any])]):
Unit = {
+ for (i <- 0 until mixedTypeSeq.size - 1) {
+ for (j <- i + 1 until mixedTypeSeq.size) {
+ val (type1, _, seq1) = mixedTypeSeq(i)
+ val (type2, _, seq2) = mixedTypeSeq(j)
+ setupMixedTypeAccumulation(seq1, seq2)
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(type2))
+ )
+ }
+ }
+ }
+
+ checkMixedTypeError(mixedNumberTypes)
+ checkMixedTypeError(mixedDateTimeTypes)
+ }
+
+ // enumerate all combinations of number and datetime types
+ gridTest("SPARK-52798: number vs datetime - fail on UNION")(
+ for {
+ (type1, typeName1, seq1) <- mixedNumberTypes
+ (type2, typeName2, seq2) <- mixedDateTimeTypes
+ } yield ((type1, typeName1, seq1), (type2, typeName2, seq2))) {
+ case ((_, type1, seq1), (_, type2, seq2)) =>
+ checkError(
+ exception = intercept[ExtendedAnalysisException] {
+ setupMixedTypeAccumulation(seq1, seq2)
+ },
+ condition = "INCOMPATIBLE_COLUMN_TYPE",
+ parameters = Map(
+ "tableOrdinalNumber" -> "second",
+ "columnOrdinalNumber" -> "first",
+ "dataType2" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: " + type1 + ", itemDataTypeDDL: STRING NOT
NULL>\""),
+ "operator" -> "UNION",
+ "hint" -> "",
+ "dataType1" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: " + type2 + ", itemDataTypeDDL: STRING NOT NULL>\"")
+ ),
+ queryContext = Array(
+ ExpectedContext(
+ "SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2", 0, 68))
+ )
+ }
+
+ gridTest("SPARK-52798: number vs string - fail at
combine")(mixedNumberTypes) {
+ case (type1, _, seq1) =>
+ setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'",
"'c'", "'d'", "'d'"))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(StringType))
+ )
+ }
+
+ gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypes)
{
+ case (_, type1, seq1) =>
+ val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
+ checkError(
+ exception = intercept[ExtendedAnalysisException] {
+ setupMixedTypeAccumulation(seq1, seq2)
+ },
+ condition = "INCOMPATIBLE_COLUMN_TYPE",
+ parameters = Map(
+ "tableOrdinalNumber" -> "second",
+ "columnOrdinalNumber" -> "first",
+ "dataType2" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: " + type1 + ", itemDataTypeDDL: STRING NOT
NULL>\""),
+ "operator" -> "UNION",
+ "hint" -> "",
+ "dataType1" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: BOOLEAN, itemDataTypeDDL: STRING NOT NULL>\"")
+ ),
+ queryContext = Array(
+ ExpectedContext(
+ "SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2", 0, 68))
+ )
+ }
+
+ gridTest("SPARK-52798: datetime vs string - fail at
combine")(mixedDateTimeTypes) {
+ case (type1, _, seq1) =>
+ setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'",
"'c'", "'d'", "'d'"))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(type1), "type2" ->
toSQLType(StringType))
+ )
+ }
+
+ gridTest("SPARK-52798: datetime vs boolean - fail at
UNION")(mixedDateTimeTypes) {
+ case (_, type1, seq1) =>
+ val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
+ checkError(
+ exception = intercept[ExtendedAnalysisException] {
+ setupMixedTypeAccumulation(seq1, seq2)
+ },
+ condition = "INCOMPATIBLE_COLUMN_TYPE",
+ parameters = Map(
+ "tableOrdinalNumber" -> "second",
+ "columnOrdinalNumber" -> "first",
+ "dataType2" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: " + type1 + ", itemDataTypeDDL: STRING NOT
NULL>\""),
+ "operator" -> "UNION",
+ "hint" -> "",
+ "dataType1" -> ("\"STRUCT<sketch: BINARY NOT NULL, maxItemsTracked:
INT NOT NULL, " +
+ "itemDataType: BOOLEAN, itemDataTypeDDL: STRING NOT NULL>\"")
+ ),
+ queryContext = Array(
+ ExpectedContext(
+ "SELECT acc from accumulation1 UNION ALL SELECT acc FROM
accumulation2", 0, 68))
+ )
+ }
+
+ test("SPARK-52798: string vs boolean - fail at combine") {
+ val seq1 = Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")
+ val seq2 = Seq("(true)", "(true)", "(false)", "(false)")
+ setupMixedTypeAccumulation(seq1, seq2)
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM
unioned;").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH",
+ parameters = Map("type1" -> toSQLType(StringType), "type2" ->
toSQLType(BooleanType))
+ )
+ }
+
+ test("SPARK-52798: combine more than 2 sketches with specified size") {
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+
+ sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
+ "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
+ .createOrReplaceTempView("accumulation3")
+
+ sql("SELECT acc from accumulation1 UNION ALL " +
+ "SELECT acc FROM accumulation2 UNION ALL " +
+ "SELECT acc FROM accumulation3")
+ .createOrReplaceTempView("unioned")
+
+ sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned")
+ .createOrReplaceTempView("combined")
+
+ val est = sql("SELECT approx_top_k_estimate(com) FROM combined;")
+ checkAnswer(est, Row(Seq(Row(2, 6), Row(3, 5), Row(1, 4), Row(0, 3),
Row(4, 2))))
+ }
+
+ test("SPARK-52798: combine more than 2 sketches without specified size") {
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ "FROM VALUES (0), (0), (0), (1), (1), (2), (2) AS tab(expr);")
+ .createOrReplaceTempView("accumulation1")
+
+ sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " +
+ "FROM VALUES (1), (1), (2), (2), (3), (3), (4) AS tab(expr);")
+ .createOrReplaceTempView("accumulation2")
+
+ sql(s"SELECT approx_top_k_accumulate(expr, 20) as acc " +
+ "FROM VALUES (2), (2), (3), (3), (3), (4), (5) AS tab(expr);")
+ .createOrReplaceTempView("accumulation3")
+
+ sql("SELECT acc from accumulation1 UNION ALL " +
+ "SELECT acc FROM accumulation2 UNION ALL " +
+ "SELECT acc FROM accumulation3")
+ .createOrReplaceTempView("unioned")
+
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("SELECT approx_top_k_combine(acc) as com FROM unioned").collect()
+ },
+ condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH",
+ parameters = Map("size1" -> "10", "size2" -> "20")
+ )
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]