Repository: spark
Updated Branches:
  refs/heads/master a91784fb6 -> bd94ea4c8


[SPARK-14175][SQL] whole stage codegen interface refactor

## What changes were proposed in this pull request?

1. merge consumeChild into consume()
2. always generate code for input variables and UnsafeRow, a plan can use eight 
of them.

## How was this patch tested?

Existing tests.

Author: Davies Liu <[email protected]>

Closes #11975 from davies/gen_refactor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bd94ea4c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bd94ea4c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bd94ea4c

Branch: refs/heads/master
Commit: bd94ea4c80f4fc18f4000346d7c6717539846efb
Parents: a91784f
Author: Davies Liu <[email protected]>
Authored: Sat Mar 26 11:03:05 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Sat Mar 26 11:03:05 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/ExistingRDD.scala       |   3 +-
 .../org/apache/spark/sql/execution/Expand.scala |   2 +-
 .../org/apache/spark/sql/execution/Sort.scala   |  26 +---
 .../spark/sql/execution/WholeStageCodegen.scala | 153 +++++++------------
 .../execution/aggregate/TungstenAggregate.scala |   2 +-
 .../spark/sql/execution/basicOperators.scala    |   4 +-
 .../spark/sql/execution/debug/package.scala     |   2 +-
 .../sql/execution/joins/BroadcastHashJoin.scala |   2 +-
 .../org/apache/spark/sql/execution/limit.scala  |   2 +-
 9 files changed, 72 insertions(+), 124 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 3e2c799..815ff01 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -271,7 +271,8 @@ private[sql] case class DataSourceScan(
       |   }
       | }""".stripMargin)
 
-    val exprRows = output.zipWithIndex.map(x => new BoundReference(x._2, 
x._1.dataType, true))
+    val exprRows =
+      output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, 
x._1.nullable))
     ctx.INPUT_ROW = row
     ctx.currentVars = null
     val columns2 = exprRows.map(_.gen(ctx))

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 05627ba..bd23b7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -93,7 +93,7 @@ case class Expand(
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
     /*
      * When the projections list looks like:
      *   expr1A, exprB, expr1C

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index b4dd770..efd8760 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -98,6 +98,8 @@ case class Sort(
     }
   }
 
+  override def usedInputs: AttributeSet = AttributeSet(Seq.empty)
+
   override def upstreams(): Seq[RDD[InternalRow]] = {
     child.asInstanceOf[CodegenSupport].upstreams()
   }
@@ -105,8 +107,6 @@ case class Sort(
   // Name of sorter variable used in codegen.
   private var sorterVariable: String = _
 
-  override def preferUnsafeRow: Boolean = true
-
   override protected def doProduce(ctx: CodegenContext): String = {
     val needToSort = ctx.freshName("needToSort")
     ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
@@ -158,22 +158,10 @@ case class Sort(
      """.stripMargin.trim
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
-    if (row != null) {
-      s"$sorterVariable.insertRow((UnsafeRow)$row);"
-    } else {
-      val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
-        BoundReference(i, attr.dataType, attr.nullable)
-      }
-
-      ctx.currentVars = input
-      val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
-
-      s"""
-         | // Convert the input attributes to an UnsafeRow and add it to the 
sorter
-         | ${code.code}
-         | $sorterVariable.insertRow(${code.value});
-       """.stripMargin.trim
-    }
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
+    s"""
+       |${row.code}
+       |$sorterVariable.insertRow((UnsafeRow)${row.value});
+     """.stripMargin
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 0be0b80..1b13c8f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -69,11 +69,6 @@ trait CodegenSupport extends SparkPlan {
   protected var parent: CodegenSupport = null
 
   /**
-    * Whether this SparkPlan prefers to accept UnsafeRow as input in doConsume.
-    */
-  def preferUnsafeRow: Boolean = false
-
-  /**
     * Returns all the RDDs of InternalRow which generates the input rows.
     *
     * Note: right now we support up to two RDDs.
@@ -114,13 +109,52 @@ trait CodegenSupport extends SparkPlan {
   protected def doProduce(ctx: CodegenContext): String
 
   /**
-    * Consume the columns generated from current SparkPlan, call it's parent.
+    * Consume the generated columns or row from current SparkPlan, call it's 
parent's doConsume().
     */
-  final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = 
null): String = {
-    if (input != null) {
-      assert(input.length == output.length)
+  final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: 
String = null): String = {
+    val inputVars =
+      if (row != null) {
+        ctx.currentVars = null
+        ctx.INPUT_ROW = row
+        output.zipWithIndex.map { case (attr, i) =>
+          BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+        }
+      } else {
+        assert(outputVars != null)
+        assert(outputVars.length == output.length)
+        // outputVars will be used to generate the code for UnsafeRow, so we 
should copy them
+        outputVars.map(_.copy())
+      }
+    val rowVar = if (row != null) {
+      ExprCode("", "false", row)
+    } else {
+      if (outputVars.nonEmpty) {
+        val colExprs = output.zipWithIndex.map { case (attr, i) =>
+          BoundReference(i, attr.dataType, attr.nullable)
+        }
+        val evaluateInputs = evaluateVariables(outputVars)
+        // generate the code to create a UnsafeRow
+        ctx.currentVars = outputVars
+        val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
+        val code = s"""
+          |$evaluateInputs
+          |${ev.code.trim}
+         """.stripMargin.trim
+        ExprCode(code, "false", ev.value)
+      } else {
+        // There is no columns
+        ExprCode("", "false", "unsafeRow")
+      }
     }
-    parent.consumeChild(ctx, this, input, row)
+
+    ctx.freshNamePrefix = parent.variablePrefix
+    val evaluated = evaluateRequiredVariables(output, inputVars, 
parent.usedInputs)
+    s"""
+       |
+       |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
+       |${evaluated}
+       |${parent.doConsume(ctx, inputVars, rowVar)}
+     """.stripMargin
   }
 
   /**
@@ -160,47 +194,6 @@ trait CodegenSupport extends SparkPlan {
   def usedInputs: AttributeSet = references
 
   /**
-   * Consume the columns generated from its child, call doConsume() or emit 
the rows.
-   *
-   * An operator could generate variables for the output, or a row, either one 
could be null.
-   *
-   * If the row is not null, we create variables to access the columns that 
are actually used by
-   * current plan before calling doConsume().
-   */
-  def consumeChild(
-      ctx: CodegenContext,
-      child: SparkPlan,
-      input: Seq[ExprCode],
-      row: String = null): String = {
-    ctx.freshNamePrefix = variablePrefix
-    val inputVars =
-      if (row != null) {
-        ctx.currentVars = null
-        ctx.INPUT_ROW = row
-        child.output.zipWithIndex.map { case (attr, i) =>
-          BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
-        }
-      } else {
-        input
-      }
-
-    val evaluated =
-      if (row != null && preferUnsafeRow) {
-        // Current plan can consume UnsafeRows directly.
-        ""
-      } else {
-        evaluateRequiredVariables(child.output, inputVars, usedInputs)
-      }
-
-    s"""
-       |
-       |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
-       |${evaluated}
-       |${doConsume(ctx, inputVars, row)}
-     """.stripMargin
-  }
-
-  /**
     * Generate the Java source code to process the rows from child SparkPlan.
     *
     * This should be override by subclass to support codegen.
@@ -210,8 +203,10 @@ trait CodegenSupport extends SparkPlan {
     *   # code to evaluate the predicate expression, result is isNull1 and 
value2
     *   if (isNull1 || !value2) continue;
     *   # call consume(), which will call parent.doConsume()
+    *
+    * Note: A plan can either consume the rows as UnsafeRow (row), or a list 
of variables (input).
     */
-  protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
+  def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): 
String = {
     throw new UnsupportedOperationException
   }
 }
