Repository: spark Updated Branches: refs/heads/master 2082a4956 -> dcaa01661
[SPARK-13897][SQL] RelationalGroupedDataset and KeyValueGroupedDataset ## What changes were proposed in this pull request? Previously, Dataset.groupBy returns a GroupedData, and Dataset.groupByKey returns a GroupedDataset. The naming is very similar, and unfortunately does not convey the real differences between the two. Assume we are grouping by some keys (K). groupByKey is a key-value style group by, in which the schema of the returned dataset is a tuple of just two fields: key and value. groupBy, on the other hand, is a relational style group by, in which the schema of the returned dataset is flattened and contain |K| + |V| fields. This pull request also removes the experimental tag from RelationalGroupedDataset. It has been with DataFrame since 1.3, and we have enough confidence now to stabilize it. ## How was this patch tested? This is a rename to improve API understandability. Should be covered by all existing tests. Author: Reynold Xin <[email protected]> Closes #11841 from rxin/SPARK-13897. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dcaa0166 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dcaa0166 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dcaa0166 Branch: refs/heads/master Commit: dcaa016610ac2c11d7dd01803f3515b02ab32e64 Parents: 2082a49 Author: Reynold Xin <[email protected]> Authored: Sat Mar 19 11:23:14 2016 -0700 Committer: Reynold Xin <[email protected]> Committed: Sat Mar 19 11:23:14 2016 -0700 ---------------------------------------------------------------------- project/MimaExcludes.scala | 1 + .../scala/org/apache/spark/sql/Dataset.scala | 56 +-- .../org/apache/spark/sql/GroupedData.scala | 417 ------------------- .../org/apache/spark/sql/GroupedDataset.scala | 337 --------------- .../spark/sql/KeyValueGroupedDataset.scala | 336 +++++++++++++++ .../spark/sql/RelationalGroupedDataset.scala | 414 ++++++++++++++++++ .../org/apache/spark/sql/JavaDatasetSuite.java | 8 +- 7 files changed, 786 insertions(+), 783 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/dcaa0166/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b38eec3..9a091bf 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -315,6 +315,7 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), http://git-wip-us.apache.org/repos/asf/spark/blob/dcaa0166/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 39f7f35..6e7d208 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1036,7 +1036,7 @@ class Dataset[T] private[sql]( /** * Groups the [[Dataset]] using the specified columns, so we can run aggregation on them. See - * [[GroupedData]] for all the available aggregate functions. + * [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns grouped by department. @@ -1053,14 +1053,14 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedData = { - GroupedData(toDF(), cols.map(_.expr), GroupedData.GroupByType) + def groupBy(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) } /** * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns rolluped by department and group. @@ -1077,14 +1077,14 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def rollup(cols: Column*): GroupedData = { - GroupedData(toDF(), cols.map(_.expr), GroupedData.RollupType) + def rollup(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.RollupType) } /** * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns cubed by department and group. @@ -1101,11 +1101,13 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def cube(cols: Column*): GroupedData = GroupedData(toDF(), cols.map(_.expr), GroupedData.CubeType) + def cube(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.CubeType) + } /** * Groups the [[Dataset]] using the specified columns, so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of groupBy that can only group by existing columns using column names * (i.e. cannot construct expressions). @@ -1124,9 +1126,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(col1: String, cols: String*): GroupedData = { + def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.GroupByType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.GroupByType) } /** @@ -1156,18 +1159,18 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Scala-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. * * @group typedrel * @since 2.0.0 */ @Experimental - def groupByKey[K: Encoder](func: T => K): GroupedDataset[K, T] = { + def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) - new GroupedDataset( + new KeyValueGroupedDataset( encoderFor[K], encoderFor[T], executed, @@ -1177,14 +1180,15 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given [[Column]] + * expressions. * * @group typedrel * @since 2.0.0 */ @Experimental @scala.annotation.varargs - def groupByKey(cols: Column*): GroupedDataset[Row, T] = { + def groupByKey(cols: Column*): KeyValueGroupedDataset[Row, T] = { val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) val withKey = Project(withKeyColumns, logicalPlan) val executed = sqlContext.executePlan(withKey) @@ -1192,7 +1196,7 @@ class Dataset[T] private[sql]( val dataAttributes = executed.analyzed.output.dropRight(cols.size) val keyAttributes = executed.analyzed.output.takeRight(cols.size) - new GroupedDataset( + new KeyValueGroupedDataset( RowEncoder(keyAttributes.toStructType), encoderFor[T], executed, @@ -1203,19 +1207,19 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Java-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. * * @group typedrel * @since 2.0.0 */ @Experimental - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) /** * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of rollup that can only group by existing columns using column names * (i.e. cannot construct expressions). @@ -1235,15 +1239,16 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def rollup(col1: String, cols: String*): GroupedData = { + def rollup(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.RollupType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.RollupType) } /** * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of cube that can only group by existing columns using column names * (i.e. cannot construct expressions). @@ -1262,9 +1267,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def cube(col1: String, cols: String*): GroupedData = { + def cube(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.CubeType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.CubeType) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/dcaa0166/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala deleted file mode 100644 index 04d277b..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ /dev/null @@ -1,417 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} -import org.apache.spark.sql.catalyst.util.usePrettyExpression -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.NumericType - -/** - * :: Experimental :: - * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]]. - * - * The main method is the agg function, which has multiple variants. This class also contains - * convenience some first order statistics such as mean, sum for convenience. - * - * @since 1.3.0 - */ -@Experimental -class GroupedData protected[sql]( - df: DataFrame, - groupingExprs: Seq[Expression], - groupType: GroupedData.GroupType) { - - private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { - val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - groupingExprs ++ aggExprs - } else { - aggExprs - } - - val aliasedAgg = aggregates.map(alias) - - groupType match { - case GroupedData.GroupByType => - Dataset.newDataFrame( - df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) - case GroupedData.RollupType => - Dataset.newDataFrame( - df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) - case GroupedData.CubeType => - Dataset.newDataFrame( - df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) - case GroupedData.PivotType(pivotCol, values) => - val aliasedGrps = groupingExprs.map(alias) - Dataset.newDataFrame( - df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) - } - } - - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - private[this] def alias(expr: Expression): NamedExpression = expr match { - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() - } - - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) - : DataFrame = { - - val columnExprs = if (colNames.isEmpty) { - // No columns specified. Use all numeric columns. - df.numericColumns - } else { - // Make sure all specified columns are numeric. - colNames.map { colName => - val namedExpr = df.resolve(colName) - if (!namedExpr.dataType.isInstanceOf[NumericType]) { - throw new AnalysisException( - s""""$colName" is not a numeric column. """ + - "Aggregation function can only be applied on a numeric column.") - } - namedExpr - } - } - toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) - } - - private[this] def strToExpr(expr: String): (Expression => Expression) = { - val exprToFunc: (Expression => Expression) = { - (inputExpr: Expression) => expr.toLowerCase match { - // We special handle a few cases that have alias that are not in function registry. - case "avg" | "average" | "mean" => - UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) - case "stddev" | "std" => - UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) - // Also special handle count because we need to take care count(*). - case "count" | "size" => - // Turn count(*) into count(1) - inputExpr match { - case s: Star => Count(Literal(1)).toAggregateExpression() - case _ => Count(inputExpr).toAggregateExpression() - } - case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) - } - } - (inputExpr: Expression) => exprToFunc(inputExpr) - } - - /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} - * - * @since 1.3.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - agg((aggExpr +: aggExprs).toMap) - } - - /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg(Map( - * "age" -> "max", - * "expense" -> "sum" - * )) - * }}} - * - * @since 1.3.0 - */ - def agg(exprs: Map[String, String]): DataFrame = { - toDF(exprs.map { case (colName, expr) => - strToExpr(expr)(df(colName).expr) - }.toSeq) - } - - /** - * (Java-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); - * }}} - * - * @since 1.3.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.asScala.toMap) - } - - /** - * Compute aggregates by specifying a series of aggregate columns. Note that this function by - * default retains the grouping columns in its output. To not retain grouping columns, set - * `spark.sql.retainGroupColumns` to false. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change - * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. - * {{{ - * // Scala, 1.3.x: - * df.groupBy("department").agg($"department", max("age"), sum("expense")) - * - * // Java, 1.3.x: - * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); - * }}} - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr)) - } - - /** - * Count the number of rows for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * - * @since 1.3.0 - */ - def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) - - /** - * Compute the average value for each numeric columns for each group. This is an alias for `avg`. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the average values for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) - } - - /** - * Compute the max value for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the max values for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Max) - } - - /** - * Compute the mean value for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the mean values for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) - } - - /** - * Compute the min value for each numeric column for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the min values for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Min) - } - - /** - * Compute the sum for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Sum) - } - - /** - * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * @param pivotColumn Name of the column to pivot. - * @since 1.6.0 - */ - def pivot(pivotColumn: String): GroupedData = { - // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) - // Get the distinct values of the column and sort them so its consistent - val values = df.select(pivotColumn) - .distinct() - .sort(pivotColumn) // ensure that the output columns are in a consistent logical order - .rdd - .map(_.get(0)) - .take(maxValues + 1) - .toSeq - - if (values.length > maxValues) { - throw new AnalysisException( - s"The pivot column $pivotColumn has more than $maxValues distinct values, " + - "this could indicate an error. " + - s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + - "to at least the number of distinct values of the pivot column.") - } - - pivot(pivotColumn, values) - } - - /** - * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * @param pivotColumn Name of the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 - */ - def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { - groupType match { - case GroupedData.GroupByType => - new GroupedData( - df, - groupingExprs, - GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) - case _: GroupedData.PivotType => - throw new UnsupportedOperationException("repeated pivots are not supported") - case _ => - throw new UnsupportedOperationException("pivot is only supported after a groupBy") - } - } - - /** - * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} - * - * @param pivotColumn Name of the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 - */ - def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { - pivot(pivotColumn, values.asScala) - } -} - - -/** - * Companion object for GroupedData. - */ -private[sql] object GroupedData { - - def apply( - df: DataFrame, - groupingExprs: Seq[Expression], - groupType: GroupType): GroupedData = { - new GroupedData(df, groupingExprs, groupType: GroupType) - } - - /** - * The Grouping Type - */ - private[sql] trait GroupType - - /** - * To indicate it's the GroupBy - */ - private[sql] object GroupByType extends GroupType - - /** - * To indicate it's the CUBE - */ - private[sql] object CubeType extends GroupType - - /** - * To indicate it's the ROLLUP - */ - private[sql] object RollupType extends GroupType - - /** - * To indicate it's the PIVOT - */ - private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType -} http://git-wip-us.apache.org/repos/asf/spark/blob/dcaa0166/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala deleted file mode 100644 index a8700de..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ /dev/null @@ -1,337 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.QueryExecution - -/** - * :: Experimental :: - * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not - * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing - * [[Dataset]]. - * - * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However, - * making this change to the class hierarchy would break some function signatures. As such, this - * class should be considered a preview of the final API. Changes will be made to the interface - * after Spark 1.6. - * - * @since 1.6.0 - */ -@Experimental -class GroupedDataset[K, V] private[sql]( - kEncoder: Encoder[K], - vEncoder: Encoder[V], - val queryExecution: QueryExecution, - private val dataAttributes: Seq[Attribute], - private val groupingAttributes: Seq[Attribute]) extends Serializable { - - // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders - // when constructing new logical plans that will operate on the output of the current - // queryexecution. - - private implicit val unresolvedKEncoder = encoderFor(kEncoder) - private implicit val unresolvedVEncoder = encoderFor(vEncoder) - - private val resolvedKEncoder = - unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) - private val resolvedVEncoder = - unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) - - private def logicalPlan = queryExecution.analyzed - private def sqlContext = queryExecution.sqlContext - - private def groupedData = - new GroupedData( - Dataset.newDataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) - - /** - * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified - * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. - * - * @since 1.6.0 - */ - def keyAs[L : Encoder]: GroupedDataset[L, V] = - new GroupedDataset( - encoderFor[L], - unresolvedVEncoder, - queryExecution, - dataAttributes, - groupingAttributes) - - /** - * Returns a [[Dataset]] that contains each unique key. - * - * @since 1.6.0 - */ - def keys: Dataset[K] = { - Dataset[K]( - sqlContext, - Distinct( - Project(groupingAttributes, logicalPlan))) - } - - /** - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an iterator containing elements of an arbitrary type which will be returned - * as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @since 1.6.0 - */ - def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { - Dataset[U]( - sqlContext, - MapGroups( - f, - groupingAttributes, - dataAttributes, - logicalPlan)) - } - - /** - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an iterator containing elements of an arbitrary type which will be returned - * as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @since 1.6.0 - */ - def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) - } - - /** - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @since 1.6.0 - */ - def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { - val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) - flatMapGroups(func) - } - - /** - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @since 1.6.0 - */ - def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroups((key, data) => f.call(key, data.asJava))(encoder) - } - - /** - * Reduces the elements of each group of data using the specified binary function. - * The given function must be commutative and associative or the result may be non-deterministic. - * - * @since 1.6.0 - */ - def reduce(f: (V, V) => V): Dataset[(K, V)] = { - val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - - implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) - flatMapGroups(func) - } - - /** - * Reduces the elements of each group of data using the specified binary function. - * The given function must be commutative and associative or the result may be non-deterministic. - * - * @since 1.6.0 - */ - def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { - reduce(f.call _) - } - - // This is here to prevent us from adding overloads that would be ambiguous. - @scala.annotation.varargs - private def agg(exprs: Column*): DataFrame = - groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) - - private def withEncoder(c: Column): Column = c match { - case tc: TypedColumn[_, _] => - tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) - case _ => c - } - - /** - * Internal helper function for building typed aggregations that return tuples. For simplicity - * and code reuse, we do this without the help of the type system and then use helper functions - * that cast appropriately for the user facing interface. - * TODO: does not handle aggregations that return nonflat results, - */ - protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoders = columns.map(_.encoder) - val namedColumns = - columns.map( - _.withInputType(resolvedVEncoder, dataAttributes).named) - val keyColumn = if (resolvedKEncoder.flat) { - assert(groupingAttributes.length == 1) - groupingAttributes.head - } else { - Alias(CreateStruct(groupingAttributes), "key")() - } - val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) - val execution = new QueryExecution(sqlContext, aggregate) - - new Dataset( - sqlContext, - execution, - ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) - } - - /** - * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key - * and the result of computing this aggregation over all elements in the group. - * - * @since 1.6.0 - */ - def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 1.6.0 - */ - def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 1.6.0 - */ - def agg[U1, U2, U3]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key - * and the result of computing these aggregations over all elements in the group. - * - * @since 1.6.0 - */ - def agg[U1, U2, U3, U4]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] - - /** - * Returns a [[Dataset]] that contains a tuple with each key and the number of items present - * for that key. - * - * @since 1.6.0 - */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) - - /** - * Applies the given function to each cogrouped data. For each unique group, the function will - * be passed the grouping key and 2 iterators containing all elements in the group from - * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an - * arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 1.6.0 - */ - def cogroup[U, R : Encoder]( - other: GroupedDataset[K, U])( - f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit val uEncoder = other.unresolvedVEncoder - Dataset[R]( - sqlContext, - CoGroup( - f, - this.groupingAttributes, - other.groupingAttributes, - this.dataAttributes, - other.dataAttributes, - this.logicalPlan, - other.logicalPlan)) - } - - /** - * Applies the given function to each cogrouped data. For each unique group, the function will - * be passed the grouping key and 2 iterators containing all elements in the group from - * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an - * arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 1.6.0 - */ - def cogroup[U, R]( - other: GroupedDataset[K, U], - f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/dcaa0166/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala new file mode 100644 index 0000000..f0f9682 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.function._ +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution + +/** + * :: Experimental :: + * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not + * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupBy` on an existing + * [[Dataset]]. + * + * @since 2.0.0 + */ +@Experimental +class KeyValueGroupedDataset[K, V] private[sql]( + kEncoder: Encoder[K], + vEncoder: Encoder[V], + val queryExecution: QueryExecution, + private val dataAttributes: Seq[Attribute], + private val groupingAttributes: Seq[Attribute]) extends Serializable { + + // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders + // when constructing new logical plans that will operate on the output of the current + // queryexecution. + + private implicit val unresolvedKEncoder = encoderFor(kEncoder) + private implicit val unresolvedVEncoder = encoderFor(vEncoder) + + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedVEncoder = + unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + + private def logicalPlan = queryExecution.analyzed + private def sqlContext = queryExecution.sqlContext + + private def groupedData = { + new RelationalGroupedDataset( + Dataset.newDataFrame(sqlContext, logicalPlan), + groupingAttributes, + RelationalGroupedDataset.GroupByType) + } + + /** + * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the + * specified type. The mapping of key columns to the type follows the same rules as `as` on + * [[Dataset]]. + * + * @since 1.6.0 + */ + def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = + new KeyValueGroupedDataset( + encoderFor[L], + unresolvedVEncoder, + queryExecution, + dataAttributes, + groupingAttributes) + + /** + * Returns a [[Dataset]] that contains each unique key. + * + * @since 1.6.0 + */ + def keys: Dataset[K] = { + Dataset[K]( + sqlContext, + Distinct( + Project(groupingAttributes, logicalPlan))) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + Dataset[U]( + sqlContext, + MapGroups( + f, + groupingAttributes, + dataAttributes, + logicalPlan)) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) + flatMapGroups(func) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava))(encoder) + } + + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduce(f: (V, V) => V): Dataset[(K, V)] = { + val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) + + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + flatMapGroups(func) + } + + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { + reduce(f.call _) + } + + // This is here to prevent us from adding overloads that would be ambiguous. + @scala.annotation.varargs + private def agg(exprs: Column*): DataFrame = + groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) + + private def withEncoder(c: Column): Column = c match { + case tc: TypedColumn[_, _] => + tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) + case _ => c + } + + /** + * Internal helper function for building typed aggregations that return tuples. For simplicity + * and code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + * TODO: does not handle aggregations that return nonflat results, + */ + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map( + _.withInputType(resolvedVEncoder, dataAttributes).named) + val keyColumn = if (resolvedKEncoder.flat) { + assert(groupingAttributes.length == 1) + groupingAttributes.head + } else { + Alias(CreateStruct(groupingAttributes), "key")() + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) + val execution = new QueryExecution(sqlContext, aggregate) + + new Dataset( + sqlContext, + execution, + ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) + } + + /** + * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key + * and the result of computing this aggregation over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2, U3]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2, U3, U4]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] + + /** + * Returns a [[Dataset]] that contains a tuple with each key and the number of items present + * for that key. + * + * @since 1.6.0 + */ + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]())) + + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ + def cogroup[U, R : Encoder]( + other: KeyValueGroupedDataset[K, U])( + f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + implicit val uEncoder = other.unresolvedVEncoder + Dataset[R]( + sqlContext, + CoGroup( + f, + this.groupingAttributes, + other.groupingAttributes, + this.dataAttributes, + other.dataAttributes, + this.logicalPlan, + other.logicalPlan)) + } + + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ + def cogroup[U, R]( + other: KeyValueGroupedDataset[K, U], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = { + cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dcaa0166/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala new file mode 100644 index 0000000..521032a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions + +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} +import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.NumericType + +/** + * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]]. + * + * The main method is the agg function, which has multiple variants. This class also contains + * convenience some first order statistics such as mean, sum for convenience. + * + * @since 2.0.0 + */ +class RelationalGroupedDataset protected[sql]( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: RelationalGroupedDataset.GroupType) { + + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { + val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + groupingExprs ++ aggExprs + } else { + aggExprs + } + + val aliasedAgg = aggregates.map(alias) + + groupType match { + case RelationalGroupedDataset.GroupByType => + Dataset.newDataFrame( + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.RollupType => + Dataset.newDataFrame( + df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.CubeType => + Dataset.newDataFrame( + df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + Dataset.newDataFrame( + df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + } + } + + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + private[this] def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + } + + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) + : DataFrame = { + + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.numericColumns + } else { + // Make sure all specified columns are numeric. + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[NumericType]) { + throw new AnalysisException( + s""""$colName" is not a numeric column. """ + + "Aggregation function can only be applied on a numeric column.") + } + namedExpr + } + } + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) + } + + private[this] def strToExpr(expr: String): (Expression => Expression) = { + val exprToFunc: (Expression => Expression) = { + (inputExpr: Expression) => expr.toLowerCase match { + // We special handle a few cases that have alias that are not in function registry. + case "avg" | "average" | "mean" => + UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) + case "stddev" | "std" => + UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) + // Also special handle count because we need to take care count(*). + case "count" | "size" => + // Turn count(*) into count(1) + inputExpr match { + case s: Star => Count(Literal(1)).toAggregateExpression() + case _ => Count(inputExpr).toAggregateExpression() + } + case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) + } + } + (inputExpr: Expression) => exprToFunc(inputExpr) + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg( + * "age" -> "max", + * "expense" -> "sum" + * ) + * }}} + * + * @since 1.3.0 + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + agg((aggExpr +: aggExprs).toMap) + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max", + * "expense" -> "sum" + * )) + * }}} + * + * @since 1.3.0 + */ + def agg(exprs: Map[String, String]): DataFrame = { + toDF(exprs.map { case (colName, expr) => + strToExpr(expr)(df(colName).expr) + }.toSeq) + } + + /** + * (Java-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import com.google.common.collect.ImmutableMap; + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); + * }}} + * + * @since 1.3.0 + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.asScala.toMap) + } + + /** + * Compute aggregates by specifying a series of aggregate columns. Note that this function by + * default retains the grouping columns in its output. To not retain grouping columns, set + * `spark.sql.retainGroupColumns` to false. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. + * + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * + * // Scala: + * import org.apache.spark.sql.functions._ + * df.groupBy("department").agg(max("age"), sum("expense")) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.groupBy("department").agg(max("age"), sum("expense")); + * }}} + * + * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change + * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. + * {{{ + * // Scala, 1.3.x: + * df.groupBy("department").agg($"department", max("age"), sum("expense")) + * + * // Java, 1.3.x: + * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); + * }}} + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = { + toDF((expr +: exprs).map(_.expr)) + } + + /** + * Count the number of rows for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * + * @since 1.3.0 + */ + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def mean(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Average) + } + + /** + * Compute the max value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the max values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def max(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Max) + } + + /** + * Compute the mean value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the mean values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def avg(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Average) + } + + /** + * Compute the min value for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the min values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def min(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Min) + } + + /** + * Compute the sum for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def sum(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Sum) + } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @since 1.6.0 + */ + def pivot(pivotColumn: String): RelationalGroupedDataset = { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order + .rdd + .map(_.get(0)) + .take(maxValues + 1) + .toSeq + + if (values.length > maxValues) { + throw new AnalysisException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + + "to at least the number of distinct values of the pivot column.") + } + + pivot(pivotColumn, values) + } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + groupType match { + case RelationalGroupedDataset.GroupByType => + new RelationalGroupedDataset( + df, + groupingExprs, + RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: RelationalGroupedDataset.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + pivot(pivotColumn, values.asScala) + } +} + + +/** + * Companion object for GroupedData. + */ +private[sql] object RelationalGroupedDataset { + + def apply( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: GroupType): RelationalGroupedDataset = { + new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType) + } + + /** + * The Grouping Type + */ + private[sql] trait GroupType + + /** + * To indicate it's the GroupBy + */ + private[sql] object GroupByType extends GroupType + + /** + * To indicate it's the CUBE + */ + private[sql] object CubeType extends GroupType + + /** + * To indicate it's the ROLLUP + */ + private[sql] object RollupType extends GroupType + + /** + * To indicate it's the PIVOT + */ + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType +} http://git-wip-us.apache.org/repos/asf/spark/blob/dcaa0166/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 79b6e61..4b8b0d9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable { public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() { + KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); - GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() { + KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; @@ -249,7 +249,7 @@ public class JavaDatasetSuite implements Serializable { public void testGroupByColumn() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset<Integer, String> grouped = + KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( @@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable { Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); - GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( + KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() { @Override public String call(Tuple2<String, Integer> value) throws Exception { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
