This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris-spark-connector.git
commit 921a2caa836f1621557e3d56d67689e05f64786e Author: Youngwb <yangwenbo_mail...@163.com> AuthorDate: Fri Jan 10 14:11:15 2020 +0800 Convert from arrow to rowbatch (#2723) For #2722 In our test environment, Doris cluster used 1 fe and 7 be (32C+128G). When using spakr-doris connecter to query a table containing 67 columns, it took about 1 hour for the query to return 69 million rows of data. After the improvement, the same query condition took 2.5 minutes and the query performance was significantly improved --- .../apache/doris/spark/serialization/RowBatch.java | 136 +++++++++++++++------ 1 file changed, 96 insertions(+), 40 deletions(-) diff --git a/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index 668e72d..d710fbb 100644 --- a/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -73,6 +73,7 @@ public class RowBatch { private int offsetInOneBatch = 0; private int rowCountInOneBatch = 0; private int readRowCount = 0; + private List<Row> rowBatch = new ArrayList<>(); private final ArrowStreamReader arrowStreamReader; private final VectorSchemaRoot root; private List<FieldVector> fieldVectors; @@ -115,6 +116,11 @@ public class RowBatch { } offsetInOneBatch = 0; rowCountInOneBatch = root.getRowCount(); + // init the rowBatch + for (int i = 0; i < rowCountInOneBatch; ++i) { + rowBatch.add(new Row(fieldVectors.size())); + } + convertArrowToRowBatch(); return true; } } catch (IOException e) { @@ -128,98 +134,135 @@ public class RowBatch { return false; } - public List<Object> next() throws DorisException { + private void addValueToRow(int rowIndex, Object obj) { + if (rowIndex > rowCountInOneBatch) { + String errMsg = "Get row offset: " + rowIndex + " larger than row size: " + + rowCountInOneBatch; + logger.error(errMsg); + throw new NoSuchElementException(errMsg); + } + rowBatch.get(rowIndex).put(obj); + } + + public void convertArrowToRowBatch() throws DorisException { try { - if (!hasNext()) { - String errMsg = "Get row offset:" + offsetInOneBatch + " larger than row size: " + rowCountInOneBatch; - logger.error(errMsg); - throw new NoSuchElementException(errMsg); - } - Row row = new Row(fieldVectors.size()); - for (int j = 0; j < fieldVectors.size(); j++) { - FieldVector curFieldVector = fieldVectors.get(j); + for (int col = 0; col < fieldVectors.size(); col++) { + FieldVector curFieldVector = fieldVectors.get(col); Types.MinorType mt = curFieldVector.getMinorType(); - if (curFieldVector.isNull(offsetInOneBatch)) { - row.put(null); - continue; - } - final String currentType = schema.get(j).getType(); + final String currentType = schema.get(col).getType(); switch (currentType) { case "NULL_TYPE": - row.put(null); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + addValueToRow(rowIndex, null); + } break; case "BOOLEAN": Preconditions.checkArgument(mt.equals(Types.MinorType.BIT), typeMismatchMessage(currentType, mt)); BitVector bitVector = (BitVector) curFieldVector; - int bit = bitVector.get(offsetInOneBatch); - row.put(bit != 0); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + Object fieldValue = bitVector.isNull(rowIndex) ? null : bitVector.get(rowIndex) != 0; + addValueToRow(rowIndex, fieldValue); + } break; case "TINYINT": Preconditions.checkArgument(mt.equals(Types.MinorType.TINYINT), typeMismatchMessage(currentType, mt)); TinyIntVector tinyIntVector = (TinyIntVector) curFieldVector; - row.put(tinyIntVector.get(offsetInOneBatch)); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + Object fieldValue = tinyIntVector.isNull(rowIndex) ? null : tinyIntVector.get(rowIndex); + addValueToRow(rowIndex, fieldValue); + } break; case "SMALLINT": Preconditions.checkArgument(mt.equals(Types.MinorType.SMALLINT), typeMismatchMessage(currentType, mt)); SmallIntVector smallIntVector = (SmallIntVector) curFieldVector; - row.put(smallIntVector.get(offsetInOneBatch)); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + Object fieldValue = smallIntVector.isNull(rowIndex) ? null : smallIntVector.get(rowIndex); + addValueToRow(rowIndex, fieldValue); + } break; case "INT": Preconditions.checkArgument(mt.equals(Types.MinorType.INT), typeMismatchMessage(currentType, mt)); IntVector intVector = (IntVector) curFieldVector; - row.put(intVector.get(offsetInOneBatch)); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + Object fieldValue = intVector.isNull(rowIndex) ? null : intVector.get(rowIndex); + addValueToRow(rowIndex, fieldValue); + } break; case "BIGINT": Preconditions.checkArgument(mt.equals(Types.MinorType.BIGINT), typeMismatchMessage(currentType, mt)); BigIntVector bigIntVector = (BigIntVector) curFieldVector; - row.put(bigIntVector.get(offsetInOneBatch)); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + Object fieldValue = bigIntVector.isNull(rowIndex) ? null : bigIntVector.get(rowIndex); + addValueToRow(rowIndex, fieldValue); + } break; case "FLOAT": Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT4), typeMismatchMessage(currentType, mt)); Float4Vector float4Vector = (Float4Vector) curFieldVector; - row.put(float4Vector.get(offsetInOneBatch)); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + Object fieldValue = float4Vector.isNull(rowIndex) ? null : float4Vector.get(rowIndex); + addValueToRow(rowIndex, fieldValue); + } break; case "TIME": case "DOUBLE": Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT8), typeMismatchMessage(currentType, mt)); Float8Vector float8Vector = (Float8Vector) curFieldVector; - row.put(float8Vector.get(offsetInOneBatch)); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + Object fieldValue = float8Vector.isNull(rowIndex) ? null : float8Vector.get(rowIndex); + addValueToRow(rowIndex, fieldValue); + } break; case "BINARY": Preconditions.checkArgument(mt.equals(Types.MinorType.VARBINARY), typeMismatchMessage(currentType, mt)); VarBinaryVector varBinaryVector = (VarBinaryVector) curFieldVector; - row.put(varBinaryVector.get(offsetInOneBatch)); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + Object fieldValue = varBinaryVector.isNull(rowIndex) ? null : varBinaryVector.get(rowIndex); + addValueToRow(rowIndex, fieldValue); + } break; case "DECIMAL": Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR), typeMismatchMessage(currentType, mt)); VarCharVector varCharVectorForDecimal = (VarCharVector) curFieldVector; - String decimalValue = new String(varCharVectorForDecimal.get(offsetInOneBatch)); - Decimal decimal = new Decimal(); - try { - decimal.set(new scala.math.BigDecimal(new BigDecimal(decimalValue))); - } catch (NumberFormatException e) { - String errMsg = "Decimal response result '" + decimalValue + "' is illegal."; - logger.error(errMsg, e); - throw new DorisException(errMsg); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (varCharVectorForDecimal.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + String decimalValue = new String(varCharVectorForDecimal.get(rowIndex)); + Decimal decimal = new Decimal(); + try { + decimal.set(new scala.math.BigDecimal(new BigDecimal(decimalValue))); + } catch (NumberFormatException e) { + String errMsg = "Decimal response result '" + decimalValue + "' is illegal."; + logger.error(errMsg, e); + throw new DorisException(errMsg); + } + addValueToRow(rowIndex, decimal); } - row.put(decimal); break; case "DECIMALV2": Preconditions.checkArgument(mt.equals(Types.MinorType.DECIMAL), typeMismatchMessage(currentType, mt)); DecimalVector decimalVector = (DecimalVector) curFieldVector; - Decimal decimalV2 = Decimal.apply(decimalVector.getObject(offsetInOneBatch)); - row.put(decimalV2); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (decimalVector.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + Decimal decimalV2 = Decimal.apply(decimalVector.getObject(rowIndex)); + addValueToRow(rowIndex, decimalV2); + } break; case "DATE": case "DATETIME": @@ -229,23 +272,36 @@ public class RowBatch { Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR), typeMismatchMessage(currentType, mt)); VarCharVector varCharVector = (VarCharVector) curFieldVector; - String value = new String(varCharVector.get(offsetInOneBatch)); - row.put(value); + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (varCharVector.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + String value = new String(varCharVector.get(rowIndex)); + addValueToRow(rowIndex, value); + } break; default: - String errMsg = "Unsupported type " + schema.get(j).getType(); + String errMsg = "Unsupported type " + schema.get(col).getType(); logger.error(errMsg); throw new DorisException(errMsg); } } - offsetInOneBatch++; - return row.getCols(); } catch (Exception e) { close(); throw e; } } + public List<Object> next() throws DorisException { + if (!hasNext()) { + String errMsg = "Get row offset:" + offsetInOneBatch + " larger than row size: " + rowCountInOneBatch; + logger.error(errMsg); + throw new NoSuchElementException(errMsg); + } + return rowBatch.get(offsetInOneBatch++).getCols(); + } + private String typeMismatchMessage(final String sparkType, final Types.MinorType arrowType) { final String messageTemplate = "Spark type is %1$s, but arrow type is %2$s."; return String.format(messageTemplate, sparkType, arrowType.name()); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org