This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fd74b5ec6bd6 [SPARK-52470][ML][CONNECT] Support model summary
offloading
fd74b5ec6bd6 is described below
commit fd74b5ec6bd662ca51ea6bec06c4432503e566d4
Author: Weichen Xu <[email protected]>
AuthorDate: Tue Jun 17 20:41:15 2025 +0800
[SPARK-52470][ML][CONNECT] Support model summary offloading
### What changes were proposed in this pull request?
This PR makes Spark Connect ML supporting model summary offloading.
Model summary offloading is hard to support because it contains a Spark
dataset which can't be easily serialized in Spark driver (NOTE: we can't java
serializer to serialize the Spark dataset logical plan otherwise it is a RCE
vulnerability),
to address the issue, when saving Summary to disk, it only saves the
necessary data fields,
when loading Summary back, the client needs to send the dataset to Spark
driver again,
to achieve it, 2 new proto messages are introduced:
1. `CreateSummary` in `MlCommand`
```
// This is for re-creating the model summary when the model summary is
lost
// (model summary is lost when the model is offloaded and then loaded
back)
message CreateSummary {
ObjectRef model_ref = 1;
Relation dataset = 2;
}
```
2: `model_summary_dataset` in `MlRelation`
```
// (Optional) the dataset for restoring the model summary
optional Relation model_summary_dataset = 3;
```
### Why are the changes needed?
Support model summary offloading.
Without this, the model summary will be evicted from Spark driver memory
after default 15min timeout, results in `model.summary` API unavailability.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #51187 from WeichenXu123/SPARK-52470.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 5 +
.../spark/ml/classification/FMClassifier.scala | 49 ++-
.../apache/spark/ml/classification/LinearSVC.scala | 49 ++-
.../ml/classification/LogisticRegression.scala | 73 +++--
.../MultilayerPerceptronClassifier.scala | 43 ++-
.../ml/classification/RandomForestClassifier.scala | 51 ++--
.../spark/ml/clustering/BisectingKMeans.scala | 3 +
.../spark/ml/clustering/GaussianMixture.scala | 37 ++-
.../org/apache/spark/ml/clustering/KMeans.scala | 47 ++-
.../regression/GeneralizedLinearRegression.scala | 45 ++-
.../spark/ml/regression/LinearRegression.scala | 93 ++++--
.../apache/spark/ml/util/HasTrainingSummary.scala | 11 +
python/pyspark/ml/classification.py | 129 +++-----
python/pyspark/ml/clustering.py | 63 +---
python/pyspark/ml/connect/proto.py | 11 +-
python/pyspark/ml/regression.py | 44 +--
.../pyspark/ml/tests/connect/test_connect_cache.py | 13 +-
python/pyspark/ml/tests/test_classification.py | 236 +++++++++------
python/pyspark/ml/tests/test_clustering.py | 101 ++++---
python/pyspark/ml/tests/test_regression.py | 150 +++++-----
python/pyspark/ml/util.py | 90 ++++--
python/pyspark/sql/connect/client/core.py | 3 +-
python/pyspark/sql/connect/proto/ml_pb2.py | 46 +--
python/pyspark/sql/connect/proto/ml_pb2.pyi | 66 +++-
python/pyspark/sql/connect/proto/relations_pb2.py | 332 ++++++++++-----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 33 +-
.../src/main/protobuf/spark/connect/ml.proto | 11 +
.../main/protobuf/spark/connect/relations.proto | 3 +
.../org/apache/spark/sql/connect/ml/MLCache.scala | 31 +-
.../apache/spark/sql/connect/ml/MLException.scala | 6 +
.../apache/spark/sql/connect/ml/MLHandler.scala | 55 +++-
31 files changed, 1194 insertions(+), 735 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 5def48196cf3..a280887da845 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -853,6 +853,11 @@
"Please fit or load a model smaller than <modelMaxSize> bytes."
]
},
+ "MODEL_SUMMARY_LOST" : {
+ "message" : [
+ "The model <objectName> summary is lost because the cached model is
offloaded."
+ ]
+ },
"UNSUPPORTED_EXCEPTION" : {
"message" : [
"<message>"
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index cefa13b2bbe7..b653383161e7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -224,17 +224,8 @@ class FMClassifier @Since("3.0.0") (
factors: Matrix,
objectiveHistory: Array[Double]): FMClassificationModel = {
val model = copyValues(new FMClassificationModel(uid, intercept, linear,
factors))
- val weightColName = if (!isDefined(weightCol)) "weightCol" else
$(weightCol)
-
- val (summaryModel, probabilityColName, predictionColName) =
model.findSummaryModel()
- val summary = new FMClassificationTrainingSummaryImpl(
- summaryModel.transform(dataset),
- probabilityColName,
- predictionColName,
- $(labelCol),
- weightColName,
- objectiveHistory)
- model.setSummary(Some(summary))
+ model.createSummary(dataset, objectiveHistory)
+ model
}
@Since("3.0.0")
@@ -343,6 +334,42 @@ class FMClassificationModel private[classification] (
s"uid=${super.toString}, numClasses=$numClasses,
numFeatures=$numFeatures, " +
s"factorSize=${$(factorSize)}, fitLinear=${$(fitLinear)},
fitIntercept=${$(fitIntercept)}"
}
+
+ private[spark] def createSummary(
+ dataset: Dataset[_], objectiveHistory: Array[Double]
+ ): Unit = {
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else
$(weightCol)
+
+ val (summaryModel, probabilityColName, predictionColName) =
findSummaryModel()
+ val summary = new FMClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ objectiveHistory)
+ setSummary(Some(summary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {
+ ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
+ path, Tuple1(summary.objectiveHistory),
+ (data, dos) => {
+ ReadWriteUtils.serializeDoubleArray(data._1, dos)
+ }
+ )
+ }
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ val Tuple1(objectiveHistory: Array[Double])
+ = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
+ path,
+ dis => {
+ Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
+ }
+ )
+ createSummary(dataset, objectiveHistory)
+ }
}
@Since("3.0.0")
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index a50346ae88f4..0d163b761686 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -277,17 +277,8 @@ class LinearSVC @Since("2.2.0") (
intercept: Double,
objectiveHistory: Array[Double]): LinearSVCModel = {
val model = copyValues(new LinearSVCModel(uid, coefficients, intercept))
- val weightColName = if (!isDefined(weightCol)) "weightCol" else
$(weightCol)
-
- val (summaryModel, rawPredictionColName, predictionColName) =
model.findSummaryModel()
- val summary = new LinearSVCTrainingSummaryImpl(
- summaryModel.transform(dataset),
- rawPredictionColName,
- predictionColName,
- $(labelCol),
- weightColName,
- objectiveHistory)
- model.setSummary(Some(summary))
+ model.createSummary(dataset, objectiveHistory)
+ model
}
private def trainImpl(
@@ -445,6 +436,42 @@ class LinearSVCModel private[classification] (
override def toString: String = {
s"LinearSVCModel: uid=$uid, numClasses=$numClasses,
numFeatures=$numFeatures"
}
+
+ private[spark] def createSummary(
+ dataset: Dataset[_], objectiveHistory: Array[Double]
+ ): Unit = {
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else
$(weightCol)
+
+ val (summaryModel, rawPredictionColName, predictionColName) =
findSummaryModel()
+ val summary = new LinearSVCTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ rawPredictionColName,
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ objectiveHistory)
+ setSummary(Some(summary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {
+ ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
+ path, Tuple1(summary.objectiveHistory),
+ (data, dos) => {
+ ReadWriteUtils.serializeDoubleArray(data._1, dos)
+ }
+ )
+ }
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ val Tuple1(objectiveHistory: Array[Double])
+ = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
+ path,
+ dis => {
+ Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
+ }
+ )
+ createSummary(dataset, objectiveHistory)
+ }
}
@Since("2.2.0")
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 58a2652d0eab..8c010f67f5e0 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -718,29 +718,8 @@ class LogisticRegression @Since("1.2.0") (
objectiveHistory: Array[Double]): LogisticRegressionModel = {
val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix,
interceptVector,
numClasses, checkMultinomial(numClasses)))
- val weightColName = if (!isDefined(weightCol)) "weightCol" else
$(weightCol)
-
- val (summaryModel, probabilityColName, predictionColName) =
model.findSummaryModel()
- val logRegSummary = if (numClasses <= 2) {
- new BinaryLogisticRegressionTrainingSummaryImpl(
- summaryModel.transform(dataset),
- probabilityColName,
- predictionColName,
- $(labelCol),
- $(featuresCol),
- weightColName,
- objectiveHistory)
- } else {
- new LogisticRegressionTrainingSummaryImpl(
- summaryModel.transform(dataset),
- probabilityColName,
- predictionColName,
- $(labelCol),
- $(featuresCol),
- weightColName,
- objectiveHistory)
- }
- model.setSummary(Some(logRegSummary))
+ model.createSummary(dataset, objectiveHistory)
+ model
}
private def createBounds(
@@ -1323,6 +1302,54 @@ class LogisticRegressionModel private[spark] (
override def toString: String = {
s"LogisticRegressionModel: uid=$uid, numClasses=$numClasses,
numFeatures=$numFeatures"
}
+
+ private[spark] def createSummary(
+ dataset: Dataset[_], objectiveHistory: Array[Double]
+ ): Unit = {
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else
$(weightCol)
+
+ val (summaryModel, probabilityColName, predictionColName) =
findSummaryModel()
+ val logRegSummary = if (numClasses <= 2) {
+ new BinaryLogisticRegressionTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ $(featuresCol),
+ weightColName,
+ objectiveHistory)
+ } else {
+ new LogisticRegressionTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ $(featuresCol),
+ weightColName,
+ objectiveHistory)
+ }
+ setSummary(Some(logRegSummary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {
+ ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
+ path, Tuple1(summary.objectiveHistory),
+ (data, dos) => {
+ ReadWriteUtils.serializeDoubleArray(data._1, dos)
+ }
+ )
+ }
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ val Tuple1(objectiveHistory: Array[Double])
+ = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
+ path,
+ dis => {
+ Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
+ }
+ )
+ createSummary(dataset, objectiveHistory)
+ }
}
@Since("1.6.0")
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 6bd46cff815d..5e52d62fb83c 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -251,14 +251,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
objectiveHistory: Array[Double]):
MultilayerPerceptronClassificationModel = {
val model = copyValues(new MultilayerPerceptronClassificationModel(uid,
weights))
- val (summaryModel, _, predictionColName) = model.findSummaryModel()
- val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
- summaryModel.transform(dataset),
- predictionColName,
- $(labelCol),
- "",
- objectiveHistory)
- model.setSummary(Some(summary))
+ model.createSummary(dataset, objectiveHistory)
+ model
}
}
@@ -365,6 +359,39 @@ class MultilayerPerceptronClassificationModel private[ml] (
s"MultilayerPerceptronClassificationModel: uid=$uid,
numLayers=${$(layers).length}, " +
s"numClasses=$numClasses, numFeatures=$numFeatures"
}
+
+ private[spark] def createSummary(
+ dataset: Dataset[_], objectiveHistory: Array[Double]
+ ): Unit = {
+ val (summaryModel, _, predictionColName) = findSummaryModel()
+ val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ predictionColName,
+ $(labelCol),
+ "",
+ objectiveHistory)
+ setSummary(Some(summary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {
+ ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
+ path, Tuple1(summary.objectiveHistory),
+ (data, dos) => {
+ ReadWriteUtils.serializeDoubleArray(data._1, dos)
+ }
+ )
+ }
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ val Tuple1(objectiveHistory: Array[Double])
+ = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
+ path,
+ dis => {
+ Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
+ }
+ )
+ createSummary(dataset, objectiveHistory)
+ }
}
@Since("2.0.0")
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index f64e2a6d4efc..8b580b1e075c 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -182,26 +182,8 @@ class RandomForestClassifier @Since("1.4.0") (
numFeatures: Int,
numClasses: Int): RandomForestClassificationModel = {
val model = copyValues(new RandomForestClassificationModel(uid, trees,
numFeatures, numClasses))
- val weightColName = if (!isDefined(weightCol)) "weightCol" else
$(weightCol)
-
- val (summaryModel, probabilityColName, predictionColName) =
model.findSummaryModel()
- val rfSummary = if (numClasses <= 2) {
- new BinaryRandomForestClassificationTrainingSummaryImpl(
- summaryModel.transform(dataset),
- probabilityColName,
- predictionColName,
- $(labelCol),
- weightColName,
- Array(0.0))
- } else {
- new RandomForestClassificationTrainingSummaryImpl(
- summaryModel.transform(dataset),
- predictionColName,
- $(labelCol),
- weightColName,
- Array(0.0))
- }
- model.setSummary(Some(rfSummary))
+ model.createSummary(dataset)
+ model
}
@Since("1.4.1")
@@ -393,6 +375,35 @@ class RandomForestClassificationModel private[ml] (
@Since("2.0.0")
override def write: MLWriter =
new
RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
+
+ private[spark] def createSummary(dataset: Dataset[_]): Unit = {
+ val weightColName = if (!isDefined(weightCol)) "weightCol" else
$(weightCol)
+
+ val (summaryModel, probabilityColName, predictionColName) =
findSummaryModel()
+ val rfSummary = if (numClasses <= 2) {
+ new BinaryRandomForestClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ Array(0.0))
+ } else {
+ new RandomForestClassificationTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ predictionColName,
+ $(labelCol),
+ weightColName,
+ Array(0.0))
+ }
+ setSummary(Some(rfSummary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {}
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ createSummary(dataset)
+ }
}
@Since("2.0.0")
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 3248b4b391d0..9e09ee00c3e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -180,6 +180,9 @@ class BisectingKMeansModel private[ml] (
override def summary: BisectingKMeansSummary = super.summary
override def estimatedSize: Long = SizeEstimator.estimate(parentModel)
+
+ // BisectingKMeans model hasn't supported offloading, so put an empty
`saveSummary` here for now
+ override private[spark] def saveSummary(path: String): Unit = {}
}
object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index a94b8a87d8fc..e7f930065486 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -223,6 +223,36 @@ class GaussianMixtureModel private[ml] (
override def summary: GaussianMixtureSummary = super.summary
override def estimatedSize: Long = SizeEstimator.estimate((weights,
gaussians))
+
+ private[spark] def createSummary(
+ predictions: DataFrame, logLikelihood: Double, iteration: Int
+ ): Unit = {
+ val summary = new GaussianMixtureSummary(predictions,
+ $(predictionCol), $(probabilityCol), $(featuresCol), $(k),
logLikelihood, iteration)
+ setSummary(Some(summary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {
+ ReadWriteUtils.saveObjectToLocal[(Double, Int)](
+ path, (summary.logLikelihood, summary.numIter),
+ (data, dos) => {
+ dos.writeDouble(data._1)
+ dos.writeInt(data._2)
+ }
+ )
+ }
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ val (logLikelihood: Double, numIter: Int) =
ReadWriteUtils.loadObjectFromLocal[(Double, Int)](
+ path,
+ dis => {
+ val logLikelihood = dis.readDouble()
+ val numIter = dis.readInt()
+ (logLikelihood, numIter)
+ }
+ )
+ createSummary(dataset, logLikelihood, numIter)
+ }
}
@Since("2.0.0")
@@ -453,11 +483,10 @@ class GaussianMixture @Since("2.0.0") (
val model = copyValues(new GaussianMixtureModel(uid, weights,
gaussianDists))
.setParent(this)
- val summary = new GaussianMixtureSummary(model.transform(dataset),
- $(predictionCol), $(probabilityCol), $(featuresCol), $(k),
logLikelihood, iteration)
+ model.createSummary(model.transform(dataset), logLikelihood, iteration)
instr.logNamedValue("logLikelihood", logLikelihood)
- instr.logNamedValue("clusterSizes", summary.clusterSizes)
- model.setSummary(Some(summary))
+ instr.logNamedValue("clusterSizes", model.summary.clusterSizes)
+ model
}
private def trainImpl(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index f3ac58e670e5..ccae39cedd20 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -215,6 +215,42 @@ class KMeansModel private[ml] (
override def summary: KMeansSummary = super.summary
override def estimatedSize: Long =
SizeEstimator.estimate(parentModel.clusterCenters)
+
+ private[spark] def createSummary(
+ predictions: DataFrame, numIter: Int, trainingCost: Double
+ ): Unit = {
+ val summary = new KMeansSummary(
+ predictions,
+ $(predictionCol),
+ $(featuresCol),
+ $(k),
+ numIter,
+ trainingCost)
+
+ setSummary(Some(summary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {
+ ReadWriteUtils.saveObjectToLocal[(Int, Double)](
+ path, (summary.numIter, summary.trainingCost),
+ (data, dos) => {
+ dos.writeInt(data._1)
+ dos.writeDouble(data._2)
+ }
+ )
+ }
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ val (numIter: Int, trainingCost: Double) =
ReadWriteUtils.loadObjectFromLocal[(Int, Double)](
+ path,
+ dis => {
+ val numIter = dis.readInt()
+ val trainingCost = dis.readDouble()
+ (numIter, trainingCost)
+ }
+ )
+ createSummary(dataset, numIter, trainingCost)
+ }
}
/** Helper class for storing model data */
@@ -414,16 +450,9 @@ class KMeans @Since("1.5.0") (
}
val model = copyValues(new KMeansModel(uid, oldModel).setParent(this))
- val summary = new KMeansSummary(
- model.transform(dataset),
- $(predictionCol),
- $(featuresCol),
- $(k),
- oldModel.numIter,
- oldModel.trainingCost)
- model.setSummary(Some(summary))
- instr.logNamedValue("clusterSizes", summary.clusterSizes)
+ model.createSummary(model.transform(dataset), oldModel.numIter,
oldModel.trainingCost)
+ instr.logNamedValue("clusterSizes", model.summary.clusterSizes)
model
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 14467c761b21..cf62c2bf41b6 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -419,9 +419,8 @@ class GeneralizedLinearRegression @Since("2.0.0")
(@Since("2.0.0") override val
val model = copyValues(
new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients,
wlsModel.intercept)
.setParent(this))
- val trainingSummary = new
GeneralizedLinearRegressionTrainingSummary(dataset, model,
- wlsModel.diagInvAtWA.toArray, 1, getSolver)
- model.setSummary(Some(trainingSummary))
+ model.createSummary(dataset, wlsModel.diagInvAtWA.toArray, 1)
+ model
} else {
val instances = validated.rdd.map {
case Row(label: Double, weight: Double, offset: Double, features:
Vector) =>
@@ -436,9 +435,8 @@ class GeneralizedLinearRegression @Since("2.0.0")
(@Since("2.0.0") override val
val model = copyValues(
new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients,
irlsModel.intercept)
.setParent(this))
- val trainingSummary = new
GeneralizedLinearRegressionTrainingSummary(dataset, model,
- irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
- model.setSummary(Some(trainingSummary))
+ model.createSummary(dataset, irlsModel.diagInvAtWA.toArray,
irlsModel.numIterations)
+ model
}
model
@@ -1140,6 +1138,39 @@ class GeneralizedLinearRegressionModel private[ml] (
s"GeneralizedLinearRegressionModel: uid=$uid, family=${$(family)},
link=${$(link)}, " +
s"numFeatures=$numFeatures"
}
+
+ private[spark] def createSummary(
+ dataset: Dataset[_], diagInvAtWA: Array[Double], numIter: Int
+ ): Unit = {
+ val summary = new GeneralizedLinearRegressionTrainingSummary(
+ dataset, this, diagInvAtWA, numIter, $(solver)
+ )
+
+ setSummary(Some(summary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {
+ ReadWriteUtils.saveObjectToLocal[(Array[Double], Int)](
+ path, (summary.diagInvAtWA, summary.numIterations),
+ (data, dos) => {
+ ReadWriteUtils.serializeDoubleArray(data._1, dos)
+ dos.writeInt(data._2)
+ }
+ )
+ }
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ val (diagInvAtWA: Array[Double], numIterations: Int) =
+ ReadWriteUtils.loadObjectFromLocal[(Array[Double], Int)](
+ path,
+ dis => {
+ val diagInvAtWA = ReadWriteUtils.deserializeDoubleArray(dis)
+ val numIterations = dis.readInt()
+ (diagInvAtWA, numIterations)
+ }
+ )
+ createSummary(dataset, diagInvAtWA, numIterations)
+ }
}
@Since("2.0.0")
@@ -1467,7 +1498,7 @@ class GeneralizedLinearRegressionSummary
private[regression] (
class GeneralizedLinearRegressionTrainingSummary private[regression] (
dataset: Dataset[_],
origModel: GeneralizedLinearRegressionModel,
- private val diagInvAtWA: Array[Double],
+ private[spark] val diagInvAtWA: Array[Double],
@Since("2.0.0") val numIterations: Int,
@Since("2.0.0") val solver: String)
extends GeneralizedLinearRegressionSummary(dataset, origModel) with
Serializable {
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index b06140e48338..822df270c0bf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -433,15 +433,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0")
override val uid: String
}
val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd)
-
- // Handle possible missing or invalid prediction columns
- val (summaryModel, predictionColName) =
model.findSummaryModelAndPredictionCol()
- val trainingSummary = new LinearRegressionTrainingSummary(
- summaryModel.transform(dataset), predictionColName, $(labelCol),
$(featuresCol),
- summaryModel.get(summaryModel.weightCol).getOrElse(""),
- summaryModel.numFeatures, summaryModel.getFitIntercept,
- Array(0.0), objectiveHistory)
- model.setSummary(Some(trainingSummary))
+ model.createSummary(dataset, Array(0.0), objectiveHistory,
Array.emptyDoubleArray)
+ model
}
private def trainWithNormal(
@@ -459,20 +452,16 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0")
override val uid: String
// attach returned model.
val lrModel = copyValues(new LinearRegressionModel(
uid, model.coefficients.compressed, model.intercept))
- val (summaryModel, predictionColName) =
lrModel.findSummaryModelAndPredictionCol()
- val coefficientArray = if (summaryModel.getFitIntercept) {
- summaryModel.coefficients.toArray ++ Array(summaryModel.intercept)
+ val coefficientArray = if (lrModel.getFitIntercept) {
+ lrModel.coefficients.toArray ++ Array(lrModel.intercept)
} else {
- summaryModel.coefficients.toArray
+ lrModel.coefficients.toArray
}
- val trainingSummary = new LinearRegressionTrainingSummary(
- summaryModel.transform(dataset), predictionColName, $(labelCol),
$(featuresCol),
- summaryModel.get(summaryModel.weightCol).getOrElse(""),
- summaryModel.numFeatures, summaryModel.getFitIntercept,
- model.diagInvAtWA.toArray, model.objectiveHistory, coefficientArray)
-
- lrModel.setSummary(Some(trainingSummary))
+ lrModel.createSummary(
+ dataset, model.diagInvAtWA.toArray, model.objectiveHistory,
coefficientArray
+ )
+ lrModel
}
private def trainWithConstantLabel(
@@ -497,16 +486,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0")
override val uid: String
val intercept = yMean
val model = copyValues(new LinearRegressionModel(uid, coefficients,
intercept))
- // Handle possible missing or invalid prediction columns
- val (summaryModel, predictionColName) =
model.findSummaryModelAndPredictionCol()
- val trainingSummary = new LinearRegressionTrainingSummary(
- summaryModel.transform(dataset), predictionColName, $(labelCol),
$(featuresCol),
- summaryModel.get(summaryModel.weightCol).getOrElse(""),
- summaryModel.numFeatures, summaryModel.getFitIntercept,
- Array(0.0), Array(0.0))
-
- model.setSummary(Some(trainingSummary))
+ model.createSummary(dataset, Array(0.0), Array(0.0),
Array.emptyDoubleArray)
+ model
}
private def createOptimizer(
@@ -800,6 +782,53 @@ class LinearRegressionModel private[ml] (
override def toString: String = {
s"LinearRegressionModel: uid=$uid, numFeatures=$numFeatures"
}
+
+ private[spark] def createSummary(
+ dataset: Dataset[_],
+ diagInvAtWA: Array[Double],
+ objectiveHistory: Array[Double],
+ coefficientArray: Array[Double]
+ ): Unit = {
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
+
+ val trainingSummary = new LinearRegressionTrainingSummary(
+ summaryModel.transform(dataset), predictionColName, $(labelCol),
$(featuresCol),
+ summaryModel.get(summaryModel.weightCol).getOrElse(""),
+ summaryModel.numFeatures, summaryModel.getFitIntercept,
+ diagInvAtWA, objectiveHistory, coefficientArray)
+
+ setSummary(Some(trainingSummary))
+ }
+
+ override private[spark] def saveSummary(path: String): Unit = {
+ ReadWriteUtils.saveObjectToLocal[(Array[Double], Array[Double],
Array[Double])](
+ path, (summary.diagInvAtWA, summary.objectiveHistory,
summary.coefficientArray),
+ (data, dos) => {
+ ReadWriteUtils.serializeDoubleArray(data._1, dos)
+ ReadWriteUtils.serializeDoubleArray(data._2, dos)
+ ReadWriteUtils.serializeDoubleArray(data._3, dos)
+ }
+ )
+ }
+
+ override private[spark] def loadSummary(path: String, dataset: DataFrame):
Unit = {
+ val (
+ diagInvAtWA: Array[Double],
+ objectiveHistory: Array[Double],
+ coefficientArray: Array[Double]
+ )
+ = ReadWriteUtils.loadObjectFromLocal[(Array[Double], Array[Double],
Array[Double])](
+ path,
+ dis => {
+ val diagInvAtWA = ReadWriteUtils.deserializeDoubleArray(dis)
+ val objectiveHistory = ReadWriteUtils.deserializeDoubleArray(dis)
+ val coefficientArray = ReadWriteUtils.deserializeDoubleArray(dis)
+ (diagInvAtWA, objectiveHistory, coefficientArray)
+ }
+ )
+ createSummary(dataset, diagInvAtWA, objectiveHistory, coefficientArray)
+ }
}
private[ml] case class LinearModelData(intercept: Double, coefficients:
Vector, scale: Double)
@@ -926,7 +955,7 @@ class LinearRegressionTrainingSummary private[regression] (
private val fitIntercept: Boolean,
diagInvAtWA: Array[Double],
val objectiveHistory: Array[Double],
- private val coefficientArray: Array[Double] = Array.emptyDoubleArray)
+ override private[regression] val coefficientArray: Array[Double] =
Array.emptyDoubleArray)
extends LinearRegressionSummary(
predictions,
predictionCol,
@@ -972,8 +1001,8 @@ class LinearRegressionSummary private[regression] (
private val weightCol: String,
private val numFeatures: Int,
private val fitIntercept: Boolean,
- private val diagInvAtWA: Array[Double],
- private val coefficientArray: Array[Double] = Array.emptyDoubleArray)
+ private[regression] val diagInvAtWA: Array[Double],
+ private[regression] val coefficientArray: Array[Double] =
Array.emptyDoubleArray)
extends Summary with Serializable {
@transient private val metrics = {
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
index 0ba8ce072ab4..c6f6babf71a2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.util
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
+import org.apache.spark.sql.DataFrame
/**
@@ -49,4 +50,14 @@ private[spark] trait HasTrainingSummary[T] {
this.trainingSummary = summary
this
}
+
+ private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
+ throw new SparkException(
+ s"No loadSummary implementation for this ${this.getClass.getSimpleName}")
+ }
+
+ private[spark] def saveSummary(path: String): Unit = {
+ throw new SparkException(
+ s"No saveSummary implementation for this ${this.getClass.getSimpleName}")
+ }
}
diff --git a/python/pyspark/ml/classification.py
b/python/pyspark/ml/classification.py
index a5fdaed0db2c..f66fc762971b 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -889,15 +889,14 @@ class LinearSVCModel(
Gets summary (accuracy/precision/recall, objective history, total
iterations) of model
trained on the training set. An exception is thrown if
`trainingSummary is None`.
"""
- if self.hasSummary:
- s = LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
+ return super().summary
+
+ @property
+ def _summaryCls(self) -> type:
+ return LinearSVCTrainingSummary
+
+ def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame:
+ return train_dataset
def evaluate(self, dataset: DataFrame) -> "LinearSVCSummary":
"""
@@ -1577,29 +1576,6 @@ class LogisticRegressionModel(
"""
return self._call_java("interceptVector")
- @property
- @since("2.0.0")
- def summary(self) -> "LogisticRegressionTrainingSummary":
- """
- Gets summary (accuracy/precision/recall, objective history, total
iterations) of model
- trained on the training set. An exception is thrown if
`trainingSummary is None`.
- """
- if self.hasSummary:
- s: LogisticRegressionTrainingSummary
- if self.numClasses <= 2:
- s = BinaryLogisticRegressionTrainingSummary(
- super(LogisticRegressionModel, self).summary
- )
- else:
- s =
LogisticRegressionTrainingSummary(super(LogisticRegressionModel, self).summary)
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
-
def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary":
"""
Evaluates the model on a test dataset.
@@ -1623,6 +1599,15 @@ class LogisticRegressionModel(
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
+ @property
+ def _summaryCls(self) -> type:
+ if self.numClasses <= 2:
+ return BinaryLogisticRegressionTrainingSummary
+ return LogisticRegressionTrainingSummary
+
+ def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame:
+ return train_dataset
+
class LogisticRegressionSummary(_ClassificationSummary):
"""
@@ -2315,29 +2300,13 @@ class RandomForestClassificationModel(
return [DecisionTreeClassificationModel(m) for m in
list(self._call_java("trees"))]
@property
- @since("3.1.0")
- def summary(self) -> "RandomForestClassificationTrainingSummary":
- """
- Gets summary (accuracy/precision/recall, objective history, total
iterations) of model
- trained on the training set. An exception is thrown if
`trainingSummary is None`.
- """
- if self.hasSummary:
- s: RandomForestClassificationTrainingSummary
- if self.numClasses <= 2:
- s = BinaryRandomForestClassificationTrainingSummary(
- super(RandomForestClassificationModel, self).summary
- )
- else:
- s = RandomForestClassificationTrainingSummary(
- super(RandomForestClassificationModel, self).summary
- )
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
+ def _summaryCls(self) -> type:
+ if self.numClasses <= 2:
+ return BinaryRandomForestClassificationTrainingSummary
+ return RandomForestClassificationTrainingSummary
+
+ def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame:
+ return train_dataset
def evaluate(self, dataset: DataFrame) ->
"RandomForestClassificationSummary":
"""
@@ -3372,17 +3341,14 @@ class MultilayerPerceptronClassificationModel(
Gets summary (accuracy/precision/recall, objective history, total
iterations) of model
trained on the training set. An exception is thrown if
`trainingSummary is None`.
"""
- if self.hasSummary:
- s = MultilayerPerceptronClassificationTrainingSummary(
- super(MultilayerPerceptronClassificationModel, self).summary
- )
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
+ return super().summary
+
+ @property
+ def _summaryCls(self) -> type:
+ return MultilayerPerceptronClassificationTrainingSummary
+
+ def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame:
+ return train_dataset
def evaluate(self, dataset: DataFrame) ->
"MultilayerPerceptronClassificationSummary":
"""
@@ -4321,22 +4287,6 @@ class FMClassificationModel(
"""
return self._call_java("factors")
- @since("3.1.0")
- def summary(self) -> "FMClassificationTrainingSummary":
- """
- Gets summary (accuracy/precision/recall, objective history, total
iterations) of model
- trained on the training set. An exception is thrown if
`trainingSummary is None`.
- """
- if self.hasSummary:
- s = FMClassificationTrainingSummary(super(FMClassificationModel,
self).summary)
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
-
def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary":
"""
Evaluates the model on a test dataset.
@@ -4356,6 +4306,21 @@ class FMClassificationModel(
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
+ @since("3.1.0")
+ def summary(self) -> "FMClassificationTrainingSummary":
+ """
+ Gets summary (accuracy/precision/recall, objective history, total
iterations) of model
+ trained on the training set. An exception is thrown if
`trainingSummary is None`.
+ """
+ return super().summary
+
+ @property
+ def _summaryCls(self) -> type:
+ return FMClassificationTrainingSummary
+
+ def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame:
+ return train_dataset
+
class FMClassificationSummary(_BinaryClassificationSummary):
"""
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 7267ee280598..0e26398de3c4 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -255,23 +255,6 @@ class GaussianMixtureModel(
"""
return self._call_java("gaussiansDF")
- @property
- @since("2.1.0")
- def summary(self) -> "GaussianMixtureSummary":
- """
- Gets summary (cluster assignments, cluster sizes) of the model trained
on the
- training set. An exception is thrown if no summary exists.
- """
- if self.hasSummary:
- s = GaussianMixtureSummary(super(GaussianMixtureModel,
self).summary)
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
-
@since("3.0.0")
def predict(self, value: Vector) -> int:
"""
@@ -286,6 +269,10 @@ class GaussianMixtureModel(
"""
return self._call_java("predictProbability", value)
+ @property
+ def _summaryCls(self) -> type:
+ return GaussianMixtureSummary
+
@inherit_doc
class GaussianMixture(
@@ -705,23 +692,6 @@ class KMeansModel(
"""
return self._call_java("numFeatures")
- @property
- @since("2.1.0")
- def summary(self) -> KMeansSummary:
- """
- Gets summary (cluster assignments, cluster sizes) of the model trained
on the
- training set. An exception is thrown if no summary exists.
- """
- if self.hasSummary:
- s = KMeansSummary(super(KMeansModel, self).summary)
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
-
@since("3.0.0")
def predict(self, value: Vector) -> int:
"""
@@ -729,6 +699,10 @@ class KMeansModel(
"""
return self._call_java("predict", value)
+ @property
+ def _summaryCls(self) -> type:
+ return KMeansSummary
+
@inherit_doc
class KMeans(JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable,
JavaMLReadable["KMeans"]):
@@ -1055,23 +1029,6 @@ class BisectingKMeansModel(
"""
return self._call_java("numFeatures")
- @property
- @since("2.1.0")
- def summary(self) -> "BisectingKMeansSummary":
- """
- Gets summary (cluster assignments, cluster sizes) of the model trained
on the
- training set. An exception is thrown if no summary exists.
- """
- if self.hasSummary:
- s = BisectingKMeansSummary(super(BisectingKMeansModel,
self).summary)
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
-
@since("3.0.0")
def predict(self, value: Vector) -> int:
"""
@@ -1079,6 +1036,10 @@ class BisectingKMeansModel(
"""
return self._call_java("predict", value)
+ @property
+ def _summaryCls(self) -> type:
+ return BisectingKMeansSummary
+
@inherit_doc
class BisectingKMeans(
diff --git a/python/pyspark/ml/connect/proto.py
b/python/pyspark/ml/connect/proto.py
index 31f100859281..7cffd32631ba 100644
--- a/python/pyspark/ml/connect/proto.py
+++ b/python/pyspark/ml/connect/proto.py
@@ -70,8 +70,13 @@ class AttributeRelation(LogicalPlan):
could be a model or a summary. This attribute returns a DataFrame.
"""
- def __init__(self, ref_id: str, methods: List[pb2.Fetch.Method]) -> None:
- super().__init__(None)
+ def __init__(
+ self,
+ ref_id: str,
+ methods: List[pb2.Fetch.Method],
+ child: Optional["LogicalPlan"] = None,
+ ) -> None:
+ super().__init__(child)
self._ref_id = ref_id
self._methods = methods
@@ -79,4 +84,6 @@ class AttributeRelation(LogicalPlan):
plan = self._create_proto_relation()
plan.ml_relation.fetch.obj_ref.CopyFrom(pb2.ObjectRef(id=self._ref_id))
plan.ml_relation.fetch.methods.extend(self._methods)
+ if self._child is not None:
+
plan.ml_relation.model_summary_dataset.CopyFrom(self._child.plan(session))
return plan
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 66d6dbd6a267..ce97b98f6665 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -479,22 +479,11 @@ class LinearRegressionModel(
return self._call_java("scale")
@property
- @since("2.0.0")
- def summary(self) -> "LinearRegressionTrainingSummary":
- """
- Gets summary (residuals, MSE, r-squared ) of model on
- training set. An exception is thrown if
- `trainingSummary is None`.
- """
- if self.hasSummary:
- s = LinearRegressionTrainingSummary(super(LinearRegressionModel,
self).summary)
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
+ def _summaryCls(self) -> type:
+ return LinearRegressionTrainingSummary
+
+ def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame:
+ return train_dataset
def evaluate(self, dataset: DataFrame) -> "LinearRegressionSummary":
"""
@@ -2774,24 +2763,11 @@ class GeneralizedLinearRegressionModel(
return self._call_java("intercept")
@property
- @since("2.0.0")
- def summary(self) -> "GeneralizedLinearRegressionTrainingSummary":
- """
- Gets summary (residuals, deviance, p-values) of model on
- training set. An exception is thrown if
- `trainingSummary is None`.
- """
- if self.hasSummary:
- s = GeneralizedLinearRegressionTrainingSummary(
- super(GeneralizedLinearRegressionModel, self).summary
- )
- if is_remote():
- s.__source_transformer__ = self # type: ignore[attr-defined]
- return s
- else:
- raise RuntimeError(
- "No training summary available for this %s" %
self.__class__.__name__
- )
+ def _summaryCls(self) -> type:
+ return GeneralizedLinearRegressionTrainingSummary
+
+ def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame:
+ return train_dataset
def evaluate(self, dataset: DataFrame) ->
"GeneralizedLinearRegressionSummary":
"""
diff --git a/python/pyspark/ml/tests/connect/test_connect_cache.py
b/python/pyspark/ml/tests/connect/test_connect_cache.py
index 8d156a0f11e1..f911ab22286c 100644
--- a/python/pyspark/ml/tests/connect/test_connect_cache.py
+++ b/python/pyspark/ml/tests/connect/test_connect_cache.py
@@ -48,20 +48,24 @@ class MLConnectCacheTests(ReusedConnectTestCase):
"obj: class org.apache.spark.ml.classification.LinearSVCModel" in
cache_info[0],
cache_info,
)
- assert model._java_obj._ref_count == 1
+ # the `model._summary` holds another ref to the remote model.
+ assert model._java_obj._ref_count == 2
model2 = model.copy()
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 1)
- assert model._java_obj._ref_count == 2
- assert model2._java_obj._ref_count == 2
+ assert model._java_obj._ref_count == 3
+ assert model2._java_obj._ref_count == 3
# explicitly delete the model
del model
cache_info = spark.client._get_ml_cache_info()
self.assertEqual(len(cache_info), 1)
- assert model2._java_obj._ref_count == 1
+ # Note the copied model 'model2' also holds the `_summary` object,
+ # and the `_summary` object holds another ref to the remote model.
+ # so the ref count is 2.
+ assert model2._java_obj._ref_count == 2
del model2
cache_info = spark.client._get_ml_cache_info()
@@ -99,7 +103,6 @@ class MLConnectCacheTests(ReusedConnectTestCase):
cache_info,
)
- # explicitly delete the model1
del model1
cache_info = spark.client._get_ml_cache_info()
diff --git a/python/pyspark/ml/tests/test_classification.py
b/python/pyspark/ml/tests/test_classification.py
index 57e4c0ef86dc..21bce70e8735 100644
--- a/python/pyspark/ml/tests/test_classification.py
+++ b/python/pyspark/ml/tests/test_classification.py
@@ -55,6 +55,7 @@ from pyspark.ml.classification import (
MultilayerPerceptronClassificationTrainingSummary,
)
from pyspark.ml.regression import DecisionTreeRegressionModel
+from pyspark.sql import is_remote
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -241,37 +242,45 @@ class ClassificationTestsMixin:
model = lr.fit(df)
self.assertEqual(lr.uid, model.uid)
self.assertTrue(model.hasSummary)
- s = model.summary
- # test that api is callable and returns expected types
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.probabilityCol, "probability")
- self.assertEqual(s.labelCol, "label")
- self.assertEqual(s.featuresCol, "features")
- self.assertEqual(s.predictionCol, "prediction")
- objHist = s.objectiveHistory
- self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0],
float))
- self.assertGreater(s.totalIterations, 0)
- self.assertTrue(isinstance(s.labels, list))
- self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
- self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
- self.assertTrue(isinstance(s.precisionByLabel, list))
- self.assertTrue(isinstance(s.recallByLabel, list))
- self.assertTrue(isinstance(s.fMeasureByLabel(), list))
- self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
- self.assertTrue(isinstance(s.roc, DataFrame))
- self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
- self.assertTrue(isinstance(s.pr, DataFrame))
- self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
- self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
- self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
- self.assertAlmostEqual(s.accuracy, 1.0, 2)
- self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
- self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
- self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
- self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
- self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
- self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
+ def check_summary():
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.probabilityCol, "probability")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and
isinstance(objHist[0], float))
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.labels, list))
+ self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.precisionByLabel, list))
+ self.assertTrue(isinstance(s.recallByLabel, list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+ self.assertTrue(isinstance(s.roc, DataFrame))
+ self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
+ self.assertTrue(isinstance(s.pr, DataFrame))
+ self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
+ self.assertAlmostEqual(s.accuracy, 1.0, 2)
+ self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
+ self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
+ self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
+
+ s = model.summary
# test evaluation (with training dataset) produces a summary with same
values
# one check is enough to verify a summary is returned, Scala version
runs full test
sameSummary = model.evaluate(df)
@@ -292,31 +301,39 @@ class ClassificationTestsMixin:
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight",
fitIntercept=False)
model = lr.fit(df)
self.assertTrue(model.hasSummary)
- s = model.summary
- # test that api is callable and returns expected types
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.probabilityCol, "probability")
- self.assertEqual(s.labelCol, "label")
- self.assertEqual(s.featuresCol, "features")
- self.assertEqual(s.predictionCol, "prediction")
- objHist = s.objectiveHistory
- self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0],
float))
- self.assertGreater(s.totalIterations, 0)
- self.assertTrue(isinstance(s.labels, list))
- self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
- self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
- self.assertTrue(isinstance(s.precisionByLabel, list))
- self.assertTrue(isinstance(s.recallByLabel, list))
- self.assertTrue(isinstance(s.fMeasureByLabel(), list))
- self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
- self.assertAlmostEqual(s.accuracy, 0.75, 2)
- self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
- self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
- self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
- self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
- self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2)
- self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2)
+ def check_summary():
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.probabilityCol, "probability")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and
isinstance(objHist[0], float))
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.labels, list))
+ self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.precisionByLabel, list))
+ self.assertTrue(isinstance(s.recallByLabel, list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+ self.assertAlmostEqual(s.accuracy, 0.75, 2)
+ self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
+ self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
+ self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
+ self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2)
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
+
+ s = model.summary
# test evaluation (with training dataset) produces a summary with same
values
# one check is enough to verify a summary is returned, Scala version
runs full test
sameSummary = model.evaluate(df)
@@ -426,15 +443,21 @@ class ClassificationTestsMixin:
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 4)
- # model summary
- self.assertTrue(model.hasSummary)
- summary = model.summary()
- self.assertIsInstance(summary, LinearSVCSummary)
- self.assertIsInstance(summary, LinearSVCTrainingSummary)
- self.assertEqual(summary.labels, [0.0, 1.0])
- self.assertEqual(summary.accuracy, 0.5)
- self.assertEqual(summary.areaUnderROC, 0.75)
- self.assertEqual(summary.predictions.columns, expected_cols)
+ def check_summary():
+ # model summary
+ self.assertTrue(model.hasSummary)
+ summary = model.summary()
+ self.assertIsInstance(summary, LinearSVCSummary)
+ self.assertIsInstance(summary, LinearSVCTrainingSummary)
+ self.assertEqual(summary.labels, [0.0, 1.0])
+ self.assertEqual(summary.accuracy, 0.5)
+ self.assertEqual(summary.areaUnderROC, 0.75)
+ self.assertEqual(summary.predictions.columns, expected_cols)
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
summary2 = model.evaluate(df)
self.assertIsInstance(summary2, LinearSVCSummary)
@@ -526,13 +549,20 @@ class ClassificationTestsMixin:
# model summary
self.assertTrue(model.hasSummary)
- summary = model.summary()
- self.assertIsInstance(summary, FMClassificationSummary)
- self.assertIsInstance(summary, FMClassificationTrainingSummary)
- self.assertEqual(summary.labels, [0.0, 1.0])
- self.assertEqual(summary.accuracy, 0.25)
- self.assertEqual(summary.areaUnderROC, 0.5)
- self.assertEqual(summary.predictions.columns, expected_cols)
+
+ def check_summary():
+ summary = model.summary()
+ self.assertIsInstance(summary, FMClassificationSummary)
+ self.assertIsInstance(summary, FMClassificationTrainingSummary)
+ self.assertEqual(summary.labels, [0.0, 1.0])
+ self.assertEqual(summary.accuracy, 0.25)
+ self.assertEqual(summary.areaUnderROC, 0.5)
+ self.assertEqual(summary.predictions.columns, expected_cols)
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
summary2 = model.evaluate(df)
self.assertIsInstance(summary2, FMClassificationSummary)
@@ -773,21 +803,27 @@ class ClassificationTestsMixin:
self.assertEqual(tree.transform(df).count(), 4)
self.assertEqual(tree.transform(df).columns, expected_cols)
- # model summary
- summary = model.summary
- self.assertTrue(isinstance(summary,
BinaryRandomForestClassificationSummary))
- self.assertTrue(isinstance(summary,
BinaryRandomForestClassificationTrainingSummary))
- self.assertEqual(summary.labels, [0.0, 1.0])
- self.assertEqual(summary.accuracy, 0.75)
- self.assertEqual(summary.areaUnderROC, 0.875)
- self.assertEqual(summary.predictions.columns, expected_cols)
+ def check_summary():
+ # model summary
+ summary = model.summary
+ self.assertTrue(isinstance(summary,
BinaryRandomForestClassificationSummary))
+ self.assertTrue(isinstance(summary,
BinaryRandomForestClassificationTrainingSummary))
+ self.assertEqual(summary.labels, [0.0, 1.0])
+ self.assertEqual(summary.accuracy, 0.75)
+ self.assertEqual(summary.areaUnderROC, 0.875)
+ self.assertEqual(summary.predictions.columns, expected_cols)
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
summary2 = model.evaluate(df)
self.assertTrue(isinstance(summary2,
BinaryRandomForestClassificationSummary))
self.assertFalse(isinstance(summary2,
BinaryRandomForestClassificationTrainingSummary))
self.assertEqual(summary2.labels, [0.0, 1.0])
self.assertEqual(summary2.accuracy, 0.75)
- self.assertEqual(summary.areaUnderROC, 0.875)
+ self.assertEqual(summary2.areaUnderROC, 0.875)
self.assertEqual(summary2.predictions.columns, expected_cols)
# Model save & load
@@ -859,13 +895,19 @@ class ClassificationTestsMixin:
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 4)
- # model summary
- summary = model.summary
- self.assertTrue(isinstance(summary, RandomForestClassificationSummary))
- self.assertTrue(isinstance(summary,
RandomForestClassificationTrainingSummary))
- self.assertEqual(summary.labels, [0.0, 1.0, 2.0])
- self.assertEqual(summary.accuracy, 0.5)
- self.assertEqual(summary.predictions.columns, expected_cols)
+ def check_summary():
+ # model summary
+ summary = model.summary
+ self.assertTrue(isinstance(summary,
RandomForestClassificationSummary))
+ self.assertTrue(isinstance(summary,
RandomForestClassificationTrainingSummary))
+ self.assertEqual(summary.labels, [0.0, 1.0, 2.0])
+ self.assertEqual(summary.accuracy, 0.5)
+ self.assertEqual(summary.predictions.columns, expected_cols)
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
summary2 = model.evaluate(df)
self.assertTrue(isinstance(summary2,
RandomForestClassificationSummary))
@@ -953,14 +995,20 @@ class ClassificationTestsMixin:
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 4)
- # model summary
- self.assertTrue(model.hasSummary)
- summary = model.summary()
- self.assertIsInstance(summary,
MultilayerPerceptronClassificationSummary)
- self.assertIsInstance(summary,
MultilayerPerceptronClassificationTrainingSummary)
- self.assertEqual(summary.labels, [0.0, 1.0])
- self.assertEqual(summary.accuracy, 0.75)
- self.assertEqual(summary.predictions.columns, expected_cols)
+ def check_summary():
+ # model summary
+ self.assertTrue(model.hasSummary)
+ summary = model.summary()
+ self.assertIsInstance(summary,
MultilayerPerceptronClassificationSummary)
+ self.assertIsInstance(summary,
MultilayerPerceptronClassificationTrainingSummary)
+ self.assertEqual(summary.labels, [0.0, 1.0])
+ self.assertEqual(summary.accuracy, 0.75)
+ self.assertEqual(summary.predictions.columns, expected_cols)
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
summary2 = model.evaluate(df)
self.assertIsInstance(summary2,
MultilayerPerceptronClassificationSummary)
diff --git a/python/pyspark/ml/tests/test_clustering.py
b/python/pyspark/ml/tests/test_clustering.py
index 1b8eb73135a9..fbf012babcc3 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -85,23 +85,39 @@ class ClusteringTestsMixin:
self.assertTrue(np.allclose(model.predict(Vectors.dense(0.0, 5.0)), 1,
atol=1e-4))
- # Model summary
- self.assertTrue(model.hasSummary)
- summary = model.summary
- self.assertTrue(isinstance(summary, KMeansSummary))
- self.assertEqual(summary.k, 2)
- self.assertEqual(summary.numIter, 2)
- self.assertEqual(summary.clusterSizes, [4, 2])
- self.assertTrue(np.allclose(summary.trainingCost, 1.35710375,
atol=1e-4))
+ def check_summary():
+ # Model summary
+ self.assertTrue(model.hasSummary)
+ summary = model.summary
+ self.assertTrue(isinstance(summary, KMeansSummary))
+ self.assertEqual(summary.k, 2)
+ self.assertEqual(summary.numIter, 2)
+ self.assertEqual(summary.clusterSizes, [4, 2])
+ self.assertTrue(np.allclose(summary.trainingCost, 1.35710375,
atol=1e-4))
- self.assertEqual(summary.featuresCol, "features")
- self.assertEqual(summary.predictionCol, "prediction")
+ self.assertEqual(summary.featuresCol, "features")
+ self.assertEqual(summary.predictionCol, "prediction")
- self.assertEqual(summary.cluster.columns, ["prediction"])
- self.assertEqual(summary.cluster.count(), 6)
+ self.assertEqual(summary.cluster.columns, ["prediction"])
+ self.assertEqual(summary.cluster.count(), 6)
- self.assertEqual(summary.predictions.columns, expected_cols)
- self.assertEqual(summary.predictions.count(), 6)
+ self.assertEqual(summary.predictions.columns, expected_cols)
+ self.assertEqual(summary.predictions.count(), 6)
+
+ # check summary before model offloading occurs
+ check_summary()
+
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ # check summary "try_remote_call" path after model offloading
occurs
+ self.assertEqual(model.summary.numIter, 2)
+
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ # check summary "invoke_remote_attribute_relation" path after
model offloading occurs
+ self.assertEqual(model.summary.cluster.count(), 6)
+
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
# save & load
with tempfile.TemporaryDirectory(prefix="kmeans_model") as d:
@@ -112,6 +128,9 @@ class ClusteringTestsMixin:
model.write().overwrite().save(d)
model2 = KMeansModel.load(d)
self.assertEqual(str(model), str(model2))
+ self.assertFalse(model2.hasSummary)
+ with self.assertRaisesRegex(Exception, "No training summary
available"):
+ model2.summary
def test_bisecting_kmeans(self):
df = (
@@ -278,30 +297,36 @@ class ClusteringTestsMixin:
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 6)
- # Model summary
- self.assertTrue(model.hasSummary)
- summary = model.summary
- self.assertTrue(isinstance(summary, GaussianMixtureSummary))
- self.assertEqual(summary.k, 2)
- self.assertEqual(summary.numIter, 2)
- self.assertEqual(len(summary.clusterSizes), 2)
- self.assertEqual(summary.clusterSizes, [3, 3])
- ll = summary.logLikelihood
- self.assertTrue(ll < 0, ll)
- self.assertTrue(np.allclose(ll, -1.311264553744033, atol=1e-4), ll)
-
- self.assertEqual(summary.featuresCol, "features")
- self.assertEqual(summary.predictionCol, "prediction")
- self.assertEqual(summary.probabilityCol, "probability")
-
- self.assertEqual(summary.cluster.columns, ["prediction"])
- self.assertEqual(summary.cluster.count(), 6)
-
- self.assertEqual(summary.predictions.columns, expected_cols)
- self.assertEqual(summary.predictions.count(), 6)
-
- self.assertEqual(summary.probability.columns, ["probability"])
- self.assertEqual(summary.predictions.count(), 6)
+ def check_summary():
+ # Model summary
+ self.assertTrue(model.hasSummary)
+ summary = model.summary
+ self.assertTrue(isinstance(summary, GaussianMixtureSummary))
+ self.assertEqual(summary.k, 2)
+ self.assertEqual(summary.numIter, 2)
+ self.assertEqual(len(summary.clusterSizes), 2)
+ self.assertEqual(summary.clusterSizes, [3, 3])
+ ll = summary.logLikelihood
+ self.assertTrue(ll < 0, ll)
+ self.assertTrue(np.allclose(ll, -1.311264553744033, atol=1e-4), ll)
+
+ self.assertEqual(summary.featuresCol, "features")
+ self.assertEqual(summary.predictionCol, "prediction")
+ self.assertEqual(summary.probabilityCol, "probability")
+
+ self.assertEqual(summary.cluster.columns, ["prediction"])
+ self.assertEqual(summary.cluster.count(), 6)
+
+ self.assertEqual(summary.predictions.columns, expected_cols)
+ self.assertEqual(summary.predictions.count(), 6)
+
+ self.assertEqual(summary.probability.columns, ["probability"])
+ self.assertEqual(summary.predictions.count(), 6)
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
# save & load
with tempfile.TemporaryDirectory(prefix="gaussian_mixture") as d:
diff --git a/python/pyspark/ml/tests/test_regression.py
b/python/pyspark/ml/tests/test_regression.py
index 8638fb4d6078..52688fdd63cf 100644
--- a/python/pyspark/ml/tests/test_regression.py
+++ b/python/pyspark/ml/tests/test_regression.py
@@ -43,6 +43,7 @@ from pyspark.ml.regression import (
GBTRegressor,
GBTRegressionModel,
)
+from pyspark.sql import is_remote
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -193,50 +194,58 @@ class RegressionTestsMixin:
np.allclose(model.predict(Vectors.dense(0.0, 5.0)),
0.21249999999999963, atol=1e-4)
)
- # Model summary
- summary = model.summary
- self.assertTrue(isinstance(summary, LinearRegressionSummary))
- self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary))
- self.assertEqual(summary.predictions.columns, expected_cols)
- self.assertEqual(summary.predictions.count(), 4)
- self.assertEqual(summary.residuals.columns, ["residuals"])
- self.assertEqual(summary.residuals.count(), 4)
-
- self.assertEqual(summary.degreesOfFreedom, 1)
- self.assertEqual(summary.numInstances, 4)
- self.assertEqual(summary.objectiveHistory, [0.0])
- self.assertTrue(
- np.allclose(
- summary.coefficientStandardErrors,
- [1.2859821149611763, 0.6248749874975031, 3.1645497310044184],
- atol=1e-4,
+ def check_summary():
+ # Model summary
+ summary = model.summary
+ self.assertTrue(isinstance(summary, LinearRegressionSummary))
+ self.assertTrue(isinstance(summary,
LinearRegressionTrainingSummary))
+ self.assertEqual(summary.predictions.columns, expected_cols)
+ self.assertEqual(summary.predictions.count(), 4)
+ self.assertEqual(summary.residuals.columns, ["residuals"])
+ self.assertEqual(summary.residuals.count(), 4)
+
+ self.assertEqual(summary.degreesOfFreedom, 1)
+ self.assertEqual(summary.numInstances, 4)
+ self.assertEqual(summary.objectiveHistory, [0.0])
+ self.assertTrue(
+ np.allclose(
+ summary.coefficientStandardErrors,
+ [1.2859821149611763, 0.6248749874975031,
3.1645497310044184],
+ atol=1e-4,
+ )
)
- )
- self.assertTrue(
- np.allclose(
- summary.devianceResiduals, [-0.7424621202458727,
0.7875000000000003], atol=1e-4
+ self.assertTrue(
+ np.allclose(
+ summary.devianceResiduals, [-0.7424621202458727,
0.7875000000000003], atol=1e-4
+ )
)
- )
- self.assertTrue(
- np.allclose(
- summary.pValues,
- [0.7020630236843428, 0.8866003086182783, 0.9298746994547682],
- atol=1e-4,
+ self.assertTrue(
+ np.allclose(
+ summary.pValues,
+ [0.7020630236843428, 0.8866003086182783,
0.9298746994547682],
+ atol=1e-4,
+ )
)
- )
- self.assertTrue(
- np.allclose(
- summary.tValues,
- [0.5054502643838291, 0.1800360108036021, -0.11060025272186746],
- atol=1e-4,
+ self.assertTrue(
+ np.allclose(
+ summary.tValues,
+ [0.5054502643838291, 0.1800360108036021,
-0.11060025272186746],
+ atol=1e-4,
+ )
)
- )
- self.assertTrue(np.allclose(summary.explainedVariance,
0.07997500000000031, atol=1e-4))
- self.assertTrue(np.allclose(summary.meanAbsoluteError,
0.4200000000000002, atol=1e-4))
- self.assertTrue(np.allclose(summary.meanSquaredError,
0.20212500000000005, atol=1e-4))
- self.assertTrue(np.allclose(summary.rootMeanSquaredError,
0.44958314025327956, atol=1e-4))
- self.assertTrue(np.allclose(summary.r2, 0.4427212572373862, atol=1e-4))
- self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414,
atol=1e-4))
+ self.assertTrue(np.allclose(summary.explainedVariance,
0.07997500000000031, atol=1e-4))
+ self.assertTrue(np.allclose(summary.meanAbsoluteError,
0.4200000000000002, atol=1e-4))
+ self.assertTrue(np.allclose(summary.meanSquaredError,
0.20212500000000005, atol=1e-4))
+ self.assertTrue(
+ np.allclose(summary.rootMeanSquaredError, 0.44958314025327956,
atol=1e-4)
+ )
+ self.assertTrue(np.allclose(summary.r2, 0.4427212572373862,
atol=1e-4))
+ self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414,
atol=1e-4))
+
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
summary2 = model.evaluate(df)
self.assertTrue(isinstance(summary2, LinearRegressionSummary))
@@ -318,36 +327,43 @@ class RegressionTestsMixin:
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 4)
- # Model summary
- self.assertTrue(model.hasSummary)
-
- summary = model.summary
- self.assertIsInstance(summary, GeneralizedLinearRegressionSummary)
- self.assertIsInstance(summary,
GeneralizedLinearRegressionTrainingSummary)
- self.assertEqual(summary.numIterations, 1)
- self.assertEqual(summary.numInstances, 4)
- self.assertEqual(summary.rank, 3)
- self.assertTrue(
- np.allclose(
+ def check_summary():
+ # Model summary
+ self.assertTrue(model.hasSummary)
+
+ summary = model.summary
+ self.assertIsInstance(summary, GeneralizedLinearRegressionSummary)
+ self.assertIsInstance(summary,
GeneralizedLinearRegressionTrainingSummary)
+ self.assertEqual(summary.numIterations, 1)
+ self.assertEqual(summary.numInstances, 4)
+ self.assertEqual(summary.rank, 3)
+ self.assertTrue(
+ np.allclose(
+ summary.tValues,
+ [0.3725037662281711, -0.49418209022924164,
2.6589353685797654],
+ atol=1e-4,
+ ),
summary.tValues,
- [0.3725037662281711, -0.49418209022924164, 2.6589353685797654],
- atol=1e-4,
- ),
- summary.tValues,
- )
- self.assertTrue(
- np.allclose(
+ )
+ self.assertTrue(
+ np.allclose(
+ summary.pValues,
+ [0.7729938686180984, 0.707802691825973,
0.22900885781807023],
+ atol=1e-4,
+ ),
summary.pValues,
- [0.7729938686180984, 0.707802691825973, 0.22900885781807023],
- atol=1e-4,
- ),
- summary.pValues,
- )
- self.assertEqual(summary.predictions.columns, expected_cols)
- self.assertEqual(summary.predictions.count(), 4)
- self.assertEqual(summary.residuals().columns, ["devianceResiduals"])
- self.assertEqual(summary.residuals().count(), 4)
+ )
+ self.assertEqual(summary.predictions.columns, expected_cols)
+ self.assertEqual(summary.predictions.count(), 4)
+ self.assertEqual(summary.residuals().columns,
["devianceResiduals"])
+ self.assertEqual(summary.residuals().count(), 4)
+ check_summary()
+ if is_remote():
+ self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
+ check_summary()
+
+ summary = model.summary
summary2 = model.evaluate(df)
self.assertIsInstance(summary2, GeneralizedLinearRegressionSummary)
self.assertNotIsInstance(summary2,
GeneralizedLinearRegressionTrainingSummary)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index b86178a97c38..3e55241b07e2 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -40,6 +40,7 @@ from typing import (
from contextlib import contextmanager
from pyspark import since
+from pyspark.errors.exceptions.connect import SparkException
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
from pyspark.sql.utils import is_remote
@@ -72,20 +73,6 @@ ML_CONNECT_HELPER_ID = "______ML_CONNECT_HELPER______"
_logger = logging.getLogger("pyspark.ml.util")
-def try_remote_intermediate_result(f: FuncT) -> FuncT:
- """Mark the function/property that returns the intermediate result of the
remote call.
- Eg, model.summary"""
-
- @functools.wraps(f)
- def wrapped(self: "JavaWrapper") -> Any:
- if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
- return f"{str(self._java_obj)}.{f.__name__}"
- else:
- return f(self)
-
- return cast(FuncT, wrapped)
-
-
def invoke_helper_attr(method: str, *args: Any) -> Any:
from pyspark.ml.wrapper import JavaWrapper
@@ -125,7 +112,12 @@ def invoke_remote_attribute_relation(
object_id = instance._java_obj # type: ignore
methods, obj_ref = _extract_id_methods(object_id)
methods.append(pb2.Fetch.Method(method=method,
args=serialize(session.client, *args)))
- plan = AttributeRelation(obj_ref, methods)
+
+ if methods[0].method == "summary":
+ child = instance._summary_dataset._plan # type: ignore
+ else:
+ child = None
+ plan = AttributeRelation(obj_ref, methods, child=child)
# To delay the GC of the model, keep a reference to the source instance,
# might be a model or a summary.
@@ -204,6 +196,15 @@ def try_remote_fit(f: FuncT) -> FuncT:
_logger.warning(warning_msg)
remote_model_ref = RemoteModelRef(model_info.obj_ref.id)
model = self._create_model(remote_model_ref)
+ if isinstance(model, HasTrainingSummary):
+ summary_dataset = model._summary_dataset(dataset)
+
+ summary = model._summaryCls(f"{str(model._java_obj)}.summary")
# type: ignore
+ summary._summary_dataset = summary_dataset
+ summary._remote_model_obj = model._java_obj # type: ignore
+ summary._remote_model_obj.add_ref()
+
+ model._summary = summary # type: ignore
if model.__class__.__name__ not in ["Bucketizer"]:
model._resetUid(self.uid)
return self._copyValues(model)
@@ -278,15 +279,16 @@ def try_remote_call(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any:
- if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
- # Launch a remote call if possible
- import pyspark.sql.connect.proto as pb2
- from pyspark.sql.connect.session import SparkSession
+ import pyspark.sql.connect.proto as pb2
+ from pyspark.sql.connect.session import SparkSession
+
+ session = SparkSession.getActiveSession()
+
+ def remote_call() -> Any:
from pyspark.ml.connect.util import _extract_id_methods
from pyspark.ml.connect.serialize import serialize, deserialize
from pyspark.ml.wrapper import JavaModel
- session = SparkSession.getActiveSession()
assert session is not None
if self._java_obj == ML_CONNECT_HELPER_ID:
obj_id = ML_CONNECT_HELPER_ID
@@ -315,6 +317,28 @@ def try_remote_call(f: FuncT) -> FuncT:
return model_info.obj_ref.id
else:
return deserialize(properties)
+
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ try:
+ return remote_call()
+ except SparkException as e:
+ if e.getErrorClass() == "CONNECT_ML.MODEL_SUMMARY_LOST":
+ # the model summary is lost because the remote model was
offloaded,
+ # send request to restore model.summary
+ create_summary_command = pb2.Command()
+ create_summary_command.ml_command.create_summary.CopyFrom(
+ pb2.MlCommand.CreateSummary(
+ model_ref=pb2.ObjectRef(
+ id=self._remote_model_obj.ref_id # type:
ignore
+ ),
+ dataset=self._summary_dataset._plan.plan( # type:
ignore
+ session.client # type: ignore
+ ),
+ )
+ )
+ session.client.execute_command(create_summary_command) #
type: ignore
+
+ return remote_call()
else:
return f(self, name, *args)
@@ -346,8 +370,11 @@ def try_remote_del(f: FuncT) -> FuncT:
except Exception:
return
- if in_remote and isinstance(self._java_obj, RemoteModelRef):
- self._java_obj.release_ref()
+ if in_remote:
+ if isinstance(self._java_obj, RemoteModelRef):
+ self._java_obj.release_ref()
+ if hasattr(self, "_remote_model_obj"):
+ self._remote_model_obj.release_ref()
return
else:
return f(self)
@@ -1076,17 +1103,32 @@ class HasTrainingSummary(Generic[T]):
Indicates whether a training summary exists for this model
instance.
"""
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ return hasattr(self, "_summary")
return cast("JavaWrapper", self)._call_java("hasSummary")
@property
@since("2.1.0")
- @try_remote_intermediate_result
def summary(self) -> T:
"""
Gets summary of the model trained on the training set. An exception is
thrown if
no summary exists.
"""
- return cast("JavaWrapper", self)._call_java("summary")
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ if hasattr(self, "_summary"):
+ return self._summary
+ else:
+ raise RuntimeError(
+ "No training summary available for this %s" %
self.__class__.__name__
+ )
+ return self._summaryCls(cast("JavaWrapper",
self)._call_java("summary"))
+
+ @property
+ def _summaryCls(self) -> type:
+ raise NotImplementedError()
+
+ def _summary_dataset(self, train_dataset: "DataFrame") -> "DataFrame":
+ return self.transform(train_dataset) # type: ignore
class MetaAlgorithmReadWrite:
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 34719f2b0ba6..3cfb38fdfa7d 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1985,7 +1985,7 @@ class SparkConnectClient(object):
profile_id = properties["create_resource_profile_command_result"]
return profile_id
- def _delete_ml_cache(self, cache_ids: List[str]) -> List[str]:
+ def _delete_ml_cache(self, cache_ids: List[str], evict_only: bool = False)
-> List[str]:
# try best to delete the cache
try:
if len(cache_ids) > 0:
@@ -1993,6 +1993,7 @@ class SparkConnectClient(object):
command.ml_command.delete.obj_refs.extend(
[pb2.ObjectRef(id=cache_id) for cache_id in cache_ids]
)
+ command.ml_command.delete.evict_only = evict_only
(_, properties, _) = self.execute_command(command)
assert properties is not None
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py
b/python/pyspark/sql/connect/proto/ml_pb2.py
index 46fc82131a9e..1ede558b9414 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as
spark_dot_connect_dot_ml_
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb2\x0b\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\
[...]
+
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\r\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
[...]
)
_globals = globals()
@@ -54,25 +54,27 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
_globals["_MLCOMMAND"]._serialized_start = 137
- _globals["_MLCOMMAND"]._serialized_end = 1595
- _globals["_MLCOMMAND_FIT"]._serialized_start = 631
- _globals["_MLCOMMAND_FIT"]._serialized_end = 809
- _globals["_MLCOMMAND_DELETE"]._serialized_start = 811
- _globals["_MLCOMMAND_DELETE"]._serialized_end = 872
- _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 874
- _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 886
- _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 888
- _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 902
- _globals["_MLCOMMAND_WRITE"]._serialized_start = 905
- _globals["_MLCOMMAND_WRITE"]._serialized_end = 1315
- _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1217
- _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1275
- _globals["_MLCOMMAND_READ"]._serialized_start = 1317
- _globals["_MLCOMMAND_READ"]._serialized_end = 1398
- _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1401
- _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1584
- _globals["_MLCOMMANDRESULT"]._serialized_start = 1598
- _globals["_MLCOMMANDRESULT"]._serialized_end = 2067
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1791
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2052
+ _globals["_MLCOMMAND"]._serialized_end = 1850
+ _globals["_MLCOMMAND_FIT"]._serialized_start = 712
+ _globals["_MLCOMMAND_FIT"]._serialized_end = 890
+ _globals["_MLCOMMAND_DELETE"]._serialized_start = 892
+ _globals["_MLCOMMAND_DELETE"]._serialized_end = 1004
+ _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1006
+ _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1018
+ _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1020
+ _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1034
+ _globals["_MLCOMMAND_WRITE"]._serialized_start = 1037
+ _globals["_MLCOMMAND_WRITE"]._serialized_end = 1447
+ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1349
+ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1407
+ _globals["_MLCOMMAND_READ"]._serialized_start = 1449
+ _globals["_MLCOMMAND_READ"]._serialized_end = 1530
+ _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1533
+ _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1716
+ _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1718
+ _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1839
+ _globals["_MLCOMMANDRESULT"]._serialized_start = 1853
+ _globals["_MLCOMMANDRESULT"]._serialized_end = 2322
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2046
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2307
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi
b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index 88cc6cb625de..0a72c207b526 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -118,21 +118,39 @@ class MlCommand(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
OBJ_REFS_FIELD_NUMBER: builtins.int
+ EVICT_ONLY_FIELD_NUMBER: builtins.int
@property
def obj_refs(
self,
) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.ml_common_pb2.ObjectRef
]: ...
+ evict_only: builtins.bool
+ """if set `evict_only` to true, only evict the cached model from
memory,
+ but keep the offloaded model in Spark driver local disk.
+ """
def __init__(
self,
*,
obj_refs:
collections.abc.Iterable[pyspark.sql.connect.proto.ml_common_pb2.ObjectRef]
| None = ...,
+ evict_only: builtins.bool | None = ...,
) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_evict_only", b"_evict_only", "evict_only", b"evict_only"
+ ],
+ ) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["obj_refs",
b"obj_refs"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_evict_only", b"_evict_only", "evict_only", b"evict_only",
"obj_refs", b"obj_refs"
+ ],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_evict_only",
b"_evict_only"]
+ ) -> typing_extensions.Literal["evict_only"] | None: ...
class CleanCache(google.protobuf.message.Message):
"""Force to clean up all the ML cached objects"""
@@ -342,6 +360,34 @@ class MlCommand(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_params", b"_params"]
) -> typing_extensions.Literal["params"] | None: ...
+ class CreateSummary(google.protobuf.message.Message):
+ """This is for re-creating the model summary when the model summary is
lost
+ (model summary is lost when the model is offloaded and then loaded
back)
+ """
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ MODEL_REF_FIELD_NUMBER: builtins.int
+ DATASET_FIELD_NUMBER: builtins.int
+ @property
+ def model_ref(self) ->
pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ...
+ @property
+ def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
...
+ def __init__(
+ self,
+ *,
+ model_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef |
None = ...,
+ dataset: pyspark.sql.connect.proto.relations_pb2.Relation | None =
...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal["dataset", b"dataset",
"model_ref", b"model_ref"],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal["dataset", b"dataset",
"model_ref", b"model_ref"],
+ ) -> None: ...
+
FIT_FIELD_NUMBER: builtins.int
FETCH_FIELD_NUMBER: builtins.int
DELETE_FIELD_NUMBER: builtins.int
@@ -350,6 +396,7 @@ class MlCommand(google.protobuf.message.Message):
EVALUATE_FIELD_NUMBER: builtins.int
CLEAN_CACHE_FIELD_NUMBER: builtins.int
GET_CACHE_INFO_FIELD_NUMBER: builtins.int
+ CREATE_SUMMARY_FIELD_NUMBER: builtins.int
@property
def fit(self) -> global___MlCommand.Fit: ...
@property
@@ -366,6 +413,8 @@ class MlCommand(google.protobuf.message.Message):
def clean_cache(self) -> global___MlCommand.CleanCache: ...
@property
def get_cache_info(self) -> global___MlCommand.GetCacheInfo: ...
+ @property
+ def create_summary(self) -> global___MlCommand.CreateSummary: ...
def __init__(
self,
*,
@@ -377,6 +426,7 @@ class MlCommand(google.protobuf.message.Message):
evaluate: global___MlCommand.Evaluate | None = ...,
clean_cache: global___MlCommand.CleanCache | None = ...,
get_cache_info: global___MlCommand.GetCacheInfo | None = ...,
+ create_summary: global___MlCommand.CreateSummary | None = ...,
) -> None: ...
def HasField(
self,
@@ -385,6 +435,8 @@ class MlCommand(google.protobuf.message.Message):
b"clean_cache",
"command",
b"command",
+ "create_summary",
+ b"create_summary",
"delete",
b"delete",
"evaluate",
@@ -408,6 +460,8 @@ class MlCommand(google.protobuf.message.Message):
b"clean_cache",
"command",
b"command",
+ "create_summary",
+ b"create_summary",
"delete",
b"delete",
"evaluate",
@@ -428,7 +482,15 @@ class MlCommand(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["command", b"command"]
) -> (
typing_extensions.Literal[
- "fit", "fetch", "delete", "write", "read", "evaluate",
"clean_cache", "get_cache_info"
+ "fit",
+ "fetch",
+ "delete",
+ "write",
+ "read",
+ "evaluate",
+ "clean_cache",
+ "get_cache_info",
+ "create_summary",
]
| None
): ...
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 525ba88ff67c..3774bcbdbfb0 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -43,7 +43,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as
spark_dot_connect_dot_ml_
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\x9c\x1d\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\x9c\x1d\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x [...]
)
_globals = globals()
@@ -81,169 +81,169 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_RELATION"]._serialized_start = 224
_globals["_RELATION"]._serialized_end = 3964
_globals["_MLRELATION"]._serialized_start = 3967
- _globals["_MLRELATION"]._serialized_end = 4343
- _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4097
- _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4332
- _globals["_FETCH"]._serialized_start = 4346
- _globals["_FETCH"]._serialized_end = 4677
- _globals["_FETCH_METHOD"]._serialized_start = 4462
- _globals["_FETCH_METHOD"]._serialized_end = 4677
- _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4550
- _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4677
- _globals["_UNKNOWN"]._serialized_start = 4679
- _globals["_UNKNOWN"]._serialized_end = 4688
- _globals["_RELATIONCOMMON"]._serialized_start = 4691
- _globals["_RELATIONCOMMON"]._serialized_end = 4833
- _globals["_SQL"]._serialized_start = 4836
- _globals["_SQL"]._serialized_end = 5314
- _globals["_SQL_ARGSENTRY"]._serialized_start = 5130
- _globals["_SQL_ARGSENTRY"]._serialized_end = 5220
- _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5222
- _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5314
- _globals["_WITHRELATIONS"]._serialized_start = 5316
- _globals["_WITHRELATIONS"]._serialized_end = 5433
- _globals["_READ"]._serialized_start = 5436
- _globals["_READ"]._serialized_end = 6099
- _globals["_READ_NAMEDTABLE"]._serialized_start = 5614
- _globals["_READ_NAMEDTABLE"]._serialized_end = 5806
- _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 5748
- _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 5806
- _globals["_READ_DATASOURCE"]._serialized_start = 5809
- _globals["_READ_DATASOURCE"]._serialized_end = 6086
- _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 5748
- _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 5806
- _globals["_PROJECT"]._serialized_start = 6101
- _globals["_PROJECT"]._serialized_end = 6218
- _globals["_FILTER"]._serialized_start = 6220
- _globals["_FILTER"]._serialized_end = 6332
- _globals["_JOIN"]._serialized_start = 6335
- _globals["_JOIN"]._serialized_end = 6996
- _globals["_JOIN_JOINDATATYPE"]._serialized_start = 6674
- _globals["_JOIN_JOINDATATYPE"]._serialized_end = 6766
- _globals["_JOIN_JOINTYPE"]._serialized_start = 6769
- _globals["_JOIN_JOINTYPE"]._serialized_end = 6977
- _globals["_SETOPERATION"]._serialized_start = 6999
- _globals["_SETOPERATION"]._serialized_end = 7478
- _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7315
- _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 7429
- _globals["_LIMIT"]._serialized_start = 7480
- _globals["_LIMIT"]._serialized_end = 7556
- _globals["_OFFSET"]._serialized_start = 7558
- _globals["_OFFSET"]._serialized_end = 7637
- _globals["_TAIL"]._serialized_start = 7639
- _globals["_TAIL"]._serialized_end = 7714
- _globals["_AGGREGATE"]._serialized_start = 7717
- _globals["_AGGREGATE"]._serialized_end = 8483
- _globals["_AGGREGATE_PIVOT"]._serialized_start = 8132
- _globals["_AGGREGATE_PIVOT"]._serialized_end = 8243
- _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8245
- _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8321
- _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8324
- _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 8483
- _globals["_SORT"]._serialized_start = 8486
- _globals["_SORT"]._serialized_end = 8646
- _globals["_DROP"]._serialized_start = 8649
- _globals["_DROP"]._serialized_end = 8790
- _globals["_DEDUPLICATE"]._serialized_start = 8793
- _globals["_DEDUPLICATE"]._serialized_end = 9033
- _globals["_LOCALRELATION"]._serialized_start = 9035
- _globals["_LOCALRELATION"]._serialized_end = 9124
- _globals["_CACHEDLOCALRELATION"]._serialized_start = 9126
- _globals["_CACHEDLOCALRELATION"]._serialized_end = 9198
- _globals["_CACHEDREMOTERELATION"]._serialized_start = 9200
- _globals["_CACHEDREMOTERELATION"]._serialized_end = 9255
- _globals["_SAMPLE"]._serialized_start = 9258
- _globals["_SAMPLE"]._serialized_end = 9531
- _globals["_RANGE"]._serialized_start = 9534
- _globals["_RANGE"]._serialized_end = 9679
- _globals["_SUBQUERYALIAS"]._serialized_start = 9681
- _globals["_SUBQUERYALIAS"]._serialized_end = 9795
- _globals["_REPARTITION"]._serialized_start = 9798
- _globals["_REPARTITION"]._serialized_end = 9940
- _globals["_SHOWSTRING"]._serialized_start = 9943
- _globals["_SHOWSTRING"]._serialized_end = 10085
- _globals["_HTMLSTRING"]._serialized_start = 10087
- _globals["_HTMLSTRING"]._serialized_end = 10201
- _globals["_STATSUMMARY"]._serialized_start = 10203
- _globals["_STATSUMMARY"]._serialized_end = 10295
- _globals["_STATDESCRIBE"]._serialized_start = 10297
- _globals["_STATDESCRIBE"]._serialized_end = 10378
- _globals["_STATCROSSTAB"]._serialized_start = 10380
- _globals["_STATCROSSTAB"]._serialized_end = 10481
- _globals["_STATCOV"]._serialized_start = 10483
- _globals["_STATCOV"]._serialized_end = 10579
- _globals["_STATCORR"]._serialized_start = 10582
- _globals["_STATCORR"]._serialized_end = 10719
- _globals["_STATAPPROXQUANTILE"]._serialized_start = 10722
- _globals["_STATAPPROXQUANTILE"]._serialized_end = 10886
- _globals["_STATFREQITEMS"]._serialized_start = 10888
- _globals["_STATFREQITEMS"]._serialized_end = 11013
- _globals["_STATSAMPLEBY"]._serialized_start = 11016
- _globals["_STATSAMPLEBY"]._serialized_end = 11325
- _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11217
- _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 11316
- _globals["_NAFILL"]._serialized_start = 11328
- _globals["_NAFILL"]._serialized_end = 11462
- _globals["_NADROP"]._serialized_start = 11465
- _globals["_NADROP"]._serialized_end = 11599
- _globals["_NAREPLACE"]._serialized_start = 11602
- _globals["_NAREPLACE"]._serialized_end = 11898
- _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 11757
- _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 11898
- _globals["_TODF"]._serialized_start = 11900
- _globals["_TODF"]._serialized_end = 11988
- _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 11991
- _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 12373
- _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start =
12235
- _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end =
12302
- _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 12304
- _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 12373
- _globals["_WITHCOLUMNS"]._serialized_start = 12375
- _globals["_WITHCOLUMNS"]._serialized_end = 12494
- _globals["_WITHWATERMARK"]._serialized_start = 12497
- _globals["_WITHWATERMARK"]._serialized_end = 12631
- _globals["_HINT"]._serialized_start = 12634
- _globals["_HINT"]._serialized_end = 12766
- _globals["_UNPIVOT"]._serialized_start = 12769
- _globals["_UNPIVOT"]._serialized_end = 13096
- _globals["_UNPIVOT_VALUES"]._serialized_start = 13026
- _globals["_UNPIVOT_VALUES"]._serialized_end = 13085
- _globals["_TRANSPOSE"]._serialized_start = 13098
- _globals["_TRANSPOSE"]._serialized_end = 13220
- _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13222
- _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 13347
- _globals["_TOSCHEMA"]._serialized_start = 13349
- _globals["_TOSCHEMA"]._serialized_end = 13455
- _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 13458
- _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 13661
- _globals["_MAPPARTITIONS"]._serialized_start = 13664
- _globals["_MAPPARTITIONS"]._serialized_end = 13896
- _globals["_GROUPMAP"]._serialized_start = 13899
- _globals["_GROUPMAP"]._serialized_end = 14749
- _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 14752
- _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 14975
- _globals["_COGROUPMAP"]._serialized_start = 14978
- _globals["_COGROUPMAP"]._serialized_end = 15504
- _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 15507
- _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 15864
- _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 15867
- _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16111
- _globals["_PYTHONUDTF"]._serialized_start = 16114
- _globals["_PYTHONUDTF"]._serialized_end = 16291
- _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16294
- _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 16445
- _globals["_PYTHONDATASOURCE"]._serialized_start = 16447
- _globals["_PYTHONDATASOURCE"]._serialized_end = 16522
- _globals["_COLLECTMETRICS"]._serialized_start = 16525
- _globals["_COLLECTMETRICS"]._serialized_end = 16661
- _globals["_PARSE"]._serialized_start = 16664
- _globals["_PARSE"]._serialized_end = 17052
- _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 5748
- _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 5806
- _globals["_PARSE_PARSEFORMAT"]._serialized_start = 16953
- _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17041
- _globals["_ASOFJOIN"]._serialized_start = 17055
- _globals["_ASOFJOIN"]._serialized_end = 17530
- _globals["_LATERALJOIN"]._serialized_start = 17533
- _globals["_LATERALJOIN"]._serialized_end = 17763
+ _globals["_MLRELATION"]._serialized_end = 4451
+ _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4179
+ _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4414
+ _globals["_FETCH"]._serialized_start = 4454
+ _globals["_FETCH"]._serialized_end = 4785
+ _globals["_FETCH_METHOD"]._serialized_start = 4570
+ _globals["_FETCH_METHOD"]._serialized_end = 4785
+ _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4658
+ _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4785
+ _globals["_UNKNOWN"]._serialized_start = 4787
+ _globals["_UNKNOWN"]._serialized_end = 4796
+ _globals["_RELATIONCOMMON"]._serialized_start = 4799
+ _globals["_RELATIONCOMMON"]._serialized_end = 4941
+ _globals["_SQL"]._serialized_start = 4944
+ _globals["_SQL"]._serialized_end = 5422
+ _globals["_SQL_ARGSENTRY"]._serialized_start = 5238
+ _globals["_SQL_ARGSENTRY"]._serialized_end = 5328
+ _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5330
+ _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5422
+ _globals["_WITHRELATIONS"]._serialized_start = 5424
+ _globals["_WITHRELATIONS"]._serialized_end = 5541
+ _globals["_READ"]._serialized_start = 5544
+ _globals["_READ"]._serialized_end = 6207
+ _globals["_READ_NAMEDTABLE"]._serialized_start = 5722
+ _globals["_READ_NAMEDTABLE"]._serialized_end = 5914
+ _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 5856
+ _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 5914
+ _globals["_READ_DATASOURCE"]._serialized_start = 5917
+ _globals["_READ_DATASOURCE"]._serialized_end = 6194
+ _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 5856
+ _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 5914
+ _globals["_PROJECT"]._serialized_start = 6209
+ _globals["_PROJECT"]._serialized_end = 6326
+ _globals["_FILTER"]._serialized_start = 6328
+ _globals["_FILTER"]._serialized_end = 6440
+ _globals["_JOIN"]._serialized_start = 6443
+ _globals["_JOIN"]._serialized_end = 7104
+ _globals["_JOIN_JOINDATATYPE"]._serialized_start = 6782
+ _globals["_JOIN_JOINDATATYPE"]._serialized_end = 6874
+ _globals["_JOIN_JOINTYPE"]._serialized_start = 6877
+ _globals["_JOIN_JOINTYPE"]._serialized_end = 7085
+ _globals["_SETOPERATION"]._serialized_start = 7107
+ _globals["_SETOPERATION"]._serialized_end = 7586
+ _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7423
+ _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 7537
+ _globals["_LIMIT"]._serialized_start = 7588
+ _globals["_LIMIT"]._serialized_end = 7664
+ _globals["_OFFSET"]._serialized_start = 7666
+ _globals["_OFFSET"]._serialized_end = 7745
+ _globals["_TAIL"]._serialized_start = 7747
+ _globals["_TAIL"]._serialized_end = 7822
+ _globals["_AGGREGATE"]._serialized_start = 7825
+ _globals["_AGGREGATE"]._serialized_end = 8591
+ _globals["_AGGREGATE_PIVOT"]._serialized_start = 8240
+ _globals["_AGGREGATE_PIVOT"]._serialized_end = 8351
+ _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8353
+ _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8429
+ _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8432
+ _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 8591
+ _globals["_SORT"]._serialized_start = 8594
+ _globals["_SORT"]._serialized_end = 8754
+ _globals["_DROP"]._serialized_start = 8757
+ _globals["_DROP"]._serialized_end = 8898
+ _globals["_DEDUPLICATE"]._serialized_start = 8901
+ _globals["_DEDUPLICATE"]._serialized_end = 9141
+ _globals["_LOCALRELATION"]._serialized_start = 9143
+ _globals["_LOCALRELATION"]._serialized_end = 9232
+ _globals["_CACHEDLOCALRELATION"]._serialized_start = 9234
+ _globals["_CACHEDLOCALRELATION"]._serialized_end = 9306
+ _globals["_CACHEDREMOTERELATION"]._serialized_start = 9308
+ _globals["_CACHEDREMOTERELATION"]._serialized_end = 9363
+ _globals["_SAMPLE"]._serialized_start = 9366
+ _globals["_SAMPLE"]._serialized_end = 9639
+ _globals["_RANGE"]._serialized_start = 9642
+ _globals["_RANGE"]._serialized_end = 9787
+ _globals["_SUBQUERYALIAS"]._serialized_start = 9789
+ _globals["_SUBQUERYALIAS"]._serialized_end = 9903
+ _globals["_REPARTITION"]._serialized_start = 9906
+ _globals["_REPARTITION"]._serialized_end = 10048
+ _globals["_SHOWSTRING"]._serialized_start = 10051
+ _globals["_SHOWSTRING"]._serialized_end = 10193
+ _globals["_HTMLSTRING"]._serialized_start = 10195
+ _globals["_HTMLSTRING"]._serialized_end = 10309
+ _globals["_STATSUMMARY"]._serialized_start = 10311
+ _globals["_STATSUMMARY"]._serialized_end = 10403
+ _globals["_STATDESCRIBE"]._serialized_start = 10405
+ _globals["_STATDESCRIBE"]._serialized_end = 10486
+ _globals["_STATCROSSTAB"]._serialized_start = 10488
+ _globals["_STATCROSSTAB"]._serialized_end = 10589
+ _globals["_STATCOV"]._serialized_start = 10591
+ _globals["_STATCOV"]._serialized_end = 10687
+ _globals["_STATCORR"]._serialized_start = 10690
+ _globals["_STATCORR"]._serialized_end = 10827
+ _globals["_STATAPPROXQUANTILE"]._serialized_start = 10830
+ _globals["_STATAPPROXQUANTILE"]._serialized_end = 10994
+ _globals["_STATFREQITEMS"]._serialized_start = 10996
+ _globals["_STATFREQITEMS"]._serialized_end = 11121
+ _globals["_STATSAMPLEBY"]._serialized_start = 11124
+ _globals["_STATSAMPLEBY"]._serialized_end = 11433
+ _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11325
+ _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 11424
+ _globals["_NAFILL"]._serialized_start = 11436
+ _globals["_NAFILL"]._serialized_end = 11570
+ _globals["_NADROP"]._serialized_start = 11573
+ _globals["_NADROP"]._serialized_end = 11707
+ _globals["_NAREPLACE"]._serialized_start = 11710
+ _globals["_NAREPLACE"]._serialized_end = 12006
+ _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 11865
+ _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 12006
+ _globals["_TODF"]._serialized_start = 12008
+ _globals["_TODF"]._serialized_end = 12096
+ _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 12099
+ _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 12481
+ _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start =
12343
+ _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end =
12410
+ _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 12412
+ _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 12481
+ _globals["_WITHCOLUMNS"]._serialized_start = 12483
+ _globals["_WITHCOLUMNS"]._serialized_end = 12602
+ _globals["_WITHWATERMARK"]._serialized_start = 12605
+ _globals["_WITHWATERMARK"]._serialized_end = 12739
+ _globals["_HINT"]._serialized_start = 12742
+ _globals["_HINT"]._serialized_end = 12874
+ _globals["_UNPIVOT"]._serialized_start = 12877
+ _globals["_UNPIVOT"]._serialized_end = 13204
+ _globals["_UNPIVOT_VALUES"]._serialized_start = 13134
+ _globals["_UNPIVOT_VALUES"]._serialized_end = 13193
+ _globals["_TRANSPOSE"]._serialized_start = 13206
+ _globals["_TRANSPOSE"]._serialized_end = 13328
+ _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13330
+ _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 13455
+ _globals["_TOSCHEMA"]._serialized_start = 13457
+ _globals["_TOSCHEMA"]._serialized_end = 13563
+ _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 13566
+ _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 13769
+ _globals["_MAPPARTITIONS"]._serialized_start = 13772
+ _globals["_MAPPARTITIONS"]._serialized_end = 14004
+ _globals["_GROUPMAP"]._serialized_start = 14007
+ _globals["_GROUPMAP"]._serialized_end = 14857
+ _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 14860
+ _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 15083
+ _globals["_COGROUPMAP"]._serialized_start = 15086
+ _globals["_COGROUPMAP"]._serialized_end = 15612
+ _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 15615
+ _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 15972
+ _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 15975
+ _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16219
+ _globals["_PYTHONUDTF"]._serialized_start = 16222
+ _globals["_PYTHONUDTF"]._serialized_end = 16399
+ _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16402
+ _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 16553
+ _globals["_PYTHONDATASOURCE"]._serialized_start = 16555
+ _globals["_PYTHONDATASOURCE"]._serialized_end = 16630
+ _globals["_COLLECTMETRICS"]._serialized_start = 16633
+ _globals["_COLLECTMETRICS"]._serialized_end = 16769
+ _globals["_PARSE"]._serialized_start = 16772
+ _globals["_PARSE"]._serialized_end = 17160
+ _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 5856
+ _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 5914
+ _globals["_PARSE_PARSEFORMAT"]._serialized_start = 17061
+ _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17149
+ _globals["_ASOFJOIN"]._serialized_start = 17163
+ _globals["_ASOFJOIN"]._serialized_end = 17638
+ _globals["_LATERALJOIN"]._serialized_start = 17641
+ _globals["_LATERALJOIN"]._serialized_end = 17871
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index beeeb712da76..e1eb7945c19f 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -707,28 +707,57 @@ class MlRelation(google.protobuf.message.Message):
TRANSFORM_FIELD_NUMBER: builtins.int
FETCH_FIELD_NUMBER: builtins.int
+ MODEL_SUMMARY_DATASET_FIELD_NUMBER: builtins.int
@property
def transform(self) -> global___MlRelation.Transform: ...
@property
def fetch(self) -> global___Fetch: ...
+ @property
+ def model_summary_dataset(self) -> global___Relation:
+ """(Optional) the dataset for restoring the model summary"""
def __init__(
self,
*,
transform: global___MlRelation.Transform | None = ...,
fetch: global___Fetch | None = ...,
+ model_summary_dataset: global___Relation | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
- "fetch", b"fetch", "ml_type", b"ml_type", "transform", b"transform"
+ "_model_summary_dataset",
+ b"_model_summary_dataset",
+ "fetch",
+ b"fetch",
+ "ml_type",
+ b"ml_type",
+ "model_summary_dataset",
+ b"model_summary_dataset",
+ "transform",
+ b"transform",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "fetch", b"fetch", "ml_type", b"ml_type", "transform", b"transform"
+ "_model_summary_dataset",
+ b"_model_summary_dataset",
+ "fetch",
+ b"fetch",
+ "ml_type",
+ b"ml_type",
+ "model_summary_dataset",
+ b"model_summary_dataset",
+ "transform",
+ b"transform",
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self,
+ oneof_group: typing_extensions.Literal["_model_summary_dataset",
b"_model_summary_dataset"],
+ ) -> typing_extensions.Literal["model_summary_dataset"] | None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["ml_type", b"ml_type"]
) -> typing_extensions.Literal["transform", "fetch"] | None: ...
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index b66c0a186df3..3497284af4ab 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -38,6 +38,7 @@ message MlCommand {
Evaluate evaluate = 6;
CleanCache clean_cache = 7;
GetCacheInfo get_cache_info = 8;
+ CreateSummary create_summary = 9;
}
// Command for estimator.fit(dataset)
@@ -54,6 +55,9 @@ message MlCommand {
// or summary evaluated by a model
message Delete {
repeated ObjectRef obj_refs = 1;
+ // if set `evict_only` to true, only evict the cached model from memory,
+ // but keep the offloaded model in Spark driver local disk.
+ optional bool evict_only = 2;
}
// Force to clean up all the ML cached objects
@@ -98,6 +102,13 @@ message MlCommand {
// (Required) the evaluating dataset
Relation dataset = 3;
}
+
+ // This is for re-creating the model summary when the model summary is lost
+ // (model summary is lost when the model is offloaded and then loaded back)
+ message CreateSummary {
+ ObjectRef model_ref = 1;
+ Relation dataset = 2;
+ }
}
// The result of MlCommand
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
index 70a52a211149..ccb674e812dc 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -115,6 +115,9 @@ message MlRelation {
Transform transform = 1;
Fetch fetch = 2;
}
+ // (Optional) the dataset for restoring the model summary
+ optional Relation model_summary_dataset = 3;
+
// Relation to represent transform(input) of the operator
// which could be a cached model or a new transformer
message Transform {
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index ef1b17dc2221..b075187b7002 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -30,7 +30,7 @@ import org.apache.commons.io.FileUtils
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.ml.Model
-import org.apache.spark.ml.util.{ConnectHelper, MLWritable, Summary}
+import org.apache.spark.ml.util.{ConnectHelper, HasTrainingSummary,
MLWritable, Summary}
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SessionHolder
@@ -115,6 +115,12 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
}
}
+ private[spark] def getModelOffloadingPath(refId: String): Path = {
+ val path = offloadedModelsDir.resolve(refId)
+ require(path.startsWith(offloadedModelsDir))
+ path
+ }
+
/**
* Cache an object into a map of MLCache, and return its key
* @param obj
@@ -137,9 +143,14 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
}
cachedModel.put(objectId, CacheItem(obj, sizeBytes))
if (getMemoryControlEnabled) {
- val savePath = offloadedModelsDir.resolve(objectId)
- require(savePath.startsWith(offloadedModelsDir))
+ val savePath = getModelOffloadingPath(objectId)
obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString)
+ if (obj.isInstanceOf[HasTrainingSummary[_]]
+ && obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) {
+ obj
+ .asInstanceOf[HasTrainingSummary[_]]
+ .saveSummary(savePath.resolve("summary").toString)
+ }
Files.writeString(savePath.resolve(modelClassNameFile),
obj.getClass.getName)
totalMLCacheInMemorySizeBytes.addAndGet(sizeBytes)
totalMLCacheSizeBytes.addAndGet(sizeBytes)
@@ -176,8 +187,7 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
verifyObjectId(refId)
var obj: Object =
Option(cachedModel.get(refId)).map(_.obj).getOrElse(null)
if (obj == null && getMemoryControlEnabled) {
- val loadPath = offloadedModelsDir.resolve(refId)
- require(loadPath.startsWith(offloadedModelsDir))
+ val loadPath = getModelOffloadingPath(refId)
if (Files.isDirectory(loadPath)) {
val className =
Files.readString(loadPath.resolve(modelClassNameFile))
obj = MLUtils.loadTransformer(
@@ -194,14 +204,13 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
}
}
- def _removeModel(refId: String): Boolean = {
+ def _removeModel(refId: String, evictOnly: Boolean): Boolean = {
verifyObjectId(refId)
val removedModel = cachedModel.remove(refId)
val removedFromMem = removedModel != null
- val removedFromDisk = if (removedModel != null && getMemoryControlEnabled)
{
+ val removedFromDisk = if (!evictOnly && removedModel != null &&
getMemoryControlEnabled) {
totalMLCacheSizeBytes.addAndGet(-removedModel.sizeBytes)
- val removePath = offloadedModelsDir.resolve(refId)
- require(removePath.startsWith(offloadedModelsDir))
+ val removePath = getModelOffloadingPath(refId)
val offloadingPath = new File(removePath.toString)
if (offloadingPath.exists()) {
FileUtils.deleteDirectory(offloadingPath)
@@ -220,8 +229,8 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
* @param refId
* the key used to look up the corresponding object
*/
- def remove(refId: String): Boolean = {
- val modelIsRemoved = _removeModel(refId)
+ def remove(refId: String, evictOnly: Boolean = false): Boolean = {
+ val modelIsRemoved = _removeModel(refId, evictOnly)
modelIsRemoved
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala
index a017c719ed16..847052be98a9 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala
@@ -51,3 +51,9 @@ private[spark] case class
MLCacheSizeOverflowException(mlCacheMaxSize: Long)
errorClass = "CONNECT_ML.ML_CACHE_SIZE_OVERFLOW_EXCEPTION",
messageParameters = Map("mlCacheMaxSize" -> mlCacheMaxSize.toString),
cause = null)
+
+private[spark] case class MLModelSummaryLostException(objectName: String)
+ extends SparkException(
+ errorClass = "CONNECT_ML.MODEL_SUMMARY_LOST",
+ messageParameters = Map("objectName" -> objectName),
+ cause = null)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index d40b70ba0813..7220acb8feac 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -229,9 +229,7 @@ private[connect] object MLHandler extends Logging {
if (obj != null && obj.isInstanceOf[HasTrainingSummary[_]]
&& methods(0).getMethod == "summary"
&& !obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) {
- throw MLCacheInvalidException(
- objRefId,
- sessionHolder.mlCache.getOffloadingTimeoutMinute)
+ throw MLModelSummaryLostException(objRefId)
}
val helper = AttributeHelper(sessionHolder, objRefId, methods)
val attrResult = helper.getAttribute
@@ -264,9 +262,13 @@ private[connect] object MLHandler extends Logging {
case proto.MlCommand.CommandCase.DELETE =>
val ids = mutable.ArrayBuilder.make[String]
- mlCommand.getDelete.getObjRefsList.asScala.toArray.foreach { objId =>
+ val deleteCmd = mlCommand.getDelete
+ val evictOnly = if (deleteCmd.hasEvictOnly) {
+ deleteCmd.getEvictOnly
+ } else { false }
+ deleteCmd.getObjRefsList.asScala.toArray.foreach { objId =>
if (!objId.getId.contains(".")) {
- if (mlCache.remove(objId.getId)) {
+ if (mlCache.remove(objId.getId, evictOnly)) {
ids += objId.getId
}
}
@@ -400,10 +402,29 @@ private[connect] object MLHandler extends Logging {
.setParam(LiteralValueProtoConverter.toLiteralProto(metric))
.build()
+ case proto.MlCommand.CommandCase.CREATE_SUMMARY =>
+ val createSummaryCmd = mlCommand.getCreateSummary
+ createModelSummary(sessionHolder, createSummaryCmd)
+
case other => throw MlUnsupportedException(s"$other not supported")
}
}
+ private def createModelSummary(
+ sessionHolder: SessionHolder,
+ createSummaryCmd: proto.MlCommand.CreateSummary): proto.MlCommandResult
= {
+ val refId = createSummaryCmd.getModelRef.getId
+ val model =
sessionHolder.mlCache.get(refId).asInstanceOf[HasTrainingSummary[_]]
+ val dataset = MLUtils.parseRelationProto(createSummaryCmd.getDataset,
sessionHolder)
+ val modelPath = sessionHolder.mlCache.getModelOffloadingPath(refId)
+ val summaryPath = modelPath.resolve("summary").toString
+ model.loadSummary(summaryPath, dataset)
+ proto.MlCommandResult
+ .newBuilder()
+ .setParam(LiteralValueProtoConverter.toLiteralProto(true))
+ .build()
+ }
+
def transformMLRelation(relation: proto.MlRelation, sessionHolder:
SessionHolder): DataFrame = {
relation.getMlTypeCase match {
// Ml transform
@@ -433,10 +454,26 @@ private[connect] object MLHandler extends Logging {
// Get the attribute from a cached object which could be a model or
summary
case proto.MlRelation.MlTypeCase.FETCH =>
- val helper = AttributeHelper(
- sessionHolder,
- relation.getFetch.getObjRef.getId,
- relation.getFetch.getMethodsList.asScala.toArray)
+ val objRefId = relation.getFetch.getObjRef.getId
+ val methods = relation.getFetch.getMethodsList.asScala.toArray
+ val obj = sessionHolder.mlCache.get(objRefId)
+ if (obj != null && obj.isInstanceOf[HasTrainingSummary[_]]
+ && methods(0).getMethod == "summary"
+ && !obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) {
+
+ if (relation.hasModelSummaryDataset) {
+ val dataset =
+ MLUtils.parseRelationProto(relation.getModelSummaryDataset,
sessionHolder)
+ val modelPath =
sessionHolder.mlCache.getModelOffloadingPath(objRefId)
+ val summaryPath = modelPath.resolve("summary").toString
+ obj.asInstanceOf[HasTrainingSummary[_]].loadSummary(summaryPath,
dataset)
+ } else {
+ // For old Spark client backward compatibility.
+ throw MLModelSummaryLostException(objRefId)
+ }
+ }
+
+ val helper = AttributeHelper(sessionHolder, objRefId, methods)
helper.getAttribute.asInstanceOf[DataFrame]
case other => throw MlUnsupportedException(s"$other not supported")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]