Repository: spark
Updated Branches:
  refs/heads/master 03377d252 -> 6996bd2e8


[SPARK-8264][SQL]add substring_index function

This PR is based on #7533 , thanks to zhichao-li

Closes #7533

Author: zhichao.li <[email protected]>
Author: Davies Liu <[email protected]>

Closes #7843 from davies/str_index and squashes the following commits:

391347b [Davies Liu] add python api
3ce7802 [Davies Liu] fix substringIndex
f2d29a1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
str_index
515519b [zhichao.li] add foldable and remove null checking
9546991 [zhichao.li] scala style
67c253a [zhichao.li] hide some apis and clean code
b19b013 [zhichao.li] add codegen and clean code
ac863e9 [zhichao.li] reduce the calling of numChars
12e108f [zhichao.li] refine unittest
d92951b [zhichao.li] add lastIndexOf
52d7b03 [zhichao.li] add substring_index function


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6996bd2e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6996bd2e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6996bd2e

Branch: refs/heads/master
Commit: 6996bd2e81bf6597dcda499d9a9a80927a43e30f
Parents: 03377d2
Author: zhichao.li <[email protected]>
Authored: Fri Jul 31 21:18:01 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Fri Jul 31 21:18:01 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 19 +++++
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../catalyst/expressions/stringOperations.scala | 25 ++++++
 .../expressions/StringExpressionsSuite.scala    | 31 ++++++++
 .../scala/org/apache/spark/sql/functions.scala  | 12 ++-
 .../apache/spark/sql/StringFunctionsSuite.scala | 57 ++++++++++++++
 .../apache/spark/unsafe/types/UTF8String.java   | 80 +++++++++++++++++++-
 .../spark/unsafe/types/UTF8StringSuite.java     | 38 ++++++++++
 8 files changed, 261 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6996bd2e/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bb9926c..89a2a5c 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -921,6 +921,25 @@ def trunc(date, format):
 
 
 @since(1.5)
+@ignore_unicode_prefix
+def substring_index(str, delim, count):
+    """
+    Returns the substring from string str before count occurrences of the 
delimiter delim.
+    If count is positive, everything the left of the final delimiter (counting 
from left) is
+    returned. If count is negative, every to the right of the final delimiter 
(counting from the
+    right) is returned. substring_index performs a case-sensitive match when 
searching for delim.
+
+    >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s'])
+    >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()
+    [Row(s=u'a.b')]
+    >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()
+    [Row(s=u'b.c.d')]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.substring_index(_to_java_column(str), 
delim, count))
+
+
+@since(1.5)
 def size(col):
     """
     Collection function: returns the length of the array or map stored in the 
column.

http://git-wip-us.apache.org/repos/asf/spark/blob/6996bd2e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3f61a9a..ee44cbc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -199,6 +199,7 @@ object FunctionRegistry {
     expression[StringSplit]("split"),
     expression[Substring]("substr"),
     expression[Substring]("substring"),
+    expression[SubstringIndex]("substring_index"),
     expression[StringTrim]("trim"),
     expression[UnBase64]("unbase64"),
     expression[Upper]("ucase"),

