This is an automated email from the ASF dual-hosted git repository.
cloud-fan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c4842918cf95 [SPARK-55568][SQL] Separate schema construction from
field stats collection
c4842918cf95 is described below
commit c4842918cf958ca4681415e755fd66cfda48a5e3
Author: Qiegang Long <[email protected]>
AuthorDate: Thu May 28 02:09:20 2026 +0800
[SPARK-55568][SQL] Separate schema construction from field stats collection
### Why are the changes needed?
Variant shredding schema inference is expensive and can take over 100ms per
file. Replace fold-based schema merging with deferred schema construction using
single-pass field statistics collection.
Previous approach:
- Used foldLeft to build and merge complete schemas for each row
- Merged schemas repeatedly across 4096 rows
- High allocation overhead from recursive schema construction
New approach:
- Separate schema construction from field statistics collection to avoid
excessive intermediate allocations and repeated merges.
- Single-pass field traversal with field statistics tree to track field
types and row counts
- Using lastSeenRow for deduplication
- Defers schema construction until after all rows processed
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Functional test:
* Pass all existing unit tests
Performance vs master:
- Tested with scenarios with different field counts, array sizes, and batch
sizes(1-4096 rows, 10-200 fields, varying nesting depths and sparsity patterns).
- 1.7x to 2.4x speed up across test scenarios
- Consistent performance across multiple runs
- 96% of tests show improvement
### Was this patch authored or co-authored using generative AI tooling?
Co-authored with Claude Sonnet 4.5
Closes #54343 from qlong/SPARK-55568-optimize-variant-schema-inference.
Lead-authored-by: Qiegang Long <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../parquet/InferVariantShreddingSchema.scala | 356 ++++++++++++++++-----
.../parquet/VariantInferShreddingSuite.scala | 261 ++++++++++++++-
2 files changed, 516 insertions(+), 101 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala
index 1ebb61968150..bbfbbfde0ba4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.parquet
+import scala.collection.mutable
+
import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
@@ -93,74 +95,31 @@ class InferVariantShreddingSchema(val schema: StructType) {
private val COUNT_METADATA_KEY = "COUNT"
- /**
- * Return an appropriate schema for shredding a Variant value.
- * It is similar to the SchemaOfVariant expression, but the rules are
somewhat different, because
- * we want the types to be consistent with what will be allowed during
shredding. E.g.
- * SchemaOfVariant will consider the common type across Integer and Double
to be double, but we
- * consider it to be VariantType, since shredding will not allow those types
to be written to
- * the same typed_value.
- * We also maintain metadata on struct fields to track how frequently they
occur. Rare fields
- * are dropped in the final schema.
- */
- private def schemaOf(v: Variant, maxDepth: Int): DataType = v.getType match {
- case Type.OBJECT =>
- if (maxDepth <= 0) return VariantType
- val size = v.objectSize()
- val fields = new Array[StructField](size)
- for (i <- 0 until size) {
- val field = v.getFieldAtIndex(i)
- fields(i) = StructField(field.key, schemaOf(field.value, maxDepth - 1),
- metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY,
1).build())
- }
- // According to the variant spec, object fields must be sorted
alphabetically. So we don't
- // have to sort, but just need to validate they are sorted.
- for (i <- 1 until size) {
- if (fields(i - 1).name >= fields(i).name) {
- throw new SparkRuntimeException(
- errorClass = "MALFORMED_VARIANT",
- messageParameters = Map.empty
- )
- }
- }
- StructType(fields)
- case Type.ARRAY =>
- if (maxDepth <= 0) return VariantType
- var elementType: DataType = NullType
- for (i <- 0 until v.arraySize()) {
- elementType = mergeSchema(elementType,
schemaOf(v.getElementAtIndex(i), maxDepth - 1))
- }
- ArrayType(elementType)
- case Type.NULL => NullType
- case Type.BOOLEAN => BooleanType
- case Type.LONG =>
- // Compute the smallest decimal that can contain this value.
- // This will allow us to merge with decimal later without introducing
excessive precision.
- // If we only end up encountering integer values, we'll convert back to
LongType when we
- // finalize.
- val d = BigDecimal(v.getLong())
- val precision = d.precision
- if (precision <= Decimal.MAX_LONG_DIGITS) {
- DecimalType(precision, 0)
- } else {
- // Value is too large for Decimal(18, 0), so record its type as long.
- LongType
+ // Node for tree-based field tracking
+ private case class FieldNode(
+ // Scalar type, or a marker (StructType(empty) / ArrayType(NullType));
structural
+ // shape lives in `children` and `arrayElementNode`.
+ var dataType: DataType,
+ var rowCount: Int = 0, // Count of distinct rows containing this
field
+ var lastSeenRow: Int = -1, // Last row index that incremented
rowCount
+ var arrayElementCount: Long = 0, // Total occurrences across all array
elements
+ children: mutable.Map[String, FieldNode] = mutable.Map.empty,
+ var arrayElementNode: Option[FieldNode] = None
+ ) {
+
+ def getOrCreateChild(fieldName: String): FieldNode = {
+ children.getOrElseUpdate(fieldName, FieldNode(NullType))
+ }
+
+ def getChildren: Seq[(String, FieldNode)] = children.toSeq
+
+ def getOrCreateArrayElement(): FieldNode = {
+ arrayElementNode.getOrElse {
+ val node = FieldNode(NullType)
+ arrayElementNode = Some(node)
+ node
}
- case Type.STRING => StringType
- case Type.DOUBLE => DoubleType
- case Type.DECIMAL =>
- // Don't strip trailing zeros to determine scale. Even if we allow scale
relaxation during
- // shredding, it's useful to take trailing zeros as a hint that the
extra digits may be used
- // in later values, and use the larger scale.
- val d = Decimal(v.getDecimalWithOriginalScale())
- DecimalType(d.precision, d.scale)
- case Type.DATE => DateType
- case Type.TIMESTAMP => TimestampType
- case Type.TIMESTAMP_NTZ => TimestampNTZType
- case Type.FLOAT => FloatType
- case Type.BINARY => BinaryType
- // Spark doesn't support UUID, so shred it as an untyped value.
- case Type.UUID => VariantType
+ }
}
private def getFieldCount(field: StructField): Long = {
@@ -203,6 +162,9 @@ class InferVariantShreddingSchema(val schema: StructType) {
mergeDecimalWithLong(d)
case (StructType(fields1), StructType(fields2)) =>
// Rely on fields being sorted by name, and merge fields with the same
name recursively.
+ // In this inference path, non-empty struct merges are unused:
`inferPrimitiveType` only
+ // produces empty struct markers; field lists are built in
`buildSchemaFromStats` from the
+ // FieldNode tree instead.
val newFields = new java.util.ArrayList[StructField]()
var f1Idx = 0
@@ -351,36 +313,256 @@ class InferVariantShreddingSchema(val schema:
StructType) {
}
def inferSchema(rows: Seq[InternalRow]): StructType = {
- // For each path to a Variant value, iterate over all rows and update the
inferred schema.
- // Add the result to a map, which we'll use to update the full schema.
- // maxShreddedFieldsPerFile is a global max for all fields, so initialize
it here.
+ // For each variant path, collect field statistics using a single pass
val maxFields = MaxFields(maxShreddedFieldsPerFile)
+
val inferredSchemas = pathsToVariant.map { path =>
- var numNonNullValues = 0
- val simpleSchema = rows.foldLeft(NullType: DataType) {
- case (partialSchema, row) =>
- getValueAtPath(schema, row, path).map { variantVal =>
- numNonNullValues += 1
- val v = new Variant(variantVal.getValue, variantVal.getMetadata)
- val schemaOfRow = schemaOf(v, maxShreddingDepth)
- mergeSchema(partialSchema, schemaOfRow)
- // If getValueAtPath returned None, the value is null in this row;
just ignore.
- }
- .getOrElse(partialSchema)
- // If we didn't find any non-null rows, use an unshredded schema.
- }
+ val rootNode = FieldNode(NullType)
+ var numNonNullVariants = 0
- // Don't infer a schema for fields that appear in less than 10% of rows.
- // Ensure that minCardinality is at least 1 if we have any rows.
- val minCardinality = (numNonNullValues + 9) / 10
+ // Single pass: process all rows for this variant path
+ rows.zipWithIndex.foreach { case (row, rowIdx) =>
+ getValueAtPath(schema, row, path).foreach { variantVal =>
+ numNonNullVariants += 1
+ val v = new Variant(variantVal.getValue, variantVal.getMetadata)
+ rootNode.dataType = mergeSchema(rootNode.dataType,
inferPrimitiveType(v, 0))
+ // Traverse variant and update field stats tree
+ collectFieldStats(v, rootNode, rowIdx, 0, inArrayContext = false)
+ }
+ }
+ // Build final schema from collected statistics
+ val minCardinality = (numNonNullVariants + 9) / 10
+ val simpleSchema = buildSchemaFromStats(
+ rootNode,
+ minCardinality,
+ inArrayContext = false,
+ isArray = rootNode.arrayElementNode.isDefined)
val finalizedSchema = finalizeSimpleSchema(simpleSchema, minCardinality,
maxFields)
val shreddingSchema =
SparkShreddingUtils.variantShreddingSchema(finalizedSchema)
val schemaWithMetadata =
SparkShreddingUtils.addWriteShreddingMetadata(shreddingSchema)
(path, schemaWithMetadata)
}.toMap
- // Insert each inferred schema into the full schema.
+ // Insert each inferred schema into the full schema
updateSchema(schema, inferredSchemas)
}
+
+ /**
+ * Recursively traverse a variant value and build a field statistics tree.
+ * For each field encountered, record its type and track distinct row count.
+ * For fields inside arrays, also increment the occurrence count.
+ */
+ private def collectFieldStats(
+ v: Variant,
+ currentNode: FieldNode,
+ rowIdx: Int,
+ depth: Int,
+ inArrayContext: Boolean): Unit = {
+
+ if (depth >= maxShreddingDepth) return
+
+ v.getType match {
+ case Type.OBJECT =>
+ val size = v.objectSize()
+ // Validate fields are sorted (per variant spec)
+ for (i <- 1 until size) {
+ val prevKey = v.getFieldAtIndex(i - 1).key
+ val currKey = v.getFieldAtIndex(i).key
+ if (prevKey >= currKey) {
+ throw new SparkRuntimeException(
+ errorClass = "MALFORMED_VARIANT",
+ messageParameters = Map.empty
+ )
+ }
+ }
+
+ // Process each field
+ for (i <- 0 until size) {
+ val field = v.getFieldAtIndex(i)
+ val fieldName = field.key
+
+ // Get or create child node (O(1) map access - no path string
building!)
+ val childNode = currentNode.getOrCreateChild(fieldName)
+
+ // Track row-level presence only outside array context.
+ if (inArrayContext) {
+ childNode.arrayElementCount += 1
+ } else if (childNode.lastSeenRow != rowIdx) {
+ childNode.rowCount += 1
+ childNode.lastSeenRow = rowIdx
+ }
+
+ // Infer and merge type
+ val fieldType = inferPrimitiveType(field.value, depth)
+ childNode.dataType = mergeSchema(childNode.dataType, fieldType)
+
+ // Recurse into nested structures (pass child node, not path string)
+ collectFieldStats(field.value, childNode, rowIdx, depth + 1,
inArrayContext)
+ }
+
+ case Type.ARRAY =>
+ val arrayNode = currentNode.getOrCreateArrayElement()
+
+ // Track distinct row count for the array field itself
+ if (arrayNode.lastSeenRow != rowIdx) {
+ arrayNode.rowCount += 1
+ arrayNode.lastSeenRow = rowIdx
+ }
+
+ val arraySize = v.arraySize()
+ if (arraySize > 0) {
+ // Process array elements
+ for (i <- 0 until arraySize) {
+ val element = v.getElementAtIndex(i)
+ val elementTypeClass = element.getType
+
+ // Primitives merge into `dataType` only; objects and arrays need
tree descent.
+ if (elementTypeClass != Type.OBJECT && elementTypeClass !=
Type.ARRAY) {
+ val primitiveType = inferPrimitiveType(element, depth)
+ arrayNode.dataType = mergeSchema(arrayNode.dataType,
primitiveType)
+ } else {
+ collectFieldStats(element, arrayNode, rowIdx, depth + 1,
inArrayContext = true)
+ }
+ }
+ }
+
+ case _ =>
+ }
+ }
+
+ /**
+ * Infer the type of a variant value without recursive field collection.
+ * For objects and arrays, return a marker type; recursive collection is
done separately.
+ */
+ private def inferPrimitiveType(v: Variant, depth: Int): DataType = {
+ if (depth >= maxShreddingDepth) return VariantType
+
+ v.getType match {
+ case Type.OBJECT =>
+ // Return empty struct as marker; fields collected separately
+ StructType(Seq.empty)
+ case Type.ARRAY =>
+ // Return array with null element as marker; elements processed
separately
+ ArrayType(NullType)
+ case Type.NULL => NullType
+ case Type.BOOLEAN => BooleanType
+ case Type.LONG =>
+ val d = BigDecimal(v.getLong())
+ val precision = d.precision
+ if (precision <= Decimal.MAX_LONG_DIGITS) {
+ DecimalType(precision, 0)
+ } else {
+ LongType
+ }
+ case Type.STRING => StringType
+ case Type.DOUBLE => DoubleType
+ case Type.DECIMAL =>
+ val d = Decimal(v.getDecimalWithOriginalScale())
+ DecimalType(d.precision, d.scale)
+ case Type.DATE => DateType
+ case Type.TIMESTAMP => TimestampType
+ case Type.TIMESTAMP_NTZ => TimestampNTZType
+ case Type.FLOAT => FloatType
+ case Type.BINARY => BinaryType
+ case Type.UUID => VariantType
+ }
+ }
+
+ /**
+ * Build a schema from collected field statistics tree.
+ *
+ * When isArray=true the function builds and returns the full ArrayType for
this node
+ * (using its arrayElementNode to determine the element type).
+ * When isArray=false it returns the type for the node itself (scalar,
VariantType,
+ * or StructType).
+ *
+ * Cardinality metric:
+ * - inArrayContext=true uses arrayElementCount (total occurrences across
array positions).
+ * - inArrayContext=false uses rowCount (distinct rows containing the
field).
+ */
+ private def buildSchemaFromStats(
+ currentNode: FieldNode,
+ minCardinality: Int,
+ inArrayContext: Boolean,
+ isArray: Boolean): DataType = {
+
+ // Pick the right counter for this context; reused in filter, sort, and
metadata below.
+ def cardinality(n: FieldNode): Long =
+ if (inArrayContext) n.arrayElementCount else n.rowCount
+
+ // Array branch
+ if (isArray) {
+ // Case 1: mixed array and non-array rows at the same path merged
dataType to VariantType.
+ // The whole node is variant, not an array.
+ if (currentNode.dataType == VariantType) {
+ return VariantType
+ }
+ // Case 2 (defensive): every reachable isArray=true call site is paired
with a guard that
+ // already rules this out -- the root call collapses to
VariantType (Case 1), the
+ // Case 3 recursion short-circuits via the
`elemNode.children.nonEmpty &&
+ // elemNode.arrayElementNode.isDefined` check below, and the
struct-field
+ // ArrayType recursion only fires when the child was uniformly
an array (so
+ // `children` is empty). Keep this as a backstop for future call
sites.
+ if (currentNode.children.nonEmpty &&
currentNode.arrayElementNode.isDefined) {
+ return ArrayType(VariantType)
+ }
+
+ // Case 3: uniform inner array -- recurse into the element node and
merge with any scalar
+ // dataType on the same node (e.g. `[1, {"a":1}]` merges long +
struct -> variant).
+ currentNode.arrayElementNode match {
+ case Some(elemNode) if elemNode.rowCount >= minCardinality =>
+ // If the element node itself has both object children and a nested
array, the two
+ // element shapes cannot be reconciled: treat the element as variant.
+ val elementType = if (elemNode.children.nonEmpty &&
elemNode.arrayElementNode.isDefined) {
+ VariantType
+ } else {
+ buildSchemaFromStats(
+ elemNode,
+ minCardinality,
+ inArrayContext = true,
+ isArray = elemNode.arrayElementNode.isDefined)
+ }
+ return ArrayType(mergeSchema(elemNode.dataType, elementType))
+ case _ =>
+ return currentNode.dataType
+ }
+ }
+
+ // Non-array branch: incompatible types already collapsed to VariantType
during collection.
+ if (currentNode.dataType == VariantType) return VariantType
+
+ // Filter children by cardinality, keep the top N by frequency, sort
alphabetically.
+ val maxStructSize = Math.min(1000, maxShreddedFieldsPerFile)
+ val children = currentNode.getChildren
+ .filter { case (_, n) => cardinality(n) >= minCardinality }
+ .sortBy { case (name, n) => (-cardinality(n), name) }
+ .take(maxStructSize)
+ .sortBy(_._1)
+
+ if (children.isEmpty) {
+ // No qualifying children: fall back to any scalar merged at this node,
or variant.
+ return currentNode.dataType match {
+ case _: StructType | _: ArrayType | NullType => VariantType
+ case dt => dt
+ }
+ }
+
+ val fields = children.map { case (fieldName, childNode) =>
+ val fieldType = childNode.dataType match {
+ case StructType(_) =>
+ buildSchemaFromStats(childNode, minCardinality, inArrayContext,
isArray = false)
+ case ArrayType(_, _) =>
+ buildSchemaFromStats(
+ childNode, minCardinality, inArrayContext = true,
+ isArray = childNode.arrayElementNode.isDefined)
+ case other => other
+ }
+ val cnt = cardinality(childNode)
+ StructField(fieldName, fieldType,
+ metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY,
cnt).build())
+ }
+
+ StructType(fields.toSeq)
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
index e9712aac0478..cbd3d89c3658 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
@@ -111,6 +111,74 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
}
+ testWithTempDir("infer shredding pure primitive root column") { dir =>
+ val df = spark.sql(
+ """
+ | select parse_json('42') as v
+ | from range(0, 20, 1, 1)
+ |""".stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+ // Root `dataType` merges to a numeric type; inference should preserve it
(typed_value),
+ // not collapse to an unshredded variant when there are no object fields.
+ val expected = LongType
+ checkFileSchema(expected, dir)
+ checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+ }
+
+ testWithTempDir("infer shredding mixed object and primitive at root") { dir
=>
+ val df = spark.sql(
+ """
+ | select if(id % 2 = 0, parse_json('{"a": 1}'), parse_json('42')) as v
+ | from range(0, 20, 1, 1)
+ |""".stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+ // Object rows and scalar rows cannot share one shredded struct; the
column should stay
+ // variant (value-only) at the logical level.
+ val expected = VariantType
+ checkFileSchema(expected, dir)
+ checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+ }
+
+ testWithTempDir("infer shredding heterogeneous array elements (primitive and
object)") { dir =>
+ val df = spark.sql(
+ """
+ | select parse_json('[1, {"a": 1}]') as v
+ | from range(0, 20, 1, 1)
+ |""".stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+ // Mixed element kinds should merge to variant elements, not a struct that
only reflects
+ // object fields.
+ val expected = ArrayType(VariantType, containsNull = false)
+ checkFileSchema(expected, dir)
+ checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+ }
+
+ testWithTempDir("infer shredding mixed array and non-array at root") { dir =>
+ val df = spark.sql(
+ """
+ | select if(id % 2 = 0, parse_json('[1]'), parse_json('42')) as v
+ | from range(0, 20, 1, 1)
+ |""".stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+ // Array rows and scalar rows cannot share one array element type; the
column must stay variant.
+ val expected = VariantType
+ checkFileSchema(expected, dir)
+ checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+ }
+
+ testWithTempDir("infer shredding heterogeneous array elements (object and
nested array)") { dir =>
+ val df = spark.sql(
+ """
+ | select parse_json('[{"a": 1}, [2]]') as v
+ | from range(0, 20, 1, 1)
+ |""".stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+ // Object elements and inner-array elements on the same aggregate must
become array<variant>.
+ val expected = ArrayType(VariantType, containsNull = false)
+ checkFileSchema(expected, dir)
+ checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+ }
+
test("infer shredding does not infer rare rows") {
Seq(2, 9, 10, 11, 19, 20, 21, 100).foreach { inverseFreq =>
withTempDir { dir =>
@@ -162,6 +230,57 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
}
}
+ testWithTempDir("infer shredding nested array of arrays") { dir =>
+ // Array-of-arrays [[1,2],[3,4]] in all rows. Verifies that nested arrays
are
+ // correctly inferred and shredded using the FieldNode tree
(arrayElementNode chain).
+ val df = spark.sql(
+ """
+ | select parse_json('[[1, 2], [3, 4]]') as v
+ | from range(0, 100, 1, 1)
+ """.stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+ val expected = DataType.fromDDL("array<array<long>>")
+ checkFileSchema(expected, dir)
+ checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+ }
+
+ test("infer shredding does not infer rare nested arrays") {
+ // Similar to "infer shredding does not infer rare rows" but for nested
arrays.
+ // Demonstrates row-level gating for nested arrays.
+ Seq(2, 9, 10, 11, 19, 20, 21, 100).foreach { inverseFreq =>
+ withTempDir { dir =>
+ val df = spark.sql(
+ s"""
+ | select case when id % $inverseFreq = 0 then
+ | parse_json('{"a": ' || id ||
+ | ', "nestedArr": [[1, 2], [3, 4]]' ||
+ | ', "nestedArr2": [[1, 2], [3, 4]]' ||
+ | ', "b": "' || id || '"}')
+ | else
+ | parse_json('{"a": ' || id ||
+ | ', "nestedArr2": []' ||
+ | ', "b": "' || id || '"}')
+ | end as v
+ | from range(0, 10000, 1, 1)
+ |""".stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+ // "a" and "b" appear in all rows. nestedArr: only in 1/inverseFreq
rows (like rareArray).
+ // nestedArr2: always present, inner arrays only in 1/inverseFreq rows
(like rareArray2).
+ val expected = if (inverseFreq > 10) {
+ // nestedArr dropped. nestedArr2: array<variant> (inner array in <
10% rows)
+ DataType.fromDDL("struct<a long, b string, nestedArr2
array<variant>>")
+ } else {
+ // Both nested arrays in >= 10% of rows -> array<array<long>>
+ DataType.fromDDL(
+ "struct<a long, b string, " +
+ "nestedArr array<array<long>>, nestedArr2 array<array<long>>>")
+ }
+ checkFileSchema(expected, dir)
+ checkStringAndSchema(dir, df)
+ }
+ }
+ }
+
test("infer shredding does not infer wide schemas") {
Seq(50, 60, 70).foreach { topLevelFields =>
// If this changes, we should change the test, or set it explicitly.
@@ -203,13 +322,9 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
}
testWithTempDir("infer shredding key as data") { dir =>
- // The first 10 fields in each object include the row ID in the field
name, so they'll be
- // unique. Because we impose a 1000-field limit when building up the
schema, we'll end up
- // dropping all but the first 1000, so we won't include the non-unique
fields in the schema.
- // Since the unique names are below the count threshold, we'll end up
with an unshredded
- // schema.
- // In the future, we could consider trying to improve this by dropping
the least-common fields
- // when we hit the limit of 1000.
+ // The 50 first_*_<id> fields are unique per row (low cardinality) and
are filtered
+ // out; the 50 last_* fields are shared across all rows (high
cardinality) and are
+ // shredded into typed_value. v2 is shredded independently.
val bigObject = (0 until 100).map { i =>
if (i < 50) {
s""" "first_${i}_' || id || '": {"x": $i, "y": "${i + 1}"} """
@@ -226,15 +341,29 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
val footers = getFooters(dir)
assert(footers.size == 1)
- // We can't call checkFileSchema, because it only handles the case of
one Variant column in
- // the file.
- val largeExpected =
SparkShreddingUtils.variantShreddingSchema(DataType.fromDDL("variant"))
+ val actual = getFileSchema(dir)
+ val v_schema = actual.fields(0).dataType.asInstanceOf[StructType]
+ val v2_schema = actual.fields(1).dataType.asInstanceOf[StructType]
+
+ // v should have shredded typed_value containing exactly the 50 last_*
fields, each
+ // with the struct<x long, y string> shape (matching the literal types
in the input).
+ // No first_*_<id> field should survive.
+ assert(v_schema.fieldNames.contains("typed_value"))
+ val v_typed = v_schema("typed_value").dataType.asInstanceOf[StructType]
+ assert(!v_typed.fieldNames.exists(_.startsWith("first_")))
+ assert(v_typed.fieldNames.count(_.startsWith("last_")) == 50)
+ val perElementExpected = SparkShreddingUtils.variantShreddingSchema(
+ DataType.fromDDL("struct<x long, y string>"), isTopLevel = false)
+ v_typed.fields.foreach { f =>
+ assert(f.name.startsWith("last_"))
+ assert(f.dataType == perElementExpected)
+ }
+
+ // v2 should be fully shredded
val smallExpected = SparkShreddingUtils.variantShreddingSchema(
DataType.fromDDL("struct<x long, y long>"))
- val actual = getFileSchema(dir)
- assert(actual == StructType(Seq(
- StructField("v", largeExpected, nullable = false),
- StructField("v2", smallExpected, nullable = false))))
+ assert(v2_schema == smallExpected)
+
checkStringAndSchema(dir, df)
}
@@ -634,4 +763,108 @@ class VariantInferShreddingSuite extends
SharedSparkSession with ParquetTest {
checkFileSchema(expected, dir)
checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
}
+
+ testWithTempDir("special characters in field names - dots") { dir =>
+ val df = spark.sql(
+ """
+ |select parse_json(
+ | '{"field.with.dots": ' || id || ', "another.dotted.field": "value"}'
+ |) as v
+ |from range(0, 100, 1, 1)
+ """.stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+ // Verify the schema contains fields with dots
+ val schema = getFileSchema(dir)
+ val vSchema = schema("v").dataType.asInstanceOf[StructType]
+ val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+ assert(typedValue.fieldNames.contains("another.dotted.field"))
+ assert(typedValue.fieldNames.contains("field.with.dots"))
+
+ // Verify the variant values round-trip correctly (catches metadata
mis-encoding).
+ checkStringAndSchema(dir, df)
+ }
+
+ testWithTempDir("special characters in field names - brackets") { dir =>
+ val df = spark.sql(
+ """
+ |select parse_json(
+ | '{"field[0]": ' || id || ', "another[key]": "value"}'
+ |) as v
+ |from range(0, 100, 1, 1)
+ """.stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+ // Verify the schema contains fields with brackets
+ val schema = getFileSchema(dir)
+ val vSchema = schema("v").dataType.asInstanceOf[StructType]
+ val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+ assert(typedValue.fieldNames.contains("another[key]"))
+ assert(typedValue.fieldNames.contains("field[0]"))
+
+ // Verify the variant values round-trip correctly (catches metadata
mis-encoding).
+ checkStringAndSchema(dir, df)
+ }
+
+ testWithTempDir("special characters in field names - mixed") { dir =>
+ val df = spark.sql(
+ """
+ |select parse_json(
+ | '{"a.b[0]": ' || id || ', "c[d].e": "value", "normal_field": 42}'
+ |) as v
+ |from range(0, 100, 1, 1)
+ """.stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+ // Verify the schema contains fields with mixed special characters
+ val schema = getFileSchema(dir)
+ val vSchema = schema("v").dataType.asInstanceOf[StructType]
+ val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+ assert(typedValue.fieldNames.contains("a.b[0]"))
+ assert(typedValue.fieldNames.contains("c[d].e"))
+ assert(typedValue.fieldNames.contains("normal_field"))
+
+ // Verify the variant values round-trip correctly (catches metadata
mis-encoding).
+ checkStringAndSchema(dir, df)
+ }
+
+ testWithTempDir("special characters in field names - literal empty
brackets") { dir =>
+ val df = spark.sql(
+ """
+ |select parse_json(
+ | '{"[]": ' || id || ', "normal_field": "value"}'
+ |) as v
+ |from range(0, 100, 1, 1)
+ """.stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+ val schema = getFileSchema(dir)
+ val vSchema = schema("v").dataType.asInstanceOf[StructType]
+ val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+ assert(typedValue.fieldNames.contains("[]"))
+ assert(typedValue.fieldNames.contains("normal_field"))
+
+ // Verify the variant values round-trip correctly (catches metadata
mis-encoding).
+ checkStringAndSchema(dir, df)
+ }
+
+ testWithTempDir("special characters in field names - literal empty brackets
with array") { dir =>
+ val df = spark.sql(
+ """
+ |select parse_json(
+ | '{"[]": ' || id || ', "arr": [' || id || ', ' || (id + 1) || ']}'
+ |) as v
+ |from range(0, 100, 1, 1)
+ """.stripMargin)
+ df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+ val schema = getFileSchema(dir)
+ val vSchema = schema("v").dataType.asInstanceOf[StructType]
+ val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+ assert(typedValue.fieldNames.contains("[]"))
+ assert(typedValue.fieldNames.contains("arr"))
+
+ // Verify the variant values round-trip correctly (catches metadata
mis-encoding).
+ checkStringAndSchema(dir, df)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]