Repository: spark
Updated Branches:
  refs/heads/master 8228b0630 -> a54438cb2


[SPARK-16285][SQL] Implement sentences SQL functions

## What changes were proposed in this pull request?

This PR implements `sentences` SQL function.

## How was this patch tested?

Pass the Jenkins tests with a new testcase.

Author: Dongjoon Hyun <[email protected]>

Closes #14004 from dongjoon-hyun/SPARK_16285.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a54438cb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a54438cb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a54438cb

Branch: refs/heads/master
Commit: a54438cb23c80f7c7fc35da273677c39317cb1a5
Parents: 8228b06
Author: Dongjoon Hyun <[email protected]>
Authored: Fri Jul 8 17:05:24 2016 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Fri Jul 8 17:05:24 2016 +0800

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../expressions/stringExpressions.scala         | 68 +++++++++++++++++++-
 .../expressions/StringExpressionsSuite.scala    | 23 +++++++
 .../apache/spark/sql/StringFunctionsSuite.scala | 20 ++++++
 .../spark/sql/hive/HiveSessionCatalog.scala     |  2 +-
 5 files changed, 111 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a54438cb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f6ebcae..842c9c6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -296,6 +296,7 @@ object FunctionRegistry {
     expression[RLike]("rlike"),
     expression[StringRPad]("rpad"),
     expression[StringTrimRight]("rtrim"),
+    expression[Sentences]("sentences"),
     expression[SoundEx]("soundex"),
     expression[StringSpace]("space"),
     expression[StringSplit]("split"),

http://git-wip-us.apache.org/repos/asf/spark/blob/a54438cb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index b0df957..894e12d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import java.text.{DecimalFormat, DecimalFormatSymbols}
+import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
 import java.util.{HashMap, Locale, Map => JMap}
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
 
@@ -1188,3 +1190,65 @@ case class FormatNumber(x: Expression, d: Expression)
 
   override def prettyName: String = "format_number"
 }
+
+/**
+ * Splits a string into arrays of sentences, where each sentence is an array 
of words.
+ * The 'lang' and 'country' arguments are optional, and if omitted, the 
default locale is used.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of 
words.",
+  extended = "> SELECT _FUNC_('Hi there! Good morning.');\n  [['Hi','there'], 
['Good','morning']]")
+case class Sentences(
+    str: Expression,
+    language: Expression = Literal(""),
+    country: Expression = Literal(""))
+  extends Expression with ImplicitCastInputTypes with CodegenFallback {
+
+  def this(str: Expression) = this(str, Literal(""), Literal(""))
+  def this(str: Expression, language: Expression) = this(str, language, 
Literal(""))
+
+  override def nullable: Boolean = true
+  override def dataType: DataType =
+    ArrayType(ArrayType(StringType, containsNull = false), containsNull = 
false)
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, 
StringType)
+  override def children: Seq[Expression] = str :: language :: country :: Nil
+
+  override def eval(input: InternalRow): Any = {
+    val string = str.eval(input)
+    if (string == null) {
+      null
+    } else {
+      val languageStr = language.eval(input).asInstanceOf[UTF8String]
+      val countryStr = country.eval(input).asInstanceOf[UTF8String]
+      val locale = if (languageStr != null && countryStr != null) {
+        new Locale(languageStr.toString, countryStr.toString)
+      } else {
+        Locale.getDefault
+      }
+      getSentences(string.asInstanceOf[UTF8String].toString, locale)
+    }
+  }
+
+  private def getSentences(sentences: String, locale: Locale) = {
+    val bi = BreakIterator.getSentenceInstance(locale)
+    bi.setText(sentences)
+    var idx = 0
+    val result = new ArrayBuffer[GenericArrayData]
+    while (bi.next != BreakIterator.DONE) {
+      val sentence = sentences.substring(idx, bi.current)
+      idx = bi.current
+
+      val wi = BreakIterator.getWordInstance(locale)
+      var widx = 0
+      wi.setText(sentence)
+      val words = new ArrayBuffer[UTF8String]
+      while (wi.next != BreakIterator.DONE) {
+        val word = sentence.substring(widx, wi.current)
+        widx = wi.current
+        if (Character.isLetterOrDigit(word.charAt(0))) words += 
UTF8String.fromString(word)
+      }
+      result += new GenericArrayData(words)
+    }
+    new GenericArrayData(result)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a54438cb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 5f01561..256ce85 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -725,4 +725,27 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
     checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
   }
+
+  test("Sentences") {
+    val nullString = Literal.create(null, StringType)
+    checkEvaluation(Sentences(nullString, nullString, nullString), null)
+    checkEvaluation(Sentences(nullString, nullString), null)
+    checkEvaluation(Sentences(nullString), null)
+    checkEvaluation(Sentences(Literal.create(null, NullType)), null)
+    checkEvaluation(Sentences("", nullString, nullString), Seq.empty)
+    checkEvaluation(Sentences("", nullString), Seq.empty)
+    checkEvaluation(Sentences(""), Seq.empty)
+
+    val answer = Seq(
+      Seq("Hi", "there"),
+      Seq("The", "price", "was"),
+      Seq("But", "not", "now"))
+
+    checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not 
now."), answer)
+    checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not 
now.", "en"), answer)
+    checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not 
now.", "en", "US"),
+      answer)
+    checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not 
now.", "XXX", "YYY"),
+      answer)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a54438cb/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 3edd988..044ac22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -349,4 +349,24 @@ class StringFunctionsSuite extends QueryTest with 
SharedSQLContext {
       df2.filter("b>0").selectExpr("format_number(a, b)"),
       Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: 
Row("3.00") :: Nil)
   }
+
+  test("string sentences function") {
+    val df = Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", 
"US"))
+      .toDF("str", "language", "country")
+
+    checkAnswer(
+      df.selectExpr("sentences(str, language, country)"),
+      Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", 
"not", "now"))))
+
+    // Type coercion
+    checkAnswer(
+      df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"),
+      Row(null, Seq(Seq("10")), Seq(Seq("3.14"))))
+
+    // Argument number exception
+    val m = intercept[AnalysisException] {
+      df.selectExpr("sentences()")
+    }.getMessage
+    assert(m.contains("Invalid number of arguments for function sentences"))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a54438cb/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index fdc4c18..6f05f0f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
   // str_to_map, windowingtablefunction.
   private val hiveFunctions = Seq(
     "hash", "java_method", "histogram_numeric",
-    "parse_url", "percentile", "percentile_approx", "reflect", "sentences", 
"str_to_map",
+    "parse_url", "percentile", "percentile_approx", "reflect", "str_to_map",
     "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
     "xpath_number", "xpath_short", "xpath_string"
   )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to