This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ba74bf9b509 [SPARK-50898][ML][PYTHON][CONNECT] Support `FPGrowth` on 
connect
3ba74bf9b509 is described below

commit 3ba74bf9b509e1cddbda6bb4849782e26fa840ed
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 21 16:03:13 2025 +0800

    [SPARK-50898][ML][PYTHON][CONNECT] Support `FPGrowth` on connect
    
    ### What changes were proposed in this pull request?
    Support `FPGrowth` on connect
    
    ### Why are the changes needed?
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, new algorithms supported on connect
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49579 from zhengruifeng/ml_connect_fpm.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |  2 +
 .../services/org.apache.spark.ml.Estimator         |  4 +
 .../services/org.apache.spark.ml.Transformer       |  3 +
 .../scala/org/apache/spark/ml/fpm/FPGrowth.scala   |  2 +
 python/pyspark/ml/fpm.py                           |  4 +-
 .../pyspark/ml/tests/connect/test_parity_fpm.py    | 30 +++----
 python/pyspark/ml/tests/test_fpm.py                | 94 ++++++++++++++++++++++
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  4 +-
 8 files changed, 124 insertions(+), 19 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index cacd4a83bbe4..5fd3f7377276 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -664,6 +664,7 @@ pyspark_ml = Module(
         # unittests
         "pyspark.ml.tests.test_algorithms",
         "pyspark.ml.tests.test_als",
+        "pyspark.ml.tests.test_fpm",
         "pyspark.ml.tests.test_base",
         "pyspark.ml.tests.test_evaluation",
         "pyspark.ml.tests.test_feature",
@@ -1119,6 +1120,7 @@ pyspark_ml_connect = Module(
         "pyspark.ml.tests.connect.test_connect_pipeline",
         "pyspark.ml.tests.connect.test_connect_tuning",
         "pyspark.ml.tests.connect.test_parity_als",
+        "pyspark.ml.tests.connect.test_parity_fpm",
         "pyspark.ml.tests.connect.test_parity_classification",
         "pyspark.ml.tests.connect.test_parity_regression",
         "pyspark.ml.tests.connect.test_parity_clustering",
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index a7d7d3da9df3..4046cca07dc0 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -39,3 +39,7 @@ org.apache.spark.ml.clustering.BisectingKMeans
 
 # recommendation
 org.apache.spark.ml.recommendation.ALS
+
+
+# fpm
+org.apache.spark.ml.fpm.FPGrowth
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 392115be98ba..7c10796f9a87 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -38,3 +38,6 @@ org.apache.spark.ml.clustering.BisectingKMeansModel
 
 # recommendation
 org.apache.spark.ml.recommendation.ALSModel
+
+# fpm
+org.apache.spark.ml.fpm.FPGrowthModel
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index d054ea8ebdb4..d90124c62d54 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -223,6 +223,8 @@ class FPGrowthModel private[ml] (
     private val numTrainingRecords: Long)
   extends Model[FPGrowthModel] with FPGrowthParams with MLWritable {
 
+  private[ml] def this() = this(Identifiable.randomUID("fpgrowth"), null, 
Map.empty, 0L)
+
   /** @group setParam */
   @Since("2.2.0")
   def setMinConfidence(value: Double): this.type = set(minConfidence, value)
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 72fcfccf19e4..c068b5f26ba8 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -20,7 +20,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
 
 from pyspark import keyword_only, since
 from pyspark.sql import DataFrame
-from pyspark.ml.util import JavaMLWritable, JavaMLReadable
+from pyspark.ml.util import JavaMLWritable, JavaMLReadable, 
try_remote_attribute_relation
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
 from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters, 
Params
 
@@ -126,6 +126,7 @@ class FPGrowthModel(JavaModel, _FPGrowthParams, 
JavaMLWritable, JavaMLReadable["
 
     @property
     @since("2.2.0")
+    @try_remote_attribute_relation
     def freqItemsets(self) -> DataFrame:
         """
         DataFrame with two columns:
@@ -136,6 +137,7 @@ class FPGrowthModel(JavaModel, _FPGrowthParams, 
JavaMLWritable, JavaMLReadable["
 
     @property
     @since("2.2.0")
+    @try_remote_attribute_relation
     def associationRules(self) -> DataFrame:
         """
         DataFrame with four columns:
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator 
b/python/pyspark/ml/tests/connect/test_parity_fpm.py
similarity index 50%
copy from 
mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
copy to python/pyspark/ml/tests/connect/test_parity_fpm.py
index a7d7d3da9df3..85ceba87a2f5 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/python/pyspark/ml/tests/connect/test_parity_fpm.py
@@ -15,27 +15,23 @@
 # limitations under the License.
 #
 
-# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
estimators.
-# So register the supported estimator here if you're trying to add a new one.
+import unittest
 
-# classification
-org.apache.spark.ml.classification.LogisticRegression
-org.apache.spark.ml.classification.DecisionTreeClassifier
-org.apache.spark.ml.classification.RandomForestClassifier
-org.apache.spark.ml.classification.GBTClassifier
+from pyspark.ml.tests.test_fpm import FPMTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-# regression
-org.apache.spark.ml.regression.LinearRegression
-org.apache.spark.ml.regression.DecisionTreeRegressor
-org.apache.spark.ml.regression.RandomForestRegressor
-org.apache.spark.ml.regression.GBTRegressor
+class FPMParityTests(FPMTestsMixin, ReusedConnectTestCase):
+    pass
 
 
-# clustering
-org.apache.spark.ml.clustering.KMeans
-org.apache.spark.ml.clustering.BisectingKMeans
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_parity_fpm import *  # noqa: F401
 
+    try:
+        import xmlrunner  # type: ignore[import]
 
-# recommendation
-org.apache.spark.ml.recommendation.ALS
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_fpm.py 
b/python/pyspark/ml/tests/test_fpm.py
new file mode 100644
index 000000000000..8db35158978d
--- /dev/null
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+
+from pyspark.sql import SparkSession
+import pyspark.sql.functions as sf
+from pyspark.ml.fpm import (
+    FPGrowth,
+    FPGrowthModel,
+)
+
+
+class FPMTestsMixin:
+    def test_fp_growth(self):
+        df = self.spark.createDataFrame(
+            [
+                ["r z h k p"],
+                ["z y x w v u t s"],
+                ["s x o n r"],
+                ["x z y m t s q e"],
+                ["z"],
+                ["x z y r q t p"],
+            ],
+            ["items"],
+        ).select(sf.split("items", " ").alias("items"))
+
+        fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
+        fp.setNumPartitions(1)
+        self.assertEqual(fp.getMinSupport(), 0.2)
+        self.assertEqual(fp.getMinConfidence(), 0.7)
+        self.assertEqual(fp.getNumPartitions(), 1)
+
+        # Estimator save & load
+        with tempfile.TemporaryDirectory(prefix="fp_growth") as d:
+            fp.write().overwrite().save(d)
+            fp2 = FPGrowth.load(d)
+            self.assertEqual(str(fp), str(fp2))
+
+        model = fp.fit(df)
+
+        self.assertEqual(model.freqItemsets.columns, ["items", "freq"])
+        self.assertEqual(model.freqItemsets.count(), 54)
+
+        self.assertEqual(
+            model.associationRules.columns,
+            ["antecedent", "consequent", "confidence", "lift", "support"],
+        )
+        self.assertEqual(model.associationRules.count(), 89)
+
+        output = model.transform(df)
+        self.assertEqual(output.columns, ["items", "prediction"])
+        self.assertEqual(output.count(), 6)
+
+        # Model save & load
+        with tempfile.TemporaryDirectory(prefix="fp_growth_model") as d:
+            model.write().overwrite().save(d)
+            model2 = FPGrowthModel.load(d)
+            self.assertEqual(str(model), str(model2))
+
+
+class FPMTests(FPMTestsMixin, unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.master("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_fpm import *  # noqa: F401,F403
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 4e93aec47ef0..b85bc6771f8e 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -500,7 +500,9 @@ private[ml] object MLUtils {
     "recommendForAllUsers", // ALSModel
     "recommendForAllItems", // ALSModel
     "recommendForUserSubset", // ALSModel
-    "recommendForItemSubset" // ALSModel
+    "recommendForItemSubset", // ALSModel
+    "associationRules", // FPGrowthModel
+    "freqItemsets" // FPGrowthModel
   )
 
   def invokeMethodAllowed(obj: Object, methodName: String): Object = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to