This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0356ac009472 [SPARK-40876][SQL] Widening type promotion from integers 
to decimal in Parquet vectorized reader
0356ac009472 is described below

commit 0356ac00947282b1a0885ad7eaae1e25e43671fe
Author: Johan Lasperas <[email protected]>
AuthorDate: Tue Jan 23 12:37:18 2024 -0800

    [SPARK-40876][SQL] Widening type promotion from integers to decimal in 
Parquet vectorized reader
    
    ### What changes were proposed in this pull request?
    This is a follow-up from https://github.com/apache/spark/pull/44368 and 
https://github.com/apache/spark/pull/44513, implementing an additional type 
promotion from integers to decimals in the parquet vectorized reader, bringing 
it at parity with the non-vectorized reader in that regard.
    
    ### Why are the changes needed?
    This allows reading parquet files that have different schemas and mix 
decimals and integers - e.g reading files containing either `Decimal(15, 2)` 
and `INT32` as `Decimal(15, 2)` - as long as the requested decimal type is 
large enough to accommodate the integer values without precision loss.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, the following now succeeds when using the vectorized Parquet reader:
    ```
      Seq(20).toDF($"a".cast(IntegerType)).write.parquet(path)
      spark.read.schema("a decimal(12, 0)").parquet(path).collect()
    ```
    It failed before with the vectorized reader and succeeded with the 
non-vectorized reader.
    
    ### How was this patch tested?
    - Tests added to `ParquetWideningTypeSuite`
    - Updated relevant `ParquetQuerySuite` test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #44803 from johanl-db/SPARK-40876-widening-promotion-int-to-decimal.
    
    Authored-by: Johan Lasperas <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../parquet/ParquetVectorUpdaterFactory.java       |  39 ++++++-
 .../parquet/VectorizedColumnReader.java            |   7 +-
 .../datasources/parquet/ParquetQuerySuite.scala    |   8 +-
 .../parquet/ParquetTypeWideningSuite.scala         | 123 ++++++++++++++++++---
 4 files changed, 150 insertions(+), 27 deletions(-)

diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
index 0d8713b58cec..f369688597b9 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
@@ -1407,7 +1407,11 @@ public class ParquetVectorUpdaterFactory {
       super(sparkType);
       LogicalTypeAnnotation typeAnnotation =
         descriptor.getPrimitiveType().getLogicalTypeAnnotation();
-      this.parquetScale = ((DecimalLogicalTypeAnnotation) 
typeAnnotation).getScale();
+      if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
+        this.parquetScale = ((DecimalLogicalTypeAnnotation) 
typeAnnotation).getScale();
+      } else {
+        this.parquetScale = 0;
+      }
     }
 
     @Override
@@ -1436,14 +1440,18 @@ public class ParquetVectorUpdaterFactory {
     }
   }
 
-private static class LongToDecimalUpdater extends DecimalUpdater {
+  private static class LongToDecimalUpdater extends DecimalUpdater {
     private final int parquetScale;
 
-   LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
+    LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
       super(sparkType);
       LogicalTypeAnnotation typeAnnotation =
         descriptor.getPrimitiveType().getLogicalTypeAnnotation();
-      this.parquetScale = ((DecimalLogicalTypeAnnotation) 
typeAnnotation).getScale();
+      if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
+        this.parquetScale = ((DecimalLogicalTypeAnnotation) 
typeAnnotation).getScale();
+      } else {
+        this.parquetScale = 0;
+      }
     }
 
     @Override
@@ -1641,6 +1649,12 @@ private static class FixedLenByteArrayToDecimalUpdater 
extends DecimalUpdater {
     return typeAnnotation instanceof DateLogicalTypeAnnotation;
   }
 
+  private static boolean isSignedIntAnnotation(LogicalTypeAnnotation 
typeAnnotation) {
+    if (!(typeAnnotation instanceof IntLogicalTypeAnnotation)) return false;
+    IntLogicalTypeAnnotation intAnnotation = (IntLogicalTypeAnnotation) 
typeAnnotation;
+    return intAnnotation.isSigned();
+  }
+
   private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, 
