This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new dd771ccd2845 [SPARK-50901][ML][PYTHON][CONNECT] Support Transformer
`VectorAssembler`
dd771ccd2845 is described below
commit dd771ccd28450c25159d5b4d391cd7acbe3e32da
Author: Bobby Wang <[email protected]>
AuthorDate: Wed Jan 22 09:28:55 2025 +0800
[SPARK-50901][ML][PYTHON][CONNECT] Support Transformer `VectorAssembler`
### What changes were proposed in this pull request?
This PR adds support transformer on ml connect. Currently, VectorAssembler
is fully supported.
### Why are the changes needed?
for feature parity
### Does this PR introduce _any_ user-facing change?
Yes, new algorithms supported on connect
### How was this patch tested?
The newly added test can pass
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49588 from wbo4958/transformer.
Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 3450184d332c3ff6203a200df0dddeced7ec9fd4)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Transformer | 1 +
python/pyspark/ml/connect/readwrite.py | 48 +++---
python/pyspark/ml/tests/test_feature.py | 44 ++++++
.../apache/spark/sql/connect/ml/MLHandler.scala | 16 +-
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 24 +++
.../services/org.apache.spark.ml.Transformer | 1 +
.../spark/sql/connect/ml/MLBackendSuite.scala | 159 +++++++-------------
.../org/apache/spark/sql/connect/ml/MLHelper.scala | 160 +++++++++++++++++++-
.../org/apache/spark/sql/connect/ml/MLSuite.scala | 161 +++++++--------------
9 files changed, 374 insertions(+), 240 deletions(-)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 4b029ae610d7..a25c03ed2b8e 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -17,6 +17,7 @@
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml
non-model transformer.
# So register the supported transformer here if you're trying to add a new one.
+########### Transformers
org.apache.spark.ml.feature.VectorAssembler
########### Model for loading
diff --git a/python/pyspark/ml/connect/readwrite.py
b/python/pyspark/ml/connect/readwrite.py
index 1f514c653aa0..41ae66d32108 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from typing import cast, Type, TYPE_CHECKING
+from typing import cast, Type, TYPE_CHECKING, Union
import pyspark.sql.connect.proto as pb2
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize,
deserialize_param
@@ -37,7 +37,7 @@ class RemoteMLWriter(MLWriter):
raise RuntimeError("Accessing SparkContext is not supported on
Connect")
def save(self, path: str) -> None:
- from pyspark.ml.wrapper import JavaModel, JavaEstimator
+ from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.sql.connect.session import SparkSession
@@ -57,35 +57,29 @@ class RemoteMLWriter(MLWriter):
should_overwrite=self.shouldOverwrite,
options=self.optionMap,
)
- elif isinstance(self._instance, JavaEstimator):
- estimator = cast("JavaEstimator", self._instance)
- params = serialize_ml_params(estimator, session.client)
- assert isinstance(estimator._java_obj, str)
- writer = pb2.MlCommand.Write(
- operator=pb2.MlOperator(
- name=estimator._java_obj, uid=estimator.uid,
type=pb2.MlOperator.ESTIMATOR
- ),
- params=params,
- path=path,
- should_overwrite=self.shouldOverwrite,
- options=self.optionMap,
- )
- elif isinstance(self._instance, JavaEvaluator):
- evaluator = cast("JavaEvaluator", self._instance)
- params = serialize_ml_params(evaluator, session.client)
- assert isinstance(evaluator._java_obj, str)
+ else:
+ operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator]
+ if isinstance(self._instance, JavaEstimator):
+ ml_type = pb2.MlOperator.ESTIMATOR
+ operator = cast("JavaEstimator", self._instance)
+ elif isinstance(self._instance, JavaEvaluator):
+ ml_type = pb2.MlOperator.EVALUATOR
+ operator = cast("JavaEvaluator", self._instance)
+ elif isinstance(self._instance, JavaTransformer):
+ ml_type = pb2.MlOperator.TRANSFORMER
+ operator = cast("JavaTransformer", self._instance)
+ else:
+ raise NotImplementedError(f"Unsupported writing for
{self._instance}")
+
+ params = serialize_ml_params(operator, session.client)
+ assert isinstance(operator._java_obj, str)
writer = pb2.MlCommand.Write(
- operator=pb2.MlOperator(
- name=evaluator._java_obj, uid=evaluator.uid,
type=pb2.MlOperator.EVALUATOR
- ),
+ operator=pb2.MlOperator(name=operator._java_obj,
uid=operator.uid, type=ml_type),
params=params,
path=path,
should_overwrite=self.shouldOverwrite,
options=self.optionMap,
)
- else:
- raise NotImplementedError(f"Unsupported writing for
{self._instance}")
-
command = pb2.Command()
command.ml_command.write.CopyFrom(writer)
session.client.execute_command(command)
@@ -98,7 +92,7 @@ class RemoteMLReader(MLReader[RL]):
def load(self, path: str) -> RL:
from pyspark.sql.connect.session import SparkSession
- from pyspark.ml.wrapper import JavaModel, JavaEstimator
+ from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
session = SparkSession.getActiveSession()
@@ -116,6 +110,8 @@ class RemoteMLReader(MLReader[RL]):
ml_type = pb2.MlOperator.ESTIMATOR
elif issubclass(self._clazz, JavaEvaluator):
ml_type = pb2.MlOperator.EVALUATOR
+ elif issubclass(self._clazz, JavaTransformer):
+ ml_type = pb2.MlOperator.TRANSFORMER
else:
raise ValueError(f"Unsupported reading for
{java_qualified_class_name}")
diff --git a/python/pyspark/ml/tests/test_feature.py
b/python/pyspark/ml/tests/test_feature.py
index a46fdd22e2bc..51c7a3631e1b 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -40,6 +40,7 @@ from pyspark.ml.feature import (
StringIndexerModel,
TargetEncoder,
VectorSizeHint,
+ VectorAssembler,
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
from pyspark.sql import Row
@@ -48,6 +49,49 @@ from pyspark.testing.mlutils import check_params,
SparkSessionTestCase
class FeatureTestsMixin:
+ def test_vector_assembler(self):
+ # Create a DataFrame
+ df = (
+ self.spark.createDataFrame(
+ [
+ (1, 5.0, 6.0, 7.0),
+ (2, 1.0, 2.0, None),
+ (3, 3.0, float("nan"), 4.0),
+ ],
+ ["index", "a", "b", "c"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("index")
+ )
+
+ # Initialize VectorAssembler
+ vec_assembler =
VectorAssembler(outputCol="features").setInputCols(["a", "b", "c"])
+ output = vec_assembler.transform(df)
+ self.assertEqual(output.columns, ["index", "a", "b", "c", "features"])
+ self.assertEqual(output.head().features, Vectors.dense([5.0, 6.0,
7.0]))
+
+ # Set custom parameters and transform the DataFrame
+ params = {vec_assembler.inputCols: ["b", "a"],
vec_assembler.outputCol: "vector"}
+ self.assertEqual(
+ vec_assembler.transform(df, params).head().vector,
Vectors.dense([6.0, 5.0])
+ )
+
+ # read/write
+ with tempfile.TemporaryDirectory(prefix="read_write") as tmp_dir:
+ vec_assembler.write().overwrite().save(tmp_dir)
+ vec_assembler2 = VectorAssembler.load(tmp_dir)
+ self.assertEqual(str(vec_assembler), str(vec_assembler2))
+
+ # Initialize a new VectorAssembler with handleInvalid="keep"
+ vec_assembler3 = VectorAssembler(
+ inputCols=["a", "b", "c"], outputCol="features",
handleInvalid="keep"
+ )
+ self.assertEqual(vec_assembler3.transform(df).count(), 3)
+
+ # Update handleInvalid to "skip" and transform the DataFrame
+ vec_assembler3.setParams(handleInvalid="skip")
+ self.assertEqual(vec_assembler3.transform(df).count(), 1)
+
def test_standard_scaler(self):
df = (
self.spark.createDataFrame(
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index c66a2e7004b9..ea6303937bc3 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -191,6 +191,15 @@ private[connect] object MLHandler extends Logging {
case other => throw MlUnsupportedException(s"Evaluator
$other is not writable")
}
+ case proto.MlOperator.OperatorType.TRANSFORMER =>
+ val transformer =
+ MLUtils.getTransformer(sessionHolder, writer.getOperator,
params)
+ transformer match {
+ case writable: MLWritable => MLUtils.write(writable,
mlCommand.getWrite)
+ case other =>
+ throw MlUnsupportedException(s"Transformer $other is not
writable")
+ }
+
case _ =>
throw MlUnsupportedException(s"Operator $operatorName is not
supported")
}
@@ -217,12 +226,15 @@ private[connect] object MLHandler extends Logging {
.build()
} else if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR
||
- operator.getType == proto.MlOperator.OperatorType.EVALUATOR) {
+ operator.getType == proto.MlOperator.OperatorType.EVALUATOR ||
+ operator.getType == proto.MlOperator.OperatorType.TRANSFORMER) {
val mlOperator = {
if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR) {
MLUtils.loadEstimator(sessionHolder, name,
path).asInstanceOf[Params]
- } else {
+ } else if (operator.getType ==
proto.MlOperator.OperatorType.EVALUATOR) {
MLUtils.loadEvaluator(sessionHolder, name,
path).asInstanceOf[Params]
+ } else {
+ MLUtils.loadTransformer(sessionHolder, name,
path).asInstanceOf[Params]
}
}
proto.MlCommandResult
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 04dbb60cb1ed..34a0317f55af 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -338,6 +338,30 @@ private[ml] object MLUtils {
getInstance[Transformer](name, uid, transformers, Some(params))
}
+ /**
+ * Get the Transformer instance according to the proto information
+ *
+ * @param sessionHolder
+ * session holder to hold the Spark Connect session state
+ * @param operator
+ * MlOperator information
+ * @param params
+ * The optional parameters of the transformer
+ * @return
+ * the transformer
+ */
+ def getTransformer(
+ sessionHolder: SessionHolder,
+ operator: proto.MlOperator,
+ params: Option[proto.MlParams]): Transformer = {
+ val name = replaceOperator(sessionHolder, operator.getName)
+ val uid = operator.getUid
+
+ // Load the transformers by ServiceLoader everytime
+ val transformers = loadOperators(classOf[Transformer])
+ getInstance[Transformer](name, uid, transformers, params)
+ }
+
/**
* Get the Evaluator instance according to the proto information
*
diff --git
a/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
index 92d3a7018054..e74b087fa8da 100644
---
a/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
+++
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -18,3 +18,4 @@
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml
estimators.
# So register the supported estimator here if you're trying to add a new one.
org.apache.spark.sql.connect.ml.MyLogisticRegressionModel
+org.apache.spark.sql.connect.ml.MyVectorAssembler
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
index 7cd95f9f657d..5b2b5e6dd793 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.connect.ml
-import java.io.File
+import scala.jdk.CollectionConverters.ListHasAsScala
import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
-import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.util.Utils
@@ -79,43 +78,12 @@ class MLBackendSuite extends MLHelper {
assert(model.intercept == 3.5f)
assert(model.coefficients == 4.6f)
- // read/write
- val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
- try {
- val path = new File(tempDir,
Identifiable.randomUID("LogisticRegression")).getPath
- val writeCmd = proto.MlCommand
- .newBuilder()
- .setWrite(
- proto.MlCommand.Write
- .newBuilder()
- .setOperator(getLogisticRegressionBuilder)
- .setParams(getMaxIterBuilder)
- .setPath(path)
- .setShouldOverwrite(true))
- .build()
- MLHandler.handleMlCommand(sessionHolder, writeCmd)
+ val ret = readWrite(sessionHolder, getLogisticRegressionBuilder,
getMaxIterBuilder)
- val readCmd = proto.MlCommand
- .newBuilder()
- .setRead(
- proto.MlCommand.Read
- .newBuilder()
- .setOperator(getLogisticRegressionBuilder)
- .setPath(path))
- .build()
-
- val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
-
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
-
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("maxIter"))
- assert(
- ret.getOperatorInfo.getParams.getParamsMap.get("maxIter").getInteger
- == 2)
- assert(
-
ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
- == 101010)
- } finally {
- Utils.deleteRecursively(tempDir)
- }
+
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
+ assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("maxIter"))
+
assert(ret.getOperatorInfo.getParams.getParamsMap.get("maxIter").getInteger ==
2)
+
assert(ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
== 101010)
}
}
@@ -138,37 +106,11 @@ class MLBackendSuite extends MLHelper {
.setParams(getMaxIterBuilder))
.build()
val fitRet = MLHandler.handleMlCommand(sessionHolder, fitCommand)
- val modelId = fitRet.getOperatorInfo.getObjRef.getId
-
- // Write a model
- val path = new File(tempDir,
Identifiable.randomUID("LogisticRegression")).getPath
- val writeCmd = proto.MlCommand
- .newBuilder()
- .setWrite(
- proto.MlCommand.Write
- .newBuilder()
- .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
- .setPath(path)
- .setShouldOverwrite(true))
- .build()
- MLHandler.handleMlCommand(sessionHolder, writeCmd)
-
- // read a model
- val readCmd = proto.MlCommand
- .newBuilder()
- .setRead(
- proto.MlCommand.Read
- .newBuilder()
- .setOperator(proto.MlOperator
- .newBuilder()
-
.setName("org.apache.spark.ml.classification.LogisticRegressionModel")
- .setType(proto.MlOperator.OperatorType.MODEL))
- .setPath(path))
- .build()
- val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
-
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
-
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("maxIter"))
+ val ret = readWrite(
+ sessionHolder,
+ fitRet.getOperatorInfo.getObjRef.getId,
+ "org.apache.spark.ml.classification.LogisticRegressionModel")
assert(
ret.getOperatorInfo.getParams.getParamsMap.get("maxIter").getInteger
== 2)
@@ -203,43 +145,54 @@ class MLBackendSuite extends MLHelper {
val evalResult = MLHandler.handleMlCommand(sessionHolder, evalCmd)
assert(evalResult.getParam.getDouble == 1.11)
- // read/write
- val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
- try {
- val path = new File(tempDir,
Identifiable.randomUID("Evaluator")).getPath
- val writeCmd = proto.MlCommand
- .newBuilder()
- .setWrite(
- proto.MlCommand.Write
- .newBuilder()
- .setOperator(getRegressorEvaluator)
- .setParams(getMetricName)
- .setPath(path)
- .setShouldOverwrite(true))
- .build()
- MLHandler.handleMlCommand(sessionHolder, writeCmd)
+ val ret = readWrite(sessionHolder, getRegressorEvaluator, getMetricName)
- val readCmd = proto.MlCommand
- .newBuilder()
- .setRead(
- proto.MlCommand.Read
- .newBuilder()
- .setOperator(getRegressorEvaluator)
- .setPath(path))
- .build()
+
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
+
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("metricName"))
+ assert(
+ ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString
+ == "mae")
+ assert(
+ ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
+ == 101010)
+ }
+ }
- val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
-
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
-
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("metricName"))
- assert(
-
ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString
- == "mae")
- assert(
-
ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
- == 101010)
- } finally {
- Utils.deleteRecursively(tempDir)
- }
+ test("ML backend: transformer works") {
+ withSparkConf(
+ Connect.CONNECT_ML_BACKEND_CLASSES.key ->
+ "org.apache.spark.sql.connect.ml.MyMlBackend") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+
+ val transformerRelation = proto.MlRelation
+ .newBuilder()
+ .setTransform(
+ proto.MlRelation.Transform
+ .newBuilder()
+ .setTransformer(getVectorAssembler)
+ .setParams(getVectorAssemblerParams)
+ .setInput(createMultiColumnLocalRelationProto))
+ .build()
+
+ val transRet = MLHandler.transformMLRelation(transformerRelation,
sessionHolder)
+ // MyVectorAssembler has hacked the transform function
+ Seq("a", "b", "c", "new").foreach(n =>
assert(transRet.schema.names.contains(n)))
+
+ val ret = readWrite(sessionHolder, getVectorAssembler,
getVectorAssemblerParams)
+ assert(
+
ret.getOperatorInfo.getParams.getParamsMap.get("handleInvalid").getString
+ == "skip")
+ assert(
+ ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
+ == 101010)
+ assert(
+ ret.getOperatorInfo.getParams.getParamsMap
+ .get("inputCols")
+ .getArray
+ .getElementsList
+ .asScala
+ .map(_.getString)
+ .toArray sameElements Array("a", "b", "c"))
}
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
index ef5b8a59a58b..9383794b38dc 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
@@ -17,15 +17,16 @@
package org.apache.spark.sql.connect.ml
+import java.io.File
import java.util.Optional
import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
-import org.apache.spark.ml.param.shared.HasMaxIter
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols,
HasMaxIter, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable,
Identifiable, MLReadable, MLReader}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.InternalRow
@@ -33,7 +34,10 @@ import
org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connect.planner.SparkConnectPlanTest
import org.apache.spark.sql.connect.plugin.MLBackendPlugin
-import org.apache.spark.sql.types.{DoubleType, FloatType, Metadata,
StructField, StructType}
+import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType,
Metadata, StructField, StructType}
+import org.apache.spark.util.Utils
trait MLHelper extends SparkFunSuite with SparkConnectPlanTest {
@@ -74,6 +78,32 @@ trait MLHelper extends SparkFunSuite with
SparkConnectPlanTest {
createLocalRelationProto(schema, inputRows)
}
+ def createMultiColumnLocalRelationProto: proto.Relation = {
+ val rows = Seq(InternalRow(1, 0, 3))
+ val schema = StructType(
+ Seq(
+ StructField("a", IntegerType),
+ StructField("b", IntegerType),
+ StructField("c", IntegerType)))
+ val inputRows = rows.map { row =>
+ val proj = UnsafeProjection.create(schema)
+ proj(row).copy()
+ }
+ createLocalRelationProto(schema, inputRows)
+ }
+
+ def getLogisticRegression: proto.MlOperator.Builder =
+ proto.MlOperator
+ .newBuilder()
+ .setName("org.apache.spark.ml.classification.LogisticRegression")
+ .setUid("LogisticRegression")
+ .setType(proto.MlOperator.OperatorType.ESTIMATOR)
+
+ def getMaxIter: proto.MlParams.Builder =
+ proto.MlParams
+ .newBuilder()
+ .putParams("maxIter",
proto.Expression.Literal.newBuilder().setInteger(2).build())
+
def getRegressorEvaluator: proto.MlOperator.Builder =
proto.MlOperator
.newBuilder()
@@ -96,6 +126,109 @@ trait MLHelper extends SparkFunSuite with
SparkConnectPlanTest {
.addMethods(proto.Fetch.Method.newBuilder().setMethod(method)))
.build()
}
+
+ def getArrayStrings: proto.Expression.Literal =
+ proto.Expression.Literal
+ .newBuilder()
+ .setArray(
+ proto.Expression.Literal.Array
+ .newBuilder()
+ .setElementType(proto.DataType
+ .newBuilder()
+ .setString(proto.DataType.String.getDefaultInstance)
+ .build())
+ .addElements(proto.Expression.Literal.newBuilder().setString("a"))
+ .addElements(proto.Expression.Literal.newBuilder().setString("b"))
+ .addElements(proto.Expression.Literal.newBuilder().setString("c"))
+ .build())
+ .build()
+
+ def getVectorAssembler: proto.MlOperator.Builder =
+ proto.MlOperator
+ .newBuilder()
+ .setUid("vec")
+ .setName("org.apache.spark.ml.feature.VectorAssembler")
+ .setType(proto.MlOperator.OperatorType.TRANSFORMER)
+
+ def getVectorAssemblerParams: proto.MlParams.Builder =
+ proto.MlParams
+ .newBuilder()
+ .putParams("handleInvalid",
proto.Expression.Literal.newBuilder().setString("skip").build())
+ .putParams("outputCol",
proto.Expression.Literal.newBuilder().setString("features").build())
+ .putParams("inputCols", getArrayStrings)
+
+ def readWrite(
+ sessionHolder: SessionHolder,
+ operator: proto.MlOperator.Builder,
+ params: proto.MlParams.Builder): proto.MlCommandResult = {
+ // read/write
+ val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
+ try {
+ val path = new File(tempDir, Identifiable.randomUID("test")).getPath
+ val writeCmd = proto.MlCommand
+ .newBuilder()
+ .setWrite(
+ proto.MlCommand.Write
+ .newBuilder()
+ .setOperator(operator)
+ .setParams(params)
+ .setPath(path)
+ .setShouldOverwrite(true))
+ .build()
+ MLHandler.handleMlCommand(sessionHolder, writeCmd)
+
+ val readCmd = proto.MlCommand
+ .newBuilder()
+ .setRead(
+ proto.MlCommand.Read
+ .newBuilder()
+ .setOperator(operator)
+ .setPath(path))
+ .build()
+
+ MLHandler.handleMlCommand(sessionHolder, readCmd)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ def readWrite(
+ sessionHolder: SessionHolder,
+ modelId: String,
+ clsName: String): proto.MlCommandResult = {
+ val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
+ try {
+ val path = new File(tempDir, Identifiable.randomUID("test")).getPath
+ val writeCmd = proto.MlCommand
+ .newBuilder()
+ .setWrite(
+ proto.MlCommand.Write
+ .newBuilder()
+ .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
+ .setPath(path)
+ .setShouldOverwrite(true))
+ .build()
+ MLHandler.handleMlCommand(sessionHolder, writeCmd)
+
+ val readCmd = proto.MlCommand
+ .newBuilder()
+ .setRead(
+ proto.MlCommand.Read
+ .newBuilder()
+ .setOperator(
+ proto.MlOperator
+ .newBuilder()
+ .setName(clsName)
+ .setType(proto.MlOperator.OperatorType.MODEL))
+ .setPath(path))
+ .build()
+
+ MLHandler.handleMlCommand(sessionHolder, readCmd)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
}
class MyMlBackend extends MLBackendPlugin {
@@ -108,6 +241,8 @@ class MyMlBackend extends MLBackendPlugin {
Optional.of("org.apache.spark.sql.connect.ml.MyLogisticRegressionModel")
case "org.apache.spark.ml.evaluation.RegressionEvaluator" =>
Optional.of("org.apache.spark.sql.connect.ml.MyRegressionEvaluator")
+ case "org.apache.spark.ml.feature.VectorAssembler" =>
+ Optional.of("org.apache.spark.sql.connect.ml.MyVectorAssembler")
case _ => Optional.empty()
}
}
@@ -117,6 +252,25 @@ trait HasFakedParam extends Params {
final val fakeParam: IntParam = new IntParam(this, "fakeParam", "faked
parameter")
}
+class MyVectorAssembler(override val uid: String)
+ extends Transformer
+ with HasInputCols
+ with HasOutputCol
+ with HasHandleInvalid
+ with HasFakedParam
+ with DefaultParamsWritable {
+ set(fakeParam, 101010)
+ private[spark] def this() = this(Identifiable.randomUID("MyVectorAssembler"))
+ override def transform(dataset: Dataset[_]): DataFrame =
+ dataset.withColumn("new", lit(1))
+ override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
+ override def transformSchema(schema: StructType): StructType = schema
+}
+
+object MyVectorAssembler extends DefaultParamsReadable[MyVectorAssembler] {
+ override def load(path: String): MyVectorAssembler = super.load(path)
+}
+
class MyRegressionEvaluator(override val uid: String)
extends Evaluator
with DefaultParamsWritable
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index aee0759d0d3a..c3ab6248be8f 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -17,16 +17,15 @@
package org.apache.spark.sql.connect.ml
-import java.io.File
+import scala.jdk.CollectionConverters.ListHasAsScala
import org.apache.spark.connect.proto
import org.apache.spark.ml.classification.LogisticRegressionModel
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.service.SessionHolder
-import org.apache.spark.util.Utils
trait FakeArrayParams extends Params {
final val arrayString: StringArrayParam =
@@ -76,21 +75,7 @@ class MLSuite extends MLHelper {
.putParams("double",
proto.Expression.Literal.newBuilder().setDouble(1.0).build())
.putParams("int",
proto.Expression.Literal.newBuilder().setInteger(10).build())
.putParams("float",
proto.Expression.Literal.newBuilder().setFloat(10.0f).build())
- .putParams(
- "arrayString",
- proto.Expression.Literal
- .newBuilder()
- .setArray(
- proto.Expression.Literal.Array
- .newBuilder()
- .setElementType(proto.DataType
- .newBuilder()
- .setString(proto.DataType.String.getDefaultInstance)
- .build())
-
.addElements(proto.Expression.Literal.newBuilder().setString("hello"))
-
.addElements(proto.Expression.Literal.newBuilder().setString("world"))
- .build())
- .build())
+ .putParams("arrayString", getArrayStrings)
.putParams(
"arrayInt",
proto.Expression.Literal
@@ -127,7 +112,7 @@ class MLSuite extends MLHelper {
assert(fakedML.getFloat === 10.0)
assert(fakedML.getArrayInt === Array(1, 2))
assert(fakedML.getArrayDouble === Array(11.0, 12.0))
- assert(fakedML.getArrayString === Array("hello", "world"))
+ assert(fakedML.getArrayString === Array("a", "b", "c"))
assert(fakedML.getBoolean === true)
assert(fakedML.getDouble === 1.0)
}
@@ -139,29 +124,21 @@ class MLSuite extends MLHelper {
proto.MlCommand.Fit
.newBuilder()
.setDataset(createLocalRelationProto)
- .setEstimator(
- proto.MlOperator
- .newBuilder()
- .setName("org.apache.spark.ml.classification.LogisticRegression")
- .setUid("LogisticRegression")
- .setType(proto.MlOperator.OperatorType.ESTIMATOR))
- .setParams(
- proto.MlParams
- .newBuilder()
- .putParams(
- "maxIter",
- proto.Expression.Literal
- .newBuilder()
- .setInteger(2)
- .build())))
+ .setEstimator(getLogisticRegression)
+ .setParams(getMaxIter))
.build()
val fitResult = MLHandler.handleMlCommand(sessionHolder, fitCommand)
fitResult.getOperatorInfo.getObjRef.getId
}
+ // Estimator/Model works
test("LogisticRegression works") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ // estimator read/write
+ val ret = readWrite(sessionHolder, getLogisticRegression, getMaxIter)
+
assert(ret.getOperatorInfo.getParams.getParamsMap.get("maxIter").getInteger ==
2)
+
def verifyModel(modelId: String, hasSummary: Boolean = false): Unit = {
val model = sessionHolder.mlCache.get(modelId)
// Model is cached
@@ -248,48 +225,15 @@ class MLSuite extends MLHelper {
}
}
- try {
- val modelId = trainLogisticRegressionModel(sessionHolder)
-
- verifyModel(modelId, true)
-
- // read/write
- val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
- try {
- val path = new File(tempDir,
Identifiable.randomUID("LogisticRegression")).getPath
- val writeCmd = proto.MlCommand
- .newBuilder()
- .setWrite(
- proto.MlCommand.Write
- .newBuilder()
- .setPath(path)
- .setObjRef(proto.ObjectRef.newBuilder().setId(modelId)))
- .build()
- MLHandler.handleMlCommand(sessionHolder, writeCmd)
-
- val readCmd = proto.MlCommand
- .newBuilder()
- .setRead(
- proto.MlCommand.Read
- .newBuilder()
- .setOperator(
- proto.MlOperator
- .newBuilder()
-
.setName("org.apache.spark.ml.classification.LogisticRegressionModel")
- .setType(proto.MlOperator.OperatorType.MODEL))
- .setPath(path))
- .build()
-
- val readResult = MLHandler.handleMlCommand(sessionHolder, readCmd)
- verifyModel(readResult.getOperatorInfo.getObjRef.getId)
-
- } finally {
- Utils.deleteRecursively(tempDir)
- }
-
- } finally {
- sessionHolder.mlCache.clear()
- }
+ val modelId = trainLogisticRegressionModel(sessionHolder)
+ verifyModel(modelId, hasSummary = true)
+
+ // model read/write
+ val ret1 = readWrite(
+ sessionHolder,
+ modelId,
+ "org.apache.spark.ml.classification.LogisticRegressionModel")
+ verifyModel(ret1.getOperatorInfo.getObjRef.getId)
}
test("Exception: Unsupported ML operator") {
@@ -365,37 +309,42 @@ class MLSuite extends MLHelper {
evalResult.getParam.getDouble > 2.841 &&
evalResult.getParam.getDouble < 2.843)
- // read/write
- val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
- try {
- val path = new File(tempDir,
Identifiable.randomUID("RegressionEvaluator")).getPath
- val writeCmd = proto.MlCommand
- .newBuilder()
- .setWrite(
- proto.MlCommand.Write
- .newBuilder()
- .setOperator(getRegressorEvaluator)
- .setParams(getMetricName)
- .setPath(path)
- .setShouldOverwrite(true))
- .build()
- MLHandler.handleMlCommand(sessionHolder, writeCmd)
+ val ret = readWrite(sessionHolder, getRegressorEvaluator, getMetricName)
+ assert(
+ ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString ==
+ "mae")
+ }
- val readCmd = proto.MlCommand
- .newBuilder()
- .setRead(
- proto.MlCommand.Read
- .newBuilder()
- .setOperator(getRegressorEvaluator)
- .setPath(path))
- .build()
+ // Transformer works
+ test("VectorAssembler works") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
- val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
- assert(
- ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString
==
- "mae")
- } finally {
- Utils.deleteRecursively(tempDir)
- }
+ val transformerRelation = proto.MlRelation
+ .newBuilder()
+ .setTransform(
+ proto.MlRelation.Transform
+ .newBuilder()
+ .setTransformer(getVectorAssembler)
+ .setParams(getVectorAssemblerParams)
+ .setInput(createMultiColumnLocalRelationProto))
+ .build()
+
+ val transRet = MLHandler.transformMLRelation(transformerRelation,
sessionHolder)
+ Seq("a", "b", "c", "features").foreach(n =>
assert(transRet.schema.names.contains(n)))
+ assert(transRet.schema("features").dataType.isInstanceOf[VectorUDT])
+ val rows = transRet.collect()
+ assert(rows.mkString(",") === "[1,0,3,[1.0,0.0,3.0]]")
+
+ val ret = readWrite(sessionHolder, getVectorAssembler,
getVectorAssemblerParams)
+
assert(ret.getOperatorInfo.getParams.getParamsMap.get("outputCol").getString ==
"features")
+
assert(ret.getOperatorInfo.getParams.getParamsMap.get("handleInvalid").getString
== "skip")
+ assert(
+ ret.getOperatorInfo.getParams.getParamsMap
+ .get("inputCols")
+ .getArray
+ .getElementsList
+ .asScala
+ .map(_.getString)
+ .toArray sameElements Array("a", "b", "c"))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]