This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a28977e25801 [SPARK-51711][ML][PYTHON][CONNECT] Propagates the active remote spark session to new threads to fix CrossValidator a28977e25801 is described below commit a28977e25801bf30ea962b55339da6d2eed24863 Author: Xi Lyu <xi....@databricks.com> AuthorDate: Wed Apr 23 19:38:37 2025 +0800 [SPARK-51711][ML][PYTHON][CONNECT] Propagates the active remote spark session to new threads to fix CrossValidator ### What changes were proposed in this pull request? In SparkML with Spark Connect, the `_parallelFitTasks` fails when running `CrossValidator` fitting, as the active remote spark session is not properly propagated to the new threads. Before the PR, this code will fail in the line `cvModel = cv.fit(data)`: ``` from pyspark.ml.classification import RandomForestClassifier from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.tuning import ParamGridBuilder, CrossValidator from pyspark.ml.linalg import Vectors data = spark.createDataFrame([ (Vectors.dense(1.0, 2.0), 0), (Vectors.dense(2.0, 3.0), 1), (Vectors.dense(1.5, 2.5), 0), (Vectors.dense(3.0, 4.0), 1), (Vectors.dense(1.1, 2.1), 0), (Vectors.dense(2.5, 3.5), 1), ], ["features", "label"]) rf = RandomForestClassifier(labelCol="label", featuresCol="features") evaluator = BinaryClassificationEvaluator(labelCol="label") paramGrid = (ParamGridBuilder() .addGrid(rf.maxDepth, [2]) .addGrid(rf.numTrees, [5, 10]) .build()) cv = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3) cvModel = cv.fit(data) bestModel = cvModel.bestModel print(f"Best maxDepth: {bestModel.getMaxDepth()}") print(f"Best maxBins: {bestModel.getMaxBins()}") print(f"Best numTrees: {bestModel.getNumTrees}") ``` It fails because the active remote spark session is not properly set on that thread: ``` File ~/spark/python/pyspark/ml/util.py:250, in try_remote_call.<locals>.wrapped(self, name, *args) 247 from pyspark.ml.connect.serialize import serialize, deserialize 249 session = SparkSession.getActiveSession() --> 250 assert session is not None 251 assert isinstance(self._java_obj, str) 252 methods, obj_ref = _extract_id_methods(self._java_obj) AssertionError: ``` With this fix, the above code snippet works correctly. ### Why are the changes needed? It fixes a bug with CrossValidator fitting. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50507 from xi-db/fix-parallelFitTasks. Lead-authored-by: Xi Lyu <xi....@databricks.com> Co-authored-by: Xi Lyu <159039256+xi...@users.noreply.github.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/ml/connect/tuning.py | 2 +- python/pyspark/ml/tests/test_tuning.py | 25 ++++++++++++++++++++++++- python/pyspark/util.py | 6 ++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py index 2bbc63ef4dc2..1ef055d25007 100644 --- a/python/pyspark/ml/connect/tuning.py +++ b/python/pyspark/ml/connect/tuning.py @@ -434,7 +434,7 @@ class CrossValidator( tasks = _parallelFitTasks(est, train, eva, validation, epm) if not is_remote(): - tasks = list(map(inheritable_thread_target, tasks)) + tasks = list(map(inheritable_thread_target(dataset.sparkSession), tasks)) for j, metric in pool.imap_unordered(lambda f: f(), tasks): metrics_all[i][j] = metric diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index 0b79373f9991..5885761272a2 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -22,7 +22,7 @@ import numpy as np from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.linalg import Vectors -from pyspark.ml.classification import LogisticRegression +from pyspark.ml.classification import LogisticRegression, RandomForestClassifier from pyspark.ml.tuning import ( ParamGridBuilder, CrossValidator, @@ -94,6 +94,7 @@ class TuningTestsMixin: self.assertEqual(str(tvs_model.getEstimator()), str(model2.getEstimator())) self.assertEqual(str(tvs_model.getEvaluator()), str(model2.getEvaluator())) + @unittest.skip("Disabled due to a Python side reference count issue in _parallelFitTasks.") def test_cross_validator(self): dataset = self.spark.createDataFrame( [ @@ -246,6 +247,28 @@ class TuningTestsMixin: with self.assertRaisesRegex(Exception, "The validation data at fold 3 is empty"): cv.fit(dataset_with_folds) + def test_crossvalidator_with_random_forest_classifier(self): + dataset = self.spark.createDataFrame( + [ + (Vectors.dense(1.0, 2.0), 0), + (Vectors.dense(2.0, 3.0), 1), + (Vectors.dense(1.5, 2.5), 0), + (Vectors.dense(3.0, 4.0), 1), + (Vectors.dense(1.1, 2.1), 0), + (Vectors.dense(2.5, 3.5), 1), + ], + ["features", "label"], + ) + rf = RandomForestClassifier(labelCol="label", featuresCol="features") + evaluator = BinaryClassificationEvaluator(labelCol="label") + paramGrid = ( + ParamGridBuilder().addGrid(rf.maxDepth, [2]).addGrid(rf.numTrees, [5, 10]).build() + ) + cv = CrossValidator( + estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3 + ) + cv.fit(dataset) + class TuningTests(TuningTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 2a2bd8d6aea7..605f4f070b5a 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -397,6 +397,12 @@ def inheritable_thread_target(f: Optional[Union[Callable, "SparkSession"]] = Non @functools.wraps(ff) def inner(*args: Any, **kwargs: Any) -> Any: + # Propagates the active remote spark session to the current thread. + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + + RemoteSparkSession._set_default_and_active_session( + session # type: ignore[arg-type] + ) # Set thread locals in child thread. for attr, value in session_client_thread_local_attrs: setattr( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org