This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a523d66755c0 [SPARK-51963][ML] Simplify IndexToString.transform a523d66755c0 is described below commit a523d66755c09b42dbcac2d304363ffe7c715ed6 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu May 1 07:42:35 2025 +0800 [SPARK-51963][ML] Simplify IndexToString.transform ### What changes were proposed in this pull request? Simplify IndexToString.transform ### Why are the changes needed? the logic is pretty simple, we should use built-in functions instead of udf ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #50767 from zhengruifeng/ml_sql_string_ind. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../org/apache/spark/ml/feature/StringIndexer.scala | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 6518b0d9cf92..30b8c813188f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -28,7 +28,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset} -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.{get => fget, printf => fprintf, _} import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.VersionUtils.majorMinorVersion @@ -586,17 +586,11 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") override val uid: String) } else { $(labels) } - val indexer = udf { index: Double => - val idx = index.toInt - if (0 <= idx && idx < values.length) { - values(idx) - } else { - throw new SparkException(s"Unseen index: $index ??") - } - } - val outputColName = $(outputCol) - dataset.select(col("*"), - indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) + + val idxCol = col($(inputCol)).cast(IntegerType) + val valCol = when(lit(0) <= idxCol && idxCol < lit(values.length), fget(lit(values), idxCol)) + .otherwise(raise_error(fprintf(lit("Unseen index: %s ??"), idxCol.cast(StringType)))) + dataset.withColumn($(outputCol), valCol) } @Since("1.5.0") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org