This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6382a3be8f6 [SPARK-38591][SQL] Add flatMapSortedGroups and 
cogroupSorted
6382a3be8f6 is described below

commit 6382a3be8f6a937412c4e23a92ab440f7ba80bdf
Author: Enrico Minack <[email protected]>
AuthorDate: Fri Jan 20 13:21:29 2023 +0800

    [SPARK-38591][SQL] Add flatMapSortedGroups and cogroupSorted
    
    ### What changes were proposed in this pull request?
    This adds a sorted version of `Dataset.groupByKey(…).flatMapGroups(…)` and 
`Dataset.groupByKey(…).cogroup(…)`.
    
    ### Why are the changes needed?
    The existing methods `KeyValueGroupedDataset.flatMapGroups` and 
`KeyValueGroupedDataset.cogroup` provide iterators of rows for each group key.
    
    Sorting entire groups inside `flatMapGroups` / `cogroup` requires 
materialising all rows, which is against the idea of an iterator in the first 
place. Methods `flatMapGroups` and `cogroup` have the great advantage that they 
work with groups that are _too large to fit into memory of one executor_. 
Sorting them in the user function breaks this property.
    
    
[org.apache.spark.sql.KeyValueGroupedDataset](https://github.com/apache/spark/blob/47485a3c2df3201c838b939e82d5b26332e2d858/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala#L134-L137):
    > Internally, the implementation will spill to disk if any given group is 
too large to fit into
    > memory.  However, users must take care to avoid materializing the whole 
iterator for a group
    > (for example, by calling `toList`) unless they are sure that this is 
possible given the memory
    > constraints of their cluster.
    
    The implementations of `KeyValueGroupedDataset.flatMapGroups` and 
`KeyValueGroupedDataset.cogroup` already sort each partition according to the 
group key. By additionally sorting by some data columns, the iterator can be 
guaranteed to provide some order.
    
    ### Does this PR introduce _any_ user-facing change?
    This adds `KeyValueGroupedDataset.flatMapSortedGroups` and 
`KeyValueGroupedDataset.cogroupSorted`, which guarantees order of group 
iterators.
    
    ### How was this patch tested?
    Tests have been added to `DatasetSuite` and `JavaDatasetSuite`.
    
    Closes #39640 from EnricoMi/branch-sorted-groups-and-cogroups.
    
    Authored-by: Enrico Minack <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../apache/spark/sql/catalyst/dsl/package.scala    |   6 +-
 .../spark/sql/catalyst/plans/logical/object.scala  |  12 +-
 .../apache/spark/sql/KeyValueGroupedDataset.scala  | 148 ++++++++++++++++++++-
 .../spark/sql/execution/SparkStrategies.scala      |  11 +-
 .../org/apache/spark/sql/execution/objects.scala   |  15 ++-
 .../streaming/FlatMapGroupsWithStateExec.scala     |   6 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java     |  47 +++++--
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  68 ++++++++++
 8 files changed, 291 insertions(+), 22 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 86d85abc6f3..ecd1ed94ffd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -425,7 +425,9 @@ package object dsl {
           leftGroup: Seq[Attribute],
           rightGroup: Seq[Attribute],
           leftAttr: Seq[Attribute],
-          rightAttr: Seq[Attribute]
+          rightAttr: Seq[Attribute],
+          leftOrder: Seq[SortOrder] = Nil,
+          rightOrder: Seq[SortOrder] = Nil
         ): LogicalPlan = {
         CoGroup.apply[Key, Left, Right, Result](
           func,
@@ -433,6 +435,8 @@ package object dsl {
           rightGroup,
           leftAttr,
           rightAttr,
+          leftOrder,
+          rightOrder,
           logicalPlan,
           otherPlan)
       }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index e5fe07e2d95..b27c650cfb2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -389,6 +389,7 @@ object MapGroups {
       func: (K, Iterator[T]) => TraversableOnce[U],
       groupingAttributes: Seq[Attribute],
       dataAttributes: Seq[Attribute],
+      dataOrder: Seq[SortOrder],
       child: LogicalPlan): LogicalPlan = {
     val mapped = new MapGroups(
       func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
@@ -396,6 +397,7 @@ object MapGroups {
       UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
       groupingAttributes,
       dataAttributes,
+      dataOrder,
       CatalystSerde.generateObjAttr[U],
       child)
     CatalystSerde.serialize[U](mapped)
@@ -405,7 +407,8 @@ object MapGroups {
 /**
  * Applies func to each unique group in `child`, based on the evaluation of 
`groupingAttributes`.
  * Func is invoked with an object representation of the grouping key an 
iterator containing the
- * object representation of all the rows with that key.
+ * object representation of all the rows with that key. Given an additional 
`dataOrder`, data in
+ * the iterator will be sorted accordingly. That sorting does not add 
computational complexity.
  *
  * @param keyDeserializer used to extract the key object for each group.
  * @param valueDeserializer used to extract the items in the iterator from an 
input row.
@@ -416,6 +419,7 @@ case class MapGroups(
     valueDeserializer: Expression,
     groupingAttributes: Seq[Attribute],
     dataAttributes: Seq[Attribute],
+    dataOrder: Seq[SortOrder],
     outputObjAttr: Attribute,
     child: LogicalPlan) extends UnaryNode with ObjectProducer {
   override protected def withNewChildInternal(newChild: LogicalPlan): 
MapGroups =
@@ -649,6 +653,8 @@ object CoGroup {
       rightGroup: Seq[Attribute],
       leftAttr: Seq[Attribute],
       rightAttr: Seq[Attribute],
+      leftOrder: Seq[SortOrder],
+      rightOrder: Seq[SortOrder],
       left: LogicalPlan,
       right: LogicalPlan): LogicalPlan = {
     require(StructType.fromAttributes(leftGroup) == 
StructType.fromAttributes(rightGroup))
@@ -664,6 +670,8 @@ object CoGroup {
       rightGroup,
       leftAttr,
       rightAttr,
+      leftOrder,
+      rightOrder,
       CatalystSerde.generateObjAttr[OUT],
       left,
       right)
@@ -684,6 +692,8 @@ case class CoGroup(
     rightGroup: Seq[Attribute],
     leftAttr: Seq[Attribute],
     rightAttr: Seq[Attribute],
+    leftOrder: Seq[SortOrder],
+    rightOrder: Seq[SortOrder],
     outputObjAttr: Attribute,
     left: LogicalPlan,
     right: LogicalPlan) extends BinaryNode with ObjectProducer {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index add692f57d2..4d2377b9b96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.api.java.function._
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
CreateStruct}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, 
CreateStruct, Expression, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.expressions.ReduceAggregator
@@ -145,6 +145,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
         f,
         groupingAttributes,
         dataAttributes,
+        Seq.empty,
         logicalPlan))
   }
 
@@ -171,6 +172,83 @@ class KeyValueGroupedDataset[K, V] private[sql](
     flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
   }
 
+  /**
+   * (Scala-specific)
+   * Applies the given function to each group of data.  For each unique group, 
the function will
+   * be passed the group key and a sorted iterator that contains all of the 
elements in the group.
+   * The function can return an iterator containing elements of an arbitrary 
type which will be
+   * returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the memory
+   * constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except 
for the iterator
+   * to be sorted according to the given sort expressions. That sorting does 
not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.4.0
+   */
+  def flatMapSortedGroups[U : Encoder](
+      sortExprs: Column*)(
+      f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+      col.expr match {
+        case expr: SortOrder => expr
+        case expr: Expression => SortOrder(expr, Ascending)
+      }
+    }
+
+    Dataset[U](
+      sparkSession,
+      MapGroups(
+        f,
+        groupingAttributes,
+        dataAttributes,
+        sortOrder,
+        logicalPlan
+      )
+    )
+  }
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each group of data.  For each unique group, 
the function will
+   * be passed the group key and a sorted iterator that contains all of the 
elements in the group.
+   * The function can return an iterator containing elements of an arbitrary 
type which will be
+   * returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the memory
+   * constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except 
for the iterator
+   * to be sorted according to the given sort expressions. That sorting does 
not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.4.0
+   */
+  def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = {
+    flatMapSortedGroups(SortExprs: _*)((key, data) => f.call(key, 
data.asJava).asScala)(encoder)
+  }
+
   /**
    * (Scala-specific)
    * Applies the given function to each group of data.  For each unique group, 
the function will
@@ -753,6 +831,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
         other.groupingAttributes,
         this.dataAttributes,
         other.dataAttributes,
+        Seq.empty,
+        Seq.empty,
         this.logicalPlan,
         other.logicalPlan))
   }
@@ -773,6 +853,72 @@ class KeyValueGroupedDataset[K, V] private[sql](
     cogroup(other)((key, left, right) => f.call(key, left.asJava, 
right.asJava).asScala)(encoder)
   }
 
+  /**
+   * (Scala-specific)
+   * Applies the given function to each sorted cogrouped data.  For each 
unique group, the function
+   * will be passed the grouping key and 2 sorted iterators containing all 
elements in the group
+   * from [[Dataset]] `this` and `other`.  The function can return an iterator 
containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the 
iterators
+   * to be sorted according to the given sort expressions. That sorting does 
not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.4.0
+   */
+  def cogroupSorted[U, R : Encoder](
+      other: KeyValueGroupedDataset[K, U])(
+      thisSortExprs: Column*)(
+      otherSortExprs: Column*)(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    def toSortOrder(col: Column): SortOrder = col.expr match {
+      case expr: SortOrder => expr
+      case expr: Expression => SortOrder(expr, Ascending)
+    }
+
+    val thisSortOrder: Seq[SortOrder] = thisSortExprs.map(toSortOrder)
+    val otherSortOrder: Seq[SortOrder] = otherSortExprs.map(toSortOrder)
+
+    implicit val uEncoder = other.vExprEnc
+    Dataset[R](
+      sparkSession,
+      CoGroup(
+        f,
+        this.groupingAttributes,
+        other.groupingAttributes,
+        this.dataAttributes,
+        other.dataAttributes,
+        thisSortOrder,
+        otherSortOrder,
+        this.logicalPlan,
+        other.logicalPlan))
+  }
+
+  /**
+   * (Java-specific)
+   * Applies the given function to each sorted cogrouped data.  For each 
unique group, the function
+   * will be passed the grouping key and 2 sorted iterators containing all 
elements in the group
+   * from [[Dataset]] `this` and `other`.  The function can return an iterator 
containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the 
iterators
+   * to be sorted according to the given sort expressions. That sorting does 
not add
+   * computational complexity.
+   *
+   * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.4.0
+   */
+  def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
+      (key, left, right) => f.call(key, left.asJava, 
right.asJava).asScala)(encoder)
+  }
+
   override def toString: String = {
     val builder = new StringBuilder
     val kFields = kExprEnc.schema.map {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 110fe45cc12..cd4485e3822 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -806,8 +806,10 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
       case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
         execution.AppendColumnsWithObjectExec(f, childSer, newSer, 
planLater(child)) :: Nil
-      case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
-        execution.MapGroupsExec(f, key, value, grouping, data, objAttr, 
planLater(child)) :: Nil
+      case logical.MapGroups(f, key, value, grouping, data, order, objAttr, 
child) =>
+        execution.MapGroupsExec(
+          f, key, value, grouping, data, order, objAttr, planLater(child)
+        ) :: Nil
       case logical.FlatMapGroupsWithState(
           f, keyDeserializer, valueDeserializer, grouping, data, output, 
stateEncoder, outputMode,
           isFlatMapGroupsWithState, timeout, hasInitialState, 
initialStateGroupAttrs,
@@ -821,9 +823,10 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         // TODO(SPARK-40443): support applyInPandasWithState in batch query
         throw new UnsupportedOperationException(
           "applyInPandasWithState is unsupported in batch query. Use 
applyInPandas instead.")
-      case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, 
oAttr, left, right) =>
+      case logical.CoGroup(
+          f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, lOrder, rOrder, 
oAttr, left, right) =>
         execution.CoGroupExec(
-          f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
+          f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, lOrder, rOrder, 
oAttr,
           planLater(left), planLater(right)) :: Nil
 
       case r @ logical.Repartition(numPartitions, shuffle, child) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 869d3fe9790..bda592ff929 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -391,7 +391,8 @@ case class AppendColumnsWithObjectExec(
 
 /**
  * Groups the input rows together and calls the function with each group and 
an iterator containing
- * all elements in the group.  The result of this function is flattened before 
being output.
+ * all elements in the group. The iterator is sorted according to `dataOrder` 
if given.
+ * The result of this function is flattened before being output.
  */
 case class MapGroupsExec(
     func: (Any, Iterator[Any]) => TraversableOnce[Any],
@@ -399,6 +400,7 @@ case class MapGroupsExec(
     valueDeserializer: Expression,
     groupingAttributes: Seq[Attribute],
     dataAttributes: Seq[Attribute],
+    dataOrder: Seq[SortOrder],
     outputObjAttr: Attribute,
     child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
 
@@ -408,7 +410,7 @@ case class MapGroupsExec(
     ClusteredDistribution(groupingAttributes) :: Nil
 
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+    Seq(groupingAttributes.map(SortOrder(_, Ascending)) ++ dataOrder)
 
   override protected def doExecute(): RDD[InternalRow] = {
     child.execute().mapPartitionsInternal { iter =>
@@ -438,6 +440,7 @@ object MapGroupsExec {
       valueDeserializer: Expression,
       groupingAttributes: Seq[Attribute],
       dataAttributes: Seq[Attribute],
+      dataOrder: Seq[SortOrder],
       outputObjAttr: Attribute,
       timeoutConf: GroupStateTimeout,
       child: SparkPlan): MapGroupsExec = {
@@ -449,7 +452,7 @@ object MapGroupsExec {
       func(key, values, GroupStateImpl.createForBatch(timeoutConf, 
watermarkPresent))
     }
     new MapGroupsExec(f, keyDeserializer, valueDeserializer,
-      groupingAttributes, dataAttributes, outputObjAttr, child)
+      groupingAttributes, dataAttributes, dataOrder, outputObjAttr, child)
   }
 }
 
@@ -623,6 +626,8 @@ case class CoGroupExec(
     rightGroup: Seq[Attribute],
     leftAttr: Seq[Attribute],
     rightAttr: Seq[Attribute],
+    leftOrder: Seq[SortOrder],
+    rightOrder: Seq[SortOrder],
     outputObjAttr: Attribute,
     left: SparkPlan,
     right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
@@ -631,7 +636,9 @@ case class CoGroupExec(
     ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: 
Nil
 
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, 
Ascending)) :: Nil
+    (leftGroup.map(SortOrder(_, Ascending)) ++ leftOrder) ::
+      (rightGroup.map(SortOrder(_, Ascending)) ++ rightOrder) ::
+      Nil
 
   override protected def doExecute(): RDD[InternalRow] = {
     left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 138029e76c1..760681e81c9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -510,12 +510,12 @@ object FlatMapGroupsWithStateExec {
       }
       CoGroupExec(
         func, keyDeserializer, valueDeserializer, initialStateDeserializer, 
groupingAttributes,
-        initialStateGroupAttrs, dataAttributes, initialStateDataAttrs, 
outputObjAttr,
-        child, initialState)
+        initialStateGroupAttrs, dataAttributes, initialStateDataAttrs, 
Seq.empty, Seq.empty,
+        outputObjAttr, child, initialState)
     } else {
       MapGroupsExec(
         userFunc, keyDeserializer, valueDeserializer, groupingAttributes,
-        dataAttributes, outputObjAttr, timeoutConf, child)
+        dataAttributes, Seq.empty, outputObjAttr, timeoutConf, child)
     }
   }
 }
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 120c95aa866..228b7855142 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -302,7 +302,7 @@ public class JavaDatasetSuite implements Serializable {
   }
 
   @Test
-  public void testGroupBy() {
+  public void testGroupByKey() {
     List<String> data = Arrays.asList("a", "foo", "bar");
     Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
     KeyValueGroupedDataset<Integer, String> grouped =
@@ -327,9 +327,21 @@ public class JavaDatasetSuite implements Serializable {
           }
           return Collections.singletonList(sb.toString()).iterator();
         },
-      Encoders.STRING());
+        Encoders.STRING());
 
     Assert.assertEquals(asSet("1a", "3foobar"), 
toSet(flatMapped.collectAsList()));
+    Dataset<String> flatMapSorted = grouped.flatMapSortedGroups(
+        new Column[] { ds.col("value") },
+        (FlatMapGroupsFunction<Integer, String, String>) (key, values) -> {
+          StringBuilder sb = new StringBuilder(key.toString());
+          while (values.hasNext()) {
+            sb.append(values.next());
+          }
+          return Collections.singletonList(sb.toString()).iterator();
+        },
+        Encoders.STRING());
+
+    Assert.assertEquals(asSet("1a", "3barfoo"), 
toSet(flatMapSorted.collectAsList()));
 
     Dataset<String> mapped2 = grouped.mapGroupsWithState(
         (MapGroupsWithStateFunction<Integer, String, Long, String>) (key, 
values, s) -> {
@@ -352,10 +364,10 @@ public class JavaDatasetSuite implements Serializable {
           }
           return Collections.singletonList(sb.toString()).iterator();
         },
-      OutputMode.Append(),
-      Encoders.LONG(),
-      Encoders.STRING(),
-      GroupStateTimeout.NoTimeout());
+        OutputMode.Append(),
+        Encoders.LONG(),
+        Encoders.STRING(),
+        GroupStateTimeout.NoTimeout());
 
     Assert.assertEquals(asSet("1a", "3foobar"), 
toSet(flatMapped2.collectAsList()));
 
@@ -366,7 +378,7 @@ public class JavaDatasetSuite implements Serializable {
       asSet(tuple2(1, "a"), tuple2(3, "foobar")),
       toSet(reduced.collectAsList()));
 
-    List<Integer> data2 = Arrays.asList(2, 6, 10);
+    List<Integer> data2 = Arrays.asList(2, 6, 7, 10);
     Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT());
     KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(
         (MapFunction<Integer, Integer>) v -> v / 2,
@@ -387,7 +399,26 @@ public class JavaDatasetSuite implements Serializable {
       },
       Encoders.STRING());
 
-    Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), 
toSet(cogrouped.collectAsList()));
+    Assert.assertEquals(asSet("1a#2", "3foobar#67", "5#10"), 
toSet(cogrouped.collectAsList()));
+
+    Dataset<String> cogroupSorted = grouped.cogroupSorted(
+      grouped2,
+      new Column[] { ds.col("value") },
+      new Column[] { ds2.col("value").desc() },
+      (CoGroupFunction<Integer, String, Integer, String>) (key, left, right) 
-> {
+        StringBuilder sb = new StringBuilder(key.toString());
+        while (left.hasNext()) {
+          sb.append(left.next());
+        }
+        sb.append("#");
+        while (right.hasNext()) {
+          sb.append(right.next());
+        }
+        return Collections.singletonList(sb.toString()).iterator();
+      },
+      Encoders.STRING());
+
+    Assert.assertEquals(asSet("1a#2", "3barfoo#76", "5#10"), 
toSet(cogroupSorted.collectAsList()));
   }
 
   @Test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d298d7129c7..8b48d7e7827 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -573,6 +573,38 @@ class DatasetSuite extends QueryTest
       "a", "30", "b", "3", "c", "1")
   }
 
