This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 09b0548 [SPARK-26450][SQL] Avoid rebuilding map of schema for every column in projection 09b0548 is described below commit 09b05487b78fdd4b5d4cde114f5f4440a076bf10 Author: Bruce Robbins <bersprock...@gmail.com> AuthorDate: Sun Jan 13 23:54:19 2019 +0100 [SPARK-26450][SQL] Avoid rebuilding map of schema for every column in projection ## What changes were proposed in this pull request? When creating some unsafe projections, Spark rebuilds the map of schema attributes once for each expression in the projection. Some file format readers create one unsafe projection per input file, others create one per task. ProjectExec also creates one unsafe projection per task. As a result, for wide queries on wide tables, Spark might build the map of schema attributes hundreds of thousands of times. This PR changes two functions to reuse the same AttributeSeq instance when creating BoundReference objects for each expression in the projection. This avoids the repeated rebuilding of the map of schema attributes. ### Benchmarks The time saved by this PR depends on size of the schema, size of the projection, number of input files (or number of file splits), number of tasks, and file format. I chose a couple of example cases. In the following tests, I ran the query ```sql select * from table where id1 = 1 ``` Matching rows are about 0.2% of the table. #### Orc table 6000 columns, 500K rows, 34 input files baseline | pr | improvement ----|----|---- 1.772306 min | 1.487267 min | 16.082943% #### Orc table 6000 columns, 500K rows, *17* input files baseline | pr | improvement ----|----|---- 1.656400 min | 1.423550 min | 14.057595% #### Orc table 60 columns, 50M rows, 34 input files baseline | pr | improvement ----|----|---- 0.299878 min | 0.290339 min | 3.180926% #### Parquet table 6000 columns, 500K rows, 34 input files baseline | pr | improvement ----|----|---- 1.478306 min | 1.373728 min | 7.074165% Note: The parquet reader does not create an unsafe projection. However, the filter operation in the query causes the planner to add a ProjectExec, which does create an unsafe projection for each task. So these results have nothing to do with Parquet itself. #### Parquet table 60 columns, 50M rows, 34 input files baseline | pr | improvement ----|----|---- 0.245006 min | 0.242200 min | 1.145099% #### CSV table 6000 columns, 500K rows, 34 input files baseline | pr | improvement ----|----|---- 2.390117 min | 2.182778 min | 8.674844% #### CSV table 60 columns, 50M rows, 34 input files baseline | pr | improvement ----|----|---- 1.520911 min | 1.510211 min | 0.703526% ## How was this patch tested? SQL unit tests Python core and SQL test Closes #23392 from bersprockets/norebuild. Authored-by: Bruce Robbins <bersprock...@gmail.com> Signed-off-by: Herman van Hovell <hvanhov...@databricks.com> --- .../sql/catalyst/expressions/BoundAttribute.scala | 9 +++++ .../expressions/InterpretedMutableProjection.scala | 3 +- .../sql/catalyst/expressions/Projection.scala | 9 +++-- .../codegen/GenerateMutableProjection.scala | 3 +- .../expressions/codegen/GenerateOrdering.scala | 5 ++- .../codegen/GenerateSafeProjection.scala | 3 +- .../codegen/GenerateUnsafeProjection.scala | 3 +- .../spark/sql/catalyst/expressions/ordering.scala | 3 +- .../spark/sql/catalyst/expressions/package.scala | 7 ---- .../apache/spark/sql/execution/ExpandExec.scala | 5 ++- .../execution/aggregate/AggregationIterator.scala | 3 +- .../execution/aggregate/HashAggregateExec.scala | 45 +++++++++++----------- .../sql/execution/basicPhysicalOperators.scala | 3 +- .../execution/datasources/FileFormatWriter.scala | 6 +-- .../spark/sql/execution/joins/HashJoin.scala | 6 +-- .../sql/execution/joins/SortMergeJoinExec.scala | 3 +- .../sql/execution/window/WindowFunctionFrame.scala | 8 ++-- 17 files changed, 68 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ea8c369..7ae5924 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -86,4 +86,13 @@ object BindReferences extends Logging { } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } + + /** + * A helper function to bind given expressions to an input schema. + */ + def bindReferences[A <: Expression]( + expressions: Seq[A], + input: AttributeSeq): Seq[A] = { + expressions.map(BindReferences.bindReference(_, input)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 122a564..5c8aa4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp */ class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(toBoundExprs(expressions, inputSchema)) + this(bindReferences(expressions, inputSchema)) private[this] val buffer = new Array[Any](expressions.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index b48f7ba..eaaf94b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType} */ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(expressions.map(BindReferences.bindReference(_, inputSchema))) + this(bindReferences(expressions, inputSchema)) override def initialize(partitionIndex: Int): Unit = { expressions.foreach(_.foreach { @@ -99,7 +100,7 @@ object MutableProjection * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } @@ -162,7 +163,7 @@ object UnsafeProjection * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } @@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d588e7f..838bd1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp // MutableProjection is not accessible in Java @@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + bindReferences(in, inputSchema) def generate( expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 283fd2a..b66b80a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = - in.map(BindReferences.bindReference(_, inputSchema)) + bindReferences(in, inputSchema) /** * Creates a code gen ordering for sorting this schema, in ascending order. @@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) extends Ordering[InternalRow] with KryoSerializable { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(ordering.map(BindReferences.bindReference(_, inputSchema))) + this(bindReferences(ordering, inputSchema)) @transient private[this] var generatedOrdering = GenerateOrdering.generate(ordering) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3977866..e285398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} @@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + bindReferences(in, inputSchema) private def createCodeForStruct( ctx: CodegenContext, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 0ecd0de..fb1d8a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + bindReferences(in, inputSchema) def generate( expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index e24a3de..c8d6671 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.types._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(ordering.map(BindReferences.bindReference(_, inputSchema))) + this(bindReferences(ordering, inputSchema)) def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index bf18e8b..932c364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -86,13 +86,6 @@ package object expressions { } /** - * A helper function to bind given expressions to an input schema. - */ - def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = { - exprs.map(BindReferences.bindReference(_, inputSchema)) - } - - /** * Helper functions for working with `Seq[Attribute]`. */ implicit class AttributeSeq(val attrs: Seq[Attribute]) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 5b4edf5..85f4914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -145,11 +145,12 @@ case class ExpandExec( // Part 1: declare variables for each column // If a column has the same value for all output rows, then we also generate its computation // right after declaration. Otherwise its value is computed in the part 2. + lazy val attributeSeq: AttributeSeq = child.output val outputColumns = output.indices.map { col => val firstExpr = projections.head(col) if (sameOutput(col)) { // This column is the same across all output rows. Just generate code for it here. - BindReferences.bindReference(firstExpr, child.output).genCode(ctx) + BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx) } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") @@ -170,7 +171,7 @@ case class ExpandExec( var updateCode = "" for (col <- exprs.indices) { if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx) + val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx) updateCode += s""" |${ev.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 98c4a51..a1fb23d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -77,6 +77,7 @@ abstract class AggregationIterator( val expressionsLength = expressions.length val functions = new Array[AggregateFunction](expressionsLength) var i = 0 + val inputAttributeSeq: AttributeSeq = inputAttributes while (i < expressionsLength) { val func = expressions(i).aggregateFunction val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { @@ -86,7 +87,7 @@ abstract class AggregationIterator( // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, inputAttributes) + BindReferences.bindReference(func, inputAttributeSeq) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2355d30..220a4b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -199,15 +200,13 @@ case class HashAggregateExec( val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { // evaluate aggregate results ctx.currentVars = bufVars - val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) - } + val aggResults = bindReferences( + functions.map(_.evaluateExpression), + aggregateBufferAttributes).map(_.genCode(ctx)) val evaluateAggResults = evaluateVariables(aggResults) // evaluate result expressions ctx.currentVars = aggResults - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) - } + val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx)) (resultVars, s""" |$evaluateAggResults |${evaluateVariables(resultVars)} @@ -264,7 +263,7 @@ case class HashAggregateExec( } } ctx.currentVars = bufVars ++ input - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { @@ -456,16 +455,16 @@ case class HashAggregateExec( val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars - val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) - } + val aggResults = bindReferences( + declFunctions.map(_.evaluateExpression), + aggregateBufferAttributes).map(_.genCode(ctx)) val evaluateAggResults = evaluateVariables(aggResults) // generate the final result ctx.currentVars = keyVars ++ aggResults val inputAttrs = groupingAttributes ++ aggregateAttributes - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).genCode(ctx) - } + val resultVars = bindReferences[Expression]( + resultExpressions, + inputAttrs).map(_.genCode(ctx)) s""" $evaluateKeyVars $evaluateBufferVars @@ -494,9 +493,9 @@ case class HashAggregateExec( ctx.currentVars = keyVars ++ resultBufferVars val inputAttrs = resultExpressions.map(_.toAttribute) - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).genCode(ctx) - } + val resultVars = bindReferences[Expression]( + resultExpressions, + inputAttrs).map(_.genCode(ctx)) s""" $evaluateKeyVars $evaluateResultBufferVars @@ -506,9 +505,9 @@ case class HashAggregateExec( // generate result based on grouping key ctx.INPUT_ROW = keyTerm ctx.currentVars = null - val eval = resultExpressions.map{ e => - BindReferences.bindReference(e, groupingAttributes).genCode(ctx) - } + val eval = bindReferences[Expression]( + resultExpressions, + groupingAttributes).map(_.genCode(ctx)) consume(ctx, eval) } ctx.addNewFunction(funcName, @@ -730,9 +729,9 @@ case class HashAggregateExec( private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( - ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + ctx, bindReferences[Expression](groupingExpressions, child.output)) val fastRowKeys = ctx.generateExpressions( - groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") @@ -825,7 +824,7 @@ case class HashAggregateExec( val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { @@ -849,7 +848,7 @@ case class HashAggregateExec( if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2570b36..318dca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output)) + val exprs = bindReferences[Expression](projectList, child.output) val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 774fe38..260ad97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} @@ -145,9 +146,8 @@ object FileFormatWriter extends Logging { // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = requiredOrdering - .map(SortOrder(_, Ascending)) - .map(BindReferences.bindReference(_, outputSpec.outputColumns)) + val orderingExpr = bindReferences( + requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) SortExec( orderingExpr, global = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1aef5f6..5ee4c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{RowIterator, SparkPlan} @@ -63,9 +64,8 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = HashJoin.rewriteKeyExpr(rightKeys) - .map(BindReferences.bindReference(_, right.output)) + val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output) + val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d7d3f6d..f829f07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ @@ -393,7 +394,7 @@ case class SortMergeJoinExec( input: Seq[Attribute]): Seq[ExprCode] = { ctx.INPUT_ROW = row ctx.currentVars = null - keys.map(BindReferences.bindReference(_, input).genCode(ctx)) + bindReferences(keys, input).map(_.genCode(ctx)) } private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index a560189..d5f2ffa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -21,6 +21,7 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray @@ -103,9 +104,8 @@ final class OffsetWindowFunctionFrame( private[this] val projection = { // Collect the expressions and bind them. val inputAttrs = inputSchema.map(_.withNullability(true)) - val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => - BindReferences.bindReference(e.input, inputAttrs) - } + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences( + expressions.toSeq.map(_.input), inputAttrs) // Create the projection. newMutableProjection(boundExpressions, Nil).target(target) @@ -114,7 +114,7 @@ final class OffsetWindowFunctionFrame( /** Create the projection used when the offset row DOES NOT exists. */ private[this] val fillDefaultValue = { // Collect the expressions and bind them. - val inputAttrs = inputSchema.map(_.withNullability(true)) + val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true)) val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => if (e.default == null || e.default.foldable && e.default.eval() == null) { // The default value is null. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org