This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7fbfeae25175 [SPARK-52191][ML][CONNECT] Remove Java deserializer in
model local path loader
7fbfeae25175 is described below
commit 7fbfeae25175fbc30dfff65064c669f38a599d46
Author: Weichen Xu <[email protected]>
AuthorDate: Mon May 19 11:00:20 2025 +0800
[SPARK-52191][ML][CONNECT] Remove Java deserializer in model local path
loader
### What changes were proposed in this pull request?
Remove Java deserializer in model local path loader
### Why are the changes needed?
Java deserializer is unsafe, removing it can prevent potential security
issue.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50922 from WeichenXu123/sparkml-loader-remove-java-deserializer.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
.../ml/classification/DecisionTreeClassifier.scala | 4 +-
.../spark/ml/classification/FMClassifier.scala | 21 +-
.../apache/spark/ml/classification/LinearSVC.scala | 19 +-
.../ml/classification/LogisticRegression.scala | 24 +-
.../MultilayerPerceptronClassifier.scala | 17 +-
.../spark/ml/classification/NaiveBayes.scala | 21 +-
.../spark/ml/clustering/GaussianMixture.scala | 25 +-
.../org/apache/spark/ml/clustering/KMeans.scala | 24 +-
.../scala/org/apache/spark/ml/clustering/LDA.scala | 26 +-
.../ml/feature/BucketedRandomProjectionLSH.scala | 17 +-
.../apache/spark/ml/feature/ChiSqSelector.scala | 17 +-
.../apache/spark/ml/feature/CountVectorizer.scala | 17 +-
.../scala/org/apache/spark/ml/feature/IDF.scala | 21 +-
.../org/apache/spark/ml/feature/Imputer.scala | 17 +-
.../org/apache/spark/ml/feature/MaxAbsScaler.scala | 17 +-
.../org/apache/spark/ml/feature/MinHashLSH.scala | 17 +-
.../org/apache/spark/ml/feature/MinMaxScaler.scala | 19 +-
.../apache/spark/ml/feature/OneHotEncoder.scala | 17 +-
.../scala/org/apache/spark/ml/feature/PCA.scala | 19 +-
.../org/apache/spark/ml/feature/RFormula.scala | 45 +++-
.../apache/spark/ml/feature/RFormulaParser.scala | 26 ++
.../org/apache/spark/ml/feature/RobustScaler.scala | 19 +-
.../apache/spark/ml/feature/StandardScaler.scala | 19 +-
.../apache/spark/ml/feature/StringIndexer.scala | 22 +-
.../apache/spark/ml/feature/TargetEncoder.scala | 23 +-
.../ml/feature/UnivariateFeatureSelector.scala | 17 +-
.../ml/feature/VarianceThresholdSelector.scala | 17 +-
.../apache/spark/ml/feature/VectorIndexer.scala | 32 ++-
.../org/apache/spark/ml/feature/Word2Vec.scala | 19 +-
.../org/apache/spark/ml/recommendation/ALS.scala | 27 ++-
.../ml/regression/AFTSurvivalRegression.scala | 21 +-
.../ml/regression/DecisionTreeRegressor.scala | 4 +-
.../apache/spark/ml/regression/FMRegressor.scala | 21 +-
.../regression/GeneralizedLinearRegression.scala | 18 +-
.../spark/ml/regression/IsotonicRegression.scala | 21 +-
.../spark/ml/regression/LinearRegression.scala | 27 ++-
.../org/apache/spark/ml/tree/treeModels.scala | 79 +++++-
.../scala/org/apache/spark/ml/util/ReadWrite.scala | 266 ++++++++++++++++++---
38 files changed, 970 insertions(+), 112 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index e400bf13eb8d..8902d12bdf94 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -302,7 +302,9 @@ object DecisionTreeClassificationModel extends
MLReadable[DecisionTreeClassifica
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
- ReadWriteUtils.saveArray(dataPath, nodeData.toArray, sparkSession,
numDataParts)
+ ReadWriteUtils.saveArray(
+ dataPath, nodeData.toArray, sparkSession, NodeData.serializeData,
numDataParts
+ )
}
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index 222cfbb80c3d..cefa13b2bbe7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -351,6 +353,21 @@ object FMClassificationModel extends
MLReadable[FMClassificationModel] {
factors: Matrix
)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ dos.writeDouble(data.intercept)
+ serializeVector(data.linear, dos)
+ serializeMatrix(data.factors, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val intercept = dis.readDouble()
+ val linear = deserializeVector(dis)
+ val factors = deserializeMatrix(dis)
+ Data(intercept, linear, factors)
+ }
+
@Since("3.0.0")
override def read: MLReader[FMClassificationModel] = new
FMClassificationModelReader
@@ -365,7 +382,7 @@ object FMClassificationModel extends
MLReadable[FMClassificationModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -377,7 +394,7 @@ object FMClassificationModel extends
MLReadable[FMClassificationModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new FMClassificationModel(
metadata.uid, data.intercept, data.linear, data.factors
)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index c5d1170318f7..a50346ae88f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
@@ -449,6 +451,19 @@ class LinearSVCModel private[classification] (
object LinearSVCModel extends MLReadable[LinearSVCModel] {
private[ml] case class Data(coefficients: Vector, intercept: Double)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.coefficients, dos)
+ dos.writeDouble(data.intercept)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val coefficients = deserializeVector(dis)
+ val intercept = dis.readDouble()
+ Data(coefficients, intercept)
+ }
+
@Since("2.2.0")
override def read: MLReader[LinearSVCModel] = new LinearSVCReader
@@ -465,7 +480,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.coefficients, instance.intercept)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -477,7 +492,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
override def load(path: String): LinearSVCModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new LinearSVCModel(metadata.uid, data.coefficients,
data.intercept)
metadata.getAndSetParams(model)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index d09cacf3fb5b..f8a74df2508d 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.classification
+import java.io.{DataInputStream, DataOutputStream}
import java.util.Locale
import scala.collection.mutable
@@ -1325,6 +1326,25 @@ object LogisticRegressionModel extends
MLReadable[LogisticRegressionModel] {
coefficientMatrix: Matrix,
isMultinomial: Boolean)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ dos.writeInt(data.numClasses)
+ dos.writeInt(data.numFeatures)
+ serializeVector(data.interceptVector, dos)
+ serializeMatrix(data.coefficientMatrix, dos)
+ dos.writeBoolean(data.isMultinomial)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val numClasses = dis.readInt()
+ val numFeatures = dis.readInt()
+ val interceptVector = deserializeVector(dis)
+ val coefficientMatrix = deserializeMatrix(dis)
+ val isMultinomial = dis.readBoolean()
+ Data(numClasses, numFeatures, interceptVector, coefficientMatrix,
isMultinomial)
+ }
+
@Since("1.6.0")
override def read: MLReader[LogisticRegressionModel] = new
LogisticRegressionModelReader
@@ -1343,7 +1363,7 @@ object LogisticRegressionModel extends
MLReadable[LogisticRegressionModel] {
val data = Data(instance.numClasses, instance.numFeatures,
instance.interceptVector,
instance.coefficientMatrix, instance.isMultinomial)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -1372,7 +1392,7 @@ object LogisticRegressionModel extends
MLReadable[LogisticRegressionModel] {
interceptVector, numClasses, isMultinomial = false)
} else {
// 2.1+
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
new LogisticRegressionModel(metadata.uid, data.coefficientMatrix,
data.interceptVector,
data.numClasses, data.isMultinomial)
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 2359749f8b48..2a5b00b9e5ed 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -370,6 +372,17 @@ object MultilayerPerceptronClassificationModel
extends MLReadable[MultilayerPerceptronClassificationModel] {
private[ml] case class Data(weights: Vector)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.weights, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val weights = deserializeVector(dis)
+ Data(weights)
+ }
+
@Since("2.0.0")
override def read: MLReader[MultilayerPerceptronClassificationModel] =
new MultilayerPerceptronClassificationModelReader
@@ -388,7 +401,7 @@ object MultilayerPerceptronClassificationModel
// Save model data: weights
val data = Data(instance.weights)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -411,7 +424,7 @@ object MultilayerPerceptronClassificationModel
val model = new MultilayerPerceptronClassificationModel(metadata.uid,
weights)
model.set("layers", layers)
} else {
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
new MultilayerPerceptronClassificationModel(metadata.uid, data.weights)
}
metadata.getAndSetParams(model)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index ce26478c625c..aaa1ef4a4764 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -600,6 +602,21 @@ class NaiveBayesModel private[ml] (
object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
private[ml] case class Data(pi: Vector, theta: Matrix, sigma: Matrix)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.pi, dos)
+ serializeMatrix(data.theta, dos)
+ serializeMatrix(data.sigma, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val pi = deserializeVector(dis)
+ val theta = deserializeMatrix(dis)
+ val sigma = deserializeMatrix(dis)
+ Data(pi, theta, sigma)
+ }
+
@Since("1.6.0")
override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader
@@ -623,7 +640,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
}
val data = Data(instance.pi, instance.theta, instance.sigma)
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -647,7 +664,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
.head()
new NaiveBayesModel(metadata.uid, pi, theta, Matrices.zeros(0, 0))
} else {
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
new NaiveBayesModel(metadata.uid, data.pi, data.theta, data.sigma)
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index ee0b19f8129d..5935a2a18f50 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.clustering
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -229,6 +231,25 @@ object GaussianMixtureModel extends
MLReadable[GaussianMixtureModel] {
sigmas: Array[OldMatrix]
)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeDoubleArray(data.weights, dos)
+ serializeGenericArray[OldVector](data.mus, dos, (v, dos) =>
serializeVector(v.asML, dos))
+ serializeGenericArray[OldMatrix](data.sigmas, dos, (v, dos) =>
serializeMatrix(v.asML, dos))
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val weights = deserializeDoubleArray(dis)
+ val mus = deserializeGenericArray[OldVector](
+ dis, dis => OldVectors.fromML(deserializeVector(dis))
+ )
+ val sigmas = deserializeGenericArray[OldMatrix](
+ dis, dis => OldMatrices.fromML(deserializeMatrix(dis))
+ )
+ Data(weights, mus, sigmas)
+ }
+
@Since("2.0.0")
override def read: MLReader[GaussianMixtureModel] = new
GaussianMixtureModelReader
@@ -249,7 +270,7 @@ object GaussianMixtureModel extends
MLReadable[GaussianMixtureModel] {
val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov))
val data = Data(weights, mus, sigmas)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -264,7 +285,7 @@ object GaussianMixtureModel extends
MLReadable[GaussianMixtureModel] {
val dataPath = new Path(path, "data").toString
val data = if (ReadWriteUtils.localSavingModeState.get()) {
- ReadWriteUtils.loadObjectFromLocal(dataPath)
+ ReadWriteUtils.loadObjectFromLocal(dataPath, deserializeData)
} else {
val row = sparkSession.read.parquet(dataPath).select("weights", "mus",
"sigmas").head()
Data(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index e87dc9eb040b..53e4eb517771 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.clustering
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.collection.mutable
import org.apache.hadoop.fs.Path
@@ -215,6 +217,20 @@ class KMeansModel private[ml] (
/** Helper class for storing model data */
private[ml] case class ClusterData(clusterIdx: Int, clusterCenter: Vector)
+private[ml] object ClusterData {
+ private[ml] def serializeData(data: ClusterData, dos: DataOutputStream):
Unit = {
+ import ReadWriteUtils._
+ dos.writeInt(data.clusterIdx)
+ serializeVector(data.clusterCenter, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): ClusterData = {
+ import ReadWriteUtils._
+ val clusterIdx = dis.readInt()
+ val clusterCenter = deserializeVector(dis)
+ ClusterData(clusterIdx, clusterCenter)
+ }
+}
/** A writer for KMeans that handles the "internal" (or default) format */
private class InternalKMeansModelWriter extends MLWriterFormat with
MLFormatRegister {
@@ -233,7 +249,9 @@ private class InternalKMeansModelWriter extends
MLWriterFormat with MLFormatRegi
ClusterData(idx, center)
}
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveArray[ClusterData](dataPath, data, sparkSession)
+ ReadWriteUtils.saveArray[ClusterData](
+ dataPath, data, sparkSession, ClusterData.serializeData
+ )
}
}
@@ -281,7 +299,9 @@ object KMeansModel extends MLReadable[KMeansModel] {
val dataPath = new Path(path, "data").toString
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
- val data = ReadWriteUtils.loadArray[ClusterData](dataPath,
sparkSession)
+ val data = ReadWriteUtils.loadArray[ClusterData](
+ dataPath, sparkSession, ClusterData.deserializeData
+ )
data.sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and
earlier.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 4db66ca9325c..67c9a8f58dd2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.clustering
+import java.io.{DataInputStream, DataOutputStream}
import java.util.Locale
import breeze.linalg.normalize
@@ -650,6 +651,25 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
topicConcentration: Double,
gammaShape: Double)
+ private[ml] def serializeData(data: LocalModelData, dos: DataOutputStream):
Unit = {
+ import ReadWriteUtils._
+ dos.writeInt(data.vocabSize)
+ serializeMatrix(data.topicsMatrix, dos)
+ serializeVector(data.docConcentration, dos)
+ dos.writeDouble(data.topicConcentration)
+ dos.writeDouble(data.gammaShape)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): LocalModelData = {
+ import ReadWriteUtils._
+ val vocabSize = dis.readInt()
+ val topicsMatrix = deserializeMatrix(dis)
+ val docConcentration = deserializeVector(dis)
+ val topicConcentration = dis.readDouble()
+ val gammaShape = dis.readDouble()
+ LocalModelData(vocabSize, topicsMatrix, docConcentration,
topicConcentration, gammaShape)
+ }
+
private[LocalLDAModel]
class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter {
@@ -661,7 +681,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
oldModel.topicConcentration, oldModel.gammaShape
)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[LocalModelData](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[LocalModelData](dataPath, data, sparkSession,
serializeData)
}
}
@@ -673,7 +693,9 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[LocalModelData](dataPath,
sparkSession)
+ val data = ReadWriteUtils.loadObject[LocalModelData](
+ dataPath, sparkSession, deserializeData
+ )
val oldModel = new OldLocalLDAModel(
data.topicsMatrix,
data.docConcentration,
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index ef7ff1be69a6..e4c5b91133f0 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.util.Random
import org.apache.hadoop.fs.Path
@@ -215,6 +217,17 @@ object BucketedRandomProjectionLSHModel extends
MLReadable[BucketedRandomProject
// TODO: Save using the existing format of Array[Vector] once SPARK-12878 is
resolved.
private[ml] case class Data(randUnitVectors: Matrix)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeMatrix(data.randUnitVectors, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val randUnitVectors = deserializeMatrix(dis)
+ Data(randUnitVectors)
+ }
+
@Since("2.1.0")
override def read: MLReader[BucketedRandomProjectionLSHModel] = {
new BucketedRandomProjectionLSHModelReader
@@ -230,7 +243,7 @@ object BucketedRandomProjectionLSHModel extends
MLReadable[BucketedRandomProject
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.randMatrix)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -244,7 +257,7 @@ object BucketedRandomProjectionLSHModel extends
MLReadable[BucketedRandomProject
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new BucketedRandomProjectionLSHModel(metadata.uid,
data.randUnitVectors)
metadata.getAndSetParams(model)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 545bac693a93..abe03017538f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -172,13 +174,24 @@ final class ChiSqSelectorModel private[ml] (
object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
private[ml] case class Data(selectedFeatures: Seq[Int])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeIntArray(data.selectedFeatures.toArray, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val selectedFeatures = deserializeIntArray(dis).toSeq
+ Data(selectedFeatures)
+ }
+
class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends
MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.selectedFeatures.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -189,7 +202,7 @@ object ChiSqSelectorModel extends
MLReadable[ChiSqSelectorModel] {
override def load(path: String): ChiSqSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new ChiSqSelectorModel(metadata.uid,
data.selectedFeatures.toArray)
metadata.getAndSetParams(model)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 92b2a09f85b5..060e445e0254 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -370,6 +372,17 @@ class CountVectorizerModel(
object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
private[ml] case class Data(vocabulary: Seq[String])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeStringArray(data.vocabulary.toArray, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val vocabulary = deserializeStringArray(dis).toSeq
+ Data(vocabulary)
+ }
+
private[CountVectorizerModel]
class CountVectorizerModelWriter(instance: CountVectorizerModel) extends
MLWriter {
@@ -377,7 +390,7 @@ object CountVectorizerModel extends
MLReadable[CountVectorizerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.vocabulary.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -388,7 +401,7 @@ object CountVectorizerModel extends
MLReadable[CountVectorizerModel] {
override def load(path: String): CountVectorizerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new CountVectorizerModel(metadata.uid,
data.vocabulary.toArray)
metadata.getAndSetParams(model)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 11ef88ac1fb8..12d957ea360f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -197,13 +199,28 @@ class IDFModel private[ml] (
object IDFModel extends MLReadable[IDFModel] {
private[ml] case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.idf, dos)
+ serializeLongArray(data.docFreq, dos)
+ dos.writeLong(data.numDocs)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val idf = deserializeVector(dis)
+ val docFreq = deserializeLongArray(dis)
+ val numDocs = dis.readLong()
+ Data(idf, docFreq, numDocs)
+ }
+
private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.idf, instance.docFreq, instance.numDocs)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -217,7 +234,7 @@ object IDFModel extends MLReadable[IDFModel] {
val data = sparkSession.read.parquet(dataPath)
val model = if (majorVersion(metadata.sparkVersion) >= 3) {
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
new IDFModel(
metadata.uid,
new feature.IDFModel(OldVectors.fromML(data.idf), data.docFreq,
data.numDocs)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index a4109a8ad9e1..a5b1a3cfde82 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -311,11 +311,16 @@ object ImputerModel extends MLReadable[ImputerModel] {
private[ImputerModel] class ImputerModelWriter(instance: ImputerModel)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
+ import org.apache.spark.ml.util.ReadWriteUtils._
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString
if (ReadWriteUtils.localSavingModeState.get()) {
ReadWriteUtils.saveObjectToLocal[(Array[String], Array[Double])](
- dataPath, (instance.columnNames, instance.surrogates)
+ dataPath, (instance.columnNames, instance.surrogates),
+ (v, dos) => {
+ serializeStringArray(v._1, dos)
+ serializeDoubleArray(v._2, dos)
+ }
)
} else {
instance.surrogateDF.repartition(1).write.parquet(dataPath)
@@ -328,10 +333,18 @@ object ImputerModel extends MLReadable[ImputerModel] {
private val className = classOf[ImputerModel].getName
override def load(path: String): ImputerModel = {
+ import org.apache.spark.ml.util.ReadWriteUtils._
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val model = if (ReadWriteUtils.localSavingModeState.get()) {
- val data = ReadWriteUtils.loadObjectFromLocal[(Array[String],
Array[Double])](dataPath)
+ val data = ReadWriteUtils.loadObjectFromLocal[(Array[String],
Array[Double])](
+ dataPath,
+ dis => {
+ val v1 = deserializeStringArray(dis)
+ val v2 = deserializeDoubleArray(dis)
+ (v1, v2)
+ }
+ )
new ImputerModel(metadata.uid, data._1, data._2)
} else {
val row = sparkSession.read.parquet(dataPath).head()
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index db60ee879afb..6d2e39d65302 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -160,6 +162,17 @@ class MaxAbsScalerModel private[ml] (
object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
private[ml] case class Data(maxAbs: Vector)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.maxAbs, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val maxAbs = deserializeVector(dis)
+ Data(maxAbs)
+ }
+
private[MaxAbsScalerModel]
class MaxAbsScalerModelWriter(instance: MaxAbsScalerModel) extends MLWriter {
@@ -167,7 +180,7 @@ object MaxAbsScalerModel extends
MLReadable[MaxAbsScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = new Data(instance.maxAbs)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -178,7 +191,7 @@ object MaxAbsScalerModel extends
MLReadable[MaxAbsScalerModel] {
override def load(path: String): MaxAbsScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new MaxAbsScalerModel(metadata.uid, data.maxAbs)
metadata.getAndSetParams(model)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 8faadcc7db49..90aaaea7e36e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.util.Random
import org.apache.hadoop.fs.Path
@@ -212,6 +214,17 @@ object MinHashLSH extends
DefaultParamsReadable[MinHashLSH] {
object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
private[ml] case class Data(randCoefficients: Array[Int])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeIntArray(data.randCoefficients, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val randCoefficients = deserializeIntArray(dis)
+ Data(randCoefficients)
+ }
+
@Since("2.1.0")
override def read: MLReader[MinHashLSHModel] = new MinHashLSHModelReader
@@ -225,7 +238,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.randCoefficients.flatMap(tuple =>
Array(tuple._1, tuple._2)))
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -238,7 +251,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new MinHashLSHModel(
metadata.uid,
data.randCoefficients.grouped(2).map(tuple => (tuple(0),
tuple(1))).toArray
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index e02a25bf7b8d..36942bd08a3a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -244,6 +246,19 @@ class MinMaxScalerModel private[ml] (
object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
private[ml] case class Data(originalMin: Vector, originalMax: Vector)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.originalMin, dos)
+ serializeVector(data.originalMax, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val originalMin = deserializeVector(dis)
+ val originalMax = deserializeVector(dis)
+ Data(originalMin, originalMax)
+ }
+
private[MinMaxScalerModel]
class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter {
@@ -251,7 +266,7 @@ object MinMaxScalerModel extends
MLReadable[MinMaxScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.originalMin, instance.originalMax)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -262,7 +277,7 @@ object MinMaxScalerModel extends
MLReadable[MinMaxScalerModel] {
override def load(path: String): MinMaxScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new MinMaxScalerModel(metadata.uid, data.originalMin,
data.originalMax)
metadata.getAndSetParams(model)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 0a9b6c46feae..b38dfcb4ed5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
@@ -403,6 +405,17 @@ class OneHotEncoderModel private[ml] (
object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] {
private[ml] case class Data(categorySizes: Array[Int])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeIntArray(data.categorySizes, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val categorySizes = deserializeIntArray(dis)
+ Data(categorySizes)
+ }
+
private[OneHotEncoderModel]
class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends
MLWriter {
@@ -410,7 +423,7 @@ object OneHotEncoderModel extends
MLReadable[OneHotEncoderModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.categorySizes)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -421,7 +434,7 @@ object OneHotEncoderModel extends
MLReadable[OneHotEncoderModel] {
override def load(path: String): OneHotEncoderModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new OneHotEncoderModel(metadata.uid, data.categorySizes)
metadata.getAndSetParams(model)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index e5fd96671b20..d09d471788cf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -183,13 +185,26 @@ class PCAModel private[ml] (
object PCAModel extends MLReadable[PCAModel] {
private[ml] case class Data(pc: Matrix, explainedVariance: Vector)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeMatrix(data.pc, dos)
+ serializeVector(data.explainedVariance, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val pc = deserializeMatrix(dis)
+ val explainedVariance = deserializeVector(dis)
+ Data(pc, explainedVariance)
+ }
+
private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.pc, instance.explainedVariance)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -211,7 +226,7 @@ object PCAModel extends MLReadable[PCAModel] {
val dataPath = new Path(path, "data").toString
val model = if (majorVersion(metadata.sparkVersion) >= 2) {
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
new PCAModel(metadata.uid, data.pc.toDense,
data.explainedVariance.toDense)
} else {
// pc field is the old matrix format in Spark <= 1.6
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index b482d08b2fac..d4d9aade5b25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -441,7 +443,10 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: resolvedFormula
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[ResolvedRFormula](dataPath,
instance.resolvedFormula, sparkSession)
+ ReadWriteUtils.saveObject[ResolvedRFormula](
+ dataPath, instance.resolvedFormula, sparkSession,
+ ResolvedRFormula.serializeData
+ )
// Save pipeline model
val pmPath = new Path(path, "pipelineModel").toString
instance.pipelineModel.save(pmPath)
@@ -457,7 +462,9 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val resolvedRFormula =
ReadWriteUtils.loadObject[ResolvedRFormula](dataPath, sparkSession)
+ val resolvedRFormula = ReadWriteUtils.loadObject[ResolvedRFormula](
+ dataPath, sparkSession, ResolvedRFormula.deserializeData
+ )
val pmPath = new Path(path, "pipelineModel").toString
val pipelineModel = PipelineModel.load(pmPath)
@@ -498,6 +505,17 @@ private class ColumnPruner(override val uid: String, val
columnsToPrune: Set[Str
private object ColumnPruner extends MLReadable[ColumnPruner] {
private[ml] case class Data(columnsToPrune: Seq[String])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeStringArray(data.columnsToPrune.toArray, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val columnsToPrune = deserializeStringArray(dis).toSeq
+ Data(columnsToPrune)
+ }
+
override def read: MLReader[ColumnPruner] = new ColumnPrunerReader
override def load(path: String): ColumnPruner = super.load(path)
@@ -511,7 +529,7 @@ private object ColumnPruner extends
MLReadable[ColumnPruner] {
// Save model data: columnsToPrune
val data = Data(instance.columnsToPrune.toSeq)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -524,7 +542,7 @@ private object ColumnPruner extends
MLReadable[ColumnPruner] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val pruner = new ColumnPruner(metadata.uid, data.columnsToPrune.toSet)
metadata.getAndSetParams(pruner)
@@ -590,6 +608,21 @@ private class VectorAttributeRewriter(
private object VectorAttributeRewriter extends
MLReadable[VectorAttributeRewriter] {
private[ml] case class Data(vectorCol: String, prefixesToRewrite:
Map[String, String])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ dos.writeUTF(data.vectorCol)
+ val kvSer = (s: String, dos: DataOutputStream) => dos.writeUTF(s)
+ serializeMap(data.prefixesToRewrite, dos, kvSer, kvSer)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val vectorCol = dis.readUTF()
+ val kvDeser = (dis: DataInputStream) => dis.readUTF()
+ val prefixesToRewrite = deserializeMap(dis, kvDeser, kvDeser)
+ Data(vectorCol, prefixesToRewrite)
+ }
+
override def read: MLReader[VectorAttributeRewriter] = new
VectorAttributeRewriterReader
override def load(path: String): VectorAttributeRewriter = super.load(path)
@@ -604,7 +637,7 @@ private object VectorAttributeRewriter extends
MLReadable[VectorAttributeRewrite
// Save model data: vectorCol, prefixesToRewrite
val data = Data(instance.vectorCol, instance.prefixesToRewrite)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -617,7 +650,7 @@ private object VectorAttributeRewriter extends
MLReadable[VectorAttributeRewrite
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val rewriter = new VectorAttributeRewriter(
metadata.uid, data.vectorCol, data.prefixesToRewrite
)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
index 88e63b766ca6..0d064d920cdd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.collection.mutable
import scala.util.parsing.combinator.RegexParsers
@@ -152,6 +154,30 @@ private[ml] case class ResolvedRFormula(
}
}
+private[ml] object ResolvedRFormula {
+ private[ml] def serializeData(data: ResolvedRFormula, dos:
DataOutputStream): Unit = {
+ import org.apache.spark.ml.util.ReadWriteUtils._
+
+ dos.writeUTF(data.label)
+ serializeGenericArray[Seq[String]](
+ data.terms.toArray, dos,
+ (strSeq, dos) => serializeStringArray(strSeq.toArray, dos)
+ )
+ dos.writeBoolean(data.hasIntercept)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): ResolvedRFormula = {
+ import org.apache.spark.ml.util.ReadWriteUtils._
+
+ val label = dis.readUTF()
+ val terms = deserializeGenericArray[Seq[String]](
+ dis, dis => deserializeStringArray(dis).toSeq
+ ).toSeq
+ val hasIntercept = dis.readBoolean()
+ ResolvedRFormula(label, terms, hasIntercept)
+ }
+}
+
/**
* R formula terms. See the R formula docs here for more information:
* http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
index 246e553b3add..098976495085 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -281,6 +283,19 @@ class RobustScalerModel private[ml] (
object RobustScalerModel extends MLReadable[RobustScalerModel] {
private[ml] case class Data(range: Vector, median: Vector)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.range, dos)
+ serializeVector(data.median, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val range = deserializeVector(dis)
+ val median = deserializeVector(dis)
+ Data(range, median)
+ }
+
private[RobustScalerModel]
class RobustScalerModelWriter(instance: RobustScalerModel) extends MLWriter {
@@ -288,7 +303,7 @@ object RobustScalerModel extends
MLReadable[RobustScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.range, instance.median)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -299,7 +314,7 @@ object RobustScalerModel extends
MLReadable[RobustScalerModel] {
override def load(path: String): RobustScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new RobustScalerModel(metadata.uid, data.range, data.median)
metadata.getAndSetParams(model)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 87e2557eb484..9cd2fba004a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -202,6 +204,19 @@ class StandardScalerModel private[ml] (
object StandardScalerModel extends MLReadable[StandardScalerModel] {
private[ml] case class Data(std: Vector, mean: Vector)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.std, dos)
+ serializeVector(data.mean, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val std = deserializeVector(dis)
+ val mean = deserializeVector(dis)
+ Data(std, mean)
+ }
+
private[StandardScalerModel]
class StandardScalerModelWriter(instance: StandardScalerModel) extends
MLWriter {
@@ -209,7 +224,7 @@ object StandardScalerModel extends
MLReadable[StandardScalerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.std, instance.mean)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -220,7 +235,7 @@ object StandardScalerModel extends
MLReadable[StandardScalerModel] {
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new StandardScalerModel(metadata.uid, data.std, data.mean)
metadata.getAndSetParams(model)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 243333f9f0de..db3749558f47 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
@@ -471,6 +473,22 @@ class StringIndexerModel (
object StringIndexerModel extends MLReadable[StringIndexerModel] {
private[ml] case class Data(labelsArray: Seq[Seq[String]])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeGenericArray[Seq[String]](
+ data.labelsArray.toArray, dos,
+ (strSeq, dos) => serializeStringArray(strSeq.toArray, dos)
+ )
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val labelsArray = deserializeGenericArray[Seq[String]](
+ dis, dis => deserializeStringArray(dis).toSeq
+ ).toSeq
+ Data(labelsArray)
+ }
+
private[StringIndexerModel]
class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter {
@@ -478,7 +496,7 @@ object StringIndexerModel extends
MLReadable[StringIndexerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data =
Data(instance.labelsArray.map(_.toImmutableArraySeq).toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -501,7 +519,7 @@ object StringIndexerModel extends
MLReadable[StringIndexerModel] {
val labels = data.getAs[Seq[String]](0).toArray
Array(labels)
} else {
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
data.labelsArray.map(_.toArray).toArray
}
val model = new StringIndexerModel(metadata.uid, labelsArray)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
index aa11a139b022..736aa4fc5356 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
@@ -406,6 +408,23 @@ object TargetEncoderModel extends
MLReadable[TargetEncoderModel] {
index: Int, categories: Array[Double],
counts: Array[Double], stats: Array[Double])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ dos.writeInt(data.index)
+ serializeDoubleArray(data.categories, dos)
+ serializeDoubleArray(data.counts, dos)
+ serializeDoubleArray(data.stats, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val index = dis.readInt()
+ val categories = deserializeDoubleArray(dis)
+ val counts = deserializeDoubleArray(dis)
+ val stats = deserializeDoubleArray(dis)
+ Data(index, categories, counts, stats)
+ }
+
private[TargetEncoderModel]
class TargetEncoderModelWriter(instance: TargetEncoderModel) extends
MLWriter {
@@ -417,7 +436,7 @@ object TargetEncoderModel extends
MLReadable[TargetEncoderModel] {
Data(index, _categories.toArray, _counts.toArray, _stats.toArray)
}.toSeq
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveArray[Data](dataPath, datum.toArray, sparkSession)
+ ReadWriteUtils.saveArray[Data](dataPath, datum.toArray, sparkSession,
serializeData)
}
}
@@ -429,7 +448,7 @@ object TargetEncoderModel extends
MLReadable[TargetEncoderModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val datum = ReadWriteUtils.loadArray[Data](dataPath, sparkSession)
+ val datum = ReadWriteUtils.loadArray[Data](dataPath, sparkSession,
deserializeData)
val stats = datum.map { data =>
(data.index, data.categories.zip(data.counts.zip(data.stats)).toMap)
}.sortBy(_._1).map(_._2)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index c394f121a215..ca7e9eb826fc 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.collection.mutable.ArrayBuilder
import org.apache.hadoop.fs.Path
@@ -340,6 +342,17 @@ class UnivariateFeatureSelectorModel private[ml](
object UnivariateFeatureSelectorModel extends
MLReadable[UnivariateFeatureSelectorModel] {
private[ml] case class Data(selectedFeatures: Seq[Int])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeIntArray(data.selectedFeatures.toArray, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val selectedFeatures = deserializeIntArray(dis).toSeq
+ Data(selectedFeatures)
+ }
+
@Since("3.1.1")
override def read: MLReader[UnivariateFeatureSelectorModel] =
new UnivariateFeatureSelectorModelReader
@@ -354,7 +367,7 @@ object UnivariateFeatureSelectorModel extends
MLReadable[UnivariateFeatureSelect
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.selectedFeatures.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -367,7 +380,7 @@ object UnivariateFeatureSelectorModel extends
MLReadable[UnivariateFeatureSelect
override def load(path: String): UnivariateFeatureSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new UnivariateFeatureSelectorModel(metadata.uid,
data.selectedFeatures.toArray)
metadata.getAndSetParams(model)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
index 0549434e2429..b794e6df3944 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -178,6 +180,17 @@ class VarianceThresholdSelectorModel private[ml](
object VarianceThresholdSelectorModel extends
MLReadable[VarianceThresholdSelectorModel] {
private[ml] case class Data(selectedFeatures: Seq[Int])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeIntArray(data.selectedFeatures.toArray, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val selectedFeatures = deserializeIntArray(dis).toSeq
+ Data(selectedFeatures)
+ }
+
@Since("3.1.0")
override def read: MLReader[VarianceThresholdSelectorModel] =
new VarianceThresholdSelectorModelReader
@@ -192,7 +205,7 @@ object VarianceThresholdSelectorModel extends
MLReadable[VarianceThresholdSelect
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.selectedFeatures.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -205,7 +218,7 @@ object VarianceThresholdSelectorModel extends
MLReadable[VarianceThresholdSelect
override def load(path: String): VarianceThresholdSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new VarianceThresholdSelectorModel(metadata.uid,
data.selectedFeatures.toArray)
metadata.getAndSetParams(model)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 8d98153a8a14..290e5fe05fca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
import java.lang.{Double => JDouble, Integer => JInt}
import java.util.{Map => JMap, NoSuchElementException}
@@ -530,6 +531,33 @@ class VectorIndexerModel private[ml] (
object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
private[ml] case class Data(numFeatures: Int, categoryMaps: Map[Int,
Map[Double, Int]])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ dos.writeInt(data.numFeatures)
+ serializeMap[Int, Map[Double, Int]](
+ data.categoryMaps, dos,
+ (k, dos) => dos.writeInt(k),
+ (v, dos) => {
+ serializeMap[Double, Int](
+ v, dos,
+ (kk, dos) => dos.writeDouble(kk),
+ (vv, dos) => dos.writeInt(vv)
+ )
+ }
+ )
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val numFeatures = dis.readInt()
+ val categoryMaps = deserializeMap[Int, Map[Double, Int]](
+ dis,
+ dis => dis.readInt(),
+ dis => deserializeMap[Double, Int](dis, dis => dis.readDouble(), dis =>
dis.readInt())
+ )
+ Data(numFeatures, categoryMaps)
+ }
+
private[VectorIndexerModel]
class VectorIndexerModelWriter(instance: VectorIndexerModel) extends
MLWriter {
@@ -537,7 +565,7 @@ object VectorIndexerModel extends
MLReadable[VectorIndexerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.numFeatures, instance.categoryMaps)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -548,7 +576,7 @@ object VectorIndexerModel extends
MLReadable[VectorIndexerModel] {
override def load(path: String): VectorIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new VectorIndexerModel(metadata.uid, data.numFeatures,
data.categoryMaps)
metadata.getAndSetParams(model)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 50e25ccf092c..bfaf0acbde4b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -351,6 +353,19 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
private[Word2VecModel] case class Data(word: String, vector: Array[Float])
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ dos.writeUTF(data.word)
+ serializeFloatArray(data.vector, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val word = dis.readUTF()
+ val vector = deserializeFloatArray(dis)
+ Data(word, vector)
+ }
+
private[Word2VecModel]
class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter {
@@ -364,7 +379,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions(
bufferSizeInBytes, instance.wordVectors.wordIndex.size,
instance.getVectorSize)
val datum = wordVectors.toArray.map { case (word, vector) => Data(word,
vector) }
- ReadWriteUtils.saveArray[Data](dataPath, datum, sparkSession,
numPartitions)
+ ReadWriteUtils.saveArray[Data](dataPath, datum, sparkSession,
serializeData, numPartitions)
}
}
@@ -416,7 +431,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
val wordVectors = data.getAs[Seq[Float]](1).toArray
new feature.Word2VecModel(wordIndex, wordVectors)
} else {
- val datum = ReadWriteUtils.loadArray[Data](dataPath, sparkSession)
+ val datum = ReadWriteUtils.loadArray[Data](dataPath, sparkSession,
deserializeData)
val wordVectorsMap = datum.map(wordVector => (wordVector.word,
wordVector.vector)).toMap
new feature.Word2VecModel(wordVectorsMap)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 18a6cb3b5257..276c7630d2d5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.recommendation
import java.{util => ju}
-import java.io.IOException
+import java.io.{DataInputStream, DataOutputStream, IOException}
import java.util.Locale
import scala.collection.mutable
@@ -552,6 +552,19 @@ private[ml] case class FeatureData(id: Int, features:
Array[Float])
@Since("1.6.0")
object ALSModel extends MLReadable[ALSModel] {
+ private[ml] def serializeData(data: FeatureData, dos: DataOutputStream):
Unit = {
+ import ReadWriteUtils._
+ dos.writeInt(data.id)
+ serializeFloatArray(data.features, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): FeatureData = {
+ import ReadWriteUtils._
+ val id = dis.readInt()
+ val features = deserializeFloatArray(dis)
+ FeatureData(id, features)
+ }
+
private val NaN = "nan"
private val Drop = "drop"
private[recommendation] final val supportedColdStartStrategies = Array(NaN,
Drop)
@@ -579,9 +592,9 @@ object ALSModel extends MLReadable[ALSModel] {
import sparkSession.implicits._
val userFactorsData = instance.userFactors.as[FeatureData].collect()
- ReadWriteUtils.saveArray(userPath, userFactorsData, sparkSession)
+ ReadWriteUtils.saveArray(userPath, userFactorsData, sparkSession,
serializeData)
val itemFactorsData = instance.itemFactors.as[FeatureData].collect()
- ReadWriteUtils.saveArray(itemPath, itemFactorsData, sparkSession)
+ ReadWriteUtils.saveArray(itemPath, itemFactorsData, sparkSession,
serializeData)
} else {
instance.userFactors.write.format("parquet").save(userPath)
instance.itemFactors.write.format("parquet").save(itemPath)
@@ -603,9 +616,13 @@ object ALSModel extends MLReadable[ALSModel] {
val (userFactors, itemFactors) = if
(ReadWriteUtils.localSavingModeState.get()) {
import org.apache.spark.util.ArrayImplicits._
- val userFactorsData = ReadWriteUtils.loadArray[FeatureData](userPath,
sparkSession)
+ val userFactorsData = ReadWriteUtils.loadArray[FeatureData](
+ userPath, sparkSession, deserializeData
+ )
val userFactors =
sparkSession.createDataFrame(userFactorsData.toImmutableArraySeq)
- val itemFactorsData = ReadWriteUtils.loadArray[FeatureData](itemPath,
sparkSession)
+ val itemFactorsData = ReadWriteUtils.loadArray[FeatureData](
+ itemPath, sparkSession, deserializeData
+ )
val itemFactors =
sparkSession.createDataFrame(itemFactorsData.toImmutableArraySeq)
(userFactors, itemFactors)
} else {
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index de9d016edea6..3aee34a148ad 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.regression
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
@@ -499,6 +501,21 @@ class AFTSurvivalRegressionModel private[ml] (
object AFTSurvivalRegressionModel extends
MLReadable[AFTSurvivalRegressionModel] {
private[ml] case class Data(coefficients: Vector, intercept: Double, scale:
Double)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeVector(data.coefficients, dos)
+ dos.writeDouble(data.intercept)
+ dos.writeDouble(data.scale)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val coefficients = deserializeVector(dis)
+ val intercept = dis.readDouble()
+ val scale = dis.readDouble()
+ Data(coefficients, intercept, scale)
+ }
+
@Since("1.6.0")
override def read: MLReader[AFTSurvivalRegressionModel] = new
AFTSurvivalRegressionModelReader
@@ -516,7 +533,7 @@ object AFTSurvivalRegressionModel extends
MLReadable[AFTSurvivalRegressionModel]
// Save model data: coefficients, intercept, scale
val data = Data(instance.coefficients, instance.intercept,
instance.scale)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -529,7 +546,7 @@ object AFTSurvivalRegressionModel extends
MLReadable[AFTSurvivalRegressionModel]
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new AFTSurvivalRegressionModel(
metadata.uid, data.coefficients, data.intercept, data.scale
)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index f53d26882d3f..2c40a2f353b7 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -311,7 +311,9 @@ object DecisionTreeRegressionModel extends
MLReadable[DecisionTreeRegressionMode
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
- ReadWriteUtils.saveArray(dataPath, nodeData.toArray, sparkSession,
numDataParts)
+ ReadWriteUtils.saveArray(
+ dataPath, nodeData.toArray, sparkSession, NodeData.serializeData,
numDataParts
+ )
}
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index 0bb89354c47a..1b624895c7f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.regression
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.util.Random
import breeze.linalg.{axpy => brzAxpy, norm => brzNorm, Vector => BV}
@@ -515,6 +517,21 @@ object FMRegressionModel extends
MLReadable[FMRegressionModel] {
linear: Vector,
factors: Matrix)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ dos.writeDouble(data.intercept)
+ serializeVector(data.linear, dos)
+ serializeMatrix(data.factors, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val intercept = dis.readDouble()
+ val linear = deserializeVector(dis)
+ val factors = deserializeMatrix(dis)
+ Data(intercept, linear, factors)
+ }
+
@Since("3.0.0")
override def read: MLReader[FMRegressionModel] = new FMRegressionModelReader
@@ -529,7 +546,7 @@ object FMRegressionModel extends
MLReadable[FMRegressionModel] {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -540,7 +557,7 @@ object FMRegressionModel extends
MLReadable[FMRegressionModel] {
override def load(path: String): FMRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new FMRegressionModel(metadata.uid, data.intercept,
data.linear, data.factors)
metadata.getAndSetParams(model)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 777b70e7d021..14467c761b21 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.regression
+import java.io.{DataInputStream, DataOutputStream}
import java.util.Locale
import breeze.stats.{distributions => dist}
@@ -1145,6 +1146,19 @@ class GeneralizedLinearRegressionModel private[ml] (
object GeneralizedLinearRegressionModel extends
MLReadable[GeneralizedLinearRegressionModel] {
private[ml] case class Data(intercept: Double, coefficients: Vector)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ dos.writeDouble(data.intercept)
+ serializeVector(data.coefficients, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val intercept = dis.readDouble()
+ val coefficients = deserializeVector(dis)
+ Data(intercept, coefficients)
+ }
+
@Since("2.0.0")
override def read: MLReader[GeneralizedLinearRegressionModel] =
new GeneralizedLinearRegressionModelReader
@@ -1163,7 +1177,7 @@ object GeneralizedLinearRegressionModel extends
MLReadable[GeneralizedLinearRegr
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -1177,7 +1191,7 @@ object GeneralizedLinearRegressionModel extends
MLReadable[GeneralizedLinearRegr
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new GeneralizedLinearRegressionModel(
metadata.uid, data.coefficients, data.intercept
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 131fbcd4d167..6eddcb416d8a 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.regression
+import java.io.{DataInputStream, DataOutputStream}
+
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
@@ -290,6 +292,21 @@ object IsotonicRegressionModel extends
MLReadable[IsotonicRegressionModel] {
predictions: Array[Double],
isotonic: Boolean)
+ private[ml] def serializeData(data: Data, dos: DataOutputStream): Unit = {
+ import ReadWriteUtils._
+ serializeDoubleArray(data.boundaries, dos)
+ serializeDoubleArray(data.predictions, dos)
+ dos.writeBoolean(data.isotonic)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): Data = {
+ import ReadWriteUtils._
+ val boundaries = deserializeDoubleArray(dis)
+ val predictions = deserializeDoubleArray(dis)
+ val isotonic = dis.readBoolean()
+ Data(boundaries, predictions, isotonic)
+ }
+
@Since("1.6.0")
override def read: MLReader[IsotonicRegressionModel] = new
IsotonicRegressionModelReader
@@ -308,7 +325,7 @@ object IsotonicRegressionModel extends
MLReadable[IsotonicRegressionModel] {
val data = Data(
instance.oldModel.boundaries, instance.oldModel.predictions,
instance.oldModel.isotonic)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[Data](dataPath, data, sparkSession,
serializeData)
}
}
@@ -321,7 +338,7 @@ object IsotonicRegressionModel extends
MLReadable[IsotonicRegressionModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
- val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession)
+ val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession,
deserializeData)
val model = new IsotonicRegressionModel(
metadata.uid,
new MLlibIsotonicRegressionModel(data.boundaries, data.predictions,
data.isotonic)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 847115eb02b1..b06140e48338 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.regression
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
@@ -802,6 +804,23 @@ class LinearRegressionModel private[ml] (
private[ml] case class LinearModelData(intercept: Double, coefficients:
Vector, scale: Double)
+private[ml] object LinearModelData {
+ private[ml] def serializeData(data: LinearModelData, dos: DataOutputStream):
Unit = {
+ import ReadWriteUtils._
+ dos.writeDouble(data.intercept)
+ serializeVector(data.coefficients, dos)
+ dos.writeDouble(data.scale)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): LinearModelData = {
+ import ReadWriteUtils._
+ val intercept = dis.readDouble()
+ val coefficients = deserializeVector(dis)
+ val scale = dis.readDouble()
+ LinearModelData(intercept, coefficients, scale)
+ }
+}
+
/** A writer for LinearRegression that handles the "internal" (or default)
format */
private class InternalLinearRegressionModelWriter
extends MLWriterFormat with MLFormatRegister {
@@ -818,7 +837,9 @@ private class InternalLinearRegressionModelWriter
// Save model data: intercept, coefficients, scale
val data = LinearModelData(instance.intercept, instance.coefficients,
instance.scale)
val dataPath = new Path(path, "data").toString
- ReadWriteUtils.saveObject[LinearModelData](dataPath, data, sparkSession)
+ ReadWriteUtils.saveObject[LinearModelData](
+ dataPath, data, sparkSession, LinearModelData.serializeData
+ )
}
}
@@ -871,7 +892,9 @@ object LinearRegressionModel extends
MLReadable[LinearRegressionModel] {
.head()
new LinearRegressionModel(metadata.uid, coefficients, intercept)
} else {
- val data = ReadWriteUtils.loadObject[LinearModelData](dataPath,
sparkSession)
+ val data = ReadWriteUtils.loadObject[LinearModelData](
+ dataPath, sparkSession, LinearModelData.deserializeData
+ )
new LinearRegressionModel(
metadata.uid, data.coefficients, data.intercept, data.scale
)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index ea9a0de16563..b20b2e943dee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.tree
+import java.io.{DataInputStream, DataOutputStream}
+
import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path
@@ -348,6 +350,21 @@ private[ml] object DecisionTreeModelReadWrite {
case s: ContinuousSplit =>
SplitData(s.featureIndex, Array(s.threshold), -1)
}
+
+ private[ml] def serializeData(data: SplitData, dos: DataOutputStream):
Unit = {
+ import ReadWriteUtils._
+ dos.writeInt(data.featureIndex)
+ serializeDoubleArray(data.leftCategoriesOrThreshold, dos)
+ dos.writeInt(data.numCategories)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): SplitData = {
+ import ReadWriteUtils._
+ val featureIndex = dis.readInt()
+ val leftCategoriesOrThreshold = deserializeDoubleArray(dis)
+ val numCategories = dis.readInt()
+ SplitData(featureIndex, leftCategoriesOrThreshold, numCategories)
+ }
}
/**
@@ -403,6 +420,36 @@ private[ml] object DecisionTreeModelReadWrite {
// 7,280,000 nodes is about 128MB
(numNodes / 7280000.0).ceil.toInt
}
+
+ private[ml] def serializeData(data: NodeData, dos: DataOutputStream): Unit
= {
+ import ReadWriteUtils._
+ dos.writeInt(data.id)
+ dos.writeDouble(data.prediction)
+ dos.writeDouble(data.impurity)
+ serializeDoubleArray(data.impurityStats, dos)
+ dos.writeLong(data.rawCount)
+ dos.writeDouble(data.gain)
+ dos.writeInt(data.leftChild)
+ dos.writeInt(data.rightChild)
+ SplitData.serializeData(data.split, dos)
+ }
+
+ private[ml] def deserializeData(dis: DataInputStream): NodeData = {
+ import ReadWriteUtils._
+ val id = dis.readInt()
+ val prediction = dis.readDouble()
+ val impurity = dis.readDouble()
+ val impurityStats = deserializeDoubleArray(dis)
+ val rawCount = dis.readLong()
+ val gain = dis.readDouble()
+ val leftChild = dis.readInt()
+ val rightChild = dis.readInt()
+ val split = SplitData.deserializeData(dis)
+ NodeData(
+ id, prediction, impurity, impurityStats, rawCount, gain, leftChild,
rightChild, split
+ )
+ }
+
}
/**
@@ -430,7 +477,7 @@ private[ml] object DecisionTreeModelReadWrite {
df.as[NodeData].collect()
} else {
import org.apache.spark.ml.util.ReadWriteUtils
- ReadWriteUtils.loadArray[NodeData](dataPath, sparkSession)
+ ReadWriteUtils.loadArray[NodeData](dataPath, sparkSession,
NodeData.deserializeData)
}
buildTreeFromNodes(nodeDataArray, impurityType)
@@ -493,7 +540,12 @@ private[ml] object EnsembleModelReadWrite {
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
ReadWriteUtils.saveArray[(Int, String, Double)](
- treesMetadataPath, treesMetadataWeights, sparkSession, numDataParts = 1
+ treesMetadataPath, treesMetadataWeights, sparkSession,
+ (v, dos) => {
+ dos.writeInt(v._1)
+ dos.writeUTF(v._2)
+ dos.writeDouble(v._3)
+ }, numDataParts = 1
)
val dataPath = new Path(path, "data").toString
@@ -503,7 +555,11 @@ private[ml] object EnsembleModelReadWrite {
case (tree, treeID) => EnsembleNodeData.build(tree, treeID)
}
ReadWriteUtils.saveArray[EnsembleNodeData](
- dataPath, nodeDataArray, sparkSession, numDataParts
+ dataPath, nodeDataArray, sparkSession,
+ (v, dos) => {
+ dos.writeInt(v.treeID)
+ NodeData.serializeData(v.nodeData, dos)
+ }, numDataParts
)
}
@@ -535,7 +591,13 @@ private[ml] object EnsembleModelReadWrite {
val treesMetadataPath = new Path(path, "treesMetadata").toString
val treesMetadataWeights = ReadWriteUtils.loadArray[(Int, String, Double)](
- treesMetadataPath, sparkSession
+ treesMetadataPath, sparkSession,
+ dis => {
+ val treeID = dis.readInt()
+ val json = dis.readUTF()
+ val weights = dis.readDouble()
+ (treeID, json, weights)
+ }
).map { case (treeID: Int, json: String, weights: Double) =>
treeID -> ((DefaultParamsReader.parseMetadata(json, treeClassName),
weights))
}.sortBy(_._1).map(_._2)
@@ -555,7 +617,14 @@ private[ml] object EnsembleModelReadWrite {
df = df.withColumn("nodeData", newNodeDataCol)
df.as[EnsembleNodeData].collect()
} else {
- ReadWriteUtils.loadArray[EnsembleNodeData](dataPath, sparkSession)
+ ReadWriteUtils.loadArray[EnsembleNodeData](
+ dataPath, sparkSession,
+ dis => {
+ val treeID = dis.readInt()
+ val nodeData = NodeData.deserializeData(dis)
+ EnsembleNodeData(treeID, nodeData)
+ }
+ )
}
val rootNodes = ensembleNodeDataArray
.groupBy(_.treeID)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 63c1b9e270ae..3a6f7fc00d6f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.ml.util
-import java.io.{File, IOException}
+import java.io.{
+ BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream,
+ File, FileInputStream, FileOutputStream, IOException
+}
import java.nio.file.{Files, Paths}
import java.util.{Locale, ServiceLoader}
@@ -25,7 +28,7 @@ import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import scala.util.{Failure, Success, Try}
+import scala.util.{Failure, Success, Try, Using}
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
@@ -34,13 +37,14 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.{SparkContext, SparkEnv, SparkException}
+import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.annotation.{Since, Unstable}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.PATH
import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
import org.apache.spark.ml.feature.RFormulaModel
+import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix,
SparseMatrix, SparseVector, Vector}
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
@@ -848,6 +852,195 @@ private[spark] object ReadWriteUtils {
override def initialValue: Boolean = false
}
+ def serializeIntArray(array: Array[Int], dos: DataOutputStream): Unit = {
+ dos.writeInt(array.length)
+ for (i <- 0 until array.length) {
+ dos.writeInt(array(i))
+ }
+ }
+
+ def deserializeIntArray(dis: DataInputStream): Array[Int] = {
+ val len = dis.readInt()
+ val data = new Array[Int](len)
+ for (i <- 0 until len) {
+ data(i) = dis.readInt()
+ }
+ data
+ }
+
+ def serializeLongArray(array: Array[Long], dos: DataOutputStream): Unit = {
+ dos.writeInt(array.length)
+ for (i <- 0 until array.length) {
+ dos.writeLong(array(i))
+ }
+ }
+
+ def deserializeLongArray(dis: DataInputStream): Array[Long] = {
+ val len = dis.readInt()
+ val data = new Array[Long](len)
+ for (i <- 0 until len) {
+ data(i) = dis.readLong()
+ }
+ data
+ }
+
+ def serializeFloatArray(array: Array[Float], dos: DataOutputStream): Unit = {
+ dos.writeInt(array.length)
+ for (i <- 0 until array.length) {
+ dos.writeFloat(array(i))
+ }
+ }
+
+ def deserializeFloatArray(dis: DataInputStream): Array[Float] = {
+ val len = dis.readInt()
+ val data = new Array[Float](len)
+ for (i <- 0 until len) {
+ data(i) = dis.readFloat()
+ }
+ data
+ }
+
+ def serializeDoubleArray(array: Array[Double], dos: DataOutputStream): Unit
= {
+ dos.writeInt(array.length)
+ for (i <- 0 until array.length) {
+ dos.writeDouble(array(i))
+ }
+ }
+
+ def deserializeDoubleArray(dis: DataInputStream): Array[Double] = {
+ val len = dis.readInt()
+ val data = new Array[Double](len)
+ for (i <- 0 until len) {
+ data(i) = dis.readDouble()
+ }
+ data
+ }
+
+ def serializeStringArray(array: Array[String], dos: DataOutputStream): Unit
= {
+ serializeGenericArray[String](array, dos, (s, dos) => dos.writeUTF(s))
+ }
+
+ def deserializeStringArray(dis: DataInputStream): Array[String] = {
+ deserializeGenericArray[String](dis, dis => dis.readUTF())
+ }
+
+ def serializeMap[K, V](
+ map: Map[K, V], dos: DataOutputStream,
+ keySerializer: (K, DataOutputStream) => Unit,
+ valueSerializer: (V, DataOutputStream) => Unit
+ ): Unit = {
+ dos.writeInt(map.size)
+ map.foreach { case (k, v) =>
+ keySerializer(k, dos)
+ valueSerializer(v, dos)
+ }
+ }
+
+ def deserializeMap[K, V](
+ dis: DataInputStream,
+ keyDeserializer: DataInputStream => K,
+ valueDeserializer: DataInputStream => V
+ ): Map[K, V] = {
+ val len = dis.readInt()
+ val kvList = new Array[(K, V)](len)
+ for (i <- 0 until len) {
+ val key = keyDeserializer(dis)
+ val value = valueDeserializer(dis)
+ kvList(i) = (key, value)
+ }
+ kvList.toMap
+ }
+
+ def serializeVector(vector: Vector, dos: DataOutputStream): Unit = {
+ if (vector.isInstanceOf[DenseVector]) {
+ dos.writeBoolean(false)
+ serializeDoubleArray(vector.toArray, dos)
+ } else {
+ val sparseVec = vector.asInstanceOf[SparseVector]
+ dos.writeBoolean(true)
+ dos.writeInt(sparseVec.size)
+ serializeIntArray(sparseVec.indices, dos)
+ serializeDoubleArray(sparseVec.values, dos)
+ }
+ }
+
+ def deserializeVector(dis: DataInputStream): Vector = {
+ val isSparse = dis.readBoolean()
+ if (isSparse) {
+ val len = dis.readInt()
+ val indices = deserializeIntArray(dis)
+ val values = deserializeDoubleArray(dis)
+ new SparseVector(len, indices, values)
+ } else {
+ val values = deserializeDoubleArray(dis)
+ new DenseVector(values)
+ }
+ }
+
+ def serializeMatrix(matrix: Matrix, dos: DataOutputStream): Unit = {
+ def serializeCommon(): Unit = {
+ dos.writeInt(matrix.numRows)
+ dos.writeInt(matrix.numCols)
+ dos.writeBoolean(matrix.isTransposed)
+ }
+
+ if (matrix.isInstanceOf[DenseMatrix]) {
+ val denseMatrix = matrix.asInstanceOf[DenseMatrix]
+ dos.writeBoolean(false)
+ serializeCommon()
+ serializeDoubleArray(denseMatrix.values, dos)
+ } else {
+ val sparseMatrix = matrix.asInstanceOf[SparseMatrix]
+ dos.writeBoolean(true)
+ serializeCommon()
+ serializeIntArray(sparseMatrix.colPtrs, dos)
+ serializeIntArray(sparseMatrix.rowIndices, dos)
+ serializeDoubleArray(sparseMatrix.values, dos)
+ }
+ }
+
+ def deserializeMatrix(dis: DataInputStream): Matrix = {
+ def deserializeCommon(): (Int, Int, Boolean) = {
+ val numRows = dis.readInt()
+ val numCols = dis.readInt()
+ val transposed = dis.readBoolean()
+ (numRows, numCols, transposed)
+ }
+
+ val isSparse = dis.readBoolean()
+ if (isSparse) {
+ val (numRows, numCols, transposed) = deserializeCommon()
+ val colPtrs = deserializeIntArray(dis)
+ val rowIndices = deserializeIntArray(dis)
+ val values = deserializeDoubleArray(dis)
+ new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values,
transposed)
+ } else {
+ val (numRows, numCols, transposed) = deserializeCommon()
+ val values = deserializeDoubleArray(dis)
+ new DenseMatrix(numRows, numCols, values, transposed)
+ }
+ }
+
+ def serializeGenericArray[T: ClassTag](
+ array: Array[T], dos: DataOutputStream, serializer: (T, DataOutputStream)
=> Unit
+ ): Unit = {
+ dos.writeInt(array.length)
+ for (item <- array) {
+ serializer(item, dos)
+ }
+ }
+
+ def deserializeGenericArray[T: ClassTag](
+ dis: DataInputStream, deserializer: DataInputStream => T
+ ): Array[T] = {
+ val len = dis.readInt()
+ val data = new Array[T](len)
+ for (i <- 0 until len) {
+ data(i) = deserializer(dis)
+ }
+ data
+ }
+
def saveText(path: String, data: String, spark: SparkSession): Unit = {
if (localSavingModeState.get()) {
val filePath = Paths.get(path)
@@ -867,38 +1060,44 @@ private[spark] object ReadWriteUtils {
}
}
- def saveObjectToLocal[T <: Product: ClassTag: TypeTag](path: String, data:
T): Unit = {
- val serializer = SparkEnv.get.serializer.newInstance()
- val dataBuffer = serializer.serialize(data)
- val dataBytes = new Array[Byte](dataBuffer.limit)
- dataBuffer.get(dataBytes)
-
+ def saveObjectToLocal[T <: Product: ClassTag: TypeTag](
+ path: String, data: T, serializer: (T, DataOutputStream) => Unit
+ ): Unit = {
val filePath = Paths.get(path)
-
Files.createDirectories(filePath.getParent)
- Files.write(filePath, dataBytes)
+
+ Using.resource(
+ new DataOutputStream(new BufferedOutputStream(new
FileOutputStream(filePath.toFile)))
+ ) { dos =>
+ serializer(data, dos)
+ }
}
def saveObject[T <: Product: ClassTag: TypeTag](
- path: String, data: T, spark: SparkSession
+ path: String, data: T, spark: SparkSession, localSerializer: (T,
DataOutputStream) => Unit
): Unit = {
if (localSavingModeState.get()) {
- saveObjectToLocal(path, data)
+ saveObjectToLocal(path, data, localSerializer)
} else {
spark.createDataFrame[T](Seq(data)).write.parquet(path)
}
}
- def loadObjectFromLocal[T <: Product: ClassTag: TypeTag](path: String): T = {
- val serializer = SparkEnv.get.serializer.newInstance()
-
- val dataBytes = Files.readAllBytes(Paths.get(path))
- serializer.deserialize[T](java.nio.ByteBuffer.wrap(dataBytes))
+ def loadObjectFromLocal[T <: Product: ClassTag: TypeTag](
+ path: String, deserializer: DataInputStream => T
+ ): T = {
+ Using.resource(
+ new DataInputStream(new BufferedInputStream(new FileInputStream(path)))
+ ) { dis =>
+ deserializer(dis)
+ }
}
- def loadObject[T <: Product: ClassTag: TypeTag](path: String, spark:
SparkSession): T = {
+ def loadObject[T <: Product: ClassTag: TypeTag](
+ path: String, spark: SparkSession, localDeserializer: DataInputStream => T
+ ): T = {
if (localSavingModeState.get()) {
- loadObjectFromLocal(path)
+ loadObjectFromLocal(path, localDeserializer)
} else {
import spark.implicits._
spark.read.parquet(path).as[T].head()
@@ -907,18 +1106,18 @@ private[spark] object ReadWriteUtils {
def saveArray[T <: Product: ClassTag: TypeTag](
path: String, data: Array[T], spark: SparkSession,
+ localSerializer: (T, DataOutputStream) => Unit,
numDataParts: Int = -1
): Unit = {
if (localSavingModeState.get()) {
- val serializer = SparkEnv.get.serializer.newInstance()
- val dataBuffer = serializer.serialize(data)
- val dataBytes = new Array[Byte](dataBuffer.limit)
- dataBuffer.get(dataBytes)
-
val filePath = Paths.get(path)
-
Files.createDirectories(filePath.getParent)
- Files.write(filePath, dataBytes)
+
+ Using.resource(
+ new DataOutputStream(new BufferedOutputStream(new
FileOutputStream(filePath.toFile)))
+ ) { dos =>
+ serializeGenericArray(data, dos, localSerializer)
+ }
} else {
import org.apache.spark.util.ArrayImplicits._
val df = spark.createDataFrame[T](data.toImmutableArraySeq)
@@ -930,12 +1129,15 @@ private[spark] object ReadWriteUtils {
}
}
- def loadArray[T <: Product: ClassTag: TypeTag](path: String, spark:
SparkSession): Array[T] = {
+ def loadArray[T <: Product: ClassTag: TypeTag](
+ path: String, spark: SparkSession, localDeserializer: DataInputStream =>
T
+ ): Array[T] = {
if (localSavingModeState.get()) {
- val serializer = SparkEnv.get.serializer.newInstance()
-
- val dataBytes = Files.readAllBytes(Paths.get(path))
- serializer.deserialize[Array[T]](java.nio.ByteBuffer.wrap(dataBytes))
+ Using.resource(
+ new DataInputStream(new BufferedInputStream(new FileInputStream(path)))
+ ) { dis =>
+ deserializeGenericArray(dis, localDeserializer)
+ }
} else {
import spark.implicits._
spark.read.parquet(path).as[T].collect()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]