This is an automated email from the ASF dual-hosted git repository. jiafengzheng pushed a commit to branch branch-for-flink-before-1.13 in repository https://gitbox.apache.org/repos/asf/incubator-doris-flink-connector.git
The following commit(s) were added to refs/heads/branch-for-flink-before-1.13 by this push: new 2ef45b8 [Bug-1.13] Fix row type decimal convert bug (#27) 2ef45b8 is described below commit 2ef45b84b96405715cef09aec048573fd018bd19 Author: aiwenmo <32723967+aiwe...@users.noreply.github.com> AuthorDate: Fri Apr 15 11:01:45 2022 +0800 [Bug-1.13] Fix row type decimal convert bug (#27) * [Bug-1.13] Fix row type decimal convert bug --- .../apache/doris/flink/serialization/RowBatch.java | 6 +- .../doris/flink/table/DorisDynamicTableSource.java | 5 +- .../doris/flink/table/DorisRowDataInputFormat.java | 46 +++++- .../doris/flink/serialization/TestRowBatch.java | 160 +++++++++++---------- 4 files changed, 128 insertions(+), 89 deletions(-) diff --git a/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java b/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java index 3337637..827ec81 100644 --- a/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java +++ b/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java @@ -37,8 +37,6 @@ import org.apache.doris.flink.exception.DorisException; import org.apache.doris.flink.rest.models.Schema; import org.apache.doris.thrift.TScanBatchResult; -import org.apache.flink.table.data.DecimalData; -import org.apache.flink.table.data.StringData; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -243,7 +241,7 @@ public class RowBatch { continue; } BigDecimal value = decimalVector.getObject(rowIndex).stripTrailingZeros(); - addValueToRow(rowIndex, DecimalData.fromBigDecimal(value, value.precision(), value.scale())); + addValueToRow(rowIndex, value); } break; case "DATE": @@ -261,7 +259,7 @@ public class RowBatch { continue; } String value = new String(varCharVector.get(rowIndex)); - addValueToRow(rowIndex, StringData.fromString(value)); + addValueToRow(rowIndex, value); } break; default: diff --git a/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisDynamicTableSource.java b/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisDynamicTableSource.java index 0262677..689aa47 100644 --- a/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisDynamicTableSource.java +++ b/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisDynamicTableSource.java @@ -33,6 +33,8 @@ import org.apache.flink.table.connector.source.ScanTableSource; import org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown; import org.apache.flink.table.connector.source.abilities.SupportsProjectionPushDown; import org.apache.flink.table.data.RowData; +import org.apache.flink.table.types.logical.RowType; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,7 +82,8 @@ public final class DorisDynamicTableSource implements ScanTableSource, LookupTab .setPassword(options.getPassword()) .setTableIdentifier(options.getTableIdentifier()) .setPartitions(dorisPartitions) - .setReadOptions(readOptions); + .setReadOptions(readOptions) + .setRowType((RowType) physicalSchema.toRowDataType().getLogicalType()); return InputFormatProvider.of(builder.build()); } diff --git a/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisRowDataInputFormat.java b/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisRowDataInputFormat.java index c75a88f..eeb63ba 100644 --- a/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisRowDataInputFormat.java +++ b/flink-doris-connector/src/main/java/org/apache/doris/flink/table/DorisRowDataInputFormat.java @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + package org.apache.doris.flink.table; import org.apache.doris.flink.cfg.DorisOptions; @@ -29,16 +30,23 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.io.InputSplitAssigner; +import org.apache.flink.table.data.DecimalData; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; import java.io.IOException; +import java.math.BigDecimal; import java.sql.PreparedStatement; import java.util.ArrayList; import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + /** * InputFormat for {@link DorisDynamicTableSource}. */ @@ -56,10 +64,13 @@ public class DorisRowDataInputFormat extends RichInputFormat<RowData, DorisTable private ScalaValueReader scalaValueReader; private transient boolean hasNext; - public DorisRowDataInputFormat(DorisOptions options, List<PartitionDefinition> dorisPartitions, DorisReadOptions readOptions) { + private RowType rowType; + + public DorisRowDataInputFormat(DorisOptions options, List<PartitionDefinition> dorisPartitions, DorisReadOptions readOptions, RowType rowType) { this.options = options; this.dorisPartitions = dorisPartitions; this.readOptions = readOptions; + this.rowType = rowType; } @Override @@ -136,15 +147,30 @@ public class DorisRowDataInputFormat extends RichInputFormat<RowData, DorisTable return null; } List next = (List) scalaValueReader.next(); - GenericRowData genericRowData = new GenericRowData(next.size()); - for (int i = 0; i < next.size(); i++) { - genericRowData.setField(i, next.get(i)); + GenericRowData genericRowData = new GenericRowData(rowType.getFieldCount()); + for (int i = 0; i < next.size() && i < rowType.getFieldCount(); i++) { + Object value = deserialize(rowType.getTypeAt(i), next.get(i)); + genericRowData.setField(i, value); } //update hasNext after we've read the record hasNext = scalaValueReader.hasNext(); return genericRowData; } + private Object deserialize(LogicalType type, Object val) { + switch (type.getTypeRoot()) { + case DECIMAL: + final DecimalType decimalType = ((DecimalType) type); + final int precision = decimalType.getPrecision(); + final int scala = decimalType.getScale(); + return DecimalData.fromBigDecimal((BigDecimal) val, precision, scala); + case VARCHAR: + return StringData.fromString((String) val); + default: + return val; + } + } + @Override public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException { return cachedStatistics; @@ -182,6 +208,7 @@ public class DorisRowDataInputFormat extends RichInputFormat<RowData, DorisTable private DorisOptions.Builder optionsBuilder; private List<PartitionDefinition> partitions; private DorisReadOptions readOptions; + private RowType rowType; public Builder() { @@ -218,9 +245,14 @@ public class DorisRowDataInputFormat extends RichInputFormat<RowData, DorisTable return this; } + public Builder setRowType(RowType rowType) { + this.rowType = rowType; + return this; + } + public DorisRowDataInputFormat build() { return new DorisRowDataInputFormat( - optionsBuilder.build(), partitions, readOptions + optionsBuilder.build(), partitions, readOptions, rowType ); } } diff --git a/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java b/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java index 0f45aaa..f2bf878 100644 --- a/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java +++ b/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java @@ -44,7 +44,7 @@ import org.apache.doris.thrift.TStatusCode; import org.apache.flink.shaded.guava18.com.google.common.collect.ImmutableList; import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import org.apache.flink.table.data.DecimalData; -import org.apache.flink.table.data.StringData; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -79,23 +79,23 @@ public class TestRowBatch { childrenBuilder.add(new Field("k8", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null)); childrenBuilder.add(new Field("k10", FieldType.nullable(new ArrowType.Utf8()), null)); childrenBuilder.add(new Field("k11", FieldType.nullable(new ArrowType.Utf8()), null)); - childrenBuilder.add(new Field("k5", FieldType.nullable(new ArrowType.Decimal(9,2)), null)); + childrenBuilder.add(new Field("k5", FieldType.nullable(new ArrowType.Decimal(9, 2)), null)); childrenBuilder.add(new Field("k6", FieldType.nullable(new ArrowType.Utf8()), null)); VectorSchemaRoot root = VectorSchemaRoot.create( - new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), - new RootAllocator(Integer.MAX_VALUE)); + new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), + new RootAllocator(Integer.MAX_VALUE)); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter( - root, - new DictionaryProvider.MapDictionaryProvider(), - outputStream); + root, + new DictionaryProvider.MapDictionaryProvider(), + outputStream); arrowStreamWriter.start(); root.setRowCount(3); FieldVector vector = root.getVector("k0"); - BitVector bitVector = (BitVector)vector; + BitVector bitVector = (BitVector) vector; bitVector.setInitialCapacity(3); bitVector.allocateNew(3); bitVector.setSafe(0, 1); @@ -104,7 +104,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k1"); - TinyIntVector tinyIntVector = (TinyIntVector)vector; + TinyIntVector tinyIntVector = (TinyIntVector) vector; tinyIntVector.setInitialCapacity(3); tinyIntVector.allocateNew(3); tinyIntVector.setSafe(0, 1); @@ -113,7 +113,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k2"); - SmallIntVector smallIntVector = (SmallIntVector)vector; + SmallIntVector smallIntVector = (SmallIntVector) vector; smallIntVector.setInitialCapacity(3); smallIntVector.allocateNew(3); smallIntVector.setSafe(0, 1); @@ -122,7 +122,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k3"); - IntVector intVector = (IntVector)vector; + IntVector intVector = (IntVector) vector; intVector.setInitialCapacity(3); intVector.allocateNew(3); intVector.setSafe(0, 1); @@ -131,7 +131,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k4"); - BigIntVector bigIntVector = (BigIntVector)vector; + BigIntVector bigIntVector = (BigIntVector) vector; bigIntVector.setInitialCapacity(3); bigIntVector.allocateNew(3); bigIntVector.setSafe(0, 1); @@ -140,7 +140,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k5"); - DecimalVector decimalVector = (DecimalVector)vector; + DecimalVector decimalVector = (DecimalVector) vector; decimalVector.setInitialCapacity(3); decimalVector.allocateNew(); decimalVector.setIndexDefined(0); @@ -152,7 +152,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k6"); - VarCharVector charVector = (VarCharVector)vector; + VarCharVector charVector = (VarCharVector) vector; charVector.setInitialCapacity(3); charVector.allocateNew(); charVector.setIndexDefined(0); @@ -167,7 +167,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k8"); - Float8Vector float8Vector = (Float8Vector)vector; + Float8Vector float8Vector = (Float8Vector) vector; float8Vector.setInitialCapacity(3); float8Vector.allocateNew(3); float8Vector.setSafe(0, 1.1); @@ -176,7 +176,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k9"); - Float4Vector float4Vector = (Float4Vector)vector; + Float4Vector float4Vector = (Float4Vector) vector; float4Vector.setInitialCapacity(3); float4Vector.allocateNew(3); float4Vector.setSafe(0, 1.1f); @@ -185,7 +185,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k10"); - VarCharVector datecharVector = (VarCharVector)vector; + VarCharVector datecharVector = (VarCharVector) vector; datecharVector.setInitialCapacity(3); datecharVector.allocateNew(); datecharVector.setIndexDefined(0); @@ -200,7 +200,7 @@ public class TestRowBatch { vector.setValueCount(3); vector = root.getVector("k11"); - VarCharVector timecharVector = (VarCharVector)vector; + VarCharVector timecharVector = (VarCharVector) vector; timecharVector.setInitialCapacity(3); timecharVector.allocateNew(); timecharVector.setIndexDefined(0); @@ -227,71 +227,74 @@ public class TestRowBatch { scanBatchResult.setRows(outputStream.toByteArray()); String schemaStr = "{\"properties\":[{\"type\":\"BOOLEAN\",\"name\":\"k0\",\"comment\":\"\"}," - + "{\"type\":\"TINYINT\",\"name\":\"k1\",\"comment\":\"\"},{\"type\":\"SMALLINT\",\"name\":\"k2\"," - + "\"comment\":\"\"},{\"type\":\"INT\",\"name\":\"k3\",\"comment\":\"\"},{\"type\":\"BIGINT\"," - + "\"name\":\"k4\",\"comment\":\"\"},{\"type\":\"FLOAT\",\"name\":\"k9\",\"comment\":\"\"}," - + "{\"type\":\"DOUBLE\",\"name\":\"k8\",\"comment\":\"\"},{\"type\":\"DATE\",\"name\":\"k10\"," - + "\"comment\":\"\"},{\"type\":\"DATETIME\",\"name\":\"k11\",\"comment\":\"\"}," - + "{\"name\":\"k5\",\"scale\":\"0\",\"comment\":\"\"," - + "\"type\":\"DECIMAL\",\"precision\":\"9\",\"aggregation_type\":\"\"},{\"type\":\"CHAR\",\"name\":\"k6\",\"comment\":\"\",\"aggregation_type\":\"REPLACE_IF_NOT_NULL\"}]," - + "\"status\":200}"; + + "{\"type\":\"TINYINT\",\"name\":\"k1\",\"comment\":\"\"},{\"type\":\"SMALLINT\",\"name\":\"k2\"," + + "\"comment\":\"\"},{\"type\":\"INT\",\"name\":\"k3\",\"comment\":\"\"},{\"type\":\"BIGINT\"," + + "\"name\":\"k4\",\"comment\":\"\"},{\"type\":\"FLOAT\",\"name\":\"k9\",\"comment\":\"\"}," + + "{\"type\":\"DOUBLE\",\"name\":\"k8\",\"comment\":\"\"},{\"type\":\"DATE\",\"name\":\"k10\"," + + "\"comment\":\"\"},{\"type\":\"DATETIME\",\"name\":\"k11\",\"comment\":\"\"}," + + "{\"name\":\"k5\",\"scale\":\"0\",\"comment\":\"\"," + + "\"type\":\"DECIMAL\",\"precision\":\"9\",\"aggregation_type\":\"\"},{\"type\":\"CHAR\",\"name\":\"k6\",\"comment\":\"\",\"aggregation_type\":\"REPLACE_IF_NOT_NULL\"}]," + + "\"status\":200}"; Schema schema = RestService.parseSchema(schemaStr, logger); RowBatch rowBatch = new RowBatch(scanBatchResult, schema).readArrow(); List<Object> expectedRow1 = Lists.newArrayList( - Boolean.TRUE, - (byte) 1, - (short) 1, - 1, - 1L, - (float) 1.1, - (double) 1.1, - StringData.fromString("2008-08-08"), - StringData.fromString("2008-08-08 00:00:00"), - DecimalData.fromBigDecimal(new BigDecimal(12.34), 4, 2), - StringData.fromString("char1") + Boolean.TRUE, + (byte) 1, + (short) 1, + 1, + 1L, + (float) 1.1, + (double) 1.1, + "2008-08-08", + "2008-08-08 00:00:00", + DecimalData.fromBigDecimal(new BigDecimal(12.34), 4, 2), + "char1" ); List<Object> expectedRow2 = Arrays.asList( - Boolean.FALSE, - (byte) 2, - (short) 2, - null, - 2L, - (float) 2.2, - (double) 2.2, - StringData.fromString("1900-08-08"), - StringData.fromString("1900-08-08 00:00:00"), - DecimalData.fromBigDecimal(new BigDecimal(88.88), 4, 2), - StringData.fromString("char2") + Boolean.FALSE, + (byte) 2, + (short) 2, + null, + 2L, + (float) 2.2, + (double) 2.2, + "1900-08-08", + "1900-08-08 00:00:00", + DecimalData.fromBigDecimal(new BigDecimal(88.88), 4, 2), + "char2" ); List<Object> expectedRow3 = Arrays.asList( - Boolean.TRUE, - (byte) 3, - (short) 3, - 3, - 3L, - (float) 3.3, - (double) 3.3, - StringData.fromString("2100-08-08"), - StringData.fromString("2100-08-08 00:00:00"), - DecimalData.fromBigDecimal(new BigDecimal(10.22), 4, 2), - StringData.fromString("char3") + Boolean.TRUE, + (byte) 3, + (short) 3, + 3, + 3L, + (float) 3.3, + (double) 3.3, + "2100-08-08", + "2100-08-08 00:00:00", + DecimalData.fromBigDecimal(new BigDecimal(10.22), 4, 2), + "char3" ); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow1 = rowBatch.next(); + actualRow1.set(9, DecimalData.fromBigDecimal((BigDecimal) actualRow1.get(9), 4, 2)); Assert.assertEquals(expectedRow1, actualRow1); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow2 = rowBatch.next(); + actualRow2.set(9, DecimalData.fromBigDecimal((BigDecimal) actualRow2.get(9), 4, 2)); Assert.assertEquals(expectedRow2, actualRow2); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow3 = rowBatch.next(); + actualRow3.set(9, DecimalData.fromBigDecimal((BigDecimal) actualRow3.get(9), 4, 2)); Assert.assertEquals(expectedRow3, actualRow3); Assert.assertFalse(rowBatch.hasNext()); @@ -310,13 +313,13 @@ public class TestRowBatch { childrenBuilder.add(new Field("k7", FieldType.nullable(new ArrowType.Binary()), null)); VectorSchemaRoot root = VectorSchemaRoot.create( - new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), - new RootAllocator(Integer.MAX_VALUE)); + new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), + new RootAllocator(Integer.MAX_VALUE)); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter( - root, - new DictionaryProvider.MapDictionaryProvider(), - outputStream); + root, + new DictionaryProvider.MapDictionaryProvider(), + outputStream); arrowStreamWriter.start(); root.setRowCount(3); @@ -356,15 +359,15 @@ public class TestRowBatch { Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow0 = rowBatch.next(); - Assert.assertArrayEquals(binaryRow0, (byte[])actualRow0.get(0)); + Assert.assertArrayEquals(binaryRow0, (byte[]) actualRow0.get(0)); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow1 = rowBatch.next(); - Assert.assertArrayEquals(binaryRow1, (byte[])actualRow1.get(0)); + Assert.assertArrayEquals(binaryRow1, (byte[]) actualRow1.get(0)); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow2 = rowBatch.next(); - Assert.assertArrayEquals(binaryRow2, (byte[])actualRow2.get(0)); + Assert.assertArrayEquals(binaryRow2, (byte[]) actualRow2.get(0)); Assert.assertFalse(rowBatch.hasNext()); thrown.expect(NoSuchElementException.class); @@ -378,13 +381,13 @@ public class TestRowBatch { childrenBuilder.add(new Field("k7", FieldType.nullable(new ArrowType.Decimal(27, 9)), null)); VectorSchemaRoot root = VectorSchemaRoot.create( - new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), - new RootAllocator(Integer.MAX_VALUE)); + new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), + new RootAllocator(Integer.MAX_VALUE)); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter( - root, - new DictionaryProvider.MapDictionaryProvider(), - outputStream); + root, + new DictionaryProvider.MapDictionaryProvider(), + outputStream); arrowStreamWriter.start(); root.setRowCount(3); @@ -411,8 +414,8 @@ public class TestRowBatch { scanBatchResult.setRows(outputStream.toByteArray()); String schemaStr = "{\"properties\":[{\"type\":\"DECIMALV2\",\"scale\": 0," - + "\"precision\": 9, \"name\":\"k7\",\"comment\":\"\"}], " - + "\"status\":200}"; + + "\"precision\": 9, \"name\":\"k7\",\"comment\":\"\"}], " + + "\"status\":200}"; Schema schema = RestService.parseSchema(schemaStr, logger); @@ -420,16 +423,19 @@ public class TestRowBatch { Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow0 = rowBatch.next(); - Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(12.340000000), 11, 9), actualRow0.get(0)); + Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(12.340000000), 11, 9), + DecimalData.fromBigDecimal((BigDecimal) actualRow0.get(0), 11, 9)); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow1 = rowBatch.next(); - Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(88.880000000), 11, 9), actualRow1.get(0)); + Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(88.880000000), 11, 9), + DecimalData.fromBigDecimal((BigDecimal) actualRow1.get(0), 11, 9)); Assert.assertTrue(rowBatch.hasNext()); List<Object> actualRow2 = rowBatch.next(); - Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(10.000000000),11, 9), actualRow2.get(0)); + Assert.assertEquals(DecimalData.fromBigDecimal(new BigDecimal(10.000000000), 11, 9), + DecimalData.fromBigDecimal((BigDecimal) actualRow2.get(0), 11, 9)); Assert.assertFalse(rowBatch.hasNext()); thrown.expect(NoSuchElementException.class); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org