Repository: spark Updated Branches: refs/heads/master 59d24c226 -> 305e77cd8
[SPARK-8209[SQL]Add function conv cc chenghao-intel adrian-wang Author: zhichao.li <[email protected]> Closes #6872 from zhichao-li/conv and squashes the following commits: 6ef3b37 [zhichao.li] add unittest and comments 78d9836 [zhichao.li] polish dataframe api and add unittest e2bace3 [zhichao.li] update to use ImplicitCastInputTypes cbcad3f [zhichao.li] add function conv Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/305e77cd Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/305e77cd Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/305e77cd Branch: refs/heads/master Commit: 305e77cd83f3dbe680a920d5329c2e8c58452d5b Parents: 59d24c2 Author: zhichao.li <[email protected]> Authored: Fri Jul 17 09:32:27 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Fri Jul 17 09:32:27 2015 -0700 ---------------------------------------------------------------------- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 191 +++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 21 +- .../scala/org/apache/spark/sql/functions.scala | 18 ++ .../apache/spark/sql/MathExpressionsSuite.scala | 13 ++ 5 files changed, 242 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e0beafe..a451817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -99,6 +99,7 @@ object FunctionRegistry { expression[Ceil]("ceil"), expression[Ceil]("ceiling"), expression[Cos]("cos"), + expression[Conv]("conv"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 84b289c..7a543ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} +import java.util.Arrays import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} @@ -139,6 +140,196 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") +/** + * Convert a num from one base to another + * @param numExpr the number to be converted + * @param fromBaseExpr from which base + * @param toBaseExpr to which base + */ +case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) + extends Expression with ImplicitCastInputTypes{ + + override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable + + override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable + + override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) + + /** Returns the result of evaluating this expression on a given input Row */ + override def eval(input: InternalRow): Any = { + val num = numExpr.eval(input) + val fromBase = fromBaseExpr.eval(input) + val toBase = toBaseExpr.eval(input) + if (num == null || fromBase == null || toBase == null) { + null + } else { + conv(num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], toBase.asInstanceOf[Int]) + } + } + + /** + * Returns the [[DataType]] of the result of evaluating this expression. It is + * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). + */ + override def dataType: DataType = StringType + + private val value = new Array[Byte](64) + + /** + * Divide x by m as if x is an unsigned 64-bit integer. Examples: + * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 + * unsignedLongDiv(0, 5) == 0 + * + * @param x is treated as unsigned + * @param m is treated as signed + */ + private def unsignedLongDiv(x: Long, m: Int): Long = { + if (x >= 0) { + x / m + } else { + // Let uval be the value of the unsigned long with the same bits as x + // Two's complement => x = uval - 2*MAX - 2 + // => uval = x + 2*MAX + 2 + // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c + (x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m) + } + } + + /** + * Decode v into value[]. + * + * @param v is treated as an unsigned 64-bit integer + * @param radix must be between MIN_RADIX and MAX_RADIX + */ + private def decode(v: Long, radix: Int): Unit = { + var tmpV = v + Arrays.fill(value, 0.asInstanceOf[Byte]) + var i = value.length - 1 + while (tmpV != 0) { + val q = unsignedLongDiv(tmpV, radix) + value(i) = (tmpV - q * radix).asInstanceOf[Byte] + tmpV = q + i -= 1 + } + } + + /** + * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a + * negative digit is found, ignore the suffix starting there. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first element that should be conisdered + * @return the result should be treated as an unsigned 64-bit integer. + */ + private def encode(radix: Int, fromPos: Int): Long = { + var v: Long = 0L + val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once + // val + // exceeds this value + var i = fromPos + while (i < value.length && value(i) >= 0) { + if (v >= bound) { + // Check for overflow + if (unsignedLongDiv(-1 - value(i), radix) < v) { + return -1 + } + } + v = v * radix + value(i) + i += 1 + } + return v + } + + /** + * Convert the bytes in value[] to the corresponding chars. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def byte2char(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while (i < value.length) { + value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert the chars in value[] to the corresponding integers. Convert invalid + * characters to -1. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def char2byte(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while ( i < value.length) { + value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert numbers between different number bases. If toBase>0 the result is + * unsigned, otherwise it is signed. + * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv + */ + private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { + if (n == null || fromBase == null || toBase == null || n.isEmpty) { + return null + } + + if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX + || Math.abs(toBase) < Character.MIN_RADIX + || Math.abs(toBase) > Character.MAX_RADIX) { + return null + } + + var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) + + // Copy the digits in the right side of the array + var i = 1 + while (i <= n.length - first) { + value(value.length - i) = n(n.length - i) + i += 1 + } + char2byte(fromBase, value.length - n.length + first) + + // Do the conversion by going through a 64 bit integer + var v = encode(fromBase, value.length - n.length + first) + if (negative && toBase > 0) { + if (v < 0) { + v = -1 + } else { + v = -v + } + } + if (toBase < 0 && v < 0) { + v = -v + negative = true + } + decode(v, Math.abs(toBase)) + + // Find the first non-zero digit or the last digits if all are zero. + val firstNonZeroPos = { + val firstNonZero = value.indexWhere( _ != 0) + if (firstNonZero != -1) firstNonZero else value.length - 1 + } + + byte2char(Math.abs(toBase), firstNonZeroPos) + + var resultStartPos = firstNonZeroPos + if (negative && toBase < 0) { + resultStartPos = firstNonZeroPos - 1 + value(resultStartPos) = '-' + } + UTF8String.fromBytes( Arrays.copyOfRange(value, resultStartPos, value.length)) + } +} + case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 52a874a..ca35c7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst.expressions -import scala.math.BigDecimal.RoundingMode - import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ + class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { /** @@ -95,6 +94,24 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } + test("conv") { + checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") + checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") + checkEvaluation(Conv(Literal(null), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal(null), Literal(16)), null) + checkEvaluation( + Conv(Literal("1234"), Literal(10), Literal(37)), null) + checkEvaluation( + Conv(Literal(""), Literal(10), Literal(16)), null) + checkEvaluation( + Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") + // If there is an invalid digit in the number, the longest valid prefix should be converted. + checkEvaluation( + Conv(Literal("11abc"), Literal(10), Literal(16)), "B") + } + test("e") { testLeaf(EulerNumber, math.E) } http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d6da284..fe511c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -69,6 +69,24 @@ object functions { def column(colName: String): Column = Column(colName) /** + * Convert a number from one base to another for the specified expressions + * + * @group math_funcs + * @since 1.5.0 + */ + def conv(num: Column, fromBase: Int, toBase: Int): Column = + Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + + /** + * Convert a number from one base to another for the specified expressions + * + * @group math_funcs + * @since 1.5.0 + */ + def conv(numColName: String, fromBase: Int, toBase: Int): Column = + conv(Column(numColName), fromBase, toBase) + + /** * Creates a [[Column]] of literal value. * * The passed in object is returned directly if it is already a [[Column]]. http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 087126b..8eb3fec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -178,6 +178,19 @@ class MathExpressionsSuite extends QueryTest { Row(0.0, 1.0, 2.0)) } + test("conv") { + val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase") + checkAnswer(df.select(conv('num, 10, 16)), Row("14D")) + checkAnswer(df.select(conv("num", 10, 16)), Row("14D")) + checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4")) + checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457")) + checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101")) + checkAnswer(df.selectExpr("""conv("100", 2, 10)"""), Row("4")) + checkAnswer(df.selectExpr("""conv("-10", 16, -10)"""), Row("-16")) + checkAnswer( + df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("-1")) // for overflow + } + test("floor") { testOneToOneMathFunction(floor, math.floor) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
