Repository: spark Updated Branches: refs/heads/master 101556d0f -> 7a8250581
[SPARK-19054][ML] Eliminate extra pass in NB ## What changes were proposed in this pull request? eliminate unnecessary extra pass in NB's train ## How was this patch tested? existing tests Author: Zheng RuiFeng <[email protected]> Closes #16453 from zhengruifeng/nb_getNC. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7a825058 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7a825058 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7a825058 Branch: refs/heads/master Commit: 7a82505817d479007adff6424473063d2003fcc1 Parents: 101556d Author: Zheng RuiFeng <[email protected]> Authored: Wed Jan 4 11:54:13 2017 +0000 Committer: Sean Owen <[email protected]> Committed: Wed Jan 4 11:54:13 2017 +0000 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/NaiveBayes.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7a825058/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 94ee2a2..e90040d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -127,13 +127,11 @@ class NaiveBayes @Since("1.5.0") ( private[spark] def trainWithLabelCheck( dataset: Dataset[_], positiveLabel: Boolean): NaiveBayesModel = { - if (positiveLabel) { + if (positiveLabel && isDefined(thresholds)) { val numClasses = getNumClasses(dataset) - if (isDefined(thresholds)) { - require($(thresholds).length == numClasses, this.getClass.getSimpleName + - ".train() called with non-matching numClasses and thresholds.length." + - s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") - } + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } val modelTypeValue = $(modelType) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
