This is an automated email from the ASF dual-hosted git repository.

tdas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e62d247  [SPARK-32585][SQL] Support scala enumeration in 
ScalaReflection
e62d247 is described below

commit e62d24717eb774f1c7adfd0fbe39640b96bc661d
Author: ulysses <youxi...@weidian.com>
AuthorDate: Thu Oct 1 15:58:01 2020 -0400

    [SPARK-32585][SQL] Support scala enumeration in ScalaReflection
    
    ### What changes were proposed in this pull request?
    
    Add code in `ScalaReflection` to support scala enumeration and make 
enumeration type as string type in Spark.
    
    ### Why are the changes needed?
    
    We support java enum but failed with scala enum, it's better to keep the 
same behavior.
    
    Here is a example.
    
    ```
    package test
    
    object TestEnum extends Enumeration {
      type TestEnum = Value
      val E1, E2, E3 = Value
    }
    import TestEnum._
    case class TestClass(i: Int,  e: TestEnum) {
    }
    
    import test._
    Seq(TestClass(1, TestEnum.E1)).toDS
    ```
    
    Before this PR
    ```
    Exception in thread "main" java.lang.UnsupportedOperationException: No 
Encoder found for test.TestEnum.TestEnum
    - field (class: "scala.Enumeration.Value", name: "e")
    - root class: "test.TestClass"
      at 
org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567)
      at 
scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69)
      at 
org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:882)
      at 
org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:881)
    ```
    
    After this PR
    `org.apache.spark.sql.Dataset[test.TestClass] = [i: int, e: string]`
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, user can make case class which include scala enumeration field as 
dataset.
    
    ### How was this patch tested?
    
    Add test.
    
    Closes #29403 from ulysses-you/SPARK-32585.
    
    Authored-by: ulysses <youxi...@weidian.com>
    Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com>
---
 .../spark/sql/catalyst/ScalaReflection.scala       | 28 ++++++++++++++++++++++
 .../spark/sql/catalyst/ScalaReflectionSuite.scala  | 15 ++++++++++++
 .../catalyst/encoders/ExpressionEncoderSuite.scala | 10 +++++++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 15 +++++++++++-
 4 files changed, 66 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index a9c8b0b..c65e181 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.util.Utils
 
 
 /**
@@ -377,6 +378,23 @@ object ScalaReflection extends ScalaReflection {
           expressions.Literal.create(null, ObjectType(cls)),
           newInstance
         )
+
+      case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
+        // package example
+        // object Foo extends Enumeration {
+        //  type Foo = Value
+        //  val E1, E2 = Value
+        // }
+        // the fullName of tpe is example.Foo.Foo, but we need example.Foo so 
that
+        // we can call example.Foo.withName to deserialize string to 
enumeration.
+        val parent = t.asInstanceOf[TypeRef].pre.typeSymbol.asClass
+        val cls = mirror.runtimeClass(parent)
+        StaticInvoke(
+          cls,
+          ObjectType(getClassFromType(t)),
+          "withName",
+          createDeserializerForString(path, false) :: Nil,
+          returnNullable = false)
     }
   }
 
@@ -561,6 +579,14 @@ object ScalaReflection extends ScalaReflection {
         }
         createSerializerForObject(inputObject, fields)
 
+      case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
+        createSerializerForString(
+          Invoke(
+            inputObject,
+            "toString",
+            ObjectType(classOf[java.lang.String]),
+            returnNullable = false))
+
       case _ =>
         throw new UnsupportedOperationException(
           s"No Encoder found for $tpe\n" + walkedTypePath)
@@ -738,6 +764,8 @@ object ScalaReflection extends ScalaReflection {
             val Schema(dataType, nullable) = schemaFor(fieldType)
             StructField(fieldName, dataType, nullable)
           }), nullable = true)
+      case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
+        Schema(StringType, nullable = true)
       case other =>
         throw new UnsupportedOperationException(s"Schema for type $other is 
not supported")
     }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index b981a50..e8c7aed 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.FooEnum.FooEnum
 import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
 import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, 
Expression, If, SpecificInternalRow, UpCast}
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
NewInstance}
@@ -90,6 +91,13 @@ case class FooWithAnnotation(f1: String @FooAnnotation, f2: 
Option[String] @FooA
 
 case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String)
 
+object FooEnum extends Enumeration {
+  type FooEnum = Value
+  val E1, E2 = Value
+}
+
+case class FooClassWithEnum(i: Int, e: FooEnum)
+
 object TestingUDT {
   @SQLUserDefinedType(udt = classOf[NestedStructUDT])
   class NestedStruct(val a: Integer, val b: Long, val c: Double)
@@ -437,4 +445,11 @@ class ScalaReflectionSuite extends SparkFunSuite {
       StructField("f2", StringType))))
     assert(deserializerFor[FooWithAnnotation].dataType == 
ObjectType(classOf[FooWithAnnotation]))
   }
+
+  test("SPARK-32585: Support scala enumeration in ScalaReflection") {
+    assert(serializerFor[FooClassWithEnum].dataType == StructType(Seq(
+      StructField("i", IntegerType, false),
+      StructField("e", StringType, true))))
+    assert(deserializerFor[FooClassWithEnum].dataType == 
ObjectType(classOf[FooClassWithEnum]))
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 6a094d4..f2598a9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.sql.{Encoder, Encoders}
-import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
+import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, 
PrimitiveData}
 import org.apache.spark.sql.catalyst.analysis.AnalysisTest
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
@@ -389,6 +389,14 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
     assert(e.getMessage.contains("tuple with more than 22 elements are not 
supported"))
   }
 
+  encodeDecodeTest((1, FooEnum.E1), "Tuple with Int and scala Enum")
+  encodeDecodeTest((null, FooEnum.E1, FooEnum.E2), "Tuple with Null and scala 
Enum")
+  encodeDecodeTest(Seq(FooEnum.E1, null), "Seq with scala Enum")
+  encodeDecodeTest(Map("key" -> FooEnum.E1), "Map with String key and scala 
Enum")
+  encodeDecodeTest(Map(FooEnum.E1 -> "value"), "Map with scala Enum key and 
String value")
+  encodeDecodeTest(FooClassWithEnum(1, FooEnum.E1), "case class with Int and 
scala Enum")
+  encodeDecodeTest(FooEnum.E1, "scala Enum")
+
   // Scala / Java big decimals 
----------------------------------------------------------
 
   encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 4923e8b..3c914ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException
 import org.scalatest.prop.TableDrivenPropertyChecks._
 
 import org.apache.spark.{SparkException, TaskContext}
-import org.apache.spark.sql.catalyst.ScroogeLikeExample
+import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, 
ScroogeLikeExample}
 import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
 import org.apache.spark.sql.catalyst.util.sideBySide
@@ -1926,6 +1926,19 @@ class DatasetSuite extends QueryTest
       }
     }
   }
+
+  test("SPARK-32585: Support scala enumeration in ScalaReflection") {
+    checkDataset(
+      Seq(FooClassWithEnum(1, FooEnum.E1), FooClassWithEnum(2, 
FooEnum.E2)).toDS(),
+      Seq(FooClassWithEnum(1, FooEnum.E1), FooClassWithEnum(2, FooEnum.E2)): _*
+    )
+
+    // test null
+    checkDataset(
+      Seq(FooClassWithEnum(1, null), FooClassWithEnum(2, FooEnum.E2)).toDS(),
+      Seq(FooClassWithEnum(1, null), FooClassWithEnum(2, FooEnum.E2)): _*
+    )
+  }
 }
 
 object AssertExecutionId {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to