Repository: spark
Updated Branches:
  refs/heads/branch-1.5 bca196754 -> 28bb97730


[SPARK-8231] [SQL] Add array_contains

This PR is based on #7580 , thanks to EntilZha

PR for work on https://issues.apache.org/jira/browse/SPARK-8231

Currently, I have an initial implementation for contains. Based on discussion 
on JIRA, it should behave same as Hive: 
https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFArrayContains.java#L102-L128

Main points are:
1. If the array is empty, null, or the value is null, return false
2. If there is a type mismatch, throw error
3. If comparison is not supported, throw error

Closes #7580

Author: Pedro Rodriguez <[email protected]>
Author: Pedro Rodriguez <[email protected]>
Author: Davies Liu <[email protected]>

Closes #7949 from davies/array_contains and squashes the following commits:

d3c08bc [Davies Liu] use foreach() to avoid copy
bc3d1fe [Davies Liu] fix array_contains
719e37d [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
array_contains
e352cf9 [Pedro Rodriguez] fixed diff from master
4d5b0ff [Pedro Rodriguez] added docs and another type check
ffc0591 [Pedro Rodriguez] fixed unit test
7a22deb [Pedro Rodriguez] Changed test to use strings instead of long/ints 
which are different between python 2 an 3
b5ffae8 [Pedro Rodriguez] fixed pyspark test
4e7dce3 [Pedro Rodriguez] added more docs
3082399 [Pedro Rodriguez] fixed unit test
46f9789 [Pedro Rodriguez] reverted change
d3ca013 [Pedro Rodriguez] Fixed type checking to match hive behavior, then 
added tests to insure this
8528027 [Pedro Rodriguez] added more tests
686e029 [Pedro Rodriguez] fix scala style
d262e9d [Pedro Rodriguez] reworked type checking code and added more tests
2517a58 [Pedro Rodriguez] removed unused import
28b4f71 [Pedro Rodriguez] fixed bug with type conversions and re-added tests
12f8795 [Pedro Rodriguez] fix scala style checks
e8a20a9 [Pedro Rodriguez] added python df (broken atm)
65b562c [Pedro Rodriguez] made array_contains nullable false
33b45aa [Pedro Rodriguez] reordered test
9623c64 [Pedro Rodriguez] fixed test
4b4425b [Pedro Rodriguez] changed Arrays in tests to Seqs
72cb4b1 [Pedro Rodriguez] added checkInputTypes and docs
69c46fb [Pedro Rodriguez] added tests and codegen
9e0bfc4 [Pedro Rodriguez] initial attempt at implementation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/28bb9773
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/28bb9773
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/28bb9773

Branch: refs/heads/branch-1.5
Commit: 28bb977302ff3077c82bb8ee7518eb36bddaf2b3
Parents: bca1967
Author: Pedro Rodriguez <[email protected]>
Authored: Tue Aug 4 22:32:21 2015 -0700
Committer: Davies Liu <[email protected]>
Committed: Tue Aug 4 22:35:45 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 17 +++++
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../expressions/collectionOperations.scala      | 78 +++++++++++++++++++-
 .../expressions/CollectionFunctionsSuite.scala  | 15 ++++
 .../scala/org/apache/spark/sql/functions.scala  |  8 ++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 49 +++++++++++-
 6 files changed, 163 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/28bb9773/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e65b14d..9f0d71d 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1311,6 +1311,23 @@ def array(*cols):
     return Column(jc)
 
 
+@since(1.5)
+def array_contains(col, value):
+    """
+    Collection function: returns True if the array contains the given value. 
The collection
+    elements and value must be of the same type.
+
+    :param col: name of column containing array
+    :param value: value to check for in array
+
+    >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
+    >>> df.select(array_contains(df.data, "a")).collect()
+    [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.array_contains(_to_java_column(col), 
value))
+
+
 @since(1.4)
 def explode(col):
     """Returns a new row for each element in the given array or map.

http://git-wip-us.apache.org/repos/asf/spark/blob/28bb9773/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 43e3e9b..94c355f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -240,6 +240,7 @@ object FunctionRegistry {
     // collection functions
     expression[Size]("size"),
     expression[SortArray]("sort_array"),
+    expression[ArrayContains]("array_contains"),
 
     // misc functions
     expression[Crc32]("crc32"),

http://git-wip-us.apache.org/repos/asf/spark/blob/28bb9773/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 6ccb565..646afa4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions
 import java.util.Comparator
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, 
CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.catalyst.expressions.codegen.{
+  CodegenFallback, CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
 /**
@@ -115,3 +116,76 @@ case class SortArray(base: Expression, ascendingOrder: 
Expression)
 
   override def prettyName: String = "sort_array"
 }
+
+/**
+ * Checks if the array (left) has the element (right)
+ */
+case class ArrayContains(left: Expression, right: Expression)
+  extends BinaryExpression with ImplicitCastInputTypes {
+
+  override def dataType: DataType = BooleanType
+
+  override def inputTypes: Seq[AbstractDataType] = right.dataType match {
+    case NullType => Seq()
+    case _ => left.dataType match {
+      case n @ ArrayType(element, _) => Seq(n, element)
+      case _ => Seq()
+    }
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (right.dataType == NullType) {
+      TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as 
arguments")
+    } else if (!left.dataType.isInstanceOf[ArrayType]
+      || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) {
+      TypeCheckResult.TypeCheckFailure(
+        "Arguments must be an array followed by a value of same type as the 
array members")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override def nullable: Boolean = false
+
+  override def eval(input: InternalRow): Boolean = {
+    val arr = left.eval(input)
+    if (arr == null) {
+      false
+    } else {
+      val value = right.eval(input)
+      if (value == null) {
+        false
+      } else {
+        arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
+          if (v == value) return true
+        )
+        false
+      }
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    val arrGen = left.gen(ctx)
+    val elementGen = right.gen(ctx)
+    val i = ctx.freshName("i")
+    val getValue = ctx.getValue(arrGen.primitive, right.dataType, i)
+    s"""
+      ${arrGen.code}
+      boolean ${ev.isNull} = false;
+      boolean ${ev.primitive} = false;
+      if (!${arrGen.isNull}) {
+        ${elementGen.code}
+        if (!${elementGen.isNull}) {
+          for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) {
+            if (${ctx.genEqual(right.dataType, elementGen.primitive, 
getValue)}) {
+              ${ev.primitive} = true;
+              break;
+            }
+          }
+        }
+      }
+     """
+  }
+
+  override def prettyName: String = "array_contains"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/28bb9773/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 2c7e85c..95f0e38 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -65,4 +65,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
 
     checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
   }
+
+  test("Array contains") {
+    val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+    val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+    val a2 = Literal.create(Seq(null), ArrayType(LongType))
+
+    checkEvaluation(ArrayContains(a0, Literal(1)), true)
+    checkEvaluation(ArrayContains(a0, Literal(0)), false)
+    checkEvaluation(ArrayContains(a0, Literal(null)), false)
+
+    checkEvaluation(ArrayContains(a1, Literal("")), true)
+    checkEvaluation(ArrayContains(a1, Literal(null)), false)
+
+    checkEvaluation(ArrayContains(a2, Literal(null)), false)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/28bb9773/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index bff7017..5a10c38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2121,6 +2121,14 @@ object functions {
   
//////////////////////////////////////////////////////////////////////////////////////////////
 
   /**
+   * Returns true if the array contain the value
+   * @group collection_funcs
+   * @since 1.5.0
+   */
+  def array_contains(column: Column, value: Any): Column =
+    ArrayContains(column.expr, Literal(value))
+
+  /**
    * Creates a new row for each element in the given array or map column.
    *
    * @group collection_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/28bb9773/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 6137527..03116a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -323,9 +323,9 @@ class DataFrameFunctionsSuite extends QueryTest {
 
   test("array size function") {
     val df = Seq(
-      (Array[Int](1, 2), "x"),
-      (Array[Int](), "y"),
-      (Array[Int](1, 2, 3), "z")
+      (Seq[Int](1, 2), "x"),
+      (Seq[Int](), "y"),
+      (Seq[Int](1, 2, 3), "z")
     ).toDF("a", "b")
     checkAnswer(
       df.select(size($"a")),
@@ -352,4 +352,47 @@ class DataFrameFunctionsSuite extends QueryTest {
       Seq(Row(2), Row(0), Row(3))
     )
   }
+
+  test("array contains function") {
+    val df = Seq(
+      (Seq[Int](1, 2), "x"),
+      (Seq[Int](), "x")
+    ).toDF("a", "b")
+
+    // Simple test cases
+    checkAnswer(
+      df.select(array_contains(df("a"), 1)),
+      Seq(Row(true), Row(false))
+    )
+    checkAnswer(
+      df.selectExpr("array_contains(a, 1)"),
+      Seq(Row(true), Row(false))
+    )
+    checkAnswer(
+      df.select(array_contains(array(lit(2), lit(null)), 1)),
+      Seq(Row(false), Row(false))
+    )
+
+    // In hive, this errors because null has no type information
+    intercept[AnalysisException] {
+      df.select(array_contains(df("a"), null))
+    }
+    intercept[AnalysisException] {
+      df.selectExpr("array_contains(a, null)")
+    }
+    intercept[AnalysisException] {
+      df.selectExpr("array_contains(null, 1)")
+    }
+
+    // In hive, if either argument has a matching type has a null value, 
return false, even if
+    // the first argument array contains a null and the second argument is null
+    checkAnswer(
+      df.selectExpr("array_contains(array(array(1), null)[1], 1)"),
+      Seq(Row(false), Row(false))
+    )
+    checkAnswer(
+      df.selectExpr("array_contains(array(0, null), array(1, null)[1])"),
+      Seq(Row(false), Row(false))
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to