Repository: spark Updated Branches: refs/heads/branch-2.0 7b925e500 -> b6b2c6138
[SPARK-15188] Add missing thresholds param to NaiveBayes in PySpark ## What changes were proposed in this pull request? Add missing thresholds param to NiaveBayes ## How was this patch tested? doctests Author: Holden Karau <[email protected]> Closes #12963 from holdenk/SPARK-15188-add-missing-naive-bayes-param. (cherry picked from commit d1aadea05ab1c7350e46479cc68d08e11916a751) Signed-off-by: Nick Pentreath <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b6b2c613 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b6b2c613 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b6b2c613 Branch: refs/heads/branch-2.0 Commit: b6b2c613847779daf2eec8122efdb5f2188fba76 Parents: 7b925e5 Author: Holden Karau <[email protected]> Authored: Fri May 13 08:39:59 2016 +0200 Committer: Nick Pentreath <[email protected]> Committed: Fri May 13 08:40:25 2016 +0200 ---------------------------------------------------------------------- python/pyspark/ml/classification.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b6b2c613/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index c26c2d7..5c11aa7 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -872,7 +872,7 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable) @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, JavaMLWritable, JavaMLReadable): + HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. `Multinomial NB @@ -918,6 +918,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H True >>> model.theta == model2.theta True + >>> nb = nb.setThresholds([0.01, 10.00]) + >>> model3 = nb.fit(df) + >>> result = model3.transform(test0).head() + >>> result.prediction + 0.0 .. versionadded:: 1.5.0 """ @@ -931,11 +936,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial"): + modelType="multinomial", thresholds=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial") + modelType="multinomial", thresholds=None) """ super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( @@ -948,11 +953,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial"): + modelType="multinomial", thresholds=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial") + modelType="multinomial", thresholds=None) Sets params for Naive Bayes. """ kwargs = self.setParams._input_kwargs --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
