Repository: spark Updated Branches: refs/heads/master 10ef4f3e7 -> 750ed64cd
[SPARK-13930] [SQL] Apply fast serialization on collect limit operator ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-13930 Recently the fast serialization has been introduced to collecting DataFrame/Dataset (#11664). The same technology can be used on collect limit operator too. ## How was this patch tested? Add a benchmark for collect limit to `BenchmarkWholeStageCodegen`. Without this patch: model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- collect limit 1 million 3413 / 3768 0.3 3255.0 1.0X collect limit 2 millions 9728 / 10440 0.1 9277.3 0.4X With this patch: model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- collect limit 1 million 833 / 1284 1.3 794.4 1.0X collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X Author: Liang-Chi Hsieh <[email protected]> Closes #11759 from viirya/execute-take. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/750ed64c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/750ed64c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/750ed64c Branch: refs/heads/master Commit: 750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d Parents: 10ef4f3 Author: Liang-Chi Hsieh <[email protected]> Authored: Thu Mar 17 23:24:44 2016 -0700 Committer: Davies Liu <[email protected]> Committed: Thu Mar 17 23:24:44 2016 -0700 ---------------------------------------------------------------------- .../apache/spark/sql/execution/SparkPlan.scala | 78 +++++++++++++------- .../execution/BenchmarkWholeStageCodegen.scala | 21 ++++++ 2 files changed, 71 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/750ed64c/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a392b53..010ed7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -219,48 +219,62 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** - * Runs this query returning the result as an array. + * Packing the UnsafeRows into byte array for faster serialization. + * The byte arrays are in the following format: + * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + * + * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + * compressed. */ - def executeCollect(): Array[InternalRow] = { - // Packing the UnsafeRows into byte array for faster serialization. - // The byte arrays are in the following format: - // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] - // - // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also - // compressed. - val byteArrayRdd = execute().mapPartitionsInternal { iter => + private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + execute().mapPartitionsInternal { iter => + var count = 0 val buffer = new Array[Byte](4 << 10) // 4K val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bos = new ByteArrayOutputStream() val out = new DataOutputStream(codec.compressedOutputStream(bos)) - while (iter.hasNext) { + while (iter.hasNext && (n < 0 || count < n)) { val row = iter.next().asInstanceOf[UnsafeRow] out.writeInt(row.getSizeInBytes) row.writeToStream(out, buffer) + count += 1 } out.writeInt(-1) out.flush() out.close() Iterator(bos.toByteArray) } + } - // Collect the byte arrays back to driver, then decode them as UnsafeRows. + /** + * Decode the byte arrays back to UnsafeRows and put them into buffer. + */ + private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = { val nFields = schema.length - val results = ArrayBuffer[InternalRow]() + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + buffer += row + sizeOfNextRow = ins.readInt() + } + } + + /** + * Runs this query returning the result as an array. + */ + def executeCollect(): Array[InternalRow] = { + val byteArrayRdd = getByteArrayRdd() + + val results = ArrayBuffer[InternalRow]() byteArrayRdd.collect().foreach { bytes => - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val bis = new ByteArrayInputStream(bytes) - val ins = new DataInputStream(codec.compressedInputStream(bis)) - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(nFields) - row.pointTo(bs, sizeOfNextRow) - results += row - sizeOfNextRow = ins.readInt() - } + decodeUnsafeRows(bytes, results) } results.toArray } @@ -283,7 +297,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = execute().map(_.copy()) + val childRDD = getByteArrayRdd(n) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length @@ -307,13 +321,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val left = n - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) + val res = sc.runJob(childRDD, + (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) + + res.foreach { r => + decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf) + } - res.foreach(buf ++= _.take(n - buf.size)) partsScanned += p.size } - buf.toArray + if (buf.size > n) { + buf.take(n).toArray + } else { + buf.toArray + } } private[this] def isTesting: Boolean = sys.props.contains("spark.testing") http://git-wip-us.apache.org/repos/asf/spark/blob/750ed64c/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index b6051b0..d293ff6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -465,4 +465,25 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { collect 4 millions 3193 / 3895 0.3 3044.7 0.1X */ } + + ignore("collect limit") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect limit", N) + benchmark.addCase("collect limit 1 million") { iter => + sqlContext.range(N * 4).limit(N).collect() + } + benchmark.addCase("collect limit 2 millions") { iter => + sqlContext.range(N * 4).limit(N * 2).collect() + } + benchmark.run() + + /** + model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) + collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect limit 1 million 833 / 1284 1.3 794.4 1.0X + collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X + */ + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
