This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 09b0548  [SPARK-26450][SQL] Avoid rebuilding map of schema for every 
column in projection
09b0548 is described below

commit 09b05487b78fdd4b5d4cde114f5f4440a076bf10
Author: Bruce Robbins <bersprock...@gmail.com>
AuthorDate: Sun Jan 13 23:54:19 2019 +0100

    [SPARK-26450][SQL] Avoid rebuilding map of schema for every column in 
projection
    
    ## What changes were proposed in this pull request?
    
    When creating some unsafe projections, Spark rebuilds the map of schema 
attributes once for each expression in the projection. Some file format readers 
create one unsafe projection per input file, others create one per task. 
ProjectExec also creates one unsafe projection per task. As a result, for wide 
queries on wide tables, Spark might build the map of schema attributes hundreds 
of thousands of times.
    
    This PR changes two functions to reuse the same AttributeSeq instance when 
creating BoundReference objects for each expression in the projection. This 
avoids the repeated rebuilding of the map of schema attributes.
    
    ### Benchmarks
    
    The time saved by this PR depends on size of the schema, size of the 
projection, number of input files (or number of file splits), number of tasks, 
and file format. I chose a couple of example cases.
    
    In the following tests, I ran the query
    ```sql
    select * from table where id1 = 1
    ```
    
    Matching rows are about 0.2% of the table.
    
    #### Orc table 6000 columns, 500K rows, 34 input files
    
    baseline | pr | improvement
    ----|----|----
    1.772306 min | 1.487267 min | 16.082943%
    
    #### Orc table 6000 columns, 500K rows, *17* input files
    
    baseline | pr | improvement
    ----|----|----
     1.656400 min | 1.423550 min | 14.057595%
    
    #### Orc table 60 columns, 50M rows, 34 input files
    
    baseline | pr | improvement
    ----|----|----
    0.299878 min | 0.290339 min | 3.180926%
    
    #### Parquet table 6000 columns, 500K rows, 34 input files
    
    baseline | pr | improvement
    ----|----|----
    1.478306 min | 1.373728 min | 7.074165%
    
    Note: The parquet reader does not create an unsafe projection. However, the 
filter operation in the query causes the planner to add a ProjectExec, which 
does create an unsafe projection for each task. So these results have nothing 
to do with Parquet itself.
    
    #### Parquet table 60 columns, 50M rows, 34 input files
    
    baseline | pr | improvement
    ----|----|----
    0.245006 min | 0.242200 min | 1.145099%
    
    #### CSV table 6000 columns, 500K rows, 34 input files
    
    baseline | pr | improvement
    ----|----|----
    2.390117 min | 2.182778 min | 8.674844%
    
    #### CSV table 60 columns, 50M rows, 34 input files
    
    baseline | pr | improvement
    ----|----|----
    1.520911 min | 1.510211 min | 0.703526%
    
    ## How was this patch tested?
    
    SQL unit tests
    Python core and SQL test
    
    Closes #23392 from bersprockets/norebuild.
    
    Authored-by: Bruce Robbins <bersprock...@gmail.com>
    Signed-off-by: Herman van Hovell <hvanhov...@databricks.com>
---
 .../sql/catalyst/expressions/BoundAttribute.scala  |  9 +++++
 .../expressions/InterpretedMutableProjection.scala |  3 +-
 .../sql/catalyst/expressions/Projection.scala      |  9 +++--
 .../codegen/GenerateMutableProjection.scala        |  3 +-
 .../expressions/codegen/GenerateOrdering.scala     |  5 ++-
 .../codegen/GenerateSafeProjection.scala           |  3 +-
 .../codegen/GenerateUnsafeProjection.scala         |  3 +-
 .../spark/sql/catalyst/expressions/ordering.scala  |  3 +-
 .../spark/sql/catalyst/expressions/package.scala   |  7 ----
 .../apache/spark/sql/execution/ExpandExec.scala    |  5 ++-
 .../execution/aggregate/AggregationIterator.scala  |  3 +-
 .../execution/aggregate/HashAggregateExec.scala    | 45 +++++++++++-----------
 .../sql/execution/basicPhysicalOperators.scala     |  3 +-
 .../execution/datasources/FileFormatWriter.scala   |  6 +--
 .../spark/sql/execution/joins/HashJoin.scala       |  6 +--
 .../sql/execution/joins/SortMergeJoinExec.scala    |  3 +-
 .../sql/execution/window/WindowFunctionFrame.scala |  8 ++--
 17 files changed, 68 insertions(+), 56 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index ea8c369..7ae5924 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -86,4 +86,13 @@ object BindReferences extends Logging {
       }
     }.asInstanceOf[A] // Kind of a hack, but safe.  TODO: Tighten return type 
