This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8a96d59fa810 [SPARK-54706][ML][CONNECT] Make DistributedLDAModel work 
with local file system
8a96d59fa810 is described below

commit 8a96d59fa810982d911641a2099fe0c662dd0247
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 17 10:33:39 2025 +0800

    [SPARK-54706][ML][CONNECT] Make DistributedLDAModel work with local file 
system
    
    ### What changes were proposed in this pull request?
    Make DistributedLDAModel work with local file system
    
    ### Why are the changes needed?
    this is required to support DistributedLDAModel on connect
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #53473 from zhengruifeng/connect_lda.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../scala/org/apache/spark/ml/clustering/LDA.scala | 50 +++++++++++-----
 .../apache/spark/mllib/clustering/LDAModel.scala   | 66 +++++++++++++++++++++-
 .../org/apache/spark/ml/clustering/LDASuite.scala  |  3 +-
 python/pyspark/ml/tests/test_clustering.py         |  2 -
 .../apache/spark/sql/connect/ml/MLHandler.scala    |  9 ---
 5 files changed, 102 insertions(+), 28 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index c64d3a98c0a9..e4736718430b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -827,8 +827,12 @@ class DistributedLDAModel private[ml] (
   }
 
   private[spark] override def estimatedSize: Long = {
-    // TODO: Implement this method.
-    throw new UnsupportedOperationException
+    this.oldDistributedModel.toInternals.map {
+      case df: org.apache.spark.sql.classic.DataFrame =>
+        df.toArrowBatchRdd.map(_.length.toLong).reduce(_ + _)
+      case o => throw new UnsupportedOperationException(
+        s"Unsupported dataframe type: ${o.getClass.getName}")
+    }.sum
   }
 }
 
@@ -840,14 +844,23 @@ object DistributedLDAModel extends 
MLReadable[DistributedLDAModel] {
   class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {
 
     override protected def saveImpl(path: String): Unit = {
+      val modelPath = new Path(path, "oldModel").toString
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       if (ReadWriteUtils.localSavingModeState.get()) {
-        throw new UnsupportedOperationException(
-          "DistributedLDAModel does not support saving to local filesystem 
path."
-        )
+        val Seq(metadataDF, globalTopicTotalsDF, verticesDF, edgesDF) =
+          instance.oldDistributedModel.toInternals
+
+        ReadWriteUtils.saveDataFrame(
+          new Path(modelPath, "old-metadata").toString, metadataDF)
+        ReadWriteUtils.saveDataFrame(
+          new Path(modelPath, "old-global-topic-totals").toString, 
globalTopicTotalsDF)
+        ReadWriteUtils.saveDataFrame(
+          new Path(modelPath, "old-vertices").toString, verticesDF)
+        ReadWriteUtils.saveDataFrame(
+          new Path(modelPath, "old-edges").toString, edgesDF)
+      } else {
+        instance.oldDistributedModel.save(sc, modelPath)
       }
-      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
-      val modelPath = new Path(path, "oldModel").toString
-      instance.oldDistributedModel.save(sc, modelPath)
     }
   }
 
@@ -856,14 +869,23 @@ object DistributedLDAModel extends 
MLReadable[DistributedLDAModel] {
     private val className = classOf[DistributedLDAModel].getName
 
     override def load(path: String): DistributedLDAModel = {
-      if (ReadWriteUtils.localSavingModeState.get()) {
-        throw new UnsupportedOperationException(
-          "DistributedLDAModel does not support loading from local filesystem 
path."
-        )
-      }
       val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val modelPath = new Path(path, "oldModel").toString
-      val oldModel = OldDistributedLDAModel.load(sc, modelPath)
+      val oldModel = if (ReadWriteUtils.localSavingModeState.get()) {
+        val metadataDF = ReadWriteUtils.loadDataFrame(
+          new Path(modelPath, "old-metadata").toString, sparkSession)
+        val globalTopicTotalsDF = ReadWriteUtils.loadDataFrame(
+          new Path(modelPath, "old-global-topic-totals").toString, 
sparkSession)
+        val verticesDF = ReadWriteUtils.loadDataFrame(
+          new Path(modelPath, "old-vertices").toString, sparkSession)
+        val edgesDF = ReadWriteUtils.loadDataFrame(
+          new Path(modelPath, "old-edges").toString, sparkSession)
+
+        OldDistributedLDAModel.fromInternals(
+          Seq(metadataDF, globalTopicTotalsDF, verticesDF, edgesDF))
+      } else {
+        OldDistributedLDAModel.load(sc, modelPath)
+      }
       val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
         oldModel, sparkSession, None)
       LDAParams.getAndSetParams(model, metadata)
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 10a81acede0c..ea5ed4f6fbc3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -31,7 +31,7 @@ import org.apache.spark.graphx.{Edge, EdgeContext, Graph, 
VertexId}
 import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
 import org.apache.spark.util.{BoundedPriorityQueue, Utils}
 import org.apache.spark.util.ArrayImplicits._
 
@@ -816,6 +816,36 @@ class DistributedLDAModel private[clustering] (
       sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, 
topicConcentration,
       iterationTimes, gammaShape)
   }
+
+  private[spark] def toInternals: Seq[DataFrame] = {
+    import DistributedLDAModel.SaveLoadV1_0._
+
+    val sc = graph.vertices.sparkContext
+    val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+
+    val newMetadata = compact(render
+    (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+      ("k" -> k) ~ ("vocabSize" -> vocabSize) ~
+      ("docConcentration" -> docConcentration.toArray.toImmutableArraySeq) ~
+      ("topicConcentration" -> topicConcentration) ~
+      ("iterationTimes" -> iterationTimes.toImmutableArraySeq) ~
+      ("gammaShape" -> gammaShape) ~
+      ("globalTopicTotals" -> globalTopicTotals.data.toImmutableArraySeq)))
+
+    val globalTopicTotalsDF = spark.createDataFrame(
+      Seq(Tuple1.apply(globalTopicTotals.data))).toDF("meta")
+
+    val metadataDF = spark.createDataFrame(
+      Seq(Tuple1.apply(newMetadata))).toDF("meta")
+
+    val verticesDF = spark.createDataFrame(
+      graph.vertices.map { case (ind, vertex) => ConnectVertexData(ind, 
vertex.data)})
+
+    val edgesDF = spark.createDataFrame(
+      graph.edges.map { case Edge(srcId, dstId, prop) => EdgeData(srcId, 
dstId, prop)})
+
+    Seq(metadataDF, globalTopicTotalsDF, verticesDF, edgesDF)
+  }
 }
 
 /**
@@ -834,6 +864,38 @@ object DistributedLDAModel extends 
Loader[DistributedLDAModel] {
    */
   private[clustering] val defaultGammaShape: Double = 100
 