+  test("groupBy function, flatMapSorted") {
+    val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
+      .toDF("key", "seq", "value")
+    val grouped = ds.groupByKey(v => (v.getString(0), "word"))
+    val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)")) {
+      (g, iter) => Iterator(g._1, iter.mkString(", "))
+    }
+
+    checkDatasetUnorderly(
+      aggregated,
+      "a", "[a,1,10], [a,2,20]",
+      "b", "[b,1,2], [b,2,1]",
+      "c", "[c,1,1]"
+    )
+  }
+
+  test("groupBy function, flatMapSorted desc") {
+    val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
+      .toDF("key", "seq", "value")
+    val grouped = ds.groupByKey(v => (v.getString(0), "word"))
+    val aggregated = grouped.flatMapSortedGroups($"seq".desc, 
expr("length(key)")) {
+      (g, iter) => Iterator(g._1, iter.mkString(", "))
+    }
+
+    checkDatasetUnorderly(
+      aggregated,
+      "a", "[a,2,20], [a,1,10]",
+      "b", "[b,2,1], [b,1,2]",
+      "c", "[c,1,1]"
+    )
+  }
+
   test("groupBy function, mapValues, flatMap") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val keyValue = ds.groupByKey(_._1).mapValues(_._2)
@@ -727,6 +759,42 @@ class DatasetSuite extends QueryTest
       1 -> "a", 2 -> "bc", 3 -> "d")
   }
 
