This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 3f27cb738583 [SPARK-50941][ML][PYTHON][CONNECT] add supports for 
TrainValidationSplit
3f27cb738583 is described below

commit 3f27cb738583a01f43d3f4fa8fcb06f4da4a0ca2
Author: Bobby Wang <[email protected]>
AuthorDate: Mon Jan 27 17:06:03 2025 +0800

    [SPARK-50941][ML][PYTHON][CONNECT] add supports for TrainValidationSplit
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for TrainValidationSplit and TrainValidationSplitModel 
on Connect
    
    ### Why are the changes needed?
    new feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    The CI passes
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49688 from wbo4958/train_validation_split.
    
    Lead-authored-by: Bobby Wang <[email protected]>
    Co-authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 9d0e8889451464b3e38772000f5c4597b7223a0d)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/ml/connect/readwrite.py | 52 +++++++++++++++++++++++++++--
 python/pyspark/ml/tests/test_tuning.py | 60 +++++++++++++++++++++++++++++++++-
 python/pyspark/ml/tuning.py            | 26 ++++++++++-----
 3 files changed, 125 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index 3bf2031538d9..6392e988c067 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -19,7 +19,12 @@ from typing import cast, Type, TYPE_CHECKING, Union, List, 
Dict, Any, Optional
 
 import pyspark.sql.connect.proto as pb2
 from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, 
deserialize_param
-from pyspark.ml.tuning import CrossValidatorModelWriter, CrossValidatorModel
+from pyspark.ml.tuning import (
+    CrossValidatorModelWriter,
+    CrossValidatorModel,
+    TrainValidationSplitModel,
+    TrainValidationSplitModelWriter,
+)
 from pyspark.ml.util import MLWriter, MLReader, RL
 from pyspark.ml.wrapper import JavaWrapper
 
@@ -42,6 +47,19 @@ class 
RemoteCrossValidatorModelWriter(CrossValidatorModelWriter):
         self.session(session)  # type: ignore[arg-type]
 
 
+class RemoteTrainValidationSplitModelWriter(TrainValidationSplitModelWriter):
+    def __init__(
+        self,
+        instance: "TrainValidationSplitModel",
+        optionMap: Dict[str, Any] = {},
+        session: Optional["SparkSession"] = None,
+    ):
+        super(RemoteTrainValidationSplitModelWriter, self).__init__(instance)
+        self.instance = instance
+        self.optionMap = optionMap
+        self.session(session)  # type: ignore[arg-type]
+
+
 class RemoteMLWriter(MLWriter):
     def __init__(self, instance: "JavaMLWritable") -> None:
         super().__init__()
@@ -76,7 +94,7 @@ class RemoteMLWriter(MLWriter):
         from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
-        from pyspark.ml.tuning import CrossValidator
+        from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
 
         # Spark Connect ML is built on scala Spark.ML, that means we're only
         # supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -155,6 +173,20 @@ class RemoteMLWriter(MLWriter):
                 warnings.warn("Overwrite doesn't take effect for 
CrossValidatorModel")
             cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap, 
session)
             cvm_writer.save(path)