DataType dt) {
     DecimalType requestedType = (DecimalType) dt;
     LogicalTypeAnnotation typeAnnotation = 
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
@@ -1652,6 +1666,20 @@ private static class FixedLenByteArrayToDecimalUpdater 
extends DecimalUpdater {
       int scaleIncrease = requestedType.scale() - parquetType.getScale();
       int precisionIncrease = requestedType.precision() - 
parquetType.getPrecision();
       return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease;
+    } else if (typeAnnotation == null || 
isSignedIntAnnotation(typeAnnotation)) {
+      // Allow reading signed integers (which may be un-annotated) as decimal 
as long as the
+      // requested decimal type is large enough to represent all possible 
values.
+      PrimitiveType.PrimitiveTypeName typeName =
+        descriptor.getPrimitiveType().getPrimitiveTypeName();
+      int integerPrecision = requestedType.precision() - requestedType.scale();
+      switch (typeName) {
+        case INT32:
+          return integerPrecision >= 
DecimalType$.MODULE$.IntDecimal().precision();
+        case INT64:
+          return integerPrecision >= 
DecimalType$.MODULE$.LongDecimal().precision();
+        default:
+          return false;
+      }
     }
     return false;
   }
@@ -1662,6 +1690,9 @@ private static class FixedLenByteArrayToDecimalUpdater 
extends DecimalUpdater {
     if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
       DecimalLogicalTypeAnnotation decimalType = 
(DecimalLogicalTypeAnnotation) typeAnnotation;
       return decimalType.getScale() == d.scale();
+    } else if (typeAnnotation == null || 
isSignedIntAnnotation(typeAnnotation)) {
+      // Consider signed integers (which may be un-annotated) as having scale 
0.
+      return d.scale() == 0;
     }
     return false;
   }
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index d580023bc877..731c78cf9450 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -153,10 +153,9 @@ public class VectorizedColumnReader {
     // rebasing.
     switch (typeName) {
       case INT32: {
-        boolean isDate = logicalTypeAnnotation instanceof 
DateLogicalTypeAnnotation;
-        boolean isDecimal = logicalTypeAnnotation instanceof 
DecimalLogicalTypeAnnotation;
+        boolean isDecimal = sparkType instanceof DecimalType;
         boolean needsUpcast = sparkType == LongType || sparkType == DoubleType 
||
-          (isDate && sparkType == TimestampNTZType) ||
+          sparkType == TimestampNTZType ||
           (isDecimal && !DecimalType.is32BitDecimalType(sparkType));
         boolean needsRebase = logicalTypeAnnotation instanceof 
DateLogicalTypeAnnotation &&
           !"CORRECTED".equals(datetimeRebaseMode);
@@ -164,7 +163,7 @@ public class VectorizedColumnReader {
         break;
       }
       case INT64: {
-        boolean isDecimal = logicalTypeAnnotation instanceof 
DecimalLogicalTypeAnnotation;
+        boolean isDecimal = sparkType instanceof DecimalType;
         boolean needsUpcast = (isDecimal && 
!DecimalType.is64BitDecimalType(sparkType)) ||
           updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
         boolean needsRebase = 
updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index b306a526818e..b8a6cb5d0712 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -1037,8 +1037,10 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
 
       withAllParquetReaders {
         // We can read the decimal parquet field with a larger precision, if 
scale is the same.
-        val schema = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
-        checkAnswer(readParquet(schema, path), df)
+        val schema1 = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
+        checkAnswer(readParquet(schema1, path), df)
+        val schema2 = "a DECIMAL(18, 1), b DECIMAL(38, 2), c DECIMAL(38, 2)"
+        checkAnswer(readParquet(schema2, path), df)
       }
 
       withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
@@ -1067,10 +1069,12 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
 
       withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
         checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
+        checkAnswer(readParquet("a DECIMAL(11, 2)", path), sql("SELECT 1.00"))
         checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
         checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 
123456.0"))
         checkAnswer(readParquet("c DECIMAL(11, 1)", path), Row(null))
         checkAnswer(readParquet("c DECIMAL(13, 0)", path), df.select("c"))
+        checkAnswer(readParquet("c DECIMAL(22, 0)", path), df.select("c"))
         val e = intercept[SparkException] {
           readParquet("d DECIMAL(3, 2)", path).collect()
         }.getCause
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
index 7b8357e20774..6302c2703619 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
 import java.io.File
 
 import org.apache.hadoop.fs.Path
+import org.apache.parquet.column.{Encoding, ParquetProperties}
 import org.apache.parquet.format.converter.ParquetMetadataConverter
 import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat}
 
@@ -31,6 +32,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, 
SQLConf}
 import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.DecimalType.{ByteDecimal, IntDecimal, 
LongDecimal, ShortDecimal}
 
 class ParquetTypeWideningSuite
     extends QueryTest
@@ -121,6 +123,19 @@ class ParquetTypeWideningSuite
     if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) {
       assertAllParquetFilesDictionaryEncoded(dir)
     }
+
+    // Check which encoding was used when writing Parquet V2 files.
+    val isParquetV2 = spark.conf.getOption(ParquetOutputFormat.WRITER_VERSION)
+      .contains(ParquetProperties.WriterVersion.PARQUET_2_0.toString)
+    if (isParquetV2) {
+      if (dictionaryEnabled) {
+        assertParquetV2Encoding(dir, Encoding.PLAIN)
+      } else if (DecimalType.is64BitDecimalType(dataType)) {
+        assertParquetV2Encoding(dir, Encoding.DELTA_BINARY_PACKED)
+      } else if (DecimalType.isByteArrayDecimalType(dataType)) {
+        assertParquetV2Encoding(dir, Encoding.DELTA_BYTE_ARRAY)
+      }
+    }
     df
   }
 
