Repository: spark
Updated Branches:
  refs/heads/master 585097716 -> 0874ff3aa


[SPARK-13949][ML][PYTHON] PySpark ml DecisionTreeClassifier, Regressor support 
export/import

## What changes were proposed in this pull request?

Added MLReadable and MLWritable to Decision Tree Classifier and Regressor. 
Added doctests.

## How was this patch tested?

Python Unit tests. Tests added to check persistence in DecisionTreeClassifier 
and DecisionTreeRegressor.

Author: GayathriMurali <[email protected]>

Closes #11892 from GayathriMurali/SPARK-13949.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0874ff3a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0874ff3a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0874ff3a

Branch: refs/heads/master
Commit: 0874ff3aade705a97f174b642c5db01711d214b3
Parents: 5850977
Author: GayathriMurali <[email protected]>
Authored: Thu Mar 24 19:20:49 2016 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Thu Mar 24 19:20:49 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/classification.py | 16 +++++++++++--
 python/pyspark/ml/regression.py     | 16 +++++++++++--
 python/pyspark/ml/tests.py          | 40 ++++++++++++++++++++++++++++++--
 3 files changed, 66 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0874ff3a/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index 850d775..d51b80e 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -278,7 +278,8 @@ class GBTParams(TreeEnsembleParams):
 @inherit_doc
 class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol,
                              HasProbabilityCol, HasRawPredictionCol, 
DecisionTreeParams,
-                             TreeClassifierParams, HasCheckpointInterval, 
HasSeed):
+                             TreeClassifierParams, HasCheckpointInterval, 
HasSeed, JavaMLWritable,
+                             JavaMLReadable):
     """
     `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
     learning algorithm for classification.
@@ -313,6 +314,17 @@ class DecisionTreeClassifier(JavaEstimator, 
HasFeaturesCol, HasLabelCol, HasPred
     >>> model.transform(test1).head().prediction
     1.0
 
+    >>> dtc_path = temp_path + "/dtc"
+    >>> dt.save(dtc_path)
+    >>> dt2 = DecisionTreeClassifier.load(dtc_path)
+    >>> dt2.getMaxDepth()
+    2
+    >>> model_path = temp_path + "/dtc_model"
+    >>> model.save(model_path)
+    >>> model2 = DecisionTreeClassificationModel.load(model_path)
+    >>> model.featureImportances == model2.featureImportances
+    True
+
     .. versionadded:: 1.4.0
     """
 
@@ -361,7 +373,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPred
 
 
 @inherit_doc
-class DecisionTreeClassificationModel(DecisionTreeModel):
+class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, 
JavaMLReadable):
     """
     Model fitted by DecisionTreeClassifier.
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0874ff3a/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 59d4fe3..3764854 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -389,7 +389,7 @@ class GBTParams(TreeEnsembleParams):
 @inherit_doc
 class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol,
                             DecisionTreeParams, TreeRegressorParams, 
HasCheckpointInterval,
-                            HasSeed):
+                            HasSeed, JavaMLWritable, JavaMLReadable):
     """
     `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
     learning algorithm for regression.
@@ -413,6 +413,18 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredi
     >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
+    >>> dtr_path = temp_path + "/dtr"
+    >>> dt.save(dtr_path)
+    >>> dt2 = DecisionTreeRegressor.load(dtr_path)
+    >>> dt2.getMaxDepth()
+    2
+    >>> model_path = temp_path + "/dtr_model"
+    >>> model.save(model_path)
+    >>> model2 = DecisionTreeRegressionModel.load(model_path)
+    >>> model.numNodes == model2.numNodes
+    True
+    >>> model.depth == model2.depth
+    True
 
     .. versionadded:: 1.4.0
     """
@@ -498,7 +510,7 @@ class TreeEnsembleModels(JavaModel):
 
 
 @inherit_doc
-class DecisionTreeRegressionModel(DecisionTreeModel):
+class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, 
JavaMLReadable):
     """
     Model fitted by DecisionTreeRegressor.
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0874ff3a/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2fa5da7..224232e 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -42,13 +42,13 @@ import tempfile
 import numpy as np
 
 from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
-from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.classification import LogisticRegression, 
DecisionTreeClassifier
 from pyspark.ml.clustering import KMeans
 from pyspark.ml.evaluation import RegressionEvaluator
 from pyspark.ml.feature import *
 from pyspark.ml.param import Param, Params, TypeConverters
 from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
-from pyspark.ml.regression import LinearRegression
+from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
 from pyspark.ml.tuning import *
 from pyspark.ml.util import keyword_only
 from pyspark.ml.wrapper import JavaWrapper
@@ -655,6 +655,42 @@ class PersistenceTest(PySparkTestCase):
             except OSError:
                 pass
 
+    def test_decisiontree_classifier(self):
+        dt = DecisionTreeClassifier(maxDepth=1)
+        path = tempfile.mkdtemp()
+        dtc_path = path + "/dtc"
+        dt.save(dtc_path)
+        dt2 = DecisionTreeClassifier.load(dtc_path)
+        self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+                         "Loaded DecisionTreeClassifier instance uid (%s) "
+                         "did not match Param's uid (%s)"
+                         % (dt2.uid, dt2.maxDepth.parent))
+        self.assertEqual(dt._defaultParamMap[dt.maxDepth], 
dt2._defaultParamMap[dt2.maxDepth],
+                         "Loaded DecisionTreeClassifier instance default 
params did not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
+    def test_decisiontree_regressor(self):
+        dt = DecisionTreeRegressor(maxDepth=1)
+        path = tempfile.mkdtemp()
+        dtr_path = path + "/dtr"
+        dt.save(dtr_path)
+        dt2 = DecisionTreeClassifier.load(dtr_path)
+        self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+                         "Loaded DecisionTreeRegressor instance uid (%s) "
+                         "did not match Param's uid (%s)"
+                         % (dt2.uid, dt2.maxDepth.parent))
+        self.assertEqual(dt._defaultParamMap[dt.maxDepth], 
dt2._defaultParamMap[dt2.maxDepth],
+                         "Loaded DecisionTreeRegressor instance default params 
did not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
 
 class HasThrowableProperty(Params):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to