This is an automated email from the ASF dual-hosted git repository.

cloud-fan pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new df593423adfe [SPARK-55568][SQL] Separate schema construction from 
field stats collection
df593423adfe is described below

commit df593423adfe86248fa714a7d4987f506c1ab1d0
Author: Qiegang Long <[email protected]>
AuthorDate: Thu May 28 02:09:20 2026 +0800

    [SPARK-55568][SQL] Separate schema construction from field stats collection
    
    ### Why are the changes needed?
    
    Variant shredding schema inference is expensive and can take over 100ms per 
file. Replace fold-based schema merging with deferred schema construction using 
single-pass field statistics collection.
    
    Previous approach:
    - Used foldLeft to build and merge complete schemas for each row
    - Merged schemas repeatedly across 4096 rows
    - High allocation overhead from recursive schema construction
    
    New approach:
    - Separate schema construction from field statistics collection to avoid 
excessive intermediate allocations and repeated merges.
    - Single-pass field traversal with field statistics tree to track field 
types and row counts
    - Using lastSeenRow for deduplication
    - Defers schema construction until after all rows processed
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Functional test:
    * Pass all existing unit tests
    
    Performance vs master:
    - Tested with scenarios with different field counts, array sizes, and batch 
sizes(1-4096 rows, 10-200 fields, varying nesting depths and sparsity patterns).
    - 1.7x to 2.4x speed up across test scenarios
    - Consistent performance across multiple runs
    - 96% of tests show improvement
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Co-authored with Claude Sonnet 4.5
    
    Closes #54343 from qlong/SPARK-55568-optimize-variant-schema-inference.
    
    Lead-authored-by: Qiegang Long <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit c4842918cf958ca4681415e755fd66cfda48a5e3)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../parquet/InferVariantShreddingSchema.scala      | 356 ++++++++++++++++-----
 .../parquet/VariantInferShreddingSuite.scala       | 261 ++++++++++++++-
 2 files changed, 516 insertions(+), 101 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala
index 1ebb61968150..bbfbbfde0ba4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
+import scala.collection.mutable
+
 import org.apache.spark.SparkRuntimeException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.internal.SQLConf