when possible.
   }
+
+  /**
+   * A helper function to bind given expressions to an input schema.
+   */
+  def bindReferences[A <: Expression](
+      expressions: Seq[A],
+      input: AttributeSeq): Seq[A] = {
+    expressions.map(BindReferences.bindReference(_, input))
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
index 122a564..5c8aa4e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 
 
@@ -30,7 +31,7 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
  */
 class InterpretedMutableProjection(expressions: Seq[Expression]) extends 
MutableProjection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
-    this(toBoundExprs(expressions, inputSchema))
+    this(bindReferences(expressions, inputSchema))
 
   private[this] val buffer = new Array[Any](expressions.size)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index b48f7ba..eaaf94b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, 
GenerateSafeProjection, GenerateUnsafeProjection}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
@@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
  */
 class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
-    this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(expressions, inputSchema))
 
   override def initialize(partitionIndex: Int): Unit = {
     expressions.foreach(_.foreach {
@@ -99,7 +100,7 @@ object MutableProjection
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): 
MutableProjection = {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
 
@@ -162,7 +163,7 @@ object UnsafeProjection
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): 
UnsafeProjection = {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
 
@@ -203,6 +204,6 @@ object SafeProjection extends 
CodeGeneratorWithInterpretedFallback[Seq[Expressio
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection 
= {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d588e7f..838bd1c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 
 // MutableProjection is not accessible in Java
@@ -35,7 +36,7 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], MutableP
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): 
Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   def generate(
       expressions: Seq[Expression],
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 283fd2a..b66b80a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
@@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], 
Ordering[InternalR
     in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
 
   protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): 
Seq[SortOrder] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   /**
    * Creates a code gen ordering for sorting this schema, in ascending order.
@@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
   extends Ordering[InternalRow] with KryoSerializable {
 
   def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
-    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(ordering, inputSchema))
 
   @transient
   private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 3977866..e285398 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -21,6 +21,7 @@ import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData, MapData}
@@ -41,7 +42,7 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): 
Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   private def createCodeForStruct(
       ctx: CodegenContext,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 0ecd0de..fb1d8a3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.types._
 
@@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): 
Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   def generate(
       expressions: Seq[Expression],
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
index e24a3de..c8d6671 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.types._
 
 
@@ -27,7 +28,7 @@ import org.apache.spark.sql.types._
 class InterpretedOrdering(ordering: Seq[SortOrder]) extends 
Ordering[InternalRow] {
 
   def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
-    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(ordering, inputSchema))
 
   def compare(a: InternalRow, b: InternalRow): Int = {
     var i = 0
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index bf18e8b..932c364 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -86,13 +86,6 @@ package object expressions  {
   }
 
   /**
-   * A helper function to bind given expressions to an input schema.
-   */
-  def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): 
Seq[Expression] = {
-    exprs.map(BindReferences.bindReference(_, inputSchema))
-  }
-
-  /**
    * Helper functions for working with `Seq[Attribute]`.
    */
   implicit class AttributeSeq(val attrs: Seq[Attribute]) extends Serializable {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 5b4edf5..85f4914 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -145,11 +145,12 @@ case class ExpandExec(
     // Part 1: declare variables for each column
     // If a column has the same value for all output rows, then we also 
generate its computation
     // right after declaration. Otherwise its value is computed in the part 2.
+    lazy val attributeSeq: AttributeSeq = child.output
     val outputColumns = output.indices.map { col =>
       val firstExpr = projections.head(col)
       if (sameOutput(col)) {
         // This column is the same across all output rows. Just generate code 
for it here.
-        BindReferences.bindReference(firstExpr, child.output).genCode(ctx)
+        BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
       } else {
         val isNull = ctx.freshName("isNull")
         val value = ctx.freshName("value")
@@ -170,7 +171,7 @@ case class ExpandExec(
       var updateCode = ""
       for (col <- exprs.indices) {
         if (!sameOutput(col)) {
-          val ev = BindReferences.bindReference(exprs(col), 
child.output).genCode(ctx)
+          val ev = BindReferences.bindReference(exprs(col), 
attributeSeq).genCode(ctx)
           updateCode +=
             s"""
                |${ev.code}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 98c4a51..a1fb23d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -77,6 +77,7 @@ abstract class AggregationIterator(
     val expressionsLength = expressions.length
     val functions = new Array[AggregateFunction](expressionsLength)
     var i = 0
+    val inputAttributeSeq: AttributeSeq = inputAttributes
     while (i < expressionsLength) {
       val func = expressions(i).aggregateFunction
       val funcWithBoundReferences: AggregateFunction = expressions(i).mode 
match {
@@ -86,7 +87,7 @@ abstract class AggregationIterator(
           // this function is Partial or Complete because we will call eval of 
this
           // function's children in the update method of this aggregate 
function.
           // Those eval calls require BoundReferences to work.
-          BindReferences.bindReference(func, inputAttributes)
+          BindReferences.bindReference(func, inputAttributeSeq)
         case _ =>
           // We only need to set inputBufferOffset for aggregate functions 
with mode
           // PartialMerge and Final.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 2355d30..220a4b0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -199,15 +200,13 @@ case class HashAggregateExec(
     val (resultVars, genResult) = if (modes.contains(Final) || 
modes.contains(Complete)) {
       // evaluate aggregate results
       ctx.currentVars = bufVars
-      val aggResults = functions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
-      }
+      val aggResults = bindReferences(
+        functions.map(_.evaluateExpression),
+        aggregateBufferAttributes).map(_.genCode(ctx))
       val evaluateAggResults = evaluateVariables(aggResults)
       // evaluate result expressions
       ctx.currentVars = aggResults
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
-      }
+      val resultVars = bindReferences(resultExpressions, 
aggregateAttributes).map(_.genCode(ctx))
       (resultVars, s"""
         |$evaluateAggResults
         |${evaluateVariables(resultVars)}
@@ -264,7 +263,7 @@ case class HashAggregateExec(
       }
     }
     ctx.currentVars = bufVars ++ input
-    val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, 
inputAttrs))
+    val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
     val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
     val effectiveCodes = subExprs.codes.mkString("\n")
     val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
@@ -456,16 +455,16 @@ case class HashAggregateExec(
       val evaluateBufferVars = evaluateVariables(bufferVars)
       // evaluate the aggregation result
       ctx.currentVars = bufferVars
-      val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
-      }
+      val aggResults = bindReferences(
+        declFunctions.map(_.evaluateExpression),
+        aggregateBufferAttributes).map(_.genCode(ctx))
       val evaluateAggResults = evaluateVariables(aggResults)
       // generate the final result
       ctx.currentVars = keyVars ++ aggResults
       val inputAttrs = groupingAttributes ++ aggregateAttributes
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, inputAttrs).genCode(ctx)
-      }
+      val resultVars = bindReferences[Expression](
+        resultExpressions,
+        inputAttrs).map(_.genCode(ctx))
       s"""
        $evaluateKeyVars
        $evaluateBufferVars
@@ -494,9 +493,9 @@ case class HashAggregateExec(
 
       ctx.currentVars = keyVars ++ resultBufferVars
       val inputAttrs = resultExpressions.map(_.toAttribute)
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, inputAttrs).genCode(ctx)
-      }
+      val resultVars = bindReferences[Expression](
+        resultExpressions,
+        inputAttrs).map(_.genCode(ctx))
       s"""
        $evaluateKeyVars
        $evaluateResultBufferVars
@@ -506,9 +505,9 @@ case class HashAggregateExec(
       // generate result based on grouping key
       ctx.INPUT_ROW = keyTerm
       ctx.currentVars = null
-      val eval = resultExpressions.map{ e =>
-        BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
-      }
+      val eval = bindReferences[Expression](
+        resultExpressions,
+        groupingAttributes).map(_.genCode(ctx))
       consume(ctx, eval)
     }
     ctx.addNewFunction(funcName,
@@ -730,9 +729,9 @@ case class HashAggregateExec(
   private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
     // create grouping key
     val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
-      ctx, groupingExpressions.map(e => 
BindReferences.bindReference[Expression](e, child.output)))
+      ctx, bindReferences[Expression](groupingExpressions, child.output))
     val fastRowKeys = ctx.generateExpressions(
-      groupingExpressions.map(e => BindReferences.bindReference[Expression](e, 
child.output)))
+      bindReferences[Expression](groupingExpressions, child.output))
     val unsafeRowKeys = unsafeRowKeyCode.value
     val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
     val fastRowBuffer = ctx.freshName("fastAggBuffer")
@@ -825,7 +824,7 @@ case class HashAggregateExec(
 
     val updateRowInRegularHashMap: String = {
       ctx.INPUT_ROW = unsafeRowBuffer
-      val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, 
inputAttr))
+      val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
       val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
       val effectiveCodes = subExprs.codes.mkString("\n")
       val unsafeRowBufferEvals = 
ctx.withSubExprEliminationExprs(subExprs.states) {
@@ -849,7 +848,7 @@ case class HashAggregateExec(
       if (isFastHashMapEnabled) {
         if (isVectorizedHashMapEnabled) {
           ctx.INPUT_ROW = fastRowBuffer
-          val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, 
inputAttr))
+          val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
           val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
           val effectiveCodes = subExprs.codes.mkString("\n")
           val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 2570b36..318dca0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, 
SparkContext, TaskCon
 import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], 
child: SparkPlan)
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    val exprs = projectList.map(x => 
BindReferences.bindReference[Expression](x, child.output))
+    val exprs = bindReferences[Expression](projectList, child.output)
     val resultVars = exprs.map(_.genCode(ctx))
     // Evaluation of non-deterministic expressions can't be deferred.
     val nonDeterministicAttrs = 
projectList.filterNot(_.deterministic).map(_.toAttribute)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 774fe38..260ad97 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
 import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
@@ -145,9 +146,8 @@ object FileFormatWriter extends Logging {
         // SPARK-21165: the `requiredOrdering` is based on the attributes from 
analyzed plan, and
         // the physical plan may have different attribute ids due to optimizer 
removing some
         // aliases. Here we bind the expression ahead to avoid potential 
attribute ids mismatch.
-        val orderingExpr = requiredOrdering
-          .map(SortOrder(_, Ascending))
-          .map(BindReferences.bindReference(_, outputSpec.outputColumns))
+        val orderingExpr = bindReferences(
+          requiredOrdering.map(SortOrder(_, Ascending)), 
outputSpec.outputColumns)
         SortExec(
           orderingExpr,
           global = false,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 1aef5f6..5ee4c7f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
@@ -63,9 +64,8 @@ trait HashJoin {
   protected lazy val (buildKeys, streamedKeys) = {
     require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
       "Join keys from two sides should have same types")
-    val lkeys = 
HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, 
left.output))
-    val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
-      .map(BindReferences.bindReference(_, right.output))
+    val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
+    val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), 
right.output)
     buildSide match {
       case BuildLeft => (lkeys, rkeys)
       case BuildRight => (rkeys, lkeys)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index d7d3f6d..f829f07 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans._
@@ -393,7 +394,7 @@ case class SortMergeJoinExec(
       input: Seq[Attribute]): Seq[ExprCode] = {
     ctx.INPUT_ROW = row
     ctx.currentVars = null
-    keys.map(BindReferences.bindReference(_, input).genCode(ctx))
+    bindReferences(keys, input).map(_.genCode(ctx))
   }
 
   private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): 
Seq[ExprCode] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index a560189..d5f2ffa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -21,6 +21,7 @@ import java.util
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
 
@@ -103,9 +104,8 @@ final class OffsetWindowFunctionFrame(
   private[this] val projection = {
     // Collect the expressions and bind them.
     val inputAttrs = inputSchema.map(_.withNullability(true))
-    val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { 
e =>
-      BindReferences.bindReference(e.input, inputAttrs)
-    }
+    val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences(
+      expressions.toSeq.map(_.input), inputAttrs)
 
     // Create the projection.
     newMutableProjection(boundExpressions, Nil).target(target)
@@ -114,7 +114,7 @@ final class OffsetWindowFunctionFrame(
   /** Create the projection used when the offset row DOES NOT exists. */
   private[this] val fillDefaultValue = {
     // Collect the expressions and bind them.
-    val inputAttrs = inputSchema.map(_.withNullability(true))
+    val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
     val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { 
e =>
       if (e.default == null || e.default.foldable && e.default.eval() == null) 
{
         // The default value is null.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to