@@ -145,6 +160,27 @@ class ParquetTypeWideningSuite
     }
   }
 
+  /**
+   * Asserts that all parquet files in the given directory have all their 
columns encoded with the
+   * given encoding.
+   */
+  private def assertParquetV2Encoding(dir: File, expected_encoding: Encoding): 
Unit = {
+    dir.listFiles(_.getName.endsWith(".parquet")).foreach { file =>
+      val parquetMetadata = ParquetFileReader.readFooter(
+        spark.sessionState.newHadoopConf(),
+        new Path(dir.toString, file.getName),
+        ParquetMetadataConverter.NO_FILTER)
+      parquetMetadata.getBlocks.forEach { block =>
+        block.getColumns.forEach { col =>
+          assert(
+            col.getEncodings.contains(expected_encoding),
+            s"Expected column '${col.getPath.toDotString}' to use encoding 
$expected_encoding " +
+            s"but found ${col.getEncodings}.")
+        }
+      }
+    }
+  }
+
   for {
     (values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
       (Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType),
@@ -157,24 +193,77 @@ class ParquetTypeWideningSuite
       (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, 
TimestampNTZType)
     )
   }
-    test(s"parquet widening conversion $fromType -> $toType") {
-      checkAllParquetReaders(values, fromType, toType, expectError = false)
-    }
+  test(s"parquet widening conversion $fromType -> $toType") {
+    checkAllParquetReaders(values, fromType, toType, expectError = false)
+  }
+
+  for {
+    (values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
+      (Seq("1", Byte.MaxValue.toString), ByteType, IntDecimal),
+      (Seq("1", Byte.MaxValue.toString), ByteType, LongDecimal),
+      (Seq("1", Short.MaxValue.toString), ShortType, IntDecimal),
+      (Seq("1", Short.MaxValue.toString), ShortType, LongDecimal),
+      (Seq("1", Short.MaxValue.toString), ShortType, 
DecimalType(DecimalType.MAX_PRECISION, 0)),
+      (Seq("1", Int.MaxValue.toString), IntegerType, IntDecimal),
+      (Seq("1", Int.MaxValue.toString), IntegerType, LongDecimal),
+      (Seq("1", Int.MaxValue.toString), IntegerType, 
DecimalType(DecimalType.MAX_PRECISION, 0)),
+      (Seq("1", Long.MaxValue.toString), LongType, LongDecimal),
+      (Seq("1", Long.MaxValue.toString), LongType, 
DecimalType(DecimalType.MAX_PRECISION, 0)),
+      (Seq("1", Byte.MaxValue.toString), ByteType, 
DecimalType(IntDecimal.precision + 1, 1)),
+      (Seq("1", Short.MaxValue.toString), ShortType, 
DecimalType(IntDecimal.precision + 1, 1)),
+      (Seq("1", Int.MaxValue.toString), IntegerType, 
DecimalType(IntDecimal.precision + 1, 1)),
+      (Seq("1", Long.MaxValue.toString), LongType, 
DecimalType(LongDecimal.precision + 1, 1))
+    )
+  }
+  test(s"parquet widening conversion $fromType -> $toType") {
+    checkAllParquetReaders(values, fromType, toType, expectError = false)
+  }
 
   for {
     (values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
       (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType),
       (Seq("1.23", "10.34"), DoubleType, FloatType),
       (Seq("1.23", "10.34"), FloatType, LongType),
+      (Seq("1", "10"), LongType, DoubleType),
       (Seq("1", "10"), LongType, DateType),
       (Seq("1", "10"), IntegerType, TimestampType),
       (Seq("1", "10"), IntegerType, TimestampNTZType),
       (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType)
     )
   }
