This is an automated email from the ASF dual-hosted git repository. diwu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push: new 20b1228 [fix](connector) fix arrow deserialize issue due to data being inconsistent with column order (#256) 20b1228 is described below commit 20b12282ce4735353e3d86d6b1c079e206200f64 Author: gnehil <adamlee...@gmail.com> AuthorDate: Thu Jan 9 15:32:09 2025 +0800 [fix](connector) fix arrow deserialize issue due to data being inconsistent with column order (#256) --- .../doris/spark/client/DorisFrontendClient.java | 6 ++- .../spark/client/read/AbstractThriftReader.java | 18 ++++----- .../apache/doris/spark/client/read/RowBatch.java | 43 +++++++++++----------- .../apache/doris/spark/sql/ScalaDorisRowRDD.scala | 1 - .../apache/doris/spark/util/SchemaConvertors.scala | 1 - 5 files changed, 36 insertions(+), 33 deletions(-) diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java index 67b510e..ce6b289 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java +++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java @@ -316,7 +316,11 @@ public class DorisFrontendClient implements Serializable { throw new DorisException(); } String entity = EntityUtils.toString(response.getEntity()); - return MAPPER.readValue(extractEntity(entity, "data").traverse(), QueryPlan.class); + JsonNode dataJsonNode = extractEntity(entity, "data"); + if (dataJsonNode.get("exception") != null) { + throw new DorisException("query plan failed, exception: " + dataJsonNode.get("exception").asText()); + } + return MAPPER.readValue(dataJsonNode.traverse(), QueryPlan.class); } catch (Exception e) { throw new RuntimeException("query plan request failed", e); } diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java index d5f9f88..608e30c 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java +++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java @@ -88,8 +88,9 @@ public abstract class AbstractThriftReader extends DorisReader { this.contextId = scanOpenResult.getContextId(); Schema schema = getDorisSchema(); this.dorisSchema = processDorisSchema(partition, schema); - logger.debug("origin thrift read Schema: " + schema + ", processed schema: " + dorisSchema); - + if (logger.isDebugEnabled()) { + logger.debug("origin thrift read Schema: " + schema + ", processed schema: " + dorisSchema); + } if (isAsync) { int blockingQueueSize = config.getValue(DorisOptions.DORIS_DESERIALIZE_QUEUE_SIZE); this.rowBatchQueue = new ArrayBlockingQueue<>(blockingQueueSize); @@ -241,22 +242,21 @@ public abstract class AbstractThriftReader extends DorisReader { Schema tableSchema = frontend.getTableSchema(partition.getDatabase(), partition.getTable()); Map<String, Field> fieldTypeMap = tableSchema.getProperties().stream() .collect(Collectors.toMap(Field::getName, Function.identity())); + Map<String, Field> scanTypeMap = originSchema.getProperties().stream() + .collect(Collectors.toMap(Field::getName, Function.identity())); String[] readColumns = partition.getReadColumns(); List<Field> newFieldList = new ArrayList<>(); - int offset = 0; - for (int i = 0; i < readColumns.length; i++) { - String readColumn = readColumns[i]; - if (!fieldTypeMap.containsKey(readColumn) && readColumn.contains(" AS ")) { + for (String readColumn : readColumns) { + if (readColumn.contains(" AS ")) { int asIdx = readColumn.indexOf(" AS "); String realColumn = readColumn.substring(asIdx + 4).trim().replaceAll("`", ""); - if (fieldTypeMap.containsKey(realColumn) + if (fieldTypeMap.containsKey(realColumn) && scanTypeMap.containsKey(realColumn) && ("BITMAP".equalsIgnoreCase(fieldTypeMap.get(realColumn).getType()) || "HLL".equalsIgnoreCase(fieldTypeMap.get(realColumn).getType()))) { newFieldList.add(new Field(realColumn, TPrimitiveType.VARCHAR.name(), null, 0, 0, null)); - offset++; } } else { - newFieldList.add(originSchema.getProperties().get(i + offset)); + newFieldList.add(scanTypeMap.get(readColumn.trim().replaceAll("`", ""))); } } processedSchema.setProperties(newFieldList); diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java index c1613e8..840825c 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java +++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java @@ -194,6 +194,7 @@ public class RowBatch implements Serializable { FieldVector curFieldVector = fieldVectors.get(col); MinorType mt = curFieldVector.getMinorType(); + final String colName = schema.get(col).getName(); final String currentType = schema.get(col).getType(); switch (currentType) { case "NULL_TYPE": @@ -203,7 +204,7 @@ public class RowBatch implements Serializable { break; case "BOOLEAN": Preconditions.checkArgument(mt.equals(MinorType.BIT), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); BitVector bitVector = (BitVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { Object fieldValue = bitVector.isNull(rowIndex) ? null : bitVector.get(rowIndex) != 0; @@ -212,7 +213,7 @@ public class RowBatch implements Serializable { break; case "TINYINT": Preconditions.checkArgument(mt.equals(MinorType.TINYINT), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); TinyIntVector tinyIntVector = (TinyIntVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { Object fieldValue = tinyIntVector.isNull(rowIndex) ? null : tinyIntVector.get(rowIndex); @@ -221,7 +222,7 @@ public class RowBatch implements Serializable { break; case "SMALLINT": Preconditions.checkArgument(mt.equals(MinorType.SMALLINT), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); SmallIntVector smallIntVector = (SmallIntVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { Object fieldValue = smallIntVector.isNull(rowIndex) ? null : smallIntVector.get(rowIndex); @@ -230,7 +231,7 @@ public class RowBatch implements Serializable { break; case "INT": Preconditions.checkArgument(mt.equals(MinorType.INT), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); IntVector intVector = (IntVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { Object fieldValue = intVector.isNull(rowIndex) ? null : intVector.get(rowIndex); @@ -239,7 +240,7 @@ public class RowBatch implements Serializable { break; case "BIGINT": Preconditions.checkArgument(mt.equals(MinorType.BIGINT), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); BigIntVector bigIntVector = (BigIntVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { Object fieldValue = bigIntVector.isNull(rowIndex) ? null : bigIntVector.get(rowIndex); @@ -248,7 +249,7 @@ public class RowBatch implements Serializable { break; case "LARGEINT": Preconditions.checkArgument(mt.equals(MinorType.FIXEDSIZEBINARY) || - mt.equals(MinorType.VARCHAR), typeMismatchMessage(currentType, mt)); + mt.equals(MinorType.VARCHAR), typeMismatchMessage(colName, currentType, mt)); if (mt.equals(MinorType.FIXEDSIZEBINARY)) { FixedSizeBinaryVector largeIntVector = (FixedSizeBinaryVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { @@ -276,7 +277,7 @@ public class RowBatch implements Serializable { break; case "IPV4": Preconditions.checkArgument(mt.equals(MinorType.UINT4) || mt.equals(MinorType.INT), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); BaseIntVector ipv4Vector; if (mt.equals(MinorType.INT)) { ipv4Vector = (IntVector) curFieldVector; @@ -291,7 +292,7 @@ public class RowBatch implements Serializable { break; case "FLOAT": Preconditions.checkArgument(mt.equals(MinorType.FLOAT4), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); Float4Vector float4Vector = (Float4Vector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { Object fieldValue = float4Vector.isNull(rowIndex) ? null : float4Vector.get(rowIndex); @@ -301,7 +302,7 @@ public class RowBatch implements Serializable { case "TIME": case "DOUBLE": Preconditions.checkArgument(mt.equals(MinorType.FLOAT8), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); Float8Vector float8Vector = (Float8Vector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { Object fieldValue = float8Vector.isNull(rowIndex) ? null : float8Vector.get(rowIndex); @@ -310,7 +311,7 @@ public class RowBatch implements Serializable { break; case "BINARY": Preconditions.checkArgument(mt.equals(MinorType.VARBINARY), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); VarBinaryVector varBinaryVector = (VarBinaryVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { Object fieldValue = varBinaryVector.isNull(rowIndex) ? null : varBinaryVector.get(rowIndex); @@ -319,7 +320,7 @@ public class RowBatch implements Serializable { break; case "DECIMAL": Preconditions.checkArgument(mt.equals(MinorType.VARCHAR), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); VarCharVector varCharVectorForDecimal = (VarCharVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { if (varCharVectorForDecimal.isNull(rowIndex)) { @@ -343,7 +344,7 @@ public class RowBatch implements Serializable { case "DECIMAL64": case "DECIMAL128I": Preconditions.checkArgument(mt.equals(MinorType.DECIMAL), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); DecimalVector decimalVector = (DecimalVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { if (decimalVector.isNull(rowIndex)) { @@ -357,7 +358,7 @@ public class RowBatch implements Serializable { case "DATE": case "DATEV2": Preconditions.checkArgument(mt.equals(MinorType.VARCHAR) - || mt.equals(MinorType.DATEDAY), typeMismatchMessage(currentType, mt)); + || mt.equals(MinorType.DATEDAY), typeMismatchMessage(colName, currentType, mt)); if (mt.equals(MinorType.VARCHAR)) { VarCharVector date = (VarCharVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { @@ -417,7 +418,7 @@ public class RowBatch implements Serializable { case "JSONB": case "VARIANT": Preconditions.checkArgument(mt.equals(MinorType.VARCHAR), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); VarCharVector varCharVector = (VarCharVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { if (varCharVector.isNull(rowIndex)) { @@ -430,7 +431,7 @@ public class RowBatch implements Serializable { break; case "IPV6": Preconditions.checkArgument(mt.equals(MinorType.VARCHAR), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); VarCharVector ipv6VarcharVector = (VarCharVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { if (ipv6VarcharVector.isNull(rowIndex)) { @@ -444,7 +445,7 @@ public class RowBatch implements Serializable { break; case "ARRAY": Preconditions.checkArgument(mt.equals(MinorType.LIST), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); ListVector listVector = (ListVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { if (listVector.isNull(rowIndex)) { @@ -457,7 +458,7 @@ public class RowBatch implements Serializable { break; case "MAP": Preconditions.checkArgument(mt.equals(MinorType.MAP), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); MapVector mapVector = (MapVector) curFieldVector; UnionMapReader reader = mapVector.getReader(); for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { @@ -476,7 +477,7 @@ public class RowBatch implements Serializable { break; case "STRUCT": Preconditions.checkArgument(mt.equals(MinorType.STRUCT), - typeMismatchMessage(currentType, mt)); + typeMismatchMessage(colName, currentType, mt)); StructVector structVector = (StructVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { if (structVector.isNull(rowIndex)) { @@ -508,9 +509,9 @@ public class RowBatch implements Serializable { return rowBatch.get(offsetInRowBatch++).getCols(); } - private String typeMismatchMessage(final String sparkType, final MinorType arrowType) { - final String messageTemplate = "Spark type is %1$s, but arrow type is %2$s."; - return String.format(messageTemplate, sparkType, arrowType.name()); + private String typeMismatchMessage(final String columnName, final String sparkType, final MinorType arrowType) { + final String messageTemplate = "Spark type for column %1$s is %2$s, but arrow type is %3$s."; + return String.format(messageTemplate, columnName, sparkType, arrowType.name()); } public int getReadRowCount() { diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala index 9713bf3..0e6038d 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala +++ b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala @@ -38,7 +38,6 @@ private[spark] class ScalaDorisRowRDDIterator(context: TaskContext, extends AbstractDorisRDDIterator[Row](context, partition) { override def initReader(config: DorisConfig): Unit = { - config.setProperty(DorisOptions.DORIS_READ_FIELDS, schema.map(f => s"`${f.name}`").mkString(",")) config.getValue(DorisOptions.READ_MODE).toLowerCase match { case "thrift" => config.setProperty(DorisOptions.DORIS_VALUE_READER_CLASS, classOf[DorisRowThriftReader].getName) case "arrow" => config.setProperty(DorisOptions.DORIS_VALUE_READER_CLASS, classOf[DorisRowFlightSqlReader].getName) diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala index f85eb28..303aa1f 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala +++ b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala @@ -67,7 +67,6 @@ object SchemaConvertors { def convertToSchema(tscanColumnDescs: Seq[TScanColumnDesc]): Schema = { val schema = new Schema(tscanColumnDescs.length) tscanColumnDescs.foreach(desc => { - // println(desc.getName + " " + desc.getType.name()) schema.put(new Field(desc.getName, desc.getType.name, "", 0, 0, "")) }) schema --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org