This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git

The following commit(s) were added to refs/heads/master by this push:
     new 54ac94b  [fix](connector) Fixed writing issues in arrow format (#270)
54ac94b is described below

commit 54ac94bc5067588c22391306d082017d63a22d65
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Tue Feb 25 11:14:42 2025 +0800

    [fix](connector) Fixed writing issues in arrow format (#270)
---
 .../spark/client/read/AbstractThriftReader.java    |  2 +-
 .../apache/doris/spark/client/read/RowBatch.java   | 22 ++++++--
 .../client/write/AbstractStreamLoadProcessor.java  | 20 ++++---
 .../spark/client/write/StreamLoadProcessor.java    |  9 ++--
 .../org/apache/doris/spark/load/StreamLoader.scala |  2 +
 .../apache/doris/spark/util/RowConvertors.scala    |  5 +-
 .../apache/doris/spark/util/SchemaConvertors.scala |  4 +-
 .../doris/spark/client/read/RowBatchTest.java      | 61 +++++++++++++++++-----
 8 files changed, 94 insertions(+), 31 deletions(-)

diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
index 7fdb1cf..f200b38 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
@@ -110,7 +110,7 @@ public abstract class AbstractThriftReader extends 
DorisReader {
             this.rowBatchQueue = null;
             this.asyncThread = null;
         }
-        this.datetimeJava8ApiEnabled = false;
+        this.datetimeJava8ApiEnabled = partition.getDateTimeJava8APIEnabled();
     }
 
     private void runAsync() throws DorisException, InterruptedException {
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
index 759678d..ba61d6a 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
@@ -59,10 +59,12 @@ import java.math.BigDecimal;
 import java.math.BigInteger;
 import java.nio.charset.StandardCharsets;
 import java.sql.Date;
+import java.sql.Timestamp;
 import java.time.Instant;
 import java.time.LocalDate;
 import java.time.LocalDateTime;
 import java.time.ZoneId;
+import java.time.ZoneOffset;
 import java.time.format.DateTimeFormatter;
 import java.time.format.DateTimeFormatterBuilder;
 import java.time.temporal.ChronoField;
@@ -72,6 +74,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
 import java.util.Objects;
+import java.util.TimeZone;
 
 /**
  * row batch data container.
@@ -403,8 +406,15 @@ public class RowBatch implements Serializable {
                                     addValueToRow(rowIndex, null);
                                     continue;
                                 }
-                                String value = new 
String(varCharVector.get(rowIndex), StandardCharsets.UTF_8);
-                                addValueToRow(rowIndex, value);
+                                String stringValue = completeMilliseconds(new 
String(varCharVector.get(rowIndex),
+                                        StandardCharsets.UTF_8));
+                                LocalDateTime dateTime = 
LocalDateTime.parse(stringValue, dateTimeV2Formatter);
+                                if (datetimeJava8ApiEnabled) {
+                                    Instant instant = 
dateTime.atZone(DEFAULT_ZONE_ID).toInstant();
+                                    addValueToRow(rowIndex, instant);
+                                } else {
+                                    addValueToRow(rowIndex, 
Timestamp.valueOf(dateTime));
+                                }
                             }
                         } else if (curFieldVector instanceof TimeStampVector) {
                             TimeStampVector timeStampVector = 
(TimeStampVector) curFieldVector;
@@ -414,8 +424,12 @@ public class RowBatch implements Serializable {
                                     continue;
                                 }
                                 LocalDateTime dateTime = getDateTime(rowIndex, 
timeStampVector);
-                                String formatted = 
DATE_TIME_FORMATTER.format(dateTime);
-                                addValueToRow(rowIndex, formatted);
+                                if (datetimeJava8ApiEnabled) {
+                                    Instant instant = 
dateTime.atZone(DEFAULT_ZONE_ID).toInstant();
+                                    addValueToRow(rowIndex, instant);
+                                } else {
+                                    addValueToRow(rowIndex, 
Timestamp.valueOf(dateTime));
+                                }
                             }
                         } else {
                             String errMsg = String.format("Unsupported type 
for DATETIMEV2, minorType %s, class is %s",
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java
index 8c9e859..484653e 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java
@@ -48,9 +48,9 @@ import java.io.IOException;
 import java.io.PipedInputStream;
 import java.io.PipedOutputStream;
 import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashSet;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -106,7 +106,7 @@ public abstract class AbstractStreamLoadProcessor<R> 
implements DorisWriter<R>,
 
     private boolean isFirstRecordOfBatch = true;
 
-    private final List<R> recordBuffer = new ArrayList<>();
+    private final List<R> recordBuffer = new LinkedList<>();
 
     private static final int arrowBufferSize = 1000;
 
@@ -161,6 +161,12 @@ public abstract class AbstractStreamLoadProcessor<R> 
implements DorisWriter<R>,
 
     @Override
     public String stop() throws Exception {
+        // arrow format need to send all buffer data before stop
+        if (!recordBuffer.isEmpty() && "arrow".equalsIgnoreCase(format)) {
+            List<R> rs = new LinkedList<>(recordBuffer);
+            recordBuffer.clear();
+            output.write(toArrowFormat(rs));
+        }
         output.close();
         CloseableHttpResponse res = requestFuture.get();
         if (res.getStatusLine().getStatusCode() != HttpStatus.SC_OK) {
@@ -239,13 +245,13 @@ public abstract class AbstractStreamLoadProcessor<R> 
implements DorisWriter<R>,
             case "json":
                 return toStringFormat(row, format);
             case "arrow":
-                recordBuffer.add(row);
+                recordBuffer.add(copy(row));
                 if (recordBuffer.size() < arrowBufferSize) {
                     return new byte[0];
                 } else {
-                    R[] dataArray = (R[]) recordBuffer.toArray();
+                    LinkedList<R> rs = new LinkedList<>(recordBuffer);
                     recordBuffer.clear();
-                    return toArrowFormat(dataArray);
+                    return toArrowFormat(rs);
                 }
             default:
                 throw new IllegalArgumentException("Unsupported stream load 
format: " + format);
@@ -263,7 +269,7 @@ public abstract class AbstractStreamLoadProcessor<R> 
implements DorisWriter<R>,
 
     public abstract String stringify(R row, String format);
 
-    public abstract byte[] toArrowFormat(R[] rowArray) throws IOException;
+    public abstract byte[] toArrowFormat(List<R> rows) throws IOException;
 
     public abstract String getWriteFields() throws OptionRequiredException;
 
@@ -364,4 +370,6 @@ public abstract class AbstractStreamLoadProcessor<R> 
implements DorisWriter<R>,
         }
     }
 
+    protected abstract R copy(R row);
+
 }
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java
index e5ac4fa..2f787a5 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java
@@ -34,6 +34,7 @@ import org.apache.spark.sql.util.ArrowUtils;
 
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
+import java.util.List;
 
 public class StreamLoadProcessor extends 
AbstractStreamLoadProcessor<InternalRow> {
 
@@ -50,7 +51,7 @@ public class StreamLoadProcessor extends 
AbstractStreamLoadProcessor<InternalRow
     }
 
     @Override
-    public byte[] toArrowFormat(InternalRow[] rowArray) throws IOException {
+    public byte[] toArrowFormat(List<InternalRow> rowArray) throws IOException 
{
         Schema arrowSchema = ArrowUtils.toArrowSchema(schema, "UTC");
         VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, new 
RootAllocator(Integer.MAX_VALUE));
         ArrowWriter arrowWriter = ArrowWriter.create(root);
@@ -112,6 +113,8 @@ public class StreamLoadProcessor extends 
AbstractStreamLoadProcessor<InternalRow
         this.schema = schema;
     }
 
-
-
+    @Override
+    protected InternalRow copy(InternalRow row) {
+        return row.copy();
+    }
 }
\ No newline at end of file
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
index 73e8c9b..109fa5b 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
@@ -269,6 +269,8 @@ class StreamLoader(settings: SparkSettings, isStreaming: 
Boolean) extends Loader
    */
   private def buildLoadRequest(iterator: Iterator[InternalRow], schema: 
StructType, label: String): HttpUriRequest = {
 
+    iterator.next().copy()
+
     currentLoadUrl = URLs.streamLoad(getNode, database, table, enableHttps)
     val put = new HttpPut(currentLoadUrl)
     addCommonHeader(put)
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala
 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala
index 31b7196..b75d1ce 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 import java.sql.{Date, Timestamp}
-import java.time.LocalDate
+import java.time.{Instant, LocalDate}
 import scala.collection.JavaConverters.mapAsScalaMapConverter
 import scala.collection.mutable
 
@@ -110,7 +110,8 @@ object RowConvertors {
   def convertValue(v: Any, dataType: DataType, datetimeJava8ApiEnabled: 
Boolean): Any = {
     dataType match {
       case StringType => UTF8String.fromString(v.asInstanceOf[String])
-      case TimestampType => 
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(v.asInstanceOf[String]))
+      case TimestampType if datetimeJava8ApiEnabled => 
DateTimeUtils.instantToMicros(v.asInstanceOf[Instant])
+      case TimestampType => 
DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp])
       case DateType if datetimeJava8ApiEnabled => 
v.asInstanceOf[LocalDate].toEpochDay.toInt
       case DateType => DateTimeUtils.fromJavaDate(v.asInstanceOf[Date])
       case _: MapType =>
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
index 303aa1f..e694083 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
@@ -37,8 +37,8 @@ object SchemaConvertors {
       case "DOUBLE" => DataTypes.DoubleType
       case "DATE" => DataTypes.DateType
       case "DATEV2" => DataTypes.DateType
-      case "DATETIME" => DataTypes.StringType
-      case "DATETIMEV2" => DataTypes.StringType
+      case "DATETIME" => DataTypes.TimestampType
+      case "DATETIMEV2" => DataTypes.TimestampType
       case "BINARY" => DataTypes.BinaryType
       case "DECIMAL" => DecimalType(precision, scale)
       case "CHAR" => DataTypes.StringType
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java
 
b/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java
index 05123b4..acc7712 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java
@@ -56,12 +56,10 @@ import org.apache.doris.sdk.thrift.TStatusCode;
 import org.apache.doris.spark.exception.DorisException;
 import org.apache.doris.spark.rest.RestService;
 import org.apache.doris.spark.rest.models.Schema;
-import org.apache.spark.sql.internal.SQLConf;
-import org.apache.spark.sql.internal.SQLConf$;
 import org.apache.spark.sql.types.Decimal;
 import static org.hamcrest.core.StringStartsWith.startsWith;
 import org.junit.Assert;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -75,6 +73,7 @@ import java.math.BigDecimal;
 import java.math.BigInteger;
 import java.nio.charset.StandardCharsets;
 import java.sql.Date;
+import java.sql.Timestamp;
 import java.time.LocalDate;
 import java.time.LocalDateTime;
 import java.time.ZoneId;
@@ -275,7 +274,7 @@ public class RowBatchTest {
                 (float) 1.1,
                 (double) 1.1,
                 Date.valueOf("2008-08-08"),
-                "2008-08-08 00:00:00",
+                Timestamp.valueOf("2008-08-08 00:00:00"),
                 Decimal.apply(1234L, 4, 2),
                 "char1"
         );
@@ -289,7 +288,7 @@ public class RowBatchTest {
                 (float) 2.2,
                 (double) 2.2,
                 Date.valueOf("1900-08-08"),
-                "1900-08-08 00:00:00",
+                Timestamp.valueOf("1900-08-08 00:00:00"),
                 Decimal.apply(8888L, 4, 2),
                 "char2"
         );
@@ -303,7 +302,7 @@ public class RowBatchTest {
                 (float) 3.3,
                 (double) 3.3,
                 Date.valueOf("2100-08-08"),
-                "2100-08-08 00:00:00",
+                Timestamp.valueOf("2100-08-08 00:00:00"),
                 Decimal.apply(10L, 2, 0),
                 "char3"
         );
@@ -831,16 +830,16 @@ public class RowBatchTest {
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow0 = rowBatch.next();
-        Assert.assertEquals("2024-03-20 00:00:00", actualRow0.get(0));
-        Assert.assertEquals("2024-03-20 00:00:00", actualRow0.get(1));
+        Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00"), 
actualRow0.get(0));
+        Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00"), 
actualRow0.get(1));
 
         List<Object> actualRow1 = rowBatch.next();
-        Assert.assertEquals("2024-03-20 00:00:01", actualRow1.get(0));
-        Assert.assertEquals("2024-03-20 00:00:00.123", actualRow1.get(1));
+        Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:01"), 
actualRow1.get(0));
+        Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00.123"), 
actualRow1.get(1));
 
         List<Object> actualRow2 = rowBatch.next();
-        Assert.assertEquals("2024-03-20 00:00:02", actualRow2.get(0));
-        Assert.assertEquals("2024-03-20 00:00:00.123456", actualRow2.get(1));
+        Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:02"), 
actualRow2.get(0));
+        Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00.123456"), 
actualRow2.get(1));
 
 
         Assert.assertFalse(rowBatch.hasNext());
