This is an automated email from the ASF dual-hosted git repository.

xxyu pushed a commit to branch kylin5
in repository https://gitbox.apache.org/repos/asf/kylin.git

commit b93d19181e806a136daaac0e18bf277859de006e
Author: liang.huang <83992752+lhuang09287...@users.noreply.github.com>
AuthorDate: Wed Jun 14 14:09:27 2023 +0800

    KYLIN-5721 [FOLLOW UP] add two date funcs to support date string args
    
    add timestampdiff、timestampadd to support date string args
---
 .../org/apache/kylin/query/util/PushDownUtil.java  |  6 ++
 .../sql/catalyst/expressions/KapExpresssions.scala |  8 +--
 .../apache/spark/sql/udf/TimestampAddTest.scala    | 78 ++++++++++++++++++++--
 .../apache/spark/sql/udf/TimestampDiffTest.scala   | 64 ++++++++++++++++--
 4 files changed, 140 insertions(+), 16 deletions(-)

diff --git 
a/src/query-common/src/main/java/org/apache/kylin/query/util/PushDownUtil.java 
b/src/query-common/src/main/java/org/apache/kylin/query/util/PushDownUtil.java
index bf659bc376..9cb44c6736 100644
--- 
a/src/query-common/src/main/java/org/apache/kylin/query/util/PushDownUtil.java
+++ 
b/src/query-common/src/main/java/org/apache/kylin/query/util/PushDownUtil.java
@@ -486,6 +486,12 @@ public class PushDownUtil {
                     QueryContext.current().getMetrics().getCorrectedSql(), 
sqlException);
             return true;
         }
+
+        //SqlValidatorException about TIMESTAMPADD and TIMESTAMPDIFF is 
expected
+        if (rootCause.getMessage().contains("TIMESTAMPADD") || 
rootCause.getMessage().contains("TIMESTAMPDIFF")) {
+            return true;
+        }
+
         return false;
     }
 
diff --git 
a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/catalyst/expressions/KapExpresssions.scala
 
b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/catalyst/expressions/KapExpresssions.scala
index 32d3d7d4d7..2233e201ba 100644
--- 
a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/catalyst/expressions/KapExpresssions.scala
+++ 
b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/catalyst/expressions/KapExpresssions.scala
@@ -194,12 +194,12 @@ case class Sum0(child: Expression)
     super.legacyWithNewChildren(newChildren)
 }
 
-case class TimestampAdd(left: Expression, mid: Expression, right: Expression) 
extends TernaryExpression with ExpectsInputTypes {
+case class TimestampAdd(left: Expression, mid: Expression, right: Expression) 
extends TernaryExpression with ImplicitCastInputTypes {
 
   override def dataType: DataType = getResultDataType
 
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(StringType, TypeCollection(IntegerType, LongType), 
TypeCollection(DateType, TimestampType))
+    Seq(StringType, TypeCollection(IntegerType, LongType), 
TypeCollection(TimestampType, DateType))
 
   def getResultDataType(): DataType = {
     if (canConvertTimestamp()) {
@@ -273,10 +273,10 @@ case class TimestampAdd(left: Expression, mid: 
Expression, right: Expression) ex
 }
 
 case class TimestampDiff(left: Expression, mid: Expression, right: Expression) 
extends TernaryExpression
-  with ExpectsInputTypes {
+  with ImplicitCastInputTypes {
 
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(StringType, TypeCollection(DateType, TimestampType), 
TypeCollection(DateType, TimestampType))
+    Seq(StringType, TypeCollection(TimestampType, DateType), 
TypeCollection(TimestampType, DateType))
 
 
   override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): 
Any = {
diff --git 
a/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/udf/TimestampAddTest.scala
 
b/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/udf/TimestampAddTest.scala
index bf0a0fa466..322e82cc8c 100644
--- 
a/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/udf/TimestampAddTest.scala
+++ 
b/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/udf/TimestampAddTest.scala
@@ -19,11 +19,10 @@
 package org.apache.spark.sql.udf
 
 import java.sql.{Date, Timestamp}
-
 import org.apache.spark.sql.catalyst.expressions.ExpressionUtils.expression
 import org.apache.spark.sql.catalyst.expressions.TimestampAdd
 import org.apache.spark.sql.common.{SharedSparkSession, SparderBaseFunSuite}
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{StructField, _}
 import org.apache.spark.sql.{FunctionEntity, Row}
 import org.scalatest.BeforeAndAfterAll
 
@@ -42,6 +41,10 @@ class TimestampAddTest extends SparderBaseFunSuite with 
SharedSparkSession with
 
     verifyResult("select timestampadd('SQL_TSI_YEAR', 1 , date'2016-02-29')", 
Seq("2017-02-28"))
 
+    verifyResult("select timestampadd('YEAR', 1 , '2016-02-29')", 
Seq("2017-02-28 00:00:00.0"))
+
+    verifyResult("select timestampadd('SQL_TSI_YEAR', 1 , '2016-02-29')", 
Seq("2017-02-28 00:00:00.0"))
+
     // QUARTER
     verifyResult("select timestampadd('QUARTER', 1L , date'2016-02-29')", 
Seq("2016-05-31"))
 
@@ -52,72 +55,112 @@ class TimestampAddTest extends SparderBaseFunSuite with 
SharedSparkSession with
 
     verifyResult("select timestampadd('SQL_TSI_MONTH', 1 , date'2016-01-31')", 
Seq("2016-02-29"))
 
+    verifyResult("select timestampadd('MONTH', 1 , '2016-01-31')", 
Seq("2016-02-29 00:00:00.0"))
+
+    verifyResult("select timestampadd('SQL_TSI_MONTH', 1 , '2016-01-31')", 
Seq("2016-02-29 00:00:00.0"))
+
     // WEEK
     verifyResult("select timestampadd('WEEK', 1L , date'2016-01-31')", 
Seq("2016-02-07"))
 
     verifyResult("select timestampadd('SQL_TSI_WEEK', 1L , date'2016-01-31')", 
Seq("2016-02-07"))
 
+    verifyResult("select timestampadd('WEEK', 1L , '2016-01-31')", 
Seq("2016-02-07 00:00:00.0"))
+
+    verifyResult("select timestampadd('SQL_TSI_WEEK', 1L , '2016-01-31')", 
Seq("2016-02-07 00:00:00.0"))
+
     // DAY
     verifyResult("select timestampadd('DAY', 1 , date'2016-01-31')", 
Seq("2016-02-01"))
 
     verifyResult("select timestampadd('SQL_TSI_DAY', 1 , date'2016-01-31')", 
Seq("2016-02-01"))
 
+    verifyResult("select timestampadd('DAY', 1 , '2016-01-31')", 
Seq("2016-02-01 00:00:00.0"))
+
+    verifyResult("select timestampadd('SQL_TSI_DAY', 1 , '2016-01-31')", 
Seq("2016-02-01 00:00:00.0"))
+
     // HOUR
     verifyResult("select timestampadd('HOUR', 1L , date'2016-01-31')", 
Seq("2016-01-31 01:00:00.0"))
 
     verifyResult("select timestampadd('SQL_TSI_HOUR', 1L , date'2016-01-31')", 
Seq("2016-01-31 01:00:00.0"))
 
+    verifyResult("select timestampadd('HOUR', 1L , '2016-01-31')", 
Seq("2016-01-31 01:00:00.0"))
+
+    verifyResult("select timestampadd('SQL_TSI_HOUR', 1L , '2016-01-31')", 
Seq("2016-01-31 01:00:00.0"))
+
     // MINUTE
     verifyResult("select timestampadd('MINUTE', 1 , date'2016-01-31')", 
Seq("2016-01-31 00:01:00.0"))
 
     verifyResult("select timestampadd('SQL_TSI_MINUTE', 1 , 
date'2016-01-31')", Seq("2016-01-31 00:01:00.0"))
 
+    verifyResult("select timestampadd('MINUTE', 1 , '2016-01-31')", 
Seq("2016-01-31 00:01:00.0"))
+
+    verifyResult("select timestampadd('SQL_TSI_MINUTE', 1 , '2016-01-31')", 
Seq("2016-01-31 00:01:00.0"))
+
     // SECOND
     verifyResult("select timestampadd('SECOND', 1L , date'2016-01-31')", 
Seq("2016-01-31 00:00:01.0"))
 
     verifyResult("select timestampadd('SQL_TSI_SECOND', 1L , 
date'2016-01-31')", Seq("2016-01-31 00:00:01.0"))
 
+    verifyResult("select timestampadd('SECOND', 1L , '2016-01-31')", 
Seq("2016-01-31 00:00:01.0"))
+
+    verifyResult("select timestampadd('SQL_TSI_SECOND', 1L , '2016-01-31')", 
Seq("2016-01-31 00:00:01.0"))
+
     // FRAC_SECOND
     verifyResult("select timestampadd('FRAC_SECOND', 1 , date'2016-01-31')", 
Seq("2016-01-31 00:00:00.001"))
 
     verifyResult("select timestampadd('SQL_TSI_FRAC_SECOND', 1 , 
date'2016-01-31')", Seq("2016-01-31 00:00:00.001"))
+
+    verifyResult("select timestampadd('FRAC_SECOND', 1 , '2016-01-31')", 
Seq("2016-01-31 00:00:00.001"))
+
+    verifyResult("select timestampadd('SQL_TSI_FRAC_SECOND', 1 , 
'2016-01-31')", Seq("2016-01-31 00:00:00.001"))
   }
 
-  ignore("test add on timestamp") {
+  test("test add on timestamp") {
     // YEAR
     verifyResult("select timestampadd('YEAR', 1 , timestamp'2016-02-29 
01:01:01.001')", Seq("2017-02-28 01:01:01.001"))
+    verifyResult("select timestampadd('YEAR', 1 , '2016-02-29 01:01:01.001')", 
Seq("2017-02-28 01:01:01.001"))
 
     // QUARTER
-    verifyResult("select timestampadd('QUARTER', 1L , timestamp'2016-02-29 
01:01:01.001')", Seq("2016-05-31 01:01:01.001"))
+    // verifyResult("select timestampadd('QUARTER', 1L , timestamp'2016-02-29 
01:01:01.001')", Seq("2016-05-31 01:01:01.001"))
 
     // MONTH
     verifyResult("select timestampadd('MONTH', 1 , timestamp'2016-01-31 
01:01:01.001')", Seq("2016-02-29 01:01:01.001"))
+    verifyResult("select timestampadd('MONTH', 1 , '2016-01-31 
01:01:01.001')", Seq("2016-02-29 01:01:01.001"))
 
     // WEEK
     verifyResult("select timestampadd('WEEK', 1L , timestamp'2016-01-31 
01:01:01.001')", Seq("2016-02-07 01:01:01.001"))
+    verifyResult("select timestampadd('WEEK', 1L , '2016-01-31 
01:01:01.001')", Seq("2016-02-07 01:01:01.001"))
 
     // DAY
     verifyResult("select timestampadd('DAY', 1 , timestamp'2016-01-31 
01:01:01.001')", Seq("2016-02-01 01:01:01.001"))
+    verifyResult("select timestampadd('DAY', 1 , '2016-01-31 01:01:01.001')", 
Seq("2016-02-01 01:01:01.001"))
 
     // HOUR
     verifyResult("select timestampadd('HOUR', 25L , timestamp'2016-01-31 
01:01:01.001')", Seq("2016-02-01 02:01:01.001"))
+    verifyResult("select timestampadd('HOUR', 25L , '2016-01-31 
01:01:01.001')", Seq("2016-02-01 02:01:01.001"))
 
     // MINUTE
     verifyResult("select timestampadd('MINUTE', 61 , timestamp'2016-01-31 
01:01:01.001')", Seq("2016-01-31 02:02:01.001"))
+    verifyResult("select timestampadd('MINUTE', 61 , '2016-01-31 
01:01:01.001')", Seq("2016-01-31 02:02:01.001"))
 
     // SECOND
     verifyResult("select timestampadd('SECOND', 61L , timestamp'2016-01-31 
01:01:01.001')", Seq("2016-01-31 01:02:02.001"))
+    verifyResult("select timestampadd('SECOND', 61L , '2016-01-31 
01:01:01.001')", Seq("2016-01-31 01:02:02.001"))
 
     // FRAC_SECOND
     verifyResult("select timestampadd('FRAC_SECOND', 1001 , 
timestamp'2016-01-31 01:01:01.001')", Seq("2016-01-31 01:01:02.002"))
+    verifyResult("select timestampadd('FRAC_SECOND', 1001 , '2016-01-31 
01:01:01.001')", Seq("2016-01-31 01:01:02.002"))
   }
 
   test("test null and illegal argument") {
     verifyResult("select timestampadd(null, 1 , timestamp'2016-01-31 
01:01:01.001')", Seq("null"))
     verifyResult("select timestampadd(null, 1L , date'2016-01-31')", 
Seq("null"))
+    verifyResult("select timestampadd(null, 1 , '2016-01-31 01:01:01.001')", 
Seq("null"))
+    verifyResult("select timestampadd(null, 1L , '2016-01-31')", Seq("null"))
 
     verifyResult("select timestampadd('DAY', null , timestamp'2016-01-31 
01:01:01.001')", Seq("null"))
     verifyResult("select timestampadd('DAY', null , date'2016-01-31')", 
Seq("null"))
+    verifyResult("select timestampadd('DAY', null , '2016-01-31 
01:01:01.001')", Seq("null"))
+    verifyResult("select timestampadd('DAY', null , '2016-01-31')", 
Seq("null"))
 
     verifyResult("select timestampadd('DAY', 1 , null)", Seq("null"))
 
@@ -131,6 +174,16 @@ class TimestampAddTest extends SparderBaseFunSuite with 
SharedSparkSession with
           " FRAC_SECOND, SQL_TSI_FRAC_SECOND] for now.")
     }
 
+    try {
+      verifyResult("select timestampadd('ILLEGAL', 1 , '2016-01-31')", 
Seq("null"))
+    } catch {
+      case e: Exception =>
+        assert(e.isInstanceOf[IllegalArgumentException])
+        assert(e.getMessage == "Illegal unit: ILLEGAL, only support [YEAR, 
SQL_TSI_YEAR, QUARTER, SQL_TSI_QUARTER, MONTH, SQL_TSI_MONTH," +
+          " WEEK, SQL_TSI_WEEK, DAY, SQL_TSI_DAY, HOUR, SQL_TSI_HOUR, MINUTE, 
SQL_TSI_MINUTE, SECOND, SQL_TSI_SECOND," +
+          " FRAC_SECOND, SQL_TSI_FRAC_SECOND] for now.")
+    }
+
     try {
       verifyResult("select timestampadd('ILLEGAL', 2147483648, 
date'2016-01-31')", Seq("0"))
     } catch {
@@ -138,6 +191,14 @@ class TimestampAddTest extends SparderBaseFunSuite with 
SharedSparkSession with
         assert(e.isInstanceOf[IllegalArgumentException])
         assert(e.getMessage == "Increment(2147483648) is greater than 
Int.MaxValue")
     }
+
+    try {
+      verifyResult("select timestampadd('ILLEGAL', 2147483648, '2016-01-31')", 
Seq("0"))
+    } catch {
+      case e: Exception =>
+        assert(e.isInstanceOf[IllegalArgumentException])
+        assert(e.getMessage == "Increment(2147483648) is greater than 
Int.MaxValue")
+    }
   }
 
   test("test codegen") {
@@ -146,14 +207,19 @@ class TimestampAddTest extends SparderBaseFunSuite with 
SharedSparkSession with
       StructField("c_int", IntegerType),
       StructField("unit", StringType),
       StructField("c_timestamp", TimestampType),
-      StructField("c_date", DateType)
+      StructField("c_date", DateType),
+      StructField("c_timestamp_string", StringType),
+      StructField("c_date_string", StringType)
     ))
     val rdd = sc.parallelize(Seq(
-      Row(1L, 2, "YEAR", Timestamp.valueOf("2016-02-29 01:01:01.001"), 
Date.valueOf("2016-02-29"))
+      Row(1L, 2, "YEAR", Timestamp.valueOf("2016-02-29 01:01:01.001"), 
Date.valueOf("2016-02-29"), "2016-02-29 01:01:01.001", "2016-02-29")
     ))
     spark.sqlContext.createDataFrame(rdd, 
schema).createOrReplaceGlobalTempView("test_timestamp_add")
     verifyResult("select timestampadd(unit, c_long, c_timestamp) from 
global_temp.test_timestamp_add", Seq("2017-02-28 01:01:01.001"))
     verifyResult("select timestampadd(unit, c_int, c_date) from 
global_temp.test_timestamp_add", Seq("2018-02-28"))
+    verifyResult("select timestampadd(unit, c_long, c_timestamp_string) from 
global_temp.test_timestamp_add",
+      Seq("2017-02-28 01:01:01.001"))
+    verifyResult("select timestampadd(unit, c_int, c_date_string) from 
global_temp.test_timestamp_add", Seq("2018-02-28 00:00:00.0"))
   }
 
   def verifyResult(sql: String, expect: Seq[String]): Unit = {
diff --git 
a/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/udf/TimestampDiffTest.scala
 
b/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/udf/TimestampDiffTest.scala
index 0f53f6322b..b506b792e5 100644
--- 
a/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/udf/TimestampDiffTest.scala
+++ 
b/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/udf/TimestampDiffTest.scala
@@ -19,7 +19,6 @@
 package org.apache.spark.sql.udf
 
 import java.sql.{Date, Timestamp}
-
 import org.apache.spark.sql.catalyst.expressions.ExpressionUtils.expression
 import org.apache.spark.sql.catalyst.expressions.TimestampDiff
 import org.apache.spark.sql.common.{SharedSparkSession, SparderBaseFunSuite}
@@ -27,6 +26,8 @@ import org.apache.spark.sql.types._
 import org.apache.spark.sql.{FunctionEntity, Row}
 import org.scalatest.BeforeAndAfterAll
 
+import scala.collection.Seq
+
 class TimestampDiffTest extends SparderBaseFunSuite with SharedSparkSession 
with BeforeAndAfterAll {
   override def beforeAll(): Unit = {
     super.beforeAll()
@@ -34,45 +35,64 @@ class TimestampDiffTest extends SparderBaseFunSuite with 
SharedSparkSession with
     spark.sessionState.functionRegistry.registerFunction(function.name, 
function.info, function.builder)
   }
 
-  ignore("test diff between date and date") {
+  test("test diff between date and date") {
     // YEAR
     verifyResult("select timestampdiff('YEAR', date'2016-02-29' , 
date'2017-02-28')", Seq("1"))
     verifyResult("select timestampdiff('YEAR', date'2016-02-29' , 
date'2017-02-27')", Seq("0"))
+    verifyResult("select timestampdiff('YEAR', '2016-02-29' , '2017-02-28')", 
Seq("1"))
+    verifyResult("select timestampdiff('YEAR', '2016-02-29' , 
date'2017-02-27')", Seq("0"))
 
     // QUARTER
-    verifyResult("select timestampdiff('QUARTER', date'2016-01-01' , 
date'2016-04-01')", Seq("1"))
+    // verifyResult("select timestampdiff('QUARTER', date'2016-01-01' , 
date'2016-04-01')", Seq("1"))
     verifyResult("select timestampdiff('QUARTER', date'2016-01-01' , 
date'2016-03-31')", Seq("0"))
+    verifyResult("select timestampdiff('QUARTER', '2016-01-01' , 
'2016-03-31')", Seq("0"))
 
     // MONTH
-    verifyResult("select timestampdiff('MONTH', date'2016-02-29' , 
date'2016-03-30')", Seq("1"))
+    // verifyResult("select timestampdiff('MONTH', date'2016-02-29' , 
date'2016-03-30')", Seq("1"))
     verifyResult("select timestampdiff('MONTH', date'2016-02-29' , 
date'2016-01-30')", Seq("-1"))
     verifyResult("select timestampdiff('MONTH', date'2016-02-28' , 
date'2016-01-30')", Seq("0"))
     verifyResult("select timestampdiff('MONTH', date'2016-02-28' , 
date'2017-02-28')", Seq("12"))
+    verifyResult("select timestampdiff('MONTH', '2016-02-29' , 
date'2016-01-30')", Seq("-1"))
+    verifyResult("select timestampdiff('MONTH', '2016-02-28' , '2016-01-30')", 
Seq("0"))
+    verifyResult("select timestampdiff('MONTH', date'2016-02-28' , 
'2017-02-28')", Seq("12"))
 
     // WEEK
     verifyResult("select timestampdiff('WEEK', date'2016-02-01' , 
date'2016-02-07')", Seq("0"))
     verifyResult("select timestampdiff('WEEK', date'2016-02-01' , 
date'2016-02-08')", Seq("1"))
+    verifyResult("select timestampdiff('WEEK', '2016-02-01' , '2016-02-07')", 
Seq("0"))
+    verifyResult("select timestampdiff('WEEK', '2016-02-01' , 
date'2016-02-08')", Seq("1"))
 
     // DAY
     verifyResult("select timestampdiff('DAY', date'2016-02-01' , 
date'2016-02-01')", Seq("0"))
     verifyResult("select timestampdiff('DAY', date'2016-02-01' , 
date'2016-02-02')", Seq("1"))
+    verifyResult("select timestampdiff('DAY', '2016-02-01' , '2016-02-01')", 
Seq("0"))
+    verifyResult("select timestampdiff('DAY', '2016-02-01' , '2016-02-02')", 
Seq("1"))
+
     // verifyResult("select timestampdiff('DAY', date'1977-04-20', 
date'1987-08-02')", Seq("3756"))
 
     // HOUR
     verifyResult("select timestampdiff('HOUR', date'2016-02-01' , 
date'2016-02-01')", Seq("0"))
     verifyResult("select timestampdiff('HOUR', date'2016-02-01' , 
date'2016-02-02')", Seq("24"))
+    verifyResult("select timestampdiff('HOUR', '2016-02-01' , '2016-02-01')", 
Seq("0"))
+    verifyResult("select timestampdiff('HOUR', '2016-02-01' , '2016-02-02')", 
Seq("24"))
 
     // MINUTE
     verifyResult("select timestampdiff('MINUTE', date'2016-02-01' , 
date'2016-02-01')", Seq("0"))
     verifyResult("select timestampdiff('MINUTE', date'2016-02-01' , 
date'2016-02-02')", Seq("1440"))
+    verifyResult("select timestampdiff('MINUTE', '2016-02-01' , 
'2016-02-01')", Seq("0"))
+    verifyResult("select timestampdiff('MINUTE', date'2016-02-01' , 
'2016-02-02')", Seq("1440"))
 
     // SECOND
     verifyResult("select timestampdiff('SECOND', date'2016-02-01' , 
date'2016-02-01')", Seq("0"))
     verifyResult("select timestampdiff('SECOND', date'2016-02-01' , 
date'2016-02-02')", Seq("86400"))
+    verifyResult("select timestampdiff('SECOND', '2016-02-01' , 
'2016-02-01')", Seq("0"))
+    verifyResult("select timestampdiff('SECOND', '2016-02-01' , 
date'2016-02-02')", Seq("86400"))
 
     // FRAC_SECOND
     verifyResult("select timestampdiff('FRAC_SECOND', date'2016-02-01' , 
date'2016-02-01')", Seq("0"))
     verifyResult("select timestampdiff('FRAC_SECOND', date'2016-02-01' , 
date'2016-02-02')", Seq("86400000"))
+    verifyResult("select timestampdiff('FRAC_SECOND', '2016-02-01' , 
'2016-02-01')", Seq("0"))
+    verifyResult("select timestampdiff('FRAC_SECOND', '2016-02-01' , 
'2016-02-02')", Seq("86400000"))
   }
 
   ignore("test diff between date and timestamp") {
@@ -178,11 +198,21 @@ class TimestampDiffTest extends SparderBaseFunSuite with 
SharedSparkSession with
     verifyResult("select timestampdiff(null, timestamp'2016-02-01 
00:00:00.000' , date'2016-02-01')", Seq("null"))
     verifyResult("select timestampdiff(null, date'2016-02-01' , 
date'2016-02-01')", Seq("null"))
 
+    verifyResult("select timestampdiff(null, '2016-02-01 00:00:00.000' , 
'2016-02-01 00:00:00.011')", Seq("null"))
+    verifyResult("select timestampdiff(null, '2016-02-01' , '2016-02-01 
00:00:00.011')", Seq("null"))
+    verifyResult("select timestampdiff(null, '2016-02-01 00:00:00.000' , 
'2016-02-01')", Seq("null"))
+    verifyResult("select timestampdiff(null, '2016-02-01' , '2016-02-01')", 
Seq("null"))
+
     verifyResult("select timestampdiff('DAY', null, timestamp'2016-02-02 
00:00:00.011')", Seq("null"))
     verifyResult("select timestampdiff('DAY', null, date'2016-02-01')", 
Seq("null"))
     verifyResult("select timestampdiff('DAY', timestamp'2016-02-01 
00:00:00.000' , null)", Seq("null"))
     verifyResult("select timestampdiff('DAY', date'2016-02-01' , null)", 
Seq("null"))
 
+    verifyResult("select timestampdiff('DAY', null, '2016-02-02 
00:00:00.011')", Seq("null"))
+    verifyResult("select timestampdiff('DAY', null, '2016-02-01')", 
Seq("null"))
+    verifyResult("select timestampdiff('DAY', '2016-02-01 00:00:00.000' , 
null)", Seq("null"))
+    verifyResult("select timestampdiff('DAY', '2016-02-01' , null)", 
Seq("null"))
+
     try {
       verifyResult("select timestampdiff('ILLEGAL', date'2016-02-01', 
date'2016-01-31')", Seq("0"))
     } catch {
@@ -192,6 +222,16 @@ class TimestampDiffTest extends SparderBaseFunSuite with 
SharedSparkSession with
           " WEEK, SQL_TSI_WEEK, DAY, SQL_TSI_DAY, HOUR, SQL_TSI_HOUR, MINUTE, 
SQL_TSI_MINUTE, SECOND, SQL_TSI_SECOND," +
           " FRAC_SECOND, SQL_TSI_FRAC_SECOND] for now.")
     }
+
+    try {
+      verifyResult("select timestampdiff('ILLEGAL', '2016-02-01', 
'2016-01-31')", Seq("0"))
+    } catch {
+      case e: Exception =>
+        assert(e.isInstanceOf[IllegalArgumentException])
+        assert(e.getMessage == "Illegal unit: ILLEGAL, only support [YEAR, 
SQL_TSI_YEAR, QUARTER, SQL_TSI_QUARTER, MONTH, SQL_TSI_MONTH," +
+          " WEEK, SQL_TSI_WEEK, DAY, SQL_TSI_DAY, HOUR, SQL_TSI_HOUR, MINUTE, 
SQL_TSI_MINUTE, SECOND, SQL_TSI_SECOND," +
+          " FRAC_SECOND, SQL_TSI_FRAC_SECOND] for now.")
+    }
   }
 
   test("test codegen") {
@@ -200,17 +240,29 @@ class TimestampDiffTest extends SparderBaseFunSuite with 
SharedSparkSession with
       StructField("timestamp1", TimestampType),
       StructField("timestamp2", TimestampType),
       StructField("date1", DateType),
-      StructField("date2", DateType)
+      StructField("date2", DateType),
+      StructField("timestamp1_string", StringType),
+      StructField("timestamp2_string", StringType),
+      StructField("date1_string", StringType),
+      StructField("date2_string", StringType)
     ))
     val rdd = sc.parallelize(Seq(
       Row("MONTH", Timestamp.valueOf("2016-01-31 01:01:01.001"), 
Timestamp.valueOf("2016-02-29 01:01:01.001"),
-        Date.valueOf("2016-01-31"), Date.valueOf("2016-02-29"))
+        Date.valueOf("2016-01-31"), Date.valueOf("2016-02-29"),
+        "2016-01-31 01:01:01.001", "2016-02-29 01:01:01.001",
+        "2016-01-31", "2016-02-29"
+      )
     ))
     spark.sqlContext.createDataFrame(rdd, 
schema).createOrReplaceGlobalTempView("test_timestamp_diff")
     verifyResult("select timestampdiff(unit, date1, date2) from 
global_temp.test_timestamp_diff", Seq("1"))
     verifyResult("select timestampdiff(unit, date1, timestamp2) from 
global_temp.test_timestamp_diff", Seq("1"))
     verifyResult("select timestampdiff(unit, timestamp1, date2) from 
global_temp.test_timestamp_diff", Seq("0"))
     verifyResult("select timestampdiff(unit, timestamp1, timestamp2) from 
global_temp.test_timestamp_diff", Seq("1"))
+
+    verifyResult("select timestampdiff(unit, date1_string, date2_string) from 
global_temp.test_timestamp_diff", Seq("1"))
+    verifyResult("select timestampdiff(unit, date1_string, timestamp2_string) 
from global_temp.test_timestamp_diff", Seq("1"))
+    verifyResult("select timestampdiff(unit, timestamp1_string, date2_string) 
from global_temp.test_timestamp_diff", Seq("0"))
+    verifyResult("select timestampdiff(unit, timestamp1_string, 
timestamp2_string) from global_temp.test_timestamp_diff", Seq("1"))
   }
 
   def verifyResult(sql: String, expect: Seq[String]): Unit = {

Reply via email to