Repository: spark Updated Branches: refs/heads/master bdabfd43f -> 529d6ce8f
[SPARK-14181] TrainValidationSplit should have HasSeed https://issues.apache.org/jira/browse/SPARK-14181 TrainValidationSplit should have HasSeed for the random split of RDD. I also changed the random split from the RDD function to the DataFrame function. Author: Xusen Yin <[email protected]> Closes #11985 from yinxusen/SPARK-14181. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/529d6ce8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/529d6ce8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/529d6ce8 Branch: refs/heads/master Commit: 529d6ce8f96ef2b4a57c2d9066c7d80466e36209 Parents: bdabfd4 Author: Xusen Yin <[email protected]> Authored: Wed Mar 30 14:32:29 2016 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Wed Mar 30 14:32:29 2016 -0700 ---------------------------------------------------------------------- .../spark/ml/tuning/TrainValidationSplit.scala | 15 ++++++++++----- .../spark/ml/tuning/TrainValidationSplitSuite.scala | 4 ++++ 2 files changed, 14 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/529d6ce8/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 4d1d636..07330bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. */ -private[ml] trait TrainValidationSplitParams extends ValidatorParams { +private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed { /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 @@ -80,6 +81,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def fit(dataset: DataFrame): TrainValidationSplitModel = { val schema = dataset.schema @@ -91,10 +96,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val numModels = epm.length val metrics = new Array[Double](epm.length) - val Array(training, validation) = - dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) - val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() - val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() + val Array(trainingDataset, validationDataset) = + dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + trainingDataset.cache() + validationDataset.cache() // multi-model training logDebug(s"Train split with multiple sets of parameters.") http://git-wip-us.apache.org/repos/asf/spark/blob/529d6ce8/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 7cf7b3e..4030956 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -48,6 +48,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(cv.getTrainRatio === 0.5) @@ -72,6 +73,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) @@ -120,6 +122,7 @@ class TrainValidationSplitSuite .setEvaluator(evaluator) .setTrainRatio(0.5) .setEstimatorParamMaps(paramMaps) + .setSeed(42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false) @@ -140,6 +143,7 @@ class TrainValidationSplitSuite .set(tvs.evaluator, evaluator) .set(tvs.trainRatio, 0.5) .set(tvs.estimatorParamMaps, paramMaps) + .set(tvs.seed, 42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
