This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d9816b709003 [SPARK-51873][ML] For OneVsRest algorithm, allow using 
save / load to replace cache
d9816b709003 is described below

commit d9816b709003f47f2e9d29461856a4676a06b6f0
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Wed Apr 23 22:00:35 2025 +0800

    [SPARK-51873][ML] For OneVsRest algorithm, allow using save / load to 
replace cache
    
    ### What changes were proposed in this pull request?
    
    For OneVsRest algorithm, allow using save / load to replace cache
    
    ### Why are the changes needed?
    
    Dataframe persisting is not well supported in certain cases, so we need a 
replacement.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50672 from WeichenXu123/one-vs-rest-cache.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 python/pyspark/ml/classification.py        | 127 +++++++++++++++--------------
 python/pyspark/ml/tests/test_algorithms.py |  24 +++++-
 python/pyspark/ml/tests/test_tuning.py     |  24 ++++--
 python/pyspark/ml/util.py                  |  13 ++-
 4 files changed, 113 insertions(+), 75 deletions(-)

diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index df2cbd34a35d..3f9e3fa37f72 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -89,6 +89,7 @@ from pyspark.ml.util import (
     try_remote_read,
     try_remote_write,
     try_remote_attribute_relation,
+    _cache_spark_dataset,
 )
 from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, 
JavaWrapper
 from pyspark.ml.common import inherit_doc
@@ -3603,46 +3604,47 @@ class OneVsRest(
 
         # persist if underlying dataset is not persistent.
         handlePersistence = dataset.storageLevel == StorageLevel(False, False, 
False, False)
-        if handlePersistence:
-            multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
 
-        def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, 
CM]]]:
-            indices = iter(range(numClasses))
-
-            def trainSingleClass() -> Tuple[int, CM]:
-                index = next(indices)
-
-                binaryLabelCol = "mc2b$" + str(index)
-                trainingDataset = multiclassLabeled.withColumn(
-                    binaryLabelCol,
-                    F.when(multiclassLabeled[labelCol] == float(index), 
1.0).otherwise(0.0),
-                )
-                paramMap = dict(
-                    [
-                        (classifier.labelCol, binaryLabelCol),
-                        (classifier.featuresCol, featuresCol),
-                        (classifier.predictionCol, predictionCol),
-                    ]
-                )
-                if weightCol:
-                    paramMap[cast(HasWeightCol, classifier).weightCol] = 
weightCol
-                return index, classifier.fit(trainingDataset, paramMap)
-
-            return [trainSingleClass] * numClasses
-
-        tasks = map(
-            inheritable_thread_target(dataset.sparkSession),
-            _oneClassFitTasks(numClasses),
-        )
-        pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
-
-        subModels = [None] * numClasses
-        for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
-            assert subModels is not None
-            subModels[j] = subModel
+        with _cache_spark_dataset(
+            multiclassLabeled,
+            storageLevel=StorageLevel.MEMORY_AND_DISK,
+            enable=handlePersistence,
+        ) as multiclassLabeled:
+
+            def _oneClassFitTasks(numClasses: int) -> List[Callable[[], 
Tuple[int, CM]]]:
+                indices = iter(range(numClasses))
+
+                def trainSingleClass() -> Tuple[int, CM]:
+                    index = next(indices)
+
+                    binaryLabelCol = "mc2b$" + str(index)
+                    trainingDataset = multiclassLabeled.withColumn(
+                        binaryLabelCol,
+                        F.when(multiclassLabeled[labelCol] == float(index), 
1.0).otherwise(0.0),
+                    )
+                    paramMap = dict(
+                        [
+                            (classifier.labelCol, binaryLabelCol),
+                            (classifier.featuresCol, featuresCol),
+                            (classifier.predictionCol, predictionCol),
+                        ]
+                    )
+                    if weightCol:
+                        paramMap[cast(HasWeightCol, classifier).weightCol] = 
weightCol
+                    return index, classifier.fit(trainingDataset, paramMap)
+
+                return [trainSingleClass] * numClasses
+
+            tasks = map(
+                inheritable_thread_target(dataset.sparkSession),
+                _oneClassFitTasks(numClasses),
+            )
+            pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
 
