This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 40ccc0182eee [SPARK-54752][ML][CONNECT][TESTS] Test model offloading 
of LDA and FPGrowth
40ccc0182eee is described below

commit 40ccc0182eee1fd42553af49a267ce9c4cd3ba95
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 18 19:01:06 2025 +0800

    [SPARK-54752][ML][CONNECT][TESTS] Test model offloading of LDA and FPGrowth
    
    ### What changes were proposed in this pull request?
    Test model offloading of LDA and FPGrowth
    
    ### Why are the changes needed?
    to improve test coverage, the two models contain dataframe-based 
coefficient which are different from previous test cases
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #53524 from zhengruifeng/test_lda_offload.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../tests/connect/test_connect_model_offloading.py | 206 +++++++++++++++++----
 1 file changed, 170 insertions(+), 36 deletions(-)

diff --git a/python/pyspark/ml/tests/connect/test_connect_model_offloading.py 
b/python/pyspark/ml/tests/connect/test_connect_model_offloading.py
index aa0e569e3f75..92ca8f808fb8 100644
--- a/python/pyspark/ml/tests/connect/test_connect_model_offloading.py
+++ b/python/pyspark/ml/tests/connect/test_connect_model_offloading.py
@@ -18,6 +18,7 @@ import unittest
 
 import numpy as np
 
+from pyspark.sql import functions as sf
 from pyspark.ml.linalg import Vectors
 from pyspark.ml.classification import (
     LinearSVC,
@@ -29,6 +30,16 @@ from pyspark.ml.regression import (
     LinearRegressionSummary,
     LinearRegressionTrainingSummary,
 )
+from pyspark.ml.clustering import (
+    LDA,
+    LDAModel,
+    LocalLDAModel,
+    DistributedLDAModel,
+)
+from pyspark.ml.fpm import (
+    FPGrowth,
+    FPGrowthModel,
+)
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
@@ -58,39 +69,32 @@ class ModelOffloadingTests(ReusedConnectTestCase):
 
         model = svc.fit(df)
 
+        def check_model(m):
+            self.assertEqual(svc.uid, m.uid)
+            self.assertEqual(m.numClasses, 2)
+            self.assertEqual(m.predict(vec), 1.0)
+
+            self.assertTrue(m.hasSummary)
+            summary = m.summary()
+
+            self.assertIsInstance(summary, LinearSVCSummary)
+            self.assertIsInstance(summary, LinearSVCTrainingSummary)
+            self.assertEqual(summary.labels, [0.0, 1.0])
+
         # model is cached!
         # 'id: xxx, obj: class 
org.apache.spark.ml.classification.LinearSVCModel, size: xxx'
         cached = self.spark.client._get_ml_cache_info()
         self.assertEqual(len(cached), 1, cached)
         self.assertIn("class 
org.apache.spark.ml.classification.LinearSVCModel", cached[0])
 
-        self.assertEqual(svc.uid, model.uid)
-        self.assertEqual(model.numClasses, 2)
-        self.assertEqual(model.predict(vec), 1.0)
-
-        self.assertTrue(model.hasSummary)
-        summary = model.summary()
-
-        self.assertIsInstance(summary, LinearSVCSummary)
-        self.assertIsInstance(summary, LinearSVCTrainingSummary)
-        self.assertEqual(summary.labels, [0.0, 1.0])
+        check_model(model)
 
         # model is offloaded!
         self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
-
         cached = self.spark.client._get_ml_cache_info()
         self.assertEqual(len(cached), 0, cached)
 
-        self.assertEqual(svc.uid, model.uid)
-        self.assertEqual(model.numClasses, 2)
-        self.assertEqual(model.predict(vec), 1.0)
-
-        self.assertTrue(model.hasSummary)
-        summary = model.summary()
-
-        self.assertIsInstance(summary, LinearSVCSummary)
-        self.assertIsInstance(summary, LinearSVCTrainingSummary)
-        self.assertEqual(summary.labels, [0.0, 1.0])
+        check_model(model)
 
     def test_linear_regression_offloading(self):
         # force clean up the ml cache
@@ -122,35 +126,165 @@ class ModelOffloadingTests(ReusedConnectTestCase):
 
         model = lr.fit(df)
 
+        def check_model(m):
+            self.assertEqual(lr.uid, m.uid)
+            self.assertEqual(m.numFeatures, 2)
+            self.assertTrue(np.allclose(m.predict(vec), 0.21249999999999963, 
atol=1e-4))
+
+            summary = m.summary
+            self.assertTrue(isinstance(summary, LinearRegressionSummary))
+            self.assertTrue(isinstance(summary, 
LinearRegressionTrainingSummary))
+            self.assertEqual(summary.predictions.count(), 4)
+
         # model is cached!
         # 'id: xxx, obj: class 
org.apache.spark.ml.regression.LinearRegressionModel, size: xxx'
         cached = self.spark.client._get_ml_cache_info()
         self.assertEqual(len(cached), 1, cached)
         self.assertIn("class 
org.apache.spark.ml.regression.LinearRegressionModel", cached[0])
 
-        self.assertEqual(lr.uid, model.uid)
-        self.assertEqual(model.numFeatures, 2)
-        self.assertTrue(np.allclose(model.predict(vec), 0.21249999999999963, 
atol=1e-4))
-
-        summary = model.summary
-        self.assertTrue(isinstance(summary, LinearRegressionSummary))
-        self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary))
-        self.assertEqual(summary.predictions.count(), 4)
+        check_model(model)
 
         # model is offloaded!
         self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 0, cached)
