This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 847199fb6d9 [SPARK-45929][SQL] Support groupingSets operation in 
dataframe api
847199fb6d9 is described below

commit 847199fb6d95910ef624815cfad0be2f8ab8d9d7
Author: JacobZheng0927 <[email protected]>
AuthorDate: Tue Nov 21 10:41:17 2023 +0900

    [SPARK-45929][SQL] Support groupingSets operation in dataframe api
    
    ### What changes were proposed in this pull request?
    Add groupingSets method in dataset api.
    
    `select col1, col2, col3, sum(col4) FROM t GROUP col1, col2, col3 BY 
GROUPING SETS ((col1, col2), ())`
    This SQL can be equivalently replaced with the following code:
    `df.groupingSets(Seq(Seq("col1", "col2"), Seq()), "col1", "col2", 
"col3").sum("col4")`
    
    ### Why are the changes needed?
    Currently grouping sets can only be used in spark sql. This feature is not 
available when developing with the dataset api.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. This PR introduces the use of groupingSets in the dataset api.
    
    ### How was this patch tested?
    Tests added in `DataFrameAggregateSuite.scala`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43813 from JacobZheng0927/SPARK-45929.
    
    Authored-by: JacobZheng0927 <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 27 ++++++++++++++++++++++
 .../spark/sql/RelationalGroupedDataset.scala       | 10 ++++++++
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 15 ++++++++++++
 3 files changed, 52 insertions(+)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5a372f9a0f9..062c4c6bcad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1825,6 +1825,33 @@ class Dataset[T] private[sql](
     RelationalGroupedDataset(toDF(), cols.map(_.expr), 
RelationalGroupedDataset.CubeType)
   }
 
+  /**
+   * Create multi-dimensional aggregation for the current Dataset using the 
specified grouping sets,
+   * so we can run aggregation on them.
+   * See [[RelationalGroupedDataset]] for all the available aggregate 
functions.
+   *
+   * {{{
+   *   // Compute the average for all numeric columns group by specific 
grouping sets.
+   *   ds.groupingSets(Seq(Seq($"department", $"group"),Seq()),$"department", 
$"group").avg()
+   *
+   *   // Compute the max age and average salary, group by specific grouping 
sets.
+   *   ds.groupingSets(Seq($"department", $"gender"), Seq()),$"department", 
$"group").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 4.0.0
+   */
+  @scala.annotation.varargs
+  def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): 
RelationalGroupedDataset = {
+    RelationalGroupedDataset(
+      toDF(),
+      cols.map(_.expr),
+      
RelationalGroupedDataset.GroupingSetsType(groupingSets.map(_.map(_.expr))))
+  }
+
   /**
    * Groups the Dataset using the specified columns, so that we can run 
aggregation on them.
    * See [[RelationalGroupedDataset]] for all the available aggregate 
functions.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 7e15c0baf52..bf1b2814270 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -80,6 +80,11 @@ class RelationalGroupedDataset protected[sql](
         Dataset.ofRows(
           df.sparkSession, Aggregate(Seq(Cube(groupingExprs.map(Seq(_)))),
             aliasedAgg, df.logicalPlan))
+      case RelationalGroupedDataset.GroupingSetsType(groupingSets) =>
+        Dataset.ofRows(
+          df.sparkSession,
+          Aggregate(Seq(GroupingSets(groupingSets, groupingExprs)),
+            aliasedAgg, df.logicalPlan))
       case RelationalGroupedDataset.PivotType(pivotCol, values) =>
         val aliasedGrps = groupingExprs.map(alias)
         Dataset.ofRows(
@@ -732,6 +737,11 @@ private[sql] object RelationalGroupedDataset {
    */
   private[sql] object RollupType extends GroupType
 
+  /**
+   * To indicate it's the GroupingSets
+   */
+  private[sql] case class GroupingSetsType(groupingSets: Seq[Seq[Expression]]) 
extends GroupType
+
   /**
    * To indicate it's the PIVOT
    */
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index c8eea985c10..3691d76d251 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -161,6 +161,21 @@ class DataFrameAggregateSuite extends QueryTest
     assert(cube0.where("date IS NULL").count() > 0)
   }
 
+  test("SPARK-45929 support grouping set operation in dataframe api") {
+    checkAnswer(
+      courseSales
+        .groupingSets(
+          Seq(Seq(Column("course"), Column("year")), Seq()),
+          Column("course"),
+          Column("year"))
+        .agg(sum(Column("earnings")), grouping_id()),
+      Row("Java", 2012, 20000.0, 0) ::
+        Row("Java", 2013, 30000.0, 0) ::
+        Row("dotNET", 2012, 15000.0, 0) ::
+        Row("dotNET", 2013, 48000.0, 0) ::
+        Row(null, null, 113000.0, 3) :: Nil)
+  }
+
   test("grouping and grouping_id") {
     checkAnswer(
       courseSales.cube("course", "year")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to