+        elif isinstance(instance, TrainValidationSplit):
+            from pyspark.ml.tuning import TrainValidationSplitWriter
+
+            if shouldOverwrite:
+                # TODO(SPARK-50954): Support client side model path overwrite
+                warnings.warn("Overwrite doesn't take effect for 
TrainValidationSplit")
+            tvs_writer = TrainValidationSplitWriter(instance)
+            tvs_writer.save(path)
+        elif isinstance(instance, TrainValidationSplitModel):
+            if shouldOverwrite:
+                # TODO(SPARK-50954): Support client side model path overwrite
+                warnings.warn("Overwrite doesn't take effect for 
TrainValidationSplitModel")
+            tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, 
optionMap, session)
+            tvsm_writer.save(path)
         else:
             raise NotImplementedError(f"Unsupported write for 
{instance.__class__}")
 
@@ -182,7 +214,7 @@ class RemoteMLReader(MLReader[RL]):
         from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
-        from pyspark.ml.tuning import CrossValidator
+        from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
 
         if (
             issubclass(clazz, JavaModel)
@@ -261,5 +293,19 @@ class RemoteMLReader(MLReader[RL]):
             cvm_reader.session(session)
             return cvm_reader.load(path)
 
+        elif issubclass(clazz, TrainValidationSplit):
+            from pyspark.ml.tuning import TrainValidationSplitReader
+
+            tvs_reader = TrainValidationSplitReader(TrainValidationSplit)
+            tvs_reader.session(session)
+            return tvs_reader.load(path)
+
+        elif issubclass(clazz, TrainValidationSplitModel):
+            from pyspark.ml.tuning import TrainValidationSplitModelReader
+
+            tvs_reader = 
TrainValidationSplitModelReader(TrainValidationSplitModel)
+            tvs_reader.session(session)
+            return tvs_reader.load(path)
+
         else:
             raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/test_tuning.py 
b/python/pyspark/ml/tests/test_tuning.py
index 2bc0e22c1209..39b404d266fc 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -24,11 +24,69 @@ import numpy as np
 from pyspark.ml.evaluation import BinaryClassificationEvaluator
 from pyspark.ml.linalg import Vectors
 from pyspark.ml.classification import LogisticRegression
-from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, 
CrossValidatorModel
+from pyspark.ml.tuning import (
+    ParamGridBuilder,
+    CrossValidator,
+    CrossValidatorModel,
+    TrainValidationSplit,
+    TrainValidationSplitModel,
+)
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
 class TuningTestsMixin:
+    def test_train_validation_split(self):
+        dataset = self.spark.createDataFrame(
+            [
+                (Vectors.dense([0.0]), 0.0),
+                (Vectors.dense([0.4]), 1.0),
+                (Vectors.dense([0.5]), 0.0),
+                (Vectors.dense([0.6]), 1.0),
+                (Vectors.dense([1.0]), 1.0),
+            ]
+            * 10,  # Repeat the data 10 times
+            ["features", "label"],
+        ).repartition(1)
+
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+
+        tvs = TrainValidationSplit(
+            estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, 
parallelism=1, seed=42
+        )
+        self.assertEqual(tvs.getEstimator(), lr)
+        self.assertEqual(tvs.getEvaluator(), evaluator)
+        self.assertEqual(tvs.getParallelism(), 1)
+        self.assertEqual(tvs.getEstimatorParamMaps(), grid)
+
+        tvs_model = tvs.fit(dataset)
+
+        # Access the train ratio
+        self.assertEqual(tvs_model.getTrainRatio(), 0.75)
+        print("----------- ", tvs_model.validationMetrics)
+        self.assertTrue(np.isclose(tvs_model.validationMetrics[0], 0.5, 
atol=1e-4))
+        self.assertTrue(np.isclose(tvs_model.validationMetrics[1], 
0.8857142857142857, atol=1e-4))
+
+        evaluation_score = evaluator.evaluate(tvs_model.transform(dataset))
+        self.assertTrue(np.isclose(evaluation_score, 0.8333333333333333, 
atol=1e-4))
+
+        # save & load
+        with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
+            path1 = os.path.join(d, "cv")
+            tvs.write().save(path1)
+            tvs2 = TrainValidationSplit.load(path1)
+            self.assertEqual(str(tvs), str(tvs2))
+            self.assertEqual(str(tvs.getEstimator()), str(tvs2.getEstimator()))
+            self.assertEqual(str(tvs.getEvaluator()), str(tvs2.getEvaluator()))
+
+            path2 = os.path.join(d, "cv_model")
+            tvs_model.write().save(path2)
+            model2 = TrainValidationSplitModel.load(path2)
+            self.assertEqual(str(tvs_model), str(model2))
+            self.assertEqual(str(tvs_model.getEstimator()), 
str(model2.getEstimator()))
+            self.assertEqual(str(tvs_model.getEvaluator()), 
str(model2.getEvaluator()))
+
     def test_cross_validator(self):
         dataset = self.spark.createDataFrame(
             [
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 06d3837e1ae4..e27ef955ef73 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -1186,12 +1186,12 @@ class 
TrainValidationSplitReader(MLReader["TrainValidationSplit"]):
         self.cls = cls
 
     def load(self, path: str) -> "TrainValidationSplit":
-        metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+        metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
         if not DefaultParamsReader.isPythonParamsInstance(metadata):
             return JavaMLReader(self.cls).load(path)  # type: ignore[arg-type]
         else:
             metadata, estimator, evaluator, estimatorParamMaps = 
_ValidatorSharedReadWrite.load(
-                path, self.sc, metadata
+                path, self.sparkSession, metadata
             )
             tvs = TrainValidationSplit(
                 estimator=estimator, estimatorParamMaps=estimatorParamMaps, 
evaluator=evaluator
@@ -1209,7 +1209,7 @@ class TrainValidationSplitWriter(MLWriter):
 
     def saveImpl(self, path: str) -> None:
         _ValidatorSharedReadWrite.validateParams(self.instance)
-        _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
+        _ValidatorSharedReadWrite.saveImpl(path, self.instance, 
self.sparkSession)
 
 
 @inherit_doc
@@ -1219,15 +1219,17 @@ class 
TrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]):
         self.cls = cls
 
     def load(self, path: str) -> "TrainValidationSplitModel":
-        metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+        metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
         if not DefaultParamsReader.isPythonParamsInstance(metadata):
             return JavaMLReader(self.cls).load(path)  # type: ignore[arg-type]
         else:
             metadata, estimator, evaluator, estimatorParamMaps = 
_ValidatorSharedReadWrite.load(
-                path, self.sc, metadata
+                path, self.sparkSession, metadata
             )
             bestModelPath = os.path.join(path, "bestModel")
-            bestModel: Model = 
DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
+            bestModel: Model = DefaultParamsReader.loadParamsInstance(
+                bestModelPath, self.sparkSession
+            )
             validationMetrics = metadata["validationMetrics"]
             persistSubModels = ("persistSubModels" in metadata) and 
metadata["persistSubModels"]
 
@@ -1236,7 +1238,7 @@ class 
TrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]):
                 for paramIndex in range(len(estimatorParamMaps)):
                     modelPath = os.path.join(path, "subModels", 
f"{paramIndex}")
                     subModels[paramIndex] = 
DefaultParamsReader.loadParamsInstance(
-                        modelPath, self.sc
+                        modelPath, self.sparkSession
                     )
             else:
                 subModels = None
@@ -1273,7 +1275,9 @@ class TrainValidationSplitModelWriter(MLWriter):
             "validationMetrics": instance.validationMetrics,
             "persistSubModels": persistSubModels,
         }
-        _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, 
extraMetadata=extraMetadata)
+        _ValidatorSharedReadWrite.saveImpl(
+            path, instance, self.sparkSession, extraMetadata=extraMetadata
+        )
         bestModelPath = os.path.join(path, "bestModel")
         cast(MLWritable, instance.bestModel).save(bestModelPath)
         if persistSubModels:
@@ -1473,7 +1477,7 @@ class TrainValidationSplit(
             subModels = [None for i in range(numModels)]
 
         tasks = map(
-            inheritable_thread_target,
+            inheritable_thread_target(dataset.sparkSession),
             _parallelFitTasks(est, train, eva, validation, epm, 
collectSubModelsParam),
         )
         pool = ThreadPool(processes=min(self.getParallelism(), numModels))
@@ -1529,6 +1533,7 @@ class TrainValidationSplit(
         return newTVS
 
     @since("2.3.0")
+    @try_remote_write
     def write(self) -> MLWriter:
         """Returns an MLWriter instance for this ML instance."""
         if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1537,6 +1542,7 @@ class TrainValidationSplit(
 
     @classmethod
     @since("2.3.0")
+    @try_remote_read
     def read(cls) -> TrainValidationSplitReader:
         """Returns an MLReader instance for this class."""
         return TrainValidationSplitReader(cls)
@@ -1649,6 +1655,7 @@ class TrainValidationSplitModel(
         )
 
     @since("2.3.0")
+    @try_remote_write
     def write(self) -> MLWriter:
         """Returns an MLWriter instance for this ML instance."""
         if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1657,6 +1664,7 @@ class TrainValidationSplitModel(
 
     @classmethod
     @since("2.3.0")
+    @try_remote_read
     def read(cls) -> TrainValidationSplitModelReader:
         """Returns an MLReader instance for this class."""
         return TrainValidationSplitModelReader(cls)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to