Repository: spark
Updated Branches:
  refs/heads/master 7a75ee1c9 -> 70d495dce


[SPARK-18624][SQL] Implicit cast ArrayType(InternalType)

## What changes were proposed in this pull request?

Currently `ImplicitTypeCasts` doesn't handle casts between `ArrayType`s, this 
is not convenient, we should add a rule to enable casting from 
`ArrayType(InternalType)` to `ArrayType(newInternalType)`.

Goals:
1. Add a rule to `ImplicitTypeCasts` to enable casting between `ArrayType`s;
2. Simplify `Percentile` and `ApproximatePercentile`.

## How was this patch tested?

Updated test cases in `TypeCoercionSuite`.

Author: jiangxingbo <[email protected]>

Closes #16057 from jiangxb1987/implicit-cast-complex-types.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70d495dc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70d495dc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70d495dc

Branch: refs/heads/master
Commit: 70d495dcecce8617b7099fc599fe7c43d7eae66e
Parents: 7a75ee1
Author: jiangxingbo <[email protected]>
Authored: Mon Dec 19 21:20:47 2016 +0100
Committer: Herman van Hovell <[email protected]>
Committed: Mon Dec 19 21:20:47 2016 +0100

----------------------------------------------------------------------
 .../sql/catalyst/analysis/TypeCoercion.scala    | 57 +++++++++++++-------
 .../spark/sql/catalyst/expressions/Cast.scala   |  6 +--
 .../aggregate/ApproximatePercentile.scala       | 19 +++----
 .../expressions/aggregate/Percentile.scala      | 14 ++---
 .../catalyst/analysis/TypeCoercionSuite.scala   | 45 ++++++++++++++--
 5 files changed, 92 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 6662a9e..cd73f9c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -673,48 +673,69 @@ object TypeCoercion {
      * If the expression has an incompatible type that cannot be implicitly 
cast, return None.
      */
     def implicitCast(e: Expression, expectedType: AbstractDataType): 
Option[Expression] = {
-      val inType = e.dataType
+      implicitCast(e.dataType, expectedType).map { dt =>
+        if (dt == e.dataType) e else Cast(e, dt)
+      }
+    }
 
+    private def implicitCast(inType: DataType, expectedType: 
AbstractDataType): Option[DataType] = {
       // Note that ret is nullable to avoid typing a lot of Some(...) in this 
local scope.
       // We wrap immediately an Option after this.
-      @Nullable val ret: Expression = (inType, expectedType) match {
-
+      @Nullable val ret: DataType = (inType, expectedType) match {
         // If the expected type is already a parent of the input type, no need 
to cast.
-        case _ if expectedType.acceptsType(inType) => e
+        case _ if expectedType.acceptsType(inType) => inType
 
         // Cast null type (usually from null literals) into target types
-        case (NullType, target) => Cast(e, target.defaultConcreteType)
+        case (NullType, target) => target.defaultConcreteType
 
         // If the function accepts any numeric type and the input is a string, 
we follow the hive
         // convention and cast that input into a double
-        case (StringType, NumericType) => Cast(e, 
NumericType.defaultConcreteType)
+        case (StringType, NumericType) => NumericType.defaultConcreteType
 
         // Implicit cast among numeric types. When we reach here, input type 
is not acceptable.
 
         // If input is a numeric type but not decimal, and we expect a decimal 
type,
         // cast the input to decimal.
-        case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d))
+        case (d: NumericType, DecimalType) => DecimalType.forType(d)
         // For any other numeric types, implicitly cast to each other, e.g. 
long -> int, int -> long
-        case (_: NumericType, target: NumericType) => Cast(e, target)
+        case (_: NumericType, target: NumericType) => target
 
         // Implicit cast between date time types
-        case (DateType, TimestampType) => Cast(e, TimestampType)
-        case (TimestampType, DateType) => Cast(e, DateType)
+        case (DateType, TimestampType) => TimestampType
+        case (TimestampType, DateType) => DateType
 
         // Implicit cast from/to string
-        case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT)
-        case (StringType, target: NumericType) => Cast(e, target)
-        case (StringType, DateType) => Cast(e, DateType)
-        case (StringType, TimestampType) => Cast(e, TimestampType)
-        case (StringType, BinaryType) => Cast(e, BinaryType)
+        case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT
+        case (StringType, target: NumericType) => target
+        case (StringType, DateType) => DateType
+        case (StringType, TimestampType) => TimestampType
+        case (StringType, BinaryType) => BinaryType
         // Cast any atomic type to string.
-        case (any: AtomicType, StringType) if any != StringType => Cast(e, 
StringType)
+        case (any: AtomicType, StringType) if any != StringType => StringType
 
         // When we reach here, input type is not acceptable for any types in 
this type collection,
         // try to find the first one we can implicitly cast.
-        case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, 
_)).headOption.orNull
+        case (_, TypeCollection(types)) =>
+          types.flatMap(implicitCast(inType, _)).headOption.orNull
+
+        // Implicit cast between array types.
+        //
+        // Compare the nullabilities of the from type and the to type, check 
whether the cast of
+        // the nullability is resolvable by the following rules:
+        // 1. If the nullability of the to type is true, the cast is always 
allowed;
+        // 2. If the nullability of the to type is false, and the nullability 
of the from type is
+        // true, the cast is never allowed;
+        // 3. If the nullabilities of both the from type and the to type are 
false, the cast is
+        // allowed only when Cast.forceNullable(fromType, toType) is false.
+        case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) =>
+          implicitCast(fromType, toType).map(ArrayType(_, true)).orNull
+
+        case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) 
=> null
+
+        case (ArrayType(fromType, false), ArrayType(toType: DataType, false))
+            if !Cast.forceNullable(fromType, toType) =>
+          implicitCast(fromType, toType).map(ArrayType(_, false)).orNull
 
