This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3bc374d945ff [SPARK-50333][SQL] Codegen Support for `CsvToStructs` (by
Invoke & RuntimeReplaceable)
3bc374d945ff is described below
commit 3bc374d945ff91cda78e64c1d63fe9a95f735ebf
Author: panbingkun <[email protected]>
AuthorDate: Thu Nov 21 09:09:24 2024 +0100
[SPARK-50333][SQL] Codegen Support for `CsvToStructs` (by Invoke &
RuntimeReplaceable)
### What changes were proposed in this pull request?
The pr aims to add `Codegen` Support for `CsvToStructs`(`from_csv`).
### Why are the changes needed?
- improve codegen coverage.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Pass GA & Existed UT (eg: CsvFunctionsSuite#`*from_csv*`)
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48873 from panbingkun/from_csv_codegen.
Lead-authored-by: panbingkun <[email protected]>
Co-authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../expressions/csv/CsvExpressionEvalUtils.scala | 70 ++++++++++++++++-
.../sql/catalyst/expressions/csvExpressions.scala | 87 ++++++----------------
.../explain-results/function_from_csv.explain | 2 +-
3 files changed, 93 insertions(+), 66 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
index abd0703fa7d7..a91e4ab13001 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
@@ -18,10 +18,78 @@ package org.apache.spark.sql.catalyst.expressions.csv
import com.univocity.parsers.csv.CsvParser
-import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions}
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions,
UnivocityParser}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser,
PermissiveMode}
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{DataType, NullType, StructType}
import org.apache.spark.unsafe.types.UTF8String
+/**
+ * The expression `CsvToStructs` will utilize the `Invoke` to call it, support
codegen.
+ */
+case class CsvToStructsEvaluator(
+ options: Map[String, String],
+ nullableSchema: StructType,
+ nameOfCorruptRecord: String,
+ timeZoneId: Option[String],
+ requiredSchema: Option[StructType]) {
+
+ // This converts parsed rows to the desired output by the given schema.
+ @transient
+ private lazy val converter = (rows: Iterator[InternalRow]) => {
+ if (!rows.hasNext) {
+ throw SparkException.internalError("Expected one row from CSV parser.")
+ }
+ val result = rows.next()
+ // CSV's parser produces one record only.
+ assert(!rows.hasNext)
+ result
+ }
+
+ @transient
+ private lazy val parser = {
+ // 'lineSep' is a plan-wise option so we set a noncharacter, according to
+ // the unicode specification, which should not appear in Java's strings.
+ // See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf.
+ // scalastyle:off nonascii
+ val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString)
+ // scalastyle:on nonascii
+ val parsedOptions = new CSVOptions(
+ exprOptions,
+ columnPruning = true,
+ defaultTimeZoneId = timeZoneId.get,
+ defaultColumnNameOfCorruptRecord = nameOfCorruptRecord)
+ val mode = parsedOptions.parseMode
+ if (mode != PermissiveMode && mode != FailFastMode) {
+ throw QueryCompilationErrors.parseModeUnsupportedError("from_csv", mode)
+ }
+ ExprUtils.verifyColumnNameOfCorruptRecord(
+ nullableSchema,
+ parsedOptions.columnNameOfCorruptRecord)
+
+ val actualSchema =
+ StructType(nullableSchema.filterNot(_.name ==
parsedOptions.columnNameOfCorruptRecord))
+ val actualRequiredSchema =
+ StructType(requiredSchema.map(_.asNullable).getOrElse(nullableSchema)
+ .filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
+ val rawParser = new UnivocityParser(actualSchema,
+ actualRequiredSchema,
+ parsedOptions)
+ new FailureSafeParser[String](
+ input => rawParser.parse(input),
+ mode,
+ nullableSchema,
+ parsedOptions.columnNameOfCorruptRecord)
+ }
+
+ final def evaluate(csv: UTF8String): InternalRow = {
+ converter(parser.parse(csv.toString))
+ }
+}
+
case class SchemaOfCsvEvaluator(options: Map[String, String]) {
@transient
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index e9cdc184e55a..02e5488835c9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -19,17 +19,16 @@ package org.apache.spark.sql.catalyst.expressions
import java.io.CharArrayWriter
-import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
TypeCheckSuccess}
import org.apache.spark.sql.catalyst.csv._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
CodegenFallback, ExprCode}
-import org.apache.spark.sql.catalyst.expressions.csv.SchemaOfCsvEvaluator
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
+import org.apache.spark.sql.catalyst.expressions.csv.{CsvToStructsEvaluator,
SchemaOfCsvEvaluator}
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
-import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE,
TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils._
-import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
+import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
@@ -58,15 +57,17 @@ case class CsvToStructs(
timeZoneId: Option[String] = None,
requiredSchema: Option[StructType] = None)
extends UnaryExpression
- with TimeZoneAwareExpression
- with CodegenFallback
- with ExpectsInputTypes {
- override def nullIntolerant: Boolean = true
+ with RuntimeReplaceable
+ with ExpectsInputTypes
+ with TimeZoneAwareExpression {
+
override def nullable: Boolean = child.nullable
+ override def nodePatternsInternal(): Seq[TreePattern] =
Seq(RUNTIME_REPLACEABLE)
+
// The CSV input data might be missing certain fields. We force the
nullability
// of the user-provided schema to avoid data corruptions.
- val nullableSchema: StructType = schema.asNullable
+ private val nullableSchema: StructType = schema.asNullable
// Used in `FunctionRegistry`
def this(child: Expression, schema: Expression, options: Map[String,
String]) =
@@ -85,55 +86,7 @@ case class CsvToStructs(
child = child,
timeZoneId = None)
- // This converts parsed rows to the desired output by the given schema.
- @transient
- lazy val converter = (rows: Iterator[InternalRow]) => {
- if (rows.hasNext) {
- val result = rows.next()
- // CSV's parser produces one record only.
- assert(!rows.hasNext)
- result
- } else {
- throw SparkException.internalError("Expected one row from CSV parser.")
- }
- }
-
- val nameOfCorruptRecord =
SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
-
- @transient lazy val parser = {
- // 'lineSep' is a plan-wise option so we set a noncharacter, according to
- // the unicode specification, which should not appear in Java's strings.
- // See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf.
- // scalastyle:off nonascii
- val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString)
- // scalastyle:on nonascii
- val parsedOptions = new CSVOptions(
- exprOptions,
- columnPruning = true,
- defaultTimeZoneId = timeZoneId.get,
- defaultColumnNameOfCorruptRecord = nameOfCorruptRecord)
- val mode = parsedOptions.parseMode
- if (mode != PermissiveMode && mode != FailFastMode) {
- throw QueryCompilationErrors.parseModeUnsupportedError("from_csv", mode)
- }
- ExprUtils.verifyColumnNameOfCorruptRecord(
- nullableSchema,
- parsedOptions.columnNameOfCorruptRecord)
-
- val actualSchema =
- StructType(nullableSchema.filterNot(_.name ==
parsedOptions.columnNameOfCorruptRecord))
- val actualRequiredSchema =
- StructType(requiredSchema.map(_.asNullable).getOrElse(nullableSchema)
- .filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
- val rawParser = new UnivocityParser(actualSchema,
- actualRequiredSchema,
- parsedOptions)
- new FailureSafeParser[String](
- input => rawParser.parse(input),
- mode,
- nullableSchema,
- parsedOptions.columnNameOfCorruptRecord)
- }
+ private val nameOfCorruptRecord =
SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
override def dataType: DataType = requiredSchema.getOrElse(schema).asNullable
@@ -141,15 +94,21 @@ case class CsvToStructs(
copy(timeZoneId = Option(timeZoneId))
}
- override def nullSafeEval(input: Any): Any = {
- val csv = input.asInstanceOf[UTF8String].toString
- converter(parser.parse(csv))
- }
-
override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation ::
Nil
override def prettyName: String = "from_csv"
+ @transient
+ private lazy val evaluator: CsvToStructsEvaluator = CsvToStructsEvaluator(
+ options, nullableSchema, nameOfCorruptRecord, timeZoneId, requiredSchema)
+
+ override def replacement: Expression = Invoke(
+ Literal.create(evaluator, ObjectType(classOf[CsvToStructsEvaluator])),
+ "evaluate",
+ dataType,
+ Seq(child),
+ Seq(child.dataType))
+
override protected def withNewChildInternal(newChild: Expression):
CsvToStructs =
copy(child = newChild)
}
diff --git
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain
index 89e03c818823..ef87c18948b2 100644
---
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain
+++
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain
@@ -1,2 +1,2 @@
-Project [from_csv(StructField(id,LongType,true),
StructField(a,IntegerType,true), StructField(b,DoubleType,true),
(mode,FAILFAST), g#0, Some(America/Los_Angeles), None) AS from_csv(g)#0]
+Project [invoke(CsvToStructsEvaluator(Map(mode ->
FAILFAST),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),None).evaluate(g#0))
AS from_csv(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]