@@ -93,74 +95,31 @@ class InferVariantShreddingSchema(val schema: StructType) {
 
   private val COUNT_METADATA_KEY = "COUNT"
 
-  /**
-   * Return an appropriate schema for shredding a Variant value.
-   * It is similar to the SchemaOfVariant expression, but the rules are 
somewhat different, because
-   * we want the types to be consistent with what will be allowed during 
shredding. E.g.
-   * SchemaOfVariant will consider the common type across Integer and Double 
to be double, but we
-   * consider it to be VariantType, since shredding will not allow those types 
to be written to
-   * the same typed_value.
-   * We also maintain metadata on struct fields to track how frequently they 
occur. Rare fields
-   * are dropped in the final schema.
-   */
-  private def schemaOf(v: Variant, maxDepth: Int): DataType = v.getType match {
-    case Type.OBJECT =>
-      if (maxDepth <= 0) return VariantType
-      val size = v.objectSize()
-      val fields = new Array[StructField](size)
-      for (i <- 0 until size) {
-        val field = v.getFieldAtIndex(i)
-        fields(i) = StructField(field.key, schemaOf(field.value, maxDepth - 1),
-          metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, 
1).build())
-      }
-      // According to the variant spec, object fields must be sorted 
alphabetically. So we don't
-      // have to sort, but just need to validate they are sorted.
-      for (i <- 1 until size) {
-        if (fields(i - 1).name >= fields(i).name) {
-          throw new SparkRuntimeException(
-            errorClass = "MALFORMED_VARIANT",
-            messageParameters = Map.empty
-          )
-        }
-      }
-      StructType(fields)
-    case Type.ARRAY =>
-      if (maxDepth <= 0) return VariantType
-      var elementType: DataType = NullType
-      for (i <- 0 until v.arraySize()) {
-        elementType = mergeSchema(elementType, 
schemaOf(v.getElementAtIndex(i), maxDepth - 1))
-      }
-      ArrayType(elementType)
-    case Type.NULL => NullType
-    case Type.BOOLEAN => BooleanType
-    case Type.LONG =>
-      // Compute the smallest decimal that can contain this value.
-      // This will allow us to merge with decimal later without introducing 
excessive precision.
-      // If we only end up encountering integer values, we'll convert back to 
LongType when we
-      // finalize.
-      val d = BigDecimal(v.getLong())
-      val precision = d.precision
-      if (precision <= Decimal.MAX_LONG_DIGITS) {
-        DecimalType(precision, 0)
-      } else {
-        // Value is too large for Decimal(18, 0), so record its type as long.
-        LongType
+  // Node for tree-based field tracking
+  private case class FieldNode(
+    // Scalar type, or a marker (StructType(empty) / ArrayType(NullType)); 
structural
+    // shape lives in `children` and `arrayElementNode`.
+    var dataType: DataType,
+    var rowCount: Int = 0,           // Count of distinct rows containing this 
field
+    var lastSeenRow: Int = -1,       // Last row index that incremented 
rowCount
+    var arrayElementCount: Long = 0, // Total occurrences across all array 
elements
+    children: mutable.Map[String, FieldNode] = mutable.Map.empty,
+    var arrayElementNode: Option[FieldNode] = None
+  ) {
+
+    def getOrCreateChild(fieldName: String): FieldNode = {
+      children.getOrElseUpdate(fieldName, FieldNode(NullType))
+    }
+
+    def getChildren: Seq[(String, FieldNode)] = children.toSeq
+
+    def getOrCreateArrayElement(): FieldNode = {
+      arrayElementNode.getOrElse {
+        val node = FieldNode(NullType)
+        arrayElementNode = Some(node)
+        node
       }
-    case Type.STRING => StringType
-    case Type.DOUBLE => DoubleType
-    case Type.DECIMAL =>
-      // Don't strip trailing zeros to determine scale. Even if we allow scale 
relaxation during
-      // shredding, it's useful to take trailing zeros as a hint that the 
extra digits may be used
-      // in later values, and use the larger scale.
-      val d = Decimal(v.getDecimalWithOriginalScale())
-      DecimalType(d.precision, d.scale)
-    case Type.DATE => DateType
-    case Type.TIMESTAMP => TimestampType
-    case Type.TIMESTAMP_NTZ => TimestampNTZType
-    case Type.FLOAT => FloatType
-    case Type.BINARY => BinaryType
-    // Spark doesn't support UUID, so shred it as an untyped value.
-    case Type.UUID => VariantType
+    }
   }
 
   private def getFieldCount(field: StructField): Long = {
@@ -203,6 +162,9 @@ class InferVariantShreddingSchema(val schema: StructType) {
         mergeDecimalWithLong(d)
       case (StructType(fields1), StructType(fields2)) =>
         // Rely on fields being sorted by name, and merge fields with the same 
name recursively.
+        // In this inference path, non-empty struct merges are unused: 
`inferPrimitiveType` only
+        // produces empty struct markers; field lists are built in 
`buildSchemaFromStats` from the
+        // FieldNode tree instead.
         val newFields = new java.util.ArrayList[StructField]()
 
         var f1Idx = 0
@@ -351,36 +313,256 @@ class InferVariantShreddingSchema(val schema: 
StructType) {
   }
 
   def inferSchema(rows: Seq[InternalRow]): StructType = {
-    // For each path to a Variant value, iterate over all rows and update the 
inferred schema.
-    // Add the result to a map, which we'll use to update the full schema.
-    // maxShreddedFieldsPerFile is a global max for all fields, so initialize 
it here.
+    // For each variant path, collect field statistics using a single pass
     val maxFields = MaxFields(maxShreddedFieldsPerFile)
+
     val inferredSchemas = pathsToVariant.map { path =>
-      var numNonNullValues = 0
-      val simpleSchema = rows.foldLeft(NullType: DataType) {
-        case (partialSchema, row) =>
-          getValueAtPath(schema, row, path).map { variantVal =>
-            numNonNullValues += 1
-            val v = new Variant(variantVal.getValue, variantVal.getMetadata)
-            val schemaOfRow = schemaOf(v, maxShreddingDepth)
-            mergeSchema(partialSchema, schemaOfRow)
-          // If getValueAtPath returned None, the value is null in this row; 
just ignore.
-          }
-          .getOrElse(partialSchema)
-        // If we didn't find any non-null rows, use an unshredded schema.
-      }
+      val rootNode = FieldNode(NullType)
+      var numNonNullVariants = 0
 
-      // Don't infer a schema for fields that appear in less than 10% of rows.
-      // Ensure that minCardinality is at least 1 if we have any rows.
-      val minCardinality = (numNonNullValues + 9) / 10
+      // Single pass: process all rows for this variant path
+      rows.zipWithIndex.foreach { case (row, rowIdx) =>
+        getValueAtPath(schema, row, path).foreach { variantVal =>
+          numNonNullVariants += 1
+          val v = new Variant(variantVal.getValue, variantVal.getMetadata)
+          rootNode.dataType = mergeSchema(rootNode.dataType, 
inferPrimitiveType(v, 0))
+          // Traverse variant and update field stats tree
+          collectFieldStats(v, rootNode, rowIdx, 0, inArrayContext = false)
+        }
+      }
 
+      // Build final schema from collected statistics
+      val minCardinality = (numNonNullVariants + 9) / 10
+      val simpleSchema = buildSchemaFromStats(
+        rootNode,
+        minCardinality,
+        inArrayContext = false,
+        isArray = rootNode.arrayElementNode.isDefined)
       val finalizedSchema = finalizeSimpleSchema(simpleSchema, minCardinality, 
maxFields)
       val shreddingSchema = 
SparkShreddingUtils.variantShreddingSchema(finalizedSchema)
       val schemaWithMetadata = 
SparkShreddingUtils.addWriteShreddingMetadata(shreddingSchema)
       (path, schemaWithMetadata)
     }.toMap
 
-    // Insert each inferred schema into the full schema.
+    // Insert each inferred schema into the full schema
     updateSchema(schema, inferredSchemas)
   }
+
+  /**
+   * Recursively traverse a variant value and build a field statistics tree.
+   * For each field encountered, record its type and track distinct row count.
+   * For fields inside arrays, also increment the occurrence count.
+   */
+  private def collectFieldStats(
+      v: Variant,
+      currentNode: FieldNode,
+      rowIdx: Int,
+      depth: Int,
+      inArrayContext: Boolean): Unit = {
+
+    if (depth >= maxShreddingDepth) return
+
+    v.getType match {
+      case Type.OBJECT =>
+        val size = v.objectSize()
+        // Validate fields are sorted (per variant spec)
+        for (i <- 1 until size) {
+          val prevKey = v.getFieldAtIndex(i - 1).key
+          val currKey = v.getFieldAtIndex(i).key
+          if (prevKey >= currKey) {
+            throw new SparkRuntimeException(
+              errorClass = "MALFORMED_VARIANT",
+              messageParameters = Map.empty
+            )
+          }
+        }
+
+        // Process each field
+        for (i <- 0 until size) {
+          val field = v.getFieldAtIndex(i)
+          val fieldName = field.key
+
+          // Get or create child node (O(1) map access - no path string 
building!)
+          val childNode = currentNode.getOrCreateChild(fieldName)
+
+          // Track row-level presence only outside array context.
+          if (inArrayContext) {
+            childNode.arrayElementCount += 1
+          } else if (childNode.lastSeenRow != rowIdx) {
+            childNode.rowCount += 1
+            childNode.lastSeenRow = rowIdx
+          }
+
+          // Infer and merge type
+          val fieldType = inferPrimitiveType(field.value, depth)
+          childNode.dataType = mergeSchema(childNode.dataType, fieldType)
+
+          // Recurse into nested structures (pass child node, not path string)
+          collectFieldStats(field.value, childNode, rowIdx, depth + 1, 
inArrayContext)
+        }
+
+      case Type.ARRAY =>
+        val arrayNode = currentNode.getOrCreateArrayElement()
+
+        // Track distinct row count for the array field itself
+        if (arrayNode.lastSeenRow != rowIdx) {
+          arrayNode.rowCount += 1
+          arrayNode.lastSeenRow = rowIdx
+        }
+
+        val arraySize = v.arraySize()
+        if (arraySize > 0) {
+          // Process array elements
+          for (i <- 0 until arraySize) {
+            val element = v.getElementAtIndex(i)
+            val elementTypeClass = element.getType
+
+            // Primitives merge into `dataType` only; objects and arrays need 
tree descent.
+            if (elementTypeClass != Type.OBJECT && elementTypeClass != 
Type.ARRAY) {
+              val primitiveType = inferPrimitiveType(element, depth)
+              arrayNode.dataType = mergeSchema(arrayNode.dataType, 
primitiveType)
+            } else {
+              collectFieldStats(element, arrayNode, rowIdx, depth + 1, 
inArrayContext = true)
+            }
+          }
+        }
+
+      case _ =>
+    }
+  }
+
+  /**
+   * Infer the type of a variant value without recursive field collection.
+   * For objects and arrays, return a marker type; recursive collection is 
done separately.
+   */
+  private def inferPrimitiveType(v: Variant, depth: Int): DataType = {
+    if (depth >= maxShreddingDepth) return VariantType
+
+    v.getType match {
+      case Type.OBJECT =>
+        // Return empty struct as marker; fields collected separately
+        StructType(Seq.empty)
+      case Type.ARRAY =>
+        // Return array with null element as marker; elements processed 
separately
+        ArrayType(NullType)
+      case Type.NULL => NullType
+      case Type.BOOLEAN => BooleanType
+      case Type.LONG =>
+        val d = BigDecimal(v.getLong())
+        val precision = d.precision
+        if (precision <= Decimal.MAX_LONG_DIGITS) {
+          DecimalType(precision, 0)
+        } else {
+          LongType
+        }
+      case Type.STRING => StringType
+      case Type.DOUBLE => DoubleType
+      case Type.DECIMAL =>
+        val d = Decimal(v.getDecimalWithOriginalScale())
+        DecimalType(d.precision, d.scale)
+      case Type.DATE => DateType
+      case Type.TIMESTAMP => TimestampType
+      case Type.TIMESTAMP_NTZ => TimestampNTZType
+      case Type.FLOAT => FloatType
+      case Type.BINARY => BinaryType
+      case Type.UUID => VariantType
+    }
+  }
+
+  /**
+   * Build a schema from collected field statistics tree.
+   *
+   * When isArray=true the function builds and returns the full ArrayType for 
this node
+   * (using its arrayElementNode to determine the element type).
+   * When isArray=false it returns the type for the node itself (scalar, 
VariantType,
+   * or StructType).
+   *
+   * Cardinality metric:
+   *  - inArrayContext=true  uses arrayElementCount (total occurrences across 
array positions).
+   *  - inArrayContext=false uses rowCount (distinct rows containing the 
field).
+   */
+  private def buildSchemaFromStats(
+      currentNode: FieldNode,
+      minCardinality: Int,
+      inArrayContext: Boolean,
+      isArray: Boolean): DataType = {
+
+    // Pick the right counter for this context; reused in filter, sort, and 
metadata below.
+    def cardinality(n: FieldNode): Long =
+      if (inArrayContext) n.arrayElementCount else n.rowCount
+
+    // Array branch
+    if (isArray) {
+      // Case 1: mixed array and non-array rows at the same path merged 
dataType to VariantType.
+      //         The whole node is variant, not an array.
+      if (currentNode.dataType == VariantType) {
+        return VariantType
+      }
+      // Case 2 (defensive): every reachable isArray=true call site is paired 
with a guard that
+      //         already rules this out -- the root call collapses to 
VariantType (Case 1), the
+      //         Case 3 recursion short-circuits via the 
`elemNode.children.nonEmpty &&
+      //         elemNode.arrayElementNode.isDefined` check below, and the 
struct-field
+      //         ArrayType recursion only fires when the child was uniformly 
an array (so
+      //         `children` is empty). Keep this as a backstop for future call 
sites.
+      if (currentNode.children.nonEmpty && 
currentNode.arrayElementNode.isDefined) {
+        return ArrayType(VariantType)
+      }
+
+      // Case 3: uniform inner array -- recurse into the element node and 
merge with any scalar
+      //         dataType on the same node (e.g. `[1, {"a":1}]` merges long + 
struct -> variant).
+      currentNode.arrayElementNode match {
+        case Some(elemNode) if elemNode.rowCount >= minCardinality =>
+          // If the element node itself has both object children and a nested 
array, the two
+          // element shapes cannot be reconciled: treat the element as variant.
+          val elementType = if (elemNode.children.nonEmpty && 
elemNode.arrayElementNode.isDefined) {
+            VariantType
+          } else {
+            buildSchemaFromStats(
+              elemNode,
+              minCardinality,
+              inArrayContext = true,
+              isArray = elemNode.arrayElementNode.isDefined)
+          }
+          return ArrayType(mergeSchema(elemNode.dataType, elementType))
+        case _ =>
+          return currentNode.dataType
+      }
+    }
+
+    // Non-array branch: incompatible types already collapsed to VariantType 
during collection.
+    if (currentNode.dataType == VariantType) return VariantType
+
+    // Filter children by cardinality, keep the top N by frequency, sort 
alphabetically.
+    val maxStructSize = Math.min(1000, maxShreddedFieldsPerFile)
+    val children = currentNode.getChildren
+      .filter { case (_, n) => cardinality(n) >= minCardinality }
+      .sortBy { case (name, n) => (-cardinality(n), name) }
+      .take(maxStructSize)
+      .sortBy(_._1)
+
+    if (children.isEmpty) {
+      // No qualifying children: fall back to any scalar merged at this node, 
or variant.
+      return currentNode.dataType match {
+        case _: StructType | _: ArrayType | NullType => VariantType
+        case dt => dt
+      }
+    }
+
+    val fields = children.map { case (fieldName, childNode) =>
+      val fieldType = childNode.dataType match {
+        case StructType(_) =>
+          buildSchemaFromStats(childNode, minCardinality, inArrayContext, 
isArray = false)
+        case ArrayType(_, _) =>
+          buildSchemaFromStats(
+            childNode, minCardinality, inArrayContext = true,
+            isArray = childNode.arrayElementNode.isDefined)
+        case other => other
+      }
+      val cnt = cardinality(childNode)
+      StructField(fieldName, fieldType,
+        metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, 
cnt).build())
+    }
+
+    StructType(fields.toSeq)
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
index e9712aac0478..cbd3d89c3658 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala
@@ -111,6 +111,74 @@ class VariantInferShreddingSuite extends 
SharedSparkSession with ParquetTest {
     checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
   }
 
+  testWithTempDir("infer shredding pure primitive root column") { dir =>
+    val df = spark.sql(
+      """
+        | select parse_json('42') as v
+        | from range(0, 20, 1, 1)
+        |""".stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+    // Root `dataType` merges to a numeric type; inference should preserve it 
(typed_value),
+    // not collapse to an unshredded variant when there are no object fields.
+    val expected = LongType
+    checkFileSchema(expected, dir)
+    checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+  }
+
+  testWithTempDir("infer shredding mixed object and primitive at root") { dir 
=>
+    val df = spark.sql(
+      """
+        | select if(id % 2 = 0, parse_json('{"a": 1}'), parse_json('42')) as v
+        | from range(0, 20, 1, 1)
+        |""".stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+    // Object rows and scalar rows cannot share one shredded struct; the 
column should stay
+    // variant (value-only) at the logical level.
+    val expected = VariantType
+    checkFileSchema(expected, dir)
+    checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+  }
+
+  testWithTempDir("infer shredding heterogeneous array elements (primitive and 
object)") { dir =>
+    val df = spark.sql(
+      """
+        | select parse_json('[1, {"a": 1}]') as v
+        | from range(0, 20, 1, 1)
+        |""".stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+    // Mixed element kinds should merge to variant elements, not a struct that 
only reflects
+    // object fields.
+    val expected = ArrayType(VariantType, containsNull = false)
+    checkFileSchema(expected, dir)
+    checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+  }
+
+  testWithTempDir("infer shredding mixed array and non-array at root") { dir =>
+    val df = spark.sql(
+      """
+        | select if(id % 2 = 0, parse_json('[1]'), parse_json('42')) as v
+        | from range(0, 20, 1, 1)
+        |""".stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+    // Array rows and scalar rows cannot share one array element type; the 
column must stay variant.
+    val expected = VariantType
+    checkFileSchema(expected, dir)
+    checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+  }
+
+  testWithTempDir("infer shredding heterogeneous array elements (object and 
nested array)") { dir =>
+    val df = spark.sql(
+      """
+        | select parse_json('[{"a": 1}, [2]]') as v
+        | from range(0, 20, 1, 1)
+        |""".stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+    // Object elements and inner-array elements on the same aggregate must 
become array<variant>.
+    val expected = ArrayType(VariantType, containsNull = false)
+    checkFileSchema(expected, dir)
+    checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+  }
+
   test("infer shredding does not infer rare rows") {
     Seq(2, 9, 10, 11, 19, 20, 21, 100).foreach { inverseFreq =>
       withTempDir { dir =>
@@ -162,6 +230,57 @@ class VariantInferShreddingSuite extends 
SharedSparkSession with ParquetTest {
     }
   }
 
+  testWithTempDir("infer shredding nested array of arrays") { dir =>
+    // Array-of-arrays [[1,2],[3,4]] in all rows. Verifies that nested arrays 
are
+    // correctly inferred and shredded using the FieldNode tree 
(arrayElementNode chain).
+    val df = spark.sql(
+      """
+        | select parse_json('[[1, 2], [3, 4]]') as v
+        | from range(0, 100, 1, 1)
+      """.stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+    val expected = DataType.fromDDL("array<array<long>>")
+    checkFileSchema(expected, dir)
+    checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
+  }
+
+  test("infer shredding does not infer rare nested arrays") {
+    // Similar to "infer shredding does not infer rare rows" but for nested 
arrays.
+    // Demonstrates row-level gating for nested arrays.
+    Seq(2, 9, 10, 11, 19, 20, 21, 100).foreach { inverseFreq =>
+      withTempDir { dir =>
+        val df = spark.sql(
+          s"""
+             | select case when id % $inverseFreq = 0 then
+             |  parse_json('{"a": ' || id ||
+             |  ', "nestedArr": [[1, 2], [3, 4]]' ||
+             |  ', "nestedArr2": [[1, 2], [3, 4]]' ||
+             |  ', "b": "' || id || '"}')
+             |  else
+             |  parse_json('{"a": ' || id ||
+             |  ', "nestedArr2": []' ||
+             |  ', "b": "' || id || '"}')
+             |  end as v
+             |  from range(0, 10000, 1, 1)
+             |""".stripMargin)
+        df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+        // "a" and "b" appear in all rows. nestedArr: only in 1/inverseFreq 
rows (like rareArray).
+        // nestedArr2: always present, inner arrays only in 1/inverseFreq rows 
(like rareArray2).
+        val expected = if (inverseFreq > 10) {
+          // nestedArr dropped. nestedArr2: array<variant> (inner array in < 
10% rows)
+          DataType.fromDDL("struct<a long, b string, nestedArr2 
array<variant>>")
+        } else {
+          // Both nested arrays in >= 10% of rows -> array<array<long>>
+          DataType.fromDDL(
+            "struct<a long, b string, " +
+            "nestedArr array<array<long>>, nestedArr2 array<array<long>>>")
+        }
+        checkFileSchema(expected, dir)
+        checkStringAndSchema(dir, df)
+      }
+    }
+  }
+
   test("infer shredding does not infer wide schemas") {
     Seq(50, 60, 70).foreach { topLevelFields =>
       // If this changes, we should change the test, or set it explicitly.
@@ -203,13 +322,9 @@ class VariantInferShreddingSuite extends 
SharedSparkSession with ParquetTest {
   }
 
   testWithTempDir("infer shredding key as data") { dir =>
-      // The first 10 fields in each object include the row ID in the field 
name, so they'll be
-      // unique. Because we impose a 1000-field limit when building up the 
schema, we'll end up
-      // dropping all but the first 1000, so we won't include the non-unique 
fields in the schema.
-      // Since the unique names are below the count threshold, we'll end up 
with an unshredded
-      // schema.
-      // In the future, we could consider trying to improve this by dropping 
the least-common fields
-      // when we hit the limit of 1000.
+      // The 50 first_*_<id> fields are unique per row (low cardinality) and 
are filtered
+      // out; the 50 last_* fields are shared across all rows (high 
cardinality) and are
+      // shredded into typed_value. v2 is shredded independently.
       val bigObject = (0 until 100).map { i =>
         if (i < 50) {
           s""" "first_${i}_' || id || '": {"x": $i, "y": "${i + 1}"}  """
@@ -226,15 +341,29 @@ class VariantInferShreddingSuite extends 
SharedSparkSession with ParquetTest {
       val footers = getFooters(dir)
       assert(footers.size == 1)
 
-      // We can't call checkFileSchema, because it only handles the case of 
one Variant column in
-      // the file.
-      val largeExpected = 
SparkShreddingUtils.variantShreddingSchema(DataType.fromDDL("variant"))
+      val actual = getFileSchema(dir)
+      val v_schema = actual.fields(0).dataType.asInstanceOf[StructType]
+      val v2_schema = actual.fields(1).dataType.asInstanceOf[StructType]
+
+      // v should have shredded typed_value containing exactly the 50 last_* 
fields, each
+      // with the struct<x long, y string> shape (matching the literal types 
in the input).
+      // No first_*_<id> field should survive.
+      assert(v_schema.fieldNames.contains("typed_value"))
+      val v_typed = v_schema("typed_value").dataType.asInstanceOf[StructType]
+      assert(!v_typed.fieldNames.exists(_.startsWith("first_")))
+      assert(v_typed.fieldNames.count(_.startsWith("last_")) == 50)
+      val perElementExpected = SparkShreddingUtils.variantShreddingSchema(
+        DataType.fromDDL("struct<x long, y string>"), isTopLevel = false)
+      v_typed.fields.foreach { f =>
+        assert(f.name.startsWith("last_"))
+        assert(f.dataType == perElementExpected)
+      }
+
+      // v2 should be fully shredded
       val smallExpected = SparkShreddingUtils.variantShreddingSchema(
         DataType.fromDDL("struct<x long, y long>"))
-      val actual = getFileSchema(dir)
-      assert(actual == StructType(Seq(
-              StructField("v", largeExpected, nullable = false),
-              StructField("v2", smallExpected, nullable = false))))
+      assert(v2_schema == smallExpected)
+
       checkStringAndSchema(dir, df)
   }
 
@@ -634,4 +763,108 @@ class VariantInferShreddingSuite extends 
SharedSparkSession with ParquetTest {
     checkFileSchema(expected, dir)
     checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect())
   }
+
+  testWithTempDir("special characters in field names - dots") { dir =>
+    val df = spark.sql(
+      """
+        |select parse_json(
+        |  '{"field.with.dots": ' || id || ', "another.dotted.field": "value"}'
+        |) as v
+        |from range(0, 100, 1, 1)
+      """.stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+    // Verify the schema contains fields with dots
+    val schema = getFileSchema(dir)
+    val vSchema = schema("v").dataType.asInstanceOf[StructType]
+    val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+    assert(typedValue.fieldNames.contains("another.dotted.field"))
+    assert(typedValue.fieldNames.contains("field.with.dots"))
+
+    // Verify the variant values round-trip correctly (catches metadata 
mis-encoding).
+    checkStringAndSchema(dir, df)
+  }
+
+  testWithTempDir("special characters in field names - brackets") { dir =>
+    val df = spark.sql(
+      """
+        |select parse_json(
+        |  '{"field[0]": ' || id || ', "another[key]": "value"}'
+        |) as v
+        |from range(0, 100, 1, 1)
+      """.stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+    // Verify the schema contains fields with brackets
+    val schema = getFileSchema(dir)
+    val vSchema = schema("v").dataType.asInstanceOf[StructType]
+    val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+    assert(typedValue.fieldNames.contains("another[key]"))
+    assert(typedValue.fieldNames.contains("field[0]"))
+
+    // Verify the variant values round-trip correctly (catches metadata 
mis-encoding).
+    checkStringAndSchema(dir, df)
+  }
+
+  testWithTempDir("special characters in field names - mixed") { dir =>
+    val df = spark.sql(
+      """
+        |select parse_json(
+        |  '{"a.b[0]": ' || id || ', "c[d].e": "value", "normal_field": 42}'
+        |) as v
+        |from range(0, 100, 1, 1)
+      """.stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+    // Verify the schema contains fields with mixed special characters
+    val schema = getFileSchema(dir)
+    val vSchema = schema("v").dataType.asInstanceOf[StructType]
+    val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+    assert(typedValue.fieldNames.contains("a.b[0]"))
+    assert(typedValue.fieldNames.contains("c[d].e"))
+    assert(typedValue.fieldNames.contains("normal_field"))
+
+    // Verify the variant values round-trip correctly (catches metadata 
mis-encoding).
+    checkStringAndSchema(dir, df)
+  }
+
+  testWithTempDir("special characters in field names - literal empty 
brackets") { dir =>
+    val df = spark.sql(
+      """
+        |select parse_json(
+        |  '{"[]": ' || id || ', "normal_field": "value"}'
+        |) as v
+        |from range(0, 100, 1, 1)
+      """.stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+    val schema = getFileSchema(dir)
+    val vSchema = schema("v").dataType.asInstanceOf[StructType]
+    val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+    assert(typedValue.fieldNames.contains("[]"))
+    assert(typedValue.fieldNames.contains("normal_field"))
+
+    // Verify the variant values round-trip correctly (catches metadata 
mis-encoding).
+    checkStringAndSchema(dir, df)
+  }
+
+  testWithTempDir("special characters in field names - literal empty brackets 
with array") { dir =>
+    val df = spark.sql(
+      """
+        |select parse_json(
+        |  '{"[]": ' || id || ', "arr": [' || id || ', ' || (id + 1) || ']}'
+        |) as v
+        |from range(0, 100, 1, 1)
+      """.stripMargin)
+    df.write.mode("overwrite").parquet(dir.getAbsolutePath)
+
+    val schema = getFileSchema(dir)
+    val vSchema = schema("v").dataType.asInstanceOf[StructType]
+    val typedValue = vSchema("typed_value").dataType.asInstanceOf[StructType]
+    assert(typedValue.fieldNames.contains("[]"))
+    assert(typedValue.fieldNames.contains("arr"))
+
+    // Verify the variant values round-trip correctly (catches metadata 
mis-encoding).
+    checkStringAndSchema(dir, df)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to