This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 104dc2b708c4 [SPARK-51503][SQL] Support Variant type in XML scan
104dc2b708c4 is described below

commit 104dc2b708c44bfa948f9bef3a0bc3b27bdcd260
Author: Xiaonan Yang <xiaonan.y...@databricks.com>
AuthorDate: Fri Apr 11 08:10:12 2025 +0900

    [SPARK-51503][SQL] Support Variant type in XML scan
    
    ### What changes were proposed in this pull request?
    
    This PR introduces the capability of reading XML data as Variant type. It 
includes the Variant support in both XML scan and the `from_xml` SQL 
expression. Writing variant values to XML will be supported in a subsequent PR.
    
    Reading XML as Variant will follow the similar rules as reading XML as 
structs and use the same options. There are two modes of reading XML data as 
Variants:
    - The first mode is to ingest an entire XML file as a variant via a new 
`singleVariantColumn` option. The option specifies the name of the single 
output column with variant type. Once the option is set, the schema inference 
stage is skipped and a single variant type is used as the schema of the XML 
scan. In the parsing stage, each XML record enclosed by the `rowTag` option 
will be converted to a single variant object and the child elements are 
recursively parsed as variant types.
    - The second mode is specifying a column as the variant type, which can be 
achieved by providing a defined schema. In this mode, fields not specified as 
variants will be parsed as non-variant values following the existing behavior, 
and fields specified as variants will be parsed as variants.
    
    ### Why are the changes needed?
    
    Allow users to ingest XML data as Variants
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit tests are added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50300 from xiaonanyang-db/SPARK-51503.
    
    Authored-by: Xiaonan Yang <xiaonan.y...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../src/main/resources/error/error-conditions.json |   5 +
 .../expressions/xml/XmlExpressionEvalUtils.scala   |  52 +-
 .../sql/catalyst/expressions/xmlExpressions.scala  |  49 +-
 .../spark/sql/catalyst/xml/StaxXmlParser.scala     | 263 ++++++++-
 .../apache/spark/sql/catalyst/xml/XmlOptions.scala |   6 +
 .../execution/datasources/xml/XmlDataSource.scala  |  14 +-
 .../execution/datasources/xml/XmlFileFormat.scala  |   2 +-
 .../org/apache/spark/sql/XmlFunctionsSuite.scala   |  12 +-
 .../sql/execution/datasources/xml/XmlSuite.scala   |  28 +
 .../datasources/xml/XmlVariantSuite.scala          | 586 +++++++++++++++++++++
 10 files changed, 967 insertions(+), 50 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index f098e27555b1..37cc80a12394 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -994,6 +994,11 @@
           "Input schema <schema> can only contain STRING as a key type for a 
MAP."
         ]
       },
+      "INVALID_XML_SCHEMA" : {
+        "message" : [
+          "Input schema <schema> must be a struct or a variant."
+        ]
+      },
       "IN_SUBQUERY_DATA_TYPE_MISMATCH" : {
         "message" : [
           "The data type of one or more elements in the left hand side of an 
IN subquery is not compatible with the data type of the output of the subquery. 
Mismatched columns: [<mismatchedColumns>], left side: [<leftType>], right side: 
[<rightType>]."
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
index 89d7b8d9421a..5a2cd8ad76ee 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
@@ -17,8 +17,11 @@
 
 package org.apache.spark.sql.catalyst.expressions.xml
 
-import org.apache.spark.sql.catalyst.util.GenericArrayData
-import org.apache.spark.sql.catalyst.xml.XmlInferSchema
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils}
+import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, 
GenericArrayData, PermissiveMode}
+import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, ValidatorUtil, 
XmlInferSchema, XmlOptions}
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -119,3 +122,48 @@ case class XPathListEvaluator(path: UTF8String) extends 
XPathEvaluator {
     }
   }
 }
+
+case class XmlToStructsEvaluator(
+    options: Map[String, String],
+    nullableSchema: DataType,
+    nameOfCorruptRecord: String,
+    timeZoneId: Option[String],
+    child: Expression
+) {
+  @transient lazy val parsedOptions = new XmlOptions(options, timeZoneId.get, 
nameOfCorruptRecord)
+
+  // This converts parsed rows to the desired output by the given schema.
+  @transient
+  private lazy val converter =
+    (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
+
+  // Parser that parse XML strings as internal rows
+  @transient
+  private lazy val parser = {
+    val mode = parsedOptions.parseMode
+    if (mode != PermissiveMode && mode != FailFastMode) {
+      throw QueryCompilationErrors.parseModeUnsupportedError("from_xml", mode)
+    }
+
+    // The parser is only used when the input schema is StructType
+    val schema = nullableSchema.asInstanceOf[StructType]
+    ExprUtils.verifyColumnNameOfCorruptRecord(schema, 
parsedOptions.columnNameOfCorruptRecord)
+    val rawParser = new StaxXmlParser(schema, parsedOptions)
+
+    val xsdSchema = 
Option(parsedOptions.rowValidationXSDPath).map(ValidatorUtil.getSchema)
+
+    new FailureSafeParser[String](
+      input => rawParser.doParseColumn(input, mode, xsdSchema),
+      mode,
+      schema,
+      parsedOptions.columnNameOfCorruptRecord)
+  }
+
+  final def evaluate(xml: UTF8String): Any = {
+    if (xml == null) return null
+    nullableSchema match {
+      case _: VariantType => StaxXmlParser.parseVariant(xml.toString, 
parsedOptions)
+      case _: StructType => converter(parser.parse(xml.toString))
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
index 25a054f79c36..d6e3eea579d6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
@@ -23,10 +23,10 @@ import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, 
TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
-import org.apache.spark.sql.catalyst.expressions.xml.XmlExpressionEvalUtils
-import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, 
FailureSafeParser, PermissiveMode}
+import org.apache.spark.sql.catalyst.expressions.xml.{XmlExpressionEvalUtils, 
XmlToStructsEvaluator}
+import org.apache.spark.sql.catalyst.util.DropMalformedMode
 import org.apache.spark.sql.catalyst.util.TypeUtils._
-import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, 
ValidatorUtil, XmlInferSchema, XmlOptions}
+import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, XmlInferSchema, 
XmlOptions}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.types.StringTypeWithCollation
@@ -54,7 +54,7 @@ import org.apache.spark.unsafe.types.UTF8String
   since = "4.0.0")
 // scalastyle:on line.size.limit
 case class XmlToStructs(
-    schema: StructType,
+    schema: DataType,
     options: Map[String, String],
     child: Expression,
     timeZoneId: Option[String] = None)
