This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new d18c899be771 [SPARK-50929][ML][PYTHON][CONNECT] Support `LDA` on
Connect
d18c899be771 is described below
commit d18c899be7714aa3bf63118a989078a3c32091bb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 27 09:19:52 2025 +0800
[SPARK-50929][ML][PYTHON][CONNECT] Support `LDA` on Connect
### What changes were proposed in this pull request?
Support `LDA` on Connect
### Why are the changes needed?
feature parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49679 from zhengruifeng/ml_connect_lda.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit b6b00e87b00be9c8ca7103d006c900caf0cb032b)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Estimator | 1 +
.../services/org.apache.spark.ml.Transformer | 2 +
.../scala/org/apache/spark/ml/clustering/LDA.scala | 4 +
python/pyspark/ml/clustering.py | 1 +
python/pyspark/ml/tests/test_clustering.py | 139 ++++++++++++++++++++-
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 13 ++
6 files changed, 159 insertions(+), 1 deletion(-)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index 9c1a1f5a19a6..97526bf1a0c0 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -37,6 +37,7 @@ org.apache.spark.ml.regression.GBTRegressor
org.apache.spark.ml.clustering.KMeans
org.apache.spark.ml.clustering.BisectingKMeans
org.apache.spark.ml.clustering.GaussianMixture
+org.apache.spark.ml.clustering.LDA
# recommendation
org.apache.spark.ml.recommendation.ALS
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 3f1ae52aaaf6..c6faa54c147b 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -53,6 +53,8 @@ org.apache.spark.ml.regression.GBTRegressionModel
org.apache.spark.ml.clustering.KMeansModel
org.apache.spark.ml.clustering.BisectingKMeansModel
org.apache.spark.ml.clustering.GaussianMixtureModel
+org.apache.spark.ml.clustering.DistributedLDAModel
+org.apache.spark.ml.clustering.LocalLDAModel
# recommendation
org.apache.spark.ml.recommendation.ALSModel
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index b3d3c84db051..3fce96fbfbb0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -617,6 +617,8 @@ class LocalLDAModel private[ml] (
sparkSession: SparkSession)
extends LDAModel(uid, vocabSize, sparkSession) {
+ private[ml] def this() = this(Identifiable.randomUID("lda"), -1, null, null)
+
oldLocalModel.setSeed(getSeed)
@Since("1.6.0")
@@ -713,6 +715,8 @@ class DistributedLDAModel private[ml] (
private var oldLocalModelOption: Option[OldLocalLDAModel])
extends LDAModel(uid, vocabSize, sparkSession) {
+ private[ml] def this() = this(Identifiable.randomUID("lda"), -1, null, null,
None)
+
override private[clustering] def oldLocalModel: OldLocalLDAModel = {
if (oldLocalModelOption.isEmpty) {
oldLocalModelOption = Some(oldDistributedModel.toLocal)
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 6cd508a9e950..8166cd41c834 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -1511,6 +1511,7 @@ class LDAModel(JavaModel, _LDAParams):
return self._call_java("logPerplexity", dataset)
@since("2.0.0")
+ @try_remote_attribute_relation
def describeTopics(self, maxTermsPerTopic: int = 10) -> DataFrame:
"""
Return the topics described by their top-weighted terms.
diff --git a/python/pyspark/ml/tests/test_clustering.py
b/python/pyspark/ml/tests/test_clustering.py
index e6013d10fa8e..9a26b746f027 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -20,7 +20,7 @@ import unittest
import numpy as np
-from pyspark.ml.linalg import Vectors
+from pyspark.ml.linalg import Vectors, SparseVector
from pyspark.sql import SparkSession
from pyspark.ml.clustering import (
KMeans,
@@ -32,6 +32,10 @@ from pyspark.ml.clustering import (
GaussianMixture,
GaussianMixtureModel,
GaussianMixtureSummary,
+ LDA,
+ LDAModel,
+ LocalLDAModel,
+ DistributedLDAModel,
)
@@ -264,6 +268,139 @@ class ClusteringTestsMixin:
model2 = GaussianMixtureModel.load(d)
self.assertEqual(str(model), str(model2))
+ def test_local_lda(self):
+ spark = self.spark
+ df = (
+ spark.createDataFrame(
+ [
+ [1, Vectors.dense([0.0, 1.0])],
+ [2, SparseVector(2, {0: 1.0})],
+ ],
+ ["id", "features"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("id")
+ )
+
+ lda = LDA(k=2, optimizer="online", seed=1)
+ lda.setMaxIter(1)
+ self.assertEqual(lda.getK(), 2)
+ self.assertEqual(lda.getOptimizer(), "online")
+ self.assertEqual(lda.getMaxIter(), 1)
+ self.assertEqual(lda.getSeed(), 1)
+
+ model = lda.fit(df)
+ self.assertEqual(lda.uid, model.uid)
+ self.assertIsInstance(model, LDAModel)
+ self.assertIsInstance(model, LocalLDAModel)
+ self.assertNotIsInstance(model, DistributedLDAModel)
+ self.assertFalse(model.isDistributed())
+
+ dc = model.estimatedDocConcentration()
+ self.assertTrue(np.allclose(dc.toArray(), [0.5, 0.5], atol=1e-4), dc)
+ topics = model.topicsMatrix()
+ self.assertTrue(
+ np.allclose(
+ topics.toArray(), [[1.20296728, 1.15740442], [0.99357675,
1.02993164]], atol=1e-4
+ ),
+ topics,
+ )
+
+ ll = model.logLikelihood(df)
+ self.assertTrue(np.allclose(ll, -3.2125122434040088, atol=1e-4), ll)
+ lp = model.logPerplexity(df)
+ self.assertTrue(np.allclose(lp, 1.6062561217020044, atol=1e-4), lp)
+ dt = model.describeTopics()
+ self.assertEqual(dt.columns, ["topic", "termIndices", "termWeights"])
+ self.assertEqual(dt.count(), 2)
+
+ # LocalLDAModel specific methods
+ self.assertEqual(model.vocabSize(), 2)
+
+ output = model.transform(df)
+ expected_cols = ["id", "features", "topicDistribution"]
+ self.assertEqual(output.columns, expected_cols)
+ self.assertEqual(output.count(), 2)
+
+ # save & load
+ with tempfile.TemporaryDirectory(prefix="local_lda") as d:
+ lda.write().overwrite().save(d)
+ lda2 = LDA.load(d)
+ self.assertEqual(str(lda), str(lda2))
+
+ model.write().overwrite().save(d)
+ model2 = LocalLDAModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
+ def test_distributed_lda(self):
+ spark = self.spark
+ df = (
+ spark.createDataFrame(
+ [
+ [1, Vectors.dense([0.0, 1.0])],
+ [2, SparseVector(2, {0: 1.0})],
+ ],
+ ["id", "features"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("id")
+ )
+
+ lda = LDA(k=2, optimizer="em", seed=1)
+ lda.setMaxIter(1)
+
+ self.assertEqual(lda.getK(), 2)
+ self.assertEqual(lda.getOptimizer(), "em")
+ self.assertEqual(lda.getMaxIter(), 1)
+ self.assertEqual(lda.getSeed(), 1)
+
+ model = lda.fit(df)
+ self.assertEqual(lda.uid, model.uid)
+ self.assertIsInstance(model, LDAModel)
+ self.assertNotIsInstance(model, LocalLDAModel)
+ self.assertIsInstance(model, DistributedLDAModel)
+
+ dc = model.estimatedDocConcentration()
+ self.assertTrue(np.allclose(dc.toArray(), [26.0, 26.0], atol=1e-4), dc)
+ topics = model.topicsMatrix()
+ self.assertTrue(
+ np.allclose(
+ topics.toArray(), [[0.39149926, 0.60850074], [0.60991237,
0.39008763]], atol=1e-4
+ ),
+ topics,
+ )
+
+ ll = model.logLikelihood(df)
+ self.assertTrue(np.allclose(ll, -3.719138517085772, atol=1e-4), ll)
+ lp = model.logPerplexity(df)
+ self.assertTrue(np.allclose(lp, 1.859569258542886, atol=1e-4), lp)
+
+ dt = model.describeTopics()
+ self.assertEqual(dt.columns, ["topic", "termIndices", "termWeights"])
+ self.assertEqual(dt.count(), 2)
+
+ # DistributedLDAModel specific methods
+ ll = model.trainingLogLikelihood()
+ self.assertTrue(np.allclose(ll, -1.3847360462201639, atol=1e-4), ll)
+ lp = model.logPrior()
+ self.assertTrue(np.allclose(lp, -69.59963186898915, atol=1e-4), lp)
+ model.getCheckpointFiles()
+
+ output = model.transform(df)
+ expected_cols = ["id", "features", "topicDistribution"]
+ self.assertEqual(output.columns, expected_cols)
+ self.assertEqual(output.count(), 2)
+
+ # save & load
+ with tempfile.TemporaryDirectory(prefix="distributed_lda") as d:
+ lda.write().overwrite().save(d)
+ lda2 = LDA.load(d)
+ self.assertEqual(str(lda), str(lda2))
+
+ model.write().overwrite().save(d)
+ model2 = DistributedLDAModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
class ClusteringTests(ClusteringTestsMixin, unittest.TestCase):
def setUp(self) -> None:
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index fbcbf8f3f204..9bf3c632b219 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -594,6 +594,19 @@ private[ml] object MLUtils {
classOf[GaussianMixtureModel],
Set("predict", "numFeatures", "weights", "gaussians",
"predictProbability", "gaussiansDF")),
(classOf[GaussianMixtureSummary], Set("probability", "probabilityCol",
"logLikelihood")),
+ (
+ classOf[LDAModel],
+ Set(
+ "estimatedDocConcentration",
+ "topicsMatrix",
+ "isDistributed",
+ "logLikelihood",
+ "logPerplexity",
+ "describeTopics")),
+ (classOf[LocalLDAModel], Set("vocabSize")),
+ (
+ classOf[DistributedLDAModel],
+ Set("trainingLogLikelihood", "logPrior", "getCheckpointFiles")),
// Recommendation Models
(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]