Repository: spark Updated Branches: refs/heads/master 585097716 -> 0874ff3aa
[SPARK-13949][ML][PYTHON] PySpark ml DecisionTreeClassifier, Regressor support export/import ## What changes were proposed in this pull request? Added MLReadable and MLWritable to Decision Tree Classifier and Regressor. Added doctests. ## How was this patch tested? Python Unit tests. Tests added to check persistence in DecisionTreeClassifier and DecisionTreeRegressor. Author: GayathriMurali <[email protected]> Closes #11892 from GayathriMurali/SPARK-13949. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0874ff3a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0874ff3a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0874ff3a Branch: refs/heads/master Commit: 0874ff3aade705a97f174b642c5db01711d214b3 Parents: 5850977 Author: GayathriMurali <[email protected]> Authored: Thu Mar 24 19:20:49 2016 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Thu Mar 24 19:20:49 2016 -0700 ---------------------------------------------------------------------- python/pyspark/ml/classification.py | 16 +++++++++++-- python/pyspark/ml/regression.py | 16 +++++++++++-- python/pyspark/ml/tests.py | 40 ++++++++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0874ff3a/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 850d775..d51b80e 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -278,7 +278,8 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, - TreeClassifierParams, HasCheckpointInterval, HasSeed): + TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -313,6 +314,17 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> model.transform(test1).head().prediction 1.0 + >>> dtc_path = temp_path + "/dtc" + >>> dt.save(dtc_path) + >>> dt2 = DecisionTreeClassifier.load(dtc_path) + >>> dt2.getMaxDepth() + 2 + >>> model_path = temp_path + "/dtc_model" + >>> model.save(model_path) + >>> model2 = DecisionTreeClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + .. versionadded:: 1.4.0 """ @@ -361,7 +373,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred @inherit_doc -class DecisionTreeClassificationModel(DecisionTreeModel): +class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): """ Model fitted by DecisionTreeClassifier. http://git-wip-us.apache.org/repos/asf/spark/blob/0874ff3a/python/pyspark/ml/regression.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 59d4fe3..3764854 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -389,7 +389,7 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, - HasSeed): + HasSeed, JavaMLWritable, JavaMLReadable): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -413,6 +413,18 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> dtr_path = temp_path + "/dtr" + >>> dt.save(dtr_path) + >>> dt2 = DecisionTreeRegressor.load(dtr_path) + >>> dt2.getMaxDepth() + 2 + >>> model_path = temp_path + "/dtr_model" + >>> model.save(model_path) + >>> model2 = DecisionTreeRegressionModel.load(model_path) + >>> model.numNodes == model2.numNodes + True + >>> model.depth == model2.depth + True .. versionadded:: 1.4.0 """ @@ -498,7 +510,7 @@ class TreeEnsembleModels(JavaModel): @inherit_doc -class DecisionTreeRegressionModel(DecisionTreeModel): +class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): """ Model fitted by DecisionTreeRegressor. http://git-wip-us.apache.org/repos/asf/spark/blob/0874ff3a/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2fa5da7..224232e 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -42,13 +42,13 @@ import tempfile import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer -from pyspark.ml.classification import LogisticRegression +from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed -from pyspark.ml.regression import LinearRegression +from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaWrapper @@ -655,6 +655,42 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def test_decisiontree_classifier(self): + dt = DecisionTreeClassifier(maxDepth=1) + path = tempfile.mkdtemp() + dtc_path = path + "/dtc" + dt.save(dtc_path) + dt2 = DecisionTreeClassifier.load(dtc_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeClassifier instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeClassifier instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_decisiontree_regressor(self): + dt = DecisionTreeRegressor(maxDepth=1) + path = tempfile.mkdtemp() + dtr_path = path + "/dtr" + dt.save(dtr_path) + dt2 = DecisionTreeClassifier.load(dtr_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeRegressor instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeRegressor instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + class HasThrowableProperty(Params): --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
