This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cbc8892e83d1 [SPARK-50900][ML][CONNECT] Add VectorUDT and MatrixUDT to 
ProtoDataTypes
cbc8892e83d1 is described below

commit cbc8892e83d16fbbfc421c4759a840b45c8be0f4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 21 11:28:29 2025 +0800

    [SPARK-50900][ML][CONNECT] Add VectorUDT and MatrixUDT to ProtoDataTypes
    
    ### What changes were proposed in this pull request?
    Add VectorUDT and MatrixUDT to ProtoDataTypes
    
    ### Why are the changes needed?
    1, to avoid recreating the protobuf messages;
    2, for the two builtin UDTs, field `sqlType` can be ignored.
    
    ### Does this PR introduce _any_ user-facing change?
    NO, internal change
    
    ### How was this patch tested?
    Existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49580 from zhengruifeng/connect_proto_udt.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../connect/common/DataTypeProtoConverter.scala    | 37 ++++++++++++++--------
 .../spark/sql/connect/common/ProtoDataTypes.scala  | 24 ++++++++++++++
 2 files changed, 48 insertions(+), 13 deletions(-)

diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index 3577ca228b03..8c83ad3d1f55 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -302,20 +302,31 @@ object DataTypeProtoConverter {
 
       case udt: UserDefinedType[_] =>
         // Scala/Java UDT
-        val builder = proto.DataType.UDT.newBuilder()
-        builder
-          .setType("udt")
-          .setJvmClass(udt.getClass.getName)
-          .setSqlType(toConnectProtoType(udt.sqlType))
-
-        if (udt.pyUDT != null) {
-          builder.setPythonClass(udt.pyUDT)
-        }
+        udt.getClass.getName match {
+          // To avoid making connect-common depend on ml,
+          // we use class name to identify VectorUDT and MatrixUDT.
+          case "org.apache.spark.ml.linalg.VectorUDT" =>
+            ProtoDataTypes.VectorUDT
+
+          case "org.apache.spark.ml.linalg.MatrixUDT" =>
+            ProtoDataTypes.MatrixUDT
+
+          case className =>
+            val builder = proto.DataType.UDT.newBuilder()
+            builder
+              .setType("udt")
+              .setJvmClass(className)
+              .setSqlType(toConnectProtoType(udt.sqlType))
 
-        proto.DataType
-          .newBuilder()
-          .setUdt(builder.build())
-          .build()
+            if (udt.pyUDT != null) {
+              builder.setPythonClass(udt.pyUDT)
+            }
+
+            proto.DataType
+              .newBuilder()
+              .setUdt(builder.build())
+              .build()
+        }
 
       case _ =>
         throw InvalidPlanInput(s"Does not support convert ${t.typeName} to 
connect proto types.")
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala
index 81ac069705fa..c19d7f12d69b 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala
@@ -109,4 +109,28 @@ private[sql] object ProtoDataTypes {
       .newBuilder()
       .setVariant(proto.DataType.Variant.getDefaultInstance)
       .build()
+
+  val VectorUDT: proto.DataType =
+    proto.DataType
+      .newBuilder()
+      .setUdt(
+        proto.DataType.UDT
+          .newBuilder()
+          .setType("udt")
+          .setJvmClass("org.apache.spark.ml.linalg.VectorUDT")
+          .setPythonClass("pyspark.ml.linalg.VectorUDT")
+          .build())
+      .build()
+
+  val MatrixUDT: proto.DataType =
+    proto.DataType
+      .newBuilder()
+      .setUdt(
+        proto.DataType.UDT
+          .newBuilder()
+          .setType("udt")
+          .setJvmClass("org.apache.spark.ml.linalg.MatrixUDT")
+          .setPythonClass("pyspark.ml.linalg.MatrixUDT")
+          .build())
+      .build()
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to