Repository: spark Updated Branches: refs/heads/master 7a75ee1c9 -> 70d495dce
[SPARK-18624][SQL] Implicit cast ArrayType(InternalType) ## What changes were proposed in this pull request? Currently `ImplicitTypeCasts` doesn't handle casts between `ArrayType`s, this is not convenient, we should add a rule to enable casting from `ArrayType(InternalType)` to `ArrayType(newInternalType)`. Goals: 1. Add a rule to `ImplicitTypeCasts` to enable casting between `ArrayType`s; 2. Simplify `Percentile` and `ApproximatePercentile`. ## How was this patch tested? Updated test cases in `TypeCoercionSuite`. Author: jiangxingbo <[email protected]> Closes #16057 from jiangxb1987/implicit-cast-complex-types. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70d495dc Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70d495dc Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70d495dc Branch: refs/heads/master Commit: 70d495dcecce8617b7099fc599fe7c43d7eae66e Parents: 7a75ee1 Author: jiangxingbo <[email protected]> Authored: Mon Dec 19 21:20:47 2016 +0100 Committer: Herman van Hovell <[email protected]> Committed: Mon Dec 19 21:20:47 2016 +0100 ---------------------------------------------------------------------- .../sql/catalyst/analysis/TypeCoercion.scala | 57 +++++++++++++------- .../spark/sql/catalyst/expressions/Cast.scala | 6 +-- .../aggregate/ApproximatePercentile.scala | 19 +++---- .../expressions/aggregate/Percentile.scala | 14 ++--- .../catalyst/analysis/TypeCoercionSuite.scala | 45 ++++++++++++++-- 5 files changed, 92 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6662a9e..cd73f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -673,48 +673,69 @@ object TypeCoercion { * If the expression has an incompatible type that cannot be implicitly cast, return None. */ def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { - val inType = e.dataType + implicitCast(e.dataType, expectedType).map { dt => + if (dt == e.dataType) e else Cast(e, dt) + } + } + private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = { // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. // We wrap immediately an Option after this. - @Nullable val ret: Expression = (inType, expectedType) match { - + @Nullable val ret: DataType = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.acceptsType(inType) => e + case _ if expectedType.acceptsType(inType) => inType // Cast null type (usually from null literals) into target types - case (NullType, target) => Cast(e, target.defaultConcreteType) + case (NullType, target) => target.defaultConcreteType // If the function accepts any numeric type and the input is a string, we follow the hive // convention and cast that input into a double - case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) + case (StringType, NumericType) => NumericType.defaultConcreteType // Implicit cast among numeric types. When we reach here, input type is not acceptable. // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to decimal. - case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d)) + case (d: NumericType, DecimalType) => DecimalType.forType(d) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long - case (_: NumericType, target: NumericType) => Cast(e, target) + case (_: NumericType, target: NumericType) => target // Implicit cast between date time types - case (DateType, TimestampType) => Cast(e, TimestampType) - case (TimestampType, DateType) => Cast(e, DateType) + case (DateType, TimestampType) => TimestampType + case (TimestampType, DateType) => DateType // Implicit cast from/to string - case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT) - case (StringType, target: NumericType) => Cast(e, target) - case (StringType, DateType) => Cast(e, DateType) - case (StringType, TimestampType) => Cast(e, TimestampType) - case (StringType, BinaryType) => Cast(e, BinaryType) + case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT + case (StringType, target: NumericType) => target + case (StringType, DateType) => DateType + case (StringType, TimestampType) => TimestampType + case (StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, StringType) if any != StringType => Cast(e, StringType) + case (any: AtomicType, StringType) if any != StringType => StringType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. - case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull + case (_, TypeCollection(types)) => + types.flatMap(implicitCast(inType, _)).headOption.orNull + + // Implicit cast between array types. + // + // Compare the nullabilities of the from type and the to type, check whether the cast of + // the nullability is resolvable by the following rules: + // 1. If the nullability of the to type is true, the cast is always allowed; + // 2. If the nullability of the to type is false, and the nullability of the from type is + // true, the cast is never allowed; + // 3. If the nullabilities of both the from type and the to type are false, the cast is + // allowed only when Cast.forceNullable(fromType, toType) is false. + case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) => + implicitCast(fromType, toType).map(ArrayType(_, true)).orNull + + case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null + + case (ArrayType(fromType, false), ArrayType(toType: DataType, false)) + if !Cast.forceNullable(fromType, toType) => + implicitCast(fromType, toType).map(ArrayType(_, false)).orNull - // Else, just return the same input expression case _ => null } Option(ret) http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4db1ae6..741730e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -89,9 +89,7 @@ object Cast { case _ => false } - private def resolvableNullability(from: Boolean, to: Boolean) = !from || to - - private def forceNullable(from: DataType, to: DataType) = (from, to) match { + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false @@ -110,6 +108,8 @@ object Cast { case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } + + private def resolvableNullability(from: Boolean, to: Boolean) = !from || to } /** Cast the child expression to the target data type. */ http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 01792ae..0e71442 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -86,23 +86,16 @@ case class ApproximatePercentile( private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int] override def inputTypes: Seq[AbstractDataType] = { - Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType) + Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType) } // Mark as lazy so that percentageExpression is not evaluated during tree transformation. - private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = { - (percentageExpression.dataType, percentageExpression.eval()) match { + private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = + percentageExpression.eval() match { // Rule ImplicitTypeCasts can cast other numeric types to double - case (_, num: Double) => (false, Array(num)) - case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => - val numericArray = arrayData.toObjectArray(baseType) - (true, numericArray.map { x => - baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]) - }) - case other => - throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") + case num: Double => (false, Array(num)) + case arrayData: ArrayData => (true, arrayData.toDoubleArray()) } - } override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() @@ -162,7 +155,7 @@ case class ApproximatePercentile( override def nullable: Boolean = true override def dataType: DataType = { - if (returnPercentileArray) ArrayType(DoubleType) else DoubleType + if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType } override def prettyName: String = "percentile_approx" http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index b51b553..2f68195 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -77,15 +77,9 @@ case class Percentile( private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType] @transient - private lazy val percentages = - (percentageExpression.dataType, percentageExpression.eval()) match { - case (_, num: Double) => Seq(num) - case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => - val numericArray = arrayData.toObjectArray(baseType) - numericArray.map { x => - baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq - case other => - throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages") + private lazy val percentages = percentageExpression.eval() match { + case num: Double => Seq(num) + case arrayData: ArrayData => arrayData.toDoubleArray().toSeq } override def children: Seq[Expression] = child :: percentageExpression :: Nil @@ -99,7 +93,7 @@ case class Percentile( } override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match { - case _: ArrayType => Seq(NumericType, ArrayType) + case _: ArrayType => Seq(NumericType, ArrayType(DoubleType)) case _ => Seq(NumericType, DoubleType) } http://git-wip-us.apache.org/repos/asf/spark/blob/70d495dc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 590c9d5..dbb1e3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -57,14 +57,43 @@ class TypeCoercionSuite extends PlanTest { // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { - val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) - assert(got.map(_.dataType) == Option(expected), + // Check default value + val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to) + assert(DataType.equalsIgnoreCompatibleNullability( + castDefault.map(_.dataType).getOrElse(null), expected), + s"Failed to cast $from to $to") + + // Check null value + val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to) + assert(DataType.equalsIgnoreCaseAndNullability( + castNull.map(_.dataType).getOrElse(null), expected), s"Failed to cast $from to $to") } private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { - val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) - assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got") + // Check default value + val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to) + assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault") + + // Check null value + val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to) + assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull") + } + + private def default(dataType: DataType): Expression = dataType match { + case ArrayType(internalType: DataType, _) => + CreateArray(Seq(Literal.default(internalType))) + case MapType(keyDataType: DataType, valueDataType: DataType, _) => + CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType))) + case _ => Literal.default(dataType) + } + + private def createNull(dataType: DataType): Expression = dataType match { + case ArrayType(internalType: DataType, _) => + CreateArray(Seq(Literal.create(null, internalType))) + case MapType(keyDataType: DataType, valueDataType: DataType, _) => + CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType))) + case _ => Literal.create(null, dataType) } val integralTypes: Seq[DataType] = @@ -196,7 +225,13 @@ class TypeCoercionSuite extends PlanTest { test("implicit type cast - ArrayType(StringType)") { val checkedType = ArrayType(StringType) - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + val nonCastableTypes = + complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) + checkTypeCasting(checkedType, + castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_))) + nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _)) + shouldNotCast(ArrayType(DoubleType, containsNull = false), + ArrayType(LongType, containsNull = false)) shouldNotCast(checkedType, DecimalType) shouldNotCast(checkedType, NumericType) shouldNotCast(checkedType, IntegralType) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
