This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 23662621d08f [SPARK-51791][ML] `ImputerModel` stores coefficients with arrays instead of dataframe 23662621d08f is described below commit 23662621d08fdf1756aba508c9cf17dabb4750b6 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Apr 15 14:49:40 2025 +0800 [SPARK-51791][ML] `ImputerModel` stores coefficients with arrays instead of dataframe ### What changes were proposed in this pull request? `ImputerModel` stores coefficients with arrays instead of dataframe ### Why are the changes needed? to be compatible with the default size estimation ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? updated tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50578 from zhengruifeng/ml_imputer_model_x. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../org/apache/spark/ml/feature/Imputer.scala | 39 +++++++++++----------- .../org/apache/spark/ml/feature/ImputerSuite.scala | 12 ++----- 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 4e169ab178b9..b9fb20d14933 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -203,11 +203,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " + s"missingValue(${$(missingValue)})") } - - val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results.toImmutableArraySeq))) - val schema = StructType(inputColumns.map(col => StructField(col, DoubleType, nullable = false))) - val surrogateDF = spark.createDataFrame(rows, schema) - copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) + copyValues(new ImputerModel(uid, inputColumns, results).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -241,13 +237,22 @@ object Imputer extends DefaultParamsReadable[Imputer] { @Since("2.2.0") class ImputerModel private[ml] ( @Since("2.2.0") override val uid: String, - @Since("2.2.0") val surrogateDF: DataFrame) + @Since("4.1.0") private[ml] val columnNames: Array[String], + @Since("4.1.0") private[ml] val surrogates: Array[Double]) extends Model[ImputerModel] with ImputerParams with MLWritable { import ImputerModel._ // For ml connect only - private[ml] def this() = this("", null) + private[ml] def this() = this("", Array.empty, Array.emptyDoubleArray) + + @Since("2.2.0") + def surrogateDF: DataFrame = { + val spark = SparkSession.builder().getOrCreate() + val rows = java.util.List.of[Row](Row.fromSeq(surrogates.toImmutableArraySeq)) + val schema = StructType(columnNames.map(c => StructField(c, DoubleType, nullable = false))) + spark.createDataFrame(rows, schema) + } /** @group setParam */ @Since("3.0.0") @@ -263,19 +268,12 @@ class ImputerModel private[ml] ( /** @group setParam */ def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - @transient private lazy val surrogates = { - val row = surrogateDF.head() - row.schema.fieldNames.zipWithIndex - .map { case (name, index) => (name, row.getDouble(index)) } - .toMap - } - override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val (inputColumns, outputColumns) = getInOutCols() val newCols = inputColumns.map { inputCol => - val surrogate = surrogates(inputCol) + val surrogate = surrogates(columnNames.indexOf(inputCol)) val inputType = SchemaUtils.getSchemaFieldType(dataset.schema, inputCol) val ic = col(inputCol).cast(DoubleType) when(ic.isNull, surrogate) @@ -291,7 +289,7 @@ class ImputerModel private[ml] ( } override def copy(extra: ParamMap): ImputerModel = { - val copied = new ImputerModel(uid, surrogateDF) + val copied = new ImputerModel(uid, columnNames, surrogates) copyValues(copied, extra).setParent(parent) } @@ -326,8 +324,11 @@ object ImputerModel extends MLReadable[ImputerModel] { override def load(path: String): ImputerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val dataPath = new Path(path, "data").toString - val surrogateDF = sparkSession.read.parquet(dataPath) - val model = new ImputerModel(metadata.uid, surrogateDF) + val row = sparkSession.read.parquet(dataPath).head() + val (columnNames, surrogates) = row.schema.fieldNames.zipWithIndex + .map { case (name, index) => (name, row.getDouble(index)) } + .unzip + val model = new ImputerModel(metadata.uid, columnNames, surrogates) metadata.getAndSetParams(model) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index c149618f6066..a2fc333535fc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -298,12 +298,8 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("ImputerModel read/write") { - val spark = this.spark - import spark.implicits._ - val surrogateDF = Seq(1.234).toDF("myInputCol") - val instance = new ImputerModel( - "myImputer", surrogateDF) + "myImputer", Array("myInputCol"), Array(1.234)) .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) val newInstance = testDefaultReadWrite(instance) @@ -312,12 +308,8 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Single Column: ImputerModel read/write") { - val spark = this.spark - import spark.implicits._ - val surrogateDF = Seq(1.234).toDF("myInputCol") - val instance = new ImputerModel( - "myImputer", surrogateDF) + "myImputer", Array("myInputCol"), Array(1.234)) .setInputCol("myInputCol") .setOutputCol("myOutputCol") val newInstance = testDefaultReadWrite(instance) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org