Repository: spark Updated Branches: refs/heads/branch-2.0 6d82e0c1b -> 3801fb4f3
[SPARK-15610][ML] update error message for k in pca ## What changes were proposed in this pull request? Fix the wrong bound of `k` in `PCA` `require(k <= sources.first().size, ...` -> `require(k < sources.first().size` BTW, remove unused import in `ml.ElementwiseProduct` ## How was this patch tested? manual tests Author: Zheng RuiFeng <[email protected]> Closes #13356 from zhengruifeng/fix_pca. (cherry picked from commit 9893dc975784551a62f65bbd709f8972e0204b2a) Signed-off-by: Sean Owen <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3801fb4f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3801fb4f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3801fb4f Branch: refs/heads/branch-2.0 Commit: 3801fb4f35ba1ffb8dbaf8326eff927b738551f2 Parents: 6d82e0c Author: Zheng RuiFeng <[email protected]> Authored: Fri May 27 21:57:41 2016 -0500 Committer: Sean Owen <[email protected]> Committed: Fri May 27 21:57:48 2016 -0500 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/feature/ElementwiseProduct.scala | 1 - mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3801fb4f/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 91989c3..9d2e60f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -23,7 +23,6 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.sql.types.DataType http://git-wip-us.apache.org/repos/asf/spark/blob/3801fb4f/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 30c403e..15b7220 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -40,8 +40,9 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { */ @Since("1.4.0") def fit(sources: RDD[Vector]): PCAModel = { - require(k <= sources.first().size, - s"source vector size is ${sources.first().size} must be greater than k=$k") + val numFeatures = sources.first().size + require(k <= numFeatures, + s"source vector size $numFeatures must be no less than k=$k") val mat = new RowMatrix(sources) val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) @@ -58,7 +59,6 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { case m => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") - } val denseExplainedVariance = explainedVariance match { case dv: DenseVector => --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