-        // Else, just return the same input expression
         case _ => null
       }
       Option(ret)

http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 4db1ae6..741730e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -89,9 +89,7 @@ object Cast {
     case _ => false
   }
 
-  private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
-
-  private def forceNullable(from: DataType, to: DataType) = (from, to) match {
+  def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
     case (NullType, _) => true
     case (_, _) if from == to => false
 
@@ -110,6 +108,8 @@ object Cast {
     case (_: FractionalType, _: IntegralType) => true  // NaN, infinity
     case _ => false
   }
+
+  private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
 }
 
 /** Cast the child expression to the target data type. */

http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 01792ae..0e71442 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -86,23 +86,16 @@ case class ApproximatePercentile(
   private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType)
+    Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), 
IntegerType)
   }
 
   // Mark as lazy so that percentageExpression is not evaluated during tree 
transformation.
-  private lazy val (returnPercentileArray: Boolean, percentages: 
Array[Double]) = {
-    (percentageExpression.dataType, percentageExpression.eval()) match {
+  private lazy val (returnPercentileArray: Boolean, percentages: 
Array[Double]) =
+    percentageExpression.eval() match {
       // Rule ImplicitTypeCasts can cast other numeric types to double
-      case (_, num: Double) => (false, Array(num))
-      case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
-         val numericArray = arrayData.toObjectArray(baseType)
-        (true, numericArray.map { x =>
-          baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
-        })
-      case other =>
-        throw new AnalysisException(s"Invalid data type ${other._1} for 
parameter percentage")
+      case num: Double => (false, Array(num))
+      case arrayData: ArrayData => (true, arrayData.toDoubleArray())
     }
-  }
 
   override def checkInputDataTypes(): TypeCheckResult = {
     val defaultCheck = super.checkInputDataTypes()
@@ -162,7 +155,7 @@ case class ApproximatePercentile(
   override def nullable: Boolean = true
 
   override def dataType: DataType = {
-    if (returnPercentileArray) ArrayType(DoubleType) else DoubleType
+    if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType
   }
 
   override def prettyName: String = "percentile_approx"

http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index b51b553..2f68195 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -77,15 +77,9 @@ case class Percentile(
   private lazy val returnPercentileArray = 
percentageExpression.dataType.isInstanceOf[ArrayType]
 
   @transient
-  private lazy val percentages =
-    (percentageExpression.dataType, percentageExpression.eval()) match {
-      case (_, num: Double) => Seq(num)
-      case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
-        val numericArray = arrayData.toObjectArray(baseType)
-        numericArray.map { x =>
-          
baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq
-      case other =>
-        throw new AnalysisException(s"Invalid data type ${other._1} for 
parameter percentages")
+  private lazy val percentages = percentageExpression.eval() match {
+      case num: Double => Seq(num)
+      case arrayData: ArrayData => arrayData.toDoubleArray().toSeq
   }
 
   override def children: Seq[Expression] = child :: percentageExpression :: Nil
@@ -99,7 +93,7 @@ case class Percentile(
   }
 
   override def inputTypes: Seq[AbstractDataType] = 
percentageExpression.dataType match {
-    case _: ArrayType => Seq(NumericType, ArrayType)
+    case _: ArrayType => Seq(NumericType, ArrayType(DoubleType))
     case _ => Seq(NumericType, DoubleType)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 590c9d5..dbb1e3e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -57,14 +57,43 @@ class TypeCoercionSuite extends PlanTest {
   // scalastyle:on line.size.limit
 
   private def shouldCast(from: DataType, to: AbstractDataType, expected: 
DataType): Unit = {
-    val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, 
from), to)
-    assert(got.map(_.dataType) == Option(expected),
+    // Check default value
+    val castDefault = 
TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to)
+    assert(DataType.equalsIgnoreCompatibleNullability(
+      castDefault.map(_.dataType).getOrElse(null), expected),
+      s"Failed to cast $from to $to")
+
+    // Check null value
+    val castNull = 
TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to)
+    assert(DataType.equalsIgnoreCaseAndNullability(
+      castNull.map(_.dataType).getOrElse(null), expected),
       s"Failed to cast $from to $to")
   }
 
   private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
-    val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, 
from), to)
-    assert(got.isEmpty, s"Should not be able to cast $from to $to, but got 
$got")
+    // Check default value
+    val castDefault = 
TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to)
+    assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but 
got $castDefault")
+
+    // Check null value
+    val castNull = 
TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to)
+    assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but 
got $castNull")
+  }
+
+  private def default(dataType: DataType): Expression = dataType match {
+    case ArrayType(internalType: DataType, _) =>
+      CreateArray(Seq(Literal.default(internalType)))
+    case MapType(keyDataType: DataType, valueDataType: DataType, _) =>
+      CreateMap(Seq(Literal.default(keyDataType), 
Literal.default(valueDataType)))
+    case _ => Literal.default(dataType)
+  }
+
+  private def createNull(dataType: DataType): Expression = dataType match {
+    case ArrayType(internalType: DataType, _) =>
+      CreateArray(Seq(Literal.create(null, internalType)))
+    case MapType(keyDataType: DataType, valueDataType: DataType, _) =>
+      CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, 
valueDataType)))
+    case _ => Literal.create(null, dataType)
   }
 
   val integralTypes: Seq[DataType] =
@@ -196,7 +225,13 @@ class TypeCoercionSuite extends PlanTest {
 
   test("implicit type cast - ArrayType(StringType)") {
     val checkedType = ArrayType(StringType)
-    checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
+    val nonCastableTypes =
+      complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType)
+    checkTypeCasting(checkedType,
+      castableTypes = 
allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_)))
+    nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _))
+    shouldNotCast(ArrayType(DoubleType, containsNull = false),
+      ArrayType(LongType, containsNull = false))
     shouldNotCast(checkedType, DecimalType)
     shouldNotCast(checkedType, NumericType)
     shouldNotCast(checkedType, IntegralType)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to