Repository: spark Updated Branches: refs/heads/master c392a9efa -> 47c874bab
[SPARK-8237] [SQL] Add misc function sha2 JIRA: https://issues.apache.org/jira/browse/SPARK-8237 Author: Liang-Chi Hsieh <[email protected]> Closes #6934 from viirya/expr_sha2 and squashes the following commits: 35e0bb3 [Liang-Chi Hsieh] For comments. 68b5284 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2 8573aff [Liang-Chi Hsieh] Remove unnecessary Product. ee61e06 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2 59e41aa [Liang-Chi Hsieh] Add misc function: sha2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/47c874ba Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/47c874ba Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/47c874ba Branch: refs/heads/master Commit: 47c874babe7779c7a2f32e0b891503ef6bebcab0 Parents: c392a9e Author: Liang-Chi Hsieh <[email protected]> Authored: Thu Jun 25 22:07:37 2015 -0700 Committer: Davies Liu <[email protected]> Committed: Thu Jun 25 22:07:37 2015 -0700 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 19 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 98 +++++++++++++++++++- .../expressions/MiscFunctionsSuite.scala | 14 ++- .../scala/org/apache/spark/sql/functions.scala | 20 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 ++++ 6 files changed, 165 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cfa87ae..7d3d036 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -42,6 +42,7 @@ __all__ = [ 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha2', 'sparkPartitionId', 'struct', 'udf', @@ -363,6 +364,24 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha2(col, numBits): + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() + >>> digests[0] + Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + >>> digests[1] + Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5fb3369..457948a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -135,6 +135,7 @@ object FunctionRegistry { // misc functions expression[Md5]("md5"), + expression[Sha2]("sha2"), // aggregate functions expression[Average]("avg"), http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4bee8cb..e80706f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import java.security.MessageDigest +import java.security.NoSuchAlgorithmException + import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BinaryType, StringType, DataType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType} import org.apache.spark.unsafe.types.UTF8String /** @@ -44,7 +47,96 @@ case class Md5(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - "org.apache.spark.unsafe.types.UTF8String.fromString" + - s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + } +} + +/** + * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) + * and returns it as a hex string. The first argument is the string or binary to be hashed. The + * second argument indicates the desired bit length of the result, which must have a value of 224, + * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If + * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or + * the hash length is not one of the permitted values, the return value is NULL. + */ +case class Sha2(left: Expression, right: Expression) + extends BinaryExpression with Serializable with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def toString: String = s"SHA2($left, $right)" + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val bitLength = evalE2.asInstanceOf[Int] + val input = evalE1.asInstanceOf[Array[Byte]] + bitLength match { + case 224 => + // DigestUtils doesn't support SHA-224 now + try { + val md = MessageDigest.getInstance("SHA-224") + md.update(input) + UTF8String.fromBytes(md.digest()) + } catch { + // SHA-224 is not supported on the system, return null + case noa: NoSuchAlgorithmException => null + } + case 256 | 0 => + UTF8String.fromString(DigestUtils.sha256Hex(input)) + case 384 => + UTF8String.fromString(DigestUtils.sha384Hex(input)) + case 512 => + UTF8String.fromString(DigestUtils.sha512Hex(input)) + case _ => null + } + } + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val digestUtils = "org.apache.commons.codec.digest.DigestUtils" + + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + if (${eval2.primitive} == 224) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); + md.update(${eval1.primitive}); + ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + } catch (java.security.NoSuchAlgorithmException e) { + ${ev.isNull} = true; + } + } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 384) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 512) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive})); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ } } http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 48b8413..38482c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.commons.codec.digest.DigestUtils + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{StringType, BinaryType} +import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -29,4 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(null, BinaryType)), null) } + test("sha2") { + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) + checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), + DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) + // unsupported bit length + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 38d9085..355ce0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1414,6 +1414,26 @@ object functions { */ def md5(columnName: String): Column = md5(Column(columnName)) + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(e: Column, numBits: Int): Column = { + require(Seq(0, 224, 256, 384, 512).contains(numBits), + s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") + Sha2(e.expr, lit(numBits).expr) + } + + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8b53b38..8baed57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -144,6 +144,23 @@ class DataFrameFunctionsSuite extends QueryTest { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } + test("misc sha2 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(sha2($"a", 256), sha2("b", 256)), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + checkAnswer( + df.selectExpr("sha2(a, 256)", "sha2(b, 256)"), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + intercept[IllegalArgumentException] { + df.select(sha2($"a", 1024)) + } + } + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")), --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