+  test("cogroup sorted") {
+    val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> 
"ijk").toDS()
+    val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> 
"y").toDS()
+    val groupedLeft = left.groupByKey(_._1)
+    val groupedRight = right.groupByKey(_._1)
+
+    val neitherSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#wa", 
5 -> "hello#xzy")
+    val leftSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#wa", 5 
-> "hello#xzy")
+    val rightSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#aw", 5 
-> "hello#xyz")
+    val bothSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#aw", 5 
-> "hello#xyz")
+    val bothDescSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> 
"xyzijkabc#wa", 5 -> "hello#zyx")
+
+    val leftOrder = Seq(left("_2"))
+    val rightOrder = Seq(right("_2"))
+    val leftDescOrder = Seq(left("_2").desc)
+    val rightDescOrder = Seq(right("_2").desc)
+    val none = Seq.empty
+
+    Seq(
+      ("neither", none, none, neitherSortedExpected),
+      ("left", leftOrder, none, leftSortedExpected),
+      ("right", none, rightOrder, rightSortedExpected),
+      ("both", leftOrder, rightOrder, bothSortedExpected),
+      ("both desc", leftDescOrder, rightDescOrder, bothDescSortedExpected)
+    ).foreach { case (label, leftOrder, rightOrder, expected) =>
+      withClue(s"$label sorted") {
+        val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder: 
_*)(rightOrder: _*) {
+          (key, left, right) =>
+            Iterator(key -> (left.map(_._2).mkString + "#" + 
right.map(_._2).mkString))
+        }
+
+        checkDatasetUnorderly(cogrouped, expected.toList: _*)
+      }
+    }
+  }
+
   test("SPARK-34806: observation on datasets") {
     val namedObservation = Observation("named")
     val unnamedObservation = Observation()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to