This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a32c92c [SPARK-28140][MLLIB][PYTHON] Accept DataFrames in RowMatrix
and IndexedRowMatrix constructors
a32c92c is described below
commit a32c92c0cd72046053879416d77510d1c57bafcd
Author: Henry D <[email protected]>
AuthorDate: Tue Jul 9 16:39:21 2019 -0500
[SPARK-28140][MLLIB][PYTHON] Accept DataFrames in RowMatrix and
IndexedRowMatrix constructors
## What changes were proposed in this pull request?
In both cases, the input `DataFrame` schema must contain only the
information that's required for the matrix object, so a vector column in the
case of `RowMatrix` and long and vector columns for `IndexedRowMatrix`.
## How was this patch tested?
Unit tests that verify:
- `RowMatrix` and `IndexedRowMatrix` can be created from `DataFrame`s
- If the schema does not match expectations, we throw an
`IllegalArgumentException`
Please review https://spark.apache.org/contributing.html before opening a
pull request.
Closes #24953 from henrydavidge/row-matrix-df.
Authored-by: Henry D <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
---
.../spark/mllib/api/python/PythonMLLibAPI.scala | 10 ++++++++++
python/pyspark/mllib/linalg/distributed.py | 15 ++++++++++----
python/pyspark/mllib/tests/test_linalg.py | 23 ++++++++++++++++++++++
3 files changed, 44 insertions(+), 4 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 322ef93..4c478a5 100644
---
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -53,6 +53,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel,
GradientBoostedTree
import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.types.LongType
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -1142,12 +1143,21 @@ private[python] class PythonMLLibAPI extends
Serializable {
new RowMatrix(rows.rdd, numRows, numCols)
}
+ def createRowMatrix(df: DataFrame, numRows: Long, numCols: Int): RowMatrix =
{
+ require(df.schema.length == 1 && df.schema.head.dataType.getClass ==
classOf[VectorUDT],
+ "DataFrame must have a single vector type column")
+ new RowMatrix(df.rdd.map { case Row(vector: Vector) => vector }, numRows,
numCols)
+ }
+
/**
* Wrapper around IndexedRowMatrix constructor.
*/
def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int):
IndexedRowMatrix = {
// We use DataFrames for serialization of IndexedRows from Python,
// so map each Row in the DataFrame back to an IndexedRow.
+ require(rows.schema.length == 2 && rows.schema.head.dataType == LongType &&
+ rows.schema(1).dataType.getClass == classOf[VectorUDT],
+ "DataFrame must consist of a long type index column and a vector type
column")
val indexedRows = rows.rdd.map {
case Row(index: Long, vector: Vector) => IndexedRow(index, vector)
}
diff --git a/python/pyspark/mllib/linalg/distributed.py
b/python/pyspark/mllib/linalg/distributed.py
index b7f0978..5670175 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -30,6 +30,7 @@ from pyspark import RDD, since
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix,
QRDecomposition
from pyspark.mllib.stat import MultivariateStatisticalSummary
+from pyspark.sql import DataFrame
from pyspark.storagelevel import StorageLevel
@@ -57,7 +58,8 @@ class RowMatrix(DistributedMatrix):
Represents a row-oriented distributed Matrix with no meaningful
row indices.
- :param rows: An RDD of vectors.
+ :param rows: An RDD or DataFrame of vectors. If a DataFrame is provided,
it must have a single
+ vector typed column.
:param numRows: Number of rows in the matrix. A non-positive
value means unknown, at which point the number
of rows will be determined by the number of
@@ -73,7 +75,7 @@ class RowMatrix(DistributedMatrix):
Create a wrapper over a Java RowMatrix.
- Publicly, we require that `rows` be an RDD. However, for
+ Publicly, we require that `rows` be an RDD or DataFrame. However, for
internal usage, `rows` can also be a Java RowMatrix
object, in which case we can wrap it directly. This
assists in clean matrix conversions.
@@ -94,6 +96,8 @@ class RowMatrix(DistributedMatrix):
if isinstance(rows, RDD):
rows = rows.map(_convert_to_vector)
java_matrix = callMLlibFunc("createRowMatrix", rows,
long(numRows), int(numCols))
+ elif isinstance(rows, DataFrame):
+ java_matrix = callMLlibFunc("createRowMatrix", rows,
long(numRows), int(numCols))
elif (isinstance(rows, JavaObject)
and rows.getClass().getSimpleName() == "RowMatrix"):
java_matrix = rows
@@ -461,7 +465,8 @@ class IndexedRowMatrix(DistributedMatrix):
"""
Represents a row-oriented distributed Matrix with indexed rows.
- :param rows: An RDD of IndexedRows or (long, vector) tuples.
+ :param rows: An RDD of IndexedRows or (long, vector) tuples or a DataFrame
consisting of a
+ long typed column of indices and a vector typed column.
:param numRows: Number of rows in the matrix. A non-positive
value means unknown, at which point the number
of rows will be determined by the max row
@@ -477,7 +482,7 @@ class IndexedRowMatrix(DistributedMatrix):
Create a wrapper over a Java IndexedRowMatrix.
- Publicly, we require that `rows` be an RDD. However, for
+ Publicly, we require that `rows` be an RDD or DataFrame. However, for
internal usage, `rows` can also be a Java IndexedRowMatrix
object, in which case we can wrap it directly. This
assists in clean matrix conversions.
@@ -506,6 +511,8 @@ class IndexedRowMatrix(DistributedMatrix):
# IndexedRows on the Scala side.
java_matrix = callMLlibFunc("createIndexedRowMatrix", rows.toDF(),
long(numRows), int(numCols))
+ elif isinstance(rows, DataFrame):
+ java_matrix = callMLlibFunc("createIndexedRowMatrix", rows,
long(numRows), int(numCols))
elif (isinstance(rows, JavaObject)
and rows.getClass().getSimpleName() == "IndexedRowMatrix"):
java_matrix = rows
diff --git a/python/pyspark/mllib/tests/test_linalg.py
b/python/pyspark/mllib/tests/test_linalg.py
index b3ab25f..312730e 100644
--- a/python/pyspark/mllib/tests/test_linalg.py
+++ b/python/pyspark/mllib/tests/test_linalg.py
@@ -25,10 +25,15 @@ import pyspark.ml.linalg as newlinalg
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT,
_convert_to_vector, \
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
+from pyspark.mllib.linalg.distributed import RowMatrix, IndexedRowMatrix
from pyspark.mllib.regression import LabeledPoint
+from pyspark.sql import Row
from pyspark.testing.mllibutils import MLlibTestCase
from pyspark.testing.utils import have_scipy
+if sys.version >= '3':
+ long = int
+
class VectorTests(MLlibTestCase):
@@ -431,6 +436,24 @@ class VectorUDTTests(MLlibTestCase):
else:
raise TypeError("expecting a vector but got %r of type %r" %
(v, type(v)))
+ def test_row_matrix_from_dataframe(self):
+ from pyspark.sql.utils import IllegalArgumentException
+ df = self.spark.createDataFrame([Row(Vectors.dense(1))])
+ row_matrix = RowMatrix(df)
+ self.assertEqual(row_matrix.numRows(), 1)
+ self.assertEqual(row_matrix.numCols(), 1)
+ with self.assertRaises(IllegalArgumentException):
+ RowMatrix(df.selectExpr("'monkey'"))
+
+ def test_indexed_row_matrix_from_dataframe(self):
+ from pyspark.sql.utils import IllegalArgumentException
+ df = self.spark.createDataFrame([Row(long(0), Vectors.dense(1))])
+ matrix = IndexedRowMatrix(df)
+ self.assertEqual(matrix.numRows(), 1)
+ self.assertEqual(matrix.numCols(), 1)
+ with self.assertRaises(IllegalArgumentException):
+ IndexedRowMatrix(df.drop("_1"))
+
class MatrixUDTTests(MLlibTestCase):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]