This is an automated email from the ASF dual-hosted git repository.
cloud-fan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a4aa4adf4fd0 [SPARK-54918][SQL] Normalize floating numbers in array
set operations
a4aa4adf4fd0 is described below
commit a4aa4adf4fd0f790d7b71b322292050f6cf8db38
Author: Albert Sugranyes <[email protected]>
AuthorDate: Wed May 27 15:13:56 2026 +0800
[SPARK-54918][SQL] Normalize floating numbers in array set operations
### What changes were proposed in this pull request?
Extends `NormalizeFloatingNumbers` Catalyst optimizer rule to normalize
floating numbers in hash-based array set operations:
- `array_distinct`
- `array_union`
- `array_intersect`
- `array_except`
- `arrays_overlap`
### Why are the changes needed?
These expressions rely on hash-based set semantics for element comparison,
which distinguish -0.0 from 0.0.Under Spark SQL semantics, these values are
equivalent, so the resulting sets violate the expected algebraic properties of
the set operations.
Examples:
```scala
// Before fix: returns [0.0, -0.0, 1.0]
// After fix: returns [0.0, 1.0]
Seq(Array(0.0, -0.0, 1.0))
.toDF("values")
.selectExpr("array_distinct(values)")
.show()
// Before fix: returns [0.0, -0.0]
// After fix: returns [0.0]
Seq((Array(0.0), Array(-0.0)))
.toDF("a", "b")
.selectExpr("array_union(a, b)")
.show()
// Before fix: returns []
// After fix: returns [0.0]
Seq((Array(0.0, 1.0), Array(-0.0, 2.0)))
.toDF("a", "b")
.selectExpr("array_intersect(a, b)")
.show()
// Before fix: returns [0.0, 1.0]
// After fix: returns [1.0]
Seq((Array(0.0, 1.0), Array(-0.0)))
.toDF("a", "b")
.selectExpr("array_except(a, b)")
.show()
```
### Does this PR introduce _any_ user-facing change?
Yes. Hash-based array set operations now treat -0.0/0.0 and different NaN
representations as equal, consistent with the current behavior in joins, window
partitions and aggregates.
### How was this patch tested?
- 10 unit tests in `NormalizeFloatingPointNumbersSuite` covering logical
plan rewrites for all 5 operations plus idempotence.
- 11 end-to-end tests in `DataFrameFunctionsSuite` verifying runtime bit
patterns via `Double.doubleToRawLongBits`, since IEEE 754 defines 0.0 == -0.0
as true.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53695 from asugranyes/SPARK-54918.
Authored-by: Albert Sugranyes <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/collectionOperations.scala | 21 ++-
.../optimizer/NormalizeFloatingNumbers.scala | 95 ++++++++-----
.../spark/sql/catalyst/optimizer/Optimizer.scala | 1 +
.../spark/sql/catalyst/trees/TreePatterns.scala | 5 +
.../NormalizeFloatingPointNumbersSuite.scala | 118 +++++++++++++++-
.../apache/spark/sql/DataFrameFunctionsSuite.scala | 152 +++++++++++++++++++++
6 files changed, 348 insertions(+), 44 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index b0396188bcdd..85172f795744 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -32,12 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{
- ARRAYS_ZIP,
- CONCAT,
- MAP_FROM_ENTRIES,
- TreePattern
-}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAY_DISTINCT,
ARRAY_EXCEPT, ARRAY_INTERSECT, ARRAY_UNION, ARRAYS_OVERLAP, ARRAYS_ZIP, CONCAT,
MAP_FROM_ENTRIES, TreePattern}
import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType,
PhysicalIntegralType}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -1809,6 +1804,9 @@ case class ArrayAppend(left: Expression, right:
Expression) extends ArrayPendBas
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
extends BinaryArrayExpressionWithImplicitCast with Predicate {
+
+ final override val nodePatterns: Seq[TreePattern] = Seq(ARRAYS_OVERLAP)
+
override def nullIntolerant: Boolean = true
override def checkInputDataTypes(): TypeCheckResult =
super.checkInputDataTypes() match {
@@ -4235,6 +4233,9 @@ trait ArraySetLike {
since = "2.4.0")
case class ArrayDistinct(child: Expression)
extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
+
+ final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_DISTINCT)
+
override def nullIntolerant: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
@@ -4431,6 +4432,8 @@ trait ArrayBinaryLike
case class ArrayUnion(left: Expression, right: Expression) extends
ArrayBinaryLike
with ComplexTypeMergingExpression {
+ final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_UNION)
+
@transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
@@ -4608,6 +4611,8 @@ case class ArrayUnion(left: Expression, right:
Expression) extends ArrayBinaryLi
case class ArrayIntersect(left: Expression, right: Expression) extends
ArrayBinaryLike
with ComplexTypeMergingExpression {
+ final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_INTERSECT)
+
private lazy val internalDataType: DataType = {
dataTypeCheck
ArrayType(elementType, leftArrayElementNullable &&
rightArrayElementNullable)
@@ -4823,7 +4828,7 @@ case class ArrayIntersect(left: Expression, right:
Expression) extends ArrayBina
}
/**
- * Returns an array of the elements in the intersect of x and y, without
duplicates
+ * Returns an array of the elements in x but not in y, without duplicates
*/
@ExpressionDescription(
usage = """
@@ -4840,6 +4845,8 @@ case class ArrayIntersect(left: Expression, right:
Expression) extends ArrayBina
case class ArrayExcept(left: Expression, right: Expression) extends
ArrayBinaryLike
with ComplexTypeMergingExpression {
+ final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_EXCEPT)
+
private lazy val internalDataType: DataType = {
dataTypeCheck
left.dataType
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index 776efbed273e..44add1796169 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform,
CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo,
ExpectsInputTypes, Expression, GetStructField, If, IsNull,
KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable,
TransformValues, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayDistinct,
ArrayExcept, ArrayIntersect, ArraysOverlap, ArrayTransform, ArrayUnion,
CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo,
ExpectsInputTypes, Expression, GetStructField, If, IsNull,
KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable,
TransformValues, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
@@ -31,59 +31,82 @@ import org.apache.spark.util.ArrayImplicits._
* We need to take care of special floating numbers (NaN and -0.0) in several
places:
* 1. When compare values, different NaNs should be treated as same, `-0.0`
and `0.0` should be
* treated as same.
- * 2. In aggregate grouping keys, different NaNs should belong to the same
group, -0.0 and 0.0
+ * 2. In aggregate grouping keys, different NaNs should belong to the same
group, `-0.0` and `0.0`
* should belong to the same group.
* 3. In join keys, different NaNs should be treated as same, `-0.0` and
`0.0` should be
* treated as same.
- * 4. In window partition keys, different NaNs should belong to the same
partition, -0.0 and 0.0
- * should belong to the same partition.
+ * 4. In window partition keys, different NaNs should belong to the same
partition, `-0.0`
+ * and `0.0` should belong to the same partition.
+ * 5. In hash-based array set operations, different NaNs should be treated
as same, `-0.0`
+ * and `0.0` should be treated as same.
*
- * Case 1 is fine, as we handle NaN and -0.0 well during comparison. For
complex types, we
+ * Case 1 is fine, as we handle NaN and `-0.0` well during comparison. For
complex types, we
* recursively compare the fields/elements, so it's also fine.
*
* Case 2, 3 and 4 are problematic, as Spark SQL turns grouping/join/window
partition keys into
* binary `UnsafeRow` and compare the binary data directly. Different NaNs
have different binary
- * representation, and the same thing happens for -0.0 and 0.0.
+ * representation, and the same thing happens for `-0.0` and `0.0`.
*
- * This rule normalizes NaN and -0.0 in window partition keys, join keys and
aggregate grouping
- * keys.
+ * Case 5 is problematic for a similar reason: hash-based array set operations
compare elements by
+ * their binary representation via hash sets.
+ *
+ * This rule runs in two places:
+ * 1. Early in `FinishAnalysis` (right after `ReplaceExpressions` and
before `EvalInlineTables`)
+ * so that array set-like operations are wrapped before optimizer rules
that pre-evaluate
+ * expressions (e.g. `ConstantFolding`, `ConvertToLocalRelation`,
`EvalInlineTables`).
+ *
+ * 2. As a late batch at the end of the optimizer, because rules like
subquery rewrite and
+ * join reorder can create new joins or join conditions after
`FinishAnalysis` that still
+ * need their keys to be normalized.
*
* Ideally we should do the normalization in the physical operators that
compare the
* binary `UnsafeRow` directly. We don't need this normalization if the Spark
SQL execution engine
* is not optimized to run on binary data. This rule is created to simplify
the implementation, so
* that we have a single place to do normalization, which is more maintainable.
*
- * Note that, this rule must be executed at the end of optimizer, because the
optimizer may create
- * new joins(the subquery rewrite) and new join conditions(the join reorder).
*/
object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan match {
- case _ => plan.transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
- case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
- // Although the `windowExpressions` may refer to `partitionSpec`
expressions, we don't need
- // to normalize the `windowExpressions`, as they are executed per
input row and should take
- // the input row as it is.
- w.copy(partitionSpec = w.partitionSpec.map(normalize))
-
- // Only hash join and sort merge join need the normalization. Here we
catch all Joins with
- // join keys, assuming Joins with join keys are always planned as hash
join or sort merge
- // join. It's very unlikely that we will break this assumption in the
near future.
- case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _,
_)
- // The analyzer guarantees left and right joins keys are of the same
data type. Here we
- // only need to check join keys of one side.
- if leftKeys.exists(k => needNormalize(k)) =>
- val newLeftJoinKeys = leftKeys.map(normalize)
- val newRightJoinKeys = rightKeys.map(normalize)
- val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
- case (l, r) => EqualTo(l, r)
- } ++ condition
- j.copy(condition = Some(newConditions.reduce(And)))
-
- // TODO: ideally Aggregate should also be handled here, but its grouping
expressions are
- // mixed in its aggregate expressions. It's unreliable to change the
grouping expressions
- // here. For now we normalize grouping expressions in `AggUtils` during
planning.
- }
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan
+ .transformWithPruning( _.containsAnyPattern(WINDOW, JOIN)) {
+ case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
+ // Although the `windowExpressions` may refer to `partitionSpec`
expressions,
+ // we don't need to normalize the `windowExpressions`, as they are
executed
+ // per input row and should take the input row as it is.
+ w.copy(partitionSpec = w.partitionSpec.map(normalize))
+
+ // Only hash join and sort merge join need the normalization. Here we
catch all Joins with
+ // join keys, assuming Joins with join keys are always planned as hash
join or sort merge
+ // join. It's very unlikely that we will break this assumption in the
near future.
+ case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _,
_, _)
+ // The analyzer guarantees left and right joins keys are of the
same data type. Here we
+ // only need to check join keys of one side.
+ if leftKeys.exists(k => needNormalize(k)) =>
+ val newLeftJoinKeys = leftKeys.map(normalize)
+ val newRightJoinKeys = rightKeys.map(normalize)
+ val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
+ case (l, r) => EqualTo(l, r)
+ } ++ condition
+ j.copy(condition = Some(newConditions.reduce(And)))
+
+ // TODO: ideally Aggregate should also be handled here, but its
grouping expressions are
+ // mixed in its aggregate expressions. It's unreliable to change the
grouping expressions
+ // here. For now we normalize grouping expressions in `AggUtils`
during planning.
+ }
+ .transformAllExpressionsWithPruning(_.containsAnyPattern(
+ ARRAY_DISTINCT, ARRAY_UNION, ARRAY_INTERSECT, ARRAY_EXCEPT,
ARRAYS_OVERLAP)) {
+ case e: ArrayDistinct if needNormalize(e.child) =>
+ e.copy(child = normalize(e.child))
+ case e: ArrayUnion if needNormalize(e.left) =>
+ e.copy(left = normalize(e.left), right = normalize(e.right))
+ case e: ArrayIntersect if needNormalize(e.left) =>
+ e.copy(left = normalize(e.left), right = normalize(e.right))
+ case e: ArrayExcept if needNormalize(e.left) =>
+ e.copy(left = normalize(e.left), right = normalize(e.right))
+ case e: ArraysOverlap if needNormalize(e.left) =>
+ e.copy(left = normalize(e.left), right = normalize(e.right))
+ }
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1c991729c7d4..0cf03052cbdb 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -328,6 +328,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView,
EliminateSQLFunctionNode,
ReplaceExpressions,
+ NormalizeFloatingNumbers,
RewriteNonCorrelatedExists,
PullOutGroupingExpressions,
// Put `InsertMapSortInGroupingExpressions` after
`PullOutGroupingExpressions`,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 4e06fcb36767..557b01167d88 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -29,7 +29,12 @@ object TreePattern extends Enumeration {
val ALIAS: Value = Value
val ANALYSIS_AWARE_EXPRESSION: Value = Value
val AND: Value = Value
+ val ARRAYS_OVERLAP: Value = Value
val ARRAYS_ZIP: Value = Value
+ val ARRAY_DISTINCT: Value = Value
+ val ARRAY_EXCEPT: Value = Value
+ val ARRAY_INTERSECT: Value = Value
+ val ARRAY_UNION: Value = Value
val ATTRIBUTE_REFERENCE: Value = Value
val AVERAGE: Value = Value
val BINARY_ARITHMETIC: Value = Value
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
index 21049ca3546d..a0a9c8ec3224 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, IsNull,
KnownFloatingPointNormalized}
+import org.apache.spark.sql.catalyst.expressions.{ArrayDistinct, ArrayExcept,
ArrayIntersect, ArraysOverlap, ArrayTransform, ArrayUnion, CaseWhen,
Expression, If, IsNull, KnownFloatingPointNormalized, LambdaFunction,
NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.DoubleType
class NormalizeFloatingPointNumbersSuite extends PlanTest {
@@ -34,6 +35,18 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {
val a = testRelation1.output(0)
val testRelation2 = LocalRelation($"a".double)
val b = testRelation2.output(0)
+ val arrayRelation = LocalRelation($"arr1".array(DoubleType),
$"arr2".array(DoubleType))
+ val arr1 = arrayRelation.output(0)
+ val arr2 = arrayRelation.output(1)
+
+ private def normalizedArray(e: Expression): KnownFloatingPointNormalized = {
+ val lv = NamedLambdaVariable("arg", DoubleType, nullable = true)
+ KnownFloatingPointNormalized(
+ ArrayTransform(e,
+ LambdaFunction(
+ KnownFloatingPointNormalized(NormalizeNaNAndZero(lv)),
+ Seq(lv))))
+ }
test("normalize floating points in window function expressions") {
val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))
@@ -132,5 +145,108 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest
{
val normalizedExpr = NormalizeFloatingNumbers.normalize(nestedExpr)
assert(nestedExpr.dataType == normalizedExpr.dataType)
}
+
+ test("SPARK-54918: normalize floating points in array_distinct") {
+ val query = arrayRelation.select(ArrayDistinct(arr1).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val correctAnswer =
arrayRelation.select(ArrayDistinct(normalizedArray(arr1)).as("result"))
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in array_distinct -
idempotence") {
+ val query = arrayRelation.select(ArrayDistinct(arr1).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val doubleOptimized = Optimize.execute(optimized)
+ val correctAnswer =
arrayRelation.select(ArrayDistinct(normalizedArray(arr1)).as("result"))
+
+ comparePlans(doubleOptimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in array_union") {
+ val query = arrayRelation.select(ArrayUnion(arr1, arr2).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val correctAnswer = arrayRelation.select(
+ ArrayUnion(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in array_union - idempotence") {
+ val query = arrayRelation.select(ArrayUnion(arr1, arr2).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val doubleOptimized = Optimize.execute(optimized)
+ val correctAnswer = arrayRelation.select(
+ ArrayUnion(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+ comparePlans(doubleOptimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in array_intersect") {
+ val query = arrayRelation.select(ArrayIntersect(arr1, arr2).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val correctAnswer = arrayRelation.select(
+ ArrayIntersect(normalizedArray(arr1),
normalizedArray(arr2)).as("result"))
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in array_intersect -
idempotence") {
+ val query = arrayRelation.select(ArrayIntersect(arr1, arr2).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val doubleOptimized = Optimize.execute(optimized)
+ val correctAnswer = arrayRelation.select(
+ ArrayIntersect(normalizedArray(arr1),
normalizedArray(arr2)).as("result"))
+
+ comparePlans(doubleOptimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in array_except") {
+ val query = arrayRelation.select(ArrayExcept(arr1, arr2).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val correctAnswer = arrayRelation.select(
+ ArrayExcept(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in array_except - idempotence")
{
+ val query = arrayRelation.select(ArrayExcept(arr1, arr2).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val doubleOptimized = Optimize.execute(optimized)
+ val correctAnswer = arrayRelation.select(
+ ArrayExcept(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+ comparePlans(doubleOptimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in arrays_overlap") {
+ val query = arrayRelation.select(ArraysOverlap(arr1, arr2).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val correctAnswer = arrayRelation.select(
+ ArraysOverlap(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-54918: normalize floating points in arrays_overlap -
idempotence") {
+ val query = arrayRelation.select(ArraysOverlap(arr1, arr2).as("result"))
+
+ val optimized = Optimize.execute(query)
+ val doubleOptimized = Optimize.execute(optimized)
+ val correctAnswer = arrayRelation.select(
+ ArraysOverlap(normalizedArray(arr1), normalizedArray(arr2)).as("result"))
+
+ comparePlans(doubleOptimized, correctAnswer)
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 7faccbde997d..8f3098bedccc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -6348,7 +6348,159 @@ class DataFrameFunctionsSuite extends
SharedSparkSession {
call_function("spark_catalog.default.custom_sum", $"a")),
Row(12.0, 12.0, 12.0))
}
+ }
+
+ private def isPositiveZero(d: Double): Boolean =
+ java.lang.Double.doubleToRawLongBits(d) == 0L
+
+ test("SPARK-54918: array set ops normalize -0.0 and NaN via VALUES inline
table") {
+ val r = sql("""
+ SELECT
+ array_distinct(a) AS d,
+ array_union(a, b) AS u,
+ array_intersect(a, b) AS i,
+ array_except(a, b) AS e,
+ arrays_overlap(a, b) AS o
+ FROM VALUES (array(-0.0d, 0.0d, double('NaN')), array(0.0d,
double('NaN')))
+ AS t(a, b)
+ """).head()
+
+ val distinct = r.getSeq[Double](0)
+ assert(distinct.length == 2)
+ assert(distinct.exists(isPositiveZero))
+ assert(distinct.exists(_.isNaN))
+
+ val union = r.getSeq[Double](1)
+ assert(union.length == 2)
+ assert(union.exists(isPositiveZero))
+ assert(union.exists(_.isNaN))
+
+ val intersect = r.getSeq[Double](2)
+ assert(intersect.length == 2)
+ assert(intersect.exists(isPositiveZero))
+ assert(intersect.exists(_.isNaN))
+
+ val except = r.getSeq[Double](3)
+ assert(except.isEmpty)
+
+ assert(r.getBoolean(4))
+ }
+
+ test("SPARK-54918: array_distinct normalizes -0.0 to +0.0 - literals") {
+ val r1 = Seq(1).toDF()
+ .select(array_distinct(typedLit(Array(-0.0d,
0.0d)))).head().getSeq[Double](0)
+
+ assert(r1.length == 1)
+ assert(isPositiveZero(r1.head))
+
+ val r2 = Seq(1).toDF()
+ .select(array_distinct(
+ typedLit(Array(Double.NaN, 0.0d, -0.0d, Double.NaN)))
+ ).head().getSeq[Double](0)
+
+ assert(r2.length == 2)
+ assert(r2.exists(_.isNaN))
+ assert(r2.exists(isPositiveZero))
+ }
+
+ test("SPARK-54918: array_distinct normalizes -0.0 to +0.0") {
+ val r1 = Seq(Array(-0.0d, 0.0d)).toDF("a")
+ .select(array_distinct($"a")).head().getSeq[Double](0)
+
+ assert(r1.length == 1)
+ assert(isPositiveZero(r1.head))
+
+ val r2 = Seq(Array(Double.NaN, 0.0d, -0.0d, Double.NaN)).toDF("a")
+ .select(array_distinct($"a")).head().getSeq[Double](0)
+
+ assert(r2.length == 2)
+ assert(r2.exists(_.isNaN))
+ assert(r2.exists(isPositiveZero))
+ }
+
+ test("SPARK-54918: array_union normalizes -0.0 to +0.0 - literals") {
+ val r = Seq(1).toDF()
+ .select(array_union(
+ typedLit(Array(-0.0d)),
+ typedLit(Array(0.0d)))
+ ).head().getSeq[Double](0)
+
+ assert(r.length == 1)
+ assert(isPositiveZero(r.head))
+ }
+
+ test("SPARK-54918: array_union normalizes -0.0 to +0.0") {
+ val r = Seq((Array(-0.0d), Array(0.0d))).toDF("a", "b")
+ .select(array_union($"a", $"b")).head().getSeq[Double](0)
+
+ assert(r.length == 1)
+ assert(isPositiveZero(r.head))
+ }
+
+ test("SPARK-54918: array_intersect normalizes -0.0 to +0.0 - literals") {
+ val r = Seq(1).toDF()
+ .select(array_intersect(
+ typedLit(Array(-0.0d)),
+ typedLit(Array(0.0d)))
+ ).head().getSeq[Double](0)
+
+ assert(r.length == 1)
+ assert(isPositiveZero(r.head))
+ }
+
+ test("SPARK-54918: array_intersect normalizes -0.0 to +0.0") {
+ val r = Seq((Array(-0.0d), Array(0.0d))).toDF("a", "b")
+ .select(array_intersect($"a", $"b")).head().getSeq[Double](0)
+
+ assert(r.length == 1)
+ assert(isPositiveZero(r.head))
+ }
+
+ test("SPARK-54918: array_except normalizes -0.0 to +0.0 - literals") {
+ val r1 = Seq(1).toDF()
+ .select(array_except(
+ typedLit(Array(-0.0d)),
+ typedLit(Array(0.0d)))
+ ).head().getSeq[Double](0)
+
+ assert(r1.isEmpty)
+
+ val r2 = Seq(1).toDF()
+ .select(array_except(
+ typedLit(Array(0.0d)),
+ typedLit(Array(-0.0d)))
+ ).head().getSeq[Double](0)
+
+ assert(r2.isEmpty)
+ }
+
+ test("SPARK-54918: array_except normalizes -0.0 to +0.0") {
+ val r1 = Seq((Array(-0.0d), Array(0.0d))).toDF("a", "b")
+ .select(array_except($"a", $"b")).head().getSeq[Double](0)
+
+ assert(r1.isEmpty)
+
+ val r2 = Seq((Array(0.0d), Array(-0.0d))).toDF("a", "b")
+ .select(array_except($"a", $"b")).head().getSeq[Double](0)
+
+ assert(r2.isEmpty)
+ }
+
+ test("SPARK-54918: arrays_overlap normalizes -0.0 to +0.0 - literals") {
+ val r = Seq(1).toDF()
+ .select(arrays_overlap(
+ typedLit(Array(-0.0d)),
+ typedLit(Array(0.0d)))
+ ).head().getBoolean(0)
+
+ assert(r)
+ }
+
+ test("SPARK-54918: arrays_overlap normalizes -0.0 to +0.0") {
+ val r = Seq((Array(-0.0d), Array(0.0d))).toDF("a", "b")
+ .select(arrays_overlap($"a", $"b")).head().getBoolean(0)
+ assert(r)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]