Repository: spark
Updated Branches:
refs/heads/master 838cb4583 -> e58c4cb3c
[SPARK-14227][SQL] Add method for printing out generated code for debugging
## What changes were proposed in this pull request?
This adds `debugCodegen` to the debug package for query execution.
## How was this patch tested?
Unit and manual testing. Output example:
```
scala> import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.debug._
scala> sqlContext.range(100).groupBy("id").count().orderBy("id").debugCodegen()
Found 3 WholeStageCodegen subtrees.
== Subtree 1 / 3 ==
WholeStageCodegen
: +- TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L])
: +- Range 0, 1, 1, 100, [id#0L]
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ /** Codegened pipeline for:
/* 006 */ * TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L])
/* 007 */ +- Range 0, 1, 1, 100, [id#0L]
/* 008 */ */
/* 009 */ final class GeneratedIterator extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 010 */ private Object[] references;
/* 011 */ private boolean agg_initAgg;
/* 012 */ private org.apache.spark.sql.execution.aggregate.TungstenAggregate
agg_plan;
/* 013 */ private
org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 014 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter
agg_sorter;
/* 015 */ private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 016 */ private org.apache.spark.sql.execution.metric.LongSQLMetric
range_numOutputRows;
/* 017 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue
range_metricValue;
/* 018 */ private boolean range_initRange;
/* 019 */ private long range_partitionEnd;
/* 020 */ private long range_number;
/* 021 */ private boolean range_overflow;
/* 022 */ private scala.collection.Iterator range_input;
/* 023 */ private UnsafeRow range_result;
/* 024 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 025 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
range_rowWriter;
/* 026 */ private UnsafeRow agg_result;
/* 027 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 028 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 029 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowJoiner
agg_unsafeRowJoiner;
/* 030 */ private org.apache.spark.sql.execution.metric.LongSQLMetric
wholestagecodegen_numOutputRows;
/* 031 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue
wholestagecodegen_metricValue;
/* 032 */
/* 033 */ public GeneratedIterator(Object[] references) {
/* 034 */ this.references = references;
/* 035 */ }
/* 036 */
/* 037 */ public void init(scala.collection.Iterator inputs[]) {
/* 038 */ agg_initAgg = false;
/* 039 */ this.agg_plan =
(org.apache.spark.sql.execution.aggregate.TungstenAggregate) references[0];
/* 040 */ agg_hashMap = agg_plan.createHashMap();
/* 041 */
/* 042 */ this.range_numOutputRows =
(org.apache.spark.sql.execution.metric.LongSQLMetric) references[1];
/* 043 */ range_metricValue =
(org.apache.spark.sql.execution.metric.LongSQLMetricValue)
range_numOutputRows.localValue();
/* 044 */ range_initRange = false;
/* 045 */ range_partitionEnd = 0L;
/* 046 */ range_number = 0L;
/* 047 */ range_overflow = false;
/* 048 */ range_input = inputs[0];
/* 049 */ range_result = new UnsafeRow(1);
/* 050 */ this.range_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 051 */ this.range_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder,
1);
/* 052 */ agg_result = new UnsafeRow(1);
/* 053 */ this.agg_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 054 */ this.agg_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder,
1);
/* 055 */ agg_unsafeRowJoiner = agg_plan.createUnsafeJoiner();
/* 056 */ this.wholestagecodegen_numOutputRows =
(org.apache.spark.sql.execution.metric.LongSQLMetric) references[2];
/* 057 */ wholestagecodegen_metricValue =
(org.apache.spark.sql.execution.metric.LongSQLMetricValue)
wholestagecodegen_numOutputRows.localValue();
/* 058 */ }
/* 059 */
/* 060 */ private void agg_doAggregateWithKeys() throws java.io.IOException {
/* 061 */ /*** PRODUCE: Range 0, 1, 1, 100, [id#0L] */
/* 062 */
/* 063 */ // initialize Range
/* 064 */ if (!range_initRange) {
/* 065 */ range_initRange = true;
/* 066 */ if (range_input.hasNext()) {
/* 067 */ initRange(((InternalRow) range_input.next()).getInt(0));
/* 068 */ } else {
/* 069 */ return;
/* 070 */ }
/* 071 */ }
/* 072 */
/* 073 */ while (!range_overflow && range_number < range_partitionEnd) {
/* 074 */ long range_value = range_number;
/* 075 */ range_number += 1L;
/* 076 */ if (range_number < range_value ^ 1L < 0) {
/* 077 */ range_overflow = true;
/* 078 */ }
/* 079 */
/* 080 */ /*** CONSUME: TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L])
*/
/* 081 */
/* 082 */ // generate grouping key
/* 083 */ agg_rowWriter.write(0, range_value);
/* 084 */ /* hash(input[0, bigint], 42) */
/* 085 */ int agg_value1 = 42;
/* 086 */
/* 087 */ agg_value1 =
org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(range_value, agg_value1);
/* 088 */ UnsafeRow agg_aggBuffer = null;
/* 089 */ if (true) {
/* 090 */ // try to get the buffer from hash map
/* 091 */ agg_aggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1);
/* 092 */ }
/* 093 */ if (agg_aggBuffer == null) {
/* 094 */ if (agg_sorter == null) {
/* 095 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
/* 096 */ } else {
/* 097 */
agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 098 */ }
/* 099 */
/* 100 */ // the hash map had be spilled, it should have enough memory
now,
/* 101 */ // try to allocate buffer again.
/* 102 */ agg_aggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1);
/* 103 */ if (agg_aggBuffer == null) {
/* 104 */ // failed to allocate the first page
/* 105 */ throw new OutOfMemoryError("No enough memory for
aggregation");
/* 106 */ }
/* 107 */ }
/* 108 */
/* 109 */ // evaluate aggregate function
/* 110 */ /* (input[0, bigint] + 1) */
/* 111 */ /* input[0, bigint] */
/* 112 */ long agg_value4 = agg_aggBuffer.getLong(0);
/* 113 */
/* 114 */ long agg_value3 = -1L;
/* 115 */ agg_value3 = agg_value4 + 1L;
/* 116 */ // update aggregate buffer
/* 117 */ agg_aggBuffer.setLong(0, agg_value3);
/* 118 */
/* 119 */ if (shouldStop()) return;
/* 120 */ }
/* 121 */
/* 122 */ agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter);
/* 123 */ }
/* 124 */
/* 125 */ private void initRange(int idx) {
/* 126 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 127 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(1L);
/* 128 */ java.math.BigInteger numElement =
java.math.BigInteger.valueOf(100L);
/* 129 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 130 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 131 */
/* 132 */ java.math.BigInteger st =
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 133 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) >
0) {
/* 134 */ range_number = Long.MAX_VALUE;
/* 135 */ } else if
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 136 */ range_number = Long.MIN_VALUE;
/* 137 */ } else {
/* 138 */ range_number = st.longValue();
/* 139 */ }
/* 140 */
/* 141 */ java.math.BigInteger end =
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 142 */ .multiply(step).add(start);
/* 143 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) >
0) {
/* 144 */ range_partitionEnd = Long.MAX_VALUE;
/* 145 */ } else if
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 146 */ range_partitionEnd = Long.MIN_VALUE;
/* 147 */ } else {
/* 148 */ range_partitionEnd = end.longValue();
/* 149 */ }
/* 150 */
/* 151 */ range_metricValue.add((range_partitionEnd - range_number) / 1L);
/* 152 */ }
/* 153 */
/* 154 */ protected void processNext() throws java.io.IOException {
/* 155 */ /*** PRODUCE: TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L])
*/
/* 156 */
/* 157 */ if (!agg_initAgg) {
/* 158 */ agg_initAgg = true;
/* 159 */ agg_doAggregateWithKeys();
/* 160 */ }
/* 161 */
/* 162 */ // output the result
/* 163 */ while (agg_mapIter.next()) {
/* 164 */ wholestagecodegen_metricValue.add(1);
/* 165 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 166 */ UnsafeRow agg_aggBuffer1 = (UnsafeRow) agg_mapIter.getValue();
/* 167 */
/* 168 */ UnsafeRow agg_resultRow = agg_unsafeRowJoiner.join(agg_aggKey,
agg_aggBuffer1);
/* 169 */
/* 170 */ /*** CONSUME: WholeStageCodegen */
/* 171 */
/* 172 */ append(agg_resultRow);
/* 173 */
/* 174 */ if (shouldStop()) return;
/* 175 */ }
/* 176 */
/* 177 */ agg_mapIter.close();
/* 178 */ if (agg_sorter == null) {
/* 179 */ agg_hashMap.free();
/* 180 */ }
/* 181 */ }
/* 182 */ }
== Subtree 2 / 3 ==
WholeStageCodegen
: +- Sort [id#0L ASC], true, 0
: +- INPUT
+- Exchange rangepartitioning(id#0L ASC, 200), None
+- WholeStageCodegen
: +- TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L])
: +- INPUT
+- Exchange hashpartitioning(id#0L, 200), None
+- WholeStageCodegen
: +- TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L])
: +- Range 0, 1, 1, 100, [id#0L]
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ /** Codegened pipeline for:
/* 006 */ * Sort [id#0L ASC], true, 0
/* 007 */ +- INPUT
/* 008 */ */
/* 009 */ final class GeneratedIterator extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 010 */ private Object[] references;
/* 011 */ private boolean sort_needToSort;
/* 012 */ private org.apache.spark.sql.execution.Sort sort_plan;
/* 013 */ private org.apache.spark.sql.execution.UnsafeExternalRowSorter
sort_sorter;
/* 014 */ private org.apache.spark.executor.TaskMetrics sort_metrics;
/* 015 */ private scala.collection.Iterator<UnsafeRow> sort_sortedIter;
/* 016 */ private scala.collection.Iterator inputadapter_input;
/* 017 */ private org.apache.spark.sql.execution.metric.LongSQLMetric
sort_dataSize;
/* 018 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue
sort_metricValue;
/* 019 */ private org.apache.spark.sql.execution.metric.LongSQLMetric
sort_spillSize;
/* 020 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue
sort_metricValue1;
/* 021 */
/* 022 */ public GeneratedIterator(Object[] references) {
/* 023 */ this.references = references;
/* 024 */ }
/* 025 */
/* 026 */ public void init(scala.collection.Iterator inputs[]) {
/* 027 */ sort_needToSort = true;
/* 028 */ this.sort_plan = (org.apache.spark.sql.execution.Sort)
references[0];
/* 029 */ sort_sorter = sort_plan.createSorter();
/* 030 */ sort_metrics = org.apache.spark.TaskContext.get().taskMetrics();
/* 031 */
/* 032 */ inputadapter_input = inputs[0];
/* 033 */ this.sort_dataSize =
(org.apache.spark.sql.execution.metric.LongSQLMetric) references[1];
/* 034 */ sort_metricValue =
(org.apache.spark.sql.execution.metric.LongSQLMetricValue)
sort_dataSize.localValue();
/* 035 */ this.sort_spillSize =
(org.apache.spark.sql.execution.metric.LongSQLMetric) references[2];
/* 036 */ sort_metricValue1 =
(org.apache.spark.sql.execution.metric.LongSQLMetricValue)
sort_spillSize.localValue();
/* 037 */ }
/* 038 */
/* 039 */ private void sort_addToSorter() throws java.io.IOException {
/* 040 */ /*** PRODUCE: INPUT */
/* 041 */
/* 042 */ while (inputadapter_input.hasNext()) {
/* 043 */ InternalRow inputadapter_row = (InternalRow)
inputadapter_input.next();
/* 044 */ /*** CONSUME: Sort [id#0L ASC], true, 0 */
/* 045 */
/* 046 */ sort_sorter.insertRow((UnsafeRow)inputadapter_row);
/* 047 */ if (shouldStop()) return;
/* 048 */ }
/* 049 */
/* 050 */ }
/* 051 */
/* 052 */ protected void processNext() throws java.io.IOException {
/* 053 */ /*** PRODUCE: Sort [id#0L ASC], true, 0 */
/* 054 */ if (sort_needToSort) {
/* 055 */ sort_addToSorter();
/* 056 */ Long sort_spillSizeBefore = sort_metrics.memoryBytesSpilled();
/* 057 */ sort_sortedIter = sort_sorter.sort();
/* 058 */ sort_metricValue.add(sort_sorter.getPeakMemoryUsage());
/* 059 */ sort_metricValue1.add(sort_metrics.memoryBytesSpilled() -
sort_spillSizeBefore);
/* 060 */
sort_metrics.incPeakExecutionMemory(sort_sorter.getPeakMemoryUsage());
/* 061 */ sort_needToSort = false;
/* 062 */ }
/* 063 */
/* 064 */ while (sort_sortedIter.hasNext()) {
/* 065 */ UnsafeRow sort_outputRow = (UnsafeRow)sort_sortedIter.next();
/* 066 */
/* 067 */ /*** CONSUME: WholeStageCodegen */
/* 068 */
/* 069 */ append(sort_outputRow);
/* 070 */
/* 071 */ if (shouldStop()) return;
/* 072 */ }
/* 073 */ }
/* 074 */ }
== Subtree 3 / 3 ==
WholeStageCodegen
: +- TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L])
: +- INPUT
+- Exchange hashpartitioning(id#0L, 200), None
+- WholeStageCodegen
: +- TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L])
: +- Range 0, 1, 1, 100, [id#0L]
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ /** Codegened pipeline for:
/* 006 */ * TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L])
/* 007 */ +- INPUT
/* 008 */ */
/* 009 */ final class GeneratedIterator extends
org.apache.spark.sql.execution.BufferedRowIterator {
/* 010 */ private Object[] references;
/* 011 */ private boolean agg_initAgg;
/* 012 */ private org.apache.spark.sql.execution.aggregate.TungstenAggregate
agg_plan;
/* 013 */ private
org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
/* 014 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter
agg_sorter;
/* 015 */ private org.apache.spark.unsafe.KVIterator agg_mapIter;
/* 016 */ private scala.collection.Iterator inputadapter_input;
/* 017 */ private UnsafeRow agg_result;
/* 018 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 019 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 020 */ private UnsafeRow agg_result1;
/* 021 */ private
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder1;
/* 022 */ private
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
agg_rowWriter1;
/* 023 */ private org.apache.spark.sql.execution.metric.LongSQLMetric
wholestagecodegen_numOutputRows;
/* 024 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue
wholestagecodegen_metricValue;
/* 025 */
/* 026 */ public GeneratedIterator(Object[] references) {
/* 027 */ this.references = references;
/* 028 */ }
/* 029 */
/* 030 */ public void init(scala.collection.Iterator inputs[]) {
/* 031 */ agg_initAgg = false;
/* 032 */ this.agg_plan =
(org.apache.spark.sql.execution.aggregate.TungstenAggregate) references[0];
/* 033 */ agg_hashMap = agg_plan.createHashMap();
/* 034 */
/* 035 */ inputadapter_input = inputs[0];
/* 036 */ agg_result = new UnsafeRow(1);
/* 037 */ this.agg_holder = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 038 */ this.agg_rowWriter = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder,
1);
/* 039 */ agg_result1 = new UnsafeRow(2);
/* 040 */ this.agg_holder1 = new
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result1, 0);
/* 041 */ this.agg_rowWriter1 = new
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder1,
2);
/* 042 */ this.wholestagecodegen_numOutputRows =
(org.apache.spark.sql.execution.metric.LongSQLMetric) references[1];
/* 043 */ wholestagecodegen_metricValue =
(org.apache.spark.sql.execution.metric.LongSQLMetricValue)
wholestagecodegen_numOutputRows.localValue();
/* 044 */ }
/* 045 */
/* 046 */ private void agg_doAggregateWithKeys() throws java.io.IOException {
/* 047 */ /*** PRODUCE: INPUT */
/* 048 */
/* 049 */ while (inputadapter_input.hasNext()) {
/* 050 */ InternalRow inputadapter_row = (InternalRow)
inputadapter_input.next();
/* 051 */ /*** CONSUME: TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) */
/* 052 */ /* input[0, bigint] */
/* 053 */ long inputadapter_value = inputadapter_row.getLong(0);
/* 054 */ /* input[1, bigint] */
/* 055 */ long inputadapter_value1 = inputadapter_row.getLong(1);
/* 056 */
/* 057 */ // generate grouping key
/* 058 */ agg_rowWriter.write(0, inputadapter_value);
/* 059 */ /* hash(input[0, bigint], 42) */
/* 060 */ int agg_value1 = 42;
/* 061 */
/* 062 */ agg_value1 =
org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(inputadapter_value,
agg_value1);
/* 063 */ UnsafeRow agg_aggBuffer = null;
/* 064 */ if (true) {
/* 065 */ // try to get the buffer from hash map
/* 066 */ agg_aggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1);
/* 067 */ }
/* 068 */ if (agg_aggBuffer == null) {
/* 069 */ if (agg_sorter == null) {
/* 070 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
/* 071 */ } else {
/* 072 */
agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
/* 073 */ }
/* 074 */
/* 075 */ // the hash map had be spilled, it should have enough memory
now,
/* 076 */ // try to allocate buffer again.
/* 077 */ agg_aggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1);
/* 078 */ if (agg_aggBuffer == null) {
/* 079 */ // failed to allocate the first page
/* 080 */ throw new OutOfMemoryError("No enough memory for
aggregation");
/* 081 */ }
/* 082 */ }
/* 083 */
/* 084 */ // evaluate aggregate function
/* 085 */ /* (input[0, bigint] + input[2, bigint]) */
/* 086 */ /* input[0, bigint] */
/* 087 */ long agg_value4 = agg_aggBuffer.getLong(0);
/* 088 */
/* 089 */ long agg_value3 = -1L;
/* 090 */ agg_value3 = agg_value4 + inputadapter_value1;
/* 091 */ // update aggregate buffer
/* 092 */ agg_aggBuffer.setLong(0, agg_value3);
/* 093 */ if (shouldStop()) return;
/* 094 */ }
/* 095 */
/* 096 */ agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter);
/* 097 */ }
/* 098 */
/* 099 */ protected void processNext() throws java.io.IOException {
/* 100 */ /*** PRODUCE: TungstenAggregate(key=[id#0L],
functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) */
/* 101 */
/* 102 */ if (!agg_initAgg) {
/* 103 */ agg_initAgg = true;
/* 104 */ agg_doAggregateWithKeys();
/* 105 */ }
/* 106 */
/* 107 */ // output the result
/* 108 */ while (agg_mapIter.next()) {
/* 109 */ wholestagecodegen_metricValue.add(1);
/* 110 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
/* 111 */ UnsafeRow agg_aggBuffer1 = (UnsafeRow) agg_mapIter.getValue();
/* 112 */
/* 113 */ /* input[0, bigint] */
/* 114 */ long agg_value6 = agg_aggKey.getLong(0);
/* 115 */ /* input[0, bigint] */
/* 116 */ long agg_value7 = agg_aggBuffer1.getLong(0);
/* 117 */
/* 118 */ /*** CONSUME: WholeStageCodegen */
/* 119 */
/* 120 */ agg_rowWriter1.write(0, agg_value6);
/* 121 */
/* 122 */ agg_rowWriter1.write(1, agg_value7);
/* 123 */ append(agg_result1);
/* 124 */
/* 125 */ if (shouldStop()) return;
/* 126 */ }
/* 127 */
/* 128 */ agg_mapIter.close();
/* 129 */ if (agg_sorter == null) {
/* 130 */ agg_hashMap.free();
/* 131 */ }
/* 132 */ }
/* 133 */ }
```
rxin
Author: Eric Liang <[email protected]>
Closes #12025 from ericl/spark-14227.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e58c4cb3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e58c4cb3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e58c4cb3
Branch: refs/heads/master
Commit: e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a
Parents: 838cb45
Author: Eric Liang <[email protected]>
Authored: Tue Mar 29 13:31:51 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Tue Mar 29 13:31:51 2016 -0700
----------------------------------------------------------------------
.../spark/sql/execution/WholeStageCodegen.scala | 13 +++++-
.../spark/sql/execution/debug/package.scala | 46 +++++++++++++++++---
.../sql/execution/debug/DebuggingSuite.scala | 7 +++
3 files changed, 60 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e58c4cb3/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 1b13c8f..da3ee46 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -297,7 +297,12 @@ case class WholeStageCodegen(child: SparkPlan) extends
UnaryNode with CodegenSup
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
WholeStageCodegen.PIPELINE_DURATION_METRIC))
- override def doExecute(): RDD[InternalRow] = {
+ /**
+ * Generates code for this subtree.
+ *
+ * @return the tuple of the codegen context and the actual generated source.
+ */
+ def doCodeGen(): (CodegenContext, String) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val references = ctx.references.toArray
@@ -334,6 +339,12 @@ case class WholeStageCodegen(child: SparkPlan) extends
UnaryNode with CodegenSup
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
logDebug(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
+ (ctx, cleanedSource)
+ }
+
+ override def doExecute(): RDD[InternalRow] = {
+ val (ctx, cleanedSource) = doCodeGen()
+ val references = ctx.references.toArray
val durationMs = longMetric("pipelineTime")
http://git-wip-us.apache.org/repos/asf/spark/blob/e58c4cb3/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 5e573b3..9916482 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter,
CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf
@@ -41,6 +41,13 @@ import org.apache.spark.sql.internal.SQLConf
*/
package object debug {
+ /** Helper function to evade the println() linter. */
+ private def debugPrint(msg: String): Unit = {
+ // scalastyle:off println
+ println(msg)
+ // scalastyle:on println
+ }
+
/**
* Augments [[SQLContext]] with debug methods.
*/
@@ -62,12 +69,41 @@ package object debug {
visited += new TreeNodeRef(s)
DebugNode(s)
}
- logDebug(s"Results returned: ${debugPlan.execute().count()}")
+ debugPrint(s"Results returned: ${debugPlan.execute().count()}")
debugPlan.foreach {
case d: DebugNode => d.dumpStats()
case _ =>
}
}
+
+ /**
+ * Prints to stdout all the generated code found in this plan (i.e. the
output of each
+ * WholeStageCodegen subtree).
+ */
+ def debugCodegen(): Unit = {
+ debugPrint(debugCodegenString())
+ }
+
+ /** Visible for testing. */
+ def debugCodegenString(): String = {
+ val plan = query.queryExecution.executedPlan
+ val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]()
+ plan transform {
+ case s: WholeStageCodegen =>
+ codegenSubtrees += s
+ s
+ case s => s
+ }
+ var output = s"Found ${codegenSubtrees.size} WholeStageCodegen
subtrees.\n"
+ for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) {
+ output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n"
+ output += s
+ output += "\nGenerated code:\n"
+ val (_, source) = s.doCodeGen()
+ output += s"${CodeFormatter.format(source)}\n"
+ }
+ output
+ }
}
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with
CodegenSupport {
@@ -99,11 +135,11 @@ package object debug {
val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new
ColumnMetrics())
def dumpStats(): Unit = {
- logDebug(s"== ${child.simpleString} ==")
- logDebug(s"Tuples output: ${tupleCount.value}")
+ debugPrint(s"== ${child.simpleString} ==")
+ debugPrint(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
- logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
+ debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e58c4cb3/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 2218947..979265e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -25,4 +25,11 @@ class DebuggingSuite extends SparkFunSuite with
SharedSQLContext {
test("DataFrame.debug()") {
testData.debug()
}
+
+ test("debugCodegen") {
+ val res = sqlContext.range(10).groupBy("id").count().debugCodegenString()
+ assert(res.contains("Subtree 1 / 2"))
+ assert(res.contains("Subtree 2 / 2"))
+ assert(res.contains("Object[]"))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]