This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 23662621d08f [SPARK-51791][ML] `ImputerModel` stores coefficients with 
arrays instead of dataframe
23662621d08f is described below

commit 23662621d08fdf1756aba508c9cf17dabb4750b6
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Apr 15 14:49:40 2025 +0800

    [SPARK-51791][ML] `ImputerModel` stores coefficients with arrays instead of 
dataframe
    
    ### What changes were proposed in this pull request?
    
    `ImputerModel` stores coefficients with arrays instead of dataframe
    
    ### Why are the changes needed?
    
    to be compatible with the default size estimation
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #50578 from zhengruifeng/ml_imputer_model_x.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../org/apache/spark/ml/feature/Imputer.scala      | 39 +++++++++++-----------
 .../org/apache/spark/ml/feature/ImputerSuite.scala | 12 ++-----
 2 files changed, 22 insertions(+), 29 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 4e169ab178b9..b9fb20d14933 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.ArrayImplicits._
@@ -203,11 +203,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override 
val uid: String)
         s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " +
         s"missingValue(${$(missingValue)})")
     }
-
-    val rows = 
spark.sparkContext.parallelize(Seq(Row.fromSeq(results.toImmutableArraySeq)))
-    val schema = StructType(inputColumns.map(col => StructField(col, 
DoubleType, nullable = false)))
-    val surrogateDF = spark.createDataFrame(rows, schema)
-    copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
+    copyValues(new ImputerModel(uid, inputColumns, results).setParent(this))
   }
 
   override def transformSchema(schema: StructType): StructType = {
@@ -241,13 +237,22 @@ object Imputer extends DefaultParamsReadable[Imputer] {
 @Since("2.2.0")
 class ImputerModel private[ml] (
     @Since("2.2.0") override val uid: String,
-    @Since("2.2.0") val surrogateDF: DataFrame)
+    @Since("4.1.0") private[ml] val columnNames: Array[String],
+    @Since("4.1.0") private[ml] val surrogates: Array[Double])
   extends Model[ImputerModel] with ImputerParams with MLWritable {
 
   import ImputerModel._
 
   // For ml connect only
-  private[ml] def this() = this("", null)
+  private[ml] def this() = this("", Array.empty, Array.emptyDoubleArray)
+
+  @Since("2.2.0")
+  def surrogateDF: DataFrame = {
+    val spark = SparkSession.builder().getOrCreate()
+    val rows = 
java.util.List.of[Row](Row.fromSeq(surrogates.toImmutableArraySeq))
+    val schema = StructType(columnNames.map(c => StructField(c, DoubleType, 
nullable = false)))
+    spark.createDataFrame(rows, schema)
+  }
 
   /** @group setParam */
   @Since("3.0.0")
@@ -263,19 +268,12 @@ class ImputerModel private[ml] (
   /** @group setParam */
   def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
 
-  @transient private lazy val surrogates = {
-    val row = surrogateDF.head()
-    row.schema.fieldNames.zipWithIndex
-      .map { case (name, index) => (name, row.getDouble(index)) }
-      .toMap
-  }
-
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val (inputColumns, outputColumns) = getInOutCols()
 
     val newCols = inputColumns.map { inputCol =>
-      val surrogate = surrogates(inputCol)
+      val surrogate = surrogates(columnNames.indexOf(inputCol))
       val inputType = SchemaUtils.getSchemaFieldType(dataset.schema, inputCol)
       val ic = col(inputCol).cast(DoubleType)
       when(ic.isNull, surrogate)
@@ -291,7 +289,7 @@ class ImputerModel private[ml] (
   }
 
   override def copy(extra: ParamMap): ImputerModel = {
-    val copied = new ImputerModel(uid, surrogateDF)
+    val copied = new ImputerModel(uid, columnNames, surrogates)
     copyValues(copied, extra).setParent(parent)
   }
 
@@ -326,8 +324,11 @@ object ImputerModel extends MLReadable[ImputerModel] {
     override def load(path: String): ImputerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
-      val surrogateDF = sparkSession.read.parquet(dataPath)
-      val model = new ImputerModel(metadata.uid, surrogateDF)
+      val row = sparkSession.read.parquet(dataPath).head()
+      val (columnNames, surrogates) = row.schema.fieldNames.zipWithIndex
+        .map { case (name, index) => (name, row.getDouble(index)) }
+        .unzip
+      val model = new ImputerModel(metadata.uid, columnNames, surrogates)
       metadata.getAndSetParams(model)
       model
     }
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
index c149618f6066..a2fc333535fc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
@@ -298,12 +298,8 @@ class ImputerSuite extends MLTest with 
DefaultReadWriteTest {
   }
 
   test("ImputerModel read/write") {
-    val spark = this.spark
-    import spark.implicits._
-    val surrogateDF = Seq(1.234).toDF("myInputCol")
-
     val instance = new ImputerModel(
-      "myImputer", surrogateDF)
+      "myImputer", Array("myInputCol"), Array(1.234))
       .setInputCols(Array("myInputCol"))
       .setOutputCols(Array("myOutputCol"))
     val newInstance = testDefaultReadWrite(instance)
@@ -312,12 +308,8 @@ class ImputerSuite extends MLTest with 
DefaultReadWriteTest {
   }
 
   test("Single Column: ImputerModel read/write") {
-    val spark = this.spark
-    import spark.implicits._
-    val surrogateDF = Seq(1.234).toDF("myInputCol")
-
     val instance = new ImputerModel(
-      "myImputer", surrogateDF)
+      "myImputer", Array("myInputCol"), Array(1.234))
       .setInputCol("myInputCol")
       .setOutputCol("myOutputCol")
     val newInstance = testDefaultReadWrite(instance)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to