This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 7216099  [SPARK-30993][SQL][2.4] Use its sql type for UDT when 
checking the type of length (fixed/var) or mutable
7216099 is described below

commit 72160991dc524456a136bca3b0c86359f66f37c4
Author: Jungtaek Lim (HeartSaVioR) <[email protected]>
AuthorDate: Tue Mar 3 17:50:43 2020 +0800

    [SPARK-30993][SQL][2.4] Use its sql type for UDT when checking the type of 
length (fixed/var) or mutable
    
    ### What changes were proposed in this pull request?
    
    This patch fixes the bug of UnsafeRow which misses to handle the UDT 
specifically, in `isFixedLength` and `isMutable`. These methods don't check its 
SQL type for UDT, always treating UDT as variable-length, and non-mutable.
    
    It doesn't bring any issue if UDT is used to represent complicated type, 
but when UDT is used to represent some type which is matched with fixed length 
of SQL type, it exposes the chance of correctness issues, as these informations 
sometimes decide how the value should be handled.
    
    We got report from user mailing list which suspected as mapGroupsWithState 
looks like handling UDT incorrectly, but after some investigation it was from 
GenerateUnsafeRowJoiner in shuffle phase.
    
    
https://github.com/apache/spark/blob/0e2ca11d80c3921387d7b077cb64c3a0c06b08d7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala#L32-L43
    
    Here updating position should not happen on fixed-length column, but due to 
this bug, the value of UDT having fixed-length as sql type would be modified, 
which actually corrupts the value.
    
    ### Why are the changes needed?
    
    Misclassifying of the type of length for UDT can corrupt the value when the 
row is presented to the input of GenerateUnsafeRowJoiner, which brings 
correctness issue.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UT added.
    
    Closes #27761 from HeartSaVioR/SPARK-30993-branch-2.4.
    
    Authored-by: Jungtaek Lim (HeartSaVioR) <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/expressions/UnsafeRow.java  |  8 +++++
 .../codegen/GenerateUnsafeRowJoinerSuite.scala     | 41 ++++++++++++++++++++-
 .../apache/spark/sql/UserDefinedTypeSuite.scala    | 42 ++++++++++++++++++++++
 3 files changed, 90 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index ee2b67a..a2440d9 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -95,6 +95,10 @@ public final class UnsafeRow extends InternalRow implements 
Externalizable, Kryo
   }
 
   public static boolean isFixedLength(DataType dt) {
+    if (dt instanceof UserDefinedType) {
+      return isFixedLength(((UserDefinedType) dt).sqlType());
+    }
+
     if (dt instanceof DecimalType) {
       return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS();
     } else {
@@ -103,6 +107,10 @@ public final class UnsafeRow extends InternalRow 
implements Externalizable, Kryo
   }
 
   public static boolean isMutable(DataType dt) {
+    if (dt instanceof UserDefinedType) {
+      return isMutable(((UserDefinedType) dt).sqlType());
+    }
+
     return mutableFieldTypes.contains(dt) || dt instanceof DecimalType;
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
index 75c6bee..a5057d0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions.codegen
 
+import java.time.{LocalDateTime, ZoneOffset}
+
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.RandomDataGenerator
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, 
UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -99,6 +101,23 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
     testConcatOnce(N, N, variable)
   }
 
+  test("SPARK-30993: UserDefinedType matched to fixed length SQL type 
shouldn't be corrupted") {
+    val schema1 = new StructType(Array(
+      StructField("date", new WrappedDateTimeUDT),
+      StructField("s", StringType),
+      StructField("i", IntegerType)))
+    val proj1 = UnsafeProjection.create(schema1.fields.map(_.dataType))
+    val intRow1 = new GenericInternalRow(Array[Any](
+      LocalDateTime.now().toEpochSecond(ZoneOffset.UTC),
+      UTF8String.fromString("hello"), 1))
+
+    val schema2 = new StructType(Array(StructField("i", IntegerType)))
+    val proj2 = UnsafeProjection.create(schema2.fields.map(_.dataType))
+    val intRow2 = new GenericInternalRow(Array[Any](2))
+
+    testConcat(schema1, proj1.apply(intRow1), schema2, proj2.apply(intRow2))
+  }
+
   private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: 
Seq[DataType]): Unit = {
     for (i <- 0 until 10) {
       testConcatOnce(numFields1, numFields2, candidateTypes)
@@ -203,3 +222,23 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
   }
 
 }
