Repository: spark
Updated Branches:
  refs/heads/branch-1.6 496496b25 -> 5e53d4a8d


[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage

Pipeline.setStages failed for some code examples which worked in 1.5 but fail 
in 1.6.  This tends to occur when using a mix of transformers from ml.feature. 
It is because Java Arrays are non-covariant and the addition of MLWritable to 
some transformers means the stages0/1 arrays above are not of type 
Array[PipelineStage].  This PR modifies the following to accept subclasses of 
PipelineStage:
* Pipeline.setStages()
* Params.w()

Unit test which fails to compile before this fix.

Author: Joseph K. Bradley <[email protected]>

Closes #12430 from jkbradley/pipeline-setstages.

(cherry picked from commit f5ebb18c45ffdee2756a80f64239cb9158df1a11)
Signed-off-by: Joseph K. Bradley <[email protected]>

Conflicts:
        mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
        mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5e53d4a8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5e53d4a8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5e53d4a8

Branch: refs/heads/branch-1.6
Commit: 5e53d4a8dc68390d0cc2722fc4a5b4f341b8125f
Parents: 496496b
Author: Joseph K. Bradley <[email protected]>
Authored: Wed Apr 27 16:11:12 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Wed Apr 27 16:15:46 2016 -0700

----------------------------------------------------------------------
 mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala      | 5 ++++-
 mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 7 +++++++
 2 files changed, 11 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5e53d4a8/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 4b2b3f8..eb57ac8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -99,7 +99,10 @@ class Pipeline(override val uid: String) extends 
Estimator[PipelineModel] with M
   val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages 
of the pipeline")
 
   /** @group setParam */
-  def setStages(value: Array[PipelineStage]): this.type = { set(stages, 
value); this }
+  def setStages(value: Array[_ <: PipelineStage]): this.type = {
+    set(stages, value.asInstanceOf[Array[PipelineStage]])
+    this
+  }
 
   // Below, we clone stages so that modifications to the list of stages will 
not change
   // the Param value in the Pipeline.

http://git-wip-us.apache.org/repos/asf/spark/blob/5e53d4a8/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 8c86767..9749df6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -174,6 +174,13 @@ class PipelineSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
       }
     }
   }
+
+  test("Pipeline.setStages should handle Java Arrays being non-covariant") {
+    val stages0 = Array(new UnWritableStage("b"))
+    val stages1 = Array(new WritableStage("a"))
+    val steps = stages0 ++ stages1
+    val p = new Pipeline().setStages(steps)
+  }
 }
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to