This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6382a3be8f6 [SPARK-38591][SQL] Add flatMapSortedGroups and
cogroupSorted
6382a3be8f6 is described below
commit 6382a3be8f6a937412c4e23a92ab440f7ba80bdf
Author: Enrico Minack <[email protected]>
AuthorDate: Fri Jan 20 13:21:29 2023 +0800
[SPARK-38591][SQL] Add flatMapSortedGroups and cogroupSorted
### What changes were proposed in this pull request?
This adds a sorted version of `Dataset.groupByKey(…).flatMapGroups(…)` and
`Dataset.groupByKey(…).cogroup(…)`.
### Why are the changes needed?
The existing methods `KeyValueGroupedDataset.flatMapGroups` and
`KeyValueGroupedDataset.cogroup` provide iterators of rows for each group key.
Sorting entire groups inside `flatMapGroups` / `cogroup` requires
materialising all rows, which is against the idea of an iterator in the first
place. Methods `flatMapGroups` and `cogroup` have the great advantage that they
work with groups that are _too large to fit into memory of one executor_.
Sorting them in the user function breaks this property.
[org.apache.spark.sql.KeyValueGroupedDataset](https://github.com/apache/spark/blob/47485a3c2df3201c838b939e82d5b26332e2d858/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala#L134-L137):
> Internally, the implementation will spill to disk if any given group is
too large to fit into
> memory. However, users must take care to avoid materializing the whole
iterator for a group
> (for example, by calling `toList`) unless they are sure that this is
possible given the memory
> constraints of their cluster.
The implementations of `KeyValueGroupedDataset.flatMapGroups` and
`KeyValueGroupedDataset.cogroup` already sort each partition according to the
group key. By additionally sorting by some data columns, the iterator can be
guaranteed to provide some order.
### Does this PR introduce _any_ user-facing change?
This adds `KeyValueGroupedDataset.flatMapSortedGroups` and
`KeyValueGroupedDataset.cogroupSorted`, which guarantees order of group
iterators.
### How was this patch tested?
Tests have been added to `DatasetSuite` and `JavaDatasetSuite`.
Closes #39640 from EnricoMi/branch-sorted-groups-and-cogroups.
Authored-by: Enrico Minack <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../apache/spark/sql/catalyst/dsl/package.scala | 6 +-
.../spark/sql/catalyst/plans/logical/object.scala | 12 +-
.../apache/spark/sql/KeyValueGroupedDataset.scala | 148 ++++++++++++++++++++-
.../spark/sql/execution/SparkStrategies.scala | 11 +-
.../org/apache/spark/sql/execution/objects.scala | 15 ++-
.../streaming/FlatMapGroupsWithStateExec.scala | 6 +-
.../org/apache/spark/sql/JavaDatasetSuite.java | 47 +++++--
.../scala/org/apache/spark/sql/DatasetSuite.scala | 68 ++++++++++
8 files changed, 291 insertions(+), 22 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 86d85abc6f3..ecd1ed94ffd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -425,7 +425,9 @@ package object dsl {
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
leftAttr: Seq[Attribute],
- rightAttr: Seq[Attribute]
+ rightAttr: Seq[Attribute],
+ leftOrder: Seq[SortOrder] = Nil,
+ rightOrder: Seq[SortOrder] = Nil
): LogicalPlan = {
CoGroup.apply[Key, Left, Right, Result](
func,
@@ -433,6 +435,8 @@ package object dsl {
rightGroup,
leftAttr,
rightAttr,
+ leftOrder,
+ rightOrder,
logicalPlan,
otherPlan)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index e5fe07e2d95..b27c650cfb2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -389,6 +389,7 @@ object MapGroups {
func: (K, Iterator[T]) => TraversableOnce[U],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
+ dataOrder: Seq[SortOrder],
child: LogicalPlan): LogicalPlan = {
val mapped = new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
@@ -396,6 +397,7 @@ object MapGroups {
UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
groupingAttributes,
dataAttributes,
+ dataOrder,
CatalystSerde.generateObjAttr[U],
child)
CatalystSerde.serialize[U](mapped)
@@ -405,7 +407,8 @@ object MapGroups {
/**
* Applies func to each unique group in `child`, based on the evaluation of
`groupingAttributes`.
* Func is invoked with an object representation of the grouping key an
iterator containing the
- * object representation of all the rows with that key.
+ * object representation of all the rows with that key. Given an additional
`dataOrder`, data in
+ * the iterator will be sorted accordingly. That sorting does not add
computational complexity.
*
* @param keyDeserializer used to extract the key object for each group.
* @param valueDeserializer used to extract the items in the iterator from an
input row.
@@ -416,6 +419,7 @@ case class MapGroups(
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
+ dataOrder: Seq[SortOrder],
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectProducer {
override protected def withNewChildInternal(newChild: LogicalPlan):
MapGroups =
@@ -649,6 +653,8 @@ object CoGroup {
rightGroup: Seq[Attribute],
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
+ leftOrder: Seq[SortOrder],
+ rightOrder: Seq[SortOrder],
left: LogicalPlan,
right: LogicalPlan): LogicalPlan = {
require(StructType.fromAttributes(leftGroup) ==
StructType.fromAttributes(rightGroup))
@@ -664,6 +670,8 @@ object CoGroup {
rightGroup,
leftAttr,
rightAttr,
+ leftOrder,
+ rightOrder,
CatalystSerde.generateObjAttr[OUT],
left,
right)
@@ -684,6 +692,8 @@ case class CoGroup(
rightGroup: Seq[Attribute],
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
+ leftOrder: Seq[SortOrder],
+ rightOrder: Seq[SortOrder],
outputObjAttr: Attribute,
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode with ObjectProducer {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index add692f57d2..4d2377b9b96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
CreateStruct}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute,
CreateStruct, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
@@ -145,6 +145,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
f,
groupingAttributes,
dataAttributes,
+ Seq.empty,
logicalPlan))
}
@@ -171,6 +172,83 @@ class KeyValueGroupedDataset[K, V] private[sql](
flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
}
+ /**
+ * (Scala-specific)
+ * Applies the given function to each group of data. For each unique group,
the function will
+ * be passed the group key and a sorted iterator that contains all of the
elements in the group.
+ * The function can return an iterator containing elements of an arbitrary
type which will be
+ * returned as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result
requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an
aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is
too large to fit into
+ * memory. However, users must take care to avoid materializing the whole
iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is
possible given the memory
+ * constraints of their cluster.
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except
for the iterator
+ * to be sorted according to the given sort expressions. That sorting does
not add
+ * computational complexity.
+ *
+ * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+ * @since 3.4.0
+ */
+ def flatMapSortedGroups[U : Encoder](
+ sortExprs: Column*)(
+ f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+ val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+ col.expr match {
+ case expr: SortOrder => expr
+ case expr: Expression => SortOrder(expr, Ascending)
+ }
+ }
+
+ Dataset[U](
+ sparkSession,
+ MapGroups(
+ f,
+ groupingAttributes,
+ dataAttributes,
+ sortOrder,
+ logicalPlan
+ )
+ )
+ }
+
+ /**
+ * (Java-specific)
+ * Applies the given function to each group of data. For each unique group,
the function will
+ * be passed the group key and a sorted iterator that contains all of the
elements in the group.
+ * The function can return an iterator containing elements of an arbitrary
type which will be
+ * returned as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result
requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an
aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is
too large to fit into
+ * memory. However, users must take care to avoid materializing the whole
iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is
possible given the memory
+ * constraints of their cluster.
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except
for the iterator
+ * to be sorted according to the given sort expressions. That sorting does
not add
+ * computational complexity.
+ *
+ * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+ * @since 3.4.0
+ */
+ def flatMapSortedGroups[U](
+ SortExprs: Array[Column],
+ f: FlatMapGroupsFunction[K, V, U],
+ encoder: Encoder[U]): Dataset[U] = {
+ flatMapSortedGroups(SortExprs: _*)((key, data) => f.call(key,
data.asJava).asScala)(encoder)
+ }
+
/**
* (Scala-specific)
* Applies the given function to each group of data. For each unique group,
the function will
@@ -753,6 +831,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
other.groupingAttributes,
this.dataAttributes,
other.dataAttributes,
+ Seq.empty,
+ Seq.empty,
this.logicalPlan,
other.logicalPlan))
}
@@ -773,6 +853,72 @@ class KeyValueGroupedDataset[K, V] private[sql](
cogroup(other)((key, left, right) => f.call(key, left.asJava,
right.asJava).asScala)(encoder)
}
+ /**
+ * (Scala-specific)
+ * Applies the given function to each sorted cogrouped data. For each
unique group, the function
+ * will be passed the grouping key and 2 sorted iterators containing all
elements in the group
+ * from [[Dataset]] `this` and `other`. The function can return an iterator
containing elements
+ * of an arbitrary type which will be returned as a new [[Dataset]].
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the
iterators
+ * to be sorted according to the given sort expressions. That sorting does
not add
+ * computational complexity.
+ *
+ * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+ * @since 3.4.0
+ */
+ def cogroupSorted[U, R : Encoder](
+ other: KeyValueGroupedDataset[K, U])(
+ thisSortExprs: Column*)(
+ otherSortExprs: Column*)(
+ f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+ def toSortOrder(col: Column): SortOrder = col.expr match {
+ case expr: SortOrder => expr
+ case expr: Expression => SortOrder(expr, Ascending)
+ }
+
+ val thisSortOrder: Seq[SortOrder] = thisSortExprs.map(toSortOrder)
+ val otherSortOrder: Seq[SortOrder] = otherSortExprs.map(toSortOrder)
+
+ implicit val uEncoder = other.vExprEnc
+ Dataset[R](
+ sparkSession,
+ CoGroup(
+ f,
+ this.groupingAttributes,
+ other.groupingAttributes,
+ this.dataAttributes,
+ other.dataAttributes,
+ thisSortOrder,
+ otherSortOrder,
+ this.logicalPlan,
+ other.logicalPlan))
+ }
+
+ /**
+ * (Java-specific)
+ * Applies the given function to each sorted cogrouped data. For each
unique group, the function
+ * will be passed the grouping key and 2 sorted iterators containing all
elements in the group
+ * from [[Dataset]] `this` and `other`. The function can return an iterator
containing elements
+ * of an arbitrary type which will be returned as a new [[Dataset]].
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the
iterators
+ * to be sorted according to the given sort expressions. That sorting does
not add
+ * computational complexity.
+ *
+ * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+ * @since 3.4.0
+ */
+ def cogroupSorted[U, R](
+ other: KeyValueGroupedDataset[K, U],
+ thisSortExprs: Array[Column],
+ otherSortExprs: Array[Column],
+ f: CoGroupFunction[K, V, U, R],
+ encoder: Encoder[R]): Dataset[R] = {
+ cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
+ (key, left, right) => f.call(key, left.asJava,
right.asJava).asScala)(encoder)
+ }
+
override def toString: String = {
val builder = new StringBuilder
val kFields = kExprEnc.schema.map {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 110fe45cc12..cd4485e3822 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -806,8 +806,10 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
execution.AppendColumnsWithObjectExec(f, childSer, newSer,
planLater(child)) :: Nil
- case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
- execution.MapGroupsExec(f, key, value, grouping, data, objAttr,
planLater(child)) :: Nil
+ case logical.MapGroups(f, key, value, grouping, data, order, objAttr,
child) =>
+ execution.MapGroupsExec(
+ f, key, value, grouping, data, order, objAttr, planLater(child)
+ ) :: Nil
case logical.FlatMapGroupsWithState(
f, keyDeserializer, valueDeserializer, grouping, data, output,
stateEncoder, outputMode,
isFlatMapGroupsWithState, timeout, hasInitialState,
initialStateGroupAttrs,
@@ -821,9 +823,10 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
// TODO(SPARK-40443): support applyInPandasWithState in batch query
throw new UnsupportedOperationException(
"applyInPandasWithState is unsupported in batch query. Use
applyInPandas instead.")
- case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr,
oAttr, left, right) =>
+ case logical.CoGroup(
+ f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, lOrder, rOrder,
oAttr, left, right) =>
execution.CoGroupExec(
- f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
+ f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, lOrder, rOrder,
oAttr,
planLater(left), planLater(right)) :: Nil
case r @ logical.Repartition(numPartitions, shuffle, child) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 869d3fe9790..bda592ff929 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -391,7 +391,8 @@ case class AppendColumnsWithObjectExec(
/**
* Groups the input rows together and calls the function with each group and
an iterator containing
- * all elements in the group. The result of this function is flattened before
being output.
+ * all elements in the group. The iterator is sorted according to `dataOrder`
if given.
+ * The result of this function is flattened before being output.
*/
case class MapGroupsExec(
func: (Any, Iterator[Any]) => TraversableOnce[Any],
@@ -399,6 +400,7 @@ case class MapGroupsExec(
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
+ dataOrder: Seq[SortOrder],
outputObjAttr: Attribute,
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
@@ -408,7 +410,7 @@ case class MapGroupsExec(
ClusteredDistribution(groupingAttributes) :: Nil
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
- Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+ Seq(groupingAttributes.map(SortOrder(_, Ascending)) ++ dataOrder)
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
@@ -438,6 +440,7 @@ object MapGroupsExec {
valueDeserializer: Expression,
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
+ dataOrder: Seq[SortOrder],
outputObjAttr: Attribute,
timeoutConf: GroupStateTimeout,
child: SparkPlan): MapGroupsExec = {
@@ -449,7 +452,7 @@ object MapGroupsExec {
func(key, values, GroupStateImpl.createForBatch(timeoutConf,
watermarkPresent))
}
new MapGroupsExec(f, keyDeserializer, valueDeserializer,
- groupingAttributes, dataAttributes, outputObjAttr, child)
+ groupingAttributes, dataAttributes, dataOrder, outputObjAttr, child)
}
}
@@ -623,6 +626,8 @@ case class CoGroupExec(
rightGroup: Seq[Attribute],
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
+ leftOrder: Seq[SortOrder],
+ rightOrder: Seq[SortOrder],
outputObjAttr: Attribute,
left: SparkPlan,
right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
@@ -631,7 +636,9 @@ case class CoGroupExec(
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) ::
Nil
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
- leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_,
Ascending)) :: Nil
+ (leftGroup.map(SortOrder(_, Ascending)) ++ leftOrder) ::
+ (rightGroup.map(SortOrder(_, Ascending)) ++ rightOrder) ::
+ Nil
override protected def doExecute(): RDD[InternalRow] = {
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 138029e76c1..760681e81c9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -510,12 +510,12 @@ object FlatMapGroupsWithStateExec {
}
CoGroupExec(
func, keyDeserializer, valueDeserializer, initialStateDeserializer,
groupingAttributes,
- initialStateGroupAttrs, dataAttributes, initialStateDataAttrs,
outputObjAttr,
- child, initialState)
+ initialStateGroupAttrs, dataAttributes, initialStateDataAttrs,
Seq.empty, Seq.empty,
+ outputObjAttr, child, initialState)
} else {
MapGroupsExec(
userFunc, keyDeserializer, valueDeserializer, groupingAttributes,
- dataAttributes, outputObjAttr, timeoutConf, child)
+ dataAttributes, Seq.empty, outputObjAttr, timeoutConf, child)
}
}
}
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 120c95aa866..228b7855142 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -302,7 +302,7 @@ public class JavaDatasetSuite implements Serializable {
}
@Test
- public void testGroupBy() {
+ public void testGroupByKey() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
KeyValueGroupedDataset<Integer, String> grouped =
@@ -327,9 +327,21 @@ public class JavaDatasetSuite implements Serializable {
}
return Collections.singletonList(sb.toString()).iterator();
},
- Encoders.STRING());
+ Encoders.STRING());
Assert.assertEquals(asSet("1a", "3foobar"),
toSet(flatMapped.collectAsList()));
+ Dataset<String> flatMapSorted = grouped.flatMapSortedGroups(
+ new Column[] { ds.col("value") },
+ (FlatMapGroupsFunction<Integer, String, String>) (key, values) -> {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (values.hasNext()) {
+ sb.append(values.next());
+ }
+ return Collections.singletonList(sb.toString()).iterator();
+ },
+ Encoders.STRING());
+
+ Assert.assertEquals(asSet("1a", "3barfoo"),
toSet(flatMapSorted.collectAsList()));
Dataset<String> mapped2 = grouped.mapGroupsWithState(
(MapGroupsWithStateFunction<Integer, String, Long, String>) (key,
values, s) -> {
@@ -352,10 +364,10 @@ public class JavaDatasetSuite implements Serializable {
}
return Collections.singletonList(sb.toString()).iterator();
},
- OutputMode.Append(),
- Encoders.LONG(),
- Encoders.STRING(),
- GroupStateTimeout.NoTimeout());
+ OutputMode.Append(),
+ Encoders.LONG(),
+ Encoders.STRING(),
+ GroupStateTimeout.NoTimeout());
Assert.assertEquals(asSet("1a", "3foobar"),
toSet(flatMapped2.collectAsList()));
@@ -366,7 +378,7 @@ public class JavaDatasetSuite implements Serializable {
asSet(tuple2(1, "a"), tuple2(3, "foobar")),
toSet(reduced.collectAsList()));
- List<Integer> data2 = Arrays.asList(2, 6, 10);
+ List<Integer> data2 = Arrays.asList(2, 6, 7, 10);
Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT());
KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(
(MapFunction<Integer, Integer>) v -> v / 2,
@@ -387,7 +399,26 @@ public class JavaDatasetSuite implements Serializable {
},
Encoders.STRING());
- Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"),
toSet(cogrouped.collectAsList()));
+ Assert.assertEquals(asSet("1a#2", "3foobar#67", "5#10"),
toSet(cogrouped.collectAsList()));
+
+ Dataset<String> cogroupSorted = grouped.cogroupSorted(
+ grouped2,
+ new Column[] { ds.col("value") },
+ new Column[] { ds2.col("value").desc() },
+ (CoGroupFunction<Integer, String, Integer, String>) (key, left, right)
-> {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (left.hasNext()) {
+ sb.append(left.next());
+ }
+ sb.append("#");
+ while (right.hasNext()) {
+ sb.append(right.next());
+ }
+ return Collections.singletonList(sb.toString()).iterator();
+ },
+ Encoders.STRING());
+
+ Assert.assertEquals(asSet("1a#2", "3barfoo#76", "5#10"),
toSet(cogroupSorted.collectAsList()));
}
@Test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d298d7129c7..8b48d7e7827 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -573,6 +573,38 @@ class DatasetSuite extends QueryTest
"a", "30", "b", "3", "c", "1")
}
+ test("groupBy function, flatMapSorted") {
+ val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c",
1, 1))
+ .toDF("key", "seq", "value")
+ val grouped = ds.groupByKey(v => (v.getString(0), "word"))
+ val aggregated = grouped.flatMapSortedGroups($"seq", expr("length(key)")) {
+ (g, iter) => Iterator(g._1, iter.mkString(", "))
+ }
+
+ checkDatasetUnorderly(
+ aggregated,
+ "a", "[a,1,10], [a,2,20]",
+ "b", "[b,1,2], [b,2,1]",
+ "c", "[c,1,1]"
+ )
+ }
+
+ test("groupBy function, flatMapSorted desc") {
+ val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c",
1, 1))
+ .toDF("key", "seq", "value")
+ val grouped = ds.groupByKey(v => (v.getString(0), "word"))
+ val aggregated = grouped.flatMapSortedGroups($"seq".desc,
expr("length(key)")) {
+ (g, iter) => Iterator(g._1, iter.mkString(", "))
+ }
+
+ checkDatasetUnorderly(
+ aggregated,
+ "a", "[a,2,20], [a,1,10]",
+ "b", "[b,2,1], [b,1,2]",
+ "c", "[c,1,1]"
+ )
+ }
+
test("groupBy function, mapValues, flatMap") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val keyValue = ds.groupByKey(_._1).mapValues(_._2)
@@ -727,6 +759,42 @@ class DatasetSuite extends QueryTest
1 -> "a", 2 -> "bc", 3 -> "d")
}
+ test("cogroup sorted") {
+ val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 ->
"ijk").toDS()
+ val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 ->
"y").toDS()
+ val groupedLeft = left.groupByKey(_._1)
+ val groupedRight = right.groupByKey(_._1)
+
+ val neitherSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#wa",
5 -> "hello#xzy")
+ val leftSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#wa", 5
-> "hello#xzy")
+ val rightSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "xyzabcijk#aw", 5
-> "hello#xyz")
+ val bothSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 -> "abcijkxyz#aw", 5
-> "hello#xyz")
+ val bothDescSortedExpected = Seq(1 -> "a#", 2 -> "#q", 3 ->
"xyzijkabc#wa", 5 -> "hello#zyx")
+
+ val leftOrder = Seq(left("_2"))
+ val rightOrder = Seq(right("_2"))
+ val leftDescOrder = Seq(left("_2").desc)
+ val rightDescOrder = Seq(right("_2").desc)
+ val none = Seq.empty
+
+ Seq(
+ ("neither", none, none, neitherSortedExpected),
+ ("left", leftOrder, none, leftSortedExpected),
+ ("right", none, rightOrder, rightSortedExpected),
+ ("both", leftOrder, rightOrder, bothSortedExpected),
+ ("both desc", leftDescOrder, rightDescOrder, bothDescSortedExpected)
+ ).foreach { case (label, leftOrder, rightOrder, expected) =>
+ withClue(s"$label sorted") {
+ val cogrouped = groupedLeft.cogroupSorted(groupedRight)(leftOrder:
_*)(rightOrder: _*) {
+ (key, left, right) =>
+ Iterator(key -> (left.map(_._2).mkString + "#" +
right.map(_._2).mkString))
+ }
+
+ checkDatasetUnorderly(cogrouped, expected.toList: _*)
+ }
+ }
+ }
+
test("SPARK-34806: observation on datasets") {
val namedObservation = Observation("named")
val unnamedObservation = Observation()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]