This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ff1ac5  [SPARK-25353][SQL] executeTake in SparkPlan is modified to 
avoid unnecessary decoding.
2ff1ac5 is described below

commit 2ff1ac5d9f95576bb6998b989483bb060ab42228
Author: Dooyoung Hwang <[email protected]>
AuthorDate: Tue Jul 2 20:55:24 2019 +0800

    [SPARK-25353][SQL] executeTake in SparkPlan is modified to avoid 
unnecessary decoding.
    
    ## What changes were proposed in this pull request?
    In some cases, executeTake in SparkPlan could decode more than necessary.
    
    For example, in case of below odd/even number partitioning, total row's 
count from partitions will be 100, although it is limited with 51. And 
'executeTake' in SparkPlan decodes all of them, "49" rows of which are 
unnecessarily decoded.
    
    ```scala
    spark.sparkContext.parallelize((0 until 100).map(i => (i, 1))).toDF()
          .repartitionByRange(2, $"_1" % 2).limit(51).collect()
    ```
    
    By using a iterator of the scalar collection, we can make ensure that at 
most n rows are decoded.
    
    ## How was this patch tested?
    Existing unit tests that call limit function of DataFrame.
    
    testOnly *SQLQuerySuite
    testOnly *DataFrameSuite
    
    Closes #22347 from Dooyoung-Hwang/refactor_execute_take.
    
    Authored-by: Dooyoung Hwang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../org/apache/spark/sql/execution/SparkPlan.scala | 34 ++++++++++++----------
 1 file changed, 19 insertions(+), 15 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 6deb90c..2baf2e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -409,12 +409,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] 
with Logging with Serializ
       return new Array[InternalRow](0)
     }
 
-    val childRDD = getByteArrayRdd(n).map(_._2)
+    val childRDD = getByteArrayRdd(n)
 
     val buf = new ArrayBuffer[InternalRow]
     val totalParts = childRDD.partitions.length
     var partsScanned = 0
-    while (buf.size < n && partsScanned < totalParts) {
+    while (buf.length < n && partsScanned < totalParts) {
       // The number of partitions to try in this iteration. It is ok for this 
number to be
       // greater than totalParts because we actually cap it at totalParts in 
runJob.
       var numPartsToTry = 1L
@@ -426,28 +426,32 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] 
with Logging with Serializ
         if (buf.isEmpty) {
           numPartsToTry = partsScanned * limitScaleUpFactor
         } else {
-          val left = n - buf.size
+          val left = n - buf.length
           // As left > 0, numPartsToTry is always >= 1
-          numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
+          numPartsToTry = Math.ceil(1.5 * left * partsScanned / 
buf.length).toInt
           numPartsToTry = Math.min(numPartsToTry, partsScanned * 
limitScaleUpFactor)
         }
       }
 
       val p = partsScanned.until(math.min(partsScanned + numPartsToTry, 
totalParts).toInt)
       val sc = sqlContext.sparkContext
-      val res = sc.runJob(childRDD,
-        (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else 
Array.empty[Byte], p)
-
-      buf ++= res.flatMap(decodeUnsafeRows)
-
+      val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
+        if (it.hasNext) it.next() else (0L, Array.empty[Byte]), p)
+
+      var i = 0
+      while (buf.length < n && i < res.length) {
+        val rows = decodeUnsafeRows(res(i)._2)
+        val rowsToTake = if (n - buf.length >= res(i)._1) {
+          rows.toArray
+        } else {
+          rows.take(n - buf.length).toArray
+        }
+        buf ++= rowsToTake
+        i += 1
+      }
       partsScanned += p.size
     }
-
-    if (buf.size > n) {
-      buf.take(n).toArray
-    } else {
-      buf.toArray
-    }
+    buf.toArray
   }
 
   protected def newMutableProjection(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to