This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 14ba4fc47915 [SPARK-45216][SQL] Fix non-deterministic seeded Dataset 
APIs
14ba4fc47915 is described below

commit 14ba4fc479155611ca39bda6f879c34cc78af2ee
Author: Peter Toth <[email protected]>
AuthorDate: Thu Sep 21 09:36:23 2023 +0900

    [SPARK-45216][SQL] Fix non-deterministic seeded Dataset APIs
    
    ### What changes were proposed in this pull request?
    This PR fixes a bug regarding non-deterministic seeded Dataset functions.
    
    If we run the following example the result is the expected equal 2 columns:
    ```
    val c = rand()
    df.select(c, c)
    
    +--------------------------+--------------------------+
    |rand(-4522010140232537566)|rand(-4522010140232537566)|
    +--------------------------+--------------------------+
    |        0.4520819282997137|        0.4520819282997137|
    +--------------------------+--------------------------+
    ```
    
    But if we run use other similar APIs their result is incorrect:
    ```
    val r1 = random()
    val r2 = uuid()
    val r3 = shuffle(col("x"))
    val x = df.select(r1, r1, r2, r2, r3, r3)
    
    
+------------------+------------------+--------------------+--------------------+----------+----------+
    |            rand()|            rand()|              uuid()|              
uuid()|shuffle(x)|shuffle(x)|
    
+------------------+------------------+--------------------+--------------------+----------+----------+
    
|0.7407604956381952|0.7957319451135009|e55bc4b0-74e6-4b0...|a587163b-d06b-4bb...|
 [1, 2, 3]| [2, 1, 3]|
    
+------------------+------------------+--------------------+--------------------+----------+----------+
    ```
    
    This is because the current implementation of `rand()` passes a random seed 
