Repository: spark
Updated Branches:
  refs/heads/master 63036aee2 -> 56d3a7eb8


[SPARK-18808][ML][MLLIB] ml.KMeansModel.transform is very inefficient

## What changes were proposed in this pull request?

mllib.KMeansModel.clusterCentersWithNorm is a method than ends up being called 
every time `predict` is called on a single vector, which is bad news for now 
the ml.KMeansModel Transformer works, which necessarily transforms one vector 
at a time.

This causes the model to just store the vectors with norms upfront. The extra 
norm should be small compared to the vectors. This would avoid this form of 
overhead on this and other code paths.

## How was this patch tested?

Existing tests.

Author: Sean Owen <[email protected]>

Closes #16328 from srowen/SPARK-18808.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56d3a7eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56d3a7eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56d3a7eb

Branch: refs/heads/master
Commit: 56d3a7eb83f9c91d06dab2c91e10569723eeb105
Parents: 63036ae
Author: Sean Owen <[email protected]>
Authored: Fri Dec 30 10:40:17 2016 +0000
Committer: Sean Owen <[email protected]>
Committed: Fri Dec 30 10:40:17 2016 +0000

----------------------------------------------------------------------
 .../spark/mllib/clustering/KMeansModel.scala       | 17 ++++++++---------
 .../spark/mllib/clustering/StreamingKMeans.scala   |  2 +-
 2 files changed, 9 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56d3a7eb/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index aa78149..df2a9c0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession}
 class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: 
Array[Vector])
   extends Saveable with Serializable with PMMLExportable {
 
+  private val clusterCentersWithNorm =
+    if (clusterCenters == null) null else clusterCenters.map(new 
VectorWithNorm(_))
+
   /**
    * A Java-friendly constructor that takes an Iterable of Vectors.
    */
@@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val 
clusterCenters: Array[Vec
    * Total number of clusters.
    */
   @Since("0.8.0")
-  def k: Int = clusterCenters.length
+  def k: Int = clusterCentersWithNorm.length
 
   /**
    * Returns the cluster index that a given point belongs to.
@@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val 
clusterCenters: Array[Vec
    */
   @Since("1.0.0")
   def predict(points: RDD[Vector]): RDD[Int] = {
-    val centersWithNorm = clusterCentersWithNorm
-    val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+    val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
     points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new 
VectorWithNorm(p))._1)
   }
 
@@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val 
clusterCenters: Array[Vec
    */
   @Since("0.8.0")
   def computeCost(data: RDD[Vector]): Double = {
-    val centersWithNorm = clusterCentersWithNorm
-    val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+    val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
     data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new 
VectorWithNorm(p))).sum()
   }
 
-  private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
-    clusterCenters.map(new VectorWithNorm(_))
 
   @Since("1.4.0")
   override def save(sc: SparkContext, path: String): Unit = {
@@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] {
       val metadata = compact(render(
         ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" 
-> model.k)))
       sc.parallelize(Seq(metadata), 
1).saveAsTextFile(Loader.metadataPath(path))
-      val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { 
case (point, id) =>
-        Cluster(id, point)
+      val dataRDD = 
sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
+        Cluster(id, p.vector)
       }
       spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/56d3a7eb/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 85c37c4..3ca75e8 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") (
       }
     }
 
-    this
+    new StreamingKMeansModel(clusterCenters, clusterWeights)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to