@@ -65,7 +65,7 @@ case class XmlToStructs(
 
   def this(child: Expression, schema: Expression, options: Map[String, 
String]) =
     this(
-      schema = ExprUtils.evalSchemaExpr(schema),
+      schema = ExprUtils.evalTypeExpr(schema),
       options = options,
       child = child,
       timeZoneId = None)
@@ -81,45 +81,34 @@ case class XmlToStructs(
 
   def this(child: Expression, schema: Expression, options: Expression) =
     this(
-      schema = ExprUtils.evalSchemaExpr(schema),
+      schema = ExprUtils.evalTypeExpr(schema),
       options = ExprUtils.convertToMapData(options),
       child = child,
       timeZoneId = None)
 
-  // This converts parsed rows to the desired output by the given schema.
+  override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
+    case _: StructType | _: VariantType =>
+      val checkResult = ExprUtils.checkXmlSchema(nullableSchema)
+      if (checkResult.isFailure) checkResult else super.checkInputDataTypes()
+    case _ =>
+      DataTypeMismatch(
+        errorSubClass = "INVALID_XML_SCHEMA",
+        messageParameters = Map("schema" -> toSQLType(nullableSchema)))
+  }
+
   @transient
-  private lazy val converter =
-    (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
+  private lazy val evaluator: XmlToStructsEvaluator =
+    XmlToStructsEvaluator(options, nullableSchema, nameOfCorruptRecord, 
timeZoneId, child)
 
   private val nameOfCorruptRecord = 
SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
 
-  @transient
-  private lazy val parser = {
-    val parsedOptions = new XmlOptions(options, timeZoneId.get, 
nameOfCorruptRecord)
-    val mode = parsedOptions.parseMode
-    if (mode != PermissiveMode && mode != FailFastMode) {
-      throw QueryCompilationErrors.parseModeUnsupportedError("from_xml", mode)
-    }
-    ExprUtils.verifyColumnNameOfCorruptRecord(
-      nullableSchema, parsedOptions.columnNameOfCorruptRecord)
-    val rawParser = new StaxXmlParser(schema, parsedOptions)
-    val xsdSchema = 
Option(parsedOptions.rowValidationXSDPath).map(ValidatorUtil.getSchema)
-
-    new FailureSafeParser[String](
-      input => rawParser.doParseColumn(input, mode, xsdSchema),
-      mode,
-      nullableSchema,
-      parsedOptions.columnNameOfCorruptRecord)
-  }
-
   override def dataType: DataType = nullableSchema
 
   override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
     copy(timeZoneId = Option(timeZoneId))
   }
 
-  override def nullSafeEval(xml: Any): Any =
-    converter(parser.parse(xml.asInstanceOf[UTF8String].toString))
+  override def nullSafeEval(xml: Any): Any = 
evaluator.evaluate(xml.asInstanceOf[UTF8String])
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val expr = ctx.addReferenceObj("this", this)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
index 4b892da9db25..c82571d1f37f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.xml
 import java.io.{BufferedReader, CharConversionException, 
FileNotFoundException, InputStream, InputStreamReader, IOException, 
StringReader}
 import java.nio.charset.{Charset, MalformedInputException}
 import java.text.NumberFormat
+import java.util
 import java.util.Locale
 import javax.xml.stream.{XMLEventReader, XMLStreamException}
 import javax.xml.stream.events._
@@ -28,6 +29,7 @@ import javax.xml.validation.Schema
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 import scala.util.Try
+import scala.util.control.Exception.allCatch
 import scala.util.control.NonFatal
 import scala.xml.SAXException
 
@@ -45,7 +47,10 @@ import 
org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.types.variant.{Variant, VariantBuilder}
+import org.apache.spark.types.variant.VariantBuilder.FieldEntry
+import org.apache.spark.types.variant.VariantUtil
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
 
 class StaxXmlParser(
     schema: StructType,
@@ -138,11 +143,19 @@ class StaxXmlParser(
       xsdSchema.foreach { schema =>
         schema.newValidator().validate(new StreamSource(new StringReader(xml)))
       }
-      val parser = StaxXmlParserUtils.filteredReader(xml)
-      val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
-      val result = Some(convertObject(parser, schema, rootAttributes))
-      parser.close()
-      result
+      options.singleVariantColumn match {
+        case Some(_) =>
+          // If the singleVariantColumn is specified, parse the entire xml 
string as a Variant
+          val v = StaxXmlParser.parseVariant(xml, options)
+          Some(InternalRow(v))
+        case _ =>
+          // Otherwise, parse the xml string as Structs
+          val parser = StaxXmlParserUtils.filteredReader(xml)
+          val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
+          val result = Some(convertObject(parser, schema, rootAttributes))
+          parser.close()
+          result
+      }
     } catch {
       case e: SparkUpgradeException => throw e
       case e@(_: RuntimeException | _: XMLStreamException | _: 
MalformedInputException
@@ -379,6 +392,10 @@ class StaxXmlParser(
                 }
                 row(index) = values :+ newValue
 
+              case VariantType =>
+                val v = StaxXmlParser.convertVariant(parser, attributes, 
options)
+                row(index) = new VariantVal(v.getValue, v.getMetadata)
+
               case dt: DataType =>
                 row(index) = convertField(parser, dt, field, attributes)
             }
@@ -897,4 +914,238 @@ object StaxXmlParser {
       curRecord
     }
   }
+
+  /**
+   * Parse the input XML string as a Variant value
+   */
+  def parseVariant(xml: String, options: XmlOptions): VariantVal = {
+    val parser = StaxXmlParserUtils.filteredReader(xml)
+    val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
+    val variant = convertVariant(parser, rootAttributes, options)
+    val v = new VariantVal(variant.getValue, variant.getMetadata)
+    parser.close()
+    v
+  }
+
+  /**
+   * Parse an XML element from the XML event stream into a Variant.
+   * This method transforms the XML element along with its attributes and 
child elements
+   * into a hierarchical Variant data structure that preserves the XML 
structure.
+   *
+   * @param parser The XML event stream reader positioned after the start 
element
+   * @param attributes The attributes of the current XML element to be 
included in the Variant
+   * @param options Configuration options that control how XML is parsed into 
Variants
+   * @return A Variant representing the XML element with its attributes and 
child content
+   */
+  def convertVariant(
+      parser: XMLEventReader,
+      attributes: Array[Attribute],
+      options: XmlOptions): Variant = {
+    // The variant builder for the root startElement
+    val rootBuilder = new VariantBuilder(false)
+    val start = rootBuilder.getWritePos
+
+    // Map to store the variant values of all child fields
+    // Each field could have multiple entries, which means it's an array
+    // The map is sorted by field name, and the ordering is based on the case 
sensitivity
+    val caseSensitivityOrdering: Ordering[String] = if 
(SQLConf.get.caseSensitiveAnalysis) {
+      (x: String, y: String) => x.compareTo(y)
+    } else {
+      (x: String, y: String) => x.compareToIgnoreCase(y)
+    }
+    val fieldToVariants = collection.mutable.TreeMap.empty[String, 
java.util.ArrayList[Variant]](
+      caseSensitivityOrdering
+    )
+
+    // Handle attributes first
+    StaxXmlParserUtils.convertAttributesToValuesMap(attributes, 
options).foreach {
+      case (f, v) =>
+        val builder = new VariantBuilder(false)
+        appendXMLCharacterToVariant(builder, v, options)
+        val variants = fieldToVariants.getOrElseUpdate(f, new 
java.util.ArrayList[Variant]())
+        variants.add(builder.result())
+    }
+
+    var shouldStop = false
+    while (!shouldStop) {
+      parser.nextEvent() match {
+        case s: StartElement =>
+          // For each child element, convert it to a variant and keep track of 
it in
+          // fieldsToVariants
+          val attributes = 
s.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
+          val field = StaxXmlParserUtils.getName(s.asStartElement.getName, 
options)
+          val variants = fieldToVariants.getOrElseUpdate(field, new 
java.util.ArrayList[Variant]())
+          variants.add(convertVariant(parser, attributes, options))
+
+        case c: Characters if !c.isWhiteSpace =>
+          // Treat the character as a value tag field, where we use the 
[[XMLOptions.valueTag]] as
+          // the field key
+          val builder = new VariantBuilder(false)
+          appendXMLCharacterToVariant(builder, c.getData, options)
+          val variants = fieldToVariants.getOrElseUpdate(
+            options.valueTag,
+            new java.util.ArrayList[Variant]()
+          )
+          variants.add(builder.result())
+
+        case _: EndElement =>
+          if (fieldToVariants.nonEmpty) {
+            val onlyValueTagField = fieldToVariants.keySet.forall(_ == 
options.valueTag)
+            if (onlyValueTagField) {
+              // If the element only has value tag field, parse the element as 
a variant primitive
+              
rootBuilder.appendVariant(fieldToVariants(options.valueTag).get(0))
+            } else {
+              writeVariantObject(rootBuilder, fieldToVariants)
+            }
+          }
+          shouldStop = true
+
+        case _: EndDocument => shouldStop = true
+
+        case _ => // do nothing
+      }
+    }
+
+    // If the element is empty, we treat it as a Variant null
+    if (rootBuilder.getWritePos == start) {
+      rootBuilder.appendNull()
+    }
+
+    rootBuilder.result()
+  }
+
+  /**
+   * Write a variant object to the variant builder.
+   *
+   * @param builder The variant builder to write to
+   * @param fieldToVariants A map of field names to their corresponding 
variant values of the object
+   */
+  private def writeVariantObject(
+      builder: VariantBuilder,
+      fieldToVariants: collection.mutable.TreeMap[String, 
java.util.ArrayList[Variant]]): Unit = {
+    val start = builder.getWritePos
+    val objectFieldEntries = new java.util.ArrayList[FieldEntry]()
+
+    val (lastFieldKey, lastFieldValue) =
+      fieldToVariants.tail.foldLeft(fieldToVariants.head._1, 
fieldToVariants.head._2) {
+        case ((key, variantVals), (k, v)) =>
+          if (!SQLConf.get.caseSensitiveAnalysis && k.equalsIgnoreCase(key)) {
+            variantVals.addAll(v)
+            (key, variantVals)
+          } else {
+            writeVariantObjectField(key, variantVals, builder, start, 
objectFieldEntries)
+            (k, v)
+          }
+      }
+
+    writeVariantObjectField(lastFieldKey, lastFieldValue, builder, start, 
objectFieldEntries)
+
+    // Finish writing the variant object
+    builder.finishWritingObject(start, objectFieldEntries)
+  }
+
+  /**
+   * Write a single field to a variant object
+   *
+   * @param fieldName the name of the object field
+   * @param fieldVariants the variant value of the field. A field could have 
multiple variant value,
+   *                      which means it's an array field
+   * @param builder the variant builder
+   * @param objectStart the start position of the variant object in the builder
+   * @param objectFieldEntries a list tracking all fields of the variant object
+   */
+  private def writeVariantObjectField(
+      fieldName: String,
+      fieldVariants: java.util.ArrayList[Variant],
+      builder: VariantBuilder,
+      objectStart: Int,
+      objectFieldEntries: java.util.ArrayList[FieldEntry]): Unit = {
+    val start = builder.getWritePos
+    val fieldId = builder.addKey(fieldName)
+    objectFieldEntries.add(
+      new FieldEntry(fieldName, fieldId, builder.getWritePos - objectStart)
+    )
+
+    val fieldValue = if (fieldVariants.size() > 1) {
+      // If the field has more than one entry, it's an array field. Build a 
Variant
+      // array as the field value
+      val arrayBuilder = new VariantBuilder(false)
+      val arrayStart = arrayBuilder.getWritePos
+      val offsets = new util.ArrayList[Integer]()
+      fieldVariants.asScala.foreach { v =>
+        offsets.add(arrayBuilder.getWritePos - arrayStart)
+        arrayBuilder.appendVariant(v)
+      }
+      arrayBuilder.finishWritingArray(arrayStart, offsets)
+      arrayBuilder.result()
+    } else {
+      // Otherwise, just use the first variant as the field value
+      fieldVariants.get(0)
+    }
+
+    // Append the field value to the variant builder
+    builder.appendVariant(fieldValue)
+  }
+
+  /**
+   * Convert an XML Character value `s` into a variant value and append the 
result to `builder`.
+   * The result can only be one of a variant boolean/long/decimal/string. 
Anything other than
+   * the supported types will be appended to the Variant builder as a string.
+   *
+   * Floating point types (double, float) are not considered to avoid 
precision loss.
+   */
+  private def appendXMLCharacterToVariant(
+      builder: VariantBuilder,
+      s: String,
+      options: XmlOptions): Unit = {
+    if (s == null || s == options.nullValue) {
+      builder.appendNull()
+      return
+    }
+
+    val value = if (options.ignoreSurroundingSpaces) s.trim() else s
+
+    // Exit early for empty strings
+    if (value.isEmpty) {
+      builder.appendString(value)
+      return
+    }
+
+    // Try parsing the value as boolean first
+    if (value.toLowerCase(Locale.ROOT) == "true") {
+      builder.appendBoolean(true)
+      return
+    }
+    if (value.toLowerCase(Locale.ROOT) == "false") {
+      builder.appendBoolean(false)
+      return
+    }
+
+    // Try parsing the value as a long
+    allCatch opt value.toLong match {
+      case Some(l) =>
+        builder.appendLong(l)
+        return
+      case _ =>
+    }
+
+    // Try parsing the value as decimal
+    val decimalParser = ExprUtils.getDecimalParser(options.locale)
+    allCatch opt decimalParser(value) match {
+      case Some(decimalValue) =>
+        var d = decimalValue
+        if (d.scale() < 0) {
+          d = d.setScale(0)
+        }
+        if (d.scale <= VariantUtil.MAX_DECIMAL16_PRECISION &&
+            d.precision <= VariantUtil.MAX_DECIMAL16_PRECISION) {
+          builder.appendDecimal(d)
+          return
+        }
+      case _ =>
+    }
+
+    // If the character is of other primitive types, parse it as a string
+    builder.appendString(value)
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index 2fb25478e529..132bb1e35947 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -165,6 +165,11 @@ class XmlOptions(
   val charset = parameters.getOrElse(ENCODING,
     parameters.getOrElse(CHARSET, XmlOptions.DEFAULT_CHARSET))
 
+  // This option takes in a column name and specifies that the entire XML 
record should be stored
+  // as a single VARIANT type column in the table with the given column name.
+  // E.g. spark.read.format("xml").option("singleVariantColumn", "colName")
+  val singleVariantColumn = parameters.get(SINGLE_VARIANT_COLUMN)
+
   def buildXmlFactory(): XMLInputFactory = {
     XMLInputFactory.newInstance()
   }
@@ -208,6 +213,7 @@ object XmlOptions extends DataSourceOptions {
   val INDENT = newOption("indent")
   val PREFERS_DECIMAL = newOption("prefersDecimal")
   val VALIDATE_NAME = newOption("validateName")
+  val SINGLE_VARIANT_COLUMN = newOption("singleVariantColumn")
   // Options with alternative
   val ENCODING = "encoding"
   val CHARSET = "charset"
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
index fafde89001aa..23bca3572539 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
@@ -42,7 +42,7 @@ import 
org.apache.spark.sql.classic.ClassicConversions.castToImpl
 import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.text.TextFileFormat
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType, VariantType}
 import org.apache.spark.util.Utils
 
 /**
@@ -67,10 +67,14 @@ abstract class XmlDataSource extends Serializable with 
Logging {
       sparkSession: SparkSession,
       inputPaths: Seq[FileStatus],
       parsedOptions: XmlOptions): Option[StructType] = {
-    if (inputPaths.nonEmpty) {
-      Some(infer(sparkSession, inputPaths, parsedOptions))
-    } else {
-      None
+    parsedOptions.singleVariantColumn match {
+      case Some(columnName) => Some(StructType(Array(StructField(columnName, 
VariantType))))
+      case None =>
+        if (inputPaths.nonEmpty) {
+          Some(infer(sparkSession, inputPaths, parsedOptions))
+        } else {
+          None
+        }
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
index 2448f43d651d..eb647c41d0d1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
@@ -137,7 +137,7 @@ class XmlFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
   override def equals(other: Any): Boolean = other.isInstanceOf[XmlFileFormat]
 
   override def supportDataType(dataType: DataType): Boolean = dataType match {
-    case _: VariantType => false
+    case _: VariantType => true
 
     case _: TimeType => false
     case _: AtomicType => true
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala
index f9d003572a22..afb0ceac5b50 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala
@@ -126,10 +126,10 @@ class XmlFunctionsSuite extends QueryTest with 
SharedSparkSession {
       exception = intercept[AnalysisException] {
         Seq("1").toDS().select(from_xml($"value", lit("ARRAY<int>"), 
Map[String, String]().asJava))
       },
-      condition = "INVALID_SCHEMA.NON_STRUCT_TYPE",
+      condition = "DATATYPE_MISMATCH.INVALID_XML_SCHEMA",
       parameters = Map(
-        "inputSchema" -> "\"ARRAY<int>\"",
-        "dataType" -> "\"ARRAY<INT>\""
+        "schema" -> "\"ARRAY<INT>\"",
+        "sqlExpr" -> "\"from_xml(value)\""
       ),
       context = ExpectedContext(fragment = "from_xml", 
getCurrentClassCallSitePattern)
     )
@@ -138,10 +138,10 @@ class XmlFunctionsSuite extends QueryTest with 
SharedSparkSession {
       exception = intercept[AnalysisException] {
         Seq("1").toDF("xml").selectExpr(s"from_xml(xml, 'ARRAY<int>')")
       },
-      condition = "INVALID_SCHEMA.NON_STRUCT_TYPE",
+      condition = "DATATYPE_MISMATCH.INVALID_XML_SCHEMA",
       parameters = Map(
-        "inputSchema" -> "\"ARRAY<int>\"",
-        "dataType" -> "\"ARRAY<INT>\""
+        "schema" -> "\"ARRAY<INT>\"",
+        "sqlExpr" -> "\"from_xml(xml)\""
       ),
       context = ExpectedContext(
         fragment = "from_xml(xml, 'ARRAY<int>')",
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index 560292b263ba..5c4f4a96aee1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -3371,6 +3371,34 @@ class XmlSuite
     val outputString = writer.toString
     assert(outputString == testString)
   }
+
+  test("from_xml only allow StructType or VariantType") {
+    val xmlData = "<a>1</a>"
+    val df = Seq((8, xmlData)).toDF("number", "payload")
+
+    Seq(
+      "array<string>",
+      "map<string, string>"
+    ).foreach { schema =>
+      val exception = intercept[AnalysisException](
+        df.withColumn(
+            "decoded",
+            from_xml(df.col("payload"), schema, Map[String, String]().asJava)
+          )
+      )
+      assert(exception.getCondition.contains("INVALID_XML_SCHEMA"))
+    }
+
+    Seq(
+      "struct<a:string>",
+      "variant"
+    ).foreach { schema =>
+      df.withColumn(
+        "decoded",
+        from_xml(df.col("payload"), schema, Map[String, String]().asJava)
+      )
+    }
+  }
 }
 
 // Mock file system that checks the number of open files
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala
new file mode 100644
index 000000000000..e43ef005ecb8
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala
@@ -0,0 +1,586 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.xml
+
+import java.time.ZoneOffset
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
+import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, XmlOptions}
+import org.apache.spark.sql.functions.{col, variant_get}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+class XmlVariantSuite extends QueryTest with SharedSparkSession with 
TestXmlData {
+
+  private val baseOptions = Map("rowTag" -> "ROW", "valueTag" -> "_VALUE", 
"attributePrefix" -> "_")
+
+  private val resDir = "test-data/xml-resources/"
+
+  // ==========================
+  // ====== Parser tests ======
+  // ==========================
+
+  private def testParser(
+      xml: String,
+      expectedJsonStr: String,
+      extraOptions: Map[String, String] = Map.empty): Unit = {
+    val parsed = StaxXmlParser.parseVariant(xml, XmlOptions(baseOptions ++ 
extraOptions))
+    assert(parsed.toJson(ZoneOffset.UTC) == expectedJsonStr)
+  }
+
+  test("Parser: parse primitive XML elements (long, decimal, double, etc.) as 
variants") {
+    // Boolean -> Boolean
+    testParser("<ROW><isActive>false</isActive></ROW>", 
"""{"isActive":false}""")
+    testParser("<ROW><isActive>true</isActive></ROW>", """{"isActive":true}""")
+
+    // Long -> Long
+    testParser("<ROW><id>2</id></ROW>", """{"id":2}""")
+    testParser("<ROW><id>+2</id></ROW>", """{"id":2}""")
+    testParser("<ROW><id>-2</id></ROW>", """{"id":-2}""")
+
+    // Decimal -> Decimal
+    testParser(
+      xml = "<ROW><price>158,058,049.001</price></ROW>",
+      expectedJsonStr = """{"price":158058049.001}"""
+    )
+    testParser(
+      xml = "<ROW><decimal>10.05</decimal></ROW>",
+      expectedJsonStr = """{"decimal":10.05}"""
+    )
+    testParser(
+      xml = "<ROW><amount>5.0</amount></ROW>",
+      expectedJsonStr = """{"amount":5}"""
+    )
+    // This is parsed as String, because it is too large for Decimal
+    testParser(
+      xml = "<ROW><amount>1e40</amount></ROW>",
+      expectedJsonStr = """{"amount":"1e40"}"""
+    )
+
+    // Date -> String
+    testParser(
+      xml = "<ROW><createdAt>2023-10-01</createdAt></ROW>",
+      expectedJsonStr = """{"createdAt":"2023-10-01"}"""
+    )
+
+    // Timestamp -> String
+    testParser(
+      xml = "<ROW><createdAt>2023-10-01T12:00:00Z</createdAt></ROW>",
+      expectedJsonStr = """{"createdAt":"2023-10-01T12:00:00Z"}"""
+    )
+
+    // String -> String
+    testParser("<ROW><name>Sam</name></ROW>", """{"name":"Sam"}""")
+    // Strings with spaces
+    testParser(
+      "<ROW><note>  hello world  </note></ROW>",
+      expectedJsonStr = """{"note":"  hello world  "}""",
+      extraOptions = Map("ignoreSurroundingSpaces" -> "false")
+    )
+    testParser(
+      xml = "<ROW><note>  hello world  </note></ROW>",
+      expectedJsonStr = """{"note":"hello world"}"""
+    )
+  }
+
+  test("Parser: parse XML attributes as variants") {
+    // XML elements with only attributes
+    testParser(
+      xml = "<ROW id=\"2\"></ROW>",
+      expectedJsonStr = """{"_id":2}"""
+    )
+    testParser(
+      xml = "<ROW><a><b attr=\"1\"></b></a></ROW>",
+      expectedJsonStr = """{"a":{"b":{"_attr":1}}}"""
+    )
+    testParser(
+      xml = "<ROW id=\"2\" name=\"Sam\" amount=\"93\"></ROW>",
+      expectedJsonStr = """{"_amount":93,"_id":2,"_name":"Sam"}"""
+    )
+
+    // XML elements with attributes and elements
+    testParser(
+      xml = "<ROW id=\"2\" name=\"Sam\"><amount>93</amount></ROW>",
+      expectedJsonStr = """{"_id":2,"_name":"Sam","amount":93}"""
+    )
+
+    // XML elements with attributes and nested elements
+    testParser(
+      xml = "<ROW id=\"2\" 
name=\"Sam\"><info><amount>93</amount></info></ROW>",
+      expectedJsonStr = """{"_id":2,"_name":"Sam","info":{"amount":93}}"""
+    )
+
+    // XML elements with attributes and value tag
+    testParser(
+      xml = "<ROW id=\"2\" name=\"Sam\">93</ROW>",
+      expectedJsonStr = """{"_VALUE":93,"_id":2,"_name":"Sam"}"""
+    )
+  }
+
+  test("Parser: parse XML value tags as variants") {
+    // XML elements with value tags and attributes
+    testParser(
+      xml = "<ROW id=\"2\" name=\"Sam\">93</ROW>",
+      expectedJsonStr = """{"_VALUE":93,"_id":2,"_name":"Sam"}"""
+    )
+
+    // XML elements with value tags and nested elements
+    testParser(
+      xml = "<ROW><info>Sam<amount>93</amount></info></ROW>",
+      expectedJsonStr = """{"info":{"_VALUE":"Sam","amount":93}}"""
+    )
+  }
+
+  test("Parser: parse XML elements as variant object") {
+    testParser(
+      xml = "<ROW><info><name>Sam</name><amount>93</amount></info></ROW>",
+      expectedJsonStr = """{"info":{"amount":93,"name":"Sam"}}"""
+    )
+  }
+
+  test("Parser: parse XML elements as variant array") {
+    testParser(
+      xml = "<ROW><array>1</array><array>2</array></ROW>",
+      expectedJsonStr = """{"array":[1,2]}"""
+    )
+  }
+
+  test("Parser: null and empty XML elements are parsed as variant null") {
+    // XML elements with null and empty values
+    testParser(
+      xml = """<ROW><name></name><amount>93</amount><space> </space><newline>
+                      </newline></ROW>""",
+      expectedJsonStr = 
"""{"amount":93,"name":null,"newline":null,"space":null}"""
+    )
+    testParser(
+      xml = "<ROW><name>Sam</name><amount>n/a</amount></ROW>",
+      expectedJsonStr = """{"amount":null,"name":"Sam"}""",
+      extraOptions = Map("nullValue" -> "n/a")
+    )
+  }
+
+  test("Parser: Parse whitespaces with quotes") {
+    // XML elements with whitespaces
+    testParser(
+      xml = s"""
+               |<ROW>
+               |  <a>" "</a>
+               |  <b>" "<c>1</c></b>
+               |  <d><e attr=" "></e></d>
+               |</ROW>
+               |""".stripMargin,
+      expectedJsonStr = """{"a":"\" \"","b":{"_VALUE":"\" 
\"","c":1},"d":{"e":{"_attr":" "}}}""",
+      extraOptions = Map("ignoreSurroundingSpaces" -> "false")
+    )
+  }
+
+  test("Parser: Comments are ignored") {
+    testParser(
+      xml = """
+              |<ROW>
+              |   <!-- comment -->
+              |   <name><!-- before value --> Sam <!-- after value --></name>
+              |   <!-- comment -->
+              |   <amount>93</amount>
+              |   <!-- <a>1</a> -->
+              |</ROW>
+              |""".stripMargin,
+      expectedJsonStr = """{"amount":93,"name":"Sam"}"""
+    )
+  }
+
+  test("Parser: CDATA should be handled properly") {
+    testParser(
+      xml = """
+              |<!-- CDATA outside row should be ignored -->
+              |<ROW>
+              |   <name><![CDATA[Sam]]></name>
+              |   <amount>93</amount>
+              |</ROW>
+              |""".stripMargin,
+      expectedJsonStr = """{"amount":93,"name":"Sam"}"""
+    )
+  }
+
+  test("Parser: parse mixed types as variants") {
+    val expectedJsonStr =
+      """
+        |{
+        |   "arrayOfArray1":[
+        |     {"item":[1,2,3]},
+        |     {"item":["str1","str2"]}
+        |   ],
+        |   "arrayOfArray2":[
+        |     {"item":[1,2,3]},
+        |     {"item":[1.1,2.1,3.1]}
+        |   ],
+        |   "arrayOfBigInteger":[922337203685477580700,-922337203685477580800],
+        |   "arrayOfBoolean":[true,false,true],
+        |   
"arrayOfDouble":[1.2,1.7976931348623157,"4.9E-324","2.2250738585072014E-308"],
+        |   "arrayOfInteger":[1,2147483647,-2147483648],
+        |   
"arrayOfLong":[21474836470,9223372036854775807,-9223372036854775808],
+        |   "arrayOfNull":[null,null],
+        |   "arrayOfString":["str1","str2"],
+        |   "arrayOfStruct":[
+        |     {"field1":true,"field2":"str1"},
+        |     {"field1":false},
+        |     {"field3":null}
+        |   ],
+        |   "struct":{
+        |     "field1":true,
+        |     "field2":92233720368547758070
+        |   },
+        |   "structWithArrayFields":{
+        |     "field1":[4,5,6],
+        |     "field2":["str1","str2"]
+        |   }
+        |}
+        |""".stripMargin.replaceAll("\\s+", "")
+    testParser(
+      xml = complexFieldAndType1.head,
+      expectedJsonStr = expectedJsonStr
+    )
+
+    val expectedJsonStr2 =
+      """
+        |{
+        |   "arrayOfArray1":[
+        |     {"array":{"item":5}},
+        |     {
+        |       "array":[
+        |         {"item":[6,7]},
+        |         {"item":8}
+        |       ]
+        |     }
+        |   ],
+        |   "arrayOfArray2":[
+        |     {"array":{"item":{"inner1":"str1"}}},
+        |     {
+        |       "array":[
+        |         null,
+        |         {
+        |           "item":[
+        |             {"inner2":["str3","str33"]},
+        |             {"inner1":"str11","inner2":"str4"}
+        |           ]
+        |         }
+        |       ]
+        |     },
+        |     {
+        |       "array":{
+        |         "item":{
+        |           "inner3":[
+        |             {"inner4":[2,3]},
+        |             null
+        |           ]
+        |         }
+        |       }
+        |     }
+        |   ]
+        |}
+        """.stripMargin.replaceAll("\\s+", "")
+    testParser(
+      xml = complexFieldAndType2.head,
+      expectedJsonStr = expectedJsonStr2
+    )
+  }
+
+  test("Parser: Case sensitivity test") {
+    val xmlString =
+      """
+        |<ROW>
+        |   <a>1<b>2</b></a>
+        |   <A>3<b>4</b></A>
+        |</ROW>
+        |""".stripMargin
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      testParser(
+        xml = xmlString,
+        expectedJsonStr = """{"a":[{"_VALUE":1,"b":2},{"_VALUE":3,"b":4}]}"""
+      )
+    }
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      testParser(
+        xml = xmlString,
+        expectedJsonStr = """{"A":{"_VALUE":3,"b":4},"a":{"_VALUE":1,"b":2}}"""
+      )
+    }
+  }
+
+  test("Parser: XML array elements interspersed between other elements") {
+    testParser(
+      xml = """
+              |<ROW>
+              |   <a>1</a>
+              |   <b>2</b>
+              |   <a>3</a>
+              |</ROW>
+              |""".stripMargin,
+      expectedJsonStr = """{"a":[1,3],"b":2}"""
+    )
+
+    testParser(
+      xml = """
+              |<ROW>
+              |   value1
+              |   <a>1</a>
+              |   value2
+              |   <a>2</a>
+              |   value3
+              |</ROW>
+              |""".stripMargin,
+      expectedJsonStr = """{"_VALUE":["value1","value2","value3"],"a":[1,2]}"""
+    )
+
+    // long and double
+    testParser(
+      xml = """
+              |<ROW>
+              |  <a>
+              |    1
+              |    <b>2</b>
+              |    3
+              |    <b>4</b>
+              |    5.0
+              |  </a>
+              |</ROW>
+              |""".stripMargin,
+      expectedJsonStr = """{"a":{"_VALUE":[1,3,5],"b":[2,4]}}"""
+    )
+
+    // Comments
+    testParser(
+      xml = """
+              |<ROW>
+              |   <!-- comment -->
+              |   <a>1</a>
+              |   <!-- comment -->
+              |   <b>2</b>
+              |   <!-- comment -->
+              |   <a>3</a>
+              |</ROW>
+              |""".stripMargin,
+      expectedJsonStr = """{"a":[1,3],"b":2}"""
+    )
+  }
+
+  // =======================
+  // ====== DSL tests ======
+  // =======================
+
+  private def createDSLDataFrame(
+      fileName: String,
+      singleVariantColumn: Option[String] = None,
+      schemaDDL: Option[String] = None,
+      extraOptions: Map[String, String] = Map.empty): DataFrame = {
+    assert(
+      singleVariantColumn.isDefined || schemaDDL.isDefined,
+      "Either singleVariantColumn or schema must be defined to ingest XML 
files as variants via DSL"
+    )
+    var reader = spark.read.format("xml").options(baseOptions ++ extraOptions)
+    singleVariantColumn.foreach(
+      singleVariantColumnName =>
+        reader = reader.option("singleVariantColumn", singleVariantColumnName)
+    )
+    schemaDDL.foreach(s => reader = reader.schema(s))
+
+    reader.load(getTestResourcePath(resDir + fileName))
+  }
+
+  test("DSL: read XML files using singleVariantColumn") {
+    val df = createDSLDataFrame(fileName = "cars.xml", singleVariantColumn = 
Some("var"))
+    checkAnswer(
+      df.select(variant_get(col("var"), "$.year", "int")),
+      Seq(Row(2012), Row(1997), Row(2015))
+    )
+  }
+
+  test("DSL: read XML files with defined schema") {
+    val df = createDSLDataFrame(
+      fileName = "books-complicated.xml",
+      schemaDDL = Some(
+        "_id string, author string, title string, genre variant, price double, 
" +
+          "publish_dates variant"
+      ),
+      extraOptions = Map("rowTag" -> "book")
+    )
+    checkAnswer(
+      df.select(variant_get(col("genre"), "$.name", "string")),
+      Seq(Row("Computer"), Row("Fantasy"), Row("Fantasy"))
+    )
+  }
+
+  test("DSL: provided schema in singleVariantColumn mode") {
+    // Specified schema in singleVariantColumn mode can't contain columns 
other than the variant
+    // column and the corrupted record column
+    checkError(
+      exception = intercept[AnalysisException] {
+        createDSLDataFrame(
+          fileName = "cars.xml",
+          singleVariantColumn = Some("var"),
+          schemaDDL = Some("year variant, make string, model string, comment 
string")
+        )
+      },
+      condition = "INVALID_SINGLE_VARIANT_COLUMN",
+      parameters = Map(
+        "schema" -> """"STRUCT<year: VARIANT, make: STRING, model: STRING, 
comment: STRING>""""
+      )
+    )
+    checkError(
+      exception = intercept[AnalysisException] {
+        createDSLDataFrame(
+          fileName = "cars.xml",
+          singleVariantColumn = Some("var"),
+          schemaDDL = Some("_corrupt_record string")
+        )
+      },
+      condition = "INVALID_SINGLE_VARIANT_COLUMN",
+      parameters = Map(
+        "schema" -> """"STRUCT<_corrupt_record: STRING>""""
+      )
+    )
+
+    // Valid schema in singleVariantColumn mode
+    createDSLDataFrame(
+      fileName = "cars.xml",
+      singleVariantColumn = Some("var"),
+      schemaDDL = Some("var variant")
+    )
+    createDSLDataFrame(
+      fileName = "cars.xml",
+      singleVariantColumn = Some("var"),
+      schemaDDL = Some("var variant, _corrupt_record string")
+    )
+  }
+
+  test("DSL: handle malformed record in singleVariantColumn mode") {
+    // FAILFAST mode
+    checkError(
+      exception = intercept[SparkException] {
+        createDSLDataFrame(
+          fileName = "cars-malformed.xml",
+          singleVariantColumn = Some("var"),
+          extraOptions = Map("mode" -> "FAILFAST")
+        ).collect()
+      }.getCause.asInstanceOf[SparkException],
+      condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
+      parameters = Map(
+        "badRecord" -> "[null]",
+        "failFastMode" -> "FAILFAST")
+    )
+
+    // PERMISSIVE mode
+    val df = createDSLDataFrame(
+      fileName = "cars-malformed.xml",
+      singleVariantColumn = Some("var"),
+      extraOptions = Map("mode" -> "PERMISSIVE")
+    )
+    checkAnswer(
+      df.select(variant_get(col("var"), "$.year", "int")),
+      Seq(Row(2015), Row(null), Row(null))
+    )
+
+    // DROPMALFORMED mode
+    val df2 = createDSLDataFrame(
+      fileName = "cars-malformed.xml",
+      singleVariantColumn = Some("var"),
+      extraOptions = Map("mode" -> "DROPMALFORMED")
+    )
+    checkAnswer(
+      df2.select(variant_get(col("var"), "$.year", "int")),
+      Seq(Row(2015))
+    )
+  }
+
+  test("DSL: test XSD validation") {
+    val df = createDSLDataFrame(
+      fileName = "basket_invalid.xml",
+      singleVariantColumn = Some("var"),
+      extraOptions = Map(
+        "rowTag" -> "basket",
+        "rowValidationXSDPath" -> getTestResourcePath(resDir + 
"basket.xsd").replace("file:/", "/")
+      )
+    )
+    checkAnswer(
+      df.select(variant_get(col("var"), "$", "string")),
+      Seq(
+        // The first row matches the XSD and thus is parsed as Variant 
successfully
+        
Row("""{"entry":[{"key":1,"value":"fork"},{"key":2,"value":"cup"}]}"""),
+        // The second row fails the XSD validation and is not parsed
+        Row(null)
+      )
+    )
+  }
+
+  // =======================
+  // ====== SQL tests ======
+  // =======================
+
+  test("SQL: read an entire XML record as variant using from_xml SQL 
expression") {
+    val xmlStr =
+      """
+        |<ROW>
+        |    <year>2012<!--A comment within tags--></year>
+        |    <make>Tesla</make>
+        |    <model>S</model>
+        |    <comment>No comment</comment>
+        |</ROW>
+        |""".stripMargin
+
+    // Read the entire XML record as a single variant
+    // Verify we can extract fields from the variant type
+    checkAnswer(
+      spark
+        .sql(s"""SELECT from_xml('$xmlStr', 'variant') as var""")
+        .select(variant_get(col("var"), "$.year", "int")),
+      Seq(Row(2012))
+    )
+  }
+
+  test("SQL: read partial XML record as variant using from_xml with a defined 
schema") {
+    val xmlStr =
+      """
+        |<book>
+        |   <author>Gambardella</author>
+        |   <title>Hello</title>
+        |   <genre>
+        |     <genreid>1</genreid>
+        |     <name>Computer</name>
+        |   </genre>
+        |   <price>44.95</price>
+        |   <publish_dates>
+        |     <publish_date>
+        |       <day>1</day>
+        |       <month>10</month>
+        |       <year>2000</year>
+        |     </publish_date>
+        |   </publish_dates>
+        | </book>
+        | """.stripMargin.replaceAll("\\s+", "")
+    // Read specific elements in the XML record as variant
+    val schemaDDL =
+      "author string, title string, genre variant, price double, publish_dates 
variant"
+    // Verify we can extract fields from the variant type
+    checkAnswer(
+      spark
+        .sql(s"""SELECT from_xml('$xmlStr', '$schemaDDL') as 
book""".stripMargin)
+        .select(variant_get(col("book.publish_dates"), "$.publish_date.year", 
"int")),
+      Seq(Row(2000))
+    )
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to