to `Rand`, but other functions like `random()`, `uuid()` and `shuffle()` don’t. 
Later the `ResolveRandomSeed` rule is adds the necessary seeds but since the 
resolution rules don’t track expression object identities they can’t map an 
expression object 2 times to the same transformed object. I.e. in case of 
`random()` the `UnresolvedFunction("random", Seq.empty, ...)` object is 
transformed to 2 different `Rand(U [...]
    
    This PR explicitely adds the seeds.
    
    ### Why are the changes needed?
    To fix the above bug.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, fixes the above bug.
    
    ### How was this patch tested?
    Added new UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #42997 from peter-toth/SPARK-45216-fix-non-deterministic-seeded.
    
    Authored-by: Peter Toth <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/scala/org/apache/spark/sql/functions.scala    | 11 ++++++-----
 .../org/apache/spark/sql/ClientE2ETestSuite.scala      | 18 ++++++++++++++++++
 .../scala/org/apache/spark/sql/FunctionTestSuite.scala |  4 ++--
 .../sql/connect/planner/SparkConnectPlanner.scala      | 17 +++++++++++++++++
 python/pyspark/sql/connect/functions.py                |  8 +++++---
 .../pyspark/sql/tests/connect/test_connect_function.py | 11 +++++++++++
 python/pyspark/sql/tests/test_functions.py             | 11 +++++++++++
 .../main/scala/org/apache/spark/sql/functions.scala    |  6 +++---
 .../scala/org/apache/spark/sql/DataFrameSuite.scala    | 13 +++++++++++++
 9 files changed, 86 insertions(+), 13 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 6f79de1c9155..5bb8a92c1d2e 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.errors.DataTypeErrors
 import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, 
UserDefinedFunction}
 import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.sql.types.DataType.parseTypeWithFallback
+import org.apache.spark.util.SparkClassUtils
 
 /**
  * Commonly used functions available for DataFrame operations. Using functions 
defined here
@@ -1831,7 +1832,7 @@ object functions {
    * @group normal_funcs
    * @since 3.4.0
    */
-  def rand(): Column = Column.fn("rand")
+  def rand(): Column = Column.fn("rand", lit(SparkClassUtils.random.nextLong))
 
   /**
    * Generate a column with independent and identically distributed (i.i.d.) 
samples from the
@@ -1855,7 +1856,7 @@ object functions {
    * @group normal_funcs
    * @since 3.4.0
    */
-  def randn(): Column = Column.fn("randn")
+  def randn(): Column = Column.fn("randn", 
lit(SparkClassUtils.random.nextLong))
 
   /**
    * Partition ID.
@@ -3392,7 +3393,7 @@ object functions {
    * @group misc_funcs
    * @since 3.5.0
    */
-  def uuid(): Column = Column.fn("uuid")
+  def uuid(): Column = Column.fn("uuid", lit(SparkClassUtils.random.nextLong))
 
   /**
    * Returns an encrypted value of `input` using AES in given `mode` with the 
specified `padding`.
@@ -3711,7 +3712,7 @@ object functions {
    * @group misc_funcs
    * @since 3.5.0
    */
-  def random(): Column = Column.fn("random")
+  def random(): Column = Column.fn("random", 
lit(SparkClassUtils.random.nextLong))
 
   /**
    * Returns the bit position for the given input column.
@@ -7069,7 +7070,7 @@ object functions {
    * @group collection_funcs
    * @since 3.4.0
    */
-  def shuffle(e: Column): Column = Column.fn("shuffle", e)
+  def shuffle(e: Column): Column = Column.fn("shuffle", e, 
lit(SparkClassUtils.random.nextLong))
 
   /**
    * Returns a reversed string or an array with reverse order of elements.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 12b8434193c6..21892542eab2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -1296,6 +1296,24 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
       assert(rc == 100)
     }
   }
+
+  test("SPARK-45216: Non-deterministic functions with seed") {
+    val session: SparkSession = spark
+    import session.implicits._
+
+    val df = Seq(Array.range(0, 10)).toDF("a")
+
+    val r = rand()
+    val r2 = randn()
+    val r3 = random()
+    val r4 = uuid()
+    val r5 = shuffle(col("a"))
+    df.select(r, r.as("r"), r2, r2.as("r2"), r3, r3.as("r3"), r4, r4.as("r4"), 
r5, r5.as("r5"))
+      .collect
+      .foreach { row =>
+        (0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1)))
+      }
+  }
 }
 
 private[sql] case class ClassData(a: String, b: Int)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
index ab470e9aaaa2..65dd5862d811 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
@@ -278,7 +278,7 @@ class FunctionTestSuite extends ConnectFunSuite {
     assert(e.hasUnresolvedFunction)
     val fn = e.getUnresolvedFunction
     assert(fn.getFunctionName == "rand")
-    assert(fn.getArgumentsCount == 0)
+    assert(fn.getArgumentsCount == 1)
   }
 
   test("randn no seed") {
@@ -286,6 +286,6 @@ class FunctionTestSuite extends ConnectFunSuite {
     assert(e.hasUnresolvedFunction)
     val fn = e.getUnresolvedFunction
     assert(fn.getFunctionName == "randn")
-    assert(fn.getArgumentsCount == 0)
+    assert(fn.getArgumentsCount == 1)
   }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 84e8e1889ffa..924169715f74 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1966,6 +1966,18 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
         Some(
           CatalystDataToProtobuf(children.head, messageClassName, 
binaryFileDescSetOpt, options))
 
+      case "uuid" if fun.getArgumentsCount == 1 =>
+        // Uuid does not have a constructor which accepts Expression typed 
'seed'
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        val seed = extractLong(children(0), "seed")
+        Some(Uuid(Some(seed)))
+
+      case "shuffle" if fun.getArgumentsCount == 2 =>
+        // Shuffle does not have a constructor which accepts Expression typed 
'seed'
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        val seed = extractLong(children(1), "seed")
+        Some(Shuffle(children(0), Some(seed)))
+
       case _ => None
     }
   }
@@ -2004,6 +2016,11 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     case other => throw InvalidPlanInput(s"$field should be a literal integer, 
but got $other")
   }
 
+  private def extractLong(expr: Expression, field: String): Long = expr match {
+    case Literal(long: Long, LongType) => long
+    case other => throw InvalidPlanInput(s"$field should be a literal long, 
but got $other")
+  }
+
   private def extractString(expr: Expression, field: String): String = expr 
match {
     case Literal(s, StringType) if s != null => s.toString
     case other => throw InvalidPlanInput(s"$field should be a literal string, 
but got $other")
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 4749c642975b..f065e5391fef 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -36,6 +36,8 @@ from typing import (
     ValuesView,
     cast,
 )
+import random
+import sys
 
 import numpy as np
 
@@ -388,7 +390,7 @@ def rand(seed: Optional[int] = None) -> Column:
     if seed is not None:
         return _invoke_function("rand", lit(seed))
     else:
-        return _invoke_function("rand")
+        return _invoke_function("rand", lit(random.randint(0, sys.maxsize)))
 
 
 rand.__doc__ = pysparkfuncs.rand.__doc__
@@ -398,7 +400,7 @@ def randn(seed: Optional[int] = None) -> Column:
     if seed is not None:
         return _invoke_function("randn", lit(seed))
     else:
-        return _invoke_function("randn")
+        return _invoke_function("randn", lit(random.randint(0, sys.maxsize)))
 
 
 randn.__doc__ = pysparkfuncs.randn.__doc__
@@ -2111,7 +2113,7 @@ schema_of_xml.__doc__ = pysparkfuncs.schema_of_xml.__doc__
 
 
 def shuffle(col: "ColumnOrName") -> Column:
-    return _invoke_function_over_columns("shuffle", col)
+    return _invoke_function("shuffle", _to_col(col), lit(random.randint(0, 
sys.maxsize)))
 
 
 shuffle.__doc__ = pysparkfuncs.shuffle.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index f958b5fb574f..bc0cf1626488 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -2494,6 +2494,17 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
             "Missing functions in vanilla PySpark not as expected",
         )
 
+    # SPARK-45216: Fix non-deterministic seeded Dataset APIs
+    def test_non_deterministic_with_seed(self):
+        df = self.connect.createDataFrame([([*range(0, 10, 1)],)], ["a"])
+
+        r = CF.rand()
+        r2 = CF.randn()
+        r3 = CF.shuffle("a")
+        res = df.select(r, r, r2, r2, r3, r3).collect()
+        for i in range(3):
+            self.assertEqual(res[0][i * 2], res[0][i * 2 + 1])
+
 
 if __name__ == "__main__":
     import os
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index b0ad311d733a..66de94a6b9b3 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1353,6 +1353,17 @@ class FunctionsTestsMixin:
             message_parameters={"arg_name": "numBuckets", "arg_type": "str"},
         )
 
+    # SPARK-45216: Fix non-deterministic seeded Dataset APIs
+    def test_non_deterministic_with_seed(self):
+        df = self.spark.createDataFrame([([*range(0, 10, 1)],)], ["a"])
+
+        r = F.rand()
+        r2 = F.randn()
+        r3 = F.shuffle("a")
+        res = df.select(r, r, r2, r2, r3, r3).collect()
+        for i in range(3):
+            self.assertEqual(res[0][i * 2], res[0][i * 2 + 1])
+
 
 class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
     pass
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 1b832fc437cd..2a7ed263c748 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3340,7 +3340,7 @@ object functions {
    * @group misc_funcs
    * @since 3.5.0
    */
-  def uuid(): Column = withExpr { new Uuid() }
+  def uuid(): Column = withExpr { Uuid(Some(Utils.random.nextLong)) }
 
   /**
    * Returns an encrypted value of `input` using AES in given `mode` with the 
specified `padding`.
@@ -3659,7 +3659,7 @@ object functions {
    * @group misc_funcs
    * @since 3.5.0
    */
-  def random(): Column = call_function("random")
+  def random(): Column = random(lit(Utils.random.nextLong))
 
   /**
    * Returns the bucket number for the given input column.
@@ -6837,7 +6837,7 @@ object functions {
    * @group collection_funcs
    * @since 2.4.0
    */
-  def shuffle(e: Column): Column = withExpr { Shuffle(e.expr) }
+  def shuffle(e: Column): Column = withExpr { Shuffle(e.expr, 
Some(Utils.random.nextLong)) }
 
   /**
    * Returns a reversed string or an array with reverse order of elements.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 805bb1ccc287..c72bc9167759 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -3648,6 +3648,19 @@ class DataFrameSuite extends QueryTest
     }
   }
 
+  test("SPARK-45216: Non-deterministic functions with seed") {
+    val df = Seq(Array.range(0, 10)).toDF("a")
+
+    val r = rand()
+    val r2 = randn()
+    val r3 = random()
+    val r4 = uuid()
+    val r5 = shuffle(col("a"))
+    df.select(r, r, r2, r2, r3, r3, r4, r4, r5, r5).collect.foreach { row =>
+      (0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1)))
+    }
+  }
+
   test("SPARK-41219: IntegralDivide use decimal(1, 0) to represent 0") {
     val df = Seq("0.5944910").toDF("a")
     checkAnswer(df.selectExpr("cast(a as decimal(7,7)) div 100"), Row(0))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to