Repository: spark
Updated Branches:
  refs/heads/master bdabfd43f -> 529d6ce8f


[SPARK-14181] TrainValidationSplit should have HasSeed

https://issues.apache.org/jira/browse/SPARK-14181

TrainValidationSplit should have HasSeed for the random split of RDD. I also 
changed the random split from the RDD function to the DataFrame function.

Author: Xusen Yin <[email protected]>

Closes #11985 from yinxusen/SPARK-14181.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/529d6ce8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/529d6ce8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/529d6ce8

Branch: refs/heads/master
Commit: 529d6ce8f96ef2b4a57c2d9066c7d80466e36209
Parents: bdabfd4
Author: Xusen Yin <[email protected]>
Authored: Wed Mar 30 14:32:29 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Wed Mar 30 14:32:29 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/tuning/TrainValidationSplit.scala       | 15 ++++++++++-----
 .../spark/ml/tuning/TrainValidationSplitSuite.scala  |  4 ++++
 2 files changed, 14 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/529d6ce8/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 4d1d636..07330bb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.HasSeed
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
@@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType
 /**
  * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
  */
-private[ml] trait TrainValidationSplitParams extends ValidatorParams {
+private[ml] trait TrainValidationSplitParams extends ValidatorParams with 
HasSeed {
   /**
    * Param for ratio between train and validation data. Must be between 0 and 
1.
    * Default: 0.75
@@ -80,6 +81,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") 
override val uid: St
   @Since("1.5.0")
   def setTrainRatio(value: Double): this.type = set(trainRatio, value)
 
+  /** @group setParam */
+  @Since("2.0.0")
+  def setSeed(value: Long): this.type = set(seed, value)
+
   @Since("1.5.0")
   override def fit(dataset: DataFrame): TrainValidationSplitModel = {
     val schema = dataset.schema
@@ -91,10 +96,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") 
override val uid: St
     val numModels = epm.length
     val metrics = new Array[Double](epm.length)
 
-    val Array(training, validation) =
-      dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
-    val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
-    val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+    val Array(trainingDataset, validationDataset) =
+      dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
+    trainingDataset.cache()
+    validationDataset.cache()
 
     // multi-model training
     logDebug(s"Train split with multiple sets of parameters.")

http://git-wip-us.apache.org/repos/asf/spark/blob/529d6ce8/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 7cf7b3e..4030956 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -48,6 +48,7 @@ class TrainValidationSplitSuite
       .setEstimatorParamMaps(lrParamMaps)
       .setEvaluator(eval)
       .setTrainRatio(0.5)
+      .setSeed(42L)
     val cvModel = cv.fit(dataset)
     val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
     assert(cv.getTrainRatio === 0.5)
@@ -72,6 +73,7 @@ class TrainValidationSplitSuite
       .setEstimatorParamMaps(lrParamMaps)
       .setEvaluator(eval)
       .setTrainRatio(0.5)
+      .setSeed(42L)
     val cvModel = cv.fit(dataset)
     val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
     assert(parent.getRegParam === 0.001)
@@ -120,6 +122,7 @@ class TrainValidationSplitSuite
       .setEvaluator(evaluator)
       .setTrainRatio(0.5)
       .setEstimatorParamMaps(paramMaps)
+      .setSeed(42L)
 
     val tvs2 = testDefaultReadWrite(tvs, testParams = false)
 
@@ -140,6 +143,7 @@ class TrainValidationSplitSuite
       .set(tvs.evaluator, evaluator)
       .set(tvs.trainRatio, 0.5)
       .set(tvs.estimatorParamMaps, paramMaps)
+      .set(tvs.seed, 42L)
 
     val tvs2 = testDefaultReadWrite(tvs, testParams = false)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to