Repository: spark
Updated Branches:
  refs/heads/master a1e40b1f5 -> 4262fb0d5


[SPARK-19070] Clean-up dataset actions

## What changes were proposed in this pull request?
Dataset actions currently spin off a new `Dataframe` only to track query 
execution. This PR simplifies this code path by using the 
`Dataset.queryExecution` directly. This PR also merges the typed and untyped 
action evaluation paths.

## How was this patch tested?
Existing tests.

Author: Herman van Hovell <[email protected]>

Closes #16466 from hvanhovell/SPARK-19070.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4262fb0d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4262fb0d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4262fb0d

Branch: refs/heads/master
Commit: 4262fb0d55aed1a023e1813e09deefda8a7ce26b
Parents: a1e40b1
Author: Herman van Hovell <[email protected]>
Authored: Wed Jan 4 23:47:58 2017 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Wed Jan 4 23:47:58 2017 +0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 76 +++++++-------------
 1 file changed, 26 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4262fb0d/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2a06f3c..fd75d51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
PartitioningCollection}
 import org.apache.spark.sql.catalyst.util.usePrettyExpression
-import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, 
QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.command.{CreateViewCommand, 
ExplainCommand, GlobalTempView, LocalTempView}
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.python.EvaluatePython
@@ -2096,9 +2096,7 @@ class Dataset[T] private[sql](
    * @group action
    * @since 1.6.0
    */
-  def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df =>
-    df.collect(needCallback = false)
-  }
+  def head(n: Int): Array[T] = withAction("head", 
limit(n).queryExecution)(collectFromPlan)
 
   /**
    * Returns the first row.
@@ -2325,7 +2323,7 @@ class Dataset[T] private[sql](
   def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) 
: _*)
 
   /**
-   * Returns an array that contains all of [[Row]]s in this Dataset.
+   * Returns an array that contains all rows in this Dataset.
    *
    * Running collect requires moving all the data into the application's 
driver process, and
    * doing so on a very large dataset can crash the driver process with 
OutOfMemoryError.
@@ -2335,10 +2333,10 @@ class Dataset[T] private[sql](
    * @group action
    * @since 1.6.0
    */
-  def collect(): Array[T] = collect(needCallback = true)
+  def collect(): Array[T] = withAction("collect", 
queryExecution)(collectFromPlan)
 
   /**
-   * Returns a Java list that contains all of [[Row]]s in this Dataset.
+   * Returns a Java list that contains all rows in this Dataset.
    *
    * Running collect requires moving all the data into the application's 
driver process, and
    * doing so on a very large dataset can crash the driver process with 
OutOfMemoryError.
@@ -2346,27 +2344,13 @@ class Dataset[T] private[sql](
    * @group action
    * @since 1.6.0
    */
-  def collectAsList(): java.util.List[T] = withCallback("collectAsList", 
toDF()) { _ =>
-    withNewExecutionId {
-      val values = 
queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow)
-      java.util.Arrays.asList(values : _*)
-    }
-  }
-
-  private def collect(needCallback: Boolean): Array[T] = {
-    def execute(): Array[T] = withNewExecutionId {
-      queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow)
-    }
-
-    if (needCallback) {
-      withCallback("collect", toDF())(_ => execute())
-    } else {
-      execute()
-    }
+  def collectAsList(): java.util.List[T] = withAction("collectAsList", 
queryExecution) { plan =>
+    val values = collectFromPlan(plan)
+    java.util.Arrays.asList(values : _*)
   }
 
   /**
-   * Return an iterator that contains all of [[Row]]s in this Dataset.
+   * Return an iterator that contains all rows in this Dataset.
    *
    * The iterator will consume as much memory as the largest partition in this 
Dataset.
    *
@@ -2377,9 +2361,9 @@ class Dataset[T] private[sql](
    * @group action
    * @since 2.0.0
    */
-  def toLocalIterator(): java.util.Iterator[T] = 
withCallback("toLocalIterator", toDF()) { _ =>
-    withNewExecutionId {
-      
queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow).asJava
+  def toLocalIterator(): java.util.Iterator[T] = {
+    withAction("toLocalIterator", queryExecution) { plan =>
+      plan.executeToIterator().map(boundEnc.fromRow).asJava
     }
   }
 
@@ -2388,8 +2372,8 @@ class Dataset[T] private[sql](
    * @group action
    * @since 1.6.0
    */
-  def count(): Long = withCallback("count", groupBy().count()) { df =>
-    df.collect(needCallback = false).head.getLong(0)
+  def count(): Long = withAction("count", groupBy().count().queryExecution) { 
plan =>
+    plan.executeCollect().head.getLong(0)
   }
 
   /**
@@ -2762,38 +2746,30 @@ class Dataset[T] private[sql](
    * Wrap a Dataset action to track the QueryExecution and time cost, then 
report to the
    * user-registered callback functions.
    */
-  private def withCallback[U](name: String, df: DataFrame)(action: DataFrame 
=> U) = {
+  private def withAction[U](name: String, qe: QueryExecution)(action: 
SparkPlan => U) = {
     try {
-      df.queryExecution.executedPlan.foreach { plan =>
+      qe.executedPlan.foreach { plan =>
         plan.resetMetrics()
       }
       val start = System.nanoTime()
-      val result = action(df)
+      val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
+        action(qe.executedPlan)
+      }
       val end = System.nanoTime()
-      sparkSession.listenerManager.onSuccess(name, df.queryExecution, end - 
start)
+      sparkSession.listenerManager.onSuccess(name, qe, end - start)
       result
     } catch {
       case e: Exception =>
-        sparkSession.listenerManager.onFailure(name, df.queryExecution, e)
+        sparkSession.listenerManager.onFailure(name, qe, e)
         throw e
     }
   }
 
-  private def withTypedCallback[A, B](name: String, ds: Dataset[A])(action: 
Dataset[A] => B) = {
-    try {
-      ds.queryExecution.executedPlan.foreach { plan =>
-        plan.resetMetrics()
-      }
-      val start = System.nanoTime()
-      val result = action(ds)
-      val end = System.nanoTime()
-      sparkSession.listenerManager.onSuccess(name, ds.queryExecution, end - 
start)
-      result
-    } catch {
-      case e: Exception =>
-        sparkSession.listenerManager.onFailure(name, ds.queryExecution, e)
-        throw e
-    }
+  /**
+   * Collect all elements from a spark plan.
+   */
+  private def collectFromPlan(plan: SparkPlan): Array[T] = {
+    plan.executeCollect().map(boundEnc.fromRow)
   }
 
   private def sortInternal(global: Boolean, sortExprs: Seq[Column]): 
Dataset[T] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to