Repository: spark
Updated Branches:
  refs/heads/master b88b868eb -> 198d181df


[SPARK-7105] [PYSPARK] [MLLIB] Support model save/load in GMM

This PR introduces save / load for GMM's in python API.

Also I refactored `GaussianMixtureModel` and inherited it from 
`JavaModelWrapper` with model being `GaussianMixtureModelWrapper`, a wrapper 
which provides convenience methods to `GaussianMixtureModel` (due to 
serialization and deserialization issues) and I moved the creation of gaussians 
to the scala backend.

Author: MechCoder <[email protected]>

Closes #7617 from MechCoder/python_gmm_save_load and squashes the following 
commits:

9c305aa [MechCoder] [SPARK-7105] [PySpark] [MLlib] Support model save/load in 
GMM


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/198d181d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/198d181d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/198d181d

Branch: refs/heads/master
Commit: 198d181dfb2c04102afe40680a4637d951e92c0b
Parents: b88b868
Author: MechCoder <[email protected]>
Authored: Tue Jul 28 15:00:25 2015 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Tue Jul 28 15:00:25 2015 -0700

----------------------------------------------------------------------
 .../python/GaussianMixtureModelWrapper.scala    | 53 ++++++++++++++
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 13 +---
 python/pyspark/mllib/clustering.py              | 75 ++++++++++++++------
 python/pyspark/mllib/util.py                    |  6 ++
 4 files changed, 114 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/198d181d/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
new file mode 100644
index 0000000..0ec88ef
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.api.python
+
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix}
+import org.apache.spark.mllib.clustering.GaussianMixtureModel
+
+/**
+  * Wrapper around GaussianMixtureModel to provide helper methods in Python
+  */
+private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) 
{
+  val weights: Vector = Vectors.dense(model.weights)
+  val k: Int = weights.size
+
+  /**
+    * Returns gaussians as a List of Vectors and Matrices corresponding each 
MultivariateGaussian
+    */
+  val gaussians: JList[Object] = {
+    val modelGaussians = model.gaussians
+    var i = 0
+    var mu = ArrayBuffer.empty[Vector]
+    var sigma = ArrayBuffer.empty[Matrix]
+    while (i < k) {
+      mu += modelGaussians(i).mu
+      sigma += modelGaussians(i).sigma
+      i += 1
+    }
+    List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+  }
+
+  def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/198d181d/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index fda8d5a..6f080d3 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable {
       seed: java.lang.Long,
       initialModelWeights: java.util.ArrayList[Double],
       initialModelMu: java.util.ArrayList[Vector],
-      initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = {
+      initialModelSigma: java.util.ArrayList[Matrix]): 
GaussianMixtureModelWrapper = {
     val gmmAlg = new GaussianMixture()
       .setK(k)
       .setConvergenceTol(convergenceTol)
@@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable {
     if (seed != null) gmmAlg.setSeed(seed)
 
     try {
-      val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
-      var wt = ArrayBuffer.empty[Double]
-      var mu = ArrayBuffer.empty[Vector]
-      var sigma = ArrayBuffer.empty[Matrix]
-      for (i <- 0 until model.k) {
-          wt += model.weights(i)
-          mu += model.gaussians(i).mu
-          sigma += model.gaussians(i).sigma
-      }
-      List(Vectors.dense(wt.toArray), mu.toArray, 
sigma.toArray).map(_.asInstanceOf[Object]).asJava
+      new 
GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)))
     } finally {
       data.rdd.unpersist(blocking = false)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/198d181d/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py 
b/python/pyspark/mllib/clustering.py
index 58ad99d..900ade2 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -152,11 +152,19 @@ class KMeans(object):
         return KMeansModel([c.toArray() for c in centers])
 
 
-class GaussianMixtureModel(object):
+@inherit_doc
+class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
+
+    """
+    .. note:: Experimental
 
-    """A clustering model derived from the Gaussian Mixture Model method.
+    A clustering model derived from the Gaussian Mixture Model method.
 
     >>> from pyspark.mllib.linalg import Vectors, DenseMatrix
+    >>> from numpy.testing import assert_equal
+    >>> from shutil import rmtree
+    >>> import os, tempfile
+
     >>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
     ...                                         0.9,0.8,0.75,0.935,
     ...                                        -0.83,-0.68,-0.91,-0.76 
]).reshape(6, 2))
@@ -169,6 +177,25 @@ class GaussianMixtureModel(object):
     True
     >>> labels[4]==labels[5]
     True
