This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a32c92c  [SPARK-28140][MLLIB][PYTHON] Accept DataFrames in RowMatrix 
and IndexedRowMatrix constructors
a32c92c is described below

commit a32c92c0cd72046053879416d77510d1c57bafcd
Author: Henry D <[email protected]>
AuthorDate: Tue Jul 9 16:39:21 2019 -0500

    [SPARK-28140][MLLIB][PYTHON] Accept DataFrames in RowMatrix and 
IndexedRowMatrix constructors
    
    ## What changes were proposed in this pull request?
    
    In both cases, the input `DataFrame` schema must contain only the 
information that's required for the matrix object, so a vector column in the 
case of `RowMatrix` and long and vector columns for `IndexedRowMatrix`.
    
    ## How was this patch tested?
    
    Unit tests that verify:
    - `RowMatrix` and `IndexedRowMatrix` can be created from `DataFrame`s
    - If the schema does not match expectations, we throw an 
`IllegalArgumentException`
    
    Please review https://spark.apache.org/contributing.html before opening a 
pull request.
    
    Closes #24953 from henrydavidge/row-matrix-df.
    
    Authored-by: Henry D <[email protected]>
    Signed-off-by: Sean Owen <[email protected]>
---
 .../spark/mllib/api/python/PythonMLLibAPI.scala    | 10 ++++++++++
 python/pyspark/mllib/linalg/distributed.py         | 15 ++++++++++----
 python/pyspark/mllib/tests/test_linalg.py          | 23 ++++++++++++++++++++++
 3 files changed, 44 insertions(+), 4 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 322ef93..4c478a5 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -53,6 +53,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, 
GradientBoostedTree
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.types.LongType
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
@@ -1142,12 +1143,21 @@ private[python] class PythonMLLibAPI extends 
Serializable {
     new RowMatrix(rows.rdd, numRows, numCols)
   }
 
+  def createRowMatrix(df: DataFrame, numRows: Long, numCols: Int): RowMatrix = 
{
+    require(df.schema.length == 1 && df.schema.head.dataType.getClass == 
classOf[VectorUDT],
+      "DataFrame must have a single vector type column")
+    new RowMatrix(df.rdd.map { case Row(vector: Vector) => vector }, numRows, 
numCols)
+  }
+
   /**
    * Wrapper around IndexedRowMatrix constructor.
    */
   def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): 
