This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 37ee99256b76 [SPARK-53535][SQL] Fix missing structs always being
assumed as nulls
37ee99256b76 is described below
commit 37ee99256b7630a2943b1bc39c5b63fda9118923
Author: Ziya Mukhtarov <[email protected]>
AuthorDate: Tue Oct 21 07:47:00 2025 +0800
[SPARK-53535][SQL] Fix missing structs always being assumed as nulls
### What changes were proposed in this pull request?
Currently, if all fields of a struct mentioned in the read schema are
missing in a Parquet file, the reader populates the struct with nulls.
This PR modifies the scan behavior so that if the struct exists in the
Parquet schema but none of the fields from the read schema are present, we
instead pick an arbitrary field from the Parquet file to read and use that to
populate NULLs (as well as outer NULLs and array sizes if the struct is nested
in another nested type).
This is done by changing the schema requested by the readers. We add an
additional field to the requested schema when clipping the Parquet file schema
according to the Spark schema. This means that the readers actually read and
return more data than requested, which can cause problems. This is only a
problem for the `VectorizedParquetRecordReader`, since for the other read code
path via parquet-mr, we already have an `UnsafeProjection` for outputting only
requested schema fields in `P [...]
To ensure `VectorizedParquetRecordReader` only returns Spark requested
fields, we create the `ColumnarBatch` with vectors that match the requested
schema (we get rid of the additional fields by recursively matching
`sparkSchema` with `sparkRequestedSchema` and ensuring structs have the same
length in both). Then `ParquetColumnVector`s are responsible for allocating
dummy vectors to hold the data temporarily while reading, but these are not
exposed to the outside.
The heuristic to pick the arbitrary leaf field is as follows: We try to
minimize the amount of arrays or maps (repeated fields) in the path to a leaf
column, because the more repeated fields we have the more likely we are to read
larger amount of data. At the same repetition level, we consider the type of
each column to pick the cheapest column to read (struct nesting do not affect
the decision here). We look at the byte size of the column type to pick the
cheapest one as follows:
- BOOLEAN: 1 byte
- INT32, FLOAT: 4 bytes
- INT64, DOUBLE: 8 bytes
- INT96: 12 bytes
- BINARY, FIXED_LEN_BYTE_ARRAY, default case for future types: 32 bytes
(high cost due to variable/large size)
### Why are the changes needed?
This is a bug fix, because we were incorrectly assuming non-null struct
values to be missing from the file depending on requested fields and returning
null values.
### Does this PR introduce _any_ user-facing change?
Yes. We previously assumed structs to be null if all the fields we are
trying to read from a Parquet file were missing from that file, even if the
file contained other fields that could be used to take definition levels from.
See an example from the Jira ticket below:
```python
df_a = sql('SELECT 1 as id, named_struct("a", 1) AS s')
path = "/tmp/missing_col_test"
df_a.write.format("parquet").save(path)
df_b = sql('SELECT 2 as id, named_struct("b", 3) AS s')
spark.read.format("parquet").schema(df_b.schema).load(path).show()
```
This used to return:
```
+---+----+
| id| s|
+---+----+
| 1|NULL|
+---+----+
```
It now returns:
```
+---+------+
| id| s|
+---+------+
| 1|{NULL}|
+---+------+
```
### How was this patch tested?
Added new unit tests, also fixed an old test to expect this new behavior.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52557 from ZiyaZa/missing_struct.
Authored-by: Ziya Mukhtarov <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
docs/sql-migration-guide.md | 4 +
.../org/apache/spark/sql/internal/SQLConf.scala | 12 +
.../datasources/parquet/ParquetColumnVector.java | 48 ++-
.../parquet/SpecificParquetRecordReaderBase.java | 40 +--
.../parquet/VectorizedParquetRecordReader.java | 69 ++++-
.../execution/vectorized/OffHeapColumnVector.java | 2 +-
.../execution/vectorized/OnHeapColumnVector.java | 2 +-
.../execution/vectorized/WritableColumnVector.java | 2 +-
.../datasources/parquet/ParquetReadSupport.scala | 161 ++++++++---
.../datasources/parquet/ParquetRowConverter.scala | 22 +-
.../parquet/ParquetSchemaConverter.scala | 40 ++-
.../parquet/ParquetFieldIdSchemaSuite.scala | 3 +-
.../datasources/parquet/ParquetIOSuite.scala | 97 ++++++-
.../datasources/parquet/ParquetSchemaSuite.scala | 322 ++++++++++++++++++++-
.../parquet/ParquetVectorizedSuite.scala | 195 ++++++++++++-
15 files changed, 898 insertions(+), 121 deletions(-)
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index 0f78c0f5e551..e5becac54032 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -22,6 +22,10 @@ license: |
* Table of contents
{:toc}
+## Upgrading from Spark SQL 4.0 to 4.1
+
+- Since Spark 4.1, the Parquet reader no longer assumes all struct values to
be null, if all the requested fields are missing in the parquet file. The new
default behavior is to read an additional struct field that is present in the
file to determine nullness. To restore the previous behavior, set
`spark.sql.legacy.parquet.returnNullStructIfAllFieldsMissing` to `true`.
+
## Upgrading from Spark SQL 3.5 to 4.0
- Since Spark 4.0, `spark.sql.ansi.enabled` is on by default. To restore the
previous behavior, set `spark.sql.ansi.enabled` to `false` or
`SPARK_ANSI_SQL_MODE` to `false`.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2b859ee398eb..dd0dbe36d69a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1534,6 +1534,18 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val LEGACY_PARQUET_RETURN_NULL_STRUCT_IF_ALL_FIELDS_MISSING =
+ buildConf("spark.sql.legacy.parquet.returnNullStructIfAllFieldsMissing")
+ .internal()
+ .doc("When true, if all requested fields of a struct are missing in a
parquet file, assume " +
+ "the struct is always null, even if other fields are present. The
default behavior is " +
+ "to fetch and read an arbitrary non-requested field present in the
file to determine " +
+ "struct nullness. If enabled, schema pruning may cause non-null
structs to be read as " +
+ "null.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
val PARQUET_RECORD_FILTER_ENABLED =
buildConf("spark.sql.parquet.recordLevelFilter.enabled")
.doc("If true, enables Parquet's native record-level filtering using the
pushed down " +
"filters. " +
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
index 3331c8dfd8f5..37c936c84d5f 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
@@ -21,14 +21,10 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Set;
-import org.apache.spark.memory.MemoryMode;
import org.apache.spark.network.util.JavaUtils;
-import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
-import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.catalyst.types.DataTypeUtils;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructType;
@@ -69,16 +65,9 @@ final class ParquetColumnVector {
ParquetColumn column,
WritableColumnVector vector,
int capacity,
- MemoryMode memoryMode,
Set<ParquetColumn> missingColumns,
boolean isTopLevel,
Object defaultValue) {
- DataType sparkType = column.sparkType();
- if (!DataTypeUtils.sameType(sparkType, vector.dataType())) {
- throw new IllegalArgumentException("Spark type: " + sparkType +
- " doesn't match the type: " + vector.dataType() + " in column vector");
- }
-
this.column = column;
this.vector = vector;
this.children = new ArrayList<>();
@@ -111,11 +100,10 @@ final class ParquetColumnVector {
if (column.variantFileType().isDefined()) {
ParquetColumn fileContentCol = column.variantFileType().get();
- WritableColumnVector fileContent = memoryMode == MemoryMode.OFF_HEAP
- ? new OffHeapColumnVector(capacity, fileContentCol.sparkType())
- : new OnHeapColumnVector(capacity, fileContentCol.sparkType());
- ParquetColumnVector contentVector = new
ParquetColumnVector(fileContentCol,
- fileContent, capacity, memoryMode, missingColumns, false, null);
+ WritableColumnVector fileContent = vector.reserveNewColumn(
+ capacity, fileContentCol.sparkType());
+ ParquetColumnVector contentVector = new
ParquetColumnVector(fileContentCol, fileContent,
+ capacity, missingColumns, /* isTopLevel= */ false, /* defaultValue= */
null);
children.add(contentVector);
variantSchema =
SparkShreddingUtils.buildVariantSchema(fileContentCol.sparkType());
fieldsToExtract =
SparkShreddingUtils.getFieldsToExtract(column.sparkType(), variantSchema);
@@ -123,23 +111,30 @@ final class ParquetColumnVector {
definitionLevels = contentVector.definitionLevels;
} else if (isPrimitive) {
if (column.repetitionLevel() > 0) {
- repetitionLevels = allocateLevelsVector(capacity, memoryMode);
+ repetitionLevels = vector.reserveNewColumn(capacity,
DataTypes.IntegerType);
}
// We don't need to create and store definition levels if the column is
top-level.
if (!isTopLevel) {
- definitionLevels = allocateLevelsVector(capacity, memoryMode);
+ definitionLevels = vector.reserveNewColumn(capacity,
DataTypes.IntegerType);
}
} else {
- JavaUtils.checkArgument(column.children().size() ==
vector.getNumChildren(),
- "The number of column children is different from the number of vector
children");
+ // If a child is not present in the allocated vectors, it means we don't
care about this
+ // child's data, we just want to read its levels to help assemble some
parent struct. So we
+ // create a dummy vector below to hold the child's data. There can only
be one such child.
+ JavaUtils.checkArgument(column.children().size() ==
vector.getNumChildren() ||
+ column.children().size() == vector.getNumChildren() + 1,
+ "The number of column children is not equal to the number of vector
children or that + 1");
boolean allChildrenAreMissing = true;
for (int i = 0; i < column.children().size(); i++) {
- ParquetColumnVector childCv = new
ParquetColumnVector(column.children().apply(i),
- vector.getChild(i), capacity, memoryMode, missingColumns, false,
null);
+ ParquetColumn childColumn = column.children().apply(i);
+ WritableColumnVector childVector = i < vector.getNumChildren()
+ ? vector.getChild(i)
+ : vector.reserveNewColumn(capacity, childColumn.sparkType());
+ ParquetColumnVector childCv = new ParquetColumnVector(childColumn,
childVector, capacity,
+ missingColumns, /* isTopLevel= */ false, /* defaultValue= */ null);
children.add(childCv);
-
// Only use levels from non-missing child, this can happen if only
some but not all
// fields of a struct are missing.
if (!childCv.vector.isAllNull()) {
@@ -375,13 +370,6 @@ final class ParquetColumnVector {
vector.addElementsAppended(rowId);
}
- private static WritableColumnVector allocateLevelsVector(int capacity,
MemoryMode memoryMode) {
- return switch (memoryMode) {
- case ON_HEAP -> new OnHeapColumnVector(capacity, DataTypes.IntegerType);
- case OFF_HEAP -> new OffHeapColumnVector(capacity,
DataTypes.IntegerType);
- };
- }
-
/**
* For a collection (i.e., array or map) element at index 'idx', returns the
starting index of
* the next collection after it.
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index 038112086e47..eb0063688e70 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -87,6 +87,8 @@ public abstract class SpecificParquetRecordReaderBase<T>
extends RecordReader<Vo
protected ParquetRowGroupReader reader;
+ protected Configuration configuration;
+
@Override
public void initialize(InputSplit inputSplit, TaskAttemptContext
taskAttemptContext)
throws IOException, InterruptedException {
@@ -99,7 +101,7 @@ public abstract class SpecificParquetRecordReaderBase<T>
extends RecordReader<Vo
Option<HadoopInputFile> inputFile,
Option<SeekableInputStream> inputStream,
Option<ParquetMetadata> fileFooter) throws IOException,
InterruptedException {
- Configuration configuration = taskAttemptContext.getConfiguration();
+ this.configuration = taskAttemptContext.getConfiguration();
FileSplit split = (FileSplit) inputSplit;
this.file = split.getPath();
ParquetReadOptions options = HadoopReadOptions
@@ -164,22 +166,22 @@ public abstract class SpecificParquetRecordReaderBase<T>
extends RecordReader<Vo
* configurations.
*/
protected void initialize(String path, List<String> columns) throws
IOException {
- Configuration config = new Configuration();
- config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false);
- config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false);
- config.setBoolean(SQLConf.CASE_SENSITIVE().key(), false);
- config.setBoolean(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED().key(),
false);
- config.setBoolean(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG().key(), false);
+ this.configuration = new Configuration();
+ this.configuration.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() ,
false);
+ this.configuration.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(),
false);
+ this.configuration.setBoolean(SQLConf.CASE_SENSITIVE().key(), false);
+
this.configuration.setBoolean(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED().key(),
false);
+
this.configuration.setBoolean(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG().key(),
false);
this.file = new Path(path);
- long length =
this.file.getFileSystem(config).getFileStatus(this.file).getLen();
+ long length =
this.file.getFileSystem(configuration).getFileStatus(this.file).getLen();
ParquetReadOptions options = HadoopReadOptions
- .builder(config, file)
+ .builder(configuration, file)
.withRange(0, length)
.build();
ParquetFileReader fileReader = ParquetFileReader.open(
- HadoopInputFile.fromPath(file, config), options);
+ HadoopInputFile.fromPath(file, configuration), options);
this.reader = new ParquetRowGroupReaderImpl(fileReader);
this.fileSchema = fileReader.getFooter().getFileMetaData().getSchema();
@@ -201,9 +203,10 @@ public abstract class SpecificParquetRecordReaderBase<T>
extends RecordReader<Vo
}
}
fileReader.setRequestedSchema(requestedSchema);
- this.parquetColumn = new ParquetToSparkSchemaConverter(config)
+ this.parquetColumn = new ParquetToSparkSchemaConverter(configuration)
.convertParquetColumn(requestedSchema, Option.empty());
this.sparkSchema = (StructType) parquetColumn.sparkType();
+ this.sparkRequestedSchema = this.sparkSchema;
this.totalRowCount = fileReader.getFilteredRecordCount();
}
@@ -216,15 +219,16 @@ public abstract class SpecificParquetRecordReaderBase<T>
extends RecordReader<Vo
this.reader = rowGroupReader;
this.fileSchema = fileSchema;
this.requestedSchema = requestedSchema;
- Configuration config = new Configuration();
- config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false);
- config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false);
- config.setBoolean(SQLConf.CASE_SENSITIVE().key(), false);
- config.setBoolean(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED().key(),
false);
- config.setBoolean(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG().key(), false);
- this.parquetColumn = new ParquetToSparkSchemaConverter(config)
+ this.configuration = new Configuration();
+ this.configuration.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() ,
false);
+ this.configuration.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(),
false);
+ this.configuration.setBoolean(SQLConf.CASE_SENSITIVE().key(), false);
+
this.configuration.setBoolean(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED().key(),
false);
+
this.configuration.setBoolean(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG().key(),
false);
+ this.parquetColumn = new ParquetToSparkSchemaConverter(configuration)
.convertParquetColumn(requestedSchema, Option.empty());
this.sparkSchema = (StructType) parquetColumn.sparkType();
+ this.sparkRequestedSchema = this.sparkSchema;
this.totalRowCount = totalRowCount;
}
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
index b15f79df527e..72125701fd49 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
@@ -48,10 +48,10 @@ import
org.apache.spark.sql.execution.vectorized.ConstantColumnVector;
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import org.apache.spark.sql.internal.SQLConf$;
+import org.apache.spark.sql.types.*;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
/**
* A specialized RecordReader that reads into InternalRows or ColumnarBatches
directly using the
@@ -265,7 +265,15 @@ public class VectorizedParquetRecordReader extends
SpecificParquetRecordReaderBa
MemoryMode memMode,
StructType partitionColumns,
InternalRow partitionValues) {
- StructType batchSchema = new StructType(sparkSchema.fields());
+ boolean returnNullStructIfAllFieldsMissing = configuration.getBoolean(
+
SQLConf$.MODULE$.LEGACY_PARQUET_RETURN_NULL_STRUCT_IF_ALL_FIELDS_MISSING().key(),
+ (boolean)
SQLConf$.MODULE$.LEGACY_PARQUET_RETURN_NULL_STRUCT_IF_ALL_FIELDS_MISSING()
+ .defaultValue().get());
+ StructType batchSchema = returnNullStructIfAllFieldsMissing
+ ? new StructType(sparkSchema.fields())
+ // Truncate to match requested schema to make sure extra struct field
that we read for
+ // nullability is not included in columnarBatch and exposed outside.
+ : (StructType) truncateType(sparkSchema, sparkRequestedSchema);
int constantColumnLength = 0;
if (partitionColumns != null) {
@@ -287,7 +295,8 @@ public class VectorizedParquetRecordReader extends
SpecificParquetRecordReaderBa
defaultValue =
ResolveDefaultColumns.existenceDefaultValues(sparkRequestedSchema)[i];
}
columnVectors[i] = new
ParquetColumnVector(parquetColumn.children().apply(i),
- (WritableColumnVector) vectors[i], capacity, memMode, missingColumns,
true, defaultValue);
+ (WritableColumnVector) vectors[i], capacity, missingColumns, /*
isTopLevel= */ true,
+ defaultValue);
}
if (partitionColumns != null) {
@@ -309,6 +318,58 @@ public class VectorizedParquetRecordReader extends
SpecificParquetRecordReaderBa
initBatch(MEMORY_MODE, partitionColumns, partitionValues);
}
+ /**
+ * Keeps the hierarchy and fields of readType, recursively truncating struct
fields from the end
+ * of the fields list to match the same number of fields in requestedType.
This is used to get rid
+ * of the extra fields that are added to the structs when the fields we
wanted to read initially
+ * were missing in the file schema. So this returns a type that we would be
reading if everything
+ * was present in the file, matching Spark's expected schema.
+ *
+ * <p> Example: <pre>{@code
+ * readType: array<struct<a:int,b:long,c:int>>
+ * requestedType: array<struct<a:int,b:long>>
+ * returns: array<struct<a:int,b:long>>
+ * }</pre>
+ * We cannot return requestedType here because there might be slight
differences, like nullability
+ * of fields or the type precision (smallint/int)
+ */
+ @VisibleForTesting
+ static DataType truncateType(DataType readType, DataType requestedType) {
+ if (requestedType instanceof UserDefinedType<?> requestedUDT) {
+ requestedType = requestedUDT.sqlType();
+ }
+
+ if (readType instanceof StructType readStruct &&
+ requestedType instanceof StructType requestedStruct) {
+ StructType result = new StructType();
+ for (int i = 0; i < requestedStruct.fields().length; i++) {
+ StructField readField = readStruct.fields()[i];
+ StructField requestedField = requestedStruct.fields()[i];
+ DataType truncatedType = truncateType(readField.dataType(),
requestedField.dataType());
+ result = result.add(readField.copy(
+ readField.name(), truncatedType, readField.nullable(),
readField.metadata()));
+ }
+ return result;
+ }
+
+ if (readType instanceof ArrayType readArray &&
+ requestedType instanceof ArrayType requestedArray) {
+ DataType truncatedElementType = truncateType(
+ readArray.elementType(), requestedArray.elementType());
+ return readArray.copy(truncatedElementType, readArray.containsNull());
+ }
+
+ if (readType instanceof MapType readMap && requestedType instanceof
MapType requestedMap) {
+ DataType truncatedKeyType = truncateType(readMap.keyType(),
requestedMap.keyType());
+ DataType truncatedValueType = truncateType(readMap.valueType(),
requestedMap.valueType());
+ return readMap.copy(truncatedKeyType, truncatedValueType,
readMap.valueContainsNull());
+ }
+
+ assert !ParquetSchemaConverter.isComplexType(readType);
+ assert !ParquetSchemaConverter.isComplexType(requestedType);
+ return readType;
+ }
+
/**
* Returns the ColumnarBatch object that will be used for all rows returned
by this reader.
* This object is reused. Calling this enables the vectorized reader. This
should be called
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index da52cdf5c835..2f64ffb42aa0 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -633,7 +633,7 @@ public final class OffHeapColumnVector extends
WritableColumnVector {
}
@Override
- protected OffHeapColumnVector reserveNewColumn(int capacity, DataType type) {
+ public OffHeapColumnVector reserveNewColumn(int capacity, DataType type) {
return new OffHeapColumnVector(capacity, type);
}
}
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index fd3b07e3e217..cd8d0b688bed 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -646,7 +646,7 @@ public final class OnHeapColumnVector extends
WritableColumnVector {
}
@Override
- protected OnHeapColumnVector reserveNewColumn(int capacity, DataType type) {
+ public OnHeapColumnVector reserveNewColumn(int capacity, DataType type) {
return new OnHeapColumnVector(capacity, type);
}
}
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index fc465e73006b..3f552679bb6f 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -944,7 +944,7 @@ public abstract class WritableColumnVector extends
ColumnVector {
/**
* Reserve a new column.
*/
- protected abstract WritableColumnVector reserveNewColumn(int capacity,
DataType type);
+ public abstract WritableColumnVector reserveNewColumn(int capacity, DataType
type);
protected boolean isArray() {
return type instanceof ArrayType || type instanceof BinaryType || type
instanceof StringType ||
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
index 09fd0eccec4b..b42bd2a355f8 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
@@ -132,6 +132,9 @@ object ParquetReadSupport extends Logging {
SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get)
val ignoreMissingIds =
conf.getBoolean(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.key,
SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID.defaultValue.get)
+ val returnNullStructIfAllFieldsMissing = conf.getBoolean(
+ SQLConf.LEGACY_PARQUET_RETURN_NULL_STRUCT_IF_ALL_FIELDS_MISSING.key,
+
SQLConf.LEGACY_PARQUET_RETURN_NULL_STRUCT_IF_ALL_FIELDS_MISSING.defaultValue.get)
if (!ignoreMissingIds &&
!containsFieldIds(parquetFileSchema) &&
@@ -150,7 +153,7 @@ object ParquetReadSupport extends Logging {
|""".stripMargin)
}
val parquetClippedSchema =
ParquetReadSupport.clipParquetSchema(parquetFileSchema,
- catalystRequestedSchema, caseSensitive, useFieldId)
+ catalystRequestedSchema, caseSensitive, useFieldId,
returnNullStructIfAllFieldsMissing)
// We pass two schema to ParquetRecordMaterializer:
// - parquetRequestedSchema: the schema of the file data we want to read
@@ -192,9 +195,10 @@ object ParquetReadSupport extends Logging {
parquetSchema: MessageType,
catalystSchema: StructType,
caseSensitive: Boolean,
- useFieldId: Boolean): MessageType = {
- val clippedParquetFields = clipParquetGroupFields(
- parquetSchema.asGroupType(), catalystSchema, caseSensitive, useFieldId)
+ useFieldId: Boolean,
+ returnNullStructIfAllFieldsMissing: Boolean): MessageType = {
+ val clippedParquetFields =
clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema,
+ caseSensitive, useFieldId, returnNullStructIfAllFieldsMissing)
if (clippedParquetFields.isEmpty) {
ParquetSchemaConverter.EMPTY_MESSAGE
} else {
@@ -209,24 +213,28 @@ object ParquetReadSupport extends Logging {
parquetType: Type,
catalystType: DataType,
caseSensitive: Boolean,
- useFieldId: Boolean): Type = {
+ useFieldId: Boolean,
+ returnNullStructIfAllFieldsMissing: Boolean): Type = {
val newParquetType = catalystType match {
- case t: ArrayType if !isPrimitiveCatalystType(t.elementType) =>
+ case t: ArrayType if ParquetSchemaConverter.isComplexType(t.elementType)
=>
// Only clips array types with nested type as element type.
- clipParquetListType(parquetType.asGroupType(), t.elementType,
caseSensitive, useFieldId)
+ clipParquetListType(parquetType.asGroupType(), t.elementType,
caseSensitive, useFieldId,
+ returnNullStructIfAllFieldsMissing)
case t: MapType
- if !isPrimitiveCatalystType(t.keyType) ||
- !isPrimitiveCatalystType(t.valueType) =>
+ if ParquetSchemaConverter.isComplexType(t.keyType) ||
+ ParquetSchemaConverter.isComplexType(t.valueType) =>
// Only clips map types with nested key type or value type
clipParquetMapType(
- parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive,
useFieldId)
+ parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive,
useFieldId,
+ returnNullStructIfAllFieldsMissing)
case t: StructType if VariantMetadata.isVariantStruct(t) =>
- clipVariantSchema(parquetType.asGroupType(), t)
+ clipVariantSchema(parquetType.asGroupType(), t,
returnNullStructIfAllFieldsMissing)
case t: StructType =>
- clipParquetGroup(parquetType.asGroupType(), t, caseSensitive,
useFieldId)
+ clipParquetGroup(parquetType.asGroupType(), t, caseSensitive,
useFieldId,
+ returnNullStructIfAllFieldsMissing)
case _ =>
// UDTs and primitive types are not clipped. For UDTs, a clipped
version might not be able
@@ -241,18 +249,6 @@ object ParquetReadSupport extends Logging {
}
}
- /**
- * Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is
not equivalent to
- * [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but
it's not an
- * [[AtomicType]].
- */
- private def isPrimitiveCatalystType(dataType: DataType): Boolean = {
- dataType match {
- case _: ArrayType | _: MapType | _: StructType => false
- case _ => true
- }
- }
-
/**
* Clips a Parquet [[GroupType]] which corresponds to a Catalyst
[[ArrayType]]. The element type
* of the [[ArrayType]] should also be a nested type, namely an
[[ArrayType]], a [[MapType]], or a
@@ -262,15 +258,17 @@ object ParquetReadSupport extends Logging {
parquetList: GroupType,
elementType: DataType,
caseSensitive: Boolean,
- useFieldId: Boolean): Type = {
+ useFieldId: Boolean,
+ returnNullStructIfAllFieldsMissing: Boolean): Type = {
// Precondition of this method, should only be called for lists with
nested element types.
- assert(!isPrimitiveCatalystType(elementType))
+ assert(ParquetSchemaConverter.isComplexType(elementType))
// Unannotated repeated group should be interpreted as required list of
required element, so
// list element type is just the group itself. Clip it.
if (parquetList.getLogicalTypeAnnotation == null &&
parquetList.isRepetition(Repetition.REPEATED)) {
- clipParquetType(parquetList, elementType, caseSensitive, useFieldId)
+ clipParquetType(parquetList, elementType, caseSensitive, useFieldId,
+ returnNullStructIfAllFieldsMissing)
} else {
assert(
parquetList.getLogicalTypeAnnotation.isInstanceOf[ListLogicalTypeAnnotation],
@@ -304,14 +302,16 @@ object ParquetReadSupport extends Logging {
.as(LogicalTypeAnnotation.listType())
.addField(
clipParquetType(
- repeatedGroup, elementType, caseSensitive, useFieldId))
+ repeatedGroup, elementType, caseSensitive, useFieldId,
+ returnNullStructIfAllFieldsMissing))
.named(parquetList.getName)
} else {
val newRepeatedGroup = Types
.repeatedGroup()
.addField(
clipParquetType(
- repeatedGroup.getType(0), elementType, caseSensitive,
useFieldId))
+ repeatedGroup.getType(0), elementType, caseSensitive, useFieldId,
+ returnNullStructIfAllFieldsMissing))
.named(repeatedGroup.getName)
val newElementType = if (useFieldId && repeatedGroup.getId != null) {
@@ -341,9 +341,11 @@ object ParquetReadSupport extends Logging {
keyType: DataType,
valueType: DataType,
caseSensitive: Boolean,
- useFieldId: Boolean): GroupType = {
+ useFieldId: Boolean,
+ returnNullStructIfAllFieldsMissing: Boolean): GroupType = {
// Precondition of this method, only handles maps with nested key types or
value types.
- assert(!isPrimitiveCatalystType(keyType) ||
!isPrimitiveCatalystType(valueType))
+ assert(ParquetSchemaConverter.isComplexType(keyType) ||
+ ParquetSchemaConverter.isComplexType(valueType))
val repeatedGroup = parquetMap.getType(0).asGroupType()
val parquetKeyType = repeatedGroup.getType(0)
@@ -354,9 +356,11 @@ object ParquetReadSupport extends Logging {
.repeatedGroup()
.as(repeatedGroup.getLogicalTypeAnnotation)
.addField(
- clipParquetType(parquetKeyType, keyType, caseSensitive, useFieldId))
+ clipParquetType(parquetKeyType, keyType, caseSensitive, useFieldId,
+ returnNullStructIfAllFieldsMissing))
.addField(
- clipParquetType(parquetValueType, valueType, caseSensitive,
useFieldId))
+ clipParquetType(parquetValueType, valueType, caseSensitive,
useFieldId,
+ returnNullStructIfAllFieldsMissing))
.named(repeatedGroup.getName)
if (useFieldId && repeatedGroup.getId != null) {
newRepeatedGroup.withId(repeatedGroup.getId.intValue())
@@ -384,9 +388,11 @@ object ParquetReadSupport extends Logging {
parquetRecord: GroupType,
structType: StructType,
caseSensitive: Boolean,
- useFieldId: Boolean): GroupType = {
+ useFieldId: Boolean,
+ returnNullStructIfAllFieldsMissing: Boolean): GroupType = {
val clippedParquetFields =
- clipParquetGroupFields(parquetRecord, structType, caseSensitive,
useFieldId)
+ clipParquetGroupFields(parquetRecord, structType, caseSensitive,
useFieldId,
+ returnNullStructIfAllFieldsMissing)
Types
.buildGroup(parquetRecord.getRepetition)
.as(parquetRecord.getLogicalTypeAnnotation)
@@ -394,7 +400,10 @@ object ParquetReadSupport extends Logging {
.named(parquetRecord.getName)
}
- private def clipVariantSchema(parquetType: GroupType, variantStruct:
StructType): GroupType = {
+ private def clipVariantSchema(
+ parquetType: GroupType,
+ variantStruct: StructType,
+ returnNullStructIfAllFieldsMissing: Boolean): GroupType = {
// TODO(SHREDDING): clip `parquetType` to retain the necessary columns.
parquetType
}
@@ -408,7 +417,8 @@ object ParquetReadSupport extends Logging {
parquetRecord: GroupType,
structType: StructType,
caseSensitive: Boolean,
- useFieldId: Boolean): Seq[Type] = {
+ useFieldId: Boolean,
+ returnNullStructIfAllFieldsMissing: Boolean): Seq[Type] = {
val toParquet = new SparkToParquetSchemaConverter(
writeLegacyParquetFormat = false,
useFieldId = useFieldId)
@@ -418,11 +428,16 @@ object ParquetReadSupport extends Logging {
parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT))
lazy val idToParquetFieldMap =
parquetRecord.getFields.asScala.filter(_.getId != null).groupBy(f =>
f.getId.intValue())
+ var isStructWithMissingAllFields = true
def matchCaseSensitiveField(f: StructField): Type = {
caseSensitiveParquetFieldMap
.get(f.name)
- .map(clipParquetType(_, f.dataType, caseSensitive, useFieldId))
+ .map { parquetType =>
+ isStructWithMissingAllFields = false
+ clipParquetType(parquetType, f.dataType, caseSensitive, useFieldId,
+ returnNullStructIfAllFieldsMissing)
+ }
.getOrElse(toParquet.convertField(f, inShredded = false))
}
@@ -437,7 +452,9 @@ object ParquetReadSupport extends Logging {
throw
QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError(
f.name, parquetTypesString)
} else {
- clipParquetType(parquetTypes.head, f.dataType, caseSensitive,
useFieldId)
+ isStructWithMissingAllFields = false
+ clipParquetType(parquetTypes.head, f.dataType, caseSensitive,
useFieldId,
+ returnNullStructIfAllFieldsMissing)
}
}.getOrElse(toParquet.convertField(f, inShredded = false))
}
@@ -453,7 +470,9 @@ object ParquetReadSupport extends Logging {
throw
QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError(
fieldId, parquetTypesString)
} else {
- clipParquetType(parquetTypes.head, f.dataType, caseSensitive,
useFieldId)
+ isStructWithMissingAllFields = false
+ clipParquetType(parquetTypes.head, f.dataType, caseSensitive,
useFieldId,
+ returnNullStructIfAllFieldsMissing)
}
}.getOrElse {
// When there is no ID match, we use a fake name to avoid a name
match by accident
@@ -463,7 +482,7 @@ object ParquetReadSupport extends Logging {
}
val shouldMatchById = useFieldId && ParquetUtils.hasFieldIds(structType)
- structType.map { f =>
+ val clippedType = structType.map { f =>
if (shouldMatchById && ParquetUtils.hasFieldId(f)) {
matchIdField(f)
} else if (caseSensitive) {
@@ -472,6 +491,66 @@ object ParquetReadSupport extends Logging {
matchCaseInsensitiveField(f)
}
}
+ // Ignore MessageType, because it is the root of the schema, not a struct.
+ if (returnNullStructIfAllFieldsMissing || !isStructWithMissingAllFields ||
+ parquetRecord.isInstanceOf[MessageType]) {
+ clippedType
+ } else {
+ // Read one arbitrary field to understand when the struct value is null
or not null.
+ clippedType :+ findCheapestGroupField(parquetRecord)
+ }
+ }
+
+ /**
+ * Finds the leaf node under a given file schema node that is likely to be
cheapest to fetch.
+ * Keeps this leaf node inside the same parent hierarchy. This is used when
all struct fields in
+ * the requested schema are missing. Uses a very simple heuristic based on
the parquet type.
+ */
+ private def findCheapestGroupField(parentGroupType: GroupType): Type = {
+ def findCheapestGroupFieldRecurse(curType: Type, repLevel: Int = 0):
(Type, Int, Int) = {
+ curType match {
+ case groupType: GroupType =>
+ var (bestType, bestRepLevel, bestCost) = (Option.empty[Type], 0, 0)
+ for (field <- groupType.getFields.asScala) {
+ val newRepLevel = repLevel + (if
(field.isRepetition(Repetition.REPEATED)) 1 else 0)
+ // Never take a field at a deeper repetition level, since it's
likely to have more data.
+ // Don't do safety checks because we should already have done them
when traversing the
+ // schema for the first time.
+ if (bestType.isEmpty || newRepLevel <= bestRepLevel) {
+ val (childType, childRepLevel, childCost) =
+ findCheapestGroupFieldRecurse(field, newRepLevel)
+ // Always prefer elements with a lower repetition level, since
more nesting of arrays
+ // is likely to result in more data. At the same repetition
level, prefer the smaller
+ // type.
+ if (bestType.isEmpty || childRepLevel < bestRepLevel ||
+ (childRepLevel == bestRepLevel && childCost < bestCost)) {
+ // This is the new best path.
+ bestType = Some(childType)
+ bestRepLevel = childRepLevel
+ bestCost = childCost
+ }
+ }
+ }
+ (groupType.withNewFields(bestType.get), bestRepLevel, bestCost)
+ case primitiveType: PrimitiveType =>
+ val cost = primitiveType.getPrimitiveTypeName match {
+ case PrimitiveType.PrimitiveTypeName.BOOLEAN => 1
+ case PrimitiveType.PrimitiveTypeName.INT32 => 4
+ case PrimitiveType.PrimitiveTypeName.INT64 => 8
+ case PrimitiveType.PrimitiveTypeName.INT96 => 12
+ case PrimitiveType.PrimitiveTypeName.FLOAT => 4
+ case PrimitiveType.PrimitiveTypeName.DOUBLE => 8
+ // Strings seem undesirable, since they don't have a fixed size.
Give them a high cost.
+ case PrimitiveType.PrimitiveTypeName.BINARY |
+ PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => 32
+ // High default cost for types added in the future.
+ case _ => 32
+ }
+ (primitiveType, repLevel, cost)
+ }
+ }
+ // Ignore the highest level of the hierarchy since we are interested only
in the subfield.
+ findCheapestGroupFieldRecurse(parentGroupType)._1.asGroupType().getType(0)
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index cb5e7bf53215..f9d50bf28ea8 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -249,18 +249,24 @@ private[parquet] class ParquetRowConverter(
}
}
parquetType.getFields.asScala.map { parquetField =>
- val catalystFieldIndex = Option(parquetField.getId).flatMap { fieldId =>
+ Option(parquetField.getId).flatMap { fieldId =>
// field has id, try to match by id first before falling back to match
by name
catalystFieldIdxByFieldId.get(fieldId.intValue())
- }.getOrElse {
+ }.orElse {
// field doesn't have id, just match by name
- catalystFieldIdxByName(parquetField.getName)
+ catalystFieldIdxByName.get(parquetField.getName)
+ }.map { catalystFieldIndex =>
+ val catalystField = catalystType(catalystFieldIndex)
+ // Create a RowUpdater instance for converting Parquet objects to
Catalyst rows.
+ val rowUpdater: RowUpdater = new RowUpdater(currentRow,
catalystFieldIndex)
+ // Converted field value should be set to the `fieldIndex`-th cell of
`currentRow`
+ newConverter(parquetField, catalystField.dataType, rowUpdater)
+ }.getOrElse {
+ // This should only happen if we are reading an arbitrary field from a
struct for its levels
+ // that is not otherwise requested.
+ val catalystType =
SparkShreddingUtils.parquetTypeToSparkType(parquetField)
+ newConverter(parquetField, catalystType, NoopUpdater)
}
- val catalystField = catalystType(catalystFieldIndex)
- // Create a RowUpdater instance for converting Parquet objects to
Catalyst rows.
- val rowUpdater: RowUpdater = new RowUpdater(currentRow,
catalystFieldIndex)
- // Converted field value should be set to the `fieldIndex`-th cell of
`currentRow`
- newConverter(parquetField, catalystField.dataType, rowUpdater)
}.toArray
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index 0df21e2a5229..947c021c1bd3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -349,6 +349,26 @@ class ParquetToSparkSchemaConverter(
groupColumn: GroupColumnIO,
sparkReadType: Option[DataType] = None): ParquetColumn = {
val field = groupColumn.getType.asGroupType()
+
+ /*
+ * We need to use the Spark SQL type if available, since conversion from
Parquet to Spark could
+ * cause precision loss. For instance, Spark read schema is
smallint/tinyint but Parquet only
+ * supports int. This is only applicable to primitive types, as the Spark
type and the Parquet
+ * converted type for complex types should be largely the same (the only
difference can be at
+ * the leaves or if we are adding non-requested struct fields for
detecting struct nullability).
+ * @param sparkReadType Matching Spark schema field type for the Parquet
field
+ * @param converted ParquetColumn that is recursively converted from the
Parquet field
+ * @return `sparkReadType` if it is defined and not a complex type,
otherwise the type of
+ * `converted`
+ */
+ def getSparkTypeIfApplicable(
+ sparkReadType: Option[DataType],
+ converted: ParquetColumn): DataType = {
+ sparkReadType
+ .filterNot(ParquetSchemaConverter.isComplexType)
+ .getOrElse(converted.sparkType)
+ }
+
Option(field.getLogicalTypeAnnotation).fold(
convertInternal(groupColumn,
sparkReadType.map(_.asInstanceOf[StructType]))) {
// A Parquet list is represented as a 3-level structure:
@@ -379,7 +399,7 @@ class ParquetToSparkSchemaConverter(
if (isElementType(repeatedType, field.getName)) {
var converted = convertField(repeated, sparkReadElementType)
- val convertedType =
sparkReadElementType.getOrElse(converted.sparkType)
+ val convertedType = getSparkTypeIfApplicable(sparkReadElementType,
converted)
// legacy format such as:
// optional group my_list (LIST) {
@@ -393,7 +413,7 @@ class ParquetToSparkSchemaConverter(
} else {
val element = repeated.asInstanceOf[GroupColumnIO].getChild(0)
val converted = convertField(element, sparkReadElementType)
- val convertedType =
sparkReadElementType.getOrElse(converted.sparkType)
+ val convertedType = getSparkTypeIfApplicable(sparkReadElementType,
converted)
val optional = element.getType.isRepetition(OPTIONAL)
ParquetColumn(ArrayType(convertedType, containsNull = optional),
groupColumn, Seq(converted))
@@ -423,8 +443,8 @@ class ParquetToSparkSchemaConverter(
val sparkReadValueType =
sparkReadType.map(_.asInstanceOf[MapType].valueType)
val convertedKey = convertField(key, sparkReadKeyType)
val convertedValue = convertField(value, sparkReadValueType)
- val convertedKeyType =
sparkReadKeyType.getOrElse(convertedKey.sparkType)
- val convertedValueType =
sparkReadValueType.getOrElse(convertedValue.sparkType)
+ val convertedKeyType = getSparkTypeIfApplicable(sparkReadKeyType,
convertedKey)
+ val convertedValueType = getSparkTypeIfApplicable(sparkReadValueType,
convertedValue)
val valueOptional = value.getType.isRepetition(OPTIONAL)
ParquetColumn(
MapType(convertedKeyType, convertedValueType,
@@ -835,4 +855,16 @@ private[sql] object ParquetSchemaConverter {
messageParameters = Map("msg" -> message))
}
}
+
+ /**
+ * Whether a [[DataType]] is complex. Complex [[DataType]] is not equivalent
to
+ * non-[[AtomicType]]. For example, [[CalendarIntervalType]] is not complex,
but it's not an
+ * [[AtomicType]] either.
+ */
+ def isComplexType(dataType: DataType): Boolean = {
+ dataType match {
+ case _: ArrayType | _: MapType | _: StructType => true
+ case _ => false
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala
index d3aad531ed7a..b02820fa1b13 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFieldIdSchemaSuite.scala
@@ -46,7 +46,8 @@ class ParquetFieldIdSchemaSuite extends ParquetSchemaTest {
fileSchema,
catalystSchema,
caseSensitive = caseSensitive,
- useFieldId = useFieldId)
+ useFieldId = useFieldId,
+ returnNullStructIfAllFieldsMissing = false)
// each fake name should be uniquely generated
val fakeColumnNames =
actual.getPaths.asScala.flatten.filter(_.startsWith(FAKE_COLUMN_NAME))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index f52b0bdd8790..dacd58ee66fe 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -38,11 +38,13 @@ import org.apache.parquet.io.api.Binary
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark.{SPARK_VERSION_SHORT, SparkException, TestUtils}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.catalyst.util.{DateTimeConstants, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.localTime
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import
org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -739,9 +741,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
}
test("vectorized reader: missing all struct fields") {
- Seq(true, false).foreach { offheapEnabled =>
+ for {
+ offheapEnabled <- Seq(true, false)
+ returnNullStructIfAllFieldsMissing <- Seq(true, false)
+ } {
withSQLConf(
SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key ->
"true",
+ SQLConf.LEGACY_PARQUET_RETURN_NULL_STRUCT_IF_ALL_FIELDS_MISSING.key
->
+ returnNullStructIfAllFieldsMissing.toString,
SQLConf.COLUMN_VECTOR_OFFHEAP_ENABLED.key ->
offheapEnabled.toString) {
val data = Seq(
Tuple1((1, "a")),
@@ -755,10 +762,72 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
.add("_4", LongType, nullable = true),
nullable = true)
+ val expectedAnswer = if (!returnNullStructIfAllFieldsMissing) {
+ Row(Row(null, null)) :: Row(Row(null, null)) :: Row(null) :: Nil
+ } else {
+ Row(null) :: Row(null) :: Row(null) :: Nil
+ }
+
withParquetFile(data) { file =>
- checkAnswer(spark.read.schema(readSchema).parquet(file),
- Row(null) :: Row(null) :: Row(null) :: Nil
- )
+ val df = spark.read.schema(readSchema).parquet(file)
+ val scanNode = df.queryExecution.executedPlan.collectLeaves().head
+ VerifyNoAdditionalScanOutputExec(scanNode).execute().collect()
+ checkAnswer(df, expectedAnswer)
+ }
+ }
+ }
+ }
+
+ test("SPARK-53535: vectorized reader: missing all struct fields, struct with
complex fields") {
+ val data = Seq(
+ Row(Row(Seq(11, 12, null, 14), Row("21", 22), Row(true)), 100),
+ Row(Row(Seq(11, 12, null, 14), Row("21", 22), Row(false)), 100),
+ Row(null, 100)
+ )
+
+ val tableSchema = new StructType()
+ .add("_1", new StructType()
+ .add("_1", ArrayType(IntegerType, containsNull = true))
+ .add("_2", new StructType()
+ .add("_1", StringType)
+ .add("_2", IntegerType))
+ .add("_3", new StructType()
+ .add("_1", BooleanType)))
+ .add("_2", IntegerType)
+
+ val readSchema = new StructType()
+ .add("_1", new StructType()
+ .add("_101", IntegerType)
+ .add("_102", LongType))
+
+ withTempPath { path =>
+ val file = path.getCanonicalPath
+ spark.createDataFrame(data.asJava,
tableSchema).write.partitionBy("_2").parquet(file)
+
+ for {
+ offheapEnabled <- Seq(true, false)
+ returnNullStructIfAllFieldsMissing <- Seq(true, false)
+ } {
+ withSQLConf(
+ SQLConf.LEGACY_PARQUET_RETURN_NULL_STRUCT_IF_ALL_FIELDS_MISSING.key
->
+ returnNullStructIfAllFieldsMissing.toString,
+ SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key ->
"true",
+ SQLConf.COLUMN_VECTOR_OFFHEAP_ENABLED.key -> offheapEnabled.toString
+ ) {
+ val expectedAnswer = if (!returnNullStructIfAllFieldsMissing) {
+ Row(Row(null, null), 100) :: Row(Row(null, null), 100) ::
Row(null, 100) :: Nil
+ } else {
+ Row(null, 100) :: Row(null, 100) :: Row(null, 100) :: Nil
+ }
+
+ withAllParquetReaders {
+ val df = spark.read.schema(readSchema).parquet(file)
+ val scanNode = df.queryExecution.executedPlan.collectLeaves().head
+ if (scanNode.supportsColumnar) {
+ VerifyNoAdditionalScanOutputExec(scanNode).execute().collect()
+ }
+ checkAnswer(df, expectedAnswer)
+ }
}
}
}
@@ -1655,3 +1724,21 @@ class
TaskCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAtt
sys.error("Intentional exception for testing purposes")
}
}
+
+case class VerifyNoAdditionalScanOutputExec(override val child: SparkPlan)
extends UnaryExecNode {
+ override def doExecute(): RDD[InternalRow] = {
+ val childOutputTypes = child.output.map(_.dataType)
+ child.executeColumnar().foreachPartition { batches =>
+ batches.foreach { input =>
+ 0.until(input.numCols).foreach { index =>
+ assert(childOutputTypes(index) == input.column(index).dataType,
+ "Found additional columns in the ColumnarBatch that are not
present in output schema")
+ }
+ }
+ }
+ sparkContext.emptyRDD[InternalRow]
+ }
+ override def output: Seq[Attribute] = Nil
+ override def withNewChildInternal(newChild: SparkPlan):
VerifyNoAdditionalScanOutputExec =
+ copy(child = newChild)
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index e8d08f5d6020..bf0fbd1f9b89 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -2480,9 +2480,11 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
parquetSchema: String,
catalystSchema: StructType,
expectedSchema: String,
- caseSensitive: Boolean = true): Unit = {
+ caseSensitive: Boolean = true,
+ returnNullStructIfAllFieldsMissing: Boolean = false): Unit = {
testSchemaClipping(testName, parquetSchema, catalystSchema,
- MessageTypeParser.parseMessageType(expectedSchema), caseSensitive)
+ MessageTypeParser.parseMessageType(expectedSchema), caseSensitive,
+ returnNullStructIfAllFieldsMissing)
}
private def testSchemaClipping(
@@ -2490,13 +2492,15 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
parquetSchema: String,
catalystSchema: StructType,
expectedSchema: MessageType,
- caseSensitive: Boolean): Unit = {
+ caseSensitive: Boolean,
+ returnNullStructIfAllFieldsMissing: Boolean): Unit = {
test(s"Clipping - $testName") {
val actual = ParquetReadSupport.clipParquetSchema(
MessageTypeParser.parseMessageType(parquetSchema),
catalystSchema,
caseSensitive,
- useFieldId = false)
+ useFieldId = false,
+ returnNullStructIfAllFieldsMissing)
try {
expectedSchema.checkContains(actual)
@@ -2857,7 +2861,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
catalystSchema = new StructType(),
- expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE,
+ expectedSchema = ParquetSchemaConverter.EMPTY_MESSAGE.toString,
caseSensitive = true)
testSchemaClipping(
@@ -2886,6 +2890,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
| required group f0 {
| optional float f02;
| optional double f03;
+ | required int32 f00;
| }
|}
""".stripMargin)
@@ -3063,7 +3068,312 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
MessageTypeParser.parseMessageType(parquetSchema),
catalystSchema,
caseSensitive = false,
- useFieldId = false)
+ useFieldId = false,
+ returnNullStructIfAllFieldsMissing = false)
}
}
+
+ testSchemaClipping(
+ s"SPARK-53535: struct in struct missing in requested schema",
+ parquetSchema =
+ """message root {
+ | required int32 f0;
+ | required group f1 {
+ | required group f10 {
+ | required int32 f100;
+ | required int32 f101;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ catalystSchema = new StructType().add("f0", IntegerType, nullable = true),
+ expectedSchema =
+ """message root {
+ | required int32 f0;
+ |}
+ """.stripMargin)
+
+ for (returnNullStructIfAllFieldsMissing <- Seq(true, false)) {
+ testSchemaClipping(
+ s"SPARK-53535: struct in struct, with missing field being requested, " +
+
s"returnNullStructIfAllFieldsMissing=$returnNullStructIfAllFieldsMissing",
+ parquetSchema =
+ """message root {
+ | required group f0 {
+ | required group f00 {
+ | required int64 f000;
+ | required int32 f001;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ catalystSchema = new StructType()
+ .add("f0", new StructType()
+ .add("f01", IntegerType, nullable = true), nullable = true),
+ expectedSchema =
+ ("""message root {
+ | required group f0 {
+ | optional int32 f01;""" + (if
(!returnNullStructIfAllFieldsMissing) {
+ """
+ | required group f00 {
+ | required int32 f001;
+ | }""" } else { "" }) +
+ """
+ | }
+ |}
+ """).stripMargin,
+ returnNullStructIfAllFieldsMissing = returnNullStructIfAllFieldsMissing)
+
+ testSchemaClipping(
+ s"SPARK-53535: missing struct with complex fields, " +
+
s"returnNullStructIfAllFieldsMissing=$returnNullStructIfAllFieldsMissing",
+ parquetSchema =
+ """message root {
+ | optional group _1 {
+ | optional group _1 (LIST) {
+ | repeated group list {
+ | optional int32 element;
+ | }
+ | }
+ | optional group _2 {
+ | optional binary _1 (UTF8);
+ | optional int32 _2;
+ | }
+ | optional group _3 {
+ | optional group _1 (LIST) {
+ | repeated group list {
+ | optional int32 element;
+ | }
+ | }
+ | optional boolean _2;
+ | optional int32 _3;
+ | }
+ | optional group _4 {
+ | optional int64 _1;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ catalystSchema = new StructType()
+ .add("_1", new StructType()
+ .add("_101", IntegerType)
+ .add("_102", LongType)),
+ expectedSchema =
+ ("""message root {
+ | optional group _1 {
+ | optional int32 _101;
+ | optional int64 _102;""" + (if
(!returnNullStructIfAllFieldsMissing) {
+ """
+ | optional group _3 {
+ | optional boolean _2;
+ | }""" } else { "" }) +
+ """
+ | }
+ |}
+ """).stripMargin,
+ returnNullStructIfAllFieldsMissing = returnNullStructIfAllFieldsMissing)
+ }
+
+ testSchemaClipping(
+ s"SPARK-53535: various missing structs, cheapest type selection works as
expected",
+ parquetSchema =
+ """message root {
+ | optional group pickShortestType1 {
+ | optional int64 long;
+ | optional int32 int;
+ | optional double double;
+ | }
+ | optional group pickShortestType2 {
+ | optional int64 long;
+ | optional int32 int;
+ | optional boolean boolean;
+ | }
+ | optional group dontPickArrayOrMap {
+ | optional int64 long;
+ | optional group array (LIST) {
+ | repeated group list {
+ | optional boolean element;
+ | }
+ | }
+ | optional group map (MAP) {
+ | repeated group key_value {
+ | required boolean key;
+ | required int32 value;
+ | }
+ | }
+ | }
+ | optional group pickArrayOrMap {
+ | optional group arrayOfArray (LIST) {
+ | repeated group list {
+ | optional group element (LIST) {
+ | repeated group list {
+ | optional boolean element;
+ | }
+ | }
+ | }
+ | }
+ | optional group arrayOfLong (LIST) {
+ | repeated group list {
+ | optional int64 element;
+ | }
+ | }
+ | optional group map (MAP) {
+ | repeated group key_value {
+ | required int32 key;
+ | required binary value (UTF8);
+ | }
+ | }
+ | }
+ | optional group structNesting {
+ | optional int64 long;
+ | optional group struct {
+ | optional int32 int;
+ | optional group struct {
+ | optional boolean boolean;
+ | }
+ | }
+ | }
+ |}
+ """.stripMargin,
+ catalystSchema = new StructType()
+ .add("pickShortestType1", new StructType().add("missingColumn",
IntegerType))
+ .add("pickShortestType2", new StructType().add("missingColumn",
IntegerType))
+ .add("dontPickArrayOrMap", new StructType().add("missingColumn",
IntegerType))
+ .add("pickArrayOrMap", new StructType().add("missingColumn",
IntegerType))
+ .add("structNesting", new StructType().add("missingColumn",
IntegerType)),
+ expectedSchema =
+ """message root {
+ | optional group pickShortestType1 {
+ | optional int32 missingColumn;
+ | optional int32 int;
+ | }
+ | optional group pickShortestType2 {
+ | optional int32 missingColumn;
+ | optional boolean boolean;
+ | }
+ | optional group dontPickArrayOrMap {
+ | optional int32 missingColumn;
+ | optional int64 long;
+ | }
+ | optional group pickArrayOrMap {
+ | optional int32 missingColumn;
+ | optional group map (MAP) {
+ | repeated group key_value {
+ | required int32 key;
+ | }
+ | }
+ | }
+ | optional group structNesting {
+ | optional int32 missingColumn;
+ | optional group struct {
+ | optional group struct {
+ | optional boolean boolean;
+ | }
+ | }
+ | }
+ |}
+ """.stripMargin)
+
+ for (returnNullStructIfAllFieldsMissing <- Seq(true, false)) {
+ testSchemaClipping(
+ s"struct in struct missing in requested schema, " +
+
s"returnNullStructIfAllFieldsMissing=$returnNullStructIfAllFieldsMissing",
+ parquetSchema =
+ """message root {
+ | required int32 f0;
+ | required group f1 {
+ | required group f10 {
+ | required int32 f100;
+ | required int32 f101;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ catalystSchema = new StructType().add("f0", IntegerType, nullable =
true),
+ expectedSchema =
+ """message root {
+ | required int32 f0;
+ |}
+ """.stripMargin,
+ returnNullStructIfAllFieldsMissing = returnNullStructIfAllFieldsMissing)
+
+ testSchemaClipping(
+ s"struct in struct, with missing field being requested, " +
+
s"returnNullStructIfAllFieldsMissing=$returnNullStructIfAllFieldsMissing",
+ parquetSchema =
+ """message root {
+ | required group f0 {
+ | required group f00 {
+ | required int64 f000;
+ | required int32 f001;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ catalystSchema = new StructType()
+ .add("f0", new StructType()
+ .add("f01", IntegerType, nullable = true), nullable = true),
+ expectedSchema =
+ ("""message root {
+ | required group f0 {
+ | optional int32 f01;""" + (if
(!returnNullStructIfAllFieldsMissing) {
+ """
+ | required group f00 {
+ | required int32 f001;
+ | }""" } else { "" }) +
+ """
+ | }
+ |}
+ """).stripMargin,
+ returnNullStructIfAllFieldsMissing = returnNullStructIfAllFieldsMissing)
+
+ testSchemaClipping(
+ s"missing struct with complex fields, " +
+
s"returnNullStructIfAllFieldsMissing=$returnNullStructIfAllFieldsMissing",
+ parquetSchema =
+ """message root {
+ | optional group _1 {
+ | optional group _1 (LIST) {
+ | repeated group list {
+ | optional int32 element;
+ | }
+ | }
+ | optional group _2 {
+ | optional binary _1 (UTF8);
+ | optional int32 _2;
+ | }
+ | optional group _3 {
+ | optional group _1 (LIST) {
+ | repeated group list {
+ | optional int32 element;
+ | }
+ | }
+ | optional boolean _2;
+ | optional int32 _3;
+ | }
+ | optional group _4 {
+ | optional int64 _1;
+ | }
+ | }
+ |}
+ """.stripMargin,
+ catalystSchema = new StructType()
+ .add("_1", new StructType()
+ .add("_101", IntegerType)
+ .add("_102", LongType)),
+ expectedSchema =
+ ("""message root {
+ | optional group _1 {
+ | optional int32 _101;
+ | optional int64 _102;""" + (if
(!returnNullStructIfAllFieldsMissing) {
+ """
+ | optional group _3 {
+ | optional boolean _2;
+ | }""" } else { "" }) +
+ """
+ | }
+ |}
+ """).stripMargin,
+ returnNullStructIfAllFieldsMissing = returnNullStructIfAllFieldsMissing)
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala
index 5d68fcac1385..f9b3d4fc68f8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala
@@ -34,7 +34,7 @@ import
org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.RowOrdering
+import org.apache.spark.sql.catalyst.expressions.{RowOrdering,
SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import
org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ParquetRowGroupReader
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
@@ -481,6 +481,173 @@ class ParquetVectorizedSuite extends QueryTest with
ParquetTest with SharedSpark
}
}
+ truncateTypeTest("primitive type", IntegerType, LongType, IntegerType)
+
+ truncateTypeTest("basic struct",
+ readType = StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType),
+ StructField("c", BooleanType)
+ )),
+ requestedType = StructType(Seq(
+ StructField("a", LongType),
+ StructField("b", StringType)
+ )),
+ expectedType = StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ ))
+ )
+
+ truncateTypeTest("nested struct",
+ readType = StructType(Seq(
+ StructField("nested", StructType(Seq(
+ StructField("x", IntegerType),
+ StructField("y", StringType),
+ StructField("z", DoubleType)
+ ))),
+ StructField("extra", BooleanType)
+ )),
+ requestedType = StructType(Seq(
+ StructField("nested", StructType(Seq(
+ StructField("x", IntegerType),
+ StructField("y", StringType)
+ ))),
+ StructField("extra", BooleanType)
+ )),
+ expectedType = StructType(Seq(
+ StructField("nested", StructType(Seq(
+ StructField("x", IntegerType),
+ StructField("y", StringType)
+ ))),
+ StructField("extra", BooleanType)
+ ))
+ )
+
+ truncateTypeTest("empty structs",
+ readType = StructType(Seq.empty),
+ requestedType = StructType(Seq.empty),
+ expectedType = StructType(Seq.empty)
+ )
+
+ truncateTypeTest("simple arrays",
+ readType = ArrayType(IntegerType),
+ requestedType = ArrayType(LongType),
+ expectedType = ArrayType(IntegerType)
+ )
+
+ truncateTypeTest("structs in arrays",
+ readType = ArrayType(StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType),
+ StructField("c", BooleanType)
+ ))),
+ requestedType = ArrayType(StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ ))),
+ expectedType = ArrayType(StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ )))
+ )
+
+ truncateTypeTest("nested array, containsNull is preserved",
+ readType = ArrayType(ArrayType(IntegerType, containsNull = false)),
+ requestedType = ArrayType(ArrayType(IntegerType, containsNull = true)),
+ expectedType = ArrayType(ArrayType(IntegerType, containsNull = false))
+ )
+
+ truncateTypeTest("map with primitive key/value types",
+ readType = MapType(StringType, IntegerType),
+ requestedType = MapType(StringType, LongType),
+ expectedType = MapType(StringType, IntegerType)
+ )
+
+ truncateTypeTest("map with struct values needing truncation",
+ readType = MapType(StringType, StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType),
+ StructField("c", DoubleType)
+ ))),
+ requestedType = MapType(StringType, StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ ))),
+ expectedType = MapType(StringType, StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ )))
+ )
+
+ truncateTypeTest("map with struct keys needing truncation",
+ readType = MapType(StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", DoubleType),
+ StructField("c", StringType)
+ )), IntegerType),
+ requestedType = MapType(StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", FloatType)
+ )), IntegerType),
+ expectedType = MapType(StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", DoubleType)
+ )), IntegerType)
+ )
+
+ truncateTypeTest("map valueContainsNull is preserved",
+ readType = MapType(StringType, IntegerType, valueContainsNull = false),
+ requestedType = MapType(StringType, IntegerType, valueContainsNull = true),
+ expectedType = MapType(StringType, IntegerType, valueContainsNull = false)
+ )
+
+ truncateTypeTest("all complex types",
+ readType = StructType(Seq(
+ StructField("complexField", ArrayType(MapType(StringType, StructType(Seq(
+ StructField("x", IntegerType),
+ StructField("y", StringType),
+ StructField("z", BooleanType)
+ ))))),
+ StructField("extraTopLevel", DoubleType)
+ )),
+ requestedType = StructType(Seq(
+ StructField("complexField", ArrayType(MapType(StringType, StructType(Seq(
+ StructField("x", IntegerType),
+ StructField("y", StringType)
+ )))))
+ )),
+ expectedType = StructType(Seq(
+ StructField("complexField", ArrayType(MapType(StringType, StructType(Seq(
+ StructField("x", IntegerType),
+ StructField("y", StringType)
+ )))))
+ ))
+ )
+
+ truncateTypeTest("struct UDT",
+ readType = new StructType()
+ .add("_1", LongType)
+ .add("_2", IntegerType)
+ .add("_3", StringType),
+ requestedType = new TestStructUDT(),
+ expectedType = new StructType()
+ .add("_1", LongType)
+ .add("_2", IntegerType)
+ )
+
+ private def truncateTypeTest(
+ testName: String,
+ readType: DataType,
+ requestedType: DataType,
+ expectedType: DataType
+ ): Unit = {
+ test(s"truncateType - $testName") {
+ val result = VectorizedParquetRecordReader.truncateType(readType,
requestedType)
+ assert(result === expectedType)
+ }
+ }
+
private def testPrimitiveString(
firstRowIndexesOpt: Option[Seq[Long]],
rangesOpt: Option[Seq[(Long, Long)]],
@@ -762,3 +929,29 @@ class ParquetVectorizedSuite extends QueryTest with
ParquetTest with SharedSpark
}
}
}
+
+@SQLUserDefinedType(udt = classOf[TestStructUDT])
+case class TestStruct(a: Integer, b: java.lang.Long)
+
+class TestStructUDT extends UserDefinedType[TestStruct] {
+ override def sqlType: DataType = new StructType()
+ .add("_1", IntegerType)
+ .add("_2", LongType)
+
+ override def serialize(n: TestStruct): Any = {
+ val row = new
SpecificInternalRow(sqlType.asInstanceOf[StructType].map(_.dataType))
+ if (n.a == null) row.setNullAt(0) else row.setInt(0, n.a)
+ if (n.b == null) row.setNullAt(1) else row.setLong(1, n.b)
+ row
+ }
+
+ override def userClass: Class[TestStruct] = classOf[TestStruct]
+
+ override def deserialize(datum: Any): TestStruct = {
+ datum match {
+ case row: InternalRow => TestStruct(
+ if (row.isNullAt(0)) null else row.getInt(0),
+ if (row.isNullAt(1)) null else row.getLong(1))
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]