@@ -1169,6 +1168,10 @@ public class RowBatchTest {
         ImmutableList.Builder<Field> childrenBuilder = ImmutableList.builder();
         childrenBuilder.add(new Field("k0", FieldType.nullable(new 
ArrowType.Utf8()), null));
         childrenBuilder.add(new Field("k1", FieldType.nullable(new 
ArrowType.Date(DateUnit.DAY)), null));
+        childrenBuilder.add(new Field("k2", FieldType.nullable(new 
ArrowType.Timestamp(TimeUnit.MICROSECOND,
+                null)), null));
+        childrenBuilder.add(new Field("k3", FieldType.nullable(new 
ArrowType.Timestamp(TimeUnit.MICROSECOND,
+                null)), null));
 
         VectorSchemaRoot root = VectorSchemaRoot.create(
                 new 
org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null),
@@ -1202,6 +1205,32 @@ public class RowBatchTest {
         date2Vector.setSafe(0, (int) date);
         vector.setValueCount(1);
 
+        LocalDateTime localDateTime = LocalDateTime.of(2025, 2, 24,
+                0, 0, 0, 123000000);
+        long second = 
localDateTime.atZone(ZoneId.systemDefault()).toEpochSecond();
+        int nano = localDateTime.getNano();
+
+        vector = root.getVector("k2");
+        TimeStampMicroVector datetimeV2Vector = (TimeStampMicroVector) vector;
+        datetimeV2Vector.setInitialCapacity(1);
+        datetimeV2Vector.allocateNew();
+        datetimeV2Vector.setIndexDefined(0);
+        datetimeV2Vector.setSafe(0, second * 1000000 + nano / 1000);
+        vector.setValueCount(1);
+
+        LocalDateTime localDateTime1 = LocalDateTime.of(2025, 2, 24,
+                1, 2, 3, 123456000);
+        long second1 = 
localDateTime1.atZone(ZoneId.systemDefault()).toEpochSecond();
+        int nano1 = localDateTime1.getNano();
+
+        vector = root.getVector("k3");
+        TimeStampMicroVector datetimeV2Vector1 = (TimeStampMicroVector) vector;
+        datetimeV2Vector1.setInitialCapacity(1);
+        datetimeV2Vector1.allocateNew();
+        datetimeV2Vector1.setIndexDefined(0);
+        datetimeV2Vector1.setSafe(0, second1 * 1000000 + nano1 / 1000);
+        vector.setValueCount(1);
+
         arrowStreamWriter.writeBatch();
 
         arrowStreamWriter.end();
@@ -1217,7 +1246,9 @@ public class RowBatchTest {
 
         String schemaStr = "{\"properties\":[" +
                 "{\"type\":\"DATE\",\"name\":\"k0\",\"comment\":\"\"}, " +
-                "{\"type\":\"DATEV2\",\"name\":\"k1\",\"comment\":\"\"}" +
+                "{\"type\":\"DATEV2\",\"name\":\"k1\",\"comment\":\"\"}," +
+                "{\"type\":\"DATETIME\",\"name\":\"k2\",\"comment\":\"\"}," +
+                "{\"type\":\"DATETIMEV2\",\"name\":\"k3\",\"comment\":\"\"}" +
                 "], \"status\":200}";
 
         Schema schema = RestService.parseSchema(schemaStr, logger);
@@ -1228,6 +1259,8 @@ public class RowBatchTest {
         List<Object> actualRow0 = rowBatch1.next();
         Assert.assertEquals(Date.valueOf("2025-01-01"), actualRow0.get(0));
         Assert.assertEquals(Date.valueOf("2025-02-01"), actualRow0.get(1));
+        Assert.assertEquals(Timestamp.valueOf("2025-02-24 00:00:00.123"), 
actualRow0.get(2));
+        Assert.assertEquals(Timestamp.valueOf("2025-02-24 01:02:03.123456"), 
actualRow0.get(3));
 
         Assert.assertFalse(rowBatch1.hasNext());
 
@@ -1237,6 +1270,8 @@ public class RowBatchTest {
         List<Object> actualRow01 = rowBatch2.next();
         Assert.assertEquals(LocalDate.of(2025,1,1), actualRow01.get(0));
         Assert.assertEquals(localDate, actualRow01.get(1));
+        
Assert.assertEquals(localDateTime.atZone(ZoneId.systemDefault()).toInstant(), 
actualRow01.get(2));
+        
Assert.assertEquals(localDateTime1.atZone(ZoneId.systemDefault()).toInstant(), 
actualRow01.get(3));
 
         Assert.assertFalse(rowBatch2.hasNext());
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to