This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 790779f  [SparkLoad]remove unncessary convert from dataframe to rdd 
(#4304)
790779f is described below

commit 790779fb6f93ca3152380a376215767b3c1eff38
Author: wangbo <506340...@qq.com>
AuthorDate: Thu Aug 13 23:37:38 2020 +0800

    [SparkLoad]remove unncessary convert from dataframe to rdd (#4304)
---
 .../org/apache/doris/load/loadv2/dpp/SparkDpp.java | 341 +++++++++++----------
 .../doris/load/loadv2/dpp/SparkRDDAggregator.java  |  75 +++--
 2 files changed, 208 insertions(+), 208 deletions(-)

diff --git 
a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkDpp.java 
b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkDpp.java
index 898bd9f..53c9111 100644
--- a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkDpp.java
+++ b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkDpp.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.load.loadv2.dpp;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.doris.common.SparkDppException;
 import org.apache.doris.load.loadv2.etl.EtlJobConfig;
 
@@ -38,10 +39,8 @@ import org.apache.spark.Partitioner;
 import org.apache.spark.TaskContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.ForeachPartitionFunction;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.sql.Column;
+import org.apache.spark.api.java.function.VoidFunction;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
@@ -54,6 +53,7 @@ import org.apache.spark.sql.functions;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
+import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.util.LongAccumulator;
 
 import java.io.IOException;
@@ -70,12 +70,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Queue;
 import java.util.Set;
-import java.util.stream.Collectors;
 
 import org.apache.spark.util.SerializableConfiguration;
 import scala.Tuple2;
-import scala.collection.JavaConverters;
-import scala.collection.Seq;
 
 // This class is a Spark-based data preprocessing program,
 // which will make use of the distributed compute framework of spark to
@@ -125,81 +122,70 @@ public final class SparkDpp implements 
java.io.Serializable {
         this.serializableHadoopConf = new 
SerializableConfiguration(spark.sparkContext().hadoopConfiguration());
     }
 
-    private Dataset<Row> processRDDAggAndRepartition(Dataset<Row> dataframe, 
EtlJobConfig.EtlIndex currentIndexMeta) throws SparkDppException {
-        final boolean isDuplicateTable = 
!StringUtils.equalsIgnoreCase(currentIndexMeta.indexType, "AGGREGATE")
-                && !StringUtils.equalsIgnoreCase(currentIndexMeta.indexType, 
"UNIQUE");
-
-        // 1 make metadata for map/reduce
-        int keyLen = 0;
-        for (EtlJobConfig.EtlColumn etlColumn : currentIndexMeta.columns) {
-            keyLen = etlColumn.isKey ? keyLen + 1 : keyLen;
-        }
-
-        SparkRDDAggregator[] sparkRDDAggregators = new 
SparkRDDAggregator[currentIndexMeta.columns.size() - keyLen];
+    private JavaPairRDD<List<Object>, Object[]> 
processRDDAggregate(JavaPairRDD<List<Object>, Object[]> currentPairRDD, 
RollupTreeNode curNode,
+                                                      SparkRDDAggregator[] 
sparkRDDAggregators) throws SparkDppException {
+        final boolean isDuplicateTable = 
!StringUtils.equalsIgnoreCase(curNode.indexMeta.indexType, "AGGREGATE")
+                && !StringUtils.equalsIgnoreCase(curNode.indexMeta.indexType, 
"UNIQUE");
+
+        // Aggregate/UNIQUE table
+        if (!isDuplicateTable) {
+            // TODO(wb) set the reduce concurrency by statistic instead of 
hard code 200
+            int aggregateConcurrency = 200;
+
+            int idx = 0;
+            for (int i = 0 ; i < curNode.indexMeta.columns.size(); i++) {
+                if (!curNode.indexMeta.columns.get(i).isKey) {
+                    sparkRDDAggregators[idx] = 
SparkRDDAggregator.buildAggregator(curNode.indexMeta.columns.get(i));
+                    idx++;
+                }
+            }
 
-        for (int i = 0 ; i < currentIndexMeta.columns.size(); i++) {
-            if (!currentIndexMeta.columns.get(i).isKey && !isDuplicateTable) {
-                sparkRDDAggregators[i - keyLen] = 
SparkRDDAggregator.buildAggregator(currentIndexMeta.columns.get(i));
+            if (curNode.indexMeta.isBaseIndex) {
+                JavaPairRDD<List<Object>, Object[]> result = 
currentPairRDD.mapToPair(new 
EncodeBaseAggregateTableFunction(sparkRDDAggregators))
+                        .reduceByKey(new 
AggregateReduceFunction(sparkRDDAggregators), aggregateConcurrency);
+                return result;
+            } else {
+                JavaPairRDD<List<Object>, Object[]> result = currentPairRDD
+                        .mapToPair(new EncodeRollupAggregateTableFunction(
+                                
getColumnIndexInParentRollup(curNode.keyColumnNames, curNode.valueColumnNames,
+                                        curNode.parent.keyColumnNames, 
curNode.parent.valueColumnNames)))
+                        .reduceByKey(new 
AggregateReduceFunction(sparkRDDAggregators), aggregateConcurrency);
+                return result;
+            }
+        // Duplicate Table
+        } else {
+            int idx = 0;
+            for (int i = 0; i < curNode.indexMeta.columns.size(); i++) {
+                if (!curNode.indexMeta.columns.get(i).isKey) {
+                    // duplicate table doesn't need aggregator
+                    // init a aggregator here just for keeping interface 
compatibility when writing data to HDFS
+                    sparkRDDAggregators[idx] = new DefaultSparkRDDAggregator();
+                    idx++;
+                }
+            }
+            if (curNode.indexMeta.isBaseIndex) {
+                return currentPairRDD;
+            } else {
+                return currentPairRDD.mapToPair(new 
EncodeRollupAggregateTableFunction(
+                        getColumnIndexInParentRollup(curNode.keyColumnNames, 
curNode.valueColumnNames,
+                        curNode.parent.keyColumnNames, 
curNode.parent.valueColumnNames)));
             }
         }
-
-        PairFunction<Row, List<Object>, Object[]> encodePairFunction = 
isDuplicateTable ?
-                // add 1 to include bucketId
-                new EncodeDuplicateTableFunction(keyLen + 1, 
currentIndexMeta.columns.size() - keyLen)
-                : new EncodeAggregateTableFunction(sparkRDDAggregators, keyLen 
+ 1);
-
-        // 2 convert dataframe to rdd and  encode key and value
-        // TODO(wb) use rdd to avoid bitamp/hll serialize when calculate rollup
-        JavaPairRDD<List<Object>, Object[]> currentRollupRDD = 
dataframe.toJavaRDD().mapToPair(encodePairFunction);
-
-        // 3 do aggregate
-        // TODO(wb) set the reduce concurrency by statistic instead of hard 
code 200
-        int aggregateConcurrency = 200;
-        JavaPairRDD<List<Object>, Object[]> reduceResultRDD = isDuplicateTable 
? currentRollupRDD
-                : currentRollupRDD.reduceByKey(new 
AggregateReduceFunction(sparkRDDAggregators), aggregateConcurrency);
-
-        // 4 repartition and finalize value column
-        JavaRDD<Row> finalRDD = reduceResultRDD
-                .repartitionAndSortWithinPartitions(new 
BucketPartitioner(bucketKeyMap), new BucketComparator())
-                .map(record -> {
-                    List<Object> keys = record._1;
-                    Object[] values = record._2;
-                    int size = keys.size() + values.length;
-                    Object[] result = new Object[size];
-
-                    for (int i = 0; i < keys.size(); i++) {
-                        result[i] = keys.get(i);
-                    }
-
-                    for (int i = keys.size(); i < size; i++) {
-                        int valueIdx = i - keys.size();
-                        result[i] = isDuplicateTable ? values[valueIdx] : 
sparkRDDAggregators[valueIdx].finalize(values[valueIdx]);
-                    }
-
-                    return RowFactory.create(result);
-                });
-
-        // 4 convert to dataframe
-        StructType tableSchemaWithBucketId = 
DppUtils.createDstTableSchema(currentIndexMeta.columns, true, true);
-        dataframe = spark.createDataFrame(finalRDD, tableSchemaWithBucketId);
-        return dataframe;
-
     }
 
     // write data to parquet file by using writing the parquet scheme of spark.
-    private void writePartitionedAndSortedDataframeToParquet(Dataset<Row> 
dataframe,
-                                                             String 
pathPattern,
+    private void 
writeRepartitionAndSortedRDDToParquet(JavaPairRDD<List<Object>, Object[]> 
resultRDD,
+                                                            String pathPattern,
                                                              long tableId,
-                                                             
EtlJobConfig.EtlIndex indexMeta) throws SparkDppException {
-        StructType outputSchema = dataframe.schema();
-        StructType dstSchema = DataTypes.createStructType(
-                Arrays.asList(outputSchema.fields()).stream()
-                        .filter(field -> 
!field.name().equalsIgnoreCase(DppUtils.BUCKET_ID))
-                        .collect(Collectors.toList()));
+                                                             
EtlJobConfig.EtlIndex indexMeta,
+                                                             
SparkRDDAggregator[] sparkRDDAggregators) throws SparkDppException {
+        StructType dstSchema = 
DppUtils.createDstTableSchema(indexMeta.columns, false, true);
         ExpressionEncoder encoder = RowEncoder.apply(dstSchema);
-        dataframe.foreachPartition(new ForeachPartitionFunction<Row>() {
+
+        resultRDD.repartitionAndSortWithinPartitions(new 
BucketPartitioner(bucketKeyMap), new BucketComparator())
+        .foreachPartition(new 
VoidFunction<Iterator<Tuple2<List<Object>,Object[]>>>() {
             @Override
-            public void call(Iterator<Row> t) throws Exception {
+            public void call(Iterator<Tuple2<List<Object>, Object[]>> t) 
throws Exception {
                 // write the data to dst file
                 Configuration conf = new 
Configuration(serializableHadoopConf.value());
                 FileSystem fs = 
FileSystem.get(URI.create(etlJobConfig.outputPath), conf);
@@ -211,19 +197,24 @@ public final class SparkDpp implements 
java.io.Serializable {
                 String tmpPath = "";
 
                 while (t.hasNext()) {
-                    Row row = t.next();
-                    if (row.length() <= 1) {
-                        LOG.warn("invalid row:" + row);
+                    Tuple2<List<Object>, Object[]> pair = t.next();
+                    List<Object> keyColumns = pair._1();
+                    Object[] valueColumns = pair._2();
+                    if ((keyColumns.size() + valueColumns.length) <= 1) {
+                        LOG.warn("invalid row:" + pair);
                         continue;
                     }
 
 
-                    String curBucketKey = row.getString(0);
+                    String curBucketKey = keyColumns.get(0).toString();
                     List<Object> columnObjects = new ArrayList<>();
-                    for (int i = 1; i < row.length(); ++i) {
-                        Object columnValue = row.get(i);
-                        columnObjects.add(columnValue);
+                    for (int i = 1; i < keyColumns.size(); ++i) {
+                        columnObjects.add(keyColumns.get(i));
+                    }
+                    for (int i = 0; i < valueColumns.length; ++i) {
+                        
columnObjects.add(sparkRDDAggregators[i].finalize(valueColumns[i]));
                     }
+
                     Row rowWithoutBucketKey = 
RowFactory.create(columnObjects.toArray());
                     // if the bucket key is new, it will belong to a new tablet
                     if (lastBucketKey == null || 
!curBucketKey.equals(lastBucketKey)) {
@@ -276,22 +267,21 @@ public final class SparkDpp implements 
java.io.Serializable {
                         throw ioe;
                     }
                 }
+
             }
-        });
-    }
+        });}
 
     // TODO(wb) one shuffle to calculate the rollup in the same level
     private void processRollupTree(RollupTreeNode rootNode,
-                                   Dataset<Row> rootDataframe,
-                                   long tableId, EtlJobConfig.EtlTable 
tableMeta,
-                                   EtlJobConfig.EtlIndex baseIndex) throws 
SparkDppException {
+                                   JavaPairRDD<List<Object>, Object[]> rootRDD,
+                                   long tableId, EtlJobConfig.EtlIndex 
baseIndex) throws SparkDppException {
         Queue<RollupTreeNode> nodeQueue = new LinkedList<>();
         nodeQueue.offer(rootNode);
         int currentLevel = 0;
         // level travel the tree
-        Map<Long, Dataset<Row>> parentDataframeMap = new HashMap<>();
-        parentDataframeMap.put(baseIndex.indexId, rootDataframe);
-        Map<Long, Dataset<Row>> childrenDataframeMap = new HashMap<>();
+        Map<Long, JavaPairRDD<List<Object>, Object[]>> parentRDDMap = new 
HashMap<>();
+        parentRDDMap.put(baseIndex.indexId, rootRDD);
+        Map<Long, JavaPairRDD<List<Object>, Object[]>> childrenRDDMap = new 
HashMap<>();
         String pathPattern = etlJobConfig.outputPath + "/" + 
etlJobConfig.outputFilePattern;
         while (!nodeQueue.isEmpty()) {
             RollupTreeNode curNode = nodeQueue.poll();
@@ -301,16 +291,16 @@ public final class SparkDpp implements 
java.io.Serializable {
                     nodeQueue.offer(child);
                 }
             }
-            Dataset<Row> curDataFrame = null;
+            JavaPairRDD<List<Object>, Object[]> curRDD = null;
             // column select for rollup
             if (curNode.level != currentLevel) {
-                for (Dataset<Row> dataframe : parentDataframeMap.values()) {
-                    dataframe.unpersist();
+                for (JavaPairRDD<List<Object>, Object[]> rdd : 
parentRDDMap.values()) {
+                    rdd.unpersist();
                 }
                 currentLevel = curNode.level;
-                parentDataframeMap.clear();
-                parentDataframeMap = childrenDataframeMap;
-                childrenDataframeMap = new HashMap<>();
+                parentRDDMap.clear();
+                parentRDDMap = childrenRDDMap;
+                childrenRDDMap = new HashMap<>();
             }
 
             long parentIndexId = baseIndex.indexId;
@@ -318,37 +308,59 @@ public final class SparkDpp implements 
java.io.Serializable {
                 parentIndexId = curNode.parent.indexId;
             }
 
-            Dataset<Row> parentDataframe = 
parentDataframeMap.get(parentIndexId);
-            List<Column> columns = new ArrayList<>();
-            List<Column> keyColumns = new ArrayList<>();
-            Column bucketIdColumn = new Column(DppUtils.BUCKET_ID);
-            keyColumns.add(bucketIdColumn);
-            columns.add(bucketIdColumn);
-            for (String keyName : curNode.keyColumnNames) {
-                columns.add(new Column(keyName));
-                keyColumns.add(new Column(keyName));
-            }
-            for (String valueName : curNode.valueColumnNames) {
-                columns.add(new Column(valueName));
-            }
-            Seq<Column> columnSeq = 
JavaConverters.asScalaIteratorConverter(columns.iterator()).asScala().toSeq();
-            curDataFrame = parentDataframe.select(columnSeq);
-            // aggregate and repartition
-            curDataFrame = processRDDAggAndRepartition(curDataFrame, 
curNode.indexMeta);
+            JavaPairRDD<List<Object>, Object[]> parentRDD = 
parentRDDMap.get(parentIndexId);
 
-            childrenDataframeMap.put(curNode.indexId, curDataFrame);
+            // aggregate
+            SparkRDDAggregator[] sparkRDDAggregators = new 
SparkRDDAggregator[curNode.valueColumnNames.size()];
+            curRDD = processRDDAggregate(parentRDD, curNode, 
sparkRDDAggregators);
+
+            childrenRDDMap.put(curNode.indexId, curRDD);
 
             if (curNode.children != null && curNode.children.size() > 1) {
                 // if the children number larger than 1, persist the dataframe 
for performance
-                curDataFrame.persist();
+                curRDD.persist(StorageLevel.MEMORY_AND_DISK());
+            }
+            // repartition and write to hdfs
+            writeRepartitionAndSortedRDDToParquet(curRDD, pathPattern, 
tableId, curNode.indexMeta, sparkRDDAggregators);
+        }
+    }
+
+    // get column index map from parent rollup to child rollup
+    // not consider bucketId here
+    private Pair<Integer[], Integer[]> 
getColumnIndexInParentRollup(List<String> childRollupKeyColumns, List<String> 
childRollupValueColumns,
+                                                                            
List<String> parentRollupKeyColumns, List<String> parentRollupValueColumns) 
throws SparkDppException {
+        List<Integer> keyMap = new ArrayList<>();
+        List<Integer> valueMap = new ArrayList<>();
+        // find column index in parent rollup schema
+        for (int i = 0; i < childRollupKeyColumns.size(); i++) {
+            for (int j = 0; j < parentRollupKeyColumns.size(); j++) {
+                if (StringUtils.equalsIgnoreCase(childRollupKeyColumns.get(i), 
parentRollupKeyColumns.get(j))) {
+                    keyMap.add(j);
+                    break;
+                }
+            }
+        }
+
+        for (int i = 0; i < childRollupValueColumns.size(); i++) {
+            for (int j = 0; j < parentRollupValueColumns.size(); j++) {
+                if 
(StringUtils.equalsIgnoreCase(childRollupValueColumns.get(i), 
parentRollupValueColumns.get(j))) {
+                    valueMap.add(j);
+                    break;
+                }
             }
-            writePartitionedAndSortedDataframeToParquet(curDataFrame, 
pathPattern, tableId, curNode.indexMeta);
         }
+
+        if (keyMap.size() != childRollupKeyColumns.size() || valueMap.size() 
!= childRollupValueColumns.size()) {
+            throw new SparkDppException(String.format("column map index from 
child to parent has error, key size src: %s, dst: %s; value size src: %s, dst: 
%s",
+                    childRollupKeyColumns.size(), keyMap.size(), 
childRollupValueColumns.size(), valueMap.size()));
+        }
+
+        return Pair.of(keyMap.toArray(new Integer[keyMap.size()]), 
valueMap.toArray(new Integer[valueMap.size()]));
     }
 
     // repartition dataframe by partitionid_bucketid
     // so data in the same bucket will be consecutive.
-    private Dataset<Row> repartitionDataframeByBucketId(SparkSession spark, 
Dataset<Row> dataframe,
+    private JavaPairRDD<List<Object>, Object[]> 
fillTupleWithPartitionColumn(SparkSession spark, Dataset<Row> dataframe,
                                                         
EtlJobConfig.EtlPartitionInfo partitionInfo,
                                                         List<Integer> 
partitionKeyIndex,
                                                         List<Class> 
partitionKeySchema,
@@ -374,59 +386,48 @@ public final class SparkDpp implements 
java.io.Serializable {
         }
         // use PairFlatMapFunction instead of PairMapFunction because the 
there will be
         // 0 or 1 output row for 1 input row
-        JavaPairRDD<String, DppColumns> pairRDD = 
dataframe.javaRDD().flatMapToPair(
-                new PairFlatMapFunction<Row, String, DppColumns>() {
-                    @Override
-                    public Iterator<Tuple2<String, DppColumns>> call(Row row) {
-                        List<Object> columns = new ArrayList<>();
-                        List<Object> keyColumns = new ArrayList<>();
-                        for (String columnName : keyColumnNames) {
-                            Object columnObject = 
row.get(row.fieldIndex(columnName));
-                            columns.add(columnObject);
-                            keyColumns.add(columnObject);
-                        }
+        JavaPairRDD<List<Object>, Object[]> resultPairRDD = 
dataframe.toJavaRDD().flatMapToPair(new PairFlatMapFunction<Row, List<Object>, 
Object[]>() {
+            @Override
+            public Iterator<Tuple2<List<Object>, Object[]>> call(Row row) 
throws Exception {
+                List<Object> keyColumns = new ArrayList<>();
+                Object[] valueColumns = new Object[valueColumnNames.size()];
+                for (String columnName : keyColumnNames) {
+                    Object columnObject = row.get(row.fieldIndex(columnName));
+                    keyColumns.add(columnObject);
+                }
 
-                        for (String columnName : valueColumnNames) {
-                            columns.add(row.get(row.fieldIndex(columnName)));
-                        }
-                        DppColumns dppColumns = new DppColumns(columns);
-                        DppColumns key = new DppColumns(keyColumns);
-                        List<Tuple2<String, DppColumns>> result = new 
ArrayList<>();
-                        int pid = partitioner.getPartition(key);
-                        if (!validPartitionIndex.contains(pid)) {
-                            LOG.warn("invalid partition for row:" + row + ", 
pid:" + pid);
-                            abnormalRowAcc.add(1);
-                            LOG.info("abnormalRowAcc:" + abnormalRowAcc);
-                            if (abnormalRowAcc.value() < 5) {
-                                LOG.info("add row to invalidRows:" + 
row.toString());
-                                invalidRows.add(row.toString());
-                                LOG.info("invalid rows contents:" + 
invalidRows.value());
-                            }
-                        } else {
-                            long hashValue = DppUtils.getHashValue(row, 
distributeColumns, dstTableSchema);
-                            int bucketId = (int) ((hashValue & 0xffffffff) % 
partitionInfo.partitions.get(pid).bucketNum);
-                            long partitionId = 
partitionInfo.partitions.get(pid).partitionId;
-                            // bucketKey is partitionId_bucketId
-                            String bucketKey = partitionId + "_" + bucketId;
-                            Tuple2<String, DppColumns> newTuple = new 
Tuple2<String, DppColumns>(bucketKey, dppColumns);
-                            result.add(newTuple);
-                        }
-                        return result.iterator();
+                for (int i = 0; i < valueColumnNames.size(); i++) {
+                    valueColumns[i] = 
row.get(row.fieldIndex(valueColumnNames.get(i)));
+                }
+
+                DppColumns key = new DppColumns(keyColumns);
+                int pid = partitioner.getPartition(key);
+                List<Tuple2<List<Object>, Object[]>> result = new 
ArrayList<>();
+                if (!validPartitionIndex.contains(pid)) {
+                    LOG.warn("invalid partition for row:" + row + ", pid:" + 
pid);
+                    abnormalRowAcc.add(1);
+                    LOG.info("abnormalRowAcc:" + abnormalRowAcc);
+                    if (abnormalRowAcc.value() < 5) {
+                        LOG.info("add row to invalidRows:" + row.toString());
+                        invalidRows.add(row.toString());
+                        LOG.info("invalid rows contents:" + 
invalidRows.value());
                     }
-                });
-        // TODO(wb): using rdd instead of dataframe from here
-        JavaRDD<Row> resultRdd = pairRDD.map(record -> {
-                    String bucketKey = record._1;
-                    List<Object> row = new ArrayList<>();
-                    // bucketKey as the first key
-                    row.add(bucketKey);
-                    row.addAll(record._2.columns);
-                    return RowFactory.create(row.toArray());
+                } else {
+                    long hashValue = DppUtils.getHashValue(row, 
distributeColumns, dstTableSchema);
+                    int bucketId = (int) ((hashValue & 0xffffffff) % 
partitionInfo.partitions.get(pid).bucketNum);
+                    long partitionId = 
partitionInfo.partitions.get(pid).partitionId;
+                    // bucketKey is partitionId_bucketId
+                    String bucketKey = partitionId + "_" + bucketId;
+
+                    List<Object> tuple = new ArrayList<>();
+                    tuple.add(bucketKey);
+                    tuple.addAll(keyColumns);
+                    result.add(new Tuple2<>(tuple, valueColumns));
                 }
-        );
+                return result.iterator();
+            }
+        });
 
-        StructType tableSchemaWithBucketId = 
DppUtils.createDstTableSchema(baseIndex.columns, true, false);
-        dataframe = spark.createDataFrame(resultRdd, tableSchemaWithBucketId);
         // use bucket number as the parallel number
         int reduceNum = 0;
         for (EtlJobConfig.EtlPartition partition : partitionInfo.partitions) {
@@ -439,7 +440,7 @@ public final class SparkDpp implements java.io.Serializable 
{
         // print to system.out for easy to find log info
         System.out.println("print bucket key map:" + bucketKeyMap.toString());
 
-        return dataframe;
+        return resultPairRDD;
     }
 
     // do the etl process
@@ -815,7 +816,7 @@ public final class SparkDpp implements java.io.Serializable 
{
                 RollupTreeNode rootNode = rollupTreeParser.build(etlTable);
                 LOG.info("Start to process rollup tree:" + rootNode);
 
-                Dataset<Row> tableDataframe = null;
+                JavaPairRDD<List<Object>, Object[]> tablePairRDD = null;
                 for (EtlJobConfig.EtlFileGroup fileGroup : 
etlTable.fileGroups) {
                     List<String> filePaths = fileGroup.filePaths;
                     Dataset<Row> fileGroupDataframe = null;
@@ -838,18 +839,18 @@ public final class SparkDpp implements 
java.io.Serializable {
                         unselectedRowAcc.add(currentSize - originalSize);
                     }
 
-                    fileGroupDataframe = repartitionDataframeByBucketId(spark, 
fileGroupDataframe,
+                    JavaPairRDD<List<Object>, Object[]> ret = 
fillTupleWithPartitionColumn(spark, fileGroupDataframe,
                             partitionInfo, partitionKeyIndex,
                             partitionKeySchema, partitionRangeKeys,
                             keyColumnNames, valueColumnNames,
                             dstTableSchema, baseIndex, fileGroup.partitions);
-                    if (tableDataframe == null) {
-                        tableDataframe = fileGroupDataframe;
+                    if (tablePairRDD == null) {
+                        tablePairRDD = ret;
                     } else {
-                        tableDataframe.union(fileGroupDataframe);
+                        tablePairRDD.union(ret);
                     }
                 }
-                processRollupTree(rootNode, tableDataframe, tableId, etlTable, 
baseIndex);
+                processRollupTree(rootNode, tablePairRDD, tableId, baseIndex);
             }
             spark.stop();
         } catch (Exception exception) {
diff --git 
a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkRDDAggregator.java
 
b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkRDDAggregator.java
index bd6f0db..4682fdc 100644
--- 
a/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkRDDAggregator.java
+++ 
b/fe/spark-dpp/src/main/java/org/apache/doris/load/loadv2/dpp/SparkRDDAggregator.java
@@ -18,12 +18,12 @@
 package org.apache.doris.load.loadv2.dpp;
 
 import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.doris.common.SparkDppException;
 import org.apache.doris.load.loadv2.etl.EtlJobConfig;
 import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.sql.Row;
 import scala.Tuple2;
 
 import java.io.ByteArrayInputStream;
@@ -125,60 +125,59 @@ public abstract class SparkRDDAggregator<T> implements 
Serializable {
 
 }
 
-class EncodeDuplicateTableFunction extends EncodeAggregateTableFunction {
+// just used for duplicate table, default logic is enough
+class DefaultSparkRDDAggregator extends SparkRDDAggregator {
 
-    private int valueLen;
-
-    public EncodeDuplicateTableFunction(int keyLen, int valueLen) {
-        super(keyLen);
-        this.valueLen = valueLen;
+    @Override
+    Object update(Object v1, Object v2) {
+        return null;
     }
+}
 
-    @Override
-    public Tuple2<List<Object>, Object[]> call(Row row) throws Exception {
-        List<Object> keys = new ArrayList(keyLen);
-        Object[] values = new Object[valueLen];
+// just encode value column,used for base rollup
+class EncodeBaseAggregateTableFunction implements 
PairFunction<Tuple2<List<Object>, Object[]>, List<Object>, Object[]> {
 
-        for (int i = 0; i < keyLen; i++) {
-            keys.add(row.get(i));
-        }
+    private SparkRDDAggregator[] valueAggregators;
 
-        for (int i = keyLen; i < row.length(); i++) {
-            values[i - keyLen] = row.get(i);
-        }
+    public EncodeBaseAggregateTableFunction(SparkRDDAggregator[] 
valueAggregators) {
+        this.valueAggregators = valueAggregators;
+    }
 
-        return new Tuple2<>(keys, values);
+
+    @Override
+    public Tuple2<List<Object>, Object[]> call(Tuple2<List<Object>, Object[]> 
srcPair) throws Exception {
+        for (int i = 0; i < srcPair._2().length; i++) {
+            srcPair._2()[i] = valueAggregators[i].init(srcPair._2()[i]);
+        }
+        return srcPair;
     }
 }
 
-class EncodeAggregateTableFunction implements PairFunction<Row, List<Object>, 
Object[]> {
+// just map column from parent rollup index to child rollup index,used for 
child rollup
+class EncodeRollupAggregateTableFunction implements 
PairFunction<Tuple2<List<Object>, Object[]>, List<Object>, Object[]> {
 
-    private SparkRDDAggregator[] valueAggregators;
-    // include bucket id
-    protected int keyLen;
+    Pair<Integer[], Integer[]> columnIndexInParentRollup;
 
-    public EncodeAggregateTableFunction(int keyLen) {
-        this.keyLen = keyLen;
+    public EncodeRollupAggregateTableFunction(Pair<Integer[], Integer[]> 
columnIndexInParentRollup) {
+        this.columnIndexInParentRollup = columnIndexInParentRollup;
     }
 
-    public EncodeAggregateTableFunction(SparkRDDAggregator[] valueAggregators, 
int keyLen) {
-        this.valueAggregators = valueAggregators;
-        this.keyLen = keyLen;
-    }
-
-    // TODO(wb): use a custom class as key to instead of List to save space
     @Override
-    public Tuple2<List<Object>, Object[]> call(Row row) throws Exception {
-        List<Object> keys = new ArrayList(keyLen);
-        Object[] values = new Object[valueAggregators.length];
+    public Tuple2<List<Object>, Object[]> call(Tuple2<List<Object>, Object[]> 
parentRollupKeyValuePair) throws Exception {
+        Integer[] keyColumnIndexMap = columnIndexInParentRollup.getKey();
+        Integer[] valueColumnIndexMap = columnIndexInParentRollup.getValue();
+
+        List<Object> keys = new ArrayList();
+        Object[] values = new Object[valueColumnIndexMap.length];
 
-        for (int i = 0; i < keyLen; i++) {
-            keys.add(row.get(i));
+        // deal bucket_id column
+        keys.add(parentRollupKeyValuePair._1().get(0));
+        for (int i = 0; i < keyColumnIndexMap.length; i++) {
+            keys.add(parentRollupKeyValuePair._1().get(keyColumnIndexMap[i] + 
1));
         }
 
-        for (int i = keyLen; i < row.size(); i++) {
-            int valueIdx = i - keyLen;
-            values[valueIdx] = valueAggregators[valueIdx].init(row.get(i));
+        for (int i = 0; i < valueColumnIndexMap.length; i++) {
+            values[i] = parentRollupKeyValuePair._2()[valueColumnIndexMap[i]];
         }
         return new Tuple2<>(keys, values);
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to