This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9d0e88894514 [SPARK-50941][ML][PYTHON][CONNECT] add supports for
TrainValidationSplit
9d0e88894514 is described below
commit 9d0e8889451464b3e38772000f5c4597b7223a0d
Author: Bobby Wang <[email protected]>
AuthorDate: Mon Jan 27 17:06:03 2025 +0800
[SPARK-50941][ML][PYTHON][CONNECT] add supports for TrainValidationSplit
### What changes were proposed in this pull request?
This PR adds support for TrainValidationSplit and TrainValidationSplitModel
on Connect
### Why are the changes needed?
new feature parity
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
The CI passes
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49688 from wbo4958/train_validation_split.
Lead-authored-by: Bobby Wang <[email protected]>
Co-authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/ml/connect/readwrite.py | 52 +++++++++++++++++++++++++++--
python/pyspark/ml/tests/test_tuning.py | 60 +++++++++++++++++++++++++++++++++-
python/pyspark/ml/tuning.py | 26 ++++++++++-----
3 files changed, 125 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/ml/connect/readwrite.py
b/python/pyspark/ml/connect/readwrite.py
index 3bf2031538d9..6392e988c067 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -19,7 +19,12 @@ from typing import cast, Type, TYPE_CHECKING, Union, List,
Dict, Any, Optional
import pyspark.sql.connect.proto as pb2
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize,
deserialize_param
-from pyspark.ml.tuning import CrossValidatorModelWriter, CrossValidatorModel
+from pyspark.ml.tuning import (
+ CrossValidatorModelWriter,
+ CrossValidatorModel,
+ TrainValidationSplitModel,
+ TrainValidationSplitModelWriter,
+)
from pyspark.ml.util import MLWriter, MLReader, RL
from pyspark.ml.wrapper import JavaWrapper
@@ -42,6 +47,19 @@ class
RemoteCrossValidatorModelWriter(CrossValidatorModelWriter):
self.session(session) # type: ignore[arg-type]
+class RemoteTrainValidationSplitModelWriter(TrainValidationSplitModelWriter):
+ def __init__(
+ self,
+ instance: "TrainValidationSplitModel",
+ optionMap: Dict[str, Any] = {},
+ session: Optional["SparkSession"] = None,
+ ):
+ super(RemoteTrainValidationSplitModelWriter, self).__init__(instance)
+ self.instance = instance
+ self.optionMap = optionMap
+ self.session(session) # type: ignore[arg-type]
+
+
class RemoteMLWriter(MLWriter):
def __init__(self, instance: "JavaMLWritable") -> None:
super().__init__()
@@ -76,7 +94,7 @@ class RemoteMLWriter(MLWriter):
from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
- from pyspark.ml.tuning import CrossValidator
+ from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -155,6 +173,20 @@ class RemoteMLWriter(MLWriter):
warnings.warn("Overwrite doesn't take effect for
CrossValidatorModel")
cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap,
session)
cvm_writer.save(path)
+ elif isinstance(instance, TrainValidationSplit):
+ from pyspark.ml.tuning import TrainValidationSplitWriter
+
+ if shouldOverwrite:
+ # TODO(SPARK-50954): Support client side model path overwrite
+ warnings.warn("Overwrite doesn't take effect for
TrainValidationSplit")
+ tvs_writer = TrainValidationSplitWriter(instance)
+ tvs_writer.save(path)
+ elif isinstance(instance, TrainValidationSplitModel):
+ if shouldOverwrite:
+ # TODO(SPARK-50954): Support client side model path overwrite
+ warnings.warn("Overwrite doesn't take effect for
TrainValidationSplitModel")
+ tvsm_writer = RemoteTrainValidationSplitModelWriter(instance,
optionMap, session)
+ tvsm_writer.save(path)
else:
raise NotImplementedError(f"Unsupported write for
{instance.__class__}")
@@ -182,7 +214,7 @@ class RemoteMLReader(MLReader[RL]):
from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
- from pyspark.ml.tuning import CrossValidator
+ from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
if (
issubclass(clazz, JavaModel)
@@ -261,5 +293,19 @@ class RemoteMLReader(MLReader[RL]):
cvm_reader.session(session)
return cvm_reader.load(path)
+ elif issubclass(clazz, TrainValidationSplit):
+ from pyspark.ml.tuning import TrainValidationSplitReader
+
+ tvs_reader = TrainValidationSplitReader(TrainValidationSplit)
+ tvs_reader.session(session)
+ return tvs_reader.load(path)
+
+ elif issubclass(clazz, TrainValidationSplitModel):
+ from pyspark.ml.tuning import TrainValidationSplitModelReader
+
+ tvs_reader =
TrainValidationSplitModelReader(TrainValidationSplitModel)
+ tvs_reader.session(session)
+ return tvs_reader.load(path)
+
else:
raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/test_tuning.py
b/python/pyspark/ml/tests/test_tuning.py
index 2bc0e22c1209..39b404d266fc 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -24,11 +24,69 @@ import numpy as np
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression
-from pyspark.ml.tuning import ParamGridBuilder, CrossValidator,
CrossValidatorModel
+from pyspark.ml.tuning import (
+ ParamGridBuilder,
+ CrossValidator,
+ CrossValidatorModel,
+ TrainValidationSplit,
+ TrainValidationSplitModel,
+)
from pyspark.testing.sqlutils import ReusedSQLTestCase
class TuningTestsMixin:
+ def test_train_validation_split(self):
+ dataset = self.spark.createDataFrame(
+ [
+ (Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0),
+ ]
+ * 10, # Repeat the data 10 times
+ ["features", "label"],
+ ).repartition(1)
+
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+
+ tvs = TrainValidationSplit(
+ estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
parallelism=1, seed=42
+ )
+ self.assertEqual(tvs.getEstimator(), lr)
+ self.assertEqual(tvs.getEvaluator(), evaluator)
+ self.assertEqual(tvs.getParallelism(), 1)
+ self.assertEqual(tvs.getEstimatorParamMaps(), grid)
+
+ tvs_model = tvs.fit(dataset)
+
+ # Access the train ratio
+ self.assertEqual(tvs_model.getTrainRatio(), 0.75)
+ print("----------- ", tvs_model.validationMetrics)
+ self.assertTrue(np.isclose(tvs_model.validationMetrics[0], 0.5,
atol=1e-4))
+ self.assertTrue(np.isclose(tvs_model.validationMetrics[1],
0.8857142857142857, atol=1e-4))
+
+ evaluation_score = evaluator.evaluate(tvs_model.transform(dataset))
+ self.assertTrue(np.isclose(evaluation_score, 0.8333333333333333,
atol=1e-4))
+
+ # save & load
+ with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
+ path1 = os.path.join(d, "cv")
+ tvs.write().save(path1)
+ tvs2 = TrainValidationSplit.load(path1)
+ self.assertEqual(str(tvs), str(tvs2))
+ self.assertEqual(str(tvs.getEstimator()), str(tvs2.getEstimator()))
+ self.assertEqual(str(tvs.getEvaluator()), str(tvs2.getEvaluator()))
+
+ path2 = os.path.join(d, "cv_model")
+ tvs_model.write().save(path2)
+ model2 = TrainValidationSplitModel.load(path2)
+ self.assertEqual(str(tvs_model), str(model2))
+ self.assertEqual(str(tvs_model.getEstimator()),
str(model2.getEstimator()))
+ self.assertEqual(str(tvs_model.getEvaluator()),
str(model2.getEvaluator()))
+
def test_cross_validator(self):
dataset = self.spark.createDataFrame(
[
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 06d3837e1ae4..e27ef955ef73 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -1186,12 +1186,12 @@ class
TrainValidationSplitReader(MLReader["TrainValidationSplit"]):
self.cls = cls
def load(self, path: str) -> "TrainValidationSplit":
- metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
metadata, estimator, evaluator, estimatorParamMaps =
_ValidatorSharedReadWrite.load(
- path, self.sc, metadata
+ path, self.sparkSession, metadata
)
tvs = TrainValidationSplit(
estimator=estimator, estimatorParamMaps=estimatorParamMaps,
evaluator=evaluator
@@ -1209,7 +1209,7 @@ class TrainValidationSplitWriter(MLWriter):
def saveImpl(self, path: str) -> None:
_ValidatorSharedReadWrite.validateParams(self.instance)
- _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
+ _ValidatorSharedReadWrite.saveImpl(path, self.instance,
self.sparkSession)
@inherit_doc
@@ -1219,15 +1219,17 @@ class
TrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]):
self.cls = cls
def load(self, path: str) -> "TrainValidationSplitModel":
- metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
metadata, estimator, evaluator, estimatorParamMaps =
_ValidatorSharedReadWrite.load(
- path, self.sc, metadata
+ path, self.sparkSession, metadata
)
bestModelPath = os.path.join(path, "bestModel")
- bestModel: Model =
DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
+ bestModel: Model = DefaultParamsReader.loadParamsInstance(
+ bestModelPath, self.sparkSession
+ )
validationMetrics = metadata["validationMetrics"]
persistSubModels = ("persistSubModels" in metadata) and
metadata["persistSubModels"]
@@ -1236,7 +1238,7 @@ class
TrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]):
for paramIndex in range(len(estimatorParamMaps)):
modelPath = os.path.join(path, "subModels",
f"{paramIndex}")
subModels[paramIndex] =
DefaultParamsReader.loadParamsInstance(
- modelPath, self.sc
+ modelPath, self.sparkSession
)
else:
subModels = None
@@ -1273,7 +1275,9 @@ class TrainValidationSplitModelWriter(MLWriter):
"validationMetrics": instance.validationMetrics,
"persistSubModels": persistSubModels,
}
- _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc,
extraMetadata=extraMetadata)
+ _ValidatorSharedReadWrite.saveImpl(
+ path, instance, self.sparkSession, extraMetadata=extraMetadata
+ )
bestModelPath = os.path.join(path, "bestModel")
cast(MLWritable, instance.bestModel).save(bestModelPath)
if persistSubModels:
@@ -1473,7 +1477,7 @@ class TrainValidationSplit(
subModels = [None for i in range(numModels)]
tasks = map(
- inheritable_thread_target,
+ inheritable_thread_target(dataset.sparkSession),
_parallelFitTasks(est, train, eva, validation, epm,
collectSubModelsParam),
)
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
@@ -1529,6 +1533,7 @@ class TrainValidationSplit(
return newTVS
@since("2.3.0")
+ @try_remote_write
def write(self) -> MLWriter:
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1537,6 +1542,7 @@ class TrainValidationSplit(
@classmethod
@since("2.3.0")
+ @try_remote_read
def read(cls) -> TrainValidationSplitReader:
"""Returns an MLReader instance for this class."""
return TrainValidationSplitReader(cls)
@@ -1649,6 +1655,7 @@ class TrainValidationSplitModel(
)
@since("2.3.0")
+ @try_remote_write
def write(self) -> MLWriter:
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1657,6 +1664,7 @@ class TrainValidationSplitModel(
@classmethod
@since("2.3.0")
+ @try_remote_read
def read(cls) -> TrainValidationSplitModelReader:
"""Returns an MLReader instance for this class."""
return TrainValidationSplitModelReader(cls)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]