This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 1ea5844  [SPARK-31676][ML] QuantileDiscretizer raise error parameter 
splits given invalid value (splits array includes -0.0 and 0.0)
1ea5844 is described below

commit 1ea584443e9372a6a0b3c8449f5bf7e9e1369b0d
Author: Weichen Xu <[email protected]>
AuthorDate: Thu May 14 09:24:40 2020 -0500

    [SPARK-31676][ML] QuantileDiscretizer raise error parameter splits given 
invalid value (splits array includes -0.0 and 0.0)
    
    In QuantileDiscretizer.getDistinctSplits, before invoking distinct, 
normalize all -0.0 and 0.0 to be 0.0
    ```
        for (i <- 0 until splits.length) {
          if (splits(i) == -0.0) {
            splits(i) = 0.0
          }
        }
    ```
    Fix bug.
    
    No
    
    Unit test.
    
    ~~~scala
    import scala.util.Random
    val rng = new Random(3)
    
    val a1 = Array.tabulate(200)(_=>rng.nextDouble * 2.0 - 1.0) ++ 
Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
    
    import spark.implicits._
    val df1 = sc.parallelize(a1, 2).toDF("id")
    
    import org.apache.spark.ml.feature.QuantileDiscretizer
    val qd = new 
QuantileDiscretizer().setInputCol("id").setOutputCol("out").setNumBuckets(200).setRelativeError(0.0)
    
    val model = qd.fit(df1) // will raise error in spark master.
    ~~~
    
    scala `0.0 == -0.0` is True but `0.0.hashCode == -0.0.hashCode()` is False. 
This break the contract between equals() and hashCode() If two objects are 
equal, then they must have the same hash code.
    
    And array.distinct will rely on elem.hashCode so it leads to this error.
    
    Test code on distinct
    ```
    import scala.util.Random
    val rng = new Random(3)
    
    val a1 = Array.tabulate(200)(_=>rng.nextDouble * 2.0 - 1.0) ++ 
Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
    a1.distinct.sorted.foreach(x => print(x.toString + "\n"))
    ```
    
    Then you will see output like:
    ```
    ...
    -0.009292684662246975
    -0.0033280686465135823
    -0.0
    0.0
    0.0022219556032221366
    0.02217419561977274
    ...
    ```
    
    Closes #28498 from WeichenXu123/SPARK-31676.
    
    Authored-by: Weichen Xu <[email protected]>
    Signed-off-by: Sean Owen <[email protected]>
    (cherry picked from commit b2300fca1e1a22d74c6eeda37942920a6c6299ff)
    Signed-off-by: Sean Owen <[email protected]>
---
 .../apache/spark/ml/feature/QuantileDiscretizer.scala  | 12 ++++++++++++
 .../spark/ml/feature/QuantileDiscretizerSuite.scala    | 18 ++++++++++++++++++
 2 files changed, 30 insertions(+)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 56e2c54..f3ec358 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -243,6 +243,18 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
   private def getDistinctSplits(splits: Array[Double]): Array[Double] = {
     splits(0) = Double.NegativeInfinity
     splits(splits.length - 1) = Double.PositiveInfinity
+
+    // 0.0 and -0.0 are distinct values, array.distinct will preserve both of 
them.
+    // but 0.0 > -0.0 is False which will break the parameter validation 
checking.
+    // and in scala <= 2.12, there's bug which will cause array.distinct 
generate
+    // non-deterministic results when array contains both 0.0 and -0.0
+    // So that here we should first normalize all 0.0 and -0.0 to be 0.0
+    // See https://github.com/scala/bug/issues/11995
+    for (i <- 0 until splits.length) {
+      if (splits(i) == -0.0) {
+        splits(i) = 0.0
+      }
+    }
     val distinctSplits = splits.distinct
     if (splits.length != distinctSplits.length) {
       log.warn(s"Some quantiles were identical. Bucketing to 
${distinctSplits.length - 1}" +
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index b009038..9c37416 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -443,4 +443,22 @@ class QuantileDiscretizerSuite extends MLTest with 
DefaultReadWriteTest {
       discretizer.fit(df)
     }
   }
+
+  test("[SPARK-31676] QuantileDiscretizer raise error parameter splits given 
invalid value") {
+    import scala.util.Random
+    val rng = new Random(3)
+
+    val a1 = Array.tabulate(200)(_ => rng.nextDouble * 2.0 - 1.0) ++
+      Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
+
+    val df1 = sc.parallelize(a1, 2).toDF("id")
+
+    val qd = new QuantileDiscretizer()
+      .setInputCol("id")
+      .setOutputCol("out")
+      .setNumBuckets(200)
+      .setRelativeError(0.0)
+
+    qd.fit(df1) // assert no exception raised here.
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to