This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 104dc2b708c4 [SPARK-51503][SQL] Support Variant type in XML scan 104dc2b708c4 is described below commit 104dc2b708c44bfa948f9bef3a0bc3b27bdcd260 Author: Xiaonan Yang <xiaonan.y...@databricks.com> AuthorDate: Fri Apr 11 08:10:12 2025 +0900 [SPARK-51503][SQL] Support Variant type in XML scan ### What changes were proposed in this pull request? This PR introduces the capability of reading XML data as Variant type. It includes the Variant support in both XML scan and the `from_xml` SQL expression. Writing variant values to XML will be supported in a subsequent PR. Reading XML as Variant will follow the similar rules as reading XML as structs and use the same options. There are two modes of reading XML data as Variants: - The first mode is to ingest an entire XML file as a variant via a new `singleVariantColumn` option. The option specifies the name of the single output column with variant type. Once the option is set, the schema inference stage is skipped and a single variant type is used as the schema of the XML scan. In the parsing stage, each XML record enclosed by the `rowTag` option will be converted to a single variant object and the child elements are recursively parsed as variant types. - The second mode is specifying a column as the variant type, which can be achieved by providing a defined schema. In this mode, fields not specified as variants will be parsed as non-variant values following the existing behavior, and fields specified as variants will be parsed as variants. ### Why are the changes needed? Allow users to ingest XML data as Variants ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit tests are added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50300 from xiaonanyang-db/SPARK-51503. Authored-by: Xiaonan Yang <xiaonan.y...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/resources/error/error-conditions.json | 5 + .../expressions/xml/XmlExpressionEvalUtils.scala | 52 +- .../sql/catalyst/expressions/xmlExpressions.scala | 49 +- .../spark/sql/catalyst/xml/StaxXmlParser.scala | 263 ++++++++- .../apache/spark/sql/catalyst/xml/XmlOptions.scala | 6 + .../execution/datasources/xml/XmlDataSource.scala | 14 +- .../execution/datasources/xml/XmlFileFormat.scala | 2 +- .../org/apache/spark/sql/XmlFunctionsSuite.scala | 12 +- .../sql/execution/datasources/xml/XmlSuite.scala | 28 + .../datasources/xml/XmlVariantSuite.scala | 586 +++++++++++++++++++++ 10 files changed, 967 insertions(+), 50 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f098e27555b1..37cc80a12394 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -994,6 +994,11 @@ "Input schema <schema> can only contain STRING as a key type for a MAP." ] }, + "INVALID_XML_SCHEMA" : { + "message" : [ + "Input schema <schema> must be a struct or a variant." + ] + }, "IN_SUBQUERY_DATA_TYPE_MISMATCH" : { "message" : [ "The data type of one or more elements in the left hand side of an IN subquery is not compatible with the data type of the output of the subquery. Mismatched columns: [<mismatchedColumns>], left side: [<leftType>], right side: [<rightType>]." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala index 89d7b8d9421a..5a2cd8ad76ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.xml -import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.catalyst.xml.XmlInferSchema +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} +import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, GenericArrayData, PermissiveMode} +import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -119,3 +122,48 @@ case class XPathListEvaluator(path: UTF8String) extends XPathEvaluator { } } } + +case class XmlToStructsEvaluator( + options: Map[String, String], + nullableSchema: DataType, + nameOfCorruptRecord: String, + timeZoneId: Option[String], + child: Expression +) { + @transient lazy val parsedOptions = new XmlOptions(options, timeZoneId.get, nameOfCorruptRecord) + + // This converts parsed rows to the desired output by the given schema. + @transient + private lazy val converter = + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null + + // Parser that parse XML strings as internal rows + @transient + private lazy val parser = { + val mode = parsedOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw QueryCompilationErrors.parseModeUnsupportedError("from_xml", mode) + } + + // The parser is only used when the input schema is StructType + val schema = nullableSchema.asInstanceOf[StructType] + ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val rawParser = new StaxXmlParser(schema, parsedOptions) + + val xsdSchema = Option(parsedOptions.rowValidationXSDPath).map(ValidatorUtil.getSchema) + + new FailureSafeParser[String]( + input => rawParser.doParseColumn(input, mode, xsdSchema), + mode, + schema, + parsedOptions.columnNameOfCorruptRecord) + } + + final def evaluate(xml: UTF8String): Any = { + if (xml == null) return null + nullableSchema match { + case _: VariantType => StaxXmlParser.parseVariant(xml.toString, parsedOptions) + case _: StructType => converter(parser.parse(xml.toString)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 25a054f79c36..d6e3eea579d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -23,10 +23,10 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.expressions.xml.XmlExpressionEvalUtils -import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, FailureSafeParser, PermissiveMode} +import org.apache.spark.sql.catalyst.expressions.xml.{XmlExpressionEvalUtils, XmlToStructsEvaluator} +import org.apache.spark.sql.catalyst.util.DropMalformedMode import org.apache.spark.sql.catalyst.util.TypeUtils._ -import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} +import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, XmlInferSchema, XmlOptions} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation @@ -54,7 +54,7 @@ import org.apache.spark.unsafe.types.UTF8String since = "4.0.0") // scalastyle:on line.size.limit case class XmlToStructs( - schema: StructType, + schema: DataType, options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) @@ -65,7 +65,7 @@ case class XmlToStructs( def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = ExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = options, child = child, timeZoneId = None) @@ -81,45 +81,34 @@ case class XmlToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( - schema = ExprUtils.evalSchemaExpr(schema), + schema = ExprUtils.evalTypeExpr(schema), options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) - // This converts parsed rows to the desired output by the given schema. + override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { + case _: StructType | _: VariantType => + val checkResult = ExprUtils.checkXmlSchema(nullableSchema) + if (checkResult.isFailure) checkResult else super.checkInputDataTypes() + case _ => + DataTypeMismatch( + errorSubClass = "INVALID_XML_SCHEMA", + messageParameters = Map("schema" -> toSQLType(nullableSchema))) + } + @transient - private lazy val converter = - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null + private lazy val evaluator: XmlToStructsEvaluator = + XmlToStructsEvaluator(options, nullableSchema, nameOfCorruptRecord, timeZoneId, child) private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) - @transient - private lazy val parser = { - val parsedOptions = new XmlOptions(options, timeZoneId.get, nameOfCorruptRecord) - val mode = parsedOptions.parseMode - if (mode != PermissiveMode && mode != FailFastMode) { - throw QueryCompilationErrors.parseModeUnsupportedError("from_xml", mode) - } - ExprUtils.verifyColumnNameOfCorruptRecord( - nullableSchema, parsedOptions.columnNameOfCorruptRecord) - val rawParser = new StaxXmlParser(schema, parsedOptions) - val xsdSchema = Option(parsedOptions.rowValidationXSDPath).map(ValidatorUtil.getSchema) - - new FailureSafeParser[String]( - input => rawParser.doParseColumn(input, mode, xsdSchema), - mode, - nullableSchema, - parsedOptions.columnNameOfCorruptRecord) - } - override def dataType: DataType = nullableSchema override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } - override def nullSafeEval(xml: Any): Any = - converter(parser.parse(xml.asInstanceOf[UTF8String].toString)) + override def nullSafeEval(xml: Any): Any = evaluator.evaluate(xml.asInstanceOf[UTF8String]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val expr = ctx.addReferenceObj("this", this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 4b892da9db25..c82571d1f37f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.xml import java.io.{BufferedReader, CharConversionException, FileNotFoundException, InputStream, InputStreamReader, IOException, StringReader} import java.nio.charset.{Charset, MalformedInputException} import java.text.NumberFormat +import java.util import java.util.Locale import javax.xml.stream.{XMLEventReader, XMLStreamException} import javax.xml.stream.events._ @@ -28,6 +29,7 @@ import javax.xml.validation.Schema import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try +import scala.util.control.Exception.allCatch import scala.util.control.NonFatal import scala.xml.SAXException @@ -45,7 +47,10 @@ import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.types.variant.{Variant, VariantBuilder} +import org.apache.spark.types.variant.VariantBuilder.FieldEntry +import org.apache.spark.types.variant.VariantUtil +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} class StaxXmlParser( schema: StructType, @@ -138,11 +143,19 @@ class StaxXmlParser( xsdSchema.foreach { schema => schema.newValidator().validate(new StreamSource(new StringReader(xml))) } - val parser = StaxXmlParserUtils.filteredReader(xml) - val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser) - val result = Some(convertObject(parser, schema, rootAttributes)) - parser.close() - result + options.singleVariantColumn match { + case Some(_) => + // If the singleVariantColumn is specified, parse the entire xml string as a Variant + val v = StaxXmlParser.parseVariant(xml, options) + Some(InternalRow(v)) + case _ => + // Otherwise, parse the xml string as Structs + val parser = StaxXmlParserUtils.filteredReader(xml) + val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser) + val result = Some(convertObject(parser, schema, rootAttributes)) + parser.close() + result + } } catch { case e: SparkUpgradeException => throw e case e@(_: RuntimeException | _: XMLStreamException | _: MalformedInputException @@ -379,6 +392,10 @@ class StaxXmlParser( } row(index) = values :+ newValue + case VariantType => + val v = StaxXmlParser.convertVariant(parser, attributes, options) + row(index) = new VariantVal(v.getValue, v.getMetadata) + case dt: DataType => row(index) = convertField(parser, dt, field, attributes) } @@ -897,4 +914,238 @@ object StaxXmlParser { curRecord } } + + /** + * Parse the input XML string as a Variant value + */ + def parseVariant(xml: String, options: XmlOptions): VariantVal = { + val parser = StaxXmlParserUtils.filteredReader(xml) + val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser) + val variant = convertVariant(parser, rootAttributes, options) + val v = new VariantVal(variant.getValue, variant.getMetadata) + parser.close() + v + } + + /** + * Parse an XML element from the XML event stream into a Variant. + * This method transforms the XML element along with its attributes and child elements + * into a hierarchical Variant data structure that preserves the XML structure. + * + * @param parser The XML event stream reader positioned after the start element + * @param attributes The attributes of the current XML element to be included in the Variant + * @param options Configuration options that control how XML is parsed into Variants + * @return A Variant representing the XML element with its attributes and child content + */ + def convertVariant( + parser: XMLEventReader, + attributes: Array[Attribute], + options: XmlOptions): Variant = { + // The variant builder for the root startElement + val rootBuilder = new VariantBuilder(false) + val start = rootBuilder.getWritePos + + // Map to store the variant values of all child fields + // Each field could have multiple entries, which means it's an array + // The map is sorted by field name, and the ordering is based on the case sensitivity + val caseSensitivityOrdering: Ordering[String] = if (SQLConf.get.caseSensitiveAnalysis) { + (x: String, y: String) => x.compareTo(y) + } else { + (x: String, y: String) => x.compareToIgnoreCase(y) + } + val fieldToVariants = collection.mutable.TreeMap.empty[String, java.util.ArrayList[Variant]]( + caseSensitivityOrdering + ) + + // Handle attributes first + StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options).foreach { + case (f, v) => + val builder = new VariantBuilder(false) + appendXMLCharacterToVariant(builder, v, options) + val variants = fieldToVariants.getOrElseUpdate(f, new java.util.ArrayList[Variant]()) + variants.add(builder.result()) + } + + var shouldStop = false + while (!shouldStop) { + parser.nextEvent() match { + case s: StartElement => + // For each child element, convert it to a variant and keep track of it in + // fieldsToVariants + val attributes = s.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray + val field = StaxXmlParserUtils.getName(s.asStartElement.getName, options) + val variants = fieldToVariants.getOrElseUpdate(field, new java.util.ArrayList[Variant]()) + variants.add(convertVariant(parser, attributes, options)) + + case c: Characters if !c.isWhiteSpace => + // Treat the character as a value tag field, where we use the [[XMLOptions.valueTag]] as + // the field key + val builder = new VariantBuilder(false) + appendXMLCharacterToVariant(builder, c.getData, options) + val variants = fieldToVariants.getOrElseUpdate( + options.valueTag, + new java.util.ArrayList[Variant]() + ) + variants.add(builder.result()) + + case _: EndElement => + if (fieldToVariants.nonEmpty) { + val onlyValueTagField = fieldToVariants.keySet.forall(_ == options.valueTag) + if (onlyValueTagField) { + // If the element only has value tag field, parse the element as a variant primitive + rootBuilder.appendVariant(fieldToVariants(options.valueTag).get(0)) + } else { + writeVariantObject(rootBuilder, fieldToVariants) + } + } + shouldStop = true + + case _: EndDocument => shouldStop = true + + case _ => // do nothing + } + } + + // If the element is empty, we treat it as a Variant null + if (rootBuilder.getWritePos == start) { + rootBuilder.appendNull() + } + + rootBuilder.result() + } + + /** + * Write a variant object to the variant builder. + * + * @param builder The variant builder to write to + * @param fieldToVariants A map of field names to their corresponding variant values of the object + */ + private def writeVariantObject( + builder: VariantBuilder, + fieldToVariants: collection.mutable.TreeMap[String, java.util.ArrayList[Variant]]): Unit = { + val start = builder.getWritePos + val objectFieldEntries = new java.util.ArrayList[FieldEntry]() + + val (lastFieldKey, lastFieldValue) = + fieldToVariants.tail.foldLeft(fieldToVariants.head._1, fieldToVariants.head._2) { + case ((key, variantVals), (k, v)) => + if (!SQLConf.get.caseSensitiveAnalysis && k.equalsIgnoreCase(key)) { + variantVals.addAll(v) + (key, variantVals) + } else { + writeVariantObjectField(key, variantVals, builder, start, objectFieldEntries) + (k, v) + } + } + + writeVariantObjectField(lastFieldKey, lastFieldValue, builder, start, objectFieldEntries) + + // Finish writing the variant object + builder.finishWritingObject(start, objectFieldEntries) + } + + /** + * Write a single field to a variant object + * + * @param fieldName the name of the object field + * @param fieldVariants the variant value of the field. A field could have multiple variant value, + * which means it's an array field + * @param builder the variant builder + * @param objectStart the start position of the variant object in the builder + * @param objectFieldEntries a list tracking all fields of the variant object + */ + private def writeVariantObjectField( + fieldName: String, + fieldVariants: java.util.ArrayList[Variant], + builder: VariantBuilder, + objectStart: Int, + objectFieldEntries: java.util.ArrayList[FieldEntry]): Unit = { + val start = builder.getWritePos + val fieldId = builder.addKey(fieldName) + objectFieldEntries.add( + new FieldEntry(fieldName, fieldId, builder.getWritePos - objectStart) + ) + + val fieldValue = if (fieldVariants.size() > 1) { + // If the field has more than one entry, it's an array field. Build a Variant + // array as the field value + val arrayBuilder = new VariantBuilder(false) + val arrayStart = arrayBuilder.getWritePos + val offsets = new util.ArrayList[Integer]() + fieldVariants.asScala.foreach { v => + offsets.add(arrayBuilder.getWritePos - arrayStart) + arrayBuilder.appendVariant(v) + } + arrayBuilder.finishWritingArray(arrayStart, offsets) + arrayBuilder.result() + } else { + // Otherwise, just use the first variant as the field value + fieldVariants.get(0) + } + + // Append the field value to the variant builder + builder.appendVariant(fieldValue) + } + + /** + * Convert an XML Character value `s` into a variant value and append the result to `builder`. + * The result can only be one of a variant boolean/long/decimal/string. Anything other than + * the supported types will be appended to the Variant builder as a string. + * + * Floating point types (double, float) are not considered to avoid precision loss. + */ + private def appendXMLCharacterToVariant( + builder: VariantBuilder, + s: String, + options: XmlOptions): Unit = { + if (s == null || s == options.nullValue) { + builder.appendNull() + return + } + + val value = if (options.ignoreSurroundingSpaces) s.trim() else s + + // Exit early for empty strings + if (value.isEmpty) { + builder.appendString(value) + return + } + + // Try parsing the value as boolean first + if (value.toLowerCase(Locale.ROOT) == "true") { + builder.appendBoolean(true) + return + } + if (value.toLowerCase(Locale.ROOT) == "false") { + builder.appendBoolean(false) + return + } + + // Try parsing the value as a long + allCatch opt value.toLong match { + case Some(l) => + builder.appendLong(l) + return + case _ => + } + + // Try parsing the value as decimal + val decimalParser = ExprUtils.getDecimalParser(options.locale) + allCatch opt decimalParser(value) match { + case Some(decimalValue) => + var d = decimalValue + if (d.scale() < 0) { + d = d.setScale(0) + } + if (d.scale <= VariantUtil.MAX_DECIMAL16_PRECISION && + d.precision <= VariantUtil.MAX_DECIMAL16_PRECISION) { + builder.appendDecimal(d) + return + } + case _ => + } + + // If the character is of other primitive types, parse it as a string + builder.appendString(value) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala index 2fb25478e529..132bb1e35947 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala @@ -165,6 +165,11 @@ class XmlOptions( val charset = parameters.getOrElse(ENCODING, parameters.getOrElse(CHARSET, XmlOptions.DEFAULT_CHARSET)) + // This option takes in a column name and specifies that the entire XML record should be stored + // as a single VARIANT type column in the table with the given column name. + // E.g. spark.read.format("xml").option("singleVariantColumn", "colName") + val singleVariantColumn = parameters.get(SINGLE_VARIANT_COLUMN) + def buildXmlFactory(): XMLInputFactory = { XMLInputFactory.newInstance() } @@ -208,6 +213,7 @@ object XmlOptions extends DataSourceOptions { val INDENT = newOption("indent") val PREFERS_DECIMAL = newOption("prefersDecimal") val VALIDATE_NAME = newOption("validateName") + val SINGLE_VARIANT_COLUMN = newOption("singleVariantColumn") // Options with alternative val ENCODING = "encoding" val CHARSET = "charset" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala index fafde89001aa..23bca3572539 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType, VariantType} import org.apache.spark.util.Utils /** @@ -67,10 +67,14 @@ abstract class XmlDataSource extends Serializable with Logging { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: XmlOptions): Option[StructType] = { - if (inputPaths.nonEmpty) { - Some(infer(sparkSession, inputPaths, parsedOptions)) - } else { - None + parsedOptions.singleVariantColumn match { + case Some(columnName) => Some(StructType(Array(StructField(columnName, VariantType)))) + case None => + if (inputPaths.nonEmpty) { + Some(infer(sparkSession, inputPaths, parsedOptions)) + } else { + None + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala index 2448f43d651d..eb647c41d0d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala @@ -137,7 +137,7 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[XmlFileFormat] override def supportDataType(dataType: DataType): Boolean = dataType match { - case _: VariantType => false + case _: VariantType => true case _: TimeType => false case _: AtomicType => true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala index f9d003572a22..afb0ceac5b50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala @@ -126,10 +126,10 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("1").toDS().select(from_xml($"value", lit("ARRAY<int>"), Map[String, String]().asJava)) }, - condition = "INVALID_SCHEMA.NON_STRUCT_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_XML_SCHEMA", parameters = Map( - "inputSchema" -> "\"ARRAY<int>\"", - "dataType" -> "\"ARRAY<INT>\"" + "schema" -> "\"ARRAY<INT>\"", + "sqlExpr" -> "\"from_xml(value)\"" ), context = ExpectedContext(fragment = "from_xml", getCurrentClassCallSitePattern) ) @@ -138,10 +138,10 @@ class XmlFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { Seq("1").toDF("xml").selectExpr(s"from_xml(xml, 'ARRAY<int>')") }, - condition = "INVALID_SCHEMA.NON_STRUCT_TYPE", + condition = "DATATYPE_MISMATCH.INVALID_XML_SCHEMA", parameters = Map( - "inputSchema" -> "\"ARRAY<int>\"", - "dataType" -> "\"ARRAY<INT>\"" + "schema" -> "\"ARRAY<INT>\"", + "sqlExpr" -> "\"from_xml(xml)\"" ), context = ExpectedContext( fragment = "from_xml(xml, 'ARRAY<int>')", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 560292b263ba..5c4f4a96aee1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -3371,6 +3371,34 @@ class XmlSuite val outputString = writer.toString assert(outputString == testString) } + + test("from_xml only allow StructType or VariantType") { + val xmlData = "<a>1</a>" + val df = Seq((8, xmlData)).toDF("number", "payload") + + Seq( + "array<string>", + "map<string, string>" + ).foreach { schema => + val exception = intercept[AnalysisException]( + df.withColumn( + "decoded", + from_xml(df.col("payload"), schema, Map[String, String]().asJava) + ) + ) + assert(exception.getCondition.contains("INVALID_XML_SCHEMA")) + } + + Seq( + "struct<a:string>", + "variant" + ).foreach { schema => + df.withColumn( + "decoded", + from_xml(df.col("payload"), schema, Map[String, String]().asJava) + ) + } + } } // Mock file system that checks the number of open files diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala new file mode 100644 index 000000000000..e43ef005ecb8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.xml + +import java.time.ZoneOffset + +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, XmlOptions} +import org.apache.spark.sql.functions.{col, variant_get} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class XmlVariantSuite extends QueryTest with SharedSparkSession with TestXmlData { + + private val baseOptions = Map("rowTag" -> "ROW", "valueTag" -> "_VALUE", "attributePrefix" -> "_") + + private val resDir = "test-data/xml-resources/" + + // ========================== + // ====== Parser tests ====== + // ========================== + + private def testParser( + xml: String, + expectedJsonStr: String, + extraOptions: Map[String, String] = Map.empty): Unit = { + val parsed = StaxXmlParser.parseVariant(xml, XmlOptions(baseOptions ++ extraOptions)) + assert(parsed.toJson(ZoneOffset.UTC) == expectedJsonStr) + } + + test("Parser: parse primitive XML elements (long, decimal, double, etc.) as variants") { + // Boolean -> Boolean + testParser("<ROW><isActive>false</isActive></ROW>", """{"isActive":false}""") + testParser("<ROW><isActive>true</isActive></ROW>", """{"isActive":true}""") + + // Long -> Long + testParser("<ROW><id>2</id></ROW>", """{"id":2}""") + testParser("<ROW><id>+2</id></ROW>", """{"id":2}""") + testParser("<ROW><id>-2</id></ROW>", """{"id":-2}""") + + // Decimal -> Decimal + testParser( + xml = "<ROW><price>158,058,049.001</price></ROW>", + expectedJsonStr = """{"price":158058049.001}""" + ) + testParser( + xml = "<ROW><decimal>10.05</decimal></ROW>", + expectedJsonStr = """{"decimal":10.05}""" + ) + testParser( + xml = "<ROW><amount>5.0</amount></ROW>", + expectedJsonStr = """{"amount":5}""" + ) + // This is parsed as String, because it is too large for Decimal + testParser( + xml = "<ROW><amount>1e40</amount></ROW>", + expectedJsonStr = """{"amount":"1e40"}""" + ) + + // Date -> String + testParser( + xml = "<ROW><createdAt>2023-10-01</createdAt></ROW>", + expectedJsonStr = """{"createdAt":"2023-10-01"}""" + ) + + // Timestamp -> String + testParser( + xml = "<ROW><createdAt>2023-10-01T12:00:00Z</createdAt></ROW>", + expectedJsonStr = """{"createdAt":"2023-10-01T12:00:00Z"}""" + ) + + // String -> String + testParser("<ROW><name>Sam</name></ROW>", """{"name":"Sam"}""") + // Strings with spaces + testParser( + "<ROW><note> hello world </note></ROW>", + expectedJsonStr = """{"note":" hello world "}""", + extraOptions = Map("ignoreSurroundingSpaces" -> "false") + ) + testParser( + xml = "<ROW><note> hello world </note></ROW>", + expectedJsonStr = """{"note":"hello world"}""" + ) + } + + test("Parser: parse XML attributes as variants") { + // XML elements with only attributes + testParser( + xml = "<ROW id=\"2\"></ROW>", + expectedJsonStr = """{"_id":2}""" + ) + testParser( + xml = "<ROW><a><b attr=\"1\"></b></a></ROW>", + expectedJsonStr = """{"a":{"b":{"_attr":1}}}""" + ) + testParser( + xml = "<ROW id=\"2\" name=\"Sam\" amount=\"93\"></ROW>", + expectedJsonStr = """{"_amount":93,"_id":2,"_name":"Sam"}""" + ) + + // XML elements with attributes and elements + testParser( + xml = "<ROW id=\"2\" name=\"Sam\"><amount>93</amount></ROW>", + expectedJsonStr = """{"_id":2,"_name":"Sam","amount":93}""" + ) + + // XML elements with attributes and nested elements + testParser( + xml = "<ROW id=\"2\" name=\"Sam\"><info><amount>93</amount></info></ROW>", + expectedJsonStr = """{"_id":2,"_name":"Sam","info":{"amount":93}}""" + ) + + // XML elements with attributes and value tag + testParser( + xml = "<ROW id=\"2\" name=\"Sam\">93</ROW>", + expectedJsonStr = """{"_VALUE":93,"_id":2,"_name":"Sam"}""" + ) + } + + test("Parser: parse XML value tags as variants") { + // XML elements with value tags and attributes + testParser( + xml = "<ROW id=\"2\" name=\"Sam\">93</ROW>", + expectedJsonStr = """{"_VALUE":93,"_id":2,"_name":"Sam"}""" + ) + + // XML elements with value tags and nested elements + testParser( + xml = "<ROW><info>Sam<amount>93</amount></info></ROW>", + expectedJsonStr = """{"info":{"_VALUE":"Sam","amount":93}}""" + ) + } + + test("Parser: parse XML elements as variant object") { + testParser( + xml = "<ROW><info><name>Sam</name><amount>93</amount></info></ROW>", + expectedJsonStr = """{"info":{"amount":93,"name":"Sam"}}""" + ) + } + + test("Parser: parse XML elements as variant array") { + testParser( + xml = "<ROW><array>1</array><array>2</array></ROW>", + expectedJsonStr = """{"array":[1,2]}""" + ) + } + + test("Parser: null and empty XML elements are parsed as variant null") { + // XML elements with null and empty values + testParser( + xml = """<ROW><name></name><amount>93</amount><space> </space><newline> + </newline></ROW>""", + expectedJsonStr = """{"amount":93,"name":null,"newline":null,"space":null}""" + ) + testParser( + xml = "<ROW><name>Sam</name><amount>n/a</amount></ROW>", + expectedJsonStr = """{"amount":null,"name":"Sam"}""", + extraOptions = Map("nullValue" -> "n/a") + ) + } + + test("Parser: Parse whitespaces with quotes") { + // XML elements with whitespaces + testParser( + xml = s""" + |<ROW> + | <a>" "</a> + | <b>" "<c>1</c></b> + | <d><e attr=" "></e></d> + |</ROW> + |""".stripMargin, + expectedJsonStr = """{"a":"\" \"","b":{"_VALUE":"\" \"","c":1},"d":{"e":{"_attr":" "}}}""", + extraOptions = Map("ignoreSurroundingSpaces" -> "false") + ) + } + + test("Parser: Comments are ignored") { + testParser( + xml = """ + |<ROW> + | <!-- comment --> + | <name><!-- before value --> Sam <!-- after value --></name> + | <!-- comment --> + | <amount>93</amount> + | <!-- <a>1</a> --> + |</ROW> + |""".stripMargin, + expectedJsonStr = """{"amount":93,"name":"Sam"}""" + ) + } + + test("Parser: CDATA should be handled properly") { + testParser( + xml = """ + |<!-- CDATA outside row should be ignored --> + |<ROW> + | <name><![CDATA[Sam]]></name> + | <amount>93</amount> + |</ROW> + |""".stripMargin, + expectedJsonStr = """{"amount":93,"name":"Sam"}""" + ) + } + + test("Parser: parse mixed types as variants") { + val expectedJsonStr = + """ + |{ + | "arrayOfArray1":[ + | {"item":[1,2,3]}, + | {"item":["str1","str2"]} + | ], + | "arrayOfArray2":[ + | {"item":[1,2,3]}, + | {"item":[1.1,2.1,3.1]} + | ], + | "arrayOfBigInteger":[922337203685477580700,-922337203685477580800], + | "arrayOfBoolean":[true,false,true], + | "arrayOfDouble":[1.2,1.7976931348623157,"4.9E-324","2.2250738585072014E-308"], + | "arrayOfInteger":[1,2147483647,-2147483648], + | "arrayOfLong":[21474836470,9223372036854775807,-9223372036854775808], + | "arrayOfNull":[null,null], + | "arrayOfString":["str1","str2"], + | "arrayOfStruct":[ + | {"field1":true,"field2":"str1"}, + | {"field1":false}, + | {"field3":null} + | ], + | "struct":{ + | "field1":true, + | "field2":92233720368547758070 + | }, + | "structWithArrayFields":{ + | "field1":[4,5,6], + | "field2":["str1","str2"] + | } + |} + |""".stripMargin.replaceAll("\\s+", "") + testParser( + xml = complexFieldAndType1.head, + expectedJsonStr = expectedJsonStr + ) + + val expectedJsonStr2 = + """ + |{ + | "arrayOfArray1":[ + | {"array":{"item":5}}, + | { + | "array":[ + | {"item":[6,7]}, + | {"item":8} + | ] + | } + | ], + | "arrayOfArray2":[ + | {"array":{"item":{"inner1":"str1"}}}, + | { + | "array":[ + | null, + | { + | "item":[ + | {"inner2":["str3","str33"]}, + | {"inner1":"str11","inner2":"str4"} + | ] + | } + | ] + | }, + | { + | "array":{ + | "item":{ + | "inner3":[ + | {"inner4":[2,3]}, + | null + | ] + | } + | } + | } + | ] + |} + """.stripMargin.replaceAll("\\s+", "") + testParser( + xml = complexFieldAndType2.head, + expectedJsonStr = expectedJsonStr2 + ) + } + + test("Parser: Case sensitivity test") { + val xmlString = + """ + |<ROW> + | <a>1<b>2</b></a> + | <A>3<b>4</b></A> + |</ROW> + |""".stripMargin + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + testParser( + xml = xmlString, + expectedJsonStr = """{"a":[{"_VALUE":1,"b":2},{"_VALUE":3,"b":4}]}""" + ) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + testParser( + xml = xmlString, + expectedJsonStr = """{"A":{"_VALUE":3,"b":4},"a":{"_VALUE":1,"b":2}}""" + ) + } + } + + test("Parser: XML array elements interspersed between other elements") { + testParser( + xml = """ + |<ROW> + | <a>1</a> + | <b>2</b> + | <a>3</a> + |</ROW> + |""".stripMargin, + expectedJsonStr = """{"a":[1,3],"b":2}""" + ) + + testParser( + xml = """ + |<ROW> + | value1 + | <a>1</a> + | value2 + | <a>2</a> + | value3 + |</ROW> + |""".stripMargin, + expectedJsonStr = """{"_VALUE":["value1","value2","value3"],"a":[1,2]}""" + ) + + // long and double + testParser( + xml = """ + |<ROW> + | <a> + | 1 + | <b>2</b> + | 3 + | <b>4</b> + | 5.0 + | </a> + |</ROW> + |""".stripMargin, + expectedJsonStr = """{"a":{"_VALUE":[1,3,5],"b":[2,4]}}""" + ) + + // Comments + testParser( + xml = """ + |<ROW> + | <!-- comment --> + | <a>1</a> + | <!-- comment --> + | <b>2</b> + | <!-- comment --> + | <a>3</a> + |</ROW> + |""".stripMargin, + expectedJsonStr = """{"a":[1,3],"b":2}""" + ) + } + + // ======================= + // ====== DSL tests ====== + // ======================= + + private def createDSLDataFrame( + fileName: String, + singleVariantColumn: Option[String] = None, + schemaDDL: Option[String] = None, + extraOptions: Map[String, String] = Map.empty): DataFrame = { + assert( + singleVariantColumn.isDefined || schemaDDL.isDefined, + "Either singleVariantColumn or schema must be defined to ingest XML files as variants via DSL" + ) + var reader = spark.read.format("xml").options(baseOptions ++ extraOptions) + singleVariantColumn.foreach( + singleVariantColumnName => + reader = reader.option("singleVariantColumn", singleVariantColumnName) + ) + schemaDDL.foreach(s => reader = reader.schema(s)) + + reader.load(getTestResourcePath(resDir + fileName)) + } + + test("DSL: read XML files using singleVariantColumn") { + val df = createDSLDataFrame(fileName = "cars.xml", singleVariantColumn = Some("var")) + checkAnswer( + df.select(variant_get(col("var"), "$.year", "int")), + Seq(Row(2012), Row(1997), Row(2015)) + ) + } + + test("DSL: read XML files with defined schema") { + val df = createDSLDataFrame( + fileName = "books-complicated.xml", + schemaDDL = Some( + "_id string, author string, title string, genre variant, price double, " + + "publish_dates variant" + ), + extraOptions = Map("rowTag" -> "book") + ) + checkAnswer( + df.select(variant_get(col("genre"), "$.name", "string")), + Seq(Row("Computer"), Row("Fantasy"), Row("Fantasy")) + ) + } + + test("DSL: provided schema in singleVariantColumn mode") { + // Specified schema in singleVariantColumn mode can't contain columns other than the variant + // column and the corrupted record column + checkError( + exception = intercept[AnalysisException] { + createDSLDataFrame( + fileName = "cars.xml", + singleVariantColumn = Some("var"), + schemaDDL = Some("year variant, make string, model string, comment string") + ) + }, + condition = "INVALID_SINGLE_VARIANT_COLUMN", + parameters = Map( + "schema" -> """"STRUCT<year: VARIANT, make: STRING, model: STRING, comment: STRING>"""" + ) + ) + checkError( + exception = intercept[AnalysisException] { + createDSLDataFrame( + fileName = "cars.xml", + singleVariantColumn = Some("var"), + schemaDDL = Some("_corrupt_record string") + ) + }, + condition = "INVALID_SINGLE_VARIANT_COLUMN", + parameters = Map( + "schema" -> """"STRUCT<_corrupt_record: STRING>"""" + ) + ) + + // Valid schema in singleVariantColumn mode + createDSLDataFrame( + fileName = "cars.xml", + singleVariantColumn = Some("var"), + schemaDDL = Some("var variant") + ) + createDSLDataFrame( + fileName = "cars.xml", + singleVariantColumn = Some("var"), + schemaDDL = Some("var variant, _corrupt_record string") + ) + } + + test("DSL: handle malformed record in singleVariantColumn mode") { + // FAILFAST mode + checkError( + exception = intercept[SparkException] { + createDSLDataFrame( + fileName = "cars-malformed.xml", + singleVariantColumn = Some("var"), + extraOptions = Map("mode" -> "FAILFAST") + ).collect() + }.getCause.asInstanceOf[SparkException], + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map( + "badRecord" -> "[null]", + "failFastMode" -> "FAILFAST") + ) + + // PERMISSIVE mode + val df = createDSLDataFrame( + fileName = "cars-malformed.xml", + singleVariantColumn = Some("var"), + extraOptions = Map("mode" -> "PERMISSIVE") + ) + checkAnswer( + df.select(variant_get(col("var"), "$.year", "int")), + Seq(Row(2015), Row(null), Row(null)) + ) + + // DROPMALFORMED mode + val df2 = createDSLDataFrame( + fileName = "cars-malformed.xml", + singleVariantColumn = Some("var"), + extraOptions = Map("mode" -> "DROPMALFORMED") + ) + checkAnswer( + df2.select(variant_get(col("var"), "$.year", "int")), + Seq(Row(2015)) + ) + } + + test("DSL: test XSD validation") { + val df = createDSLDataFrame( + fileName = "basket_invalid.xml", + singleVariantColumn = Some("var"), + extraOptions = Map( + "rowTag" -> "basket", + "rowValidationXSDPath" -> getTestResourcePath(resDir + "basket.xsd").replace("file:/", "/") + ) + ) + checkAnswer( + df.select(variant_get(col("var"), "$", "string")), + Seq( + // The first row matches the XSD and thus is parsed as Variant successfully + Row("""{"entry":[{"key":1,"value":"fork"},{"key":2,"value":"cup"}]}"""), + // The second row fails the XSD validation and is not parsed + Row(null) + ) + ) + } + + // ======================= + // ====== SQL tests ====== + // ======================= + + test("SQL: read an entire XML record as variant using from_xml SQL expression") { + val xmlStr = + """ + |<ROW> + | <year>2012<!--A comment within tags--></year> + | <make>Tesla</make> + | <model>S</model> + | <comment>No comment</comment> + |</ROW> + |""".stripMargin + + // Read the entire XML record as a single variant + // Verify we can extract fields from the variant type + checkAnswer( + spark + .sql(s"""SELECT from_xml('$xmlStr', 'variant') as var""") + .select(variant_get(col("var"), "$.year", "int")), + Seq(Row(2012)) + ) + } + + test("SQL: read partial XML record as variant using from_xml with a defined schema") { + val xmlStr = + """ + |<book> + | <author>Gambardella</author> + | <title>Hello</title> + | <genre> + | <genreid>1</genreid> + | <name>Computer</name> + | </genre> + | <price>44.95</price> + | <publish_dates> + | <publish_date> + | <day>1</day> + | <month>10</month> + | <year>2000</year> + | </publish_date> + | </publish_dates> + | </book> + | """.stripMargin.replaceAll("\\s+", "") + // Read specific elements in the XML record as variant + val schemaDDL = + "author string, title string, genre variant, price double, publish_dates variant" + // Verify we can extract fields from the variant type + checkAnswer( + spark + .sql(s"""SELECT from_xml('$xmlStr', '$schemaDDL') as book""".stripMargin) + .select(variant_get(col("book.publish_dates"), "$.publish_date.year", "int")), + Seq(Row(2000)) + ) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org