This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d90507d7865 [SPARK-51282][SPARK-51422][ML][FOLLOW-UP] Replace UDF 
with builtin functions
9d90507d7865 is described below

commit 9d90507d78656dcb39ff0021b1dfb44545eec004
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Mar 20 09:00:15 2025 +0800

    [SPARK-51282][SPARK-51422][ML][FOLLOW-UP] Replace UDF with builtin functions
    
    <!--
    Thanks for sending a pull request!  Here are some tips for you:
      1. If this is your first time, please read our contributor guidelines: 
https://spark.apache.org/contributing.html
      2. Ensure you have added or run the appropriate tests for your PR: 
https://spark.apache.org/developer-tools.html
      3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., 
'[WIP][SPARK-XXXX] Your PR title ...'.
      4. Be sure to keep the PR description updated to reflect all changes.
      5. Please write your PR title to summarize what this PR proposes.
      6. If possible, provide a concise example to reproduce the issue for a 
faster review.
      7. If you want to add a new configuration, please read the guideline 
first for naming configurations in
         
'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
      8. If you want to add or modify an error type or message, please read the 
guideline first in
         'common/utils/src/main/resources/error/README.md'.
    -->
    
    ### What changes were proposed in this pull request?
    Make scala side changes corresponding to 
https://github.com/apache/spark/pull/50041 and 
https://github.com/apache/spark/pull/50184
    
    ### Why are the changes needed?
    for parity between scala and python
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    existing tests and added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50321 from zhengruifeng/ml_ovr_transform.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../apache/spark/ml/classification/OneVsRest.scala | 18 ++++-----
 .../main/scala/org/apache/spark/ml/functions.scala | 46 +++++++++++++++++++++-
 .../org/apache/spark/mllib/util/MLUtils.scala      | 27 +++++++------
 .../scala/org/apache/spark/ml/FunctionsSuite.scala | 30 +++++++++++++-
 .../org/apache/spark/mllib/util/MLUtilsSuite.scala |  4 +-
 5 files changed, 98 insertions(+), 27 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index eb85791e9dbf..b3a512caa0c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -32,7 +32,7 @@ import org.apache.spark.annotation.Since
 import org.apache.spark.internal.{LogKeys, MDC}
 import org.apache.spark.ml._
 import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.functions._
 import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
 import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
 import org.apache.spark.ml.util._
