This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d9816b709003 [SPARK-51873][ML] For OneVsRest algorithm, allow using save / load to replace cache d9816b709003 is described below commit d9816b709003f47f2e9d29461856a4676a06b6f0 Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Wed Apr 23 22:00:35 2025 +0800 [SPARK-51873][ML] For OneVsRest algorithm, allow using save / load to replace cache ### What changes were proposed in this pull request? For OneVsRest algorithm, allow using save / load to replace cache ### Why are the changes needed? Dataframe persisting is not well supported in certain cases, so we need a replacement. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50672 from WeichenXu123/one-vs-rest-cache. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- python/pyspark/ml/classification.py | 127 +++++++++++++++-------------- python/pyspark/ml/tests/test_algorithms.py | 24 +++++- python/pyspark/ml/tests/test_tuning.py | 24 ++++-- python/pyspark/ml/util.py | 13 ++- 4 files changed, 113 insertions(+), 75 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index df2cbd34a35d..3f9e3fa37f72 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -89,6 +89,7 @@ from pyspark.ml.util import ( try_remote_read, try_remote_write, try_remote_attribute_relation, + _cache_spark_dataset, ) from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper from pyspark.ml.common import inherit_doc @@ -3603,46 +3604,47 @@ class OneVsRest( # persist if underlying dataset is not persistent. handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False) - if handlePersistence: - multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) - def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, CM]]]: - indices = iter(range(numClasses)) - - def trainSingleClass() -> Tuple[int, CM]: - index = next(indices) - - binaryLabelCol = "mc2b$" + str(index) - trainingDataset = multiclassLabeled.withColumn( - binaryLabelCol, - F.when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0), - ) - paramMap = dict( - [ - (classifier.labelCol, binaryLabelCol), - (classifier.featuresCol, featuresCol), - (classifier.predictionCol, predictionCol), - ] - ) - if weightCol: - paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol - return index, classifier.fit(trainingDataset, paramMap) - - return [trainSingleClass] * numClasses - - tasks = map( - inheritable_thread_target(dataset.sparkSession), - _oneClassFitTasks(numClasses), - ) - pool = ThreadPool(processes=min(self.getParallelism(), numClasses)) - - subModels = [None] * numClasses - for j, subModel in pool.imap_unordered(lambda f: f(), tasks): - assert subModels is not None - subModels[j] = subModel + with _cache_spark_dataset( + multiclassLabeled, + storageLevel=StorageLevel.MEMORY_AND_DISK, + enable=handlePersistence, + ) as multiclassLabeled: + + def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, CM]]]: + indices = iter(range(numClasses)) + + def trainSingleClass() -> Tuple[int, CM]: + index = next(indices) + + binaryLabelCol = "mc2b$" + str(index) + trainingDataset = multiclassLabeled.withColumn( + binaryLabelCol, + F.when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0), + ) + paramMap = dict( + [ + (classifier.labelCol, binaryLabelCol), + (classifier.featuresCol, featuresCol), + (classifier.predictionCol, predictionCol), + ] + ) + if weightCol: + paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol + return index, classifier.fit(trainingDataset, paramMap) + + return [trainSingleClass] * numClasses + + tasks = map( + inheritable_thread_target(dataset.sparkSession), + _oneClassFitTasks(numClasses), + ) + pool = ThreadPool(processes=min(self.getParallelism(), numClasses)) - if handlePersistence: - multiclassLabeled.unpersist() + subModels = [None] * numClasses + for j, subModel in pool.imap_unordered(lambda f: f(), tasks): + assert subModels is not None + subModels[j] = subModel return self._copyValues(OneVsRestModel(models=cast(List[ClassificationModel], subModels))) @@ -3868,32 +3870,31 @@ class OneVsRestModel( # persist if underlying dataset is not persistent. handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False) - if handlePersistence: - newDataset.persist(StorageLevel.MEMORY_AND_DISK) - - # update the accumulator column with the result of prediction of models - aggregatedDataset = newDataset - for index, model in enumerate(self.models): - rawPredictionCol = self.getRawPredictionCol() - - columns = origCols + [rawPredictionCol, accColName] - - # add temporary column to store intermediate scores and update - tmpColName = "mbc$tmp" + str(uuid.uuid4()) - transformedDataset = model.transform(aggregatedDataset).select(*columns) - updatedDataset = transformedDataset.withColumn( - tmpColName, - F.array_append(accColName, SF.vector_get(F.col(rawPredictionCol), F.lit(1))), - ) - newColumns = origCols + [tmpColName] - - # switch out the intermediate column with the accumulator column - aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed( - tmpColName, accColName - ) + with _cache_spark_dataset( + newDataset, + storageLevel=StorageLevel.MEMORY_AND_DISK, + enable=handlePersistence, + ) as newDataset: + # update the accumulator column with the result of prediction of models + aggregatedDataset = newDataset + for index, model in enumerate(self.models): + rawPredictionCol = self.getRawPredictionCol() + + columns = origCols + [rawPredictionCol, accColName] + + # add temporary column to store intermediate scores and update + tmpColName = "mbc$tmp" + str(uuid.uuid4()) + transformedDataset = model.transform(aggregatedDataset).select(*columns) + updatedDataset = transformedDataset.withColumn( + tmpColName, + F.array_append(accColName, SF.vector_get(F.col(rawPredictionCol), F.lit(1))), + ) + newColumns = origCols + [tmpColName] - if handlePersistence: - newDataset.unpersist() + # switch out the intermediate column with the accumulator column + aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed( + tmpColName, accColName + ) if self.getRawPredictionCol(): aggregatedDataset = aggregatedDataset.withColumn( diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index d0e2600a9a8b..0f5deab4e093 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os from shutil import rmtree import tempfile import unittest @@ -154,6 +154,28 @@ class OneVsRestTests(SparkSessionTestCase): ovr2 = OneVsRest(classifier=dt, weightCol="weight") self.assertIsNotNone(ovr2.fit(df)) + def test_tmp_dfs_cache(self): + from pyspark.ml.util import _SPARKML_TEMP_DFS_PATH + + with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d: + os.environ[_SPARKML_TEMP_DFS_PATH] = d + try: + df = self.spark.createDataFrame( + [ + (0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5)), + ], + ["label", "features"], + ) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, parallelism=1) + model = ovr.fit(df) + model.transform(df) + assert len(os.listdir(d)) == 0 + finally: + os.environ.pop(_SPARKML_TEMP_DFS_PATH, None) + class KMeansTests(SparkSessionTestCase): def test_kmeans_cosine_distance(self): diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index 5885761272a2..947c599b3cf2 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -73,12 +73,15 @@ class TuningTestsMixin: with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d: os.environ[_SPARKML_TEMP_DFS_PATH] = d - tvs_model2 = tvs.fit(dataset) - assert len(os.listdir(d)) == 0 - self.assertTrue(np.isclose(tvs_model2.validationMetrics[0], 0.5, atol=1e-4)) - self.assertTrue( - np.isclose(tvs_model2.validationMetrics[1], 0.8857142857142857, atol=1e-4) - ) + try: + tvs_model2 = tvs.fit(dataset) + assert len(os.listdir(d)) == 0 + self.assertTrue(np.isclose(tvs_model2.validationMetrics[0], 0.5, atol=1e-4)) + self.assertTrue( + np.isclose(tvs_model2.validationMetrics[1], 0.8857142857142857, atol=1e-4) + ) + finally: + os.environ.pop(_SPARKML_TEMP_DFS_PATH, None) # save & load with tempfile.TemporaryDirectory(prefix="train_validation_split") as d: @@ -131,9 +134,12 @@ class TuningTestsMixin: with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d: os.environ[_SPARKML_TEMP_DFS_PATH] = d - model2 = cv.fit(dataset) - assert len(os.listdir(d)) == 0 - self.assertTrue(np.isclose(model2.avgMetrics[0], 0.5, atol=1e-4)) + try: + model2 = cv.fit(dataset) + assert len(os.listdir(d)) == 0 + self.assertTrue(np.isclose(model2.avgMetrics[0], 0.5, atol=1e-4)) + finally: + os.environ.pop(_SPARKML_TEMP_DFS_PATH, None) output = model.transform(dataset) self.assertEqual( diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index c3e4b1ba87f9..6abadec74e63 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -41,6 +41,7 @@ from pyspark import since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession from pyspark.sql.utils import is_remote +from pyspark.storagelevel import StorageLevel from pyspark.util import VersionUtils if TYPE_CHECKING: @@ -1138,7 +1139,15 @@ def _remove_dfs_dir(path: str, spark_session: "SparkSession") -> None: @contextmanager -def _cache_spark_dataset(dataset: "DataFrame") -> Iterator[Any]: +def _cache_spark_dataset( + dataset: "DataFrame", + storageLevel: "StorageLevel" = StorageLevel.MEMORY_AND_DISK_DESER, + enable: bool = True, +) -> Iterator[Any]: + if not enable: + yield dataset + return + spark_session = dataset._session tmp_dfs_path = os.environ.get(_SPARKML_TEMP_DFS_PATH) @@ -1150,7 +1159,7 @@ def _cache_spark_dataset(dataset: "DataFrame") -> Iterator[Any]: finally: _remove_dfs_dir(tmp_cache_path, spark_session) else: - dataset.cache() + dataset.persist(storageLevel) try: yield dataset finally: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org