This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8af0331bf578 [SPARK-54742][ML][CONNECT][TESTS] Add tests for connect 
ML model offloading
8af0331bf578 is described below

commit 8af0331bf5783fe0bbb94515b39f2790389b362b
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 18 10:51:23 2025 +0800

    [SPARK-54742][ML][CONNECT][TESTS] Add tests for connect ML model offloading
    
    ### What changes were proposed in this pull request?
    Add tests for connect ML model offloading
    
    ### Why are the changes needed?
    The old ones are prone to dead lock and get removed.
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #53515 from zhengruifeng/test_offload.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   1 +
 .../tests/connect/test_connect_model_offloading.py | 165 +++++++++++++++++++++
 2 files changed, 166 insertions(+)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index ee2a953a7402..e6079ba21e9d 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1233,6 +1233,7 @@ pyspark_ml_connect = Module(
         "pyspark.ml.tests.connect.test_connect_classification",
         "pyspark.ml.tests.connect.test_connect_pipeline",
         "pyspark.ml.tests.connect.test_connect_tuning",
+        "pyspark.ml.tests.connect.test_connect_model_offloading",
         "pyspark.ml.tests.connect.test_parity_als",
         "pyspark.ml.tests.connect.test_parity_fpm",
         "pyspark.ml.tests.connect.test_parity_classification",
diff --git a/python/pyspark/ml/tests/connect/test_connect_model_offloading.py 
b/python/pyspark/ml/tests/connect/test_connect_model_offloading.py
new file mode 100644
index 000000000000..aa0e569e3f75
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_connect_model_offloading.py
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import numpy as np
+
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.classification import (
+    LinearSVC,
+    LinearSVCSummary,
+    LinearSVCTrainingSummary,
+)
+from pyspark.ml.regression import (
+    LinearRegression,
+    LinearRegressionSummary,
+    LinearRegressionTrainingSummary,
+)
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class ModelOffloadingTests(ReusedConnectTestCase):
+    def test_linear_svc_offloading(self):
+        # force clean up the ml cache
+        self.spark.client._cleanup_ml_cache()
+
+        df = (
+            self.spark.createDataFrame(
+                [
+                    (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                    (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                    (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+                    (0.0, 4.0, Vectors.dense(3.0, 3.0)),
+                ],
+                ["label", "weight", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("weight")
+        )
+        vec = Vectors.dense(0.0, 5.0)
+
+        svc = LinearSVC(maxIter=1, regParam=1.0)
+        self.assertEqual(svc.getMaxIter(), 1)
+        self.assertEqual(svc.getRegParam(), 1.0)
+
+        model = svc.fit(df)
+
+        # model is cached!
+        # 'id: xxx, obj: class 
org.apache.spark.ml.classification.LinearSVCModel, size: xxx'
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 1, cached)
+        self.assertIn("class 
org.apache.spark.ml.classification.LinearSVCModel", cached[0])
+
+        self.assertEqual(svc.uid, model.uid)
+        self.assertEqual(model.numClasses, 2)
+        self.assertEqual(model.predict(vec), 1.0)
+
+        self.assertTrue(model.hasSummary)
+        summary = model.summary()
+
+        self.assertIsInstance(summary, LinearSVCSummary)
+        self.assertIsInstance(summary, LinearSVCTrainingSummary)
+        self.assertEqual(summary.labels, [0.0, 1.0])
+
+        # model is offloaded!
+        self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
+
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 0, cached)
+
+        self.assertEqual(svc.uid, model.uid)
+        self.assertEqual(model.numClasses, 2)
+        self.assertEqual(model.predict(vec), 1.0)
+
+        self.assertTrue(model.hasSummary)
+        summary = model.summary()
+
+        self.assertIsInstance(summary, LinearSVCSummary)
+        self.assertIsInstance(summary, LinearSVCTrainingSummary)
+        self.assertEqual(summary.labels, [0.0, 1.0])
+
+    def test_linear_regression_offloading(self):
+        # force clean up the ml cache
+        self.spark.client._cleanup_ml_cache()
+
+        df = (
+            self.spark.createDataFrame(
+                [
+                    (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                    (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                    (1.5, 3.0, Vectors.dense(2.0, 1.0)),
+                    (0.7, 4.0, Vectors.dense(1.5, 3.0)),
+                ],
+                ["label", "weight", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("weight")
+        )
+        vec = Vectors.dense(0.0, 5.0)
+
+        lr = LinearRegression(
+            regParam=0.0,
+            maxIter=2,
+            solver="normal",
+            weightCol="weight",
+        )
+        self.assertEqual(lr.getRegParam(), 0)
+        self.assertEqual(lr.getMaxIter(), 2)
+
+        model = lr.fit(df)
+
+        # model is cached!
+        # 'id: xxx, obj: class 
org.apache.spark.ml.regression.LinearRegressionModel, size: xxx'
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 1, cached)
+        self.assertIn("class 
org.apache.spark.ml.regression.LinearRegressionModel", cached[0])
+
+        self.assertEqual(lr.uid, model.uid)
+        self.assertEqual(model.numFeatures, 2)
+        self.assertTrue(np.allclose(model.predict(vec), 0.21249999999999963, 
atol=1e-4))
+
+        summary = model.summary
+        self.assertTrue(isinstance(summary, LinearRegressionSummary))
+        self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary))
+        self.assertEqual(summary.predictions.count(), 4)
+
+        # model is offloaded!
+        self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
+
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 0, cached)
+
+        self.assertEqual(lr.uid, model.uid)
+        self.assertEqual(model.numFeatures, 2)
+        self.assertTrue(np.allclose(model.predict(vec), 0.21249999999999963, 
atol=1e-4))
+
+        summary = model.summary
+        self.assertTrue(isinstance(summary, LinearRegressionSummary))
+        self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary))
+        self.assertEqual(summary.predictions.count(), 4)
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_connect_model_offloading import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to