+
+    >>> path = tempfile.mkdtemp()
+    >>> model.save(sc, path)
+    >>> sameModel = GaussianMixtureModel.load(sc, path)
+    >>> assert_equal(model.weights, sameModel.weights)
+    >>> mus, sigmas = list(
+    ...     zip(*[(g.mu, g.sigma) for g in model.gaussians]))
+    >>> sameMus, sameSigmas = list(
+    ...     zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
+    >>> mus == sameMus
+    True
+    >>> sigmas == sameSigmas
+    True
+    >>> from shutil import rmtree
+    >>> try:
+    ...     rmtree(path)
+    ... except OSError:
+    ...     pass
+
     >>> data =  array([-5.1971, -2.5359, -3.8220,
     ...                -5.2211, -5.0602,  4.7118,
     ...                 6.8989, 3.4592,  4.6322,
@@ -182,25 +209,15 @@ class GaussianMixtureModel(object):
     True
     >>> labels[3]==labels[4]
     True
-    >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1))
-    >>> im = GaussianMixtureModel([0.5, 0.5],
-    ...      [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, 
[1.0])),
-    ...      MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, 
[1.0]))])
-    >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im)
     """
 
-    def __init__(self, weights, gaussians):
-        self._weights = weights
-        self._gaussians = gaussians
-        self._k = len(self._weights)
-
     @property
     def weights(self):
         """
         Weights for each Gaussian distribution in the mixture, where 
weights[i] is
         the weight for Gaussian i, and weights.sum == 1.
         """
-        return self._weights
+        return array(self.call("weights"))
 
     @property
     def gaussians(self):
@@ -208,12 +225,14 @@ class GaussianMixtureModel(object):
         Array of MultivariateGaussian where gaussians[i] represents
         the Multivariate Gaussian (Normal) Distribution for Gaussian i.
         """
-        return self._gaussians
+        return [
+            MultivariateGaussian(gaussian[0], gaussian[1])
+            for gaussian in zip(*self.call("gaussians"))]
 
     @property
     def k(self):
         """Number of gaussians in mixture."""
-        return self._k
+        return len(self.weights)
 
     def predict(self, x):
         """
@@ -238,17 +257,30 @@ class GaussianMixtureModel(object):
         :return:     membership_matrix. RDD of array of double values.
         """
         if isinstance(x, RDD):
-            means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians])
+            means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
             membership_matrix = callMLlibFunc("predictSoftGMM", 
x.map(_convert_to_vector),
-                                              
_convert_to_vector(self._weights), means, sigmas)
+                                              
_convert_to_vector(self.weights), means, sigmas)
             return membership_matrix.map(lambda x: pyarray.array('d', x))
         else:
             raise TypeError("x should be represented by an RDD, "
                             "but got %s." % type(x))
 
+    @classmethod
+    def load(cls, sc, path):
+        """Load the GaussianMixtureModel from disk.
+
+        :param sc: SparkContext
+        :param path: str, path to where the model is stored.
+        """
+        model = cls._load_java(sc, path)
+        wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
+        return cls(wrapper)
+
 
 class GaussianMixture(object):
     """
+    .. note:: Experimental
+
     Learning algorithm for Gaussian Mixtures using the 
expectation-maximization algorithm.
 
     :param data:            RDD of data points
@@ -271,11 +303,10 @@ class GaussianMixture(object):
             initialModelWeights = initialModel.weights
             initialModelMu = [initialModel.gaussians[i].mu for i in 
range(initialModel.k)]
             initialModelSigma = [initialModel.gaussians[i].sigma for i in 
range(initialModel.k)]
-        weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", 
rdd.map(_convert_to_vector),
-                                          k, convergenceTol, maxIterations, 
seed,
-                                          initialModelWeights, initialModelMu, 
initialModelSigma)
-        mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
-        return GaussianMixtureModel(weight, mvg_obj)
+        java_model = callMLlibFunc("trainGaussianMixtureModel", 
rdd.map(_convert_to_vector),
+                                   k, convergenceTol, maxIterations, seed,
+                                   initialModelWeights, initialModelMu, 
initialModelSigma)
+        return GaussianMixtureModel(java_model)
 
 
 class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, 
JavaLoader):

http://git-wip-us.apache.org/repos/asf/spark/blob/198d181d/python/pyspark/mllib/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 875d3b2..916de2d 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -21,7 +21,9 @@ import warnings
 
 if sys.version > '3':
     xrange = range
+    basestring = str
 
+from pyspark import SparkContext
 from pyspark.mllib.common import callMLlibFunc, inherit_doc
 from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
 
@@ -223,6 +225,10 @@ class JavaSaveable(Saveable):
     """
 
     def save(self, sc, path):
+        if not isinstance(sc, SparkContext):
+            raise TypeError("sc should be a SparkContext, got type %s" % 
type(sc))
+        if not isinstance(path, basestring):
+            raise TypeError("path should be a basestring, got type %s" % 
type(path))
         self._java_model.save(sc._jsc.sc(), path)
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to