http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala deleted file mode 100644 index 32dc9b7..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ /dev/null @@ -1,899 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} - -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - -import scala.util.Try - -/** - * Returns the current date at the start of query evaluation. - * All calls of current_date within the same query return the same value. - * - * There is no code generation since this expression should get constant folded by the optimizer. - */ -case class CurrentDate() extends LeafExpression with CodegenFallback { - override def foldable: Boolean = true - override def nullable: Boolean = false - - override def dataType: DataType = DateType - - override def eval(input: InternalRow): Any = { - DateTimeUtils.millisToDays(System.currentTimeMillis()) - } -} - -/** - * Returns the current timestamp at the start of query evaluation. - * All calls of current_timestamp within the same query return the same value. - * - * There is no code generation since this expression should get constant folded by the optimizer. - */ -case class CurrentTimestamp() extends LeafExpression with CodegenFallback { - override def foldable: Boolean = true - override def nullable: Boolean = false - - override def dataType: DataType = TimestampType - - override def eval(input: InternalRow): Any = { - System.currentTimeMillis() * 1000L - } -} - -/** - * Adds a number of days to startdate. - */ -case class DateAdd(startDate: Expression, days: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = startDate - override def right: Expression = days - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) - - override def dataType: DataType = DateType - - override def nullSafeEval(start: Any, d: Any): Any = { - start.asInstanceOf[Int] + d.asInstanceOf[Int] - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (sd, d) => { - s"""${ev.primitive} = $sd + $d;""" - }) - } -} - -/** - * Subtracts a number of days to startdate. - */ -case class DateSub(startDate: Expression, days: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - override def left: Expression = startDate - override def right: Expression = days - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) - - override def dataType: DataType = DateType - - override def nullSafeEval(start: Any, d: Any): Any = { - start.asInstanceOf[Int] - d.asInstanceOf[Int] - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (sd, d) => { - s"""${ev.primitive} = $sd - $d;""" - }) - } -} - -case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) - - override def dataType: DataType = IntegerType - - override protected def nullSafeEval(timestamp: Any): Any = { - DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") - } -} - -case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) - - override def dataType: DataType = IntegerType - - override protected def nullSafeEval(timestamp: Any): Any = { - DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") - } -} - -case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) - - override def dataType: DataType = IntegerType - - override protected def nullSafeEval(timestamp: Any): Any = { - DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") - } -} - -case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType - - override protected def nullSafeEval(date: Any): Any = { - DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") - } -} - - -case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType - - override protected def nullSafeEval(date: Any): Any = { - DateTimeUtils.getYear(date.asInstanceOf[Int]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") - } -} - -case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType - - override protected def nullSafeEval(date: Any): Any = { - DateTimeUtils.getQuarter(date.asInstanceOf[Int]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") - } -} - -case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType - - override protected def nullSafeEval(date: Any): Any = { - DateTimeUtils.getMonth(date.asInstanceOf[Int]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") - } -} - -case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType - - override protected def nullSafeEval(date: Any): Any = { - DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") - } -} - -case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType - - @transient private lazy val c = { - val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) - c.setFirstDayOfWeek(Calendar.MONDAY) - c.setMinimalDaysInFirstWeek(4) - c - } - - override protected def nullSafeEval(date: Any): Any = { - c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) - c.get(Calendar.WEEK_OF_YEAR) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, time => { - val cal = classOf[Calendar].getName - val c = ctx.freshName("cal") - ctx.addMutableState(cal, c, - s""" - $c = $cal.getInstance(java.util.TimeZone.getTimeZone("UTC")); - $c.setFirstDayOfWeek($cal.MONDAY); - $c.setMinimalDaysInFirstWeek(4); - """) - s""" - $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.primitive} = $c.get($cal.WEEK_OF_YEAR); - """ - }) - } -} - -case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) - - override protected def nullSafeEval(timestamp: Any, format: Any): Any = { - val sdf = new SimpleDateFormat(format.toString) - UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val sdf = classOf[SimpleDateFormat].getName - defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString((new $sdf($format.toString())) - .format(new java.util.Date($timestamp / 1000)))""" - }) - } - - override def prettyName: String = "date_format" -} - -/** - * Converts time string with given pattern - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix time stamp (in seconds), returns null if fail. - * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. - * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss". - * If no parameters provided, the first parameter will be current_timestamp. - * If the first parameter is a Date or Timestamp instead of String, we will ignore the - * second parameter. - */ -case class UnixTimestamp(timeExp: Expression, format: Expression) - extends BinaryExpression with ExpectsInputTypes { - - override def left: Expression = timeExp - override def right: Expression = format - - def this(time: Expression) = { - this(time, Literal("yyyy-MM-dd HH:mm:ss")) - } - - def this() = { - this(CurrentTimestamp()) - } - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringType, DateType, TimestampType), StringType) - - override def dataType: DataType = LongType - - private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - - override def eval(input: InternalRow): Any = { - val t = left.eval(input) - if (t == null) { - null - } else { - left.dataType match { - case DateType => - DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L - case TimestampType => - t.asInstanceOf[Long] / 1000000L - case StringType if right.foldable => - if (constFormat != null) { - Try(new SimpleDateFormat(constFormat.toString).parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) - } else { - null - } - case StringType => - val f = format.eval(input) - if (f == null) { - null - } else { - val formatString = f.asInstanceOf[UTF8String].toString - Try(new SimpleDateFormat(formatString).parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) - } - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - left.dataType match { - case StringType if right.foldable => - val sdf = classOf[SimpleDateFormat].getName - val fString = if (constFormat == null) null else constFormat.toString - val formatter = ctx.freshName("formatter") - if (fString == null) { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } else { - val eval1 = left.gen(ctx) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - try { - $sdf $formatter = new $sdf("$fString"); - ${ev.primitive} = - $formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L; - } catch (java.lang.Throwable e) { - ${ev.isNull} = true; - } - } - """ - } - case StringType => - val sdf = classOf[SimpleDateFormat].getName - nullSafeCodeGen(ctx, ev, (string, format) => { - s""" - try { - ${ev.primitive} = - (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; - } catch (java.lang.Throwable e) { - ${ev.isNull} = true; - } - """ - }) - case TimestampType => - val eval1 = left.gen(ctx) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = ${eval1.primitive} / 1000000L; - } - """ - case DateType => - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val eval1 = left.gen(ctx) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L; - } - """ - } - } -} - -/** - * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string - * representing the timestamp of that moment in the current system time zone in the given - * format. If the format is missing, using format like "1970-01-01 00:00:00". - * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. - */ -case class FromUnixTime(sec: Expression, format: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = sec - override def right: Expression = format - - def this(unix: Expression) = { - this(unix, Literal("yyyy-MM-dd HH:mm:ss")) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) - - private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - - override def eval(input: InternalRow): Any = { - val time = left.eval(input) - if (time == null) { - null - } else { - if (format.foldable) { - if (constFormat == null) { - null - } else { - Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format( - new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null) - } - } else { - val f = format.eval(input) - if (f == null) { - null - } else { - Try(UTF8String.fromString(new SimpleDateFormat( - f.asInstanceOf[UTF8String].toString).format(new java.util.Date( - time.asInstanceOf[Long] * 1000L)))).getOrElse(null) - } - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val sdf = classOf[SimpleDateFormat].getName - if (format.foldable) { - if (constFormat == null) { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } else { - val t = left.gen(ctx) - s""" - ${t.code} - boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - try { - ${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( - new java.util.Date(${t.primitive} * 1000L))); - } catch (java.lang.Throwable e) { - ${ev.isNull} = true; - } - } - """ - } - } else { - nullSafeCodeGen(ctx, ev, (seconds, f) => { - s""" - try { - ${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format( - new java.util.Date($seconds * 1000L))); - } catch (java.lang.Throwable e) { - ${ev.isNull} = true; - }""".stripMargin - }) - } - } -} - -/** - * Returns the last day of the month which the date belongs to. - */ -case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { - override def child: Expression = startDate - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = DateType - - override def nullSafeEval(date: Any): Any = { - DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") - } - - override def prettyName: String = "last_day" -} - -/** - * Returns the first date which is later than startDate and named as dayOfWeek. - * For example, NextDay(2015-07-27, Sunday) would return 2015-08-02, which is the first - * Sunday later than 2015-07-27. - * - * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. - */ -case class NextDay(startDate: Expression, dayOfWeek: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = startDate - override def right: Expression = dayOfWeek - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) - - override def dataType: DataType = DateType - - override def nullSafeEval(start: Any, dayOfW: Any): Any = { - val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) - if (dow == -1) { - null - } else { - val sd = start.asInstanceOf[Int] - DateTimeUtils.getNextDateForDayOfWeek(sd, dow) - } - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (sd, dowS) => { - val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") - val dayOfWeekTerm = ctx.freshName("dayOfWeek") - if (dayOfWeek.foldable) { - val input = dayOfWeek.eval().asInstanceOf[UTF8String] - if ((input eq null) || DateTimeUtils.getDayOfWeekFromString(input) == -1) { - s""" - |${ev.isNull} = true; - """.stripMargin - } else { - val dayOfWeekValue = DateTimeUtils.getDayOfWeekFromString(input) - s""" - |${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue); - """.stripMargin - } - } else { - s""" - |int $dayOfWeekTerm = $dateTimeUtilClass.getDayOfWeekFromString($dowS); - |if ($dayOfWeekTerm == -1) { - | ${ev.isNull} = true; - |} else { - | ${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekTerm); - |} - """.stripMargin - } - }) - } - - override def prettyName: String = "next_day" -} - -/** - * Adds an interval to timestamp. - */ -case class TimeAdd(start: Expression, interval: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = start - override def right: Expression = interval - - override def toString: String = s"$left + $right" - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) - - override def dataType: DataType = TimestampType - - override def nullSafeEval(start: Any, interval: Any): Any = { - val itvl = interval.asInstanceOf[CalendarInterval] - DateTimeUtils.timestampAddInterval( - start.asInstanceOf[Long], itvl.months, itvl.microseconds) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (sd, i) => { - s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" - }) - } -} - -/** - * Assumes given timestamp is UTC and converts to given timezone. - */ -case class FromUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) - override def dataType: DataType = TimestampType - override def prettyName: String = "from_utc_timestamp" - - override def nullSafeEval(time: Any, timezone: Any): Any = { - DateTimeUtils.fromUTCTime(time.asInstanceOf[Long], - timezone.asInstanceOf[UTF8String].toString) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - if (right.foldable) { - val tz = right.eval() - if (tz == null) { - s""" - |boolean ${ev.isNull} = true; - |long ${ev.primitive} = 0; - """.stripMargin - } else { - val tzTerm = ctx.freshName("tz") - val tzClass = classOf[TimeZone].getName - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") - val eval = left.gen(ctx) - s""" - |${eval.code} - |boolean ${ev.isNull} = ${eval.isNull}; - |long ${ev.primitive} = 0; - |if (!${ev.isNull}) { - | ${ev.primitive} = ${eval.primitive} + - | ${tzTerm}.getOffset(${eval.primitive} / 1000) * 1000L; - |} - """.stripMargin - } - } else { - defineCodeGen(ctx, ev, (timestamp, format) => { - s"""$dtu.fromUTCTime($timestamp, $format.toString())""" - }) - } - } -} - -/** - * Subtracts an interval from timestamp. - */ -case class TimeSub(start: Expression, interval: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = start - override def right: Expression = interval - - override def toString: String = s"$left - $right" - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) - - override def dataType: DataType = TimestampType - - override def nullSafeEval(start: Any, interval: Any): Any = { - val itvl = interval.asInstanceOf[CalendarInterval] - DateTimeUtils.timestampAddInterval( - start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (sd, i) => { - s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" - }) - } -} - -/** - * Returns the date that is num_months after start_date. - */ -case class AddMonths(startDate: Expression, numMonths: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = startDate - override def right: Expression = numMonths - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) - - override def dataType: DataType = DateType - - override def nullSafeEval(start: Any, months: Any): Any = { - DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (sd, m) => { - s"""$dtu.dateAddMonths($sd, $m)""" - }) - } -} - -/** - * Returns number of months between dates date1 and date2. - */ -case class MonthsBetween(date1: Expression, date2: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = date1 - override def right: Expression = date2 - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) - - override def dataType: DataType = DoubleType - - override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r)""" - }) - } -} - -/** - * Assumes given timestamp is in given timezone and converts to UTC. - */ -case class ToUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) - override def dataType: DataType = TimestampType - override def prettyName: String = "to_utc_timestamp" - - override def nullSafeEval(time: Any, timezone: Any): Any = { - DateTimeUtils.toUTCTime(time.asInstanceOf[Long], - timezone.asInstanceOf[UTF8String].toString) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - if (right.foldable) { - val tz = right.eval() - if (tz == null) { - s""" - |boolean ${ev.isNull} = true; - |long ${ev.primitive} = 0; - """.stripMargin - } else { - val tzTerm = ctx.freshName("tz") - val tzClass = classOf[TimeZone].getName - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") - val eval = left.gen(ctx) - s""" - |${eval.code} - |boolean ${ev.isNull} = ${eval.isNull}; - |long ${ev.primitive} = 0; - |if (!${ev.isNull}) { - | ${ev.primitive} = ${eval.primitive} - - | ${tzTerm}.getOffset(${eval.primitive} / 1000) * 1000L; - |} - """.stripMargin - } - } else { - defineCodeGen(ctx, ev, (timestamp, format) => { - s"""$dtu.toUTCTime($timestamp, $format.toString())""" - }) - } - } -} - -/** - * Returns the date part of a timestamp or string. - */ -case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - // Implicit casting of spark will accept string in both date and timestamp format, as - // well as TimestampType. - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = DateType - - override def eval(input: InternalRow): Any = child.eval(input) - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, d => d) - } -} - -/** - * Returns date truncated to the unit specified by the format. - */ -case class TruncDate(date: Expression, format: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - override def left: Expression = date - override def right: Expression = format - - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) - override def dataType: DataType = DateType - override def prettyName: String = "trunc" - - private lazy val truncLevel: Int = - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - - override def eval(input: InternalRow): Any = { - val level = if (format.foldable) { - truncLevel - } else { - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - } - if (level == -1) { - // unknown format - null - } else { - val d = date.eval(input) - if (d == null) { - null - } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], level) - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - - if (format.foldable) { - if (truncLevel == -1) { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } else { - val d = date.gen(ctx) - s""" - ${d.code} - boolean ${ev.isNull} = ${d.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $dtu.truncDate(${d.primitive}, $truncLevel); - } - """ - } - } else { - nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { - val form = ctx.freshName("form") - s""" - int $form = $dtu.parseTruncLevel($fmt); - if ($form == -1) { - ${ev.isNull} = true; - } else { - ${ev.primitive} = $dtu.truncDate($dateVal, $form); - } - """ - }) - } - } -} - -/** - * Returns the number of days from startDate to endDate. - */ -case class DateDiff(endDate: Expression, startDate: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = endDate - override def right: Expression = startDate - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType) - override def dataType: DataType = IntegerType - - override def nullSafeEval(end: Any, start: Any): Any = { - end.asInstanceOf[Int] - start.asInstanceOf[Int] - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (end, start) => s"$end - $start") - } -}
http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala new file mode 100644 index 0000000..b7be12f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types._ + +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ +case class UnscaledValue(child: Expression) extends UnaryExpression { + + override def dataType: DataType = LongType + override def toString: String = s"UnscaledValue($child)" + + protected override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toUnscaledLong + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + } +} + +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ +case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { + + override def dataType: DataType = DecimalType(precision, scale) + override def toString: String = s"MakeDecimal($child,$precision,$scale)" + + protected override def nullSafeEval(input: Any): Any = + Decimal(input.asInstanceOf[Long], precision, scale) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, eval => { + s""" + ${ev.primitive} = (new Decimal()).setOrNull($eval, $precision, $scale); + ${ev.isNull} = ${ev.primitive} == null; + """ + }) + } +} + +/** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ +case class PromotePrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def prettyName: String = "promote_precision" +} + +/** + * Rounds the decimal to given scale and check whether the decimal can fit in provided precision + * or not, returns null if not. + */ +case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { + + override def nullable: Boolean = true + + override def nullSafeEval(input: Any): Any = { + val d = input.asInstanceOf[Decimal].clone() + if (d.changePrecision(dataType.precision, dataType.scale)) { + d + } else { + null + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, eval => { + val tmp = ctx.freshName("tmp") + s""" + | Decimal $tmp = $eval.clone(); + | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { + | ${ev.primitive} = $tmp; + | } else { + | ${ev.isNull} = true; + | } + """.stripMargin + }) + } + + override def toString: String = s"CheckOverflow($child, $dataType)" +} http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala deleted file mode 100644 index b7be12f..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.types._ - -/** - * Return the unscaled Long value of a Decimal, assuming it fits in a Long. - * Note: this expression is internal and created only by the optimizer, - * we don't need to do type check for it. - */ -case class UnscaledValue(child: Expression) extends UnaryExpression { - - override def dataType: DataType = LongType - override def toString: String = s"UnscaledValue($child)" - - protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toUnscaledLong - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") - } -} - -/** - * Create a Decimal from an unscaled Long value. - * Note: this expression is internal and created only by the optimizer, - * we don't need to do type check for it. - */ -case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { - - override def dataType: DataType = DecimalType(precision, scale) - override def toString: String = s"MakeDecimal($child,$precision,$scale)" - - protected override def nullSafeEval(input: Any): Any = - Decimal(input.asInstanceOf[Long], precision, scale) - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - ${ev.primitive} = (new Decimal()).setOrNull($eval, $precision, $scale); - ${ev.isNull} = ${ev.primitive} == null; - """ - }) - } -} - -/** - * An expression used to wrap the children when promote the precision of DecimalType to avoid - * promote multiple times. - */ -case class PromotePrecision(child: Expression) extends UnaryExpression { - override def dataType: DataType = child.dataType - override def eval(input: InternalRow): Any = child.eval(input) - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" - override def prettyName: String = "promote_precision" -} - -/** - * Rounds the decimal to given scale and check whether the decimal can fit in provided precision - * or not, returns null if not. - */ -case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { - - override def nullable: Boolean = true - - override def nullSafeEval(input: Any): Any = { - val d = input.asInstanceOf[Decimal].clone() - if (d.changePrecision(dataType.precision, dataType.scale)) { - d - } else { - null - } - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - val tmp = ctx.freshName("tmp") - s""" - | Decimal $tmp = $eval.clone(); - | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { - | ${ev.primitive} = $tmp; - | } else { - | ${ev.isNull} = true; - | } - """.stripMargin - }) - } - - override def toString: String = s"CheckOverflow($child, $dataType)" -} http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala new file mode 100644 index 0000000..23bfa18 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.io.{StringWriter, ByteArrayOutputStream} + +import com.fasterxml.jackson.core._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.{StringType, DataType} +import org.apache.spark.unsafe.types.UTF8String + +import scala.util.parsing.combinator.RegexParsers + +private[this] sealed trait PathInstruction +private[this] object PathInstruction { + private[expressions] case object Subscript extends PathInstruction + private[expressions] case object Wildcard extends PathInstruction + private[expressions] case object Key extends PathInstruction + private[expressions] case class Index(index: Long) extends PathInstruction + private[expressions] case class Named(name: String) extends PathInstruction +} + +private[this] sealed trait WriteStyle +private[this] object WriteStyle { + private[expressions] case object RawStyle extends WriteStyle + private[expressions] case object QuotedStyle extends WriteStyle + private[expressions] case object FlattenStyle extends WriteStyle +} + +private[this] object JsonPathParser extends RegexParsers { + import PathInstruction._ + + def root: Parser[Char] = '$' + + def long: Parser[Long] = "\\d+".r ^? { + case x => x.toLong + } + + // parse `[*]` and `[123]` subscripts + def subscript: Parser[List[PathInstruction]] = + for { + operand <- '[' ~> ('*' ^^^ Wildcard | long ^^ Index) <~ ']' + } yield { + Subscript :: operand :: Nil + } + + // parse `.name` or `['name']` child expressions + def named: Parser[List[PathInstruction]] = + for { + name <- '.' ~> "[^\\.\\[]+".r | "[\\'" ~> "[^\\'\\?]+" <~ "\\']" + } yield { + Key :: Named(name) :: Nil + } + + // child wildcards: `..`, `.*` or `['*']` + def wildcard: Parser[List[PathInstruction]] = + (".*" | "['*']") ^^^ List(Wildcard) + + def node: Parser[List[PathInstruction]] = + wildcard | + named | + subscript + + val expression: Parser[List[PathInstruction]] = { + phrase(root ~> rep(node) ^^ (x => x.flatten)) + } + + def parse(str: String): Option[List[PathInstruction]] = { + this.parseAll(expression, str) match { + case Success(result, _) => + Some(result) + + case NoSuccess(msg, next) => + None + } + } +} + +private[this] object GetJsonObject { + private val jsonFactory = new JsonFactory() + + // Enabled for Hive compatibility + jsonFactory.enable(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS) +} + +/** + * Extracts json object from a json string based on json path specified, and returns json string + * of the extracted json object. It will return null if the input json string is invalid. + */ +case class GetJsonObject(json: Expression, path: Expression) + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + + import GetJsonObject._ + import PathInstruction._ + import WriteStyle._ + import com.fasterxml.jackson.core.JsonToken._ + + override def left: Expression = json + override def right: Expression = path + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def dataType: DataType = StringType + override def prettyName: String = "get_json_object" + + @transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) + + override def eval(input: InternalRow): Any = { + val jsonStr = json.eval(input).asInstanceOf[UTF8String] + if (jsonStr == null) { + return null + } + + val parsed = if (path.foldable) { + parsedPath + } else { + parsePath(path.eval(input).asInstanceOf[UTF8String]) + } + + if (parsed.isDefined) { + try { + val parser = jsonFactory.createParser(jsonStr.getBytes) + val output = new ByteArrayOutputStream() + val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8) + parser.nextToken() + val matched = evaluatePath(parser, generator, RawStyle, parsed.get) + generator.close() + if (matched) { + UTF8String.fromBytes(output.toByteArray) + } else { + null + } + } catch { + case _: JsonProcessingException => null + } + } else { + null + } + } + + private def parsePath(path: UTF8String): Option[List[PathInstruction]] = { + if (path != null) { + JsonPathParser.parse(path.toString) + } else { + None + } + } + + // advance to the desired array index, assumes to start at the START_ARRAY token + private def arrayIndex(p: JsonParser, f: () => Boolean): Long => Boolean = { + case _ if p.getCurrentToken == END_ARRAY => + // terminate, nothing has been written + false + + case 0 => + // we've reached the desired index + val dirty = f() + + while (p.nextToken() != END_ARRAY) { + // advance the token stream to the end of the array + p.skipChildren() + } + + dirty + + case i if i > 0 => + // skip this token and evaluate the next + p.skipChildren() + p.nextToken() + arrayIndex(p, f)(i - 1) + } + + /** + * Evaluate a list of JsonPath instructions, returning a bool that indicates if any leaf nodes + * have been written to the generator + */ + private def evaluatePath( + p: JsonParser, + g: JsonGenerator, + style: WriteStyle, + path: List[PathInstruction]): Boolean = { + (p.getCurrentToken, path) match { + case (VALUE_STRING, Nil) if style == RawStyle => + // there is no array wildcard or slice parent, emit this string without quotes + if (p.hasTextCharacters) { + g.writeRaw(p.getTextCharacters, p.getTextOffset, p.getTextLength) + } else { + g.writeRaw(p.getText) + } + true + + case (START_ARRAY, Nil) if style == FlattenStyle => + // flatten this array into the parent + var dirty = false + while (p.nextToken() != END_ARRAY) { + dirty |= evaluatePath(p, g, style, Nil) + } + dirty + + case (_, Nil) => + // general case: just copy the child tree verbatim + g.copyCurrentStructure(p) + true + + case (START_OBJECT, Key :: xs) => + var dirty = false + while (p.nextToken() != END_OBJECT) { + if (dirty) { + // once a match has been found we can skip other fields + p.skipChildren() + } else { + dirty = evaluatePath(p, g, style, xs) + } + } + dirty + + case (START_ARRAY, Subscript :: Wildcard :: Subscript :: Wildcard :: xs) => + // special handling for the non-structure preserving double wildcard behavior in Hive + var dirty = false + g.writeStartArray() + while (p.nextToken() != END_ARRAY) { + dirty |= evaluatePath(p, g, FlattenStyle, xs) + } + g.writeEndArray() + dirty + + case (START_ARRAY, Subscript :: Wildcard :: xs) if style != QuotedStyle => + // retain Flatten, otherwise use Quoted... cannot use Raw within an array + val nextStyle = style match { + case RawStyle => QuotedStyle + case FlattenStyle => FlattenStyle + case QuotedStyle => throw new IllegalStateException() + } + + // temporarily buffer child matches, the emitted json will need to be + // modified slightly if there is only a single element written + val buffer = new StringWriter() + val flattenGenerator = jsonFactory.createGenerator(buffer) + flattenGenerator.writeStartArray() + + var dirty = 0 + while (p.nextToken() != END_ARRAY) { + // track the number of array elements and only emit an outer array if + // we've written more than one element, this matches Hive's behavior + dirty += (if (evaluatePath(p, flattenGenerator, nextStyle, xs)) 1 else 0) + } + flattenGenerator.writeEndArray() + flattenGenerator.close() + + val buf = buffer.getBuffer + if (dirty > 1) { + g.writeRawValue(buf.toString) + } else if (dirty == 1) { + // remove outer array tokens + g.writeRawValue(buf.substring(1, buf.length()-1)) + } // else do not write anything + + dirty > 0 + + case (START_ARRAY, Subscript :: Wildcard :: xs) => + var dirty = false + g.writeStartArray() + while (p.nextToken() != END_ARRAY) { + // wildcards can have multiple matches, continually update the dirty count + dirty |= evaluatePath(p, g, QuotedStyle, xs) + } + g.writeEndArray() + + dirty + + case (START_ARRAY, Subscript :: Index(idx) :: (xs@Subscript :: Wildcard :: _)) => + p.nextToken() + // we're going to have 1 or more results, switch to QuotedStyle + arrayIndex(p, () => evaluatePath(p, g, QuotedStyle, xs))(idx) + + case (START_ARRAY, Subscript :: Index(idx) :: xs) => + p.nextToken() + arrayIndex(p, () => evaluatePath(p, g, style, xs))(idx) + + case (FIELD_NAME, Named(name) :: xs) if p.getCurrentName == name => + // exact field match + p.nextToken() + evaluatePath(p, g, style, xs) + + case (FIELD_NAME, Wildcard :: xs) => + // wildcard field match + p.nextToken() + evaluatePath(p, g, style, xs) + + case _ => + p.skipChildren() + false + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala deleted file mode 100644 index 23bfa18..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala +++ /dev/null @@ -1,309 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.io.{StringWriter, ByteArrayOutputStream} - -import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{StringType, DataType} -import org.apache.spark.unsafe.types.UTF8String - -import scala.util.parsing.combinator.RegexParsers - -private[this] sealed trait PathInstruction -private[this] object PathInstruction { - private[expressions] case object Subscript extends PathInstruction - private[expressions] case object Wildcard extends PathInstruction - private[expressions] case object Key extends PathInstruction - private[expressions] case class Index(index: Long) extends PathInstruction - private[expressions] case class Named(name: String) extends PathInstruction -} - -private[this] sealed trait WriteStyle -private[this] object WriteStyle { - private[expressions] case object RawStyle extends WriteStyle - private[expressions] case object QuotedStyle extends WriteStyle - private[expressions] case object FlattenStyle extends WriteStyle -} - -private[this] object JsonPathParser extends RegexParsers { - import PathInstruction._ - - def root: Parser[Char] = '$' - - def long: Parser[Long] = "\\d+".r ^? { - case x => x.toLong - } - - // parse `[*]` and `[123]` subscripts - def subscript: Parser[List[PathInstruction]] = - for { - operand <- '[' ~> ('*' ^^^ Wildcard | long ^^ Index) <~ ']' - } yield { - Subscript :: operand :: Nil - } - - // parse `.name` or `['name']` child expressions - def named: Parser[List[PathInstruction]] = - for { - name <- '.' ~> "[^\\.\\[]+".r | "[\\'" ~> "[^\\'\\?]+" <~ "\\']" - } yield { - Key :: Named(name) :: Nil - } - - // child wildcards: `..`, `.*` or `['*']` - def wildcard: Parser[List[PathInstruction]] = - (".*" | "['*']") ^^^ List(Wildcard) - - def node: Parser[List[PathInstruction]] = - wildcard | - named | - subscript - - val expression: Parser[List[PathInstruction]] = { - phrase(root ~> rep(node) ^^ (x => x.flatten)) - } - - def parse(str: String): Option[List[PathInstruction]] = { - this.parseAll(expression, str) match { - case Success(result, _) => - Some(result) - - case NoSuccess(msg, next) => - None - } - } -} - -private[this] object GetJsonObject { - private val jsonFactory = new JsonFactory() - - // Enabled for Hive compatibility - jsonFactory.enable(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS) -} - -/** - * Extracts json object from a json string based on json path specified, and returns json string - * of the extracted json object. It will return null if the input json string is invalid. - */ -case class GetJsonObject(json: Expression, path: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - - import GetJsonObject._ - import PathInstruction._ - import WriteStyle._ - import com.fasterxml.jackson.core.JsonToken._ - - override def left: Expression = json - override def right: Expression = path - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - override def dataType: DataType = StringType - override def prettyName: String = "get_json_object" - - @transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) - - override def eval(input: InternalRow): Any = { - val jsonStr = json.eval(input).asInstanceOf[UTF8String] - if (jsonStr == null) { - return null - } - - val parsed = if (path.foldable) { - parsedPath - } else { - parsePath(path.eval(input).asInstanceOf[UTF8String]) - } - - if (parsed.isDefined) { - try { - val parser = jsonFactory.createParser(jsonStr.getBytes) - val output = new ByteArrayOutputStream() - val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8) - parser.nextToken() - val matched = evaluatePath(parser, generator, RawStyle, parsed.get) - generator.close() - if (matched) { - UTF8String.fromBytes(output.toByteArray) - } else { - null - } - } catch { - case _: JsonProcessingException => null - } - } else { - null - } - } - - private def parsePath(path: UTF8String): Option[List[PathInstruction]] = { - if (path != null) { - JsonPathParser.parse(path.toString) - } else { - None - } - } - - // advance to the desired array index, assumes to start at the START_ARRAY token - private def arrayIndex(p: JsonParser, f: () => Boolean): Long => Boolean = { - case _ if p.getCurrentToken == END_ARRAY => - // terminate, nothing has been written - false - - case 0 => - // we've reached the desired index - val dirty = f() - - while (p.nextToken() != END_ARRAY) { - // advance the token stream to the end of the array - p.skipChildren() - } - - dirty - - case i if i > 0 => - // skip this token and evaluate the next - p.skipChildren() - p.nextToken() - arrayIndex(p, f)(i - 1) - } - - /** - * Evaluate a list of JsonPath instructions, returning a bool that indicates if any leaf nodes - * have been written to the generator - */ - private def evaluatePath( - p: JsonParser, - g: JsonGenerator, - style: WriteStyle, - path: List[PathInstruction]): Boolean = { - (p.getCurrentToken, path) match { - case (VALUE_STRING, Nil) if style == RawStyle => - // there is no array wildcard or slice parent, emit this string without quotes - if (p.hasTextCharacters) { - g.writeRaw(p.getTextCharacters, p.getTextOffset, p.getTextLength) - } else { - g.writeRaw(p.getText) - } - true - - case (START_ARRAY, Nil) if style == FlattenStyle => - // flatten this array into the parent - var dirty = false - while (p.nextToken() != END_ARRAY) { - dirty |= evaluatePath(p, g, style, Nil) - } - dirty - - case (_, Nil) => - // general case: just copy the child tree verbatim - g.copyCurrentStructure(p) - true - - case (START_OBJECT, Key :: xs) => - var dirty = false - while (p.nextToken() != END_OBJECT) { - if (dirty) { - // once a match has been found we can skip other fields - p.skipChildren() - } else { - dirty = evaluatePath(p, g, style, xs) - } - } - dirty - - case (START_ARRAY, Subscript :: Wildcard :: Subscript :: Wildcard :: xs) => - // special handling for the non-structure preserving double wildcard behavior in Hive - var dirty = false - g.writeStartArray() - while (p.nextToken() != END_ARRAY) { - dirty |= evaluatePath(p, g, FlattenStyle, xs) - } - g.writeEndArray() - dirty - - case (START_ARRAY, Subscript :: Wildcard :: xs) if style != QuotedStyle => - // retain Flatten, otherwise use Quoted... cannot use Raw within an array - val nextStyle = style match { - case RawStyle => QuotedStyle - case FlattenStyle => FlattenStyle - case QuotedStyle => throw new IllegalStateException() - } - - // temporarily buffer child matches, the emitted json will need to be - // modified slightly if there is only a single element written - val buffer = new StringWriter() - val flattenGenerator = jsonFactory.createGenerator(buffer) - flattenGenerator.writeStartArray() - - var dirty = 0 - while (p.nextToken() != END_ARRAY) { - // track the number of array elements and only emit an outer array if - // we've written more than one element, this matches Hive's behavior - dirty += (if (evaluatePath(p, flattenGenerator, nextStyle, xs)) 1 else 0) - } - flattenGenerator.writeEndArray() - flattenGenerator.close() - - val buf = buffer.getBuffer - if (dirty > 1) { - g.writeRawValue(buf.toString) - } else if (dirty == 1) { - // remove outer array tokens - g.writeRawValue(buf.substring(1, buf.length()-1)) - } // else do not write anything - - dirty > 0 - - case (START_ARRAY, Subscript :: Wildcard :: xs) => - var dirty = false - g.writeStartArray() - while (p.nextToken() != END_ARRAY) { - // wildcards can have multiple matches, continually update the dirty count - dirty |= evaluatePath(p, g, QuotedStyle, xs) - } - g.writeEndArray() - - dirty - - case (START_ARRAY, Subscript :: Index(idx) :: (xs@Subscript :: Wildcard :: _)) => - p.nextToken() - // we're going to have 1 or more results, switch to QuotedStyle - arrayIndex(p, () => evaluatePath(p, g, QuotedStyle, xs))(idx) - - case (START_ARRAY, Subscript :: Index(idx) :: xs) => - p.nextToken() - arrayIndex(p, () => evaluatePath(p, g, style, xs))(idx) - - case (FIELD_NAME, Named(name) :: xs) if p.getCurrentName == name => - // exact field match - p.nextToken() - evaluatePath(p, g, style, xs) - - case (FIELD_NAME, Wildcard :: xs) => - // wildcard field match - p.nextToken() - evaluatePath(p, g, style, xs) - - case _ => - p.skipChildren() - false - } - } -} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
