This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9d90507d7865 [SPARK-51282][SPARK-51422][ML][FOLLOW-UP] Replace UDF with builtin functions 9d90507d7865 is described below commit 9d90507d78656dcb39ff0021b1dfb44545eec004 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Mar 20 09:00:15 2025 +0800 [SPARK-51282][SPARK-51422][ML][FOLLOW-UP] Replace UDF with builtin functions <!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html 2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html 3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'. 4. Be sure to keep the PR description updated to reflect all changes. 5. Please write your PR title to summarize what this PR proposes. 6. If possible, provide a concise example to reproduce the issue for a faster review. 7. If you want to add a new configuration, please read the guideline first for naming configurations in 'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'. 8. If you want to add or modify an error type or message, please read the guideline first in 'common/utils/src/main/resources/error/README.md'. --> ### What changes were proposed in this pull request? Make scala side changes corresponding to https://github.com/apache/spark/pull/50041 and https://github.com/apache/spark/pull/50184 ### Why are the changes needed? for parity between scala and python ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests and added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50321 from zhengruifeng/ml_ovr_transform. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../apache/spark/ml/classification/OneVsRest.scala | 18 ++++----- .../main/scala/org/apache/spark/ml/functions.scala | 46 +++++++++++++++++++++- .../org/apache/spark/mllib/util/MLUtils.scala | 27 +++++++------ .../scala/org/apache/spark/ml/FunctionsSuite.scala | 30 +++++++++++++- .../org/apache/spark/mllib/util/MLUtilsSuite.scala | 4 +- 5 files changed, 98 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index eb85791e9dbf..b3a512caa0c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.internal.{LogKeys, MDC} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.functions._ import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ @@ -41,6 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.ThreadUtils private[ml] trait ClassifierTypeTrait { @@ -194,7 +195,6 @@ final class OneVsRestModel private[ml] ( val accColName = "mbc$acc" + UUID.randomUUID().toString val newDataset = dataset.withColumn(accColName, lit(Array.emptyDoubleArray)) val columns = newDataset.schema.fieldNames.map(col) - val updateUDF = udf { (preds: Array[Double], pred: Vector) => preds :+ pred(1) } // persist if underlying dataset is not persistent. val handlePersistence = !dataset.isStreaming && dataset.storageLevel == StorageLevel.NONE @@ -212,10 +212,9 @@ final class OneVsRestModel private[ml] ( if (isProbModel) { tmpModel.asInstanceOf[ProbabilisticClassificationModel[_, _]].setProbabilityCol("") } - - import org.apache.spark.util.ArrayImplicits._ tmpModel.transform(df) - .withColumn(accColName, updateUDF(col(accColName), col(tmpRawPredName))) + .withColumn(accColName, array_append( + col(accColName), vector_get(col(tmpRawPredName), lit(1)))) .select(columns.toImmutableArraySeq: _*) } @@ -228,18 +227,17 @@ final class OneVsRestModel private[ml] ( if (getRawPredictionCol.nonEmpty) { // output the RawPrediction as vector - val rawPredictionUDF = udf { preds: Array[Double] => Vectors.dense(preds) } predictionColNames :+= getRawPredictionCol - predictionColumns :+= rawPredictionUDF(col(accColName)) + predictionColumns :+= array_to_vector(col(accColName)) .as($(rawPredictionCol), outputSchema($(rawPredictionCol)).metadata) } if (getPredictionCol.nonEmpty) { // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (preds: Array[Double]) => preds.indices.maxBy(preds.apply).toDouble } predictionColNames :+= getPredictionCol - predictionColumns :+= labelUDF(col(accColName)) - .as(getPredictionCol, labelMetadata) + + predictionColumns :+= array_argmax(col(accColName)) + .cast(DoubleType).as(getPredictionCol, labelMetadata) } aggregatedDataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index 48bd1a7207e9..07db59e53ba7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml import org.apache.spark.annotation.Since +import org.apache.spark.sql.{functions => sf} import org.apache.spark.sql.Column -import org.apache.spark.sql.functions.lit // scalastyle:off @Since("3.0.0") @@ -34,7 +34,7 @@ object functions { * @since 3.0.0 */ def vector_to_array(v: Column, dtype: String = "float64"): Column = - Column.internalFn("vector_to_array", v, lit(dtype)) + Column.internalFn("vector_to_array", v, sf.lit(dtype)) /** * Converts a column of array of numeric type into a column of dense vectors in MLlib. @@ -43,4 +43,46 @@ object functions { * @since 3.1.0 */ def array_to_vector(v: Column): Column = Column.internalFn("array_to_vector", v) + + private[ml] def array_binary_search(a: Column, v: Column): Column = + Column.internalFn("array_binary_search", a, v) + + // input: vector, output: double + private[ml] def vector_get(v: Column, index: Column): Column = { + val unwrapped = sf.unwrap_udt(v) + val isDense = unwrapped.getField("type") === sf.lit(1) + val values = unwrapped.getField("values") + val size = sf.when(isDense, sf.array_size(values)).otherwise(unwrapped.getField("size")) + val sparseIdx = array_binary_search(unwrapped.getField("indices"), index) + + sf.when(index >= 0 && index < size, + sf.when(isDense, sf.get(values, index)) + .when(sparseIdx >= 0, sf.get(values, sparseIdx)) + .otherwise(sf.lit(0.0)) + ).otherwise( + sf.raise_error(sf.printf( + sf.lit(s"Vector index must be in [0, %s), but got %s"), size, index) + ) + ) + } + + // input: array<double>, output: int + private[ml] def array_argmax(arr: Column): Column = { + sf.aggregate( + arr, + sf.struct( + sf.lit(Double.NegativeInfinity).alias("v"), // max value + sf.lit(-1).alias("i"), // index of max value + sf.lit(0).alias("j")), // current index + (acc, vv) => { + val v = acc.getField("v") + val i = acc.getField("i") + val j = acc.getField("j") + sf.when((!vv.isNaN) && (!vv.isNull) && (vv > v), + sf.struct(vv.alias("v"), j.alias("i"), j + 1)) + .otherwise(sf.struct(v.alias("v"), i.alias("i"), j + 1)) + }, + acc => acc.getField("i") + ) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 1257d2ccfbfb..b8fcb1ffcbfe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -258,20 +258,23 @@ object MLUtils extends Logging { */ @Since("3.1.0") def kFold(df: DataFrame, numFolds: Int, foldColName: String): Array[(RDD[Row], RDD[Row])] = { - val foldCol = df.col(foldColName) - val checker = udf { foldNum: Int => - // Valid fold number is in range [0, numFolds). - if (foldNum < 0 || foldNum >= numFolds) { - throw new SparkException(s"Fold number must be in range [0, $numFolds), but got $foldNum.") - } - true - } + val checked = df.withColumn( + foldColName, + when((lit(0) <= col(foldColName)) && (col(foldColName) < lit(numFolds)), col(foldColName)) + .otherwise( + raise_error( + printf( + lit(s"Fold number must be in range [0, $numFolds), but got %s."), col(foldColName)) + ) + ) + ) + (0 until numFolds).map { fold => - val training = df - .filter(checker(foldCol) && foldCol =!= fold) + val training = checked + .filter(col(foldColName) =!= fold) .drop(foldColName).rdd - val validation = df - .filter(checker(foldCol) && foldCol === fold) + val validation = checked + .filter(col(foldColName) === fold) .drop(foldColName).rdd if (training.isEmpty()) { throw new SparkException(s"The training data at fold $fold is empty.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala index 32986a1345d6..7fcb1d2fbfbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml import org.apache.spark.SparkException -import org.apache.spark.ml.functions.{array_to_vector, vector_to_array} +import org.apache.spark.ml.functions._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.MLTest import org.apache.spark.mllib.linalg.{Vectors => OldVectors} @@ -102,4 +102,32 @@ class FunctionsSuite extends MLTest { val resultVec3 = df3.select(array_to_vector(col("c1"))).collect()(0)(0).asInstanceOf[Vector] assert(resultVec3 === Vectors.dense(Array(1.0, 2.0))) } + + test("test get_vector") { + val df = Seq( + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.dense(1.0, 2.0, 3.0), 1), + (Vectors.dense(1.0, 2.0, 3.0), 2), + (Vectors.sparse(3, Seq((0, -1.0))), 0), + (Vectors.sparse(3, Seq((0, -1.0))), 1), + (Vectors.sparse(3, Seq((0, -1.0))), 2) + ).toDF("vec", "idx") + + val result = df.select(vector_get(col("vec"), col("idx"))).as[Double].collect() + assert(result === Array(1.0, 2.0, 3.0, -1.0, 0.0, 0.0)) + } + + test("test array_argmax") { + val df = Seq( + Tuple1.apply(Array(1.0, 2.0, 3.0)), + Tuple1.apply(Array(1.0, 3.0, 2.0)), + Tuple1.apply(Array(3.0, 2.0, 1.0)), + Tuple1.apply(Array(1.0, 3.0, 3.0)), + Tuple1.apply(Array(3.0, 3.0, 3.0)), + Tuple1.apply(Array.emptyDoubleArray) + ).toDF("arr") + + val result = df.select(array_argmax(col("arr"))).as[Int].collect() + assert(result === Array(2, 1, 0, 1, 0, -1)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 1a02e26b9260..c4dc3b171572 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -24,7 +24,7 @@ import scala.io.Source import com.google.common.io.Files -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException} import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ @@ -378,7 +378,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("kFold with fold column: invalid fold numbers") { val data = sc.parallelize(Seq(0, 1, 2), 2).toDF( "fold") - val err1 = intercept[SparkException] { + val err1 = intercept[SparkRuntimeException] { kFold(data, 2, "fold")(0)._1.collect() } assert(err1.getMessage.contains("Fold number must be in range [0, 2), but got 2.")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org