Repository: spark Updated Branches: refs/heads/master 1e5c51f33 -> 7db09abb0
[SPARK-18356][ML] KMeans should cache RDD before training ## What changes were proposed in this pull request? According to request of Mr. Joseph Bradley , I did this update of my PR https://github.com/apache/spark/pull/15965 in order to eliminate the extrat fit() method. jkbradley ## How was this patch tested? Pass existing tests Author: Zakaria_Hili <[email protected]> Author: HILI Zakaria <[email protected]> Closes #16295 from ZakariaHili/zakbranch. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7db09abb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7db09abb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7db09abb Branch: refs/heads/master Commit: 7db09abb0168b77697064c69126ee82ca89609a0 Parents: 1e5c51f Author: Zakaria_Hili <[email protected]> Authored: Mon Dec 19 10:30:38 2016 +0000 Committer: Sean Owen <[email protected]> Committed: Mon Dec 19 10:30:38 2016 +0000 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7db09abb/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e168a41..e02b532 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -302,22 +302,19 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - fit(dataset, handlePersistence) - } - - @Since("2.2.0") - protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true) + + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } + if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) } + val instr = Instrumentation.create(this, instances) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) - val algo = new MLlibKMeans() .setK($(k)) .setInitializationMode($(initMode)) @@ -329,6 +326,7 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.setSummary(Some(summary)) instr.logSuccess(model) if (handlePersistence) { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
