Repository: spark Updated Branches: refs/heads/master 6fc1e72d9 -> 8bd812132
http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 38f1210..b15a77a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -99,7 +99,7 @@ case class Not(child: Expression) protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } @@ -157,9 +157,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val valueGen = value.gen(ctx) - val listGen = list.map(_.gen(ctx)) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + val valueGen = value.genCode(ctx) + val listGen = list.map(_.genCode(ctx)) val listCode = listGen.map(x => s""" if (!${ev.value}) { @@ -216,10 +216,10 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with def getHSet(): Set[Any] = hset - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName - val childGen = child.gen(ctx) + val childGen = child.genCode(ctx) ctx.references += this val hsetTerm = ctx.freshName("hset") val hasNullTerm = ctx.freshName("hasNull") @@ -274,9 +274,9 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { @@ -339,9 +339,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { @@ -379,7 +379,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType @@ -428,7 +428,7 @@ case class EqualTo(left: Expression, right: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } @@ -464,9 +464,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.isNull = "false" eval1.code + eval2.code + s""" http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 1ec092a..1eed24d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -67,7 +67,7 @@ case class Rand(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, @@ -92,7 +92,7 @@ case class Randn(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") }) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 85a5429..4f5b85d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -78,7 +78,7 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" val pattern = ctx.freshName("pattern") @@ -92,7 +92,7 @@ case class Like(left: Expression, right: Expression) s"""$pattern = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) + val eval = left.genCode(ctx) s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; @@ -128,7 +128,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val patternClass = classOf[Pattern].getName val pattern = ctx.freshName("pattern") @@ -141,7 +141,7 @@ case class RLike(left: Expression, right: Expression) s"""$pattern = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) + val eval = left.genCode(ctx) s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; @@ -188,7 +188,7 @@ case class StringSplit(str: Expression, pattern: Expression) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. @@ -247,7 +247,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def children: Seq[Expression] = subject :: regexp :: rep :: Nil override def prettyName: String = "regexp_replace" - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") @@ -330,7 +330,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def prettyName: String = "regexp_extract" - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index a174826..8c15357 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -51,8 +51,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas UTF8String.concat(inputs : _*) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val evals = children.map(_.gen(ctx)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.value}" }.mkString(", ") @@ -106,10 +106,10 @@ case class ConcatWs(children: Seq[Expression]) UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. - val evals = children.map(_.gen(ctx)) + val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? (UTF8String) null : ${eval.value}" @@ -124,7 +124,7 @@ case class ConcatWs(children: Seq[Expression]) val varargNum = ctx.freshName("varargNum") val idxInVararg = ctx.freshName("idxInVararg") - val evals = children.map(_.gen(ctx)) + val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => child.dataType match { case StringType => @@ -185,7 +185,7 @@ case class Upper(child: Expression) override def convert(v: UTF8String): UTF8String = v.toUpperCase - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } } @@ -200,7 +200,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx override def convert(v: UTF8String): UTF8String = v.toLowerCase - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } } @@ -225,7 +225,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } } @@ -236,7 +236,7 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } } @@ -247,7 +247,7 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } } @@ -298,7 +298,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac srcEval.asInstanceOf[UTF8String].translate(dict) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastMatching = ctx.freshName("lastMatching") val termLastReplace = ctx.freshName("lastReplace") val termDict = ctx.freshName("dict") @@ -351,7 +351,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override protected def nullSafeEval(word: Any, set: Any): Any = set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);" ) @@ -375,7 +375,7 @@ case class StringTrim(child: Expression) override def prettyName: String = "trim" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trim()") } } @@ -393,7 +393,7 @@ case class StringTrimLeft(child: Expression) override def prettyName: String = "ltrim" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trimLeft()") } } @@ -411,7 +411,7 @@ case class StringTrimRight(child: Expression) override def prettyName: String = "rtrim" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trimRight()") } } @@ -440,7 +440,7 @@ case class StringInstr(str: Expression, substr: Expression) override def prettyName: String = "instr" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } @@ -475,7 +475,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: count.asInstanceOf[Int]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } } @@ -524,10 +524,10 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val substrGen = substr.gen(ctx) - val strGen = str.gen(ctx) - val startGen = start.gen(ctx) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + val substrGen = substr.genCode(ctx) + val strGen = str.genCode(ctx) + val startGen = start.genCode(ctx) s""" int ${ev.value} = 0; boolean ${ev.isNull} = false; @@ -571,7 +571,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } @@ -597,7 +597,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } @@ -638,10 +638,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val pattern = children.head.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + val pattern = children.head.genCode(ctx) - val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListGen = children.tail.map(x => (x.dataType, x.genCode(ctx))) val argListCode = argListGen.map(_._2.code + "\n") val argListString = argListGen.foldLeft("")((s, v) => { @@ -694,7 +694,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def nullSafeEval(string: Any): Any = { string.asInstanceOf[UTF8String].toLowerCase.toTitleCase } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") } } @@ -719,7 +719,7 @@ case class StringRepeat(str: Expression, times: Expression) override def prettyName: String = "repeat" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } } @@ -735,7 +735,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 override def prettyName: String = "reverse" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).reverse()") } } @@ -757,7 +757,7 @@ case class StringSpace(child: Expression) UTF8String.blankString(if (length < 0) 0 else length) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (length) => s"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } @@ -799,7 +799,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { @@ -825,7 +825,7 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => value.asInstanceOf[Array[Byte]].length } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") @@ -848,7 +848,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } @@ -868,7 +868,7 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } } @@ -894,7 +894,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { val bytes = ctx.freshName("bytes") s""" @@ -924,7 +924,7 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn bytes.asInstanceOf[Array[Byte]])) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s"""${ev.value} = UTF8String.fromBytes( org.apache.commons.codec.binary.Base64.encodeBase64($child)); @@ -945,7 +945,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast protected override def nullSafeEval(string: Any): Any = org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s""" ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); @@ -973,7 +973,7 @@ case class Decode(bin: Expression, charset: Expression) UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (bytes, charset) => s""" try { @@ -1005,7 +1005,7 @@ case class Encode(value: Expression, charset: Expression) input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (string, charset) => s""" try { @@ -1088,7 +1088,7 @@ case class FormatNumber(x: Expression, d: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (num, d) => { def typeHelper(p: String): String = { http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index ff34b1e3..de410b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -35,8 +35,8 @@ case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpres override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - Literal.create(value, dataType).genCode(ctx, ev) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + Literal.create(value, dataType).doGenCode(ctx, ev) } } http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 392c48f..3dc2aa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -192,7 +192,7 @@ private[sql] case class RowDataSourceScan( val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null - val columnsRowInput = exprRows.map(_.gen(ctx)) + val columnsRowInput = exprRows.map(_.genCode(ctx)) val inputRow = if (outputUnsafeRows) row else null s""" |while ($input.hasNext()) { http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index bd23b7e..cc0382e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -149,7 +149,7 @@ case class Expand( val firstExpr = projections.head(col) if (sameOutput(col)) { // This column is the same across all output rows. Just generate code for it here. - BindReferences.bindReference(firstExpr, child.output).gen(ctx) + BindReferences.bindReference(firstExpr, child.output).genCode(ctx) } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") @@ -166,7 +166,7 @@ case class Expand( var updateCode = "" for (col <- exprs.indices) { if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), child.output).gen(ctx) + val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx) updateCode += s""" |${ev.code} http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 29acc38..12d08c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -118,7 +118,7 @@ trait CodegenSupport extends SparkPlan { ctx.currentVars = null ctx.INPUT_ROW = row output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + BoundReference(i, attr.dataType, attr.nullable).genCode(ctx) } } else { assert(outputVars != null) http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index f585759..d819a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -165,7 +165,7 @@ case class TungstenAggregate( ctx.addMutableState("boolean", isNull, "") ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column - val ev = e.gen(ctx) + val ev = e.genCode(ctx) val initVars = s""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; @@ -179,13 +179,13 @@ case class TungstenAggregate( // evaluate aggregate results ctx.currentVars = bufVars val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) } val evaluateAggResults = evaluateVariables(aggResults) // evaluate result expressions ctx.currentVars = aggResults val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) } (resultVars, s""" |$evaluateAggResults @@ -196,7 +196,7 @@ case class TungstenAggregate( (bufVars, "") } else { // no aggregate function, the result should be literals - val resultVars = resultExpressions.map(_.gen(ctx)) + val resultVars = resultExpressions.map(_.genCode(ctx)) (resultVars, evaluateVariables(resultVars)) } @@ -240,7 +240,7 @@ case class TungstenAggregate( } ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) + val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).genCode(ctx)) // aggregate buffer should be updated atomic val updates = aggVals.zipWithIndex.map { case (ev, i) => s""" @@ -394,25 +394,25 @@ case class TungstenAggregate( ctx.currentVars = null ctx.INPUT_ROW = keyTerm val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).gen(ctx) + BoundReference(i, e.dataType, e.nullable).genCode(ctx) } val evaluateKeyVars = evaluateVariables(keyVars) ctx.INPUT_ROW = bufferTerm val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).gen(ctx) + BoundReference(i, e.dataType, e.nullable).genCode(ctx) } val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) } val evaluateAggResults = evaluateVariables(aggResults) // generate the final result ctx.currentVars = keyVars ++ aggResults val inputAttrs = groupingAttributes ++ aggregateAttributes val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).gen(ctx) + BindReferences.bindReference(e, inputAttrs).genCode(ctx) } s""" $evaluateKeyVars @@ -437,7 +437,7 @@ case class TungstenAggregate( ctx.INPUT_ROW = keyTerm ctx.currentVars = null val eval = resultExpressions.map{ e => - BindReferences.bindReference(e, groupingAttributes).gen(ctx) + BindReferences.bindReference(e, groupingAttributes).genCode(ctx) } consume(ctx, eval) } @@ -576,7 +576,7 @@ case class TungstenAggregate( // generate hash code for key val hashExpr = Murmur3Hash(groupingExpressions, 42) ctx.currentVars = input - val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx) + val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val inputAttr = aggregateBufferAttributes ++ child.output ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input @@ -613,7 +613,8 @@ case class TungstenAggregate( val updateRowInVectorizedHashMap: Option[String] = { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = vectorizedRowBuffer - val vectorizedRowEvals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) + val vectorizedRowEvals = + updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable) @@ -663,7 +664,7 @@ case class TungstenAggregate( val updateRowInUnsafeRowMap: String = { ctx.INPUT_ROW = unsafeRowBuffer val unsafeRowBufferEvals = - updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) + updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 344aaff..c689fc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -53,7 +53,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input - val resultVars = exprs.map(_.gen(ctx)) + val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" @@ -122,7 +122,7 @@ case class Filter(condition: Expression, child: SparkPlan) val evaluated = evaluateRequiredVariables(child.output, in, c.references) // Generate the code for the predicate. - val ev = ExpressionCanonicalizer.execute(bound).gen(ctx) + val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx) val nullCheck = if (bound.nullable) { s"${ev.isNull} || " } else { http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index a8f8541..b94b0d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -118,7 +118,7 @@ case class BroadcastHashJoin( ctx.currentVars = input if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { // generate the join key as Long - val ev = streamedKeys.head.gen(ctx) + val ev = streamedKeys.head.genCode(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow @@ -134,7 +134,7 @@ case class BroadcastHashJoin( ctx.currentVars = null ctx.INPUT_ROW = matched buildPlan.output.zipWithIndex.map { case (a, i) => - val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) + val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) if (joinType == Inner) { ev } else { @@ -170,7 +170,8 @@ case class BroadcastHashJoin( val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) // filter the output via condition ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) s""" |$eval |${ev.code} @@ -244,7 +245,8 @@ case class BroadcastHashJoin( // evaluate the variables from build side that used by condition val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) s""" |boolean $conditionPassed = true; |${eval.trim} http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 0e7b2f2..443a7b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -226,7 +226,7 @@ case class SortMergeJoin( keys: Seq[Expression], input: Seq[Attribute]): Seq[ExprCode] = { ctx.INPUT_ROW = row - keys.map(BindReferences.bindReference(_, input).gen(ctx)) + keys.map(BindReferences.bindReference(_, input).genCode(ctx)) } private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { @@ -376,7 +376,7 @@ case class SortMergeJoin( private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { ctx.INPUT_ROW = rightRow right.output.zipWithIndex.map { case (a, i) => - BoundReference(i, a.dataType, a.nullable).gen(ctx) + BoundReference(i, a.dataType, a.nullable).genCode(ctx) } } @@ -427,7 +427,7 @@ case class SortMergeJoin( val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) // Generate code for condition ctx.currentVars = leftVars ++ rightVars - val cond = BindReferences.bindReference(condition.get, output).gen(ctx) + val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop val before = s""" |boolean $loaded = false; http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d2ab18e..784b1e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -48,7 +48,7 @@ case class DeserializeToObject( val bound = ExpressionCanonicalizer.execute( BindReferences.bindReference(deserializer, child.output)) ctx.currentVars = input - val resultVars = bound.gen(ctx) :: Nil + val resultVars = bound.genCode(ctx) :: Nil consume(ctx, resultVars) } @@ -82,7 +82,7 @@ case class SerializeFromObject( ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) } ctx.currentVars = input - val resultVars = bound.map(_.gen(ctx)) + val resultVars = bound.map(_.genCode(ctx)) consume(ctx, resultVars) } @@ -173,13 +173,13 @@ case class MapElements( val bound = ExpressionCanonicalizer.execute( BindReferences.bindReference(callFunc, child.output)) ctx.currentVars = input - val evaluated = bound.gen(ctx) + val evaluated = bound.genCode(ctx) val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType) val outputFields = serializer.map(_ transform { case _: BoundReference => resultObj }) - val resultVars = outputFields.map(_.gen(ctx)) + val resultVars = outputFields.map(_.genCode(ctx)) s""" ${evaluated.code} ${consume(ctx, resultVars)} http://git-wip-us.apache.org/repos/asf/spark/blob/8bd81213/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 4b3091b..03defc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -54,8 +54,8 @@ case class ScalarSubquery( override def eval(input: InternalRow): Any = result - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - Literal.create(result, dataType).genCode(ctx, ev) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + Literal.create(result, dataType).doGenCode(ctx, ev) } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
