Repository: spark Updated Branches: refs/heads/branch-1.6 496496b25 -> 5e53d4a8d
[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6. This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage]. This PR modifies the following to accept subclasses of PipelineStage: * Pipeline.setStages() * Params.w() Unit test which fails to compile before this fix. Author: Joseph K. Bradley <[email protected]> Closes #12430 from jkbradley/pipeline-setstages. (cherry picked from commit f5ebb18c45ffdee2756a80f64239cb9158df1a11) Signed-off-by: Joseph K. Bradley <[email protected]> Conflicts: mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5e53d4a8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5e53d4a8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5e53d4a8 Branch: refs/heads/branch-1.6 Commit: 5e53d4a8dc68390d0cc2722fc4a5b4f341b8125f Parents: 496496b Author: Joseph K. Bradley <[email protected]> Authored: Wed Apr 27 16:11:12 2016 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Wed Apr 27 16:15:46 2016 -0700 ---------------------------------------------------------------------- mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 5 ++++- mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5e53d4a8/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 4b2b3f8..eb57ac8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -99,7 +99,10 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") /** @group setParam */ - def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def setStages(value: Array[_ <: PipelineStage]): this.type = { + set(stages, value.asInstanceOf[Array[PipelineStage]]) + this + } // Below, we clone stages so that modifications to the list of stages will not change // the Param value in the Pipeline. http://git-wip-us.apache.org/repos/asf/spark/blob/5e53d4a8/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 8c86767..9749df6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -174,6 +174,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } } + + test("Pipeline.setStages should handle Java Arrays being non-covariant") { + val stages0 = Array(new UnWritableStage("b")) + val stages1 = Array(new WritableStage("a")) + val steps = stages0 ++ stages1 + val p = new Pipeline().setStages(steps) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