@@ -245,16 +240,11 @@ case class InputAdapter(child: SparkPlan) extends 
UnaryNode with CodegenSupport
     val input = ctx.freshName("input")
     // Right now, InputAdapter is only used when there is one upstream.
     ctx.addMutableState("scala.collection.Iterator", input, s"$input = 
inputs[0];")
-
-    val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, 
x._1.dataType, true))
     val row = ctx.freshName("row")
-    ctx.INPUT_ROW = row
-    ctx.currentVars = null
-    val columns = exprs.map(_.gen(ctx))
     s"""
        | while ($input.hasNext()) {
        |   InternalRow $row = (InternalRow) $input.next();
-       |   ${consume(ctx, columns, row).trim}
+       |   ${consume(ctx, null, row).trim}
        |   if (shouldStop()) return;
        | }
      """.stripMargin
@@ -282,18 +272,15 @@ object WholeStageCodegen {
   *     |
   *  doExecute() --------->   upstreams() -------> upstreams() ------> 
execute()
   *     |
-  *      ----------------->   produce()
+  *     +----------------->   produce()
   *                             |
   *                          doProduce()  -------> produce()
   *                                                   |
   *                                                doProduce()
   *                                                   |
-  *                                                consume()
-  *                        consumeChild() <-----------|
+  *                         doConsume() <--------- consume()
   *                             |
-  *                          doConsume()
-  *                             |
-  *  consumeChild()  <-----  consume()
+  *  doConsume()  <--------  consume()
   *
   * SparkPlan A should override doProduce() and doConsume().
   *
@@ -392,44 +379,16 @@ case class WholeStageCodegen(child: SparkPlan) extends 
UnaryNode with CodegenSup
     throw new UnsupportedOperationException
   }
 
