This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8a96d59fa810 [SPARK-54706][ML][CONNECT] Make DistributedLDAModel work
with local file system
8a96d59fa810 is described below
commit 8a96d59fa810982d911641a2099fe0c662dd0247
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 17 10:33:39 2025 +0800
[SPARK-54706][ML][CONNECT] Make DistributedLDAModel work with local file
system
### What changes were proposed in this pull request?
Make DistributedLDAModel work with local file system
### Why are the changes needed?
this is required to support DistributedLDAModel on connect
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
updated tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #53473 from zhengruifeng/connect_lda.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../scala/org/apache/spark/ml/clustering/LDA.scala | 50 +++++++++++-----
.../apache/spark/mllib/clustering/LDAModel.scala | 66 +++++++++++++++++++++-
.../org/apache/spark/ml/clustering/LDASuite.scala | 3 +-
python/pyspark/ml/tests/test_clustering.py | 2 -
.../apache/spark/sql/connect/ml/MLHandler.scala | 9 ---
5 files changed, 102 insertions(+), 28 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index c64d3a98c0a9..e4736718430b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -827,8 +827,12 @@ class DistributedLDAModel private[ml] (
}
private[spark] override def estimatedSize: Long = {
- // TODO: Implement this method.
- throw new UnsupportedOperationException
+ this.oldDistributedModel.toInternals.map {
+ case df: org.apache.spark.sql.classic.DataFrame =>
+ df.toArrowBatchRdd.map(_.length.toLong).reduce(_ + _)
+ case o => throw new UnsupportedOperationException(
+ s"Unsupported dataframe type: ${o.getClass.getName}")
+ }.sum
}
}
@@ -840,14 +844,23 @@ object DistributedLDAModel extends
MLReadable[DistributedLDAModel] {
class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
+ val modelPath = new Path(path, "oldModel").toString
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
if (ReadWriteUtils.localSavingModeState.get()) {
- throw new UnsupportedOperationException(
- "DistributedLDAModel does not support saving to local filesystem
path."
- )
+ val Seq(metadataDF, globalTopicTotalsDF, verticesDF, edgesDF) =
+ instance.oldDistributedModel.toInternals
+
+ ReadWriteUtils.saveDataFrame(
+ new Path(modelPath, "old-metadata").toString, metadataDF)
+ ReadWriteUtils.saveDataFrame(
+ new Path(modelPath, "old-global-topic-totals").toString,
globalTopicTotalsDF)
+ ReadWriteUtils.saveDataFrame(
+ new Path(modelPath, "old-vertices").toString, verticesDF)
+ ReadWriteUtils.saveDataFrame(
+ new Path(modelPath, "old-edges").toString, edgesDF)
+ } else {
+ instance.oldDistributedModel.save(sc, modelPath)
}
- DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
- val modelPath = new Path(path, "oldModel").toString
- instance.oldDistributedModel.save(sc, modelPath)
}
}
@@ -856,14 +869,23 @@ object DistributedLDAModel extends
MLReadable[DistributedLDAModel] {
private val className = classOf[DistributedLDAModel].getName
override def load(path: String): DistributedLDAModel = {
- if (ReadWriteUtils.localSavingModeState.get()) {
- throw new UnsupportedOperationException(
- "DistributedLDAModel does not support loading from local filesystem
path."
- )
- }
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val modelPath = new Path(path, "oldModel").toString
- val oldModel = OldDistributedLDAModel.load(sc, modelPath)
+ val oldModel = if (ReadWriteUtils.localSavingModeState.get()) {
+ val metadataDF = ReadWriteUtils.loadDataFrame(
+ new Path(modelPath, "old-metadata").toString, sparkSession)
+ val globalTopicTotalsDF = ReadWriteUtils.loadDataFrame(
+ new Path(modelPath, "old-global-topic-totals").toString,
sparkSession)
+ val verticesDF = ReadWriteUtils.loadDataFrame(
+ new Path(modelPath, "old-vertices").toString, sparkSession)
+ val edgesDF = ReadWriteUtils.loadDataFrame(
+ new Path(modelPath, "old-edges").toString, sparkSession)
+
+ OldDistributedLDAModel.fromInternals(
+ Seq(metadataDF, globalTopicTotalsDF, verticesDF, edgesDF))
+ } else {
+ OldDistributedLDAModel.load(sc, modelPath)
+ }
val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
oldModel, sparkSession, None)
LDAParams.getAndSetParams(model, metadata)
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 10a81acede0c..ea5ed4f6fbc3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -31,7 +31,7 @@ import org.apache.spark.graphx.{Edge, EdgeContext, Graph,
VertexId}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.ArrayImplicits._
@@ -816,6 +816,36 @@ class DistributedLDAModel private[clustering] (
sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration,
topicConcentration,
iterationTimes, gammaShape)
}
+
+ private[spark] def toInternals: Seq[DataFrame] = {
+ import DistributedLDAModel.SaveLoadV1_0._
+
+ val sc = graph.vertices.sparkContext
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+
+ val newMetadata = compact(render
+ (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("k" -> k) ~ ("vocabSize" -> vocabSize) ~
+ ("docConcentration" -> docConcentration.toArray.toImmutableArraySeq) ~
+ ("topicConcentration" -> topicConcentration) ~
+ ("iterationTimes" -> iterationTimes.toImmutableArraySeq) ~
+ ("gammaShape" -> gammaShape) ~
+ ("globalTopicTotals" -> globalTopicTotals.data.toImmutableArraySeq)))
+
+ val globalTopicTotalsDF = spark.createDataFrame(
+ Seq(Tuple1.apply(globalTopicTotals.data))).toDF("meta")
+
+ val metadataDF = spark.createDataFrame(
+ Seq(Tuple1.apply(newMetadata))).toDF("meta")
+
+ val verticesDF = spark.createDataFrame(
+ graph.vertices.map { case (ind, vertex) => ConnectVertexData(ind,
vertex.data)})
+
+ val edgesDF = spark.createDataFrame(
+ graph.edges.map { case Edge(srcId, dstId, prop) => EdgeData(srcId,
dstId, prop)})
+
+ Seq(metadataDF, globalTopicTotalsDF, verticesDF, edgesDF)
+ }
}
/**
@@ -834,6 +864,38 @@ object DistributedLDAModel extends
Loader[DistributedLDAModel] {
*/
private[clustering] val defaultGammaShape: Double = 100
+ private[spark] def fromInternals(internals: Seq[DataFrame]):
DistributedLDAModel = {
+ val Seq(metadataDF, globalTopicTotalsDF, verticesDF, edgesDF) = internals
+ val spark = metadataDF.sparkSession
+
+ import spark.implicits._
+
+ implicit val formats: Formats = DefaultFormats
+ val metadata = parse(metadataDF.as[String].first())
+
+ val expectedK = (metadata \ "k").extract[Int]
+ val vocabSize = (metadata \ "vocabSize").extract[Int]
+ val docConcentration =
+ Vectors.dense((metadata \
"docConcentration").extract[Seq[Double]].toArray)
+ val topicConcentration = (metadata \ "topicConcentration").extract[Double]
+ val iterationTimes = (metadata \
"iterationTimes").extract[Seq[Double]].toArray
+ val gammaShape = (metadata \ "gammaShape").extract[Double]
+
+ val globalTopicTotals = new LDA.TopicCounts(
+ globalTopicTotalsDF.first().getSeq[Double](0).toArray)
+
+ val vertices: RDD[(VertexId, LDA.TopicCounts)] = verticesDF.rdd.map {
+ case row: Row => (row.getLong(0), new
LDA.TopicCounts(row.getSeq[Double](1).toArray))
+ }
+ val edges: RDD[Edge[LDA.TokenCount]] = edgesDF.rdd.map {
+ case row: Row => Edge(row.getLong(0), row.getLong(1), row.getDouble(2))
+ }
+ val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
+
+ new DistributedLDAModel(graph, globalTopicTotals,
globalTopicTotals.length, vocabSize,
+ docConcentration, topicConcentration, iterationTimes, gammaShape)
+ }
+
private object SaveLoadV1_0 {
val thisFormatVersion = "1.0"
@@ -846,6 +908,8 @@ object DistributedLDAModel extends
Loader[DistributedLDAModel] {
// Store each term and document vertex with an id and the topicWeights.
case class VertexData(id: Long, topicWeights: Vector)
+ case class ConnectVertexData(id: Long, topicWeights: Array[Double])
+
// Store each edge with the source id, destination id and tokenCounts.
case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 4ae1d3ce24a6..a0223396da31 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -274,8 +274,7 @@ class LDASuite extends MLTest with DefaultReadWriteTest {
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset,
LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
- LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData,
- skipTestSaveLocal = true)
+ LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}
test("EM LDA checkpointing: save last checkpoint") {
diff --git a/python/pyspark/ml/tests/test_clustering.py
b/python/pyspark/ml/tests/test_clustering.py
index fbf012babcc3..d624b6398881 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -403,8 +403,6 @@ class ClusteringTestsMixin:
self.assertEqual(str(model), str(model2))
def test_distributed_lda(self):
- if is_remote():
- self.skipTest("Do not support Spark Connect.")
spark = self.spark
df = (
spark.createDataFrame(
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 3a53aa77fde6..a4c3d2737052 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -229,15 +229,6 @@ private[connect] object MLHandler extends Logging {
} catch {
case _: UnsupportedOperationException => ()
}
- if (estimator.getClass.getName ==
"org.apache.spark.ml.clustering.LDA"
- && estimator
- .asInstanceOf[org.apache.spark.ml.clustering.LDA]
- .getOptimizer
- .toLowerCase() == "em") {
- throw MlUnsupportedException(
- "LDA algorithm with 'em' optimizer is not supported " +
- "if Spark Connect model cache offloading is enabled.")
- }
}
EstimatorUtils.warningMessagesBuffer.set(new
mutable.ArrayBuffer[String]())
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]