+
+private[sql] case class WrappedDateTime(dt: LocalDateTime)
+
+private[sql] class WrappedDateTimeUDT extends UserDefinedType[WrappedDateTime] 
{
+  override def sqlType: DataType = LongType
+
+  override def serialize(obj: WrappedDateTime): Long = {
+    obj.dt.toEpochSecond(ZoneOffset.UTC)
+  }
+
+  def deserialize(datum: Any): WrappedDateTime = datum match {
+    case value: Long =>
+      val v = LocalDateTime.ofEpochSecond(value, 0, ZoneOffset.UTC)
+      WrappedDateTime(v)
+  }
+
+  override def userClass: Class[WrappedDateTime] = classOf[WrappedDateTime]
+
+  private[spark] override def asNullable: WrappedDateTimeUDT = this
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index cc8b600..4e74e92 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import java.time.{LocalDateTime, ZoneOffset}
+
 import scala.beans.{BeanInfo, BeanProperty}
 
 import org.apache.spark.rdd.RDD
@@ -145,6 +147,30 @@ private[spark] class ExampleSubTypeUDT extends 
UserDefinedType[IExampleSubType]
   override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
 }
 
+private[sql] case class FooWithDate(date: LocalDateTime, s: String, i: Int)
+
+private[sql] object FooWithDate {
+  def concatFoo(a: FooWithDate, b: FooWithDate): FooWithDate = {
+    FooWithDate(b.date, a.s + b.s, a.i)
+  }
+}
+
+private[sql] class LocalDateTimeUDT extends UserDefinedType[LocalDateTime] {
+  override def sqlType: DataType = LongType
+
+  override def serialize(obj: LocalDateTime): Long = {
+    obj.toEpochSecond(ZoneOffset.UTC)
+  }
+
+  def deserialize(datum: Any): LocalDateTime = datum match {
+    case value: Long => LocalDateTime.ofEpochSecond(value, 0, ZoneOffset.UTC)
+  }
+
+  override def userClass: Class[LocalDateTime] = classOf[LocalDateTime]
+
+  private[spark] override def asNullable: LocalDateTimeUDT = this
+}
+
 class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with 
ParquetTest
     with ExpressionEvalHelper {
   import testImplicits._
@@ -315,4 +341,20 @@ class UserDefinedTypeSuite extends QueryTest with 
SharedSQLContext with ParquetT
     val ret = Cast(Literal(data, udt), StringType, None)
     checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)")
   }
+
+  test("SPARK-30993: UserDefinedType matched to fixed length SQL type 
shouldn't be corrupted") {
+    UDTRegistration.register(classOf[LocalDateTime].getName, 
classOf[LocalDateTimeUDT].getName)
+
+    // remove sub-millisecond part as we only use millis based timestamp while 
serde
+    val date = 
LocalDateTime.ofEpochSecond(LocalDateTime.now().toEpochSecond(ZoneOffset.UTC),
+      0, ZoneOffset.UTC)
+    val inputDS = List(FooWithDate(date, "Foo", 1), FooWithDate(date, "Foo", 
3),
+      FooWithDate(date, "Foo", 3)).toDS()
+    val agg = inputDS.groupByKey(x => x.i).mapGroups { (_, iter) =>
+      iter.reduce(FooWithDate.concatFoo)
+    }
+    val result = agg.collect()
+
+    assert(result.toSet === Set(FooWithDate(date, "FooFoo", 3), 
FooWithDate(date, "Foo", 1)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to