-    test(s"unsupported parquet conversion $fromType -> $toType") {
-      checkAllParquetReaders(values, fromType, toType, expectError = true)
-    }
+  test(s"unsupported parquet conversion $fromType -> $toType") {
+    checkAllParquetReaders(values, fromType, toType, expectError = true)
+  }
+
+  for {
+    (values: Seq[String], fromType: DataType, toType: DecimalType) <- Seq(
+      // Parquet stores byte, short, int values as INT32, which then requires 
using a decimal that
+      // can hold at least 4 byte integers.
+      (Seq("1", "2"), ByteType, DecimalType(1, 0)),
+      (Seq("1", "2"), ByteType, ByteDecimal),
+      (Seq("1", "2"), ShortType, ByteDecimal),
+      (Seq("1", "2"), ShortType, ShortDecimal),
+      (Seq("1", "2"), IntegerType, ShortDecimal),
+      (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision + 1, 1)),
+      (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision + 1, 1)),
+      (Seq("1", "2"), LongType, IntDecimal),
+      (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision - 1, 0)),
+      (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision - 1, 0)),
+      (Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision - 1, 0)),
+      (Seq("1", "2"), LongType, DecimalType(LongDecimal.precision - 1, 0)),
+      (Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision, 1)),
+      (Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision, 1)),
+      (Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision, 1)),
+      (Seq("1", "2"), LongType, DecimalType(LongDecimal.precision, 1))
+    )
+  }
+  test(s"unsupported parquet conversion $fromType -> $toType") {
+    checkAllParquetReaders(values, fromType, toType,
+      expectError =
+      // parquet-mr allows reading decimals into a smaller precision decimal 
type without
+      // checking for overflows. See test below checking for the overflow case 
in parquet-mr.
+        
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
+  }
 
   for {
     (values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
@@ -201,17 +290,17 @@ class ParquetTypeWideningSuite
     Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++
       Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20)
   }
-    test(
-      s"parquet decimal precision change Decimal($fromPrecision, 2) -> 
Decimal($toPrecision, 2)") {
-      checkAllParquetReaders(
-        values = Seq("1.23", "10.34"),
-        fromType = DecimalType(fromPrecision, 2),
-        toType = DecimalType(toPrecision, 2),
-        expectError = fromPrecision > toPrecision &&
-          // parquet-mr allows reading decimals into a smaller precision 
decimal type without
-          // checking for overflows. See test below checking for the overflow 
case in parquet-mr.
-          
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
-    }
+  test(
+    s"parquet decimal precision change Decimal($fromPrecision, 2) -> 
Decimal($toPrecision, 2)") {
+    checkAllParquetReaders(
+      values = Seq("1.23", "10.34"),
+      fromType = DecimalType(fromPrecision, 2),
+      toType = DecimalType(toPrecision, 2),
+      expectError = fromPrecision > toPrecision &&
+        // parquet-mr allows reading decimals into a smaller precision decimal 
type without
+        // checking for overflows. See test below checking for the overflow 
case in parquet-mr.
+        
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
+  }
 
   for {
     ((fromPrecision, fromScale), (toPrecision, toScale)) <-


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to