Repository: spark Updated Branches: refs/heads/branch-2.1 4fcecb4cf -> 42777b1b3
[SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings ## What changes were proposed in this pull request? Several places in MLlib use custom regexes or other approaches to parse Spark versions. Those should be fixed to use the VersionUtils. This PR replaces custom regexes with VersionUtils to get Spark version numbers. ## How was this patch tested? Existing tests. Signed-off-by: VinceShieh vincent.xieintel.com Author: VinceShieh <[email protected]> Closes #15055 from VinceShieh/SPARK-17462. (cherry picked from commit de77c67750dc868d75d6af173c3820b75a9fe4b7) Signed-off-by: Sean Owen <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42777b1b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42777b1b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42777b1b Branch: refs/heads/branch-2.1 Commit: 42777b1b3c10d3945494e27f1dedd43f2f836361 Parents: 4fcecb4 Author: VinceShieh <[email protected]> Authored: Thu Nov 17 13:37:42 2016 +0000 Committer: Sean Owen <[email protected]> Committed: Thu Nov 17 13:37:53 2016 +0000 ---------------------------------------------------------------------- .../src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 6 ++---- mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/42777b1b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a0d481b..26505b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Common params for KMeans and KMeansModel @@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - - val clusterCenters = if (major.toInt >= 2) { + val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { http://git-wip-us.apache.org/repos/asf/spark/blob/42777b1b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 444006f..1e49352 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Params for [[PCA]] and [[PCAModel]]. @@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - val dataPath = new Path(path, "data").toString - val model = if (major.toInt >= 2) { + val model = if (majorVersion(metadata.sparkVersion) >= 2) { val Row(pc: DenseMatrix, explainedVariance: DenseVector) = sparkSession.read.parquet(dataPath) .select("pc", "explainedVariance") --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