-  override def consumeChild(
-      ctx: CodegenContext,
-      child: SparkPlan,
-      input: Seq[ExprCode],
-      row: String = null): String = {
-
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
     val doCopy = if (ctx.copyResult) {
       ".copy()"
     } else {
       ""
     }
-    if (row != null) {
-      // There is an UnsafeRow already
-      s"""
-         |append($row$doCopy);
-       """.stripMargin.trim
-    } else {
-      assert(input != null)
-      if (input.nonEmpty) {
-        val colExprs = output.zipWithIndex.map { case (attr, i) =>
-          BoundReference(i, attr.dataType, attr.nullable)
-        }
-        val evaluateInputs = evaluateVariables(input)
-        // generate the code to create a UnsafeRow
-        ctx.currentVars = input
-        val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
-        s"""
-           |$evaluateInputs
-           |${code.code.trim}
-           |append(${code.value}$doCopy);
-         """.stripMargin.trim
-      } else {
-        // There is no columns
-        s"""
-           |append(unsafeRow);
-         """.stripMargin.trim
-      }
-    }
+    s"""
+      |${row.code}
+      |append(${row.value}$doCopy);
+     """.stripMargin.trim
   }
 
   override def innerChildren: Seq[SparkPlan] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 28945a5..7c215d1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -139,7 +139,7 @@ case class TungstenAggregate(
     }
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
     if (groupingExpressions.isEmpty) {
       doConsumeWithoutKeys(ctx, input)
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index ee3f1d7..70e04d0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -49,7 +49,7 @@ case class Project(projectList: Seq[NamedExpression], child: 
SparkPlan)
     references.filter(a => usedMoreThanOnce.contains(a.exprId))
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
     val exprs = projectList.map(x =>
       ExpressionCanonicalizer.execute(BindReferences.bindReference(x, 
child.output)))
     ctx.currentVars = input
@@ -107,7 +107,7 @@ case class Filter(condition: Expression, child: SparkPlan)
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
     val numOutput = metricTerm(ctx, "numOutputRows")
 
     // filter out the nulls

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index d5ce124..5e573b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -137,7 +137,7 @@ package object debug {
       child.asInstanceOf[CodegenSupport].produce(ctx, this)
     }
 
-    override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
+    override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
       consume(ctx, input)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index aa2da28..f5b083c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -110,7 +110,7 @@ case class BroadcastHashJoin(
     streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
     joinType match {
       case Inner => codegenInner(ctx, input)
       case LeftOuter | RightOuter => codegenOuter(ctx, input)

http://git-wip-us.apache.org/repos/asf/spark/blob/bd94ea4c/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index ca624a5..9643b52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -65,7 +65,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport {
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
String): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
     val stopEarly = ctx.freshName("stopEarly")
     ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to