-        if handlePersistence:
-            multiclassLabeled.unpersist()
+            subModels = [None] * numClasses
+            for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
+                assert subModels is not None
+                subModels[j] = subModel
 
         return 
self._copyValues(OneVsRestModel(models=cast(List[ClassificationModel], 
subModels)))
 
@@ -3868,32 +3870,31 @@ class OneVsRestModel(
 
         # persist if underlying dataset is not persistent.
         handlePersistence = dataset.storageLevel == StorageLevel(False, False, 
False, False)
-        if handlePersistence:
-            newDataset.persist(StorageLevel.MEMORY_AND_DISK)
-
-        # update the accumulator column with the result of prediction of models
-        aggregatedDataset = newDataset
-        for index, model in enumerate(self.models):
-            rawPredictionCol = self.getRawPredictionCol()
-
-            columns = origCols + [rawPredictionCol, accColName]
-
-            # add temporary column to store intermediate scores and update
-            tmpColName = "mbc$tmp" + str(uuid.uuid4())
-            transformedDataset = 
model.transform(aggregatedDataset).select(*columns)
-            updatedDataset = transformedDataset.withColumn(
-                tmpColName,
-                F.array_append(accColName, 
SF.vector_get(F.col(rawPredictionCol), F.lit(1))),
-            )
-            newColumns = origCols + [tmpColName]
-
-            # switch out the intermediate column with the accumulator column
-            aggregatedDataset = 
updatedDataset.select(*newColumns).withColumnRenamed(
-                tmpColName, accColName
-            )
+        with _cache_spark_dataset(
+            newDataset,
+            storageLevel=StorageLevel.MEMORY_AND_DISK,
+            enable=handlePersistence,
+        ) as newDataset:
+            # update the accumulator column with the result of prediction of 
models
+            aggregatedDataset = newDataset
+            for index, model in enumerate(self.models):
+                rawPredictionCol = self.getRawPredictionCol()
+
+                columns = origCols + [rawPredictionCol, accColName]
+
+                # add temporary column to store intermediate scores and update
+                tmpColName = "mbc$tmp" + str(uuid.uuid4())
+                transformedDataset = 
model.transform(aggregatedDataset).select(*columns)
+                updatedDataset = transformedDataset.withColumn(
+                    tmpColName,
+                    F.array_append(accColName, 
SF.vector_get(F.col(rawPredictionCol), F.lit(1))),
+                )
+                newColumns = origCols + [tmpColName]
 
-        if handlePersistence:
-            newDataset.unpersist()
+                # switch out the intermediate column with the accumulator 
column
+                aggregatedDataset = 
updatedDataset.select(*newColumns).withColumnRenamed(
+                    tmpColName, accColName
+                )
 
         if self.getRawPredictionCol():
             aggregatedDataset = aggregatedDataset.withColumn(
diff --git a/python/pyspark/ml/tests/test_algorithms.py 
b/python/pyspark/ml/tests/test_algorithms.py
index d0e2600a9a8b..0f5deab4e093 100644
--- a/python/pyspark/ml/tests/test_algorithms.py
+++ b/python/pyspark/ml/tests/test_algorithms.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import os
 from shutil import rmtree
 import tempfile
 import unittest
@@ -154,6 +154,28 @@ class OneVsRestTests(SparkSessionTestCase):
         ovr2 = OneVsRest(classifier=dt, weightCol="weight")
         self.assertIsNotNone(ovr2.fit(df))
 
+    def test_tmp_dfs_cache(self):
+        from pyspark.ml.util import _SPARKML_TEMP_DFS_PATH
+
+        with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
+            os.environ[_SPARKML_TEMP_DFS_PATH] = d
+            try:
+                df = self.spark.createDataFrame(
+                    [
+                        (0.0, Vectors.dense(1.0, 0.8)),
+                        (1.0, Vectors.sparse(2, [], [])),
+                        (2.0, Vectors.dense(0.5, 0.5)),
+                    ],
+                    ["label", "features"],
+                )
+                lr = LogisticRegression(maxIter=5, regParam=0.01)
+                ovr = OneVsRest(classifier=lr, parallelism=1)
+                model = ovr.fit(df)
+                model.transform(df)
+                assert len(os.listdir(d)) == 0
+            finally:
+                os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
+
 
 class KMeansTests(SparkSessionTestCase):
     def test_kmeans_cosine_distance(self):
diff --git a/python/pyspark/ml/tests/test_tuning.py 
b/python/pyspark/ml/tests/test_tuning.py
index 5885761272a2..947c599b3cf2 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -73,12 +73,15 @@ class TuningTestsMixin:
 
         with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
             os.environ[_SPARKML_TEMP_DFS_PATH] = d
-            tvs_model2 = tvs.fit(dataset)
-            assert len(os.listdir(d)) == 0
-            self.assertTrue(np.isclose(tvs_model2.validationMetrics[0], 0.5, 
atol=1e-4))
-            self.assertTrue(
-                np.isclose(tvs_model2.validationMetrics[1], 
0.8857142857142857, atol=1e-4)
-            )
+            try:
+                tvs_model2 = tvs.fit(dataset)
+                assert len(os.listdir(d)) == 0
+                self.assertTrue(np.isclose(tvs_model2.validationMetrics[0], 
0.5, atol=1e-4))
+                self.assertTrue(
+                    np.isclose(tvs_model2.validationMetrics[1], 
0.8857142857142857, atol=1e-4)
+                )
+            finally:
+                os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
 
         # save & load
         with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
@@ -131,9 +134,12 @@ class TuningTestsMixin:
 
         with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
             os.environ[_SPARKML_TEMP_DFS_PATH] = d
-            model2 = cv.fit(dataset)
-            assert len(os.listdir(d)) == 0
-            self.assertTrue(np.isclose(model2.avgMetrics[0], 0.5, atol=1e-4))
+            try:
+                model2 = cv.fit(dataset)
+                assert len(os.listdir(d)) == 0
+                self.assertTrue(np.isclose(model2.avgMetrics[0], 0.5, 
atol=1e-4))
+            finally:
+                os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
 
         output = model.transform(dataset)
         self.assertEqual(
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index c3e4b1ba87f9..6abadec74e63 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -41,6 +41,7 @@ from pyspark import since
 from pyspark.ml.common import inherit_doc
 from pyspark.sql import SparkSession
 from pyspark.sql.utils import is_remote
+from pyspark.storagelevel import StorageLevel
 from pyspark.util import VersionUtils
 
 if TYPE_CHECKING:
@@ -1138,7 +1139,15 @@ def _remove_dfs_dir(path: str, spark_session: 
"SparkSession") -> None:
 
 
 @contextmanager
-def _cache_spark_dataset(dataset: "DataFrame") -> Iterator[Any]:
+def _cache_spark_dataset(
+    dataset: "DataFrame",
+    storageLevel: "StorageLevel" = StorageLevel.MEMORY_AND_DISK_DESER,
+    enable: bool = True,
+) -> Iterator[Any]:
+    if not enable:
+        yield dataset
+        return
+
     spark_session = dataset._session
     tmp_dfs_path = os.environ.get(_SPARKML_TEMP_DFS_PATH)
 
@@ -1150,7 +1159,7 @@ def _cache_spark_dataset(dataset: "DataFrame") -> 
Iterator[Any]:
         finally:
             _remove_dfs_dir(tmp_cache_path, spark_session)
     else:
-        dataset.cache()
+        dataset.persist(storageLevel)
         try:
             yield dataset
         finally:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to