Repository: spark
Updated Branches:
refs/heads/branch-2.0 7ef1d1c61 -> a04975457
[SPARK-16289][SQL] Implement posexplode table generating function
This PR implements `posexplode` table generating function. Currently, master
branch raises the following exception for `map` argument. It's different from
Hive.
**Before**
```scala
scala> sql("select posexplode(map('a', 1, 'b', 2))").show
org.apache.spark.sql.AnalysisException: No handler for Hive UDF ...
posexplode() takes an array as a parameter; line 1 pos 7
```
**After**
```scala
scala> sql("select posexplode(map('a', 1, 'b', 2))").show
+---+---+-----+
|pos|key|value|
+---+---+-----+
| 0| a| 1|
| 1| b| 2|
+---+---+-----+
```
For `array` argument, `after` is the same with `before`.
```
scala> sql("select posexplode(array(1, 2, 3))").show
+---+---+
|pos|col|
+---+---+
| 0| 1|
| 1| 2|
| 2| 3|
+---+---+
```
Pass the Jenkins tests with newly added testcases.
Author: Dongjoon Hyun <[email protected]>
Closes #13971 from dongjoon-hyun/SPARK-16289.
(cherry picked from commit 46395db80e3304e3f3a1ebdc8aadb8f2819b48b4)
Signed-off-by: Reynold Xin <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a0497545
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a0497545
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a0497545
Branch: refs/heads/branch-2.0
Commit: a049754577aa78a5a26b38821233861a4dfd8e8a
Parents: 7ef1d1c
Author: Dongjoon Hyun <[email protected]>
Authored: Thu Jun 30 12:03:54 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu Jul 7 21:05:31 2016 -0700
----------------------------------------------------------------------
R/pkg/NAMESPACE | 1 +
R/pkg/R/functions.R | 17 ++++
R/pkg/R/generics.R | 4 +
R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +-
python/pyspark/sql/functions.py | 21 +++++
.../catalyst/analysis/FunctionRegistry.scala | 1 +
.../sql/catalyst/expressions/generators.scala | 66 +++++++++++---
.../analysis/ExpressionTypeCheckingSuite.scala | 2 +
.../expressions/GeneratorExpressionSuite.scala | 71 +++++++++++++++
.../scala/org/apache/spark/sql/Column.scala | 1 +
.../scala/org/apache/spark/sql/functions.scala | 8 ++
.../spark/sql/ColumnExpressionSuite.scala | 60 -------------
.../spark/sql/GeneratorFunctionSuite.scala | 92 ++++++++++++++++++++
.../spark/sql/hive/HiveSessionCatalog.scala | 2 +-
14 files changed, 276 insertions(+), 72 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 9fd2568..bc3aceb 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -235,6 +235,7 @@ exportMethods("%in%",
"over",
"percent_rank",
"pmod",
+ "posexplode",
"quarter",
"rand",
"randn",
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/R/pkg/R/functions.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 09e5afa..52d46f9 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -2934,3 +2934,20 @@ setMethod("sort_array",
jc <- callJStatic("org.apache.spark.sql.functions", "sort_array",
x@jc, asc)
column(jc)
})
+
+#' posexplode
+#'
+#' Creates a new row for each element with position in the given array or map
column.
+#'
+#' @rdname posexplode
+#' @name posexplode
+#' @family collection_funcs
+#' @export
+#' @examples \dontrun{posexplode(df$c)}
+#' @note posexplode since 2.1.0
+setMethod("posexplode",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "posexplode",
x@jc)
+ column(jc)
+ })
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index b0f25de..e4ec508 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1054,6 +1054,10 @@ setGeneric("percent_rank", function(x) {
standardGeneric("percent_rank") })
#' @export
setGeneric("pmod", function(y, x) { standardGeneric("pmod") })
+#' @rdname posexplode
+#' @export
+setGeneric("posexplode", function(x) { standardGeneric("posexplode") })
+
#' @rdname quarter
#' @export
setGeneric("quarter", function(x) { standardGeneric("quarter") })
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R
b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 755aded..bd7b5f0 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1085,7 +1085,7 @@ test_that("column functions", {
c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c)
c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) +
md5(c)
- c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c)
+ c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c)
c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) +
monotonically_increasing_id()
c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) +
sqrt(c) + sum(c)
c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c)
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 1feca6e..92d709e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1637,6 +1637,27 @@ def explode(col):
return Column(jc)
+@since(2.1)
+def posexplode(col):
+ """Returns a new row for each element with position in the given array or
map.
+
+ >>> from pyspark.sql import Row
+ >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a":
"b"})])
+ >>> eDF.select(posexplode(eDF.intlist)).collect()
+ [Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]
+
+ >>> eDF.select(posexplode(eDF.mapfield)).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ | 0| a| b|
+ +---+---+-----+
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.posexplode(_to_java_column(col))
+ return Column(jc)
+
+
@ignore_unicode_prefix
@since(1.6)
def get_json_object(col, path):
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 27c3a09..346cdd8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -177,6 +177,7 @@ object FunctionRegistry {
expression[NullIf]("nullif"),
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
+ expression[PosExplode]("posexplode"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 12c3564..4e91cc5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -94,13 +94,10 @@ case class UserDefinedGenerator(
}
/**
- * Given an input array produces a sequence of rows for each value in the
array.
+ * A base class for Explode and PosExplode
*/
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(a) - Separates the elements of array a into multiple rows,
or the elements of a map into multiple rows and columns.")
-// scalastyle:on line.size.limit
-case class Explode(child: Expression) extends UnaryExpression with Generator
with CodegenFallback {
+abstract class ExplodeBase(child: Expression, position: Boolean)
+ extends UnaryExpression with Generator with CodegenFallback with
Serializable {
override def children: Seq[Expression] = child :: Nil
@@ -115,9 +112,26 @@ case class Explode(child: Expression) extends
UnaryExpression with Generator wit
// hive-compatible default alias for explode function ("col" for array,
"key", "value" for map)
override def elementSchema: StructType = child.dataType match {
- case ArrayType(et, containsNull) => new StructType().add("col", et,
containsNull)
+ case ArrayType(et, containsNull) =>
+ if (position) {
+ new StructType()
+ .add("pos", IntegerType, false)
+ .add("col", et, containsNull)
+ } else {
+ new StructType()
+ .add("col", et, containsNull)
+ }
case MapType(kt, vt, valueContainsNull) =>
- new StructType().add("key", kt, false).add("value", vt,
valueContainsNull)
+ if (position) {
+ new StructType()
+ .add("pos", IntegerType, false)
+ .add("key", kt, false)
+ .add("value", vt, valueContainsNull)
+ } else {
+ new StructType()
+ .add("key", kt, false)
+ .add("value", vt, valueContainsNull)
+ }
}
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -129,7 +143,7 @@ case class Explode(child: Expression) extends
UnaryExpression with Generator wit
} else {
val rows = new Array[InternalRow](inputArray.numElements())
inputArray.foreach(et, (i, e) => {
- rows(i) = InternalRow(e)
+ rows(i) = if (position) InternalRow(i, e) else InternalRow(e)
})
rows
}
@@ -141,7 +155,7 @@ case class Explode(child: Expression) extends
UnaryExpression with Generator wit
val rows = new Array[InternalRow](inputMap.numElements())
var i = 0
inputMap.foreach(kt, vt, (k, v) => {
- rows(i) = InternalRow(k, v)
+ rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v)
i += 1
})
rows
@@ -149,3 +163,35 @@ case class Explode(child: Expression) extends
UnaryExpression with Generator wit
}
}
}
+
+/**
+ * Given an input array produces a sequence of rows for each value in the
array.
+ *
+ * {{{
+ * SELECT explode(array(10,20)) ->
+ * 10
+ * 20
+ * }}}
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Separates the elements of array a into multiple rows,
or the elements of map a into multiple rows and columns.",
+ extended = "> SELECT _FUNC_(array(10,20));\n 10\n 20")
+// scalastyle:on line.size.limit
+case class Explode(child: Expression) extends ExplodeBase(child, position =
false)
+
+/**
+ * Given an input array produces a sequence of rows for each position and
value in the array.
+ *
+ * {{{
+ * SELECT posexplode(array(10,20)) ->
+ * 0 10
+ * 1 20
+ * }}}
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Separates the elements of array a into multiple rows
with positions, or the elements of a map into multiple rows and columns with
positions.",
+ extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
+// scalastyle:on line.size.limit
+case class PosExplode(child: Expression) extends ExplodeBase(child, position =
true)
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 54436ea..76e42d9 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(new Murmur3Hash(Nil), "function hash requires at least one
argument")
assertError(Explode('intField),
"input to function explode should be array or map type")
+ assertError(PosExplode('intField),
+ "input to function explode should be array or map type")
}
test("check types for CreateNamedStruct") {
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
new file mode 100644
index 0000000..2aba841
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.unsafe.types.UTF8String
+
+class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
{
+ private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]):
Unit = {
+ assert(actual.eval(null).toSeq === expected)
+ }
+
+ private final val int_array = Seq(1, 2, 3)
+ private final val str_array = Seq("a", "b", "c")
+
+ test("explode") {
+ val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
+ val str_correct_answer = Seq(
+ Seq(UTF8String.fromString("a")),
+ Seq(UTF8String.fromString("b")),
+ Seq(UTF8String.fromString("c")))
+
+ checkTuple(
+ Explode(CreateArray(Seq.empty)),
+ Seq.empty)
+
+ checkTuple(
+ Explode(CreateArray(int_array.map(Literal(_)))),
+ int_correct_answer.map(InternalRow.fromSeq(_)))
+
+ checkTuple(
+ Explode(CreateArray(str_array.map(Literal(_)))),
+ str_correct_answer.map(InternalRow.fromSeq(_)))
+ }
+
+ test("posexplode") {
+ val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
+ val str_correct_answer = Seq(
+ Seq(0, UTF8String.fromString("a")),
+ Seq(1, UTF8String.fromString("b")),
+ Seq(2, UTF8String.fromString("c")))
+
+ checkTuple(
+ PosExplode(CreateArray(Seq.empty)),
+ Seq.empty)
+
+ checkTuple(
+ PosExplode(CreateArray(int_array.map(Literal(_)))),
+ int_correct_answer.map(InternalRow.fromSeq(_)))
+
+ checkTuple(
+ PosExplode(CreateArray(str_array.map(Literal(_)))),
+ str_correct_answer.map(InternalRow.fromSeq(_)))
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 9f35107..a46d194 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -159,6 +159,7 @@ class Column(protected[sql] val expr: Expression) extends
Logging {
// Leave an unaliased generator with an empty list of names since the
analyzer will generate
// the correct defaults after the nested expression's type has been
resolved.
case explode: Explode => MultiAlias(explode, Nil)
+ case explode: PosExplode => MultiAlias(explode, Nil)
case jt: JsonTuple => MultiAlias(jt, Nil)
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e8bd489..c8782df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2722,6 +2722,14 @@ object functions {
def explode(e: Column): Column = withExpr { Explode(e.expr) }
/**
+ * Creates a new row for each element with position in the given array or
map column.
+ *
+ * @group collection_funcs
+ * @since 2.1.0
+ */
+ def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
+
+ /**
* Extracts json object from a json string based on json path specified, and
returns json string
* of the extracted json object. It will return null if the input json
string is invalid.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index a66c83d..a170fae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -122,66 +122,6 @@ class ColumnExpressionSuite extends QueryTest with
SharedSQLContext {
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key")
=== "value")
}
- test("single explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
- checkAnswer(
- df.select(explode('intList)),
- Row(1) :: Row(2) :: Row(3) :: Nil)
- }
-
- test("explode and other columns") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
-
- checkAnswer(
- df.select($"a", explode('intList)),
- Row(1, 1) ::
- Row(1, 2) ::
- Row(1, 3) :: Nil)
-
- checkAnswer(
- df.select($"*", explode('intList)),
- Row(1, Seq(1, 2, 3), 1) ::
- Row(1, Seq(1, 2, 3), 2) ::
- Row(1, Seq(1, 2, 3), 3) :: Nil)
- }
-
- test("aliased explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
-
- checkAnswer(
- df.select(explode('intList).as('int)).select('int),
- Row(1) :: Row(2) :: Row(3) :: Nil)
-
- checkAnswer(
- df.select(explode('intList).as('int)).select(sum('int)),
- Row(6) :: Nil)
- }
-
- test("explode on map") {
- val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
-
- checkAnswer(
- df.select(explode('map)),
- Row("a", "b"))
- }
-
- test("explode on map with aliases") {
- val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
-
- checkAnswer(
- df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1",
"value1"),
- Row("a", "b"))
- }
-
- test("self join explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
- val exploded = df.select(explode('intList).as('i))
-
- checkAnswer(
- exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
- Row(3) :: Nil)
- }
-
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
new file mode 100644
index 0000000..1f0ef34
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("single explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ checkAnswer(
+ df.select(explode('intList)),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+ }
+
+ test("single posexplode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ checkAnswer(
+ df.select(posexplode('intList)),
+ Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
+ }
+
+ test("explode and other columns") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select($"a", explode('intList)),
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(1, 3) :: Nil)
+
+ checkAnswer(
+ df.select($"*", explode('intList)),
+ Row(1, Seq(1, 2, 3), 1) ::
+ Row(1, Seq(1, 2, 3), 2) ::
+ Row(1, Seq(1, 2, 3), 3) :: Nil)
+ }
+
+ test("aliased explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select('int),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select(sum('int)),
+ Row(6) :: Nil)
+ }
+
+ test("explode on map") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map)),
+ Row("a", "b"))
+ }
+
+ test("explode on map with aliases") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1",
"value1"),
+ Row("a", "b"))
+ }
+
+ test("self join explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ val exploded = df.select(explode('intList).as('i))
+
+ checkAnswer(
+ exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
+ Row(3) :: Nil)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/a0497545/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 4c986b0..9fe0bf4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -241,6 +241,6 @@ private[sql] class HiveSessionCatalog(
"xpath_number", "xpath_short", "xpath_string",
// table generating function
- "inline", "posexplode"
+ "inline"
)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]