+
+        check_model(model)
+
+    def test_lda_offloading(self):
+        # force clean up the ml cache
+        self.spark.client._cleanup_ml_cache()
+
+        df = (
+            self.spark.createDataFrame(
+                [
+                    [1, Vectors.dense([0.0, 1.0])],
+                    [2, Vectors.sparse(2, {0: 1.0})],
+                ],
+                ["id", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("id")
+        )
+
+        lda = LDA(k=2, optimizer="em", seed=1)
+        lda.setMaxIter(1)
+
+        self.assertEqual(lda.getK(), 2)
+        self.assertEqual(lda.getOptimizer(), "em")
+        self.assertEqual(lda.getMaxIter(), 1)
+        self.assertEqual(lda.getSeed(), 1)
+
+        model = lda.fit(df)
+
+        def check_model(m):
+            self.assertEqual(lda.uid, m.uid)
+            self.assertIsInstance(m, LDAModel)
+            self.assertNotIsInstance(m, LocalLDAModel)
+            self.assertIsInstance(m, DistributedLDAModel)
+            self.assertTrue(m.isDistributed())
+            self.assertEqual(m.vocabSize(), 2)
+
+            output = m.transform(df)
+            expected_cols = ["id", "features", "topicDistribution"]
+            self.assertEqual(output.columns, expected_cols)
+            self.assertEqual(output.count(), 2)
+
+        # model is cached!
+        # 'id: xxx, obj: class 
org.apache.spark.ml.regression.LinearRegressionModel, size: xxx'
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 1, cached)
+        self.assertIn("class 
org.apache.spark.ml.clustering.DistributedLDAModel", cached[0])
+
+        check_model(model)
 
+        # both model and local_model are is cached!
+        local_model = model.toLocal()
+        # 'id: xxx, obj: class org.apache.spark.ml.clustering.LocalLDAModel, 
size: xxx'
+        # 'id: xxx, obj: class 
org.apache.spark.ml.clustering.DistributedLDAModel, size: xxx'
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 2, cached)
+        self.assertTrue(
+            any("class org.apache.spark.ml.clustering.LocalLDAModel" in c for 
c in cached)
+        )
+        self.assertTrue(
+            any("class org.apache.spark.ml.clustering.DistributedLDAModel" in 
c for c in cached)
+        )
+
+        def check_local_model(m):
+            self.assertIsInstance(m, LDAModel)
+            self.assertIsInstance(m, LocalLDAModel)
+            self.assertNotIsInstance(m, DistributedLDAModel)
+            self.assertFalse(m.isDistributed())
+            self.assertEqual(m.vocabSize(), 2)
+
+            output = m.transform(df)
+            expected_cols = ["id", "features", "topicDistribution"]
+            self.assertEqual(output.columns, expected_cols)
+            self.assertEqual(output.count(), 2)
+
+        check_local_model(local_model)
+
+        # both model and local_model are offloaded!
+        self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
+        self.spark.client._delete_ml_cache([local_model._java_obj._ref_id], 
evict_only=True)
         cached = self.spark.client._get_ml_cache_info()
         self.assertEqual(len(cached), 0, cached)
 
-        self.assertEqual(lr.uid, model.uid)
-        self.assertEqual(model.numFeatures, 2)
-        self.assertTrue(np.allclose(model.predict(vec), 0.21249999999999963, 
atol=1e-4))
+        check_model(model)
+        check_local_model(local_model)
+
+    def test_fp_growth_offloading(self):
+        # force clean up the ml cache
+        self.spark.client._cleanup_ml_cache()
+
+        df = self.spark.createDataFrame(
+            [
+                ["r z h k p"],
+                ["z y x w v u t s"],
+                ["s x o n r"],
+                ["x z y m t s q e"],
+                ["z"],
+                ["x z y r q t p"],
+            ],
+            ["items"],
+        ).select(sf.split("items", " ").alias("items"))
+
+        fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
+        fp.setNumPartitions(1)
+
+        model = fp.fit(df)
+
+        def check_model(m):
+            self.assertIsInstance(m, FPGrowthModel)
+            self.assertEqual(fp.uid, m.uid)
+            self.assertEqual(m.freqItemsets.columns, ["items", "freq"])
+            self.assertEqual(m.freqItemsets.count(), 54)
+
+            self.assertEqual(
+                m.associationRules.columns,
+                ["antecedent", "consequent", "confidence", "lift", "support"],
+            )
+            self.assertEqual(m.associationRules.count(), 89)
+
+            output = m.transform(df)
+            self.assertEqual(output.columns, ["items", "prediction"])
+            self.assertEqual(output.count(), 6)
+
+        # model is cached!
+        # 'id: xxx, obj: class org.apache.spark.ml.fpm.FPGrowthModel, size: 
xxx'
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 1, cached)
+        self.assertIn("class org.apache.spark.ml.fpm.FPGrowthModel", cached[0])
+
+        check_model(model)
+
+        # model is offloaded!
+        self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
+        cached = self.spark.client._get_ml_cache_info()
+        self.assertEqual(len(cached), 0, cached)
 
-        summary = model.summary
-        self.assertTrue(isinstance(summary, LinearRegressionSummary))
-        self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary))
-        self.assertEqual(summary.predictions.count(), 4)
+        check_model(model)
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to