This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 22cbb9694ca5 [SPARK-50746][SQL] Replace Either with VariantPathSegment
22cbb9694ca5 is described below
commit 22cbb9694ca53efef1d57387e14976d3906c2b15
Author: Chenhao Li <[email protected]>
AuthorDate: Tue Jan 7 13:55:46 2025 +0800
[SPARK-50746][SQL] Replace Either with VariantPathSegment
### What changes were proposed in this pull request?
It replaces `type PathSegment = Either[String, Int]` with a dedicated class
`VariantPathSegment`. There is no semantic change, but the code has clear
naming.
### Why are the changes needed?
To make the code easier to understand.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49385 from chenhao-db/VariantPathSegment.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/variant/variantExpressions.scala | 30 ++++++++++++----------
.../datasources/PushVariantIntoScan.scala | 4 +--
.../datasources/parquet/SparkShreddingUtils.scala | 15 +++++------
3 files changed, 26 insertions(+), 23 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index ba910b8c7e5f..ff8b168793b5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -184,33 +184,37 @@ case class ToVariantObject(child: Expression)
}
}
-object VariantPathParser extends RegexParsers {
- // A path segment in the `VariantGet` expression represents either an object
key access or an
- // array index access.
- type PathSegment = Either[String, Int]
+// A path segment in the `VariantGet` expression represents either an object
key access or an array
+// index access.
+sealed abstract class VariantPathSegment extends Serializable
+
+case class ObjectExtraction(key: String) extends VariantPathSegment
+case class ArrayExtraction(index: Int) extends VariantPathSegment
+
+object VariantPathParser extends RegexParsers {
private def root: Parser[Char] = '$'
// Parse index segment like `[123]`.
- private def index: Parser[PathSegment] =
+ private def index: Parser[VariantPathSegment] =
for {
index <- '[' ~> "\\d+".r <~ ']'
} yield {
- scala.util.Right(index.toInt)
+ ArrayExtraction(index.toInt)
}
// Parse key segment like `.name`, `['name']`, or `["name"]`.
- private def key: Parser[PathSegment] =
+ private def key: Parser[VariantPathSegment] =
for {
key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
"[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
} yield {
- scala.util.Left(key)
+ ObjectExtraction(key)
}
- private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key |
index))
+ private val parser: Parser[List[VariantPathSegment]] = phrase(root ~>
rep(key | index))
- def parse(str: String): Option[Array[PathSegment]] = {
+ def parse(str: String): Option[Array[VariantPathSegment]] = {
this.parseAll(parser, str) match {
case Success(result, _) => Some(result.toArray)
case _ => None
@@ -349,14 +353,14 @@ case object VariantGet {
/** The actual implementation of the `VariantGet` expression. */
def variantGet(
input: VariantVal,
- parsedPath: Array[VariantPathParser.PathSegment],
+ parsedPath: Array[VariantPathSegment],
dataType: DataType,
castArgs: VariantCastArgs): Any = {
var v = new Variant(input.getValue, input.getMetadata)
for (path <- parsedPath) {
v = path match {
- case scala.util.Left(key) if v.getType == Type.OBJECT =>
v.getFieldByKey(key)
- case scala.util.Right(index) if v.getType == Type.ARRAY =>
v.getElementAtIndex(index)
+ case ObjectExtraction(key) if v.getType == Type.OBJECT =>
v.getFieldByKey(key)
+ case ArrayExtraction(index) if v.getType == Type.ARRAY =>
v.getElementAtIndex(index)
case _ => null
}
if (v == null) return null
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index 83d219c28983..33ba4f772a13 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.variant.{VariantGet,
VariantPathParser}
+import org.apache.spark.sql.catalyst.expressions.variant._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -54,7 +54,7 @@ case class VariantMetadata(
.build()
).build()
- def parsedPath(): Array[VariantPathParser.PathSegment] = {
+ def parsedPath(): Array[VariantPathSegment] = {
VariantPathParser.parse(path).getOrElse {
val name = if (failOnError) "variant_get" else "try_variant_get"
throw QueryExecutionErrors.invalidVariantGetPath(path, name)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
index c0c490034415..ffb6704061e6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.variant._
-import
org.apache.spark.sql.catalyst.expressions.variant.VariantPathParser.PathSegment
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData,
DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.execution.RowToColumnConverter
@@ -56,9 +55,9 @@ case class SparkShreddedRow(row: SpecializedGetters) extends
ShreddingUtils.Shre
override def numElements(): Int = row.asInstanceOf[ArrayData].numElements()
}
-// The search result of a `PathSegment` in a `VariantSchema`.
+// The search result of a `VariantPathSegment` in a `VariantSchema`.
case class SchemaPathSegment(
- rawPath: PathSegment,
+ rawPath: VariantPathSegment,
// Whether this path segment is an object or array extraction.
isObject: Boolean,
// `schema.typedIdx`, if the path exists in the schema (for object
extraction, the schema
@@ -714,11 +713,11 @@ case object SparkShreddingUtils {
// found at a certain level of the file type, then `typedIdx` will
be -1 starting from
// this position, and the final `schema` will be null.
for (i <- rawPath.indices) {
- val isObject = rawPath(i).isLeft
+ val isObject = rawPath(i).isInstanceOf[ObjectExtraction]
var typedIdx = -1
var extractionIdx = -1
rawPath(i) match {
- case scala.util.Left(key) if schema != null &&
schema.objectSchema != null =>
+ case ObjectExtraction(key) if schema != null &&
schema.objectSchema != null =>
val fieldIdx = schema.objectSchemaMap.get(key)
if (fieldIdx != null) {
typedIdx = schema.typedIdx
@@ -727,7 +726,7 @@ case object SparkShreddingUtils {
} else {
schema = null
}
- case scala.util.Right(index) if schema != null &&
schema.arraySchema != null =>
+ case ArrayExtraction(index) if schema != null &&
schema.arraySchema != null =>
typedIdx = schema.typedIdx
extractionIdx = index
schema = schema.arraySchema
@@ -770,8 +769,8 @@ case object SparkShreddingUtils {
var v = new Variant(row.getBinary(variantIdx), topLevelMetadata)
while (pathIdx < pathLen) {
v = pathList(pathIdx).rawPath match {
- case scala.util.Left(key) if v.getType == Type.OBJECT =>
v.getFieldByKey(key)
- case scala.util.Right(index) if v.getType == Type.ARRAY =>
v.getElementAtIndex(index)
+ case ObjectExtraction(key) if v.getType == Type.OBJECT =>
v.getFieldByKey(key)
+ case ArrayExtraction(index) if v.getType == Type.ARRAY =>
v.getElementAtIndex(index)
case _ => null
}
if (v == null) return null
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]