Repository: spark
Updated Branches:
  refs/heads/master 42dea3acf -> ba3309684


[SPARK-9068][SQL] refactor the implicit type cast code

based on https://github.com/apache/spark/pull/7348

Author: Wenchen Fan <[email protected]>

Closes #7420 from cloud-fan/type-check and squashes the following commits:

7633fa9 [Wenchen Fan] revert
fe169b0 [Wenchen Fan] improve test
03b70da [Wenchen Fan] enhance implicit type cast


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ba330968
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ba330968
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ba330968

Branch: refs/heads/master
Commit: ba33096846dc8061e97a7bf8f3b46f899d530159
Parents: 42dea3a
Author: Wenchen Fan <[email protected]>
Authored: Wed Jul 15 22:27:39 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Wed Jul 15 22:27:39 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 33 +++------
 .../sql/catalyst/expressions/Expression.scala   | 20 +++---
 .../sql/catalyst/expressions/arithmetic.scala   |  2 -
 .../sql/catalyst/expressions/bitwise.scala      |  8 +--
 .../sql/catalyst/expressions/conditionals.scala |  4 +-
 .../spark/sql/types/AbstractDataType.scala      | 45 ++++--------
 .../org/apache/spark/sql/types/ArrayType.scala  |  2 +-
 .../org/apache/spark/sql/types/DataType.scala   |  2 +-
 .../apache/spark/sql/types/DecimalType.scala    |  2 +-
 .../org/apache/spark/sql/types/MapType.scala    |  2 +-
 .../org/apache/spark/sql/types/StructType.scala |  2 +-
 .../analysis/ExpressionTypeCheckingSuite.scala  | 75 +++++++++-----------
 .../analysis/HiveTypeCoercionSuite.scala        | 10 ++-
 13 files changed, 81 insertions(+), 126 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 2508791..50db7d2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -675,10 +675,10 @@ object HiveTypeCoercion {
       case b @ BinaryOperator(left, right) if left.dataType != right.dataType 
=>
         findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { 
commonType =>
           if (b.inputType.acceptsType(commonType)) {
-            // If the expression accepts the tighest common type, cast to that.
+            // If the expression accepts the tightest common type, cast to 
that.
             val newLeft = if (left.dataType == commonType) left else 
Cast(left, commonType)
             val newRight = if (right.dataType == commonType) right else 
Cast(right, commonType)
-            b.makeCopy(Array(newLeft, newRight))
+            b.withNewChildren(Seq(newLeft, newRight))
           } else {
             // Otherwise, don't do anything with the expression.
             b
@@ -697,7 +697,7 @@ object HiveTypeCoercion {
         // general implicit casting.
         val children: Seq[Expression] = e.children.zip(e.inputTypes).map { 
case (in, expected) =>
           if (in.dataType == NullType && !expected.acceptsType(NullType)) {
-            Cast(in, expected.defaultConcreteType)
+            Literal.create(null, expected.defaultConcreteType)
           } else {
             in
           }
@@ -719,27 +719,22 @@ object HiveTypeCoercion {
       @Nullable val ret: Expression = (inType, expectedType) match {
 
         // If the expected type is already a parent of the input type, no need 
to cast.
-        case _ if expectedType.isSameType(inType) => e
+        case _ if expectedType.acceptsType(inType) => e
 
         // Cast null type (usually from null literals) into target types
         case (NullType, target) => Cast(e, target.defaultConcreteType)
 
-        // If the function accepts any numeric type (i.e. the ADT 
`NumericType`) and the input is
-        // already a number, leave it as is.
-        case (_: NumericType, NumericType) => e
-
         // If the function accepts any numeric type and the input is a string, 
we follow the hive
         // convention and cast that input into a double
         case (StringType, NumericType) => Cast(e, 
NumericType.defaultConcreteType)
 
-        // Implicit cast among numeric types
+        // Implicit cast among numeric types. When we reach here, input type 
is not acceptable.
+
         // If input is a numeric type but not decimal, and we expect a decimal 
type,
         // cast the input to unlimited precision decimal.
-        case (_: NumericType, DecimalType) if 
!inType.isInstanceOf[DecimalType] =>
-          Cast(e, DecimalType.Unlimited)
+        case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
         // For any other numeric types, implicitly cast to each other, e.g. 
long -> int, int -> long
-        case (_: NumericType, target: NumericType) if e.dataType != target => 
Cast(e, target)
-        case (_: NumericType, target: NumericType) => e
+        case (_: NumericType, target: NumericType) => Cast(e, target)
 
         // Implicit cast between date time types
         case (DateType, TimestampType) => Cast(e, TimestampType)
@@ -753,15 +748,9 @@ object HiveTypeCoercion {
         case (StringType, BinaryType) => Cast(e, BinaryType)
         case (any, StringType) if any != StringType => Cast(e, StringType)
 
-        // Type collection.
-        // First see if we can find our input type in the type collection. If 
we can, then just
-        // use the current expression; otherwise, find the first one we can 
implicitly cast.
-        case (_, TypeCollection(types)) =>
-          if (types.exists(_.isSameType(inType))) {
-            e
-          } else {
-            types.flatMap(implicitCast(e, _)).headOption.orNull
-          }
+        // When we reach here, input type is not acceptable for any types in 
this type collection,
+        // try to find the first one we can implicitly cast.
+        case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, 
_)).headOption.orNull
 
         // Else, just return the same input expression
         case _ => null

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 8766731..a655cc8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -386,17 +386,15 @@ abstract class BinaryOperator extends BinaryExpression 
with ExpectsInputTypes {
   override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    // First call the checker for ExpectsInputTypes, and then check whether 
left and right have
-    // the same type.
-    super.checkInputDataTypes() match {
-      case TypeCheckResult.TypeCheckSuccess =>
-        if (left.dataType != right.dataType) {
-          TypeCheckResult.TypeCheckFailure(s"differing types in 
'$prettyString' " +
-            s"(${left.dataType.simpleString} and 
${right.dataType.simpleString}).")
-        } else {
-          TypeCheckResult.TypeCheckSuccess
-        }
-      case TypeCheckResult.TypeCheckFailure(msg) => 
TypeCheckResult.TypeCheckFailure(msg)
+    // First check whether left and right have the same type, then check if 
the type is acceptable.
+    if (left.dataType != right.dataType) {
+      TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
+        s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
+    } else if (!inputType.acceptsType(left.dataType)) {
+      TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts 
${inputType.simpleString} type," +
+        s" not ${left.dataType.simpleString}")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 394ef55..382cbe3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) 
extends BinaryArithmetic {
   }
 
   override def symbol: String = "max"
-  override def prettyName: String = symbol
 }
 
 case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic 
{
@@ -375,7 +374,6 @@ case class MinOf(left: Expression, right: Expression) 
extends BinaryArithmetic {
   }
 
   override def symbol: String = "min"
-  override def prettyName: String = symbol
 }
 
 case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
index af1abbc..a1e48c4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
  */
 case class BitwiseAnd(left: Expression, right: Expression) extends 
BinaryArithmetic {
 
-  override def inputType: AbstractDataType = TypeCollection.Bitwise
+  override def inputType: AbstractDataType = IntegralType
 
   override def symbol: String = "&"
 
@@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) 
extends BinaryArithme
  */
 case class BitwiseOr(left: Expression, right: Expression) extends 
BinaryArithmetic {
 
-  override def inputType: AbstractDataType = TypeCollection.Bitwise
+  override def inputType: AbstractDataType = IntegralType
 
   override def symbol: String = "|"
 
@@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) 
extends BinaryArithmet
  */
 case class BitwiseXor(left: Expression, right: Expression) extends 
BinaryArithmetic {
 
-  override def inputType: AbstractDataType = TypeCollection.Bitwise
+  override def inputType: AbstractDataType = IntegralType
 
   override def symbol: String = "^"
 
@@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) 
extends BinaryArithme
  */
 case class BitwiseNot(child: Expression) extends UnaryExpression with 
ExpectsInputTypes {
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise)
+  override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
 
   override def dataType: DataType = child.dataType
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index c7f039e..9162b73 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, 
falseValue: Expressi
       TypeCheckResult.TypeCheckFailure(
         s"type of predicate expression in If should be boolean, not 
${predicate.dataType}")
     } else if (trueValue.dataType != falseValue.dataType) {
-      TypeCheckResult.TypeCheckFailure(
-        s"differing types in If (${trueValue.dataType} and 
${falseValue.dataType}).")
+      TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
+        s"(${trueValue.dataType.simpleString} and 
${falseValue.dataType.simpleString}).")
     } else {
       TypeCheckResult.TypeCheckSuccess
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index f5715f7..076d7b5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType {
   private[sql] def defaultConcreteType: DataType
 
   /**
-   * Returns true if this data type is the same type as `other`.  This is 
different that equality
-   * as equality will also consider data type parametrization, such as decimal 
precision.
+   * Returns true if `other` is an acceptable input type for a function that 
expects this,
+   * possibly abstract DataType.
    *
    * {{{
    *   // this should return true
-   *   DecimalType.isSameType(DecimalType(10, 2))
-   *
-   *   // this should return false
-   *   NumericType.isSameType(DecimalType(10, 2))
-   * }}}
-   */
-  private[sql] def isSameType(other: DataType): Boolean
-
-  /**
-   * Returns true if `other` is an acceptable input type for a function that 
expectes this,
-   * possibly abstract, DataType.
-   *
-   * {{{
-   *   // this should return true
-   *   DecimalType.isSameType(DecimalType(10, 2))
+   *   DecimalType.acceptsType(DecimalType(10, 2))
    *
    *   // this should return true as well
    *   NumericType.acceptsType(DecimalType(10, 2))
    * }}}
    */
-  private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
+  private[sql] def acceptsType(other: DataType): Boolean
 
   /** Readable string representation for the type. */
   private[sql] def simpleString: String
@@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: 
Seq[AbstractDataType])
 
   override private[sql] def defaultConcreteType: DataType = 
types.head.defaultConcreteType
 
-  override private[sql] def isSameType(other: DataType): Boolean = false
-
   override private[sql] def acceptsType(other: DataType): Boolean =
-    types.exists(_.isSameType(other))
+    types.exists(_.acceptsType(other))
 
   override private[sql] def simpleString: String = {
     types.map(_.simpleString).mkString("(", " or ", ")")
@@ -107,13 +91,6 @@ private[sql] object TypeCollection {
     TimestampType, DateType,
     StringType, BinaryType)
 
-  /**
-   * Types that can be used in bitwise operations.
-   */
-  val Bitwise = TypeCollection(
-    BooleanType,
-    ByteType, ShortType, IntegerType, LongType)
-
   def apply(types: AbstractDataType*): TypeCollection = new 
TypeCollection(types)
 
   def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ 
match {
@@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType {
 
   override private[sql] def simpleString: String = "any"
 
-  override private[sql] def isSameType(other: DataType): Boolean = false
-
   override private[sql] def acceptsType(other: DataType): Boolean = true
 }
 
@@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType {
 
   override private[sql] def simpleString: String = "numeric"
 
-  override private[sql] def isSameType(other: DataType): Boolean = false
-
   override private[sql] def acceptsType(other: DataType): Boolean = 
other.isInstanceOf[NumericType]
 }
 
 
-private[sql] object IntegralType {
+private[sql] object IntegralType extends AbstractDataType {
   /**
    * Enables matching against IntegralType for expressions:
    * {{{
@@ -198,6 +171,12 @@ private[sql] object IntegralType {
    * }}}
    */
   def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
+
+  override private[sql] def defaultConcreteType: DataType = IntegerType
+
+  override private[sql] def simpleString: String = "integral"
+
+  override private[sql] def acceptsType(other: DataType): Boolean = 
other.isInstanceOf[IntegralType]
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 76ca7a8..5094058 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType {
 
   override private[sql] def defaultConcreteType: DataType = 
ArrayType(NullType, containsNull = true)
 
-  override private[sql] def isSameType(other: DataType): Boolean = {
+  override private[sql] def acceptsType(other: DataType): Boolean = {
     other.isInstanceOf[ArrayType]
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index da83a7f..2d133ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {
 
   override private[sql] def defaultConcreteType: DataType = this
 
-  override private[sql] def isSameType(other: DataType): Boolean = this == 
other
+  override private[sql] def acceptsType(other: DataType): Boolean = this == 
other
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index a1cafea..377c75f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType {
 
   override private[sql] def defaultConcreteType: DataType = Unlimited
 
-  override private[sql] def isSameType(other: DataType): Boolean = {
+  override private[sql] def acceptsType(other: DataType): Boolean = {
     other.isInstanceOf[DecimalType]
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index ddead10..ac34b64 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -71,7 +71,7 @@ object MapType extends AbstractDataType {
 
   override private[sql] def defaultConcreteType: DataType = apply(NullType, 
NullType)
 
-  override private[sql] def isSameType(other: DataType): Boolean = {
+  override private[sql] def acceptsType(other: DataType): Boolean = {
     other.isInstanceOf[MapType]
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index b809740..2ef97a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -307,7 +307,7 @@ object StructType extends AbstractDataType {
 
   override private[sql] def defaultConcreteType: DataType = new StructType
 
-  override private[sql] def isSameType(other: DataType): Boolean = {
+  override private[sql] def acceptsType(other: DataType): Boolean = {
     other.isInstanceOf[StructType]
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index a4ce182..ed0d20e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{TypeCollection, StringType}
 
 class ExpressionTypeCheckingSuite extends SparkFunSuite {
 
@@ -49,23 +49,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
 
   def assertErrorForDifferingTypes(expr: Expression): Unit = {
     assertError(expr,
-      s"differing types in '${expr.prettyString}' (int and boolean)")
-  }
-
-  def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): 
Unit = {
-    val e = intercept[AnalysisException] {
-      assertSuccess(expr)
-    }
-    assert(e.getMessage.contains(errorMessage))
+      s"differing types in '${expr.prettyString}'")
   }
 
   test("check types for unary arithmetic") {
     assertError(UnaryMinus('stringField), "expected to be of type numeric")
     assertError(Abs('stringField), "expected to be of type numeric")
-    assertError(BitwiseNot('stringField), "type (boolean or tinyint or 
smallint or int or bigint)")
+    assertError(BitwiseNot('stringField), "expected to be of type integral")
   }
 
-  ignore("check types for binary arithmetic") {
+  test("check types for binary arithmetic") {
     // We will cast String to Double for binary arithmetic
     assertSuccess(Add('intField, 'stringField))
     assertSuccess(Subtract('intField, 'stringField))
@@ -85,21 +78,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
     assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
 
-    assertError(Add('booleanField, 'booleanField), "operator + accepts numeric 
type")
-    assertError(Subtract('booleanField, 'booleanField), "operator - accepts 
numeric type")
-    assertError(Multiply('booleanField, 'booleanField), "operator * accepts 
numeric type")
-    assertError(Divide('booleanField, 'booleanField), "operator / accepts 
numeric type")
-    assertError(Remainder('booleanField, 'booleanField), "operator % accepts 
numeric type")
+    assertError(Add('booleanField, 'booleanField), "accepts numeric type")
+    assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
+    assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
+    assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
+    assertError(Remainder('booleanField, 'booleanField), "accepts numeric 
type")
 
-    assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts 
integral type")
-    assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts 
integral type")
-    assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts 
integral type")
+    assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral 
type")
+    assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral 
type")
+    assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral 
type")
 
-    assertError(MaxOf('complexField, 'complexField), "function maxOf accepts 
non-complex type")
-    assertError(MinOf('complexField, 'complexField), "function minOf accepts 
non-complex type")
+    assertError(MaxOf('complexField, 'complexField),
+      s"accepts ${TypeCollection.Ordered.simpleString} type")
+    assertError(MinOf('complexField, 'complexField),
+      s"accepts ${TypeCollection.Ordered.simpleString} type")
   }
 
-  ignore("check types for predicates") {
+  test("check types for predicates") {
     // We will cast String to Double for binary comparison
     assertSuccess(EqualTo('intField, 'stringField))
     assertSuccess(EqualNullSafe('intField, 'stringField))
@@ -112,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertSuccess(EqualTo('intField, 'booleanField))
     assertSuccess(EqualNullSafe('intField, 'booleanField))
 
-    assertError(EqualTo('intField, 'complexField), "differing types")
-    assertError(EqualNullSafe('intField, 'complexField), "differing types")
-
+    assertErrorForDifferingTypes(EqualTo('intField, 'complexField))
+    assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField))
     assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
     assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
     assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
     assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
 
-    assertError(
-      LessThan('complexField, 'complexField), "operator < accepts non-complex 
type")
-    assertError(
-      LessThanOrEqual('complexField, 'complexField), "operator <= accepts 
non-complex type")
-    assertError(
-      GreaterThan('complexField, 'complexField), "operator > accepts 
non-complex type")
-    assertError(
-      GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts 
non-complex type")
+    assertError(LessThan('complexField, 'complexField),
+      s"accepts ${TypeCollection.Ordered.simpleString} type")
+    assertError(LessThanOrEqual('complexField, 'complexField),
+      s"accepts ${TypeCollection.Ordered.simpleString} type")
+    assertError(GreaterThan('complexField, 'complexField),
+      s"accepts ${TypeCollection.Ordered.simpleString} type")
+    assertError(GreaterThanOrEqual('complexField, 'complexField),
+      s"accepts ${TypeCollection.Ordered.simpleString} type")
 
-    assertError(
-      If('intField, 'stringField, 'stringField),
+    assertError(If('intField, 'stringField, 'stringField),
       "type of predicate expression in If should be boolean")
     assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
 
@@ -180,12 +173,12 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
   }
 
   test("check types for ROUND") {
-    assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
-      "data type mismatch: argument 2 is expected to be of type int")
-    assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
-      "data type mismatch: argument 2 is expected to be of type int")
     assertSuccess(Round(Literal(null), Literal(null)))
-    assertError(Round('booleanField, 'intField),
-      "data type mismatch: argument 1 is expected to be of type numeric")
+    assertSuccess(Round('intField, Literal(1)))
+
+    assertError(Round('intField, 'intField), "Only foldable Expression is 
allowed")
+    assertError(Round('intField, 'booleanField), "expected to be of type int")
+    assertError(Round('intField, 'complexField), "expected to be of type int")
+    assertError(Round('booleanField, 'intField), "expected to be of type 
numeric")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba330968/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 8e9b20a..d0fd033 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -203,7 +203,7 @@ class HiveTypeCoercionSuite extends PlanTest {
 
     ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
       NumericTypeUnaryExpression(Literal.create(null, NullType)),
-      NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), 
DoubleType)))
+      NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
   }
 
   test("cast NullType for binary operators") {
@@ -215,9 +215,7 @@ class HiveTypeCoercionSuite extends PlanTest {
 
     ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
       NumericTypeBinaryOperator(Literal.create(null, NullType), 
Literal.create(null, NullType)),
-      NumericTypeBinaryOperator(
-        Cast(Literal.create(null, NullType), DoubleType),
-        Cast(Literal.create(null, NullType), DoubleType)))
+      NumericTypeBinaryOperator(Literal.create(null, DoubleType), 
Literal.create(null, DoubleType)))
   }
 
   test("coalesce casts") {
@@ -345,14 +343,14 @@ object HiveTypeCoercionSuite {
   }
 
   case class AnyTypeBinaryOperator(left: Expression, right: Expression)
-    extends BinaryOperator with ExpectsInputTypes {
+    extends BinaryOperator {
     override def dataType: DataType = NullType
     override def inputType: AbstractDataType = AnyDataType
     override def symbol: String = "anytype"
   }
 
   case class NumericTypeBinaryOperator(left: Expression, right: Expression)
-    extends BinaryOperator with ExpectsInputTypes {
+    extends BinaryOperator {
     override def dataType: DataType = NullType
     override def inputType: AbstractDataType = NumericType
     override def symbol: String = "numerictype"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to