Repository: spark Updated Branches: refs/heads/master 1fad55968 -> 8e491af52
[SPARK-14077][ML][FOLLOW-UP] Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0 ## What changes were proposed in this pull request? Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0 ## How was this patch tested? local build Author: Zheng RuiFeng <[email protected]> Closes #15313 from zhengruifeng/revert_save_load. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8e491af5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8e491af5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8e491af5 Branch: refs/heads/master Commit: 8e491af52930886cbe0c54e7d67add3796ddb15f Parents: 1fad559 Author: Zheng RuiFeng <[email protected]> Authored: Fri Sep 30 08:18:48 2016 -0700 Committer: Yanbo Liang <[email protected]> Committed: Fri Sep 30 08:18:48 2016 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/NaiveBayes.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8e491af5/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0d652aa..6775745 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -25,7 +25,8 @@ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ -import org.apache.spark.sql.Dataset +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DoubleType @@ -362,9 +363,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() - val pi = data.getAs[Vector](0) - val theta = data.getAs[Matrix](1) + val data = sparkSession.read.parquet(dataPath) + val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") + val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") + .select("pi", "theta") + .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