+  private[spark] def fromInternals(internals: Seq[DataFrame]): 
DistributedLDAModel = {
+    val Seq(metadataDF, globalTopicTotalsDF, verticesDF, edgesDF) = internals
+    val spark = metadataDF.sparkSession
+
+    import spark.implicits._
+
+    implicit val formats: Formats = DefaultFormats
+    val metadata = parse(metadataDF.as[String].first())
+
+    val expectedK = (metadata \ "k").extract[Int]
+    val vocabSize = (metadata \ "vocabSize").extract[Int]
+    val docConcentration =
+      Vectors.dense((metadata \ 
"docConcentration").extract[Seq[Double]].toArray)
+    val topicConcentration = (metadata \ "topicConcentration").extract[Double]
+    val iterationTimes = (metadata \ 
"iterationTimes").extract[Seq[Double]].toArray
+    val gammaShape = (metadata \ "gammaShape").extract[Double]
+
+    val globalTopicTotals = new LDA.TopicCounts(
+      globalTopicTotalsDF.first().getSeq[Double](0).toArray)
+
+    val vertices: RDD[(VertexId, LDA.TopicCounts)] = verticesDF.rdd.map {
+      case row: Row => (row.getLong(0), new 
LDA.TopicCounts(row.getSeq[Double](1).toArray))
+    }
+    val edges: RDD[Edge[LDA.TokenCount]] = edgesDF.rdd.map {
+      case row: Row => Edge(row.getLong(0), row.getLong(1), row.getDouble(2))
+    }
+    val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
+
+    new DistributedLDAModel(graph, globalTopicTotals, 
globalTopicTotals.length, vocabSize,
+      docConcentration, topicConcentration, iterationTimes, gammaShape)
+  }
+
   private object SaveLoadV1_0 {
 
     val thisFormatVersion = "1.0"
@@ -846,6 +908,8 @@ object DistributedLDAModel extends 
Loader[DistributedLDAModel] {
     // Store each term and document vertex with an id and the topicWeights.
     case class VertexData(id: Long, topicWeights: Vector)
 
+    case class ConnectVertexData(id: Long, topicWeights: Array[Double])
+
     // Store each edge with the source id, destination id and tokenCounts.
     case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 4ae1d3ce24a6..a0223396da31 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -274,8 +274,7 @@ class LDASuite extends MLTest with DefaultReadWriteTest {
     val lda = new LDA()
     testEstimatorAndModelReadWrite(lda, dataset,
       LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
-      LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData,
-      skipTestSaveLocal = true)
+      LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
   }
 
   test("EM LDA checkpointing: save last checkpoint") {
diff --git a/python/pyspark/ml/tests/test_clustering.py 
b/python/pyspark/ml/tests/test_clustering.py
index fbf012babcc3..d624b6398881 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -403,8 +403,6 @@ class ClusteringTestsMixin:
             self.assertEqual(str(model), str(model2))
 
     def test_distributed_lda(self):
-        if is_remote():
-            self.skipTest("Do not support Spark Connect.")
         spark = self.spark
         df = (
             spark.createDataFrame(
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 3a53aa77fde6..a4c3d2737052 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -229,15 +229,6 @@ private[connect] object MLHandler extends Logging {
           } catch {
             case _: UnsupportedOperationException => ()
           }
-          if (estimator.getClass.getName == 
"org.apache.spark.ml.clustering.LDA"
-            && estimator
-              .asInstanceOf[org.apache.spark.ml.clustering.LDA]
-              .getOptimizer
-              .toLowerCase() == "em") {
-            throw MlUnsupportedException(
-              "LDA algorithm with 'em' optimizer is not supported " +
-                "if Spark Connect model cache offloading is enabled.")
-          }
         }
 
         EstimatorUtils.warningMessagesBuffer.set(new 
mutable.ArrayBuffer[String]())


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to