[SPARK-9242] [SQL] Audit UDAF interface. A few minor changes:
1. Improved documentation 2. Rename apply(distinct....) to distinct. 3. Changed MutableAggregationBuffer from a trait to an abstract class. 4. Renamed returnDataType to dataType to be more consistent with other expressions. And unrelated to UDAFs: 1. Renamed file names in expressions to use suffix "Expressions" to be more consistent. 2. Moved regexp related expressions out to its own file. 3. Renamed StringComparison => StringPredicate. Author: Reynold Xin <[email protected]> Closes #8321 from rxin/SPARK-9242. (cherry picked from commit 2f2686a73f5a2a53ca5b1023e0d7e0e6c9be5896) Signed-off-by: Reynold Xin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/321cb99c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/321cb99c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/321cb99c Branch: refs/heads/branch-1.5 Commit: 321cb99caa9e63e19eeec0d04fe9d425abdb7109 Parents: 56a37b0 Author: Reynold Xin <[email protected]> Authored: Wed Aug 19 17:35:41 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Wed Aug 19 17:35:48 2015 -0700 ---------------------------------------------------------------------- .../sql/catalyst/expressions/bitwise.scala | 126 -- .../expressions/bitwiseExpressions.scala | 126 ++ .../expressions/conditionalExpressions.scala | 421 ++++++ .../sql/catalyst/expressions/conditionals.scala | 421 ------ .../expressions/datetimeExpressions.scala | 899 ++++++++++++ .../expressions/datetimeFunctions.scala | 899 ------------ .../expressions/decimalExpressions.scala | 109 ++ .../catalyst/expressions/decimalFunctions.scala | 109 -- .../catalyst/expressions/jsonExpressions.scala | 309 ++++ .../catalyst/expressions/jsonFunctions.scala | 309 ---- .../spark/sql/catalyst/expressions/math.scala | 801 ----------- .../catalyst/expressions/mathExpressions.scala | 801 +++++++++++ .../catalyst/expressions/nullExpressions.scala | 275 ++++ .../catalyst/expressions/nullFunctions.scala | 275 ---- .../spark/sql/catalyst/expressions/random.scala | 98 -- .../expressions/randomExpressions.scala | 98 ++ .../expressions/regexpExpressions.scala | 346 +++++ .../expressions/stringExpressions.scala | 996 +++++++++++++ .../catalyst/expressions/stringOperations.scala | 1320 ------------------ .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 2 +- .../org/apache/spark/sql/UDFRegistration.scala | 1 + .../spark/sql/execution/aggregate/udaf.scala | 2 +- .../org/apache/spark/sql/expressions/udaf.scala | 44 +- .../spark/sql/hive/JavaDataFrameSuite.java | 2 +- .../spark/sql/hive/aggregate/MyDoubleAvg.java | 2 +- .../spark/sql/hive/aggregate/MyDoubleSum.java | 2 +- 27 files changed, 4416 insertions(+), 4379 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala deleted file mode 100644 index a1e48c4..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types._ - - -/** - * A function that calculates bitwise and(&) of two numbers. - * - * Code generation inherited from BinaryArithmetic. - */ -case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = IntegralType - - override def symbol: String = "&" - - private lazy val and: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - } - - protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2) -} - -/** - * A function that calculates bitwise or(|) of two numbers. - * - * Code generation inherited from BinaryArithmetic. - */ -case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = IntegralType - - override def symbol: String = "|" - - private lazy val or: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - } - - protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2) -} - -/** - * A function that calculates bitwise xor of two numbers. - * - * Code generation inherited from BinaryArithmetic. - */ -case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = IntegralType - - override def symbol: String = "^" - - private lazy val xor: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - } - - protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2) -} - -/** - * A function that calculates bitwise not(~) of a number. - */ -case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) - - override def dataType: DataType = child.dataType - - override def toString: String = s"~$child" - - private lazy val not: (Any) => Any = dataType match { - case ByteType => - ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] - case ShortType => - ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] - case IntegerType => - ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] - case LongType => - ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") - } - - protected override def nullSafeEval(input: Any): Any = not(input) -} http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala new file mode 100644 index 0000000..a1e48c4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types._ + + +/** + * A function that calculates bitwise and(&) of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "&" + + private lazy val and: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2) +} + +/** + * A function that calculates bitwise or(|) of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "|" + + private lazy val or: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2) +} + +/** + * A function that calculates bitwise xor of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "^" + + private lazy val xor: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2) +} + +/** + * A function that calculates bitwise not(~) of a number. + */ +case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) + + override def dataType: DataType = child.dataType + + override def toString: String = s"~$child" + + private lazy val not: (Any) => Any = dataType match { + case ByteType => + ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] + case ShortType => + ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] + case IntegerType => + ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] + case LongType => + ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + } + + protected override def nullSafeEval(input: Any): Any = not(input) +} http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala new file mode 100644 index 0000000..d51f3d3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{NullType, BooleanType, DataType} + + +case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) + extends Expression { + + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable + + override def checkInputDataTypes(): TypeCheckResult = { + if (predicate.dataType != BooleanType) { + TypeCheckResult.TypeCheckFailure( + s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + } else if (trueValue.dataType != falseValue.dataType) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: DataType = trueValue.dataType + + override def eval(input: InternalRow): Any = { + if (true == predicate.eval(input)) { + trueValue.eval(input) + } else { + falseValue.eval(input) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val condEval = predicate.gen(ctx) + val trueEval = trueValue.gen(ctx) + val falseEval = falseValue.gen(ctx) + + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.primitive}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.primitive} = ${trueEval.primitive}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.primitive} = ${falseEval.primitive}; + } + """ + } + + override def toString: String = s"if ($predicate) $trueValue else $falseValue" +} + +trait CaseWhenLike extends Expression { + + // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last + // element is the value for the default catch-all case (if provided). + // Hence, `branches` consists of at least two elements, and can have an odd or even length. + def branches: Seq[Expression] + + @transient lazy val whenList = + branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq + @transient lazy val thenList = + branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq + val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) + + // both then and else expressions should be considered. + def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) + def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 + + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypesEqual) { + checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") + } + } + + protected def checkTypesInternal(): TypeCheckResult + + override def dataType: DataType = thenList.head.dataType + + override def nullable: Boolean = { + // If no value is nullable and no elseValue is provided, the whole statement defaults to null. + thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray + + override def children: Seq[Expression] = branches + + override protected def checkTypesInternal(): TypeCheckResult = { + if (whenList.forall(_.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") + } + } + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: InternalRow): Any = { + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + while (i < len - 1) { + if (branchesArr(i).eval(input) == true) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + var res: Any = null + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + return res + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${cond.primitive}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + $cases + $other + """ + } + + override def toString: String = { + "CASE" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray + + override def children: Seq[Expression] = key +: branches + + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private def evalElse(input: InternalRow): Any = { + if (branchesArr.length % 2 == 0) { + null + } else { + branchesArr(branchesArr.length - 1).eval(input) + } + } + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: InternalRow): Any = { + val evaluatedKey = key.eval(input) + // If key is null, we can just return the else part or null if there is no else. + // If key is not null but doesn't match any when part, we need to return + // the else part or null if there is no else, according to Hive's semantics. + if (evaluatedKey != null) { + val len = branchesArr.length + var i = 0 + while (i < len - 1) { + if (evaluatedKey == branchesArr(i).eval(input)) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + } + evalElse(input) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val keyEval = key.gen(ctx) + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${keyEval.code} + if (!${keyEval.isNull}) { + $cases + } + $other + """ + } + + override def toString: String = { + s"CASE $key" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} + +/** + * A function that returns the least value of all parameters, skipping null values. + * It takes at least 2 parameters, and returns null iff all parameters are null. + */ +case class Least(children: Seq[Expression]) extends Expression { + + override def nullable: Boolean = children.forall(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + TypeCheckResult.TypeCheckFailure( + s"The expressions should all have the same type," + + s" got LEAST (${children.map(_.dataType)}).") + } else { + TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + } + } + + override def dataType: DataType = children.head.dataType + + override def eval(input: InternalRow): Any = { + children.foldLeft[Any](null)((r, c) => { + val evalc = c.eval(input) + if (evalc != null) { + if (r == null || ordering.lt(evalc, r)) evalc else r + } else { + r + } + }) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evalChildren = children.map(_.gen(ctx)) + def updateEval(i: Int): String = + s""" + if (!${evalChildren(i).isNull} && (${ev.isNull} || + ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) { + ${ev.isNull} = false; + ${ev.primitive} = ${evalChildren(i).primitive}; + } + """ + s""" + ${evalChildren.map(_.code).mkString("\n")} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${children.indices.map(updateEval).mkString("\n")} + """ + } +} + +/** + * A function that returns the greatest value of all parameters, skipping null values. + * It takes at least 2 parameters, and returns null iff all parameters are null. + */ +case class Greatest(children: Seq[Expression]) extends Expression { + + override def nullable: Boolean = children.forall(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + TypeCheckResult.TypeCheckFailure( + s"The expressions should all have the same type," + + s" got GREATEST (${children.map(_.dataType)}).") + } else { + TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + } + } + + override def dataType: DataType = children.head.dataType + + override def eval(input: InternalRow): Any = { + children.foldLeft[Any](null)((r, c) => { + val evalc = c.eval(input) + if (evalc != null) { + if (r == null || ordering.gt(evalc, r)) evalc else r + } else { + r + } + }) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evalChildren = children.map(_.gen(ctx)) + def updateEval(i: Int): String = + s""" + if (!${evalChildren(i).isNull} && (${ev.isNull} || + ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) { + ${ev.isNull} = false; + ${ev.primitive} = ${evalChildren(i).primitive}; + } + """ + s""" + ${evalChildren.map(_.code).mkString("\n")} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${children.indices.map(updateEval).mkString("\n")} + """ + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala deleted file mode 100644 index d51f3d3..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ /dev/null @@ -1,421 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{NullType, BooleanType, DataType} - - -case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { - - override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil - override def nullable: Boolean = trueValue.nullable || falseValue.nullable - - override def checkInputDataTypes(): TypeCheckResult = { - if (predicate.dataType != BooleanType) { - TypeCheckResult.TypeCheckFailure( - s"type of predicate expression in If should be boolean, not ${predicate.dataType}") - } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + - s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") - } else { - TypeCheckResult.TypeCheckSuccess - } - } - - override def dataType: DataType = trueValue.dataType - - override def eval(input: InternalRow): Any = { - if (true == predicate.eval(input)) { - trueValue.eval(input) - } else { - falseValue.eval(input) - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val condEval = predicate.gen(ctx) - val trueEval = trueValue.gen(ctx) - val falseEval = falseValue.gen(ctx) - - s""" - ${condEval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.primitive}) { - ${trueEval.code} - ${ev.isNull} = ${trueEval.isNull}; - ${ev.primitive} = ${trueEval.primitive}; - } else { - ${falseEval.code} - ${ev.isNull} = ${falseEval.isNull}; - ${ev.primitive} = ${falseEval.primitive}; - } - """ - } - - override def toString: String = s"if ($predicate) $trueValue else $falseValue" -} - -trait CaseWhenLike extends Expression { - - // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last - // element is the value for the default catch-all case (if provided). - // Hence, `branches` consists of at least two elements, and can have an odd or even length. - def branches: Seq[Expression] - - @transient lazy val whenList = - branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq - @transient lazy val thenList = - branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq - val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - - // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 - - override def checkInputDataTypes(): TypeCheckResult = { - if (valueTypesEqual) { - checkTypesInternal() - } else { - TypeCheckResult.TypeCheckFailure( - "THEN and ELSE expressions should all be same type or coercible to a common type") - } - } - - protected def checkTypesInternal(): TypeCheckResult - - override def dataType: DataType = thenList.head.dataType - - override def nullable: Boolean = { - // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) - } -} - -// scalastyle:off -/** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = branches - - override protected def checkTypesInternal(): TypeCheckResult = { - if (whenList.forall(_.dataType == BooleanType)) { - TypeCheckResult.TypeCheckSuccess - } else { - val index = whenList.indexWhere(_.dataType != BooleanType) - TypeCheckResult.TypeCheckFailure( - s"WHEN expressions in CaseWhen should all be boolean type, " + - s"but the ${index + 1}th when expression's type is ${whenList(index)}") - } - } - - /** Written in imperative fashion for performance considerations. */ - override def eval(input: InternalRow): Any = { - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (branchesArr(i).eval(input) == true) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - return res - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val len = branchesArr.length - val got = ctx.freshName("got") - - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) - s""" - if (!$got) { - ${cond.code} - if (!${cond.isNull} && ${cond.primitive}) { - $got = true; - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; - } - } - """ - }.mkString("\n") - - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" - if (!$got) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; - } - """ - } else { - "" - } - - s""" - boolean $got = false; - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - $cases - $other - """ - } - - override def toString: String = { - "CASE" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } -} - -// scalastyle:off -/** - * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = key +: branches - - override protected def checkTypesInternal(): TypeCheckResult = { - if ((key +: whenList).map(_.dataType).distinct.size > 1) { - TypeCheckResult.TypeCheckFailure( - "key and WHEN expressions should all be same type or coercible to a common type") - } else { - TypeCheckResult.TypeCheckSuccess - } - } - - private def evalElse(input: InternalRow): Any = { - if (branchesArr.length % 2 == 0) { - null - } else { - branchesArr(branchesArr.length - 1).eval(input) - } - } - - /** Written in imperative fashion for performance considerations. */ - override def eval(input: InternalRow): Any = { - val evaluatedKey = key.eval(input) - // If key is null, we can just return the else part or null if there is no else. - // If key is not null but doesn't match any when part, we need to return - // the else part or null if there is no else, according to Hive's semantics. - if (evaluatedKey != null) { - val len = branchesArr.length - var i = 0 - while (i < len - 1) { - if (evaluatedKey == branchesArr(i).eval(input)) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - } - evalElse(input) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val keyEval = key.gen(ctx) - val len = branchesArr.length - val got = ctx.freshName("got") - - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) - s""" - if (!$got) { - ${cond.code} - if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { - $got = true; - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; - } - } - """ - }.mkString("\n") - - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" - if (!$got) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.primitive} = ${res.primitive}; - } - """ - } else { - "" - } - - s""" - boolean $got = false; - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${keyEval.code} - if (!${keyEval.isNull}) { - $cases - } - $other - """ - } - - override def toString: String = { - s"CASE $key" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } -} - -/** - * A function that returns the least value of all parameters, skipping null values. - * It takes at least 2 parameters, and returns null iff all parameters are null. - */ -case class Least(children: Seq[Expression]) extends Expression { - - override def nullable: Boolean = children.forall(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { - TypeCheckResult.TypeCheckFailure( - s"The expressions should all have the same type," + - s" got LEAST (${children.map(_.dataType)}).") - } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) - } - } - - override def dataType: DataType = children.head.dataType - - override def eval(input: InternalRow): Any = { - children.foldLeft[Any](null)((r, c) => { - val evalc = c.eval(input) - if (evalc != null) { - if (r == null || ordering.lt(evalc, r)) evalc else r - } else { - r - } - }) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = - s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) { - ${ev.isNull} = false; - ${ev.primitive} = ${evalChildren(i).primitive}; - } - """ - s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} - """ - } -} - -/** - * A function that returns the greatest value of all parameters, skipping null values. - * It takes at least 2 parameters, and returns null iff all parameters are null. - */ -case class Greatest(children: Seq[Expression]) extends Expression { - - override def nullable: Boolean = children.forall(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { - TypeCheckResult.TypeCheckFailure( - s"The expressions should all have the same type," + - s" got GREATEST (${children.map(_.dataType)}).") - } else { - TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) - } - } - - override def dataType: DataType = children.head.dataType - - override def eval(input: InternalRow): Any = { - children.foldLeft[Any](null)((r, c) => { - val evalc = c.eval(input) - if (evalc != null) { - if (r == null || ordering.gt(evalc, r)) evalc else r - } else { - r - } - }) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = - s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) { - ${ev.isNull} = false; - ${ev.primitive} = ${evalChildren(i).primitive}; - } - """ - s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} - """ - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala new file mode 100644 index 0000000..32dc9b7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -0,0 +1,899 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.text.SimpleDateFormat +import java.util.{Calendar, TimeZone} + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import scala.util.Try + +/** + * Returns the current date at the start of query evaluation. + * All calls of current_date within the same query return the same value. + * + * There is no code generation since this expression should get constant folded by the optimizer. + */ +case class CurrentDate() extends LeafExpression with CodegenFallback { + override def foldable: Boolean = true + override def nullable: Boolean = false + + override def dataType: DataType = DateType + + override def eval(input: InternalRow): Any = { + DateTimeUtils.millisToDays(System.currentTimeMillis()) + } +} + +/** + * Returns the current timestamp at the start of query evaluation. + * All calls of current_timestamp within the same query return the same value. + * + * There is no code generation since this expression should get constant folded by the optimizer. + */ +case class CurrentTimestamp() extends LeafExpression with CodegenFallback { + override def foldable: Boolean = true + override def nullable: Boolean = false + + override def dataType: DataType = TimestampType + + override def eval(input: InternalRow): Any = { + System.currentTimeMillis() * 1000L + } +} + +/** + * Adds a number of days to startdate. + */ +case class DateAdd(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] + d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd + $d;""" + }) + } +} + +/** + * Subtracts a number of days to startdate. + */ +case class DateSub(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] - d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd - $d;""" + }) + } +} + +case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") + } +} + +case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") + } +} + +case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(timestamp: Any): Any = { + DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") + } +} + +case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") + } +} + + +case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getYear(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") + } +} + +case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getQuarter(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") + } +} + +case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") + } +} + +case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + override protected def nullSafeEval(date: Any): Any = { + DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") + } +} + +case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient private lazy val c = { + val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.setFirstDayOfWeek(Calendar.MONDAY) + c.setMinimalDaysInFirstWeek(4) + c + } + + override protected def nullSafeEval(date: Any): Any = { + c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + c.get(Calendar.WEEK_OF_YEAR) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, time => { + val cal = classOf[Calendar].getName + val c = ctx.freshName("cal") + ctx.addMutableState(cal, c, + s""" + $c = $cal.getInstance(java.util.TimeZone.getTimeZone("UTC")); + $c.setFirstDayOfWeek($cal.MONDAY); + $c.setMinimalDaysInFirstWeek(4); + """) + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.primitive} = $c.get($cal.WEEK_OF_YEAR); + """ + }) + } +} + +case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression + with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + + override protected def nullSafeEval(timestamp: Any, format: Any): Any = { + val sdf = new SimpleDateFormat(format.toString) + UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val sdf = classOf[SimpleDateFormat].getName + defineCodeGen(ctx, ev, (timestamp, format) => { + s"""UTF8String.fromString((new $sdf($format.toString())) + .format(new java.util.Date($timestamp / 1000)))""" + }) + } + + override def prettyName: String = "date_format" +} + +/** + * Converts time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), returns null if fail. + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss". + * If no parameters provided, the first parameter will be current_timestamp. + * If the first parameter is a Date or Timestamp instead of String, we will ignore the + * second parameter. + */ +case class UnixTimestamp(timeExp: Expression, format: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = timeExp + override def right: Expression = format + + def this(time: Expression) = { + this(time, Literal("yyyy-MM-dd HH:mm:ss")) + } + + def this() = { + this(CurrentTimestamp()) + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, DateType, TimestampType), StringType) + + override def dataType: DataType = LongType + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val t = left.eval(input) + if (t == null) { + null + } else { + left.dataType match { + case DateType => + DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L + case TimestampType => + t.asInstanceOf[Long] / 1000000L + case StringType if right.foldable => + if (constFormat != null) { + Try(new SimpleDateFormat(constFormat.toString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } else { + null + } + case StringType => + val f = format.eval(input) + if (f == null) { + null + } else { + val formatString = f.asInstanceOf[UTF8String].toString + Try(new SimpleDateFormat(formatString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + left.dataType match { + case StringType if right.foldable => + val sdf = classOf[SimpleDateFormat].getName + val fString = if (constFormat == null) null else constFormat.toString + val formatter = ctx.freshName("formatter") + if (fString == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + $sdf $formatter = new $sdf("$fString"); + ${ev.primitive} = + $formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + case StringType => + val sdf = classOf[SimpleDateFormat].getName + nullSafeCodeGen(ctx, ev, (string, format) => { + s""" + try { + ${ev.primitive} = + (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + """ + }) + case TimestampType => + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${eval1.primitive} / 1000000L; + } + """ + case DateType => + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L; + } + """ + } + } +} + +/** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. If the format is missing, using format like "1970-01-01 00:00:00". + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + */ +case class FromUnixTime(sec: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = sec + override def right: Expression = format + + def this(unix: Expression) = { + this(unix, Literal("yyyy-MM-dd HH:mm:ss")) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val time = left.eval(input) + if (time == null) { + null + } else { + if (format.foldable) { + if (constFormat == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format( + new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } else { + val f = format.eval(input) + if (f == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat( + f.asInstanceOf[UTF8String].toString).format(new java.util.Date( + time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val sdf = classOf[SimpleDateFormat].getName + if (format.foldable) { + if (constFormat == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val t = left.gen(ctx) + s""" + ${t.code} + boolean ${ev.isNull} = ${t.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + ${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( + new java.util.Date(${t.primitive} * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (seconds, f) => { + s""" + try { + ${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format( + new java.util.Date($seconds * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + }""".stripMargin + }) + } + } +} + +/** + * Returns the last day of the month which the date belongs to. + */ +case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def child: Expression = startDate + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def nullSafeEval(date: Any): Any = { + DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") + } + + override def prettyName: String = "last_day" +} + +/** + * Returns the first date which is later than startDate and named as dayOfWeek. + * For example, NextDay(2015-07-27, Sunday) would return 2015-08-02, which is the first + * Sunday later than 2015-07-27. + * + * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. + */ +case class NextDay(startDate: Expression, dayOfWeek: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = dayOfWeek + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, dayOfW: Any): Any = { + val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) + if (dow == -1) { + null + } else { + val sd = start.asInstanceOf[Int] + DateTimeUtils.getNextDateForDayOfWeek(sd, dow) + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, dowS) => { + val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") + val dayOfWeekTerm = ctx.freshName("dayOfWeek") + if (dayOfWeek.foldable) { + val input = dayOfWeek.eval().asInstanceOf[UTF8String] + if ((input eq null) || DateTimeUtils.getDayOfWeekFromString(input) == -1) { + s""" + |${ev.isNull} = true; + """.stripMargin + } else { + val dayOfWeekValue = DateTimeUtils.getDayOfWeekFromString(input) + s""" + |${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue); + """.stripMargin + } + } else { + s""" + |int $dayOfWeekTerm = $dateTimeUtilClass.getDayOfWeekFromString($dowS); + |if ($dayOfWeekTerm == -1) { + | ${ev.isNull} = true; + |} else { + | ${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekTerm); + |} + """.stripMargin + } + }) + } + + override def prettyName: String = "next_day" +} + +/** + * Adds an interval to timestamp. + */ +case class TimeAdd(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left + $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], itvl.months, itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" + }) + } +} + +/** + * Assumes given timestamp is UTC and converts to given timezone. + */ +case class FromUTCTimestamp(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + override def dataType: DataType = TimestampType + override def prettyName: String = "from_utc_timestamp" + + override def nullSafeEval(time: Any, timezone: Any): Any = { + DateTimeUtils.fromUTCTime(time.asInstanceOf[Long], + timezone.asInstanceOf[UTF8String].toString) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + if (right.foldable) { + val tz = right.eval() + if (tz == null) { + s""" + |boolean ${ev.isNull} = true; + |long ${ev.primitive} = 0; + """.stripMargin + } else { + val tzTerm = ctx.freshName("tz") + val tzClass = classOf[TimeZone].getName + ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") + val eval = left.gen(ctx) + s""" + |${eval.code} + |boolean ${ev.isNull} = ${eval.isNull}; + |long ${ev.primitive} = 0; + |if (!${ev.isNull}) { + | ${ev.primitive} = ${eval.primitive} + + | ${tzTerm}.getOffset(${eval.primitive} / 1000) * 1000L; + |} + """.stripMargin + } + } else { + defineCodeGen(ctx, ev, (timestamp, format) => { + s"""$dtu.fromUTCTime($timestamp, $format.toString())""" + }) + } + } +} + +/** + * Subtracts an interval from timestamp. + */ +case class TimeSub(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left - $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" + }) + } +} + +/** + * Returns the date that is num_months after start_date. + */ +case class AddMonths(startDate: Expression, numMonths: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = numMonths + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, months: Any): Any = { + DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, m) => { + s"""$dtu.dateAddMonths($sd, $m)""" + }) + } +} + +/** + * Returns number of months between dates date1 and date2. + */ +case class MonthsBetween(date1: Expression, date2: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = date1 + override def right: Expression = date2 + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + + override def dataType: DataType = DoubleType + + override def nullSafeEval(t1: Any, t2: Any): Any = { + DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (l, r) => { + s"""$dtu.monthsBetween($l, $r)""" + }) + } +} + +/** + * Assumes given timestamp is in given timezone and converts to UTC. + */ +case class ToUTCTimestamp(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + override def dataType: DataType = TimestampType + override def prettyName: String = "to_utc_timestamp" + + override def nullSafeEval(time: Any, timezone: Any): Any = { + DateTimeUtils.toUTCTime(time.asInstanceOf[Long], + timezone.asInstanceOf[UTF8String].toString) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + if (right.foldable) { + val tz = right.eval() + if (tz == null) { + s""" + |boolean ${ev.isNull} = true; + |long ${ev.primitive} = 0; + """.stripMargin + } else { + val tzTerm = ctx.freshName("tz") + val tzClass = classOf[TimeZone].getName + ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") + val eval = left.gen(ctx) + s""" + |${eval.code} + |boolean ${ev.isNull} = ${eval.isNull}; + |long ${ev.primitive} = 0; + |if (!${ev.isNull}) { + | ${ev.primitive} = ${eval.primitive} - + | ${tzTerm}.getOffset(${eval.primitive} / 1000) * 1000L; + |} + """.stripMargin + } + } else { + defineCodeGen(ctx, ev, (timestamp, format) => { + s"""$dtu.toUTCTime($timestamp, $format.toString())""" + }) + } + } +} + +/** + * Returns the date part of a timestamp or string. + */ +case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Implicit casting of spark will accept string in both date and timestamp format, as + // well as TimestampType. + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def eval(input: InternalRow): Any = child.eval(input) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, d => d) + } +} + +/** + * Returns date truncated to the unit specified by the format. + */ +case class TruncDate(date: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = date + override def right: Expression = format + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def dataType: DataType = DateType + override def prettyName: String = "trunc" + + private lazy val truncLevel: Int = + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + + override def eval(input: InternalRow): Any = { + val level = if (format.foldable) { + truncLevel + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + if (level == -1) { + // unknown format + null + } else { + val d = date.eval(input) + if (d == null) { + null + } else { + DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + + if (format.foldable) { + if (truncLevel == -1) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val d = date.gen(ctx) + s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.truncDate(${d.primitive}, $truncLevel); + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val form = ctx.freshName("form") + s""" + int $form = $dtu.parseTruncLevel($fmt); + if ($form == -1) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $dtu.truncDate($dateVal, $form); + } + """ + }) + } + } +} + +/** + * Returns the number of days from startDate to endDate. + */ +case class DateDiff(endDate: Expression, startDate: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = endDate + override def right: Expression = startDate + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType) + override def dataType: DataType = IntegerType + + override def nullSafeEval(end: Any, start: Any): Any = { + end.asInstanceOf[Int] - start.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (end, start) => s"$end - $start") + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
