Repository: spark
Updated Branches:
  refs/heads/branch-1.6 393f4ba15 -> be3c41b26


[SPARK-15892][ML] Incorrectly merged AFTAggregator with zero total count

## What changes were proposed in this pull request?

Currently, `AFTAggregator` is not being merged correctly. For example, if there 
is any single empty partition in the data, this creates an `AFTAggregator` with 
zero total count which causes the exception below:

```
IllegalArgumentException: u'requirement failed: The number of instances should 
be greater than 0.0, but got 0.'
```

Please see 
[AFTSurvivalRegression.scala#L573-L575](https://github.com/apache/spark/blob/6ecedf39b44c9acd58cdddf1a31cf11e8e24428c/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala#L573-L575)
 as well.

Just to be clear, the python example `aft_survival_regression.py` seems using 5 
rows. So, if there exist partitions more than 5, it throws the exception above 
since it contains empty partitions which results in an incorrectly merged 
`AFTAggregator`.

Executing `bin/spark-submit 
examples/src/main/python/ml/aft_survival_regression.py` on a machine with CPUs 
more than 5 is being failed because it creates tasks with some empty partitions 
with defualt  configurations (AFAIK, it sets the parallelism level to the 
number of CPU cores).

## How was this patch tested?

An unit test in `AFTSurvivalRegressionSuite.scala` and manually tested by 
`bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py`.

Author: hyukjinkwon <gurwls...@gmail.com>
Author: Hyukjin Kwon <gurwls...@gmail.com>

Closes #13619 from HyukjinKwon/SPARK-15892.

(cherry picked from commit e3554605b36bdce63ac180cc66dbdee5c1528ec7)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/be3c41b2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/be3c41b2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/be3c41b2

Branch: refs/heads/branch-1.6
Commit: be3c41b2633215ff6f20885c04f288aab25a1712
Parents: 393f4ba
Author: hyukjinkwon <gurwls...@gmail.com>
Authored: Sun Jun 12 14:26:53 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Sun Jun 12 14:27:20 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/regression/AFTSurvivalRegression.scala     |  2 +-
 .../ml/regression/AFTSurvivalRegressionSuite.scala      | 12 ++++++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/be3c41b2/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index aedfb48..cc1d19e 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -496,7 +496,7 @@ private class AFTAggregator(parameters: BDV[Double], 
fitIntercept: Boolean)
    * @return This AFTAggregator object.
    */
   def merge(other: AFTAggregator): this.type = {
-    if (totalCnt != 0) {
+    if (other.count != 0) {
       totalCnt += other.totalCnt
       lossSum += other.lossSum
 

http://git-wip-us.apache.org/repos/asf/spark/blob/be3c41b2/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index d718ef6..e452efb 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -346,6 +346,18 @@ class AFTSurvivalRegressionSuite
     testEstimatorAndModelReadWrite(aft, datasetMultivariate,
       AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
   }
+
+  test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
+    // This `dataset` will contain an empty partition because it has two rows 
but
+    // the parallelism is bigger than that. Because the issue was about 
`AFTAggregator`s
+    // being merged incorrectly when it has an empty partition, running the 
codes below
+    // should not throw an exception.
+    val dataset = spark.createDataFrame(
+      sc.parallelize(generateAFTInput(
+        1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3))
+    val trainer = new AFTSurvivalRegression()
+    trainer.fit(dataset)
+  }
 }
 
 object AFTSurvivalRegressionSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to