Repository: spark Updated Branches: refs/heads/master d4d762f27 -> 408e64b28
[SPARK-9294][SQL] cleanup comments, code style, naming typo for the new aggregation fix some comments and code style for https://github.com/apache/spark/pull/7458 Author: Wenchen Fan <[email protected]> Closes #7619 from cloud-fan/agg-clean and squashes the following commits: 3925457 [Wenchen Fan] one more... cc78357 [Wenchen Fan] one more cleanup 26f6a93 [Wenchen Fan] some minor cleanup for the new aggregation Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/408e64b2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/408e64b2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/408e64b2 Branch: refs/heads/master Commit: 408e64b284ef8bd6796d815b5eb603312d090b74 Parents: d4d762f Author: Wenchen Fan <[email protected]> Authored: Thu Jul 23 23:40:01 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Thu Jul 23 23:40:01 2015 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../expressions/aggregate/interfaces.scala | 18 ++--- .../apache/spark/sql/execution/Exchange.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../aggregate/sortBasedIterators.scala | 82 ++++++-------------- .../spark/sql/execution/aggregate/utils.scala | 10 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 ++- 7 files changed, 46 insertions(+), 89 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8cadbc5..e916887 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -533,7 +533,7 @@ class Analyzer( case min: Min if isDistinct => min // For other aggregate functions, DISTINCT keyword is not supported for now. // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other if isDistinct => + case other: AggregateExpression1 if isDistinct => failAnalysis(s"$name does not support DISTINCT keyword.") // If it does not have DISTINCT keyword, we will return it as is. case other => other http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d3fee1a..10bd19c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCod import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** The mode of an [[AggregateFunction1]]. */ +/** The mode of an [[AggregateFunction2]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -42,8 +42,8 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers - * containing intermediate results for this function and the generate final result. + * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. */ @@ -85,12 +85,12 @@ private[sql] case class AggregateExpression2( override def nullable: Boolean = aggregateFunction.nullable override def references: AttributeSet = { - val childReferemces = mode match { + val childReferences = mode match { case Partial | Complete => aggregateFunction.references.toSeq case PartialMerge | Final => aggregateFunction.bufferAttributes } - AttributeSet(childReferemces) + AttributeSet(childReferences) } override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" @@ -99,10 +99,8 @@ private[sql] case class AggregateExpression2( abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { - self: Product => - /** An aggregate function is not foldable. */ - override def foldable: Boolean = false + final override def foldable: Boolean = false /** * The offset of this function's buffer in the underlying buffer shared with other functions. http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d31e265..41a0c51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -224,13 +224,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // compatible. // TODO: ASSUMES TRANSITIVITY? def compatible: Boolean = - !operator.children + operator.children .map(_.outputPartitioning) .sliding(2) - .map { + .forall { case Seq(a) => true case Seq(a, b) => a.compatibleWith(b) - }.exists(!_) + } // Adds Exchange or Sort operators as required def addOperatorsIfNecessary( http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f54aa20..eb4be19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -190,12 +190,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sqlContext.conf.codegenEnabled).isDefined } - def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists { - case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false + def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { + case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && - Seq(IntegerType, LongType).contains(exprs.head.dataType) => false - case _ => true + Seq(IntegerType, LongType).contains(exprs.head.dataType) => true + case _ => false } def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala index ce1cbdc..b8e95a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -67,13 +67,6 @@ private[sql] abstract class SortAggregationIterator( functions } - // All non-algebraic aggregate functions. - protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { - aggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } - // Positions of those non-algebraic aggregate functions in aggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and // func2 and func3 are non-algebraic aggregate functions. @@ -91,6 +84,10 @@ private[sql] abstract class SortAggregationIterator( positions.toArray } + // All non-algebraic aggregate functions. + protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions) + // This is used to project expressions for the grouping expressions. protected val groupGenerator = newMutableProjection(groupingExpressions, inputAttributes)() @@ -179,8 +176,6 @@ private[sql] abstract class SortAggregationIterator( // For the below compare method, we do not need to make a copy of groupingKey. val groupingKey = groupGenerator(currentRow) // Check if the current row belongs the current input row. - currentGroupingKey.equals(groupingKey) - if (currentGroupingKey == groupingKey) { processRow(currentRow) } else { @@ -288,10 +283,7 @@ class PartialSortAggregationIterator( // This projection is used to update buffer values for all AlgebraicAggregates. private val algebraicUpdateProjection = { - val bufferSchema = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } + val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes) val updateExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) @@ -348,19 +340,14 @@ class PartialMergeSortAggregationIterator( inputAttributes, inputIter) { - private val placeholderAttribtues = + private val placeholderAttributes = Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { val bufferSchemata = - placeholderAttribtues ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ placeholderAttribtues ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ + placeholderAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) @@ -444,13 +431,8 @@ class FinalSortAggregationIterator( // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ + offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) @@ -462,13 +444,8 @@ class FinalSortAggregationIterator( // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ + offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp @@ -599,11 +576,10 @@ class FinalAndCompleteSortAggregationIterator( } // All non-algebraic aggregate functions with mode Final. - private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = finalAggregateFunctions.collect { case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } + } // All aggregate functions with mode Complete. private val completeAggregateFunctions: Array[AggregateFunction2] = { @@ -617,11 +593,10 @@ class FinalAndCompleteSortAggregationIterator( } // All non-algebraic aggregate functions with mode Complete. - private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = completeAggregateFunctions.collect { case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } + } // This projection is used to merge buffer values for all AlgebraicAggregates with mode // Final. @@ -633,13 +608,9 @@ class FinalAndCompleteSortAggregationIterator( val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) val bufferSchemata = - offsetAttributes ++ finalAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } ++ completeOffsetAttributes + offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) ++ + completeOffsetAttributes ++ offsetAttributes ++ + finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ completeOffsetAttributes val mergeExpressions = placeholderExpressions ++ finalAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions @@ -658,10 +629,8 @@ class FinalAndCompleteSortAggregationIterator( val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) val bufferSchema = - offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } + offsetAttributes ++ finalOffsetAttributes ++ + completeAggregateFunctions.flatMap(_.bufferAttributes) val updateExpressions = placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions @@ -673,13 +642,8 @@ class FinalAndCompleteSortAggregationIterator( // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ + offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 1cb2771..5bbe6c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -191,10 +191,7 @@ object Utils { } val groupExpressionMap = namedGroupingExpressions.toMap val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map { - case AggregateExpression2(aggregateFunction, mode, isDistinct) => - AggregateExpression2(aggregateFunction, Partial, isDistinct) - } + val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => agg.aggregateFunction.bufferAttributes } @@ -208,10 +205,7 @@ object Utils { child) // 2. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map { - case AggregateExpression2(aggregateFunction, mode, isDistinct) => - AggregateExpression2(aggregateFunction, Final, isDistinct) - } + val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ab8dce6..95a1106 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1518,18 +1518,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-8945: add and subtract expressions for interval type") { import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.Interval.MICROS_PER_WEEK val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") - checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123))) + checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) checkAnswer(df.select(df("i") + new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123))) + Row(new Interval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) checkAnswer(df.select(df("i") - new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123))) + Row(new Interval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) // unary minus checkAnswer(df.select(-df("i")), - Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))) + Row(new Interval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