http://git-wip-us.apache.org/repos/asf/spark/blob/6996bd2e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 160e72f..5dd387a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -422,6 +422,31 @@ case class StringInstr(str: Expression, substr: Expression)
 }
 
 /**
+ * Returns the substring from string str before count occurrences of the 
delimiter delim.
+ * If count is positive, everything the left of the final delimiter (counting 
from left) is
+ * returned. If count is negative, every to the right of the final delimiter 
(counting from the
+ * right) is returned. substring_index performs a case-sensitive match when 
searching for delim.
+ */
+case class SubstringIndex(strExpr: Expression, delimExpr: Expression, 
countExpr: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+  override def dataType: DataType = StringType
+  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, 
IntegerType)
+  override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
+  override def prettyName: String = "substring_index"
+
+  override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
+    str.asInstanceOf[UTF8String].subStringIndex(
+      delim.asInstanceOf[UTF8String],
+      count.asInstanceOf[Int])
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    defineCodeGen(ctx, ev, (str, delim, count) => 
s"$str.subStringIndex($delim, $count)")
+  }
+}
+
+/**
  * A function that returns the position of the first occurrence of substr
  * in given string after position pos.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/6996bd2e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index fb72fe1..ad87ab3 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 
 class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(s.substring(0), "example", row)
   }
 
+  test("string substring_index function") {
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), 
"www.apache.org")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), 
"www.apache")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), 
"www")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), 
"www.apache.org")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), 
"apache.org")
+    checkEvaluation(
+      SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), 
"org")
+    checkEvaluation(
+      SubstringIndex(Literal(""), Literal("."), Literal(-2)), "")
+    checkEvaluation(
+      SubstringIndex(Literal.create(null, StringType), Literal("."), 
Literal(-2)), null)
+    checkEvaluation(SubstringIndex(
+        Literal("www.apache.org"), Literal.create(null, StringType), 
Literal(-2)), null)
+    // non ascii chars
+    // scalastyle:off
+    checkEvaluation(
+      SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), 
Literal(2)), "大千世界大")
+    // scalastyle:on
+    checkEvaluation(
+      SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), 
"www||apache")
+  }
+
   test("LIKE literal Regular Expression") {
     checkEvaluation(Literal.create(null, StringType).like("a"), null)
     checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, 
StringType)), null)

http://git-wip-us.apache.org/repos/asf/spark/blob/6996bd2e/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 89ffa9c..57bb00a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1788,8 +1788,18 @@ object functions {
   def instr(str: Column, substring: String): Column = StringInstr(str.expr, 
lit(substring).expr)
 
   /**
-   * Locate the position of the first occurrence of substr in a string column.
+   * Returns the substring from string str before count occurrences of the 
delimiter delim.
+   * If count is positive, everything the left of the final delimiter 
(counting from left) is
+   * returned. If count is negative, every to the right of the final delimiter 
(counting from the
+   * right) is returned. substring_index performs a case-sensitive match when 
searching for delim.
    *
+   * @group string_funcs
+   */
+  def substring_index(str: Column, delim: String, count: Int): Column =
+    SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
+
+  /**
+   * Locate the position of the first occurrence of substr.
    * NOTE: The position is not zero based, but 1 based index, returns 0 if 
substr
    * could not be found in str.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/6996bd2e/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index b7f073c..628da95 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -163,6 +163,63 @@ class StringFunctionsSuite extends QueryTest {
       Row(1))
   }
 
+  test("string substring_index function") {
+    val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c")
+    checkAnswer(
+      df.select(substring_index($"a", ".", 3)),
+      Row("www.apache.org"))
+    checkAnswer(
+      df.select(substring_index($"a", ".", 2)),
+      Row("www.apache"))
+    checkAnswer(
+      df.select(substring_index($"a", ".", 1)),
+      Row("www"))
+    checkAnswer(
+      df.select(substring_index($"a", ".", 0)),
+      Row(""))
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), ".", -1)),
+      Row("org"))
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), ".", -2)),
+      Row("apache.org"))
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), ".", -3)),
+      Row("www.apache.org"))
+    // str is empty string
+    checkAnswer(
+      df.select(substring_index(lit(""), ".", 1)),
+      Row(""))
+    // empty string delim
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), "", 1)),
+      Row(""))
+    // delim does not exist in str
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), "#", 1)),
+      Row("www.apache.org"))
+    // delim is 2 chars
+    checkAnswer(
+      df.select(substring_index(lit("www||apache||org"), "||", 2)),
+      Row("www||apache"))
+    checkAnswer(
+      df.select(substring_index(lit("www||apache||org"), "||", -2)),
+      Row("apache||org"))
+    // null
+    checkAnswer(
+      df.select(substring_index(lit(null), "||", 2)),
+      Row(null))
+    checkAnswer(
+      df.select(substring_index(lit("www.apache.org"), null, 2)),
+      Row(null))
+    // non ascii chars
+    // scalastyle:off
+    checkAnswer(
+      df.selectExpr("""substring_index("大千世界大千世界", "千", 
2)"""),
+      Row("大千世界大"))
+    // scalastyle:on
+  }
+
   test("string locate function") {
     val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6996bd2e/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 9d4998f..2561c1c 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -198,7 +198,7 @@ public final class UTF8String implements 
Comparable<UTF8String>, Serializable {
    */
   public UTF8String substring(final int start, final int until) {
     if (until <= start || start >= numBytes) {
-      return fromBytes(new byte[0]);
+      return UTF8String.EMPTY_UTF8;
     }
 
     int i = 0;
@@ -407,6 +407,84 @@ public final class UTF8String implements 
Comparable<UTF8String>, Serializable {
   }
 
   /**
+   * Find the `str` from left to right.
+   */
+  private int find(UTF8String str, int start) {
+    assert (str.numBytes > 0);
+    while (start <= numBytes - str.numBytes) {
+      if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, 
str.offset, str.numBytes)) {
+        return start;
+      }
+      start += 1;
+    }
+    return -1;
+  }
+
+  /**
+   * Find the `str` from right to left.
+   */
+  private int rfind(UTF8String str, int start) {
+    assert (str.numBytes > 0);
+    while (start >= 0) {
+      if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, 
str.offset, str.numBytes)) {
+        return start;
+      }
+      start -= 1;
+    }
+    return -1;
+  }
+
+  /**
+   * Returns the substring from string str before count occurrences of the 
delimiter delim.
+   * If count is positive, everything the left of the final delimiter 
(counting from left) is
+   * returned. If count is negative, every to the right of the final delimiter 
(counting from the
+   * right) is returned. subStringIndex performs a case-sensitive match when 
searching for delim.
+   */
+  public UTF8String subStringIndex(UTF8String delim, int count) {
+    if (delim.numBytes == 0 || count == 0) {
+      return EMPTY_UTF8;
+    }
+    if (count > 0) {
+      int idx = -1;
+      while (count > 0) {
+        idx = find(delim, idx + 1);
+        if (idx >= 0) {
+          count --;
+        } else {
+          // can not find enough delim
+          return this;
+        }
+      }
+      if (idx == 0) {
+        return EMPTY_UTF8;
+      }
+      byte[] bytes = new byte[idx];
+      copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx);
+      return fromBytes(bytes);
+
+    } else {
+      int idx = numBytes - delim.numBytes + 1;
+      count = -count;
+      while (count > 0) {
+        idx = rfind(delim, idx - 1);
+        if (idx >= 0) {
+          count --;
+        } else {
+          // can not find enough delim
+          return this;
+        }
+      }
+      if (idx + delim.numBytes == numBytes) {
+        return EMPTY_UTF8;
+      }
+      int size = numBytes - delim.numBytes - idx;
+      byte[] bytes = new byte[size];
+      copyMemory(base, offset + idx + delim.numBytes, bytes, 
BYTE_ARRAY_OFFSET, size);
+      return fromBytes(bytes);
+    }
+  }
+
+  /**
    * Returns str, right-padded with pad to a length of len
    * For example:
    *   ('hi', 5, '??') =&gt; 'hi???'

http://git-wip-us.apache.org/repos/asf/spark/blob/6996bd2e/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git 
a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java 
b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index c565210..43eed70 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -241,6 +241,44 @@ public class UTF8StringSuite {
   }
 
   @Test
+  public void substring_index() {
+    assertEquals(fromString("www.apache.org"),
+      fromString("www.apache.org").subStringIndex(fromString("."), 3));
+    assertEquals(fromString("www.apache"),
+      fromString("www.apache.org").subStringIndex(fromString("."), 2));
+    assertEquals(fromString("www"),
+      fromString("www.apache.org").subStringIndex(fromString("."), 1));
+    assertEquals(fromString(""),
+      fromString("www.apache.org").subStringIndex(fromString("."), 0));
+    assertEquals(fromString("org"),
+      fromString("www.apache.org").subStringIndex(fromString("."), -1));
+    assertEquals(fromString("apache.org"),
+      fromString("www.apache.org").subStringIndex(fromString("."), -2));
+    assertEquals(fromString("www.apache.org"),
+      fromString("www.apache.org").subStringIndex(fromString("."), -3));
+    // str is empty string
+    assertEquals(fromString(""),
+      fromString("").subStringIndex(fromString("."), 1));
+    // empty string delim
+    assertEquals(fromString(""),
+      fromString("www.apache.org").subStringIndex(fromString(""), 1));
+    // delim does not exist in str
+    assertEquals(fromString("www.apache.org"),
+      fromString("www.apache.org").subStringIndex(fromString("#"), 2));
+    // delim is 2 chars
+    assertEquals(fromString("www||apache"),
+      fromString("www||apache||org").subStringIndex(fromString("||"), 2));
+    assertEquals(fromString("apache||org"),
+      fromString("www||apache||org").subStringIndex(fromString("||"), -2));
+    // non ascii chars
+    assertEquals(fromString("大千世界大"),
+      fromString("大千世界大千世界").subStringIndex(fromString("千"), 
2));
+    // overlapped delim
+    assertEquals(fromString("||"), 
fromString("||||||").subStringIndex(fromString("|||"), 3));
+    assertEquals(fromString("|||"), 
fromString("||||||").subStringIndex(fromString("|||"), -4));
+  }
+
+  @Test
   public void reverse() {
     assertEquals(fromString("olleh"), fromString("hello").reverse());
     assertEquals(EMPTY_UTF8, EMPTY_UTF8.reverse());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to