This is an automated email from the ASF dual-hosted git repository. tdas pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e62d247 [SPARK-32585][SQL] Support scala enumeration in ScalaReflection e62d247 is described below commit e62d24717eb774f1c7adfd0fbe39640b96bc661d Author: ulysses <youxi...@weidian.com> AuthorDate: Thu Oct 1 15:58:01 2020 -0400 [SPARK-32585][SQL] Support scala enumeration in ScalaReflection ### What changes were proposed in this pull request? Add code in `ScalaReflection` to support scala enumeration and make enumeration type as string type in Spark. ### Why are the changes needed? We support java enum but failed with scala enum, it's better to keep the same behavior. Here is a example. ``` package test object TestEnum extends Enumeration { type TestEnum = Value val E1, E2, E3 = Value } import TestEnum._ case class TestClass(i: Int, e: TestEnum) { } import test._ Seq(TestClass(1, TestEnum.E1)).toDS ``` Before this PR ``` Exception in thread "main" java.lang.UnsupportedOperationException: No Encoder found for test.TestEnum.TestEnum - field (class: "scala.Enumeration.Value", name: "e") - root class: "test.TestClass" at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567) at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69) at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:882) at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:881) ``` After this PR `org.apache.spark.sql.Dataset[test.TestClass] = [i: int, e: string]` ### Does this PR introduce _any_ user-facing change? Yes, user can make case class which include scala enumeration field as dataset. ### How was this patch tested? Add test. Closes #29403 from ulysses-you/SPARK-32585. Authored-by: ulysses <youxi...@weidian.com> Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com> --- .../spark/sql/catalyst/ScalaReflection.scala | 28 ++++++++++++++++++++++ .../spark/sql/catalyst/ScalaReflectionSuite.scala | 15 ++++++++++++ .../catalyst/encoders/ExpressionEncoderSuite.scala | 10 +++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++- 4 files changed, 66 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index a9c8b0b..c65e181 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils /** @@ -377,6 +378,23 @@ object ScalaReflection extends ScalaReflection { expressions.Literal.create(null, ObjectType(cls)), newInstance ) + + case t if isSubtype(t, localTypeOf[Enumeration#Value]) => + // package example + // object Foo extends Enumeration { + // type Foo = Value + // val E1, E2 = Value + // } + // the fullName of tpe is example.Foo.Foo, but we need example.Foo so that + // we can call example.Foo.withName to deserialize string to enumeration. + val parent = t.asInstanceOf[TypeRef].pre.typeSymbol.asClass + val cls = mirror.runtimeClass(parent) + StaticInvoke( + cls, + ObjectType(getClassFromType(t)), + "withName", + createDeserializerForString(path, false) :: Nil, + returnNullable = false) } } @@ -561,6 +579,14 @@ object ScalaReflection extends ScalaReflection { } createSerializerForObject(inputObject, fields) + case t if isSubtype(t, localTypeOf[Enumeration#Value]) => + createSerializerForString( + Invoke( + inputObject, + "toString", + ObjectType(classOf[java.lang.String]), + returnNullable = false)) + case _ => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath) @@ -738,6 +764,8 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(fieldType) StructField(fieldName, dataType, nullable) }), nullable = true) + case t if isSubtype(t, localTypeOf[Enumeration#Value]) => + Schema(StringType, nullable = true) case other => throw new UnsupportedOperationException(s"Schema for type $other is not supported") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index b981a50..e8c7aed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.FooEnum.FooEnum import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} @@ -90,6 +91,13 @@ case class FooWithAnnotation(f1: String @FooAnnotation, f2: Option[String] @FooA case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String) +object FooEnum extends Enumeration { + type FooEnum = Value + val E1, E2 = Value +} + +case class FooClassWithEnum(i: Int, e: FooEnum) + object TestingUDT { @SQLUserDefinedType(udt = classOf[NestedStructUDT]) class NestedStruct(val a: Integer, val b: Long, val c: Double) @@ -437,4 +445,11 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("f2", StringType)))) assert(deserializerFor[FooWithAnnotation].dataType == ObjectType(classOf[FooWithAnnotation])) } + + test("SPARK-32585: Support scala enumeration in ScalaReflection") { + assert(serializerFor[FooClassWithEnum].dataType == StructType(Seq( + StructField("i", IntegerType, false), + StructField("e", StringType, true)))) + assert(deserializerFor[FooClassWithEnum].dataType == ObjectType(classOf[FooClassWithEnum])) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 6a094d4..f2598a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.{Encoder, Encoders} -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} +import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.AttributeReference @@ -389,6 +389,14 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes assert(e.getMessage.contains("tuple with more than 22 elements are not supported")) } + encodeDecodeTest((1, FooEnum.E1), "Tuple with Int and scala Enum") + encodeDecodeTest((null, FooEnum.E1, FooEnum.E2), "Tuple with Null and scala Enum") + encodeDecodeTest(Seq(FooEnum.E1, null), "Seq with scala Enum") + encodeDecodeTest(Map("key" -> FooEnum.E1), "Map with String key and scala Enum") + encodeDecodeTest(Map(FooEnum.E1 -> "value"), "Map with scala Enum key and String value") + encodeDecodeTest(FooClassWithEnum(1, FooEnum.E1), "case class with Int and scala Enum") + encodeDecodeTest(FooEnum.E1, "scala Enum") + // Scala / Java big decimals ---------------------------------------------------------- encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4923e8b..3c914ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.TableDrivenPropertyChecks._ import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.sql.catalyst.ScroogeLikeExample +import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide @@ -1926,6 +1926,19 @@ class DatasetSuite extends QueryTest } } } + + test("SPARK-32585: Support scala enumeration in ScalaReflection") { + checkDataset( + Seq(FooClassWithEnum(1, FooEnum.E1), FooClassWithEnum(2, FooEnum.E2)).toDS(), + Seq(FooClassWithEnum(1, FooEnum.E1), FooClassWithEnum(2, FooEnum.E2)): _* + ) + + // test null + checkDataset( + Seq(FooClassWithEnum(1, null), FooClassWithEnum(2, FooEnum.E2)).toDS(), + Seq(FooClassWithEnum(1, null), FooClassWithEnum(2, FooEnum.E2)): _* + ) + } } object AssertExecutionId { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org