This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 14ba4fc47915 [SPARK-45216][SQL] Fix non-deterministic seeded Dataset
APIs
14ba4fc47915 is described below
commit 14ba4fc479155611ca39bda6f879c34cc78af2ee
Author: Peter Toth <[email protected]>
AuthorDate: Thu Sep 21 09:36:23 2023 +0900
[SPARK-45216][SQL] Fix non-deterministic seeded Dataset APIs
### What changes were proposed in this pull request?
This PR fixes a bug regarding non-deterministic seeded Dataset functions.
If we run the following example the result is the expected equal 2 columns:
```
val c = rand()
df.select(c, c)
+--------------------------+--------------------------+
|rand(-4522010140232537566)|rand(-4522010140232537566)|
+--------------------------+--------------------------+
| 0.4520819282997137| 0.4520819282997137|
+--------------------------+--------------------------+
```
But if we run use other similar APIs their result is incorrect:
```
val r1 = random()
val r2 = uuid()
val r3 = shuffle(col("x"))
val x = df.select(r1, r1, r2, r2, r3, r3)
+------------------+------------------+--------------------+--------------------+----------+----------+
| rand()| rand()| uuid()|
uuid()|shuffle(x)|shuffle(x)|
+------------------+------------------+--------------------+--------------------+----------+----------+
|0.7407604956381952|0.7957319451135009|e55bc4b0-74e6-4b0...|a587163b-d06b-4bb...|
[1, 2, 3]| [2, 1, 3]|
+------------------+------------------+--------------------+--------------------+----------+----------+
```
This is because the current implementation of `rand()` passes a random seed
to `Rand`, but other functions like `random()`, `uuid()` and `shuffle()` don’t.
Later the `ResolveRandomSeed` rule is adds the necessary seeds but since the
resolution rules don’t track expression object identities they can’t map an
expression object 2 times to the same transformed object. I.e. in case of
`random()` the `UnresolvedFunction("random", Seq.empty, ...)` object is
transformed to 2 different `Rand(U [...]
This PR explicitely adds the seeds.
### Why are the changes needed?
To fix the above bug.
### Does this PR introduce _any_ user-facing change?
Yes, fixes the above bug.
### How was this patch tested?
Added new UT.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #42997 from peter-toth/SPARK-45216-fix-non-deterministic-seeded.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/scala/org/apache/spark/sql/functions.scala | 11 ++++++-----
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 18 ++++++++++++++++++
.../scala/org/apache/spark/sql/FunctionTestSuite.scala | 4 ++--
.../sql/connect/planner/SparkConnectPlanner.scala | 17 +++++++++++++++++
python/pyspark/sql/connect/functions.py | 8 +++++---
.../pyspark/sql/tests/connect/test_connect_function.py | 11 +++++++++++
python/pyspark/sql/tests/test_functions.py | 11 +++++++++++
.../main/scala/org/apache/spark/sql/functions.scala | 6 +++---
.../scala/org/apache/spark/sql/DataFrameSuite.scala | 13 +++++++++++++
9 files changed, 86 insertions(+), 13 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 6f79de1c9155..5bb8a92c1d2e 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction,
UserDefinedFunction}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.DataType.parseTypeWithFallback
+import org.apache.spark.util.SparkClassUtils
/**
* Commonly used functions available for DataFrame operations. Using functions
defined here
@@ -1831,7 +1832,7 @@ object functions {
* @group normal_funcs
* @since 3.4.0
*/
- def rand(): Column = Column.fn("rand")
+ def rand(): Column = Column.fn("rand", lit(SparkClassUtils.random.nextLong))
/**
* Generate a column with independent and identically distributed (i.i.d.)
samples from the
@@ -1855,7 +1856,7 @@ object functions {
* @group normal_funcs
* @since 3.4.0
*/
- def randn(): Column = Column.fn("randn")
+ def randn(): Column = Column.fn("randn",
lit(SparkClassUtils.random.nextLong))
/**
* Partition ID.
@@ -3392,7 +3393,7 @@ object functions {
* @group misc_funcs
* @since 3.5.0
*/
- def uuid(): Column = Column.fn("uuid")
+ def uuid(): Column = Column.fn("uuid", lit(SparkClassUtils.random.nextLong))
/**
* Returns an encrypted value of `input` using AES in given `mode` with the
specified `padding`.
@@ -3711,7 +3712,7 @@ object functions {
* @group misc_funcs
* @since 3.5.0
*/
- def random(): Column = Column.fn("random")
+ def random(): Column = Column.fn("random",
lit(SparkClassUtils.random.nextLong))
/**
* Returns the bit position for the given input column.
@@ -7069,7 +7070,7 @@ object functions {
* @group collection_funcs
* @since 3.4.0
*/
- def shuffle(e: Column): Column = Column.fn("shuffle", e)
+ def shuffle(e: Column): Column = Column.fn("shuffle", e,
lit(SparkClassUtils.random.nextLong))
/**
* Returns a reversed string or an array with reverse order of elements.
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 12b8434193c6..21892542eab2 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -1296,6 +1296,24 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper with PrivateM
assert(rc == 100)
}
}
+
+ test("SPARK-45216: Non-deterministic functions with seed") {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ val df = Seq(Array.range(0, 10)).toDF("a")
+
+ val r = rand()
+ val r2 = randn()
+ val r3 = random()
+ val r4 = uuid()
+ val r5 = shuffle(col("a"))
+ df.select(r, r.as("r"), r2, r2.as("r2"), r3, r3.as("r3"), r4, r4.as("r4"),
r5, r5.as("r5"))
+ .collect
+ .foreach { row =>
+ (0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1)))
+ }
+ }
}
private[sql] case class ClassData(a: String, b: Int)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
index ab470e9aaaa2..65dd5862d811 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
@@ -278,7 +278,7 @@ class FunctionTestSuite extends ConnectFunSuite {
assert(e.hasUnresolvedFunction)
val fn = e.getUnresolvedFunction
assert(fn.getFunctionName == "rand")
- assert(fn.getArgumentsCount == 0)
+ assert(fn.getArgumentsCount == 1)
}
test("randn no seed") {
@@ -286,6 +286,6 @@ class FunctionTestSuite extends ConnectFunSuite {
assert(e.hasUnresolvedFunction)
val fn = e.getUnresolvedFunction
assert(fn.getFunctionName == "randn")
- assert(fn.getArgumentsCount == 0)
+ assert(fn.getArgumentsCount == 1)
}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 84e8e1889ffa..924169715f74 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1966,6 +1966,18 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
Some(
CatalystDataToProtobuf(children.head, messageClassName,
binaryFileDescSetOpt, options))
+ case "uuid" if fun.getArgumentsCount == 1 =>
+ // Uuid does not have a constructor which accepts Expression typed
'seed'
+ val children = fun.getArgumentsList.asScala.map(transformExpression)
+ val seed = extractLong(children(0), "seed")
+ Some(Uuid(Some(seed)))
+
+ case "shuffle" if fun.getArgumentsCount == 2 =>
+ // Shuffle does not have a constructor which accepts Expression typed
'seed'
+ val children = fun.getArgumentsList.asScala.map(transformExpression)
+ val seed = extractLong(children(1), "seed")
+ Some(Shuffle(children(0), Some(seed)))
+
case _ => None
}
}
@@ -2004,6 +2016,11 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
case other => throw InvalidPlanInput(s"$field should be a literal integer,
but got $other")
}
+ private def extractLong(expr: Expression, field: String): Long = expr match {
+ case Literal(long: Long, LongType) => long
+ case other => throw InvalidPlanInput(s"$field should be a literal long,
but got $other")
+ }
+
private def extractString(expr: Expression, field: String): String = expr
match {
case Literal(s, StringType) if s != null => s.toString
case other => throw InvalidPlanInput(s"$field should be a literal string,
but got $other")
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index 4749c642975b..f065e5391fef 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -36,6 +36,8 @@ from typing import (
ValuesView,
cast,
)
+import random
+import sys
import numpy as np
@@ -388,7 +390,7 @@ def rand(seed: Optional[int] = None) -> Column:
if seed is not None:
return _invoke_function("rand", lit(seed))
else:
- return _invoke_function("rand")
+ return _invoke_function("rand", lit(random.randint(0, sys.maxsize)))
rand.__doc__ = pysparkfuncs.rand.__doc__
@@ -398,7 +400,7 @@ def randn(seed: Optional[int] = None) -> Column:
if seed is not None:
return _invoke_function("randn", lit(seed))
else:
- return _invoke_function("randn")
+ return _invoke_function("randn", lit(random.randint(0, sys.maxsize)))
randn.__doc__ = pysparkfuncs.randn.__doc__
@@ -2111,7 +2113,7 @@ schema_of_xml.__doc__ = pysparkfuncs.schema_of_xml.__doc__
def shuffle(col: "ColumnOrName") -> Column:
- return _invoke_function_over_columns("shuffle", col)
+ return _invoke_function("shuffle", _to_col(col), lit(random.randint(0,
sys.maxsize)))
shuffle.__doc__ = pysparkfuncs.shuffle.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index f958b5fb574f..bc0cf1626488 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -2494,6 +2494,17 @@ class SparkConnectFunctionTests(ReusedConnectTestCase,
PandasOnSparkTestUtils, S
"Missing functions in vanilla PySpark not as expected",
)
+ # SPARK-45216: Fix non-deterministic seeded Dataset APIs
+ def test_non_deterministic_with_seed(self):
+ df = self.connect.createDataFrame([([*range(0, 10, 1)],)], ["a"])
+
+ r = CF.rand()
+ r2 = CF.randn()
+ r3 = CF.shuffle("a")
+ res = df.select(r, r, r2, r2, r3, r3).collect()
+ for i in range(3):
+ self.assertEqual(res[0][i * 2], res[0][i * 2 + 1])
+
if __name__ == "__main__":
import os
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index b0ad311d733a..66de94a6b9b3 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1353,6 +1353,17 @@ class FunctionsTestsMixin:
message_parameters={"arg_name": "numBuckets", "arg_type": "str"},
)
+ # SPARK-45216: Fix non-deterministic seeded Dataset APIs
+ def test_non_deterministic_with_seed(self):
+ df = self.spark.createDataFrame([([*range(0, 10, 1)],)], ["a"])
+
+ r = F.rand()
+ r2 = F.randn()
+ r3 = F.shuffle("a")
+ res = df.select(r, r, r2, r2, r3, r3).collect()
+ for i in range(3):
+ self.assertEqual(res[0][i * 2], res[0][i * 2 + 1])
+
class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
pass
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 1b832fc437cd..2a7ed263c748 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3340,7 +3340,7 @@ object functions {
* @group misc_funcs
* @since 3.5.0
*/
- def uuid(): Column = withExpr { new Uuid() }
+ def uuid(): Column = withExpr { Uuid(Some(Utils.random.nextLong)) }
/**
* Returns an encrypted value of `input` using AES in given `mode` with the
specified `padding`.
@@ -3659,7 +3659,7 @@ object functions {
* @group misc_funcs
* @since 3.5.0
*/
- def random(): Column = call_function("random")
+ def random(): Column = random(lit(Utils.random.nextLong))
/**
* Returns the bucket number for the given input column.
@@ -6837,7 +6837,7 @@ object functions {
* @group collection_funcs
* @since 2.4.0
*/
- def shuffle(e: Column): Column = withExpr { Shuffle(e.expr) }
+ def shuffle(e: Column): Column = withExpr { Shuffle(e.expr,
Some(Utils.random.nextLong)) }
/**
* Returns a reversed string or an array with reverse order of elements.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 805bb1ccc287..c72bc9167759 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -3648,6 +3648,19 @@ class DataFrameSuite extends QueryTest
}
}
+ test("SPARK-45216: Non-deterministic functions with seed") {
+ val df = Seq(Array.range(0, 10)).toDF("a")
+
+ val r = rand()
+ val r2 = randn()
+ val r3 = random()
+ val r4 = uuid()
+ val r5 = shuffle(col("a"))
+ df.select(r, r, r2, r2, r3, r3, r4, r4, r5, r5).collect.foreach { row =>
+ (0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1)))
+ }
+ }
+
test("SPARK-41219: IntegralDivide use decimal(1, 0) to represent 0") {
val df = Seq("0.5944910").toDF("a")
checkAnswer(df.selectExpr("cast(a as decimal(7,7)) div 100"), Row(0))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]