Repository: spark
Updated Branches:
  refs/heads/master 81012546e -> f31527227


[SPARK-11946][SQL] Audit pivot API for 1.6.

Currently pivot's signature looks like

```scala
scala.annotation.varargs
def pivot(pivotColumn: Column, values: Column*): GroupedData

scala.annotation.varargs
def pivot(pivotColumn: String, values: Any*): GroupedData
```

I think we can remove the one that takes "Column" types, since callers should 
always be passing in literals. It'd also be more clear if the values are not 
varargs, but rather Seq or java.util.List.

I also made similar changes for Python.

Author: Reynold Xin <[email protected]>

Closes #9929 from rxin/SPARK-11946.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f3152722
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f3152722
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f3152722

Branch: refs/heads/master
Commit: f3152722791b163fa66597b3684009058195ba33
Parents: 8101254
Author: Reynold Xin <[email protected]>
Authored: Tue Nov 24 12:54:37 2015 -0800
Committer: Reynold Xin <[email protected]>
Committed: Tue Nov 24 12:54:37 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/scheduler/DAGScheduler.scala   |   1 -
 python/pyspark/sql/group.py                     |  12 +-
 .../sql/catalyst/expressions/literals.scala     |   1 +
 .../org/apache/spark/sql/GroupedData.scala      | 154 +++++++++++--------
 .../apache/spark/sql/JavaDataFrameSuite.java    |  16 ++
 .../apache/spark/sql/DataFramePivotSuite.scala  |  21 +--
 .../org/apache/spark/sql/test/SQLTestData.scala |   1 +
 7 files changed, 125 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f3152722/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ae725b4..77a184d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1574,7 +1574,6 @@ class DAGScheduler(
   }
 
   def stop() {
-    logInfo("Stopping DAGScheduler")
     messageScheduler.shutdownNow()
     eventProcessLoop.stop()
     taskScheduler.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/f3152722/python/pyspark/sql/group.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 227f40b..d8ed7eb 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -168,20 +168,24 @@ class GroupedData(object):
         """
 
     @since(1.6)
-    def pivot(self, pivot_col, *values):
+    def pivot(self, pivot_col, values=None):
         """Pivots a column of the current DataFrame and preform the specified 
aggregation.
 
         :param pivot_col: Column to pivot
         :param values: Optional list of values of pivotColumn that will be 
translated to columns in
             the output data frame. If values are not provided the method with 
do an immediate call
             to .distinct() on the pivot column.
-        >>> df4.groupBy("year").pivot("course", "dotNET", 
"Java").sum("earnings").collect()
+
+        >>> df4.groupBy("year").pivot("course", ["dotNET", 
"Java"]).sum("earnings").collect()
         [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, 
dotNET=48000, Java=30000)]
+
         >>> df4.groupBy("year").pivot("course").sum("earnings").collect()
         [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, 
dotNET=48000)]
         """
-        jgd = self._jdf.pivot(_to_java_column(pivot_col),
-                              _to_seq(self.sql_ctx._sc, values, 
_create_column_from_literal))
+        if values is None:
+            jgd = self._jdf.pivot(pivot_col)
+        else:
+            jgd = self._jdf.pivot(pivot_col, values)
         return GroupedData(jgd, self.sql_ctx)
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3152722/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e34fd49..68ec688 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -44,6 +44,7 @@ object Literal {
     case a: Array[Byte] => Literal(a, BinaryType)
     case i: CalendarInterval => Literal(i, CalendarIntervalType)
     case null => Literal(null, NullType)
+    case v: Literal => v
     case _ =>
       throw new RuntimeException("Unsupported literal type " + v.getClass + " 
" + v)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f3152722/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 63dd7fb..ee7150c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -25,7 +25,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, 
Aggregate}
-import org.apache.spark.sql.types.{StringType, NumericType}
+import org.apache.spark.sql.types.NumericType
 
 
 /**
@@ -282,74 +282,96 @@ class GroupedData protected[sql](
   }
 
   /**
-    * (Scala-specific) Pivots a column of the current [[DataFrame]] and 
preform the specified
-    * aggregation.
-    * {{{
-    *   // Compute the sum of earnings for each year by course with each 
course as a separate column
-    *   df.groupBy($"year").pivot($"course", "dotNET", 
"Java").agg(sum($"earnings"))
-    *   // Or without specifying column values
-    *   df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
-    * }}}
-    * @param pivotColumn Column to pivot
-    * @param values Optional list of values of pivotColumn that will be 
translated to columns in the
-    *               output data frame. If values are not provided the method 
with do an immediate
-    *               call to .distinct() on the pivot column.
-    * @since 1.6.0
-    */
-  @scala.annotation.varargs
-  def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType 
match {
-    case _: GroupedData.PivotType =>
-      throw new UnsupportedOperationException("repeated pivots are not 
supported")
-    case GroupedData.GroupByType =>
-      val pivotValues = if (values.nonEmpty) {
-        values.map {
-          case Column(literal: Literal) => literal
-          case other =>
-            throw new UnsupportedOperationException(
-              s"The values of a pivot must be literals, found $other")
-        }
-      } else {
-        // This is to prevent unintended OOM errors when the number of 
distinct values is large
-        val maxValues = 
df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
-        // Get the distinct values of the column and sort them so its 
consistent
-        val values = df.select(pivotColumn)
-          .distinct()
-          .sort(pivotColumn)
-          .map(_.get(0))
-          .take(maxValues + 1)
-          .map(Literal(_)).toSeq
-        if (values.length > maxValues) {
-          throw new RuntimeException(
-            s"The pivot column $pivotColumn has more than $maxValues distinct 
values, " +
-              "this could indicate an error. " +
-              "If this was intended, set \"" + 
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
-              s"to at least the number of distinct values of the pivot 
column.")
-        }
-        values
-      }
-      new GroupedData(df, groupingExprs, 
GroupedData.PivotType(pivotColumn.expr, pivotValues))
-    case _ =>
-      throw new UnsupportedOperationException("pivot is only supported after a 
groupBy")
+   * Pivots a column of the current [[DataFrame]] and preform the specified 
aggregation.
+   * There are two versions of pivot function: one that requires the caller to 
specify the list
+   * of distinct values to pivot on, and one that does not. The latter is more 
concise but less
+   * efficient, because Spark needs to first compute the list of distinct 
values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Seq("dotNET", 
"Java")).sum("earnings")
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings")
+   * }}}
+   *
+   * @param pivotColumn Name of the column to pivot.
+   * @since 1.6.0
+   */
+  def pivot(pivotColumn: String): GroupedData = {
+    // This is to prevent unintended OOM errors when the number of distinct 
values is large
+    val maxValues = 
df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+    // Get the distinct values of the column and sort them so its consistent
+    val values = df.select(pivotColumn)
+      .distinct()
+      .sort(pivotColumn)
+      .map(_.get(0))
+      .take(maxValues + 1)
+      .toSeq
+
+    if (values.length > maxValues) {
+      throw new AnalysisException(
+        s"The pivot column $pivotColumn has more than $maxValues distinct 
values, " +
+          "this could indicate an error. " +
+          s"If this was intended, set 
${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " +
+          "to at least the number of distinct values of the pivot column.")
+    }
+
+    pivot(pivotColumn, values)
   }
 
   /**
-    * Pivots a column of the current [[DataFrame]] and preform the specified 
aggregation.
-    * {{{
-    *   // Compute the sum of earnings for each year by course with each 
course as a separate column
-    *   df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
-    *   // Or without specifying column values
-    *   df.groupBy("year").pivot("course").sum("earnings")
-    * }}}
-    * @param pivotColumn Column to pivot
-    * @param values Optional list of values of pivotColumn that will be 
translated to columns in the
-    *               output data frame. If values are not provided the method 
with do an immediate
-    *               call to .distinct() on the pivot column.
-    * @since 1.6.0
-    */
-  @scala.annotation.varargs
-  def pivot(pivotColumn: String, values: Any*): GroupedData = {
-    val resolvedPivotColumn = Column(df.resolve(pivotColumn))
-    pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+   * Pivots a column of the current [[DataFrame]] and preform the specified 
aggregation.
+   * There are two versions of pivot function: one that requires the caller to 
specify the list
+   * of distinct values to pivot on, and one that does not. The latter is more 
concise but less
+   * efficient, because Spark needs to first compute the list of distinct 
values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Seq("dotNET", 
"Java")).sum("earnings")
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings")
+   * }}}
+   *
+   * @param pivotColumn Name of the column to pivot.
+   * @param values List of values that will be translated to columns in the 
output DataFrame.
+   * @since 1.6.0
+   */
+  def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = {
+    groupType match {
+      case GroupedData.GroupByType =>
+        new GroupedData(
+          df,
+          groupingExprs,
+          GroupedData.PivotType(df.resolve(pivotColumn), 
values.map(Literal.apply)))
+      case _: GroupedData.PivotType =>
+        throw new UnsupportedOperationException("repeated pivots are not 
supported")
+      case _ =>
+        throw new UnsupportedOperationException("pivot is only supported after 
a groupBy")
+    }
+  }
+
+  /**
+   * Pivots a column of the current [[DataFrame]] and preform the specified 
aggregation.
+   * There are two versions of pivot function: one that requires the caller to 
specify the list
+   * of distinct values to pivot on, and one that does not. The latter is more 
concise but less
+   * efficient, because Spark needs to first compute the list of distinct 
values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", 
"Java")).sum("earnings");
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings");
+   * }}}
+   *
+   * @param pivotColumn Name of the column to pivot.
+   * @param values List of values that will be translated to columns in the 
output DataFrame.
+   * @since 1.6.0
+   */
+  def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = {
+    pivot(pivotColumn, values.asScala)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f3152722/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 567bddd..a12fed3 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -282,4 +282,20 @@ public class JavaDataFrameSuite {
     Assert.assertEquals(1, actual[1].getLong(0));
     Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13);
   }
+
+  @Test
+  public void pivot() {
+    DataFrame df = context.table("courseSales");
+    Row[] actual = df.groupBy("year")
+      .pivot("course", Arrays.<Object>asList("dotNET", "Java"))
+      .agg(sum("earnings")).orderBy("year").collect();
+
+    Assert.assertEquals(2012, actual[0].getInt(0));
+    Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
+    Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01);
+
+    Assert.assertEquals(2013, actual[1].getInt(0));
+    Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
+    Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f3152722/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index 0c23d14..fc53aba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -25,7 +25,7 @@ class DataFramePivotSuite extends QueryTest with 
SharedSQLContext{
 
   test("pivot courses with literals") {
     checkAnswer(
-      courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+      courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
         .agg(sum($"earnings")),
       Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
     )
@@ -33,14 +33,15 @@ class DataFramePivotSuite extends QueryTest with 
SharedSQLContext{
 
   test("pivot year with literals") {
     checkAnswer(
-      courseSales.groupBy($"course").pivot($"year", lit(2012), 
lit(2013)).agg(sum($"earnings")),
+      courseSales.groupBy("course").pivot("year", Seq(2012, 
2013)).agg(sum($"earnings")),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
 
   test("pivot courses with literals and multiple aggregations") {
     checkAnswer(
-      courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+      courseSales.groupBy($"year")
+        .pivot("course", Seq("dotNET", "Java"))
         .agg(sum($"earnings"), avg($"earnings")),
       Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
         Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
@@ -49,14 +50,14 @@ class DataFramePivotSuite extends QueryTest with 
SharedSQLContext{
 
   test("pivot year with string values (cast)") {
     checkAnswer(
-      courseSales.groupBy("course").pivot("year", "2012", 
"2013").sum("earnings"),
+      courseSales.groupBy("course").pivot("year", Seq("2012", 
"2013")).sum("earnings"),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
 
   test("pivot year with int values") {
     checkAnswer(
-      courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
+      courseSales.groupBy("course").pivot("year", Seq(2012, 
2013)).sum("earnings"),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
@@ -64,22 +65,22 @@ class DataFramePivotSuite extends QueryTest with 
SharedSQLContext{
   test("pivot courses with no values") {
     // Note Java comes before dotNet in sorted order
     checkAnswer(
-      courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
+      courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
       Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
     )
   }
 
   test("pivot year with no values") {
     checkAnswer(
-      courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
+      courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
 
-  test("pivot max values inforced") {
+  test("pivot max values enforced") {
     sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
-    intercept[RuntimeException](
-      courseSales.groupBy($"year").pivot($"course")
+    intercept[AnalysisException](
+      courseSales.groupBy("year").pivot("course")
     )
     sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
       SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)

http://git-wip-us.apache.org/repos/asf/spark/blob/f3152722/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index abad0d7..83c63e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self =>
     person
     salary
     complexData
+    courseSales
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to