Repository: spark Updated Branches: refs/heads/master b84d4b4df -> f04b5672c
[SPARK-7289] handle project -> limit -> sort efficiently make the `TakeOrdered` strategy and operator more general, such that it can optionally handle a projection when necessary Author: Wenchen Fan <[email protected]> Closes #6780 from cloud-fan/limit and squashes the following commits: 34aa07b [Wenchen Fan] revert 07d5456 [Wenchen Fan] clean closure 20821ec [Wenchen Fan] fix 3676a82 [Wenchen Fan] address comments b558549 [Wenchen Fan] address comments 214842b [Wenchen Fan] fix style 2d8be83 [Wenchen Fan] add LimitPushDown 948f740 [Wenchen Fan] fix existing Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f04b5672 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f04b5672 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f04b5672 Branch: refs/heads/master Commit: f04b5672c5a5562f8494df3b0df23235285c9e9e Parents: b84d4b4 Author: Wenchen Fan <[email protected]> Authored: Wed Jun 24 13:28:50 2015 -0700 Committer: Michael Armbrust <[email protected]> Committed: Wed Jun 24 13:28:50 2015 -0700 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/Optimizer.scala | 52 ++++++++++---------- .../catalyst/optimizer/UnionPushdownSuite.scala | 4 +- .../scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../apache/spark/sql/execution/SparkPlan.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 8 ++- .../spark/sql/execution/basicOperators.scala | 27 +++++++--- .../spark/sql/execution/PlannerSuite.scala | 6 +++ .../org/apache/spark/sql/hive/HiveContext.scala | 2 +- 8 files changed, 62 insertions(+), 40 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f04b5672/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 98b4476..bfd2428 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -39,19 +39,22 @@ object DefaultOptimizer extends Optimizer { Batch("Distinct", FixedPoint(100), ReplaceDistinctWithAggregate) :: Batch("Operator Optimizations", FixedPoint(100), - UnionPushdown, - CombineFilters, + // Operator push down + UnionPushDown, + PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, ColumnPruning, + // Operator combine ProjectCollapsing, + CombineFilters, CombineLimits, + // Constant folding NullPropagation, OptimizeIn, ConstantFolding, LikeSimplification, BooleanSimplification, - PushPredicateThroughJoin, RemovePositive, SimplifyFilters, SimplifyCasts, @@ -63,25 +66,25 @@ object DefaultOptimizer extends Optimizer { } /** - * Pushes operations to either side of a Union. - */ -object UnionPushdown extends Rule[LogicalPlan] { + * Pushes operations to either side of a Union. + */ +object UnionPushDown extends Rule[LogicalPlan] { /** - * Maps Attributes from the left side to the corresponding Attribute on the right side. - */ - def buildRewrites(union: Union): AttributeMap[Attribute] = { + * Maps Attributes from the left side to the corresponding Attribute on the right side. + */ + private def buildRewrites(union: Union): AttributeMap[Attribute] = { assert(union.left.output.size == union.right.output.size) AttributeMap(union.left.output.zip(union.right.output)) } /** - * Rewrites an expression so that it can be pushed to the right side of a Union operator. - * This method relies on the fact that the output attributes of a union are always equal - * to the left child's output. - */ - def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { + * Rewrites an expression so that it can be pushed to the right side of a Union operator. + * This method relies on the fact that the output attributes of a union are always equal + * to the left child's output. + */ + private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { case a: Attribute => rewrites(a) } @@ -108,7 +111,6 @@ object UnionPushdown extends Rule[LogicalPlan] { } } - /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following * transformations: @@ -117,7 +119,6 @@ object UnionPushdown extends Rule[LogicalPlan] { * - Aggregate * - Project <- Join * - LeftSemiJoin - * - Performing alias substitution. */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -159,10 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] { Join(left, prunedChild(right, allReferences), LeftSemi, condition) + // Push down project through limit, so that we may have chance to push it further. case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // push down project if possible when the child is sort + // Push down project if possible when the child is sort case p @ Project(projectList, s @ Sort(_, _, grandChild)) if s.references.subsetOf(p.outputSet) => s.copy(child = Project(projectList, grandChild)) @@ -181,8 +183,8 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[Project]] operators into one, merging the - * expressions into one single expression. + * Combines two adjacent [[Project]] operators into one and perform alias substitution, + * merging the expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { @@ -222,10 +224,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] { object LikeSimplification extends Rule[LogicalPlan] { // if guards below protect from escapes on trailing %. // Cases like "something\%" are not optimized, but this does not affect correctness. - val startsWith = "([^_%]+)%".r - val endsWith = "%([^_%]+)".r - val contains = "%([^_%]+)%".r - val equalTo = "([^_%]*)".r + private val startsWith = "([^_%]+)%".r + private val endsWith = "%([^_%]+)".r + private val contains = "%([^_%]+)%".r + private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(l, Literal(utf, StringType)) => @@ -497,7 +499,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { grandChild)) } - def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { + private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = { condition transform { case a: AttributeReference => sourceAliases.getOrElse(a, a) } @@ -682,7 +684,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ - val MAX_DOUBLE_DIGITS = 15 + private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => http://git-wip-us.apache.org/repos/asf/spark/blob/f04b5672/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index 35f50be..ec37948 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class UnionPushdownSuite extends PlanTest { +class UnionPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - UnionPushdown) :: Nil + UnionPushDown) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) http://git-wip-us.apache.org/repos/asf/spark/blob/f04b5672/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 04fc798..5708df8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext) experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: - TakeOrdered :: + TakeOrderedAndProject :: HashAggregation :: LeftSemiJoin :: HashJoin :: http://git-wip-us.apache.org/repos/asf/spark/blob/f04b5672/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2b8d302..47f56b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -169,7 +169,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled && expressions.forall(_.isThreadSafe)) { - GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) http://git-wip-us.apache.org/repos/asf/spark/blob/f04b5672/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1ff1cc2..21912cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1) - object TakeOrdered extends Strategy { + object TakeOrderedAndProject extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrdered(limit, order, planLater(child)) :: Nil + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil case _ => Nil } } http://git-wip-us.apache.org/repos/asf/spark/blob/f04b5672/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 7aedd63..647c4ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -39,8 +39,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends @transient lazy val buildProjection = newMutableProjection(projectList, child.output) protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val resuableProjection = buildProjection() - iter.map(resuableProjection) + val reusableProjection = buildProjection() + iter.map(reusableProjection) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -147,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan) /** * :: DeveloperApi :: - * Take the first limit elements as defined by the sortOrder. This is logically equivalent to - * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but - * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -160,8 +166,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - private def collectData(): Array[InternalRow] = - child.execute().map(_.copy()).takeOrdered(limit)(ord) + // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. + @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) + + private def collectData(): Array[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) + projection.map(data.map(_)).getOrElse(data) + } override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) http://git-wip-us.apache.org/repos/asf/spark/blob/f04b5672/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5854ab4..3dd2413 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } + + test("efficient limit -> project -> sort") { + val query = testData.sort('key).select('value).limit(2).logicalPlan + val planned = planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/f04b5672/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cf05c6c..8021f91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrdered, + TakeOrderedAndProject, ParquetOperations, InMemoryScans, ParquetConversion, // Must be before HiveTableScans --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