IndexedRowMatrix = {
     // We use DataFrames for serialization of IndexedRows from Python,
     // so map each Row in the DataFrame back to an IndexedRow.
+    require(rows.schema.length == 2 && rows.schema.head.dataType == LongType &&
+      rows.schema(1).dataType.getClass == classOf[VectorUDT],
+      "DataFrame must consist of a long type index column and a vector type 
column")
     val indexedRows = rows.rdd.map {
       case Row(index: Long, vector: Vector) => IndexedRow(index, vector)
     }
diff --git a/python/pyspark/mllib/linalg/distributed.py 
b/python/pyspark/mllib/linalg/distributed.py
index b7f0978..5670175 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -30,6 +30,7 @@ from pyspark import RDD, since
 from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
 from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, 
QRDecomposition
 from pyspark.mllib.stat import MultivariateStatisticalSummary
+from pyspark.sql import DataFrame
 from pyspark.storagelevel import StorageLevel
 
 
@@ -57,7 +58,8 @@ class RowMatrix(DistributedMatrix):
     Represents a row-oriented distributed Matrix with no meaningful
     row indices.
 
-    :param rows: An RDD of vectors.
+    :param rows: An RDD or DataFrame of vectors. If a DataFrame is provided, 
it must have a single
+                 vector typed column.
     :param numRows: Number of rows in the matrix. A non-positive
                     value means unknown, at which point the number
                     of rows will be determined by the number of
@@ -73,7 +75,7 @@ class RowMatrix(DistributedMatrix):
 
         Create a wrapper over a Java RowMatrix.
 
-        Publicly, we require that `rows` be an RDD.  However, for
+        Publicly, we require that `rows` be an RDD or DataFrame.  However, for
         internal usage, `rows` can also be a Java RowMatrix
         object, in which case we can wrap it directly.  This
         assists in clean matrix conversions.
@@ -94,6 +96,8 @@ class RowMatrix(DistributedMatrix):
         if isinstance(rows, RDD):
             rows = rows.map(_convert_to_vector)
             java_matrix = callMLlibFunc("createRowMatrix", rows, 
long(numRows), int(numCols))
+        elif isinstance(rows, DataFrame):
+            java_matrix = callMLlibFunc("createRowMatrix", rows, 
long(numRows), int(numCols))
         elif (isinstance(rows, JavaObject)
               and rows.getClass().getSimpleName() == "RowMatrix"):
             java_matrix = rows
@@ -461,7 +465,8 @@ class IndexedRowMatrix(DistributedMatrix):
     """
     Represents a row-oriented distributed Matrix with indexed rows.
 
-    :param rows: An RDD of IndexedRows or (long, vector) tuples.
+    :param rows: An RDD of IndexedRows or (long, vector) tuples or a DataFrame 
consisting of a
+                 long typed column of indices and a vector typed column.
     :param numRows: Number of rows in the matrix. A non-positive
                     value means unknown, at which point the number
                     of rows will be determined by the max row
@@ -477,7 +482,7 @@ class IndexedRowMatrix(DistributedMatrix):
 
         Create a wrapper over a Java IndexedRowMatrix.
 
-        Publicly, we require that `rows` be an RDD.  However, for
+        Publicly, we require that `rows` be an RDD or DataFrame.  However, for
         internal usage, `rows` can also be a Java IndexedRowMatrix
         object, in which case we can wrap it directly.  This
         assists in clean matrix conversions.
@@ -506,6 +511,8 @@ class IndexedRowMatrix(DistributedMatrix):
             # IndexedRows on the Scala side.
             java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(),
                                         long(numRows), int(numCols))
+        elif isinstance(rows, DataFrame):
+            java_matrix = callMLlibFunc("createIndexedRowMatrix", rows, 
long(numRows), int(numCols))
         elif (isinstance(rows, JavaObject)
               and rows.getClass().getSimpleName() == "IndexedRowMatrix"):
             java_matrix = rows
diff --git a/python/pyspark/mllib/tests/test_linalg.py 
b/python/pyspark/mllib/tests/test_linalg.py
index b3ab25f..312730e 100644
--- a/python/pyspark/mllib/tests/test_linalg.py
+++ b/python/pyspark/mllib/tests/test_linalg.py
@@ -25,10 +25,15 @@ import pyspark.ml.linalg as newlinalg
 from pyspark.serializers import PickleSerializer
 from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, 
_convert_to_vector, \
     DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
+from pyspark.mllib.linalg.distributed import RowMatrix, IndexedRowMatrix
 from pyspark.mllib.regression import LabeledPoint
+from pyspark.sql import Row
 from pyspark.testing.mllibutils import MLlibTestCase
 from pyspark.testing.utils import have_scipy
 
+if sys.version >= '3':
+    long = int
+
 
 class VectorTests(MLlibTestCase):
 
@@ -431,6 +436,24 @@ class VectorUDTTests(MLlibTestCase):
             else:
                 raise TypeError("expecting a vector but got %r of type %r" % 
(v, type(v)))
 
+    def test_row_matrix_from_dataframe(self):
+        from pyspark.sql.utils import IllegalArgumentException
+        df = self.spark.createDataFrame([Row(Vectors.dense(1))])
+        row_matrix = RowMatrix(df)
+        self.assertEqual(row_matrix.numRows(), 1)
+        self.assertEqual(row_matrix.numCols(), 1)
+        with self.assertRaises(IllegalArgumentException):
+            RowMatrix(df.selectExpr("'monkey'"))
+
+    def test_indexed_row_matrix_from_dataframe(self):
+        from pyspark.sql.utils import IllegalArgumentException
+        df = self.spark.createDataFrame([Row(long(0), Vectors.dense(1))])
+        matrix = IndexedRowMatrix(df)
+        self.assertEqual(matrix.numRows(), 1)
+        self.assertEqual(matrix.numCols(), 1)
+        with self.assertRaises(IllegalArgumentException):
+            IndexedRowMatrix(df.drop("_1"))
+
 
 class MatrixUDTTests(MLlibTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to