This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 20b1228  [fix](connector) fix arrow deserialize issue due to data 
being inconsistent with column order (#256)
20b1228 is described below

commit 20b12282ce4735353e3d86d6b1c079e206200f64
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Thu Jan 9 15:32:09 2025 +0800

    [fix](connector) fix arrow deserialize issue due to data being inconsistent 
with column order (#256)
---
 .../doris/spark/client/DorisFrontendClient.java    |  6 ++-
 .../spark/client/read/AbstractThriftReader.java    | 18 ++++-----
 .../apache/doris/spark/client/read/RowBatch.java   | 43 +++++++++++-----------
 .../apache/doris/spark/sql/ScalaDorisRowRDD.scala  |  1 -
 .../apache/doris/spark/util/SchemaConvertors.scala |  1 -
 5 files changed, 36 insertions(+), 33 deletions(-)

diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
index 67b510e..ce6b289 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
@@ -316,7 +316,11 @@ public class DorisFrontendClient implements Serializable {
                     throw new DorisException();
                 }
                 String entity = EntityUtils.toString(response.getEntity());
-                return MAPPER.readValue(extractEntity(entity, 
"data").traverse(), QueryPlan.class);
+                JsonNode dataJsonNode = extractEntity(entity, "data");
+                if (dataJsonNode.get("exception") != null) {
+                    throw new DorisException("query plan failed, exception: " 
+ dataJsonNode.get("exception").asText());
+                }
+                return MAPPER.readValue(dataJsonNode.traverse(), 
QueryPlan.class);
             } catch (Exception e) {
                 throw new RuntimeException("query plan request failed", e);
             }
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
index d5f9f88..608e30c 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
@@ -88,8 +88,9 @@ public abstract class AbstractThriftReader extends 
DorisReader {
         this.contextId = scanOpenResult.getContextId();
         Schema schema = getDorisSchema();
         this.dorisSchema = processDorisSchema(partition, schema);
-        logger.debug("origin thrift read Schema: " + schema + ", processed 
schema: " + dorisSchema);
-
+        if (logger.isDebugEnabled()) {
+            logger.debug("origin thrift read Schema: " + schema + ", processed 
schema: " + dorisSchema);
+        }
         if (isAsync) {
             int blockingQueueSize = 
config.getValue(DorisOptions.DORIS_DESERIALIZE_QUEUE_SIZE);
             this.rowBatchQueue = new ArrayBlockingQueue<>(blockingQueueSize);
@@ -241,22 +242,21 @@ public abstract class AbstractThriftReader extends 
DorisReader {
         Schema tableSchema = frontend.getTableSchema(partition.getDatabase(), 
partition.getTable());
         Map<String, Field> fieldTypeMap = tableSchema.getProperties().stream()
                 .collect(Collectors.toMap(Field::getName, 
Function.identity()));
+        Map<String, Field> scanTypeMap = originSchema.getProperties().stream()
+                .collect(Collectors.toMap(Field::getName, 
Function.identity()));
         String[] readColumns = partition.getReadColumns();
         List<Field> newFieldList = new ArrayList<>();
-        int offset = 0;
-        for (int i = 0; i < readColumns.length; i++) {
-            String readColumn = readColumns[i];
-            if (!fieldTypeMap.containsKey(readColumn) && readColumn.contains(" 
AS ")) {
+        for (String readColumn : readColumns) {
+            if (readColumn.contains(" AS ")) {
                 int asIdx = readColumn.indexOf(" AS ");
                 String realColumn = readColumn.substring(asIdx + 
4).trim().replaceAll("`", "");
-                if (fieldTypeMap.containsKey(realColumn)
+                if (fieldTypeMap.containsKey(realColumn) && 
scanTypeMap.containsKey(realColumn)
                         && 
("BITMAP".equalsIgnoreCase(fieldTypeMap.get(realColumn).getType())
                         || 
"HLL".equalsIgnoreCase(fieldTypeMap.get(realColumn).getType()))) {
                     newFieldList.add(new Field(realColumn, 
TPrimitiveType.VARCHAR.name(), null, 0, 0, null));
-                    offset++;
                 }
             } else {
-                newFieldList.add(originSchema.getProperties().get(i + offset));
+                
newFieldList.add(scanTypeMap.get(readColumn.trim().replaceAll("`", "")));
             }
         }
         processedSchema.setProperties(newFieldList);
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
index c1613e8..840825c 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
@@ -194,6 +194,7 @@ public class RowBatch implements Serializable {
                 FieldVector curFieldVector = fieldVectors.get(col);
                 MinorType mt = curFieldVector.getMinorType();
 
+                final String colName = schema.get(col).getName();
                 final String currentType = schema.get(col).getType();
                 switch (currentType) {
                     case "NULL_TYPE":
@@ -203,7 +204,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "BOOLEAN":
                         Preconditions.checkArgument(mt.equals(MinorType.BIT),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         BitVector bitVector = (BitVector) curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             Object fieldValue = bitVector.isNull(rowIndex) ? 
null : bitVector.get(rowIndex) != 0;
@@ -212,7 +213,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "TINYINT":
                         
Preconditions.checkArgument(mt.equals(MinorType.TINYINT),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         TinyIntVector tinyIntVector = (TinyIntVector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             Object fieldValue = tinyIntVector.isNull(rowIndex) 
? null : tinyIntVector.get(rowIndex);
@@ -221,7 +222,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "SMALLINT":
                         
Preconditions.checkArgument(mt.equals(MinorType.SMALLINT),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         SmallIntVector smallIntVector = (SmallIntVector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             Object fieldValue = 
smallIntVector.isNull(rowIndex) ? null : smallIntVector.get(rowIndex);
@@ -230,7 +231,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "INT":
                         Preconditions.checkArgument(mt.equals(MinorType.INT),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         IntVector intVector = (IntVector) curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             Object fieldValue = intVector.isNull(rowIndex) ? 
null : intVector.get(rowIndex);
@@ -239,7 +240,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "BIGINT":
                         
Preconditions.checkArgument(mt.equals(MinorType.BIGINT),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         BigIntVector bigIntVector = (BigIntVector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             Object fieldValue = bigIntVector.isNull(rowIndex) 
? null : bigIntVector.get(rowIndex);
@@ -248,7 +249,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "LARGEINT":
                         
Preconditions.checkArgument(mt.equals(MinorType.FIXEDSIZEBINARY) ||
-                                mt.equals(MinorType.VARCHAR), 
typeMismatchMessage(currentType, mt));
+                                mt.equals(MinorType.VARCHAR), 
typeMismatchMessage(colName, currentType, mt));
                         if (mt.equals(MinorType.FIXEDSIZEBINARY)) {
                             FixedSizeBinaryVector largeIntVector = 
(FixedSizeBinaryVector) curFieldVector;
                             for (int rowIndex = 0; rowIndex < 
rowCountInOneBatch; rowIndex++) {
@@ -276,7 +277,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "IPV4":
                         Preconditions.checkArgument(mt.equals(MinorType.UINT4) 
|| mt.equals(MinorType.INT),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         BaseIntVector ipv4Vector;
                         if (mt.equals(MinorType.INT)) {
                             ipv4Vector = (IntVector) curFieldVector;
@@ -291,7 +292,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "FLOAT":
                         
Preconditions.checkArgument(mt.equals(MinorType.FLOAT4),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         Float4Vector float4Vector = (Float4Vector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             Object fieldValue = float4Vector.isNull(rowIndex) 
? null : float4Vector.get(rowIndex);
@@ -301,7 +302,7 @@ public class RowBatch implements Serializable {
                     case "TIME":
                     case "DOUBLE":
                         
Preconditions.checkArgument(mt.equals(MinorType.FLOAT8),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         Float8Vector float8Vector = (Float8Vector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             Object fieldValue = float8Vector.isNull(rowIndex) 
? null : float8Vector.get(rowIndex);
@@ -310,7 +311,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "BINARY":
                         
Preconditions.checkArgument(mt.equals(MinorType.VARBINARY),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         VarBinaryVector varBinaryVector = (VarBinaryVector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             Object fieldValue = 
varBinaryVector.isNull(rowIndex) ? null : varBinaryVector.get(rowIndex);
@@ -319,7 +320,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "DECIMAL":
                         
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         VarCharVector varCharVectorForDecimal = 
(VarCharVector) curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             if (varCharVectorForDecimal.isNull(rowIndex)) {
@@ -343,7 +344,7 @@ public class RowBatch implements Serializable {
                     case "DECIMAL64":
                     case "DECIMAL128I":
                         
Preconditions.checkArgument(mt.equals(MinorType.DECIMAL),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         DecimalVector decimalVector = (DecimalVector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             if (decimalVector.isNull(rowIndex)) {
@@ -357,7 +358,7 @@ public class RowBatch implements Serializable {
                     case "DATE":
                     case "DATEV2":
                         
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR)
-                                || mt.equals(MinorType.DATEDAY), 
typeMismatchMessage(currentType, mt));
+                                || mt.equals(MinorType.DATEDAY), 
typeMismatchMessage(colName, currentType, mt));
                         if (mt.equals(MinorType.VARCHAR)) {
                             VarCharVector date = (VarCharVector) 
curFieldVector;
                             for (int rowIndex = 0; rowIndex < 
rowCountInOneBatch; rowIndex++) {
@@ -417,7 +418,7 @@ public class RowBatch implements Serializable {
                     case "JSONB":
                     case "VARIANT":
                         
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         VarCharVector varCharVector = (VarCharVector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             if (varCharVector.isNull(rowIndex)) {
@@ -430,7 +431,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "IPV6":
                         
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         VarCharVector ipv6VarcharVector = (VarCharVector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             if (ipv6VarcharVector.isNull(rowIndex)) {
@@ -444,7 +445,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "ARRAY":
                         Preconditions.checkArgument(mt.equals(MinorType.LIST),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         ListVector listVector = (ListVector) curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             if (listVector.isNull(rowIndex)) {
@@ -457,7 +458,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "MAP":
                         Preconditions.checkArgument(mt.equals(MinorType.MAP),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         MapVector mapVector = (MapVector) curFieldVector;
                         UnionMapReader reader = mapVector.getReader();
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
@@ -476,7 +477,7 @@ public class RowBatch implements Serializable {
                         break;
                     case "STRUCT":
                         
Preconditions.checkArgument(mt.equals(MinorType.STRUCT),
-                                typeMismatchMessage(currentType, mt));
+                                typeMismatchMessage(colName, currentType, mt));
                         StructVector structVector = (StructVector) 
curFieldVector;
                         for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
                             if (structVector.isNull(rowIndex)) {
@@ -508,9 +509,9 @@ public class RowBatch implements Serializable {
         return rowBatch.get(offsetInRowBatch++).getCols();
     }
 
-    private String typeMismatchMessage(final String sparkType, final MinorType 
arrowType) {
-        final String messageTemplate = "Spark type is %1$s, but arrow type is 
%2$s.";
-        return String.format(messageTemplate, sparkType, arrowType.name());
+    private String typeMismatchMessage(final String columnName, final String 
sparkType, final MinorType arrowType) {
+        final String messageTemplate = "Spark type for column %1$s is %2$s, 
but arrow type is %3$s.";
+        return String.format(messageTemplate, columnName, sparkType, 
arrowType.name());
     }
 
     public int getReadRowCount() {
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
index 9713bf3..0e6038d 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
@@ -38,7 +38,6 @@ private[spark] class ScalaDorisRowRDDIterator(context: 
TaskContext,
   extends AbstractDorisRDDIterator[Row](context, partition) {
 
   override def initReader(config: DorisConfig): Unit = {
-    config.setProperty(DorisOptions.DORIS_READ_FIELDS, schema.map(f => 
s"`${f.name}`").mkString(","))
     config.getValue(DorisOptions.READ_MODE).toLowerCase match {
       case "thrift" => 
config.setProperty(DorisOptions.DORIS_VALUE_READER_CLASS, 
classOf[DorisRowThriftReader].getName)
       case "arrow" => 
config.setProperty(DorisOptions.DORIS_VALUE_READER_CLASS, 
classOf[DorisRowFlightSqlReader].getName)
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
index f85eb28..303aa1f 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
@@ -67,7 +67,6 @@ object SchemaConvertors {
   def convertToSchema(tscanColumnDescs: Seq[TScanColumnDesc]): Schema = {
     val schema = new Schema(tscanColumnDescs.length)
     tscanColumnDescs.foreach(desc => {
-      // println(desc.getName + " " + desc.getType.name())
       schema.put(new Field(desc.getName, desc.getType.name, "", 0, 0, ""))
     })
     schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to