@@ -41,6 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, 
SparkSession}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.ArrayImplicits._
 import org.apache.spark.util.ThreadUtils
 
 private[ml] trait ClassifierTypeTrait {
@@ -194,7 +195,6 @@ final class OneVsRestModel private[ml] (
     val accColName = "mbc$acc" + UUID.randomUUID().toString
     val newDataset = dataset.withColumn(accColName, 
lit(Array.emptyDoubleArray))
     val columns = newDataset.schema.fieldNames.map(col)
-    val updateUDF = udf { (preds: Array[Double], pred: Vector) => preds :+ 
pred(1) }
 
     // persist if underlying dataset is not persistent.
     val handlePersistence = !dataset.isStreaming && dataset.storageLevel == 
StorageLevel.NONE
@@ -212,10 +212,9 @@ final class OneVsRestModel private[ml] (
       if (isProbModel) {
         tmpModel.asInstanceOf[ProbabilisticClassificationModel[_, 
_]].setProbabilityCol("")
       }
-
-      import org.apache.spark.util.ArrayImplicits._
       tmpModel.transform(df)
-        .withColumn(accColName, updateUDF(col(accColName), 
col(tmpRawPredName)))
+        .withColumn(accColName, array_append(
+          col(accColName), vector_get(col(tmpRawPredName), lit(1))))
         .select(columns.toImmutableArraySeq: _*)
     }
 
@@ -228,18 +227,17 @@ final class OneVsRestModel private[ml] (
 
     if (getRawPredictionCol.nonEmpty) {
       // output the RawPrediction as vector
-      val rawPredictionUDF = udf { preds: Array[Double] => 
Vectors.dense(preds) }
       predictionColNames :+= getRawPredictionCol
-      predictionColumns :+= rawPredictionUDF(col(accColName))
+      predictionColumns :+= array_to_vector(col(accColName))
         .as($(rawPredictionCol), outputSchema($(rawPredictionCol)).metadata)
     }
 
     if (getPredictionCol.nonEmpty) {
       // output the index of the classifier with highest confidence as 
prediction
-      val labelUDF = udf { (preds: Array[Double]) => 
preds.indices.maxBy(preds.apply).toDouble }
       predictionColNames :+= getPredictionCol
-      predictionColumns :+= labelUDF(col(accColName))
-        .as(getPredictionCol, labelMetadata)
+
+      predictionColumns :+= array_argmax(col(accColName))
+        .cast(DoubleType).as(getPredictionCol, labelMetadata)
     }
 
     aggregatedDataset
diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala 
b/mllib/src/main/scala/org/apache/spark/ml/functions.scala
index 48bd1a7207e9..07db59e53ba7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala
@@ -18,8 +18,8 @@
 package org.apache.spark.ml
 
 import org.apache.spark.annotation.Since
+import org.apache.spark.sql.{functions => sf}
 import org.apache.spark.sql.Column
-import org.apache.spark.sql.functions.lit
 
 // scalastyle:off
 @Since("3.0.0")
@@ -34,7 +34,7 @@ object functions {
    * @since 3.0.0
    */
   def vector_to_array(v: Column, dtype: String = "float64"): Column =
-    Column.internalFn("vector_to_array", v, lit(dtype))
+    Column.internalFn("vector_to_array", v, sf.lit(dtype))
 
   /**
    * Converts a column of array of numeric type into a column of dense vectors 
in MLlib.
@@ -43,4 +43,46 @@ object functions {
    * @since 3.1.0
    */
   def array_to_vector(v: Column): Column = 
Column.internalFn("array_to_vector", v)
+
+  private[ml] def array_binary_search(a: Column, v: Column): Column =
+    Column.internalFn("array_binary_search", a, v)
+
+  // input: vector, output: double
+  private[ml] def vector_get(v: Column, index: Column): Column = {
+    val unwrapped = sf.unwrap_udt(v)
+    val isDense = unwrapped.getField("type") === sf.lit(1)
+    val values = unwrapped.getField("values")
+    val size = sf.when(isDense, 
sf.array_size(values)).otherwise(unwrapped.getField("size"))
+    val sparseIdx = array_binary_search(unwrapped.getField("indices"), index)
+
+    sf.when(index >= 0 && index < size,
+      sf.when(isDense, sf.get(values, index))
+        .when(sparseIdx >= 0, sf.get(values, sparseIdx))
+        .otherwise(sf.lit(0.0))
+    ).otherwise(
+      sf.raise_error(sf.printf(
+        sf.lit(s"Vector index must be in [0, %s), but got %s"), size, index)
+      )
+    )
+  }
+
+  // input: array<double>, output: int
+  private[ml] def array_argmax(arr: Column): Column = {
+    sf.aggregate(
+      arr,
+      sf.struct(
+        sf.lit(Double.NegativeInfinity).alias("v"), // max value
+        sf.lit(-1).alias("i"),              // index of max value
+        sf.lit(0).alias("j")),              // current index
+      (acc, vv) => {
+        val v = acc.getField("v")
+        val i = acc.getField("i")
+        val j = acc.getField("j")
+        sf.when((!vv.isNaN) && (!vv.isNull) && (vv > v),
+            sf.struct(vv.alias("v"), j.alias("i"), j + 1))
+          .otherwise(sf.struct(v.alias("v"), i.alias("i"), j + 1))
+      },
+      acc => acc.getField("i")
+    )
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 1257d2ccfbfb..b8fcb1ffcbfe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -258,20 +258,23 @@ object MLUtils extends Logging {
    */
   @Since("3.1.0")
   def kFold(df: DataFrame, numFolds: Int, foldColName: String): 
Array[(RDD[Row], RDD[Row])] = {
-    val foldCol = df.col(foldColName)
-    val checker = udf { foldNum: Int =>
-      // Valid fold number is in range [0, numFolds).
-      if (foldNum < 0 || foldNum >= numFolds) {
-        throw new SparkException(s"Fold number must be in range [0, 
$numFolds), but got $foldNum.")
-      }
-      true
-    }
+    val checked = df.withColumn(
+      foldColName,
+      when((lit(0) <= col(foldColName)) && (col(foldColName) < lit(numFolds)), 
col(foldColName))
+        .otherwise(
+          raise_error(
+            printf(
+              lit(s"Fold number must be in range [0, $numFolds), but got 
%s."), col(foldColName))
+          )
+        )
+    )
+
     (0 until numFolds).map { fold =>
-      val training = df
-        .filter(checker(foldCol) && foldCol =!= fold)
+      val training = checked
+        .filter(col(foldColName) =!= fold)
         .drop(foldColName).rdd
-      val validation = df
-        .filter(checker(foldCol) && foldCol === fold)
+      val validation = checked
+        .filter(col(foldColName) === fold)
         .drop(foldColName).rdd
       if (training.isEmpty()) {
         throw new SparkException(s"The training data at fold $fold is empty.")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala
index 32986a1345d6..7fcb1d2fbfbb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.ml
 
 import org.apache.spark.SparkException
-import org.apache.spark.ml.functions.{array_to_vector, vector_to_array}
+import org.apache.spark.ml.functions._
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.util.MLTest
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
@@ -102,4 +102,32 @@ class FunctionsSuite extends MLTest {
     val resultVec3 = 
df3.select(array_to_vector(col("c1"))).collect()(0)(0).asInstanceOf[Vector]
     assert(resultVec3 === Vectors.dense(Array(1.0, 2.0)))
   }
+
+  test("test get_vector") {
+    val df = Seq(
+      (Vectors.dense(1.0, 2.0, 3.0), 0),
+      (Vectors.dense(1.0, 2.0, 3.0), 1),
+      (Vectors.dense(1.0, 2.0, 3.0), 2),
+      (Vectors.sparse(3, Seq((0, -1.0))), 0),
+      (Vectors.sparse(3, Seq((0, -1.0))), 1),
+      (Vectors.sparse(3, Seq((0, -1.0))), 2)
+    ).toDF("vec", "idx")
+
+    val result = df.select(vector_get(col("vec"), 
col("idx"))).as[Double].collect()
+    assert(result === Array(1.0, 2.0, 3.0, -1.0, 0.0, 0.0))
+  }
+
+  test("test array_argmax") {
+    val df = Seq(
+      Tuple1.apply(Array(1.0, 2.0, 3.0)),
+      Tuple1.apply(Array(1.0, 3.0, 2.0)),
+      Tuple1.apply(Array(3.0, 2.0, 1.0)),
+      Tuple1.apply(Array(1.0, 3.0, 3.0)),
+      Tuple1.apply(Array(3.0, 3.0, 3.0)),
+      Tuple1.apply(Array.emptyDoubleArray)
+    ).toDF("arr")
+
+    val result = df.select(array_argmax(col("arr"))).as[Int].collect()
+    assert(result === Array(2, 1, 0, 1, 0, -1))
+  }
 }
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 1a02e26b9260..c4dc3b171572 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -24,7 +24,7 @@ import scala.io.Source
 
 import com.google.common.io.Files
 
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException}
 import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, 
Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLUtils._
@@ -378,7 +378,7 @@ class MLUtilsSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
   test("kFold with fold column: invalid fold numbers") {
     val data = sc.parallelize(Seq(0, 1, 2), 2).toDF( "fold")
-    val err1 = intercept[SparkException] {
+    val err1 = intercept[SparkRuntimeException] {
       kFold(data, 2, "fold")(0)._1.collect()
     }
     assert(err1.getMessage.contains("Fold number must be in range [0, 2), but 
got 2."))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to