This is an automated email from the ASF dual-hosted git repository. diwu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push: new 54ac94b [fix](connector) Fixed writing issues in arrow format (#270) 54ac94b is described below commit 54ac94bc5067588c22391306d082017d63a22d65 Author: gnehil <adamlee...@gmail.com> AuthorDate: Tue Feb 25 11:14:42 2025 +0800 [fix](connector) Fixed writing issues in arrow format (#270) --- .../spark/client/read/AbstractThriftReader.java | 2 +- .../apache/doris/spark/client/read/RowBatch.java | 22 ++++++-- .../client/write/AbstractStreamLoadProcessor.java | 20 ++++--- .../spark/client/write/StreamLoadProcessor.java | 9 ++-- .../org/apache/doris/spark/load/StreamLoader.scala | 2 + .../apache/doris/spark/util/RowConvertors.scala | 5 +- .../apache/doris/spark/util/SchemaConvertors.scala | 4 +- .../doris/spark/client/read/RowBatchTest.java | 61 +++++++++++++++++----- 8 files changed, 94 insertions(+), 31 deletions(-) diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java index 7fdb1cf..f200b38 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java +++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java @@ -110,7 +110,7 @@ public abstract class AbstractThriftReader extends DorisReader { this.rowBatchQueue = null; this.asyncThread = null; } - this.datetimeJava8ApiEnabled = false; + this.datetimeJava8ApiEnabled = partition.getDateTimeJava8APIEnabled(); } private void runAsync() throws DorisException, InterruptedException { diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java index 759678d..ba61d6a 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java +++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java @@ -59,10 +59,12 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.sql.Timestamp; import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZoneId; +import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.time.temporal.ChronoField; @@ -72,6 +74,7 @@ import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; +import java.util.TimeZone; /** * row batch data container. @@ -403,8 +406,15 @@ public class RowBatch implements Serializable { addValueToRow(rowIndex, null); continue; } - String value = new String(varCharVector.get(rowIndex), StandardCharsets.UTF_8); - addValueToRow(rowIndex, value); + String stringValue = completeMilliseconds(new String(varCharVector.get(rowIndex), + StandardCharsets.UTF_8)); + LocalDateTime dateTime = LocalDateTime.parse(stringValue, dateTimeV2Formatter); + if (datetimeJava8ApiEnabled) { + Instant instant = dateTime.atZone(DEFAULT_ZONE_ID).toInstant(); + addValueToRow(rowIndex, instant); + } else { + addValueToRow(rowIndex, Timestamp.valueOf(dateTime)); + } } } else if (curFieldVector instanceof TimeStampVector) { TimeStampVector timeStampVector = (TimeStampVector) curFieldVector; @@ -414,8 +424,12 @@ public class RowBatch implements Serializable { continue; } LocalDateTime dateTime = getDateTime(rowIndex, timeStampVector); - String formatted = DATE_TIME_FORMATTER.format(dateTime); - addValueToRow(rowIndex, formatted); + if (datetimeJava8ApiEnabled) { + Instant instant = dateTime.atZone(DEFAULT_ZONE_ID).toInstant(); + addValueToRow(rowIndex, instant); + } else { + addValueToRow(rowIndex, Timestamp.valueOf(dateTime)); + } } } else { String errMsg = String.format("Unsupported type for DATETIMEV2, minorType %s, class is %s", diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java index 8c9e859..484653e 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java +++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java @@ -48,9 +48,9 @@ import java.io.IOException; import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; @@ -106,7 +106,7 @@ public abstract class AbstractStreamLoadProcessor<R> implements DorisWriter<R>, private boolean isFirstRecordOfBatch = true; - private final List<R> recordBuffer = new ArrayList<>(); + private final List<R> recordBuffer = new LinkedList<>(); private static final int arrowBufferSize = 1000; @@ -161,6 +161,12 @@ public abstract class AbstractStreamLoadProcessor<R> implements DorisWriter<R>, @Override public String stop() throws Exception { + // arrow format need to send all buffer data before stop + if (!recordBuffer.isEmpty() && "arrow".equalsIgnoreCase(format)) { + List<R> rs = new LinkedList<>(recordBuffer); + recordBuffer.clear(); + output.write(toArrowFormat(rs)); + } output.close(); CloseableHttpResponse res = requestFuture.get(); if (res.getStatusLine().getStatusCode() != HttpStatus.SC_OK) { @@ -239,13 +245,13 @@ public abstract class AbstractStreamLoadProcessor<R> implements DorisWriter<R>, case "json": return toStringFormat(row, format); case "arrow": - recordBuffer.add(row); + recordBuffer.add(copy(row)); if (recordBuffer.size() < arrowBufferSize) { return new byte[0]; } else { - R[] dataArray = (R[]) recordBuffer.toArray(); + LinkedList<R> rs = new LinkedList<>(recordBuffer); recordBuffer.clear(); - return toArrowFormat(dataArray); + return toArrowFormat(rs); } default: throw new IllegalArgumentException("Unsupported stream load format: " + format); @@ -263,7 +269,7 @@ public abstract class AbstractStreamLoadProcessor<R> implements DorisWriter<R>, public abstract String stringify(R row, String format); - public abstract byte[] toArrowFormat(R[] rowArray) throws IOException; + public abstract byte[] toArrowFormat(List<R> rows) throws IOException; public abstract String getWriteFields() throws OptionRequiredException; @@ -364,4 +370,6 @@ public abstract class AbstractStreamLoadProcessor<R> implements DorisWriter<R>, } } + protected abstract R copy(R row); + } diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java index e5ac4fa..2f787a5 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java +++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java @@ -34,6 +34,7 @@ import org.apache.spark.sql.util.ArrowUtils; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.util.List; public class StreamLoadProcessor extends AbstractStreamLoadProcessor<InternalRow> { @@ -50,7 +51,7 @@ public class StreamLoadProcessor extends AbstractStreamLoadProcessor<InternalRow } @Override - public byte[] toArrowFormat(InternalRow[] rowArray) throws IOException { + public byte[] toArrowFormat(List<InternalRow> rowArray) throws IOException { Schema arrowSchema = ArrowUtils.toArrowSchema(schema, "UTC"); VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, new RootAllocator(Integer.MAX_VALUE)); ArrowWriter arrowWriter = ArrowWriter.create(root); @@ -112,6 +113,8 @@ public class StreamLoadProcessor extends AbstractStreamLoadProcessor<InternalRow this.schema = schema; } - - + @Override + protected InternalRow copy(InternalRow row) { + return row.copy(); + } } \ No newline at end of file diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala index 73e8c9b..109fa5b 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala +++ b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala @@ -269,6 +269,8 @@ class StreamLoader(settings: SparkSettings, isStreaming: Boolean) extends Loader */ private def buildLoadRequest(iterator: Iterator[InternalRow], schema: StructType, label: String): HttpUriRequest = { + iterator.next().copy() + currentLoadUrl = URLs.streamLoad(getNode, database, table, enableHttps) val put = new HttpPut(currentLoadUrl) addCommonHeader(put) diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala index 31b7196..b75d1ce 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala +++ b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import java.sql.{Date, Timestamp} -import java.time.LocalDate +import java.time.{Instant, LocalDate} import scala.collection.JavaConverters.mapAsScalaMapConverter import scala.collection.mutable @@ -110,7 +110,8 @@ object RowConvertors { def convertValue(v: Any, dataType: DataType, datetimeJava8ApiEnabled: Boolean): Any = { dataType match { case StringType => UTF8String.fromString(v.asInstanceOf[String]) - case TimestampType => DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(v.asInstanceOf[String])) + case TimestampType if datetimeJava8ApiEnabled => DateTimeUtils.instantToMicros(v.asInstanceOf[Instant]) + case TimestampType => DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]) case DateType if datetimeJava8ApiEnabled => v.asInstanceOf[LocalDate].toEpochDay.toInt case DateType => DateTimeUtils.fromJavaDate(v.asInstanceOf[Date]) case _: MapType => diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala index 303aa1f..e694083 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala +++ b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala @@ -37,8 +37,8 @@ object SchemaConvertors { case "DOUBLE" => DataTypes.DoubleType case "DATE" => DataTypes.DateType case "DATEV2" => DataTypes.DateType - case "DATETIME" => DataTypes.StringType - case "DATETIMEV2" => DataTypes.StringType + case "DATETIME" => DataTypes.TimestampType + case "DATETIMEV2" => DataTypes.TimestampType case "BINARY" => DataTypes.BinaryType case "DECIMAL" => DecimalType(precision, scale) case "CHAR" => DataTypes.StringType diff --git a/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java b/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java index 05123b4..acc7712 100644 --- a/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java +++ b/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java @@ -56,12 +56,10 @@ import org.apache.doris.sdk.thrift.TStatusCode; import org.apache.doris.spark.exception.DorisException; import org.apache.doris.spark.rest.RestService; import org.apache.doris.spark.rest.models.Schema; -import org.apache.spark.sql.internal.SQLConf; -import org.apache.spark.sql.internal.SQLConf$; import org.apache.spark.sql.types.Decimal; import static org.hamcrest.core.StringStartsWith.startsWith; import org.junit.Assert; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -75,6 +73,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZoneId; @@ -275,7 +274,7 @@ public class RowBatchTest { (float) 1.1, (double) 1.1, Date.valueOf("2008-08-08"), - "2008-08-08 00:00:00", + Timestamp.valueOf("2008-08-08 00:00:00"), Decimal.apply(1234L, 4, 2), "char1" ); @@ -289,7 +288,7 @@ public class RowBatchTest { (float) 2.2, (double) 2.2, Date.valueOf("1900-08-08"), - "1900-08-08 00:00:00", + Timestamp.valueOf("1900-08-08 00:00:00"), Decimal.apply(8888L, 4, 2), "char2" ); @@ -303,7 +302,7 @@ public class RowBatchTest { (float) 3.3, (double) 3.3, Date.valueOf("2100-08-08"), - "2100-08-08 00:00:00", + Timestamp.valueOf("2100-08-08 00:00:00"), Decimal.apply(10L, 2, 0), "char3" ); @@ -831,16 +830,16 @@ public class RowBatchTest { Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow0 = rowBatch.next(); - Assert.assertEquals("2024-03-20 00:00:00", actualRow0.get(0)); - Assert.assertEquals("2024-03-20 00:00:00", actualRow0.get(1)); + Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00"), actualRow0.get(0)); + Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00"), actualRow0.get(1)); List<Object> actualRow1 = rowBatch.next(); - Assert.assertEquals("2024-03-20 00:00:01", actualRow1.get(0)); - Assert.assertEquals("2024-03-20 00:00:00.123", actualRow1.get(1)); + Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:01"), actualRow1.get(0)); + Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00.123"), actualRow1.get(1)); List<Object> actualRow2 = rowBatch.next(); - Assert.assertEquals("2024-03-20 00:00:02", actualRow2.get(0)); - Assert.assertEquals("2024-03-20 00:00:00.123456", actualRow2.get(1)); + Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:02"), actualRow2.get(0)); + Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00.123456"), actualRow2.get(1)); Assert.assertFalse(rowBatch.hasNext()); @@ -1169,6 +1168,10 @@ public class RowBatchTest { ImmutableList.Builder<Field> childrenBuilder = ImmutableList.builder(); childrenBuilder.add(new Field("k0", FieldType.nullable(new ArrowType.Utf8()), null)); childrenBuilder.add(new Field("k1", FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), null)); + childrenBuilder.add(new Field("k2", FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, + null)), null)); + childrenBuilder.add(new Field("k3", FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, + null)), null)); VectorSchemaRoot root = VectorSchemaRoot.create( new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), @@ -1202,6 +1205,32 @@ public class RowBatchTest { date2Vector.setSafe(0, (int) date); vector.setValueCount(1); + LocalDateTime localDateTime = LocalDateTime.of(2025, 2, 24, + 0, 0, 0, 123000000); + long second = localDateTime.atZone(ZoneId.systemDefault()).toEpochSecond(); + int nano = localDateTime.getNano(); + + vector = root.getVector("k2"); + TimeStampMicroVector datetimeV2Vector = (TimeStampMicroVector) vector; + datetimeV2Vector.setInitialCapacity(1); + datetimeV2Vector.allocateNew(); + datetimeV2Vector.setIndexDefined(0); + datetimeV2Vector.setSafe(0, second * 1000000 + nano / 1000); + vector.setValueCount(1); + + LocalDateTime localDateTime1 = LocalDateTime.of(2025, 2, 24, + 1, 2, 3, 123456000); + long second1 = localDateTime1.atZone(ZoneId.systemDefault()).toEpochSecond(); + int nano1 = localDateTime1.getNano(); + + vector = root.getVector("k3"); + TimeStampMicroVector datetimeV2Vector1 = (TimeStampMicroVector) vector; + datetimeV2Vector1.setInitialCapacity(1); + datetimeV2Vector1.allocateNew(); + datetimeV2Vector1.setIndexDefined(0); + datetimeV2Vector1.setSafe(0, second1 * 1000000 + nano1 / 1000); + vector.setValueCount(1); + arrowStreamWriter.writeBatch(); arrowStreamWriter.end(); @@ -1217,7 +1246,9 @@ public class RowBatchTest { String schemaStr = "{\"properties\":[" + "{\"type\":\"DATE\",\"name\":\"k0\",\"comment\":\"\"}, " + - "{\"type\":\"DATEV2\",\"name\":\"k1\",\"comment\":\"\"}" + + "{\"type\":\"DATEV2\",\"name\":\"k1\",\"comment\":\"\"}," + + "{\"type\":\"DATETIME\",\"name\":\"k2\",\"comment\":\"\"}," + + "{\"type\":\"DATETIMEV2\",\"name\":\"k3\",\"comment\":\"\"}" + "], \"status\":200}"; Schema schema = RestService.parseSchema(schemaStr, logger); @@ -1228,6 +1259,8 @@ public class RowBatchTest { List<Object> actualRow0 = rowBatch1.next(); Assert.assertEquals(Date.valueOf("2025-01-01"), actualRow0.get(0)); Assert.assertEquals(Date.valueOf("2025-02-01"), actualRow0.get(1)); + Assert.assertEquals(Timestamp.valueOf("2025-02-24 00:00:00.123"), actualRow0.get(2)); + Assert.assertEquals(Timestamp.valueOf("2025-02-24 01:02:03.123456"), actualRow0.get(3)); Assert.assertFalse(rowBatch1.hasNext()); @@ -1237,6 +1270,8 @@ public class RowBatchTest { List<Object> actualRow01 = rowBatch2.next(); Assert.assertEquals(LocalDate.of(2025,1,1), actualRow01.get(0)); Assert.assertEquals(localDate, actualRow01.get(1)); + Assert.assertEquals(localDateTime.atZone(ZoneId.systemDefault()).toInstant(), actualRow01.get(2)); + Assert.assertEquals(localDateTime1.atZone(ZoneId.systemDefault()).toInstant(), actualRow01.get(3)); Assert.assertFalse(rowBatch2.hasNext()); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org