This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new abe837d5df33 [SPARK-50844][ML][CONNECT] Make model be loaded by 
ServiceLoader when loading
abe837d5df33 is described below

commit abe837d5df33d93da07f9786aff81b57c0f1d9ab
Author: Bobby Wang <[email protected]>
AuthorDate: Tue Jan 21 09:47:57 2025 +0800

    [SPARK-50844][ML][CONNECT] Make model be loaded by ServiceLoader when 
loading
    
    ### What changes were proposed in this pull request?
    
    Currently ml connect discoveries Estimators, Evaluators and Transformers 
(not including Model) by ServiceLoader which could make it more secure. This PR 
adds a public no-args constructors for some Models which could make model to be 
loaded by ServiceLoader too.
    
    ### Why are the changes needed?
    
    This PR is a follow up of https://github.com/apache/spark/pull/49503, and 
it will be more safe to discovery a Model by ServiceLoader when loading a model.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    
    The existing tests pass
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49569 from wbo4958/ml.model.serviceloader.
    
    Authored-by: Bobby Wang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 79347e08e2e77d4f85faf65d7ca2016f7d8ffa99)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../services/org.apache.spark.ml.Transformer       | 20 +++++++
 .../ml/classification/DecisionTreeClassifier.scala |  4 ++
 .../spark/ml/classification/GBTClassifier.scala    |  5 ++
 .../ml/classification/LogisticRegression.scala     |  4 ++
 .../ml/classification/RandomForestClassifier.scala |  4 ++
 .../spark/ml/clustering/BisectingKMeans.scala      |  4 ++
 .../org/apache/spark/ml/clustering/KMeans.scala    |  5 ++
 .../org/apache/spark/ml/recommendation/ALS.scala   |  4 ++
 .../ml/regression/DecisionTreeRegressor.scala      |  4 ++
 .../apache/spark/ml/regression/GBTRegressor.scala  |  5 ++
 .../spark/ml/regression/LinearRegression.scala     |  4 ++
 .../ml/regression/RandomForestRegressor.scala      |  4 ++
 .../main/scala/org/apache/spark/ml/tree/Node.scala |  5 ++
 .../apache/spark/sql/connect/ml/MLHandler.scala    | 14 +++--
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  | 62 +++++++++++++++++-----
 .../services/org.apache.spark.ml.Transformer       |  6 +--
 .../org/apache/spark/sql/connect/ml/MLHelper.scala |  2 +
 .../org/apache/spark/sql/connect/ml/MLSuite.scala  |  6 +--
 18 files changed, 138 insertions(+), 24 deletions(-)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 004ec8aeff8c..392115be98ba 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -18,3 +18,23 @@
 # Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
non-model transformer.
 # So register the supported transformer here if you're trying to add a new one.
 org.apache.spark.ml.feature.VectorAssembler
+
+########### Model for loading
+# classification
+org.apache.spark.ml.classification.LogisticRegressionModel
+org.apache.spark.ml.classification.DecisionTreeClassificationModel
+org.apache.spark.ml.classification.RandomForestClassificationModel
+org.apache.spark.ml.classification.GBTClassificationModel
+
+# regression
+org.apache.spark.ml.regression.LinearRegressionModel
+org.apache.spark.ml.regression.DecisionTreeRegressionModel
+org.apache.spark.ml.regression.RandomForestRegressionModel
+org.apache.spark.ml.regression.GBTRegressionModel
+
+# clustering
+org.apache.spark.ml.clustering.KMeansModel
+org.apache.spark.ml.clustering.BisectingKMeansModel
+
+# recommendation
+org.apache.spark.ml.recommendation.ALSModel
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index c5f1d7f39b6b..761741e7f42d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -192,6 +192,10 @@ class DecisionTreeClassificationModel private[ml] (
   private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
     this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Node.dummyNode, 0, 0)
