This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 906c646a9473 [SPARK-54741][ML][CONNECT][TESTS] Restore
ClusteringParityTests
906c646a9473 is described below
commit 906c646a9473e69119e03c55e08b751fa7aba6d4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 18 09:39:53 2025 +0800
[SPARK-54741][ML][CONNECT][TESTS] Restore ClusteringParityTests
### What changes were proposed in this pull request?
Restore ClusteringParityTests
### Why are the changes needed?
to recover test coverage
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #53514 from zhengruifeng/restore_clu.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/ml/clustering.py | 7 +++++--
.../pyspark/ml/tests/connect/test_parity_clustering.py | 2 --
python/pyspark/ml/tests/test_clustering.py | 18 ------------------
3 files changed, 5 insertions(+), 22 deletions(-)
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 0fc2b34d1748..f6543e707680 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -1542,9 +1542,12 @@ class DistributedLDAModel(LDAModel,
JavaMLReadable["DistributedLDAModel"], JavaM
.. warning:: This involves collecting a large :py:func:`topicsMatrix`
to the driver.
"""
- model = LocalLDAModel(self._call_java("toLocal"))
if is_remote():
- return model
+ from pyspark.ml.util import RemoteModelRef
+
+ return LocalLDAModel(RemoteModelRef(self._call_java("toLocal")))
+
+ model = LocalLDAModel(self._call_java("toLocal"))
# SPARK-10931: Temporary fix to be removed once LDAModel defines Params
model._create_params_from_java()
diff --git a/python/pyspark/ml/tests/connect/test_parity_clustering.py
b/python/pyspark/ml/tests/connect/test_parity_clustering.py
index bbfd2a2aea80..99714b0d6962 100644
--- a/python/pyspark/ml/tests/connect/test_parity_clustering.py
+++ b/python/pyspark/ml/tests/connect/test_parity_clustering.py
@@ -21,8 +21,6 @@ from pyspark.ml.tests.test_clustering import
ClusteringTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-# TODO(SPARK-52764): Re-enable this test after fixing the flakiness.
[email protected]("Disabled due to flakiness, should be enabled after fixing the
issue")
class ClusteringParityTests(ClusteringTestsMixin, ReusedConnectTestCase):
pass
diff --git a/python/pyspark/ml/tests/test_clustering.py
b/python/pyspark/ml/tests/test_clustering.py
index d624b6398881..c1ec03b5ecc2 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -37,7 +37,6 @@ from pyspark.ml.clustering import (
DistributedLDAModel,
PowerIterationClustering,
)
-from pyspark.sql import is_remote
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -107,18 +106,6 @@ class ClusteringTestsMixin:
# check summary before model offloading occurs
check_summary()
- if is_remote():
- self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
- # check summary "try_remote_call" path after model offloading
occurs
- self.assertEqual(model.summary.numIter, 2)
-
- self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
- # check summary "invoke_remote_attribute_relation" path after
model offloading occurs
- self.assertEqual(model.summary.cluster.count(), 6)
-
- self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
- check_summary()
-
# save & load
with tempfile.TemporaryDirectory(prefix="kmeans_model") as d:
km.write().overwrite().save(d)
@@ -323,11 +310,6 @@ class ClusteringTestsMixin:
self.assertEqual(summary.probability.columns, ["probability"])
self.assertEqual(summary.predictions.count(), 6)
- check_summary()
- if is_remote():
- self.spark.client._delete_ml_cache([model._java_obj._ref_id],
evict_only=True)
- check_summary()
-
# save & load
with tempfile.TemporaryDirectory(prefix="gaussian_mixture") as d:
gmm.write().overwrite().save(d)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]