Repository: spark
Updated Branches:
  refs/heads/master 1e5c51f33 -> 7db09abb0


[SPARK-18356][ML] KMeans should cache RDD before training

## What changes were proposed in this pull request?

According to request of Mr. Joseph Bradley , I did this update of my PR 
https://github.com/apache/spark/pull/15965 in order to eliminate the extrat 
fit() method.

jkbradley
## How was this patch tested?
Pass existing tests

Author: Zakaria_Hili <[email protected]>
Author: HILI Zakaria <[email protected]>

Closes #16295 from ZakariaHili/zakbranch.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7db09abb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7db09abb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7db09abb

Branch: refs/heads/master
Commit: 7db09abb0168b77697064c69126ee82ca89609a0
Parents: 1e5c51f
Author: Zakaria_Hili <[email protected]>
Authored: Mon Dec 19 10:30:38 2016 +0000
Committer: Sean Owen <[email protected]>
Committed: Mon Dec 19 10:30:38 2016 +0000

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/clustering/KMeans.scala   | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7db09abb/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index e168a41..e02b532 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -302,22 +302,19 @@ class KMeans @Since("1.5.0") (
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): KMeansModel = {
-    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
-    fit(dataset, handlePersistence)
-  }
-
-  @Since("2.2.0")
-  protected def fit(dataset: Dataset[_], handlePersistence: Boolean): 
KMeansModel = {
     transformSchema(dataset.schema, logging = true)
+
+    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
     val instances: RDD[OldVector] = 
dataset.select(col($(featuresCol))).rdd.map {
       case Row(point: Vector) => OldVectors.fromML(point)
     }
+
     if (handlePersistence) {
       instances.persist(StorageLevel.MEMORY_AND_DISK)
     }
+
     val instr = Instrumentation.create(this, instances)
     instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, 
maxIter, seed, tol)
-
     val algo = new MLlibKMeans()
       .setK($(k))
       .setInitializationMode($(initMode))
@@ -329,6 +326,7 @@ class KMeans @Since("1.5.0") (
     val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
     val summary = new KMeansSummary(
       model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
+
     model.setSummary(Some(summary))
     instr.logSuccess(model)
     if (handlePersistence) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to