Repository: spark Updated Branches: refs/heads/branch-1.6 393f4ba15 -> be3c41b26
[SPARK-15892][ML] Incorrectly merged AFTAggregator with zero total count ## What changes were proposed in this pull request? Currently, `AFTAggregator` is not being merged correctly. For example, if there is any single empty partition in the data, this creates an `AFTAggregator` with zero total count which causes the exception below: ``` IllegalArgumentException: u'requirement failed: The number of instances should be greater than 0.0, but got 0.' ``` Please see [AFTSurvivalRegression.scala#L573-L575](https://github.com/apache/spark/blob/6ecedf39b44c9acd58cdddf1a31cf11e8e24428c/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala#L573-L575) as well. Just to be clear, the python example `aft_survival_regression.py` seems using 5 rows. So, if there exist partitions more than 5, it throws the exception above since it contains empty partitions which results in an incorrectly merged `AFTAggregator`. Executing `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py` on a machine with CPUs more than 5 is being failed because it creates tasks with some empty partitions with defualt configurations (AFAIK, it sets the parallelism level to the number of CPU cores). ## How was this patch tested? An unit test in `AFTSurvivalRegressionSuite.scala` and manually tested by `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py`. Author: hyukjinkwon <gurwls...@gmail.com> Author: Hyukjin Kwon <gurwls...@gmail.com> Closes #13619 from HyukjinKwon/SPARK-15892. (cherry picked from commit e3554605b36bdce63ac180cc66dbdee5c1528ec7) Signed-off-by: Joseph K. Bradley <jos...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/be3c41b2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/be3c41b2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/be3c41b2 Branch: refs/heads/branch-1.6 Commit: be3c41b2633215ff6f20885c04f288aab25a1712 Parents: 393f4ba Author: hyukjinkwon <gurwls...@gmail.com> Authored: Sun Jun 12 14:26:53 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Sun Jun 12 14:27:20 2016 -0700 ---------------------------------------------------------------------- .../spark/ml/regression/AFTSurvivalRegression.scala | 2 +- .../ml/regression/AFTSurvivalRegressionSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/be3c41b2/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index aedfb48..cc1d19e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -496,7 +496,7 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * @return This AFTAggregator object. */ def merge(other: AFTAggregator): this.type = { - if (totalCnt != 0) { + if (other.count != 0) { totalCnt += other.totalCnt lossSum += other.lossSum http://git-wip-us.apache.org/repos/asf/spark/blob/be3c41b2/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index d718ef6..e452efb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -346,6 +346,18 @@ class AFTSurvivalRegressionSuite testEstimatorAndModelReadWrite(aft, datasetMultivariate, AFTSurvivalRegressionSuite.allParamSettings, checkModelData) } + + test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { + // This `dataset` will contain an empty partition because it has two rows but + // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s + // being merged incorrectly when it has an empty partition, running the codes below + // should not throw an exception. + val dataset = spark.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3)) + val trainer = new AFTSurvivalRegression() + trainer.fit(dataset) + } } object AFTSurvivalRegressionSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org