+
   override def predict(features: Vector): Double = {
     rootNode.predictImpl(features).prediction
   }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 3910beda3d0a..8ed52d5e09e0 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -272,6 +272,11 @@ class GBTClassificationModel private[ml](
   def this(uid: String, _trees: Array[DecisionTreeRegressionModel], 
_treeWeights: Array[Double]) =
     this(uid, _trees, _treeWeights, -1, 2)
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Identifiable.randomUID("gbtc"),
+    Array(new DecisionTreeRegressionModel), Array(0.0))
+
   @Since("1.4.0")
   override def trees: Array[DecisionTreeRegressionModel] = _trees
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 745cb61bb7aa..7432473e4aa9 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1076,6 +1076,10 @@ class LogisticRegressionModel private[spark] (
     this(uid, new DenseMatrix(1, coefficients.size, coefficients.toArray, 
isTransposed = true),
       Vectors.dense(intercept), 2, isMultinomial = false)
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Identifiable.randomUID("logreg"), 
Vectors.zeros(0), 0)
+
   /**
    * A vector of model coefficients for "binomial" logistic regression. If 
this model was trained
    * using the "multinomial" family then an exception is thrown.
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 9295425f9d6b..0833ad0d402b 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -255,6 +255,10 @@ class RandomForestClassificationModel private[ml] (
       numClasses: Int) =
     this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Array(new DecisionTreeClassificationModel), 0, 
0)
+
   @Since("1.4.0")
   override def trees: Array[DecisionTreeClassificationModel] = _trees
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index b4f1565362b0..d0e5cb42c41c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -96,6 +96,10 @@ class BisectingKMeansModel private[ml] (
   extends Model[BisectingKMeansModel] with BisectingKMeansParams with 
MLWritable
   with HasTrainingSummary[BisectingKMeansSummary] {
 
+  @Since("4.0.0")
+  private[ml] def this() = this(Identifiable.randomUID("bisecting-kmeans"),
+    new MLlibBisectingKMeansModel(null))
+
   @Since("3.0.0")
   lazy val numFeatures: Int = parentModel.clusterCenters.head.size
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 50fb18bb620a..17d34a277af2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -138,6 +138,11 @@ class KMeansModel private[ml] (
   extends Model[KMeansModel] with KMeansParams with GeneralMLWritable
     with HasTrainingSummary[KMeansSummary] {
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Identifiable.randomUID("kmeans"),
+    new MLlibKMeansModel(clusterCenters = null))
+
   @Since("3.0.0")
   lazy val numFeatures: Int = parentModel.clusterCenters.head.size
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 5899bf891ec9..4120e16794a8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -280,6 +280,10 @@ class ALSModel private[ml] (
     @transient val itemFactors: DataFrame)
   extends Model[ALSModel] with ALSModelParams with MLWritable {
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Identifiable.randomUID("als"), 0, null, null)
+
   /** @group setParam */
   @Since("1.4.0")
   def setUserCol(value: String): this.type = set(userCol, value)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index dace99f214b1..2c692d33a38d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -187,6 +187,10 @@ class DecisionTreeRegressionModel private[ml] (
   private[ml] def this(rootNode: Node, numFeatures: Int) =
     this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Node.dummyNode, 0)
+
   override def predict(features: Vector): Double = {
     rootNode.predictImpl(features).prediction
   }
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 732bfcbd671e..c2c672a7fa60 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -242,6 +242,11 @@ class GBTRegressionModel private[ml](
   def this(uid: String, _trees: Array[DecisionTreeRegressionModel], 
_treeWeights: Array[Double]) =
     this(uid, _trees, _treeWeights, -1)
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Identifiable.randomUID("gbtr"),
+    Array(new DecisionTreeRegressionModel), Array(0.0))
+
   @Since("1.4.0")
   override def trees: Array[DecisionTreeRegressionModel] = _trees
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 4f74dd734e8f..2afcb52dbb4f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -702,6 +702,10 @@ class LinearRegressionModel private[ml] (
   private[ml] def this(uid: String, coefficients: Vector, intercept: Double) =
     this(uid, coefficients, intercept, 1.0)
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Identifiable.randomUID("linReg"), 
Vectors.zeros(0), 0.0, 0.0)
+
   override val numFeatures: Int = coefficients.size
 
   /**
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 4135afb5ed0b..b0409c916a05 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -212,6 +212,10 @@ class RandomForestRegressionModel private[ml] (
   private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: 
Int) =
     this(Identifiable.randomUID("rfr"), trees, numFeatures)
 
+  // For ml connect only
+  @Since("4.0.0")
+  private[ml] def this() = this(Array(new DecisionTreeRegressionModel), 0)
+
   @Since("1.4.0")
   override def trees: Array[DecisionTreeRegressionModel] = _trees
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index dfa4961d9ffb..697d98953839 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -105,6 +105,11 @@ private[ml] object Node {
         split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
     }
   }
+
+  // A dummy node used for ml connect only
+  val dummyNode: Node = {
+    new LeafNode(0.0, 0.0, ImpurityCalculator.getCalculator("gini", 
Array.empty, 0))
+  }
 }
 
 /**
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index cc89079aeca3..c66a2e7004b9 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -204,7 +204,7 @@ private[connect] object MLHandler extends Logging {
         val path = mlCommand.getRead.getPath
 
         if (operator.getType == proto.MlOperator.OperatorType.MODEL) {
-          val model = MLUtils.load(sessionHolder, name, 
path).asInstanceOf[Model[_]]
+          val model = MLUtils.loadTransformer(sessionHolder, name, path)
           val id = mlCache.register(model)
           proto.MlCommandResult
             .newBuilder()
@@ -218,15 +218,21 @@ private[connect] object MLHandler extends Logging {
 
         } else if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR 
||
           operator.getType == proto.MlOperator.OperatorType.EVALUATOR) {
-          val operator = MLUtils.load(sessionHolder, name, 
path).asInstanceOf[Params]
+          val mlOperator = {
+            if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR) {
+              MLUtils.loadEstimator(sessionHolder, name, 
path).asInstanceOf[Params]
+            } else {
+              MLUtils.loadEvaluator(sessionHolder, name, 
path).asInstanceOf[Params]
+            }
+          }
           proto.MlCommandResult
             .newBuilder()
             .setOperatorInfo(
               proto.MlCommandResult.MlOperatorInfo
                 .newBuilder()
                 .setName(name)
-                .setUid(operator.uid)
-                .setParams(Serializer.serializeParams(operator)))
+                .setUid(mlOperator.uid)
+                .setParams(Serializer.serializeParams(mlOperator)))
             .build()
         } else {
           throw MlUnsupportedException(s"${operator.getType} read not 
supported")
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 000a01c232bd..86dd013b9d98 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -29,13 +29,13 @@ import org.apache.spark.ml.{Estimator, Transformer}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.linalg._
 import org.apache.spark.ml.param.Params
-import org.apache.spark.ml.util.{MLReadable, MLWritable}
+import org.apache.spark.ml.util.MLWritable
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
LiteralValueProtoConverter}
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
 import org.apache.spark.sql.connect.service.SessionHolder
-import org.apache.spark.util.{SparkClassUtils, Utils}
+import org.apache.spark.util.Utils
 
 private[ml] object MLUtils {
 
@@ -352,28 +352,62 @@ private[ml] object MLUtils {
   }
 
   /**
-   * Call "load" function on the ML operator given the operator name
+   * Load an ML component (Estimator, Transformer, or Evaluator) from the 
given path.
    *
+   * @param sessionHolder
+   *   the session holder
    * @param className
    *   the ML operator name
    * @param path
    *   the path to be loaded
+   * @param operatorClass
+   *   the class type of the ML operator (Estimator, Transformer, or Evaluator)
+   * @tparam T
+   *   the type of the ML operator
    * @return
-   *   the ML instance
+   *   the instance of the ML operator
    */
-  def load(sessionHolder: SessionHolder, className: String, path: String): 
Object = {
+  private def loadOperator[T](
+      sessionHolder: SessionHolder,
+      className: String,
+      path: String,
+      operatorClass: Class[T]): T = {
     val name = replaceOperator(sessionHolder, className)
-
-    // It's the companion object of the corresponding spark operators to load.
-    val objectCls = SparkClassUtils.classForName(name + "$")
-    val mlReadableClass = classOf[MLReadable[_]]
-    // Make sure it is implementing MLReadable
-    if (!mlReadableClass.isAssignableFrom(objectCls)) {
-      throw MlUnsupportedException(s"$name must implement MLReadable.")
+    val operators = loadOperators(operatorClass)
+    if (operators.isEmpty || !operators.contains(name)) {
+      throw MlUnsupportedException(s"Unsupported read for $name")
     }
+    operators(name)
+      .getMethod("load", classOf[String])
+      .invoke(null, path)
+      .asInstanceOf[T]
+  }
+
+  /**
+   * Load an estimator from the specified path.
+   */
+  def loadEstimator(
+      sessionHolder: SessionHolder,
+      className: String,
+      path: String): Estimator[_] = {
+    loadOperator(sessionHolder, className, path, classOf[Estimator[_]])
+  }
 
-    val loadedMethod = SparkClassUtils.classForName(name).getMethod("load", 
classOf[String])
-    loadedMethod.invoke(null, path)
+  /**
+   * Load a transformer from the specified path.
+   */
+  def loadTransformer(
+      sessionHolder: SessionHolder,
+      className: String,
+      path: String): Transformer = {
+    loadOperator(sessionHolder, className, path, classOf[Transformer])
+  }
+
+  /**
+   * Load an evaluator from the specified path.
+   */
+  def loadEvaluator(sessionHolder: SessionHolder, className: String, path: 
String): Evaluator = {
+    loadOperator(sessionHolder, className, path, classOf[Evaluator])
   }
 
   // Since we're using reflection way to get the attribute, in order not to
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
similarity index 84%
copy from 
mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
copy to 
sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
index 004ec8aeff8c..92d3a7018054 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ 
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -15,6 +15,6 @@
 # limitations under the License.
 #
 
-# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
non-model transformer.
-# So register the supported transformer here if you're trying to add a new one.
-org.apache.spark.ml.feature.VectorAssembler
+# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
estimators.
+# So register the supported estimator here if you're trying to add a new one.
+org.apache.spark.sql.connect.ml.MyLogisticRegressionModel
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
index 844e85fa03b6..ef5b8a59a58b 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
@@ -149,6 +149,8 @@ class MyLogisticRegressionModel(
     with HasFakedParam
     with DefaultParamsWritable {
 
+  private[spark] def this() = this("MyLogisticRegressionModel", 1.0f, 1.0f)
+
   def setFakeParam(v: Int): this.type = set(fakeParam, v)
 
   def setMaxIter(v: Int): this.type = set(maxIter, v)
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index 49dcd7dbe9ad..aee0759d0d3a 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -322,7 +322,7 @@ class MLSuite extends MLHelper {
     }
   }
 
-  test("ML operator must implement MLReadable for loading") {
+  test("Model must be registered into ServiceLoader when loading") {
     val thrown = intercept[MlUnsupportedException] {
       val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
       val readCmd = proto.MlCommand
@@ -339,8 +339,8 @@ class MLSuite extends MLHelper {
       MLHandler.handleMlCommand(sessionHolder, readCmd)
     }
     assert(
-      
thrown.message.contains("org.apache.spark.sql.connect.ml.NotImplementingMLReadble
 " +
-        "must implement MLReadable"))
+      thrown.message.contains("Unsupported read for " +
+        "org.apache.spark.sql.connect.ml.NotImplementingMLReadble"))
   }
 
   test("RegressionEvaluator works") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to