Repository: spark Updated Branches: refs/heads/branch-1.5 56a37b01f -> 321cb99ca
http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala deleted file mode 100644 index ca044d3..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ /dev/null @@ -1,1320 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.text.DecimalFormat -import java.util.Arrays -import java.util.{Map => JMap, HashMap} -import java.util.Locale -import java.util.regex.{MatchResult, Pattern} - -import org.apache.commons.lang3.StringEscapeUtils - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.StringUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// This file defines expressions for string operations. -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -/** - * An expression that concatenates multiple input strings into a single string. - * If any input is null, concat returns null. - */ -case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) - override def dataType: DataType = StringType - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - override def eval(input: InternalRow): Any = { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val evals = children.map(_.gen(ctx)) - val inputs = evals.map { eval => - s"${eval.isNull} ? null : ${eval.primitive}" - }.mkString(", ") - evals.map(_.code).mkString("\n") + s""" - boolean ${ev.isNull} = false; - UTF8String ${ev.primitive} = UTF8String.concat($inputs); - if (${ev.primitive} == null) { - ${ev.isNull} = true; - } - """ - } -} - - -/** - * An expression that concatenates multiple input strings or array of strings into a single string, - * using a given separator (the first child). - * - * Returns null if the separator is null. Otherwise, concat_ws skips all null values. - */ -case class ConcatWs(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes with CodegenFallback { - - require(children.nonEmpty, s"$prettyName requires at least one argument.") - - override def prettyName: String = "concat_ws" - - /** The 1st child (separator) is str, and rest are either str or array of str. */ - override def inputTypes: Seq[AbstractDataType] = { - val arrayOrStr = TypeCollection(ArrayType(StringType), StringType) - StringType +: Seq.fill(children.size - 1)(arrayOrStr) - } - - override def dataType: DataType = StringType - - override def nullable: Boolean = children.head.nullable - override def foldable: Boolean = children.forall(_.foldable) - - override def eval(input: InternalRow): Any = { - val flatInputs = children.flatMap { child => - child.eval(input) match { - case s: UTF8String => Iterator(s) - case arr: ArrayData => arr.toArray[UTF8String](StringType) - case null => Iterator(null.asInstanceOf[UTF8String]) - } - } - UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (children.forall(_.dataType == StringType)) { - // All children are strings. In that case we can construct a fixed size array. - val evals = children.map(_.gen(ctx)) - - val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}" - }.mkString(", ") - - evals.map(_.code).mkString("\n") + s""" - UTF8String ${ev.primitive} = UTF8String.concatWs($inputs); - boolean ${ev.isNull} = ${ev.primitive} == null; - """ - } else { - // Contains a mix of strings and array<string>s. Fall back to interpreted mode for now. - super.genCode(ctx, ev) - } - } -} - - -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => - - def escape(v: String): String - def matches(regex: Pattern, str: String): Boolean - - override def dataType: DataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - // try cache the pattern for Literal - private lazy val cache: Pattern = right match { - case x @ Literal(value: String, StringType) => compile(value) - case _ => null - } - - protected def compile(str: String): Pattern = if (str == null) { - null - } else { - // Let it raise exception if couldn't compile the regex string - Pattern.compile(escape(str)) - } - - protected def pattern(str: String) = if (cache == null) compile(str) else cache - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val regex = pattern(input2.asInstanceOf[UTF8String].toString()) - if(regex == null) { - null - } else { - matches(regex, input1.asInstanceOf[UTF8String].toString()) - } - } -} - -/** - * Simple RegEx pattern matching function - */ -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { - - override def escape(v: String): String = StringUtils.escapeLikeRegex(v) - - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() - - override def toString: String = s"$left LIKE $right" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val patternClass = classOf[Pattern].getName - val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - val pattern = ctx.freshName("pattern") - - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { - val regexStr = - StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") - - // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" - ${eval.code} - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); - } - """ - } else { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } - } else { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); - ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); - """ - }) - } - } -} - - -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { - - override def escape(v: String): String = v - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) - override def toString: String = s"$left RLIKE $right" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val patternClass = classOf[Pattern].getName - val pattern = ctx.freshName("pattern") - - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { - val regexStr = - StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") - - // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" - ${eval.code} - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); - } - """ - } else { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } - } else { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile(rightStr); - ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); - """ - }) - } - } -} - - -trait String2StringExpression extends ImplicitCastInputTypes { - self: UnaryExpression => - - def convert(v: UTF8String): UTF8String - - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType) - - protected override def nullSafeEval(input: Any): Any = - convert(input.asInstanceOf[UTF8String]) -} - -/** - * A function that converts the characters of a string to uppercase. - */ -@ExpressionDescription( - usage = "_FUNC_(str) - Returns str with all characters changed to uppercase", - extended = "> SELECT _FUNC_('SparkSql');\n 'SPARKSQL'") -case class Upper(child: Expression) - extends UnaryExpression with String2StringExpression { - - override def convert(v: UTF8String): UTF8String = v.toUpperCase - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") - } -} - -/** - * A function that converts the characters of a string to lowercase. - */ -@ExpressionDescription( - usage = "_FUNC_(str) - Returns str with all characters changed to lowercase", - extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'") -case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { - - override def convert(v: UTF8String): UTF8String = v.toLowerCase - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") - } -} - -/** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ImplicitCastInputTypes { - self: BinaryExpression => - - def compare(l: UTF8String, r: UTF8String): Boolean - - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - protected override def nullSafeEval(input1: Any, input2: Any): Any = - compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) - - override def toString: String = s"$nodeName($left, $right)" -} - -/** - * A function that returns true if the string `left` contains the string `right`. - */ -case class Contains(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { - override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") - } -} - -/** - * A function that returns true if the string `left` starts with the string `right`. - */ -case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { - override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") - } -} - -/** - * A function that returns true if the string `left` ends with the string `right`. - */ -case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { - override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") - } -} - -object StringTranslate { - - def buildDict(matchingString: UTF8String, replaceString: UTF8String) - : JMap[Character, Character] = { - val matching = matchingString.toString() - val replace = replaceString.toString() - val dict = new HashMap[Character, Character]() - var i = 0 - while (i < matching.length()) { - val rep = if (i < replace.length()) replace.charAt(i) else '\0' - if (null == dict.get(matching.charAt(i))) { - dict.put(matching.charAt(i), rep) - } - i += 1 - } - dict - } -} - -/** - * A function translate any character in the `srcExpr` by a character in `replaceExpr`. - * The characters in `replaceExpr` is corresponding to the characters in `matchingExpr`. - * The translate will happen when any character in the string matching with the character - * in the `matchingExpr`. - */ -case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - @transient private var lastMatching: UTF8String = _ - @transient private var lastReplace: UTF8String = _ - @transient private var dict: JMap[Character, Character] = _ - - override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = { - if (matchingEval != lastMatching || replaceEval != lastReplace) { - lastMatching = matchingEval.asInstanceOf[UTF8String].clone() - lastReplace = replaceEval.asInstanceOf[UTF8String].clone() - dict = StringTranslate.buildDict(lastMatching, lastReplace) - } - srcEval.asInstanceOf[UTF8String].translate(dict) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastMatching = ctx.freshName("lastMatching") - val termLastReplace = ctx.freshName("lastReplace") - val termDict = ctx.freshName("dict") - val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - - ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") - ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") - - nullSafeCodeGen(ctx, ev, (src, matching, replace) => { - val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"${termDict} == null" - } else { - s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" - } - s"""if ($check) { - // Not all of them is literal or matching or replace value changed - ${termLastMatching} = ${matching}.clone(); - ${termLastReplace} = ${replace}.clone(); - ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict(${termLastMatching}, ${termLastReplace}); - } - ${ev.primitive} = ${src}.translate(${termDict}); - """ - }) - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil - override def prettyName: String = "translate" -} - -/** - * A function that returns the index (1-based) of the given string (left) in the comma- - * delimited list (right). Returns 0, if the string wasn't found or if the given - * string (left) contains a comma. - */ -case class FindInSet(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - - override protected def nullSafeEval(word: Any, set: Any): Any = - set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (word, set) => - s"${ev.primitive} = $set.findInSet($word);" - ) - } - - override def dataType: DataType = IntegerType -} - -/** - * A function that trim the spaces from both ends for the specified string. - */ -case class StringTrim(child: Expression) - extends UnaryExpression with String2StringExpression { - - def convert(v: UTF8String): UTF8String = v.trim() - - override def prettyName: String = "trim" - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).trim()") - } -} - -/** - * A function that trim the spaces from left end for given string. - */ -case class StringTrimLeft(child: Expression) - extends UnaryExpression with String2StringExpression { - - def convert(v: UTF8String): UTF8String = v.trimLeft() - - override def prettyName: String = "ltrim" - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).trimLeft()") - } -} - -/** - * A function that trim the spaces from right end for given string. - */ -case class StringTrimRight(child: Expression) - extends UnaryExpression with String2StringExpression { - - def convert(v: UTF8String): UTF8String = v.trimRight() - - override def prettyName: String = "rtrim" - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).trimRight()") - } -} - -/** - * A function that returns the position of the first occurrence of substr in the given string. - * Returns null if either of the arguments are null and - * returns 0 if substr could not be found in str. - * - * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. - */ -case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = str - override def right: Expression = substr - override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - override def nullSafeEval(string: Any, sub: Any): Any = { - string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 - } - - override def prettyName: String = "instr" - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (l, r) => - s"($l).indexOf($r, 0) + 1") - } -} - -/** - * Returns the substring from string str before count occurrences of the delimiter delim. - * If count is positive, everything the left of the final delimiter (counting from left) is - * returned. If count is negative, every to the right of the final delimiter (counting from the - * right) is returned. substring_index performs a case-sensitive match when searching for delim. - */ -case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr) - override def prettyName: String = "substring_index" - - override def nullSafeEval(str: Any, delim: Any, count: Any): Any = { - str.asInstanceOf[UTF8String].subStringIndex( - delim.asInstanceOf[UTF8String], - count.asInstanceOf[Int]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") - } -} - -/** - * A function that returns the position of the first occurrence of substr - * in given string after position pos. - */ -case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { - - def this(substr: Expression, str: Expression) = { - this(substr, str, Literal(0)) - } - - override def children: Seq[Expression] = substr :: str :: start :: Nil - override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) - - override def eval(input: InternalRow): Any = { - val s = start.eval(input) - if (s == null) { - // if the start position is null, we need to return 0, (conform to Hive) - 0 - } else { - val r = substr.eval(input) - if (r == null) { - null - } else { - val l = str.eval(input) - if (l == null) { - null - } else { - l.asInstanceOf[UTF8String].indexOf( - r.asInstanceOf[UTF8String], - s.asInstanceOf[Int]) + 1 - } - } - } - } - - override def prettyName: String = "locate" -} - -/** - * Returns str, left-padded with pad to a length of len. - */ -case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - override def children: Seq[Expression] = str :: len :: pad :: Nil - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - - override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { - str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") - } - - override def prettyName: String = "lpad" -} - -/** - * Returns str, right-padded with pad to a length of len. - */ -case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - override def children: Seq[Expression] = str :: len :: pad :: Nil - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - - override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { - str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) - } - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") - } - - override def prettyName: String = "rpad" -} - -/** - * Returns the input formatted according do printf-style format strings - */ -case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { - - require(children.nonEmpty, "format_string() should take at least 1 argument") - - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children(0).nullable - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = - StringType :: List.fill(children.size - 1)(AnyDataType) - - override def eval(input: InternalRow): Any = { - val pattern = children(0).eval(input) - if (pattern == null) { - null - } else { - val sb = new StringBuffer() - val formatter = new java.util.Formatter(sb, Locale.US) - - val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef]) - formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) - - UTF8String.fromString(sb.toString) - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val pattern = children.head.gen(ctx) - - val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) - val argListCode = argListGen.map(_._2.code + "\n") - - val argListString = argListGen.foldLeft("")((s, v) => { - val nullSafeString = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { - // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" - } else { - s"(${v._2.isNull}) ? null : ${v._2.primitive}" - } - s + "," + nullSafeString - }) - - val form = ctx.freshName("formatter") - val formatter = classOf[java.util.Formatter].getName - val sb = ctx.freshName("sb") - val stringBuffer = classOf[StringBuffer].getName - s""" - ${pattern.code} - boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${argListCode.mkString} - $stringBuffer $sb = new $stringBuffer(); - $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); - $form.format(${pattern.primitive}.toString() $argListString); - ${ev.primitive} = UTF8String.fromString($sb.toString()); - } - """ - } - - override def prettyName: String = "format_string" -} - -/** - * Returns string, with the first letter of each word in uppercase. - * Words are delimited by whitespace. - */ -case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def inputTypes: Seq[DataType] = Seq(StringType) - override def dataType: DataType = StringType - - override def nullSafeEval(string: Any): Any = { - string.asInstanceOf[UTF8String].toTitleCase - } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, str => s"$str.toTitleCase()") - } -} - -/** - * Returns the string which repeat the given string value n times. - */ -case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = str - override def right: Expression = times - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType) - - override def nullSafeEval(string: Any, n: Any): Any = { - string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) - } - - override def prettyName: String = "repeat" - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") - } -} - -/** - * Returns the reversed given string. - */ -case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { - override def convert(v: UTF8String): UTF8String = v.reverse() - - override def prettyName: String = "reverse" - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).reverse()") - } -} - -/** - * Returns a n spaces string. - */ -case class StringSpace(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(IntegerType) - - override def nullSafeEval(s: Any): Any = { - val length = s.asInstanceOf[Int] - UTF8String.blankString(if (length < 0) 0 else length) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (length) => - s"""${ev.primitive} = UTF8String.blankString(($length < 0) ? 0 : $length);""") - } - - override def prettyName: String = "space" -} - -/** - * Splits str around pat (pattern is a regular expression). - */ -case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = str - override def right: Expression = pattern - override def dataType: DataType = ArrayType(StringType) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - override def nullSafeEval(string: Any, regex: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) - new GenericArrayData(strings.asInstanceOf[Array[Any]]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, pattern) => - // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") - } - - override def prettyName: String = "split" -} - -object Substring { - def subStringBinarySQL(bytes: Array[Byte], pos: Int, len: Int): Array[Byte] = { - if (pos > bytes.length) { - return Array[Byte]() - } - - var start = if (pos > 0) { - pos - 1 - } else if (pos < 0) { - bytes.length + pos - } else { - 0 - } - - val end = if ((bytes.length - start) < len) { - bytes.length - } else { - start + len - } - - start = Math.max(start, 0) // underflow - if (start < end) { - Arrays.copyOfRange(bytes, start, end) - } else { - Array[Byte]() - } - } -} -/** - * A function that takes a substring of its first argument starting at a given position. - * Defined for String and Binary types. - */ -case class Substring(str: Expression, pos: Expression, len: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - def this(str: Expression, pos: Expression) = { - this(str, pos, Literal(Integer.MAX_VALUE)) - } - - override def dataType: DataType = str.dataType - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType) - - override def children: Seq[Expression] = str :: pos :: len :: Nil - - override def nullSafeEval(string: Any, pos: Any, len: Any): Any = { - str.dataType match { - case StringType => string.asInstanceOf[UTF8String] - .substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) - case BinaryType => Substring.subStringBinarySQL(string.asInstanceOf[Array[Byte]], - pos.asInstanceOf[Int], len.asInstanceOf[Int]) - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - - val cls = classOf[Substring].getName - defineCodeGen(ctx, ev, (string, pos, len) => { - str.dataType match { - case StringType => s"$string.substringSQL($pos, $len)" - case BinaryType => s"$cls.subStringBinarySQL($string, $pos, $len)" - } - }) - } -} - -/** - * A function that return the length of the given string or binary expression. - */ -case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def dataType: DataType = IntegerType - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) - - protected override def nullSafeEval(value: Any): Any = child.dataType match { - case StringType => value.asInstanceOf[UTF8String].numChars - case BinaryType => value.asInstanceOf[Array[Byte]].length - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - child.dataType match { - case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") - case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") - } - } -} - -/** - * A function that return the Levenshtein distance between the two given strings. - */ -case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) - - override def dataType: DataType = IntegerType - protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = - leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (left, right) => - s"${ev.primitive} = $left.levenshteinDistance($right);") - } -} - -/** - * A function that return soundex code of the given string expression. - */ -case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { - - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = Seq(StringType) - - override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"$c.soundex()") - } -} - -/** - * Returns the numeric value of the first character of str. - */ -case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType) - - protected override def nullSafeEval(string: Any): Any = { - val bytes = string.asInstanceOf[UTF8String].getBytes - if (bytes.length > 0) { - bytes(0).asInstanceOf[Int] - } else { - 0 - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (child) => { - val bytes = ctx.freshName("bytes") - s""" - byte[] $bytes = $child.getBytes(); - if ($bytes.length > 0) { - ${ev.primitive} = (int) $bytes[0]; - } else { - ${ev.primitive} = 0; - } - """}) - } -} - -/** - * Converts the argument from binary to a base 64 string. - */ -case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(BinaryType) - - protected override def nullSafeEval(bytes: Any): Any = { - UTF8String.fromBytes( - org.apache.commons.codec.binary.Base64.encodeBase64( - bytes.asInstanceOf[Array[Byte]])) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (child) => { - s"""${ev.primitive} = UTF8String.fromBytes( - org.apache.commons.codec.binary.Base64.encodeBase64($child)); - """}) - } - -} - -/** - * Converts the argument from a base 64 string to BINARY. - */ -case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def dataType: DataType = BinaryType - override def inputTypes: Seq[DataType] = Seq(StringType) - - protected override def nullSafeEval(string: Any): Any = - org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (child) => { - s""" - ${ev.primitive} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); - """}) - } -} - -/** - * Decodes the first argument into a String using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - */ -case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = bin - override def right: Expression = charset - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType) - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val fromCharset = input2.asInstanceOf[UTF8String].toString - UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (bytes, charset) => - s""" - try { - ${ev.primitive} = UTF8String.fromString(new String($bytes, $charset.toString())); - } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.Platform.throwException(e); - } - """) - } -} - -/** - * Encodes the first argument into a BINARY using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. -*/ -case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = value - override def right: Expression = charset - override def dataType: DataType = BinaryType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val toCharset = input2.asInstanceOf[UTF8String].toString - input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (string, charset) => - s""" - try { - ${ev.primitive} = $string.toString().getBytes($charset.toString()); - } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.Platform.throwException(e); - }""") - } -} - -/** - * Replace all substrings of str that match regexp with rep. - * - * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. - */ -case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - // last replacement string, we don't want to convert a UTF8String => java.langString every time. - @transient private var lastReplacement: String = _ - @transient private var lastReplacementInUTF8: UTF8String = _ - // result buffer write by Matcher - @transient private val result: StringBuffer = new StringBuffer - - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - if (!r.equals(lastReplacementInUTF8)) { - // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() - lastReplacement = lastReplacementInUTF8.toString - } - val m = pattern.matcher(s.toString()) - result.delete(0, result.length()) - - while (m.find) { - m.appendReplacement(result, lastReplacement) - } - m.appendTail(result) - - UTF8String.fromString(result.toString) - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = subject :: regexp :: rep :: Nil - override def prettyName: String = "regexp_replace" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - - val termLastReplacement = ctx.freshName("lastReplacement") - val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") - - val termResult = ctx.freshName("result") - - val classNamePattern = classOf[Pattern].getCanonicalName - val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName - - ctx.addMutableState("UTF8String", - termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, - termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", - termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", - termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, - termResult, s"${termResult} = new $classNameStringBuffer();") - - nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { - s""" - if (!$regexp.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - if (!$rep.equals(${termLastReplacementInUTF8})) { - // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); - } - ${termResult}.delete(0, ${termResult}.length()); - java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); - - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); - } - m.appendTail(${termResult}); - ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); - ${ev.isNull} = false; - """ - }) - } -} - -/** - * Extract a specific(idx) group identified by a Java regex. - * - * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. - */ -case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - def this(s: Expression, r: Expression) = this(s, r, Literal(1)) - - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString()) - if (m.find) { - val mr: MatchResult = m.toMatchResult - UTF8String.fromString(mr.group(r.asInstanceOf[Int])) - } else { - UTF8String.EMPTY_UTF8 - } - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = subject :: regexp :: idx :: Nil - override def prettyName: String = "regexp_extract" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - val classNamePattern = classOf[Pattern].getCanonicalName - - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - - nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { - s""" - if (!$regexp.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - java.util.regex.Matcher m = - ${termPattern}.matcher($subject.toString()); - if (m.find()) { - java.util.regex.MatchResult mr = m.toMatchResult(); - ${ev.primitive} = UTF8String.fromString(mr.group($idx)); - ${ev.isNull} = false; - } else { - ${ev.primitive} = UTF8String.EMPTY_UTF8; - ${ev.isNull} = false; - }""" - }) - } -} - -/** - * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, - * and returns the result as a string. If D is 0, the result has no decimal point or - * fractional part. - */ -case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes { - - override def left: Expression = x - override def right: Expression = d - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) - - // Associated with the pattern, for the last d value, and we will update the - // pattern (DecimalFormat) once the new coming d value differ with the last one. - @transient - private var lastDValue: Int = -100 - - // A cached DecimalFormat, for performance concern, we will change it - // only if the d value changed. - @transient - private val pattern: StringBuffer = new StringBuffer() - - @transient - private val numberFormat: DecimalFormat = new DecimalFormat("") - - override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { - val dValue = dObject.asInstanceOf[Int] - if (dValue < 0) { - return null - } - - if (dValue != lastDValue) { - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") - } - } - val dFormat = new DecimalFormat(pattern.toString) - lastDValue = dValue - - numberFormat.applyPattern(dFormat.toPattern) - } - - x.dataType match { - case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte])) - case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short])) - case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float])) - case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int])) - case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long])) - case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double])) - case _: DecimalType => - UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal)) - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (num, d) => { - - def typeHelper(p: String): String = { - x.dataType match { - case _ : DecimalType => s"""$p.toJavaBigDecimal()""" - case _ => s"$p" - } - } - - val sb = classOf[StringBuffer].getName - val df = classOf[DecimalFormat].getName - val lastDValue = ctx.freshName("lastDValue") - val pattern = ctx.freshName("pattern") - val numberFormat = ctx.freshName("numberFormat") - val i = ctx.freshName("i") - val dFormat = ctx.freshName("dFormat") - ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") - ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""") - - s""" - if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); - - if ($d > 0) { - $pattern.append("."); - for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); - } - } - $df $dFormat = new $df($pattern.toString()); - $lastDValue = $d; - $numberFormat.applyPattern($dFormat.toPattern()); - } - ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); - } else { - ${ev.primitive} = null; - ${ev.isNull} = true; - } - """ - }) - } - - override def prettyName: String = "format_number" -} http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 42457d5..854463d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -372,7 +372,7 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => e } - case e: StringComparison => e.children match { + case e: StringPredicate => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 426dc27..99e3b13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -673,7 +673,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) } - test("number format") { + test("format_number / FormatNumber") { checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 1f27056..fc4d093 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -56,6 +56,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { /** * Register a user-defined aggregate function (UDAF). + * * @param name the name of the UDAF. * @param udaf the UDAF needs to be registered. * @return the registered UDAF. http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 7619f3e..d43d3dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -304,7 +304,7 @@ private[sql] case class ScalaUDAF( override def nullable: Boolean = true - override def dataType: DataType = udaf.returnDataType + override def dataType: DataType = udaf.dataType override def deterministic: Boolean = udaf.deterministic http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 5180871..258afad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.{Column, Row} @@ -26,7 +25,7 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * The abstract class for implementing user-defined aggregate functions. + * The base class for implementing user-defined aggregate functions (UDAF). */ @Experimental abstract class UserDefinedAggregateFunction extends Serializable { @@ -67,22 +66,35 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. */ - def returnDataType: DataType + def dataType: DataType - /** Indicates if this function is deterministic. */ + /** + * Returns true iff this function is deterministic, i.e. given the same input, + * always return the same output. + */ def deterministic: Boolean /** - * Initializes the given aggregation buffer. Initial values set by this method should satisfy - * the condition that when merging two buffers with initial values, the new buffer - * still store initial values. + * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. + * + * The contract should be that applying the merge function on two initial buffers should just + * return the initial buffer itself, i.e. + * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. */ def initialize(buffer: MutableAggregationBuffer): Unit - /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + /** + * Updates the given aggregation buffer `buffer` with new input data from `input`. + * + * This is called once per input row. + */ def update(buffer: MutableAggregationBuffer, input: Row): Unit - /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */ + /** + * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. + * + * This is called when we merge two partially aggregated data together. + */ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit /** @@ -92,7 +104,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { def evaluate(buffer: Row): Any /** - * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. + * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ @scala.annotation.varargs def apply(exprs: Column*): Column = { @@ -105,16 +117,16 @@ abstract class UserDefinedAggregateFunction extends Serializable { } /** - * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. - * If `isDistinct` is true, this UDAF is working on distinct input values. + * Creates a [[Column]] for this UDAF using the distinct values of the given + * [[Column]]s as input arguments. */ @scala.annotation.varargs - def apply(isDistinct: Boolean, exprs: Column*): Column = { + def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression2( ScalaUDAF(exprs.map(_.expr), this), Complete, - isDistinct = isDistinct) + isDistinct = true) Column(aggregateExpression) } } @@ -122,9 +134,11 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * :: Experimental :: * A [[Row]] representing an mutable aggregation buffer. + * + * This is not meant to be extended outside of Spark. */ @Experimental -trait MutableAggregationBuffer extends Row { +abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ def update(i: Int, value: Any): Unit http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 21b053f..a30dfa5 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -92,7 +92,7 @@ public class JavaDataFrameSuite { DataFrame aggregatedDF = df.groupBy() .agg( - udaf.apply(true, col("value")), + udaf.distinct(col("value")), udaf.apply(col("value")), registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index a2247e3..2961b80 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -65,7 +65,7 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { return _bufferSchema; } - @Override public DataType returnDataType() { + @Override public DataType dataType() { return _returnDataType; } http://git-wip-us.apache.org/repos/asf/spark/blob/321cb99c/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index da29e24..c71882a 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -60,7 +60,7 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { return _bufferSchema; } - @Override public DataType returnDataType() { + @Override public DataType dataType() { return _returnDataType; } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
