This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a3cf9c5  [SPARK-30247][PYSPARK][FOLLOWUP] Add Python class 
MultivariateGaussian
a3cf9c5 is described below

commit a3cf9c564e74effe0f8457eaf9835ca0d3ab8be3
Author: Huaxin Gao <[email protected]>
AuthorDate: Fri Dec 27 13:30:18 2019 +0800

    [SPARK-30247][PYSPARK][FOLLOWUP] Add Python class MultivariateGaussian
    
    ### What changes were proposed in this pull request?
    add a corresponding class MultivariateGaussian containing a vector and a 
matrix on the py side, so gaussian can be used on the py side.
    
    ### Does this PR introduce any user-facing change?
    add Python class ```MultivariateGaussian```
    
    ### How was this patch tested?
    doctest
    
    Closes #27020 from huaxingao/spark-30247.
    
    Authored-by: Huaxin Gao <[email protected]>
    Signed-off-by: zhengruifeng <[email protected]>
---
 python/pyspark/ml/clustering.py | 36 +++++++++++++++++++++++++++++++++---
 python/pyspark/ml/stat.py       | 17 +++++++++++++++++
 2 files changed, 50 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index f784b8f..7295b76 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -22,7 +22,8 @@ from pyspark import since, keyword_only
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, 
JavaWrapper
 from pyspark.ml.param.shared import *
-from pyspark.ml.common import inherit_doc
+from pyspark.ml.common import inherit_doc, _java2py
+from pyspark.ml.stat import MultivariateGaussian
 from pyspark.sql import DataFrame
 
 __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary',
@@ -161,7 +162,11 @@ class GaussianMixtureModel(JavaModel, 
_GaussianMixtureParams, JavaMLWritable, Ja
         Array of :py:class:`MultivariateGaussian` where gaussians[i] represents
         the Multivariate Gaussian (Normal) Distribution for Gaussian i
         """
-        return self._call_java("gaussians")
+        sc = SparkContext._active_spark_context
+        jgaussians = self._java_obj.gaussians()
+        return [
+            MultivariateGaussian(_java2py(sc, jgaussian.mean()), _java2py(sc, 
jgaussian.cov()))
+            for jgaussian in jgaussians]
 
     @property
     @since("2.0.0")
@@ -263,6 +268,21 @@ class GaussianMixture(JavaEstimator, 
_GaussianMixtureParams, JavaMLWritable, Jav
     >>> gaussians = model.gaussians
     >>> len(gaussians)
     3
+    >>> gaussians[0].mean
+    DenseVector([0.825, 0.8675])
+    >>> gaussians[0].cov.toArray()
+    array([[ 0.005625  , -0.0050625 ],
+           [-0.0050625 ,  0.00455625]])
+    >>> gaussians[1].mean
+    DenseVector([-0.4777, -0.4096])
+    >>> gaussians[1].cov.toArray()
+    array([[ 0.1679695 ,  0.13181786],
+           [ 0.13181786,  0.10524592]])
+    >>> gaussians[2].mean
+    DenseVector([-0.4473, -0.3853])
+    >>> gaussians[2].cov.toArray()
+    array([[ 0.16730412,  0.13112435],
+           [ 0.13112435,  0.10469614]])
     >>> model.gaussiansDF.select("mean").head()
     Row(mean=DenseVector([0.825, 0.8675]))
     >>> model.gaussiansDF.select("cov").head()
@@ -285,7 +305,17 @@ class GaussianMixture(JavaEstimator, 
_GaussianMixtureParams, JavaMLWritable, Jav
     False
     >>> model2.weights == model.weights
     True
-    >>> model2.gaussians == model.gaussians
+    >>> model2.gaussians[0].mean == model.gaussians[0].mean
+    True
+    >>> model2.gaussians[0].cov == model.gaussians[0].cov
+    True
+    >>> model2.gaussians[1].mean == model.gaussians[1].mean
+    True
+    >>> model2.gaussians[1].cov == model.gaussians[1].cov
+    True
+    >>> model2.gaussians[2].mean == model.gaussians[2].mean
+    True
+    >>> model2.gaussians[2].cov == model.gaussians[2].cov
     True
     >>> model2.gaussiansDF.select("mean").head()
     Row(mean=DenseVector([0.825, 0.8675]))
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index 8f2eadd..53a57af 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -19,6 +19,7 @@ import sys
 
 from pyspark import since, SparkContext
 from pyspark.ml.common import _java2py, _py2java
+from pyspark.ml.linalg import DenseMatrix, Vectors
 from pyspark.ml.wrapper import JavaWrapper, _jvm
 from pyspark.sql.column import Column, _to_seq
 from pyspark.sql.functions import lit
@@ -394,6 +395,22 @@ class SummaryBuilder(JavaWrapper):
         return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc))
 
 
+class MultivariateGaussian(object):
+    """Represents a (mean, cov) tuple
+
+    >>> m = MultivariateGaussian(Vectors.dense([11,12]), DenseMatrix(2, 2, 
(1.0, 3.0, 5.0, 2.0)))
+    >>> (m.mean, m.cov.toArray())
+    (DenseVector([11.0, 12.0]), array([[ 1.,  5.],
+           [ 3.,  2.]]))
+
+    .. versionadded:: 3.0.0
+
+    """
+    def __init__(self, mean, cov):
+        self.mean = mean
+        self.cov = cov
+
+
 if __name__ == "__main__":
     import doctest
     import numpy


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to