This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 1c408c31941b [SPARK-52023][SQL][3.5] Fix data corruption/segfault
returning Option[Product] from udaf
1c408c31941b is described below
commit 1c408c31941baf005be6f5bc294128b2ac177815
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Wed Jul 2 06:51:40 2025 -0700
[SPARK-52023][SQL][3.5] Fix data corruption/segfault returning
Option[Product] from udaf
### What changes were proposed in this pull request?
This fixes so defining a udaf returning a `Option[Product]` produces
correct results instead of the current behavior. Where it throws an exception,
segfaults or produces incorrect results.
### Why are the changes needed?
Fix correctness issue.
### Does this PR introduce _any_ user-facing change?
Fixes a correctness issue.
### How was this patch tested?
Existing and new unittest.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #51347 from eejbyfeldt/3.5-SPARK-52023.
Authored-by: Emil Ejbyfeldt <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/sql/execution/aggregate/udaf.scala | 2 +-
.../spark/sql/hive/execution/UDAQuerySuite.scala | 28 ++++++++++++++++++++++
2 files changed, 29 insertions(+), 1 deletion(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index e517376bc5fc..fe6307b5bbe8 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -530,7 +530,7 @@ case class ScalaAggregator[IN, BUF, OUT](
def eval(buffer: BUF): Any = {
val row = outputSerializer(agg.finish(buffer))
- if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType)
+ if (outputEncoder.isSerializedAsStructForTopLevel) row else row.get(0,
dataType)
}
private[this] lazy val bufferRow = new
UnsafeRow(bufferEncoder.namedExpressions.length)
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
index 0bd6b1403d39..31d0452c7061 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
@@ -60,6 +60,22 @@ object LongProductSumAgg extends Aggregator[(jlLong,
jlLong), Long, jlLong] {
def outputEncoder: Encoder[jlLong] = Encoders.LONG
}
+final case class Reduce[T: Encoder](r: (T, T) => T)(implicit i:
Encoder[Option[T]])
+ extends Aggregator[T, Option[T], T] {
+ def zero: Option[T] = None
+ def reduce(b: Option[T], a: T): Option[T] = Some(b.fold(a)(r(_, a)))
+ def merge(b1: Option[T], b2: Option[T]): Option[T] =
+ (b1, b2) match {
+ case (Some(a), Some(b)) => Some(r(a, b))
+ case (Some(a), None) => Some(a)
+ case (None, Some(b)) => Some(b)
+ case (None, None) => None
+ }
+ def finish(reduction: Option[T]): T = reduction.get
+ def bufferEncoder: Encoder[Option[T]] = implicitly
+ def outputEncoder: Encoder[T] = implicitly
+}
+
@SQLUserDefinedType(udt = classOf[CountSerDeUDT])
case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int)
@@ -180,6 +196,9 @@ abstract class UDAQuerySuite extends QueryTest with
SQLTestUtils with TestHiveSi
val data4 = Seq[Boolean](true, false, true).toDF("boolvalues")
data4.write.saveAsTable("agg4")
+ val data5 = Seq[(Int, (Int, Int))]((1, (2, 3))).toDF("key", "value")
+ data5.write.saveAsTable("agg5")
+
val emptyDF = spark.createDataFrame(
sparkContext.emptyRDD[Row],
StructType(StructField("key", StringType) :: StructField("value",
IntegerType) :: Nil))
@@ -190,6 +209,9 @@ abstract class UDAQuerySuite extends QueryTest with
SQLTestUtils with TestHiveSi
spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
spark.udf.register("longProductSum", udaf(LongProductSumAgg))
spark.udf.register("arraysum", udaf(ArrayDataAgg))
+ spark.udf.register("reduceOptionPair", udaf(Reduce[Option[(Int, Int)]](
+ (opt1, opt2) =>
+ opt1.zip(opt2).map { case ((a1, b1), (a2, b2)) => (a1 + a2, b1 + b2)
}.headOption)))
}
override def afterAll(): Unit = {
@@ -371,6 +393,12 @@ abstract class UDAQuerySuite extends QueryTest with
SQLTestUtils with TestHiveSi
Row(Seq(12.0, 15.0, 18.0)) :: Nil)
}
+ test("SPARK-52023: Returning Option[Product] from udaf") {
+ checkAnswer(
+ spark.sql("SELECT reduceOptionPair(value) FROM agg5 GROUP BY key"),
+ Row(Row(2, 3)) :: Nil)
+ }
+
test("verify aggregator ser/de behavior") {
val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
val agg = udaf(CountSerDeAgg)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]