This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dd153307cb97 [SPARK-50953][PYTHON][CONNECT] Add support for
non-literal paths in VariantGet
dd153307cb97 is described below
commit dd153307cb9735fd05a41124eca2a136f40f3b3f
Author: Harsh Motwani <[email protected]>
AuthorDate: Mon Feb 10 21:46:18 2025 +0800
[SPARK-50953][PYTHON][CONNECT] Add support for non-literal paths in
VariantGet
### What changes were proposed in this pull request?
This PR allows the `variant_get` expression to support non-literal path
inputs.
### Why are the changes needed?
Currently, `variant_get` only supports literal paths as the second
argument. Users may have columns containing paths which they would want to
extract from variants. This PR allows this functionality.
### Does this PR introduce _any_ user-facing change?
Yes, prior to this PR, `variant_get` did not have support for non-literal
paths.
### How was this patch tested?
Unit tests to make sure that:
1. The VariantGet/TryVariantGet expressions with non-literal paths has the
expected behavior regardless of codegen mode.
2. VariantGet expressions with non-literal paths do not get pushed down as
this functionality has not been implemented.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49609 from harshmotw-db/harsh-motwani_data/variant_get_column.
Authored-by: Harsh Motwani <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
python/pyspark/sql/connect/functions/builtin.py | 16 +++-
python/pyspark/sql/functions/builtin.py | 49 ++++++++----
python/pyspark/sql/tests/test_functions.py | 8 +-
.../scala/org/apache/spark/sql/functions.scala | 40 +++++++++-
.../expressions/variant/variantExpressions.scala | 66 +++++++++++-----
.../datasources/PushVariantIntoScan.scala | 8 +-
.../scala/org/apache/spark/sql/VariantSuite.scala | 88 ++++++++++++++++++++++
.../datasources/PushVariantIntoScanSuite.scala | 22 +++++-
8 files changed, 253 insertions(+), 44 deletions(-)
diff --git a/python/pyspark/sql/connect/functions/builtin.py
b/python/pyspark/sql/connect/functions/builtin.py
index f13eeab12dd3..51685def7dbc 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -2161,15 +2161,23 @@ def is_variant_null(v: "ColumnOrName") -> Column:
is_variant_null.__doc__ = pysparkfuncs.is_variant_null.__doc__
-def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
- return _invoke_function("variant_get", _to_col(v), lit(path),
lit(targetType))
+def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str)
-> Column:
+ assert isinstance(path, (Column, str))
+ if isinstance(path, str):
+ return _invoke_function("variant_get", _to_col(v), lit(path),
lit(targetType))
+ else:
+ return _invoke_function("variant_get", _to_col(v), path,
lit(targetType))
variant_get.__doc__ = pysparkfuncs.variant_get.__doc__
-def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
- return _invoke_function("try_variant_get", _to_col(v), lit(path),
lit(targetType))
+def try_variant_get(v: "ColumnOrName", path: Union[Column, str], targetType:
str) -> Column:
+ assert isinstance(path, (Column, str))
+ if isinstance(path, str):
+ return _invoke_function("try_variant_get", _to_col(v), lit(path),
lit(targetType))
+ else:
+ return _invoke_function("try_variant_get", _to_col(v), path,
lit(targetType))
try_variant_get.__doc__ = pysparkfuncs.try_variant_get.__doc__
diff --git a/python/pyspark/sql/functions/builtin.py
b/python/pyspark/sql/functions/builtin.py
index 4575bf730fca..2b6d8569fdf8 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -20427,7 +20427,7 @@ def is_variant_null(v: "ColumnOrName") -> Column:
@_try_remote_functions
-def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
+def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str)
-> Column:
"""
Extracts a sub-variant from `v` according to `path`, and then cast the
sub-variant to
`targetType`. Returns null if the path does not exist. Throws an exception
if the cast fails.
@@ -20438,9 +20438,10 @@ def variant_get(v: "ColumnOrName", path: str,
targetType: str) -> Column:
----------
v : :class:`~pyspark.sql.Column` or str
a variant column or column name
- path : str
- the extraction path. A valid path should start with `$` and is
followed by zero or more
- segments like `[123]`, `.name`, `['name']`, or `["name"]`.
+ path : :class:`~pyspark.sql.Column` or str
+ a column containing the extraction path strings or a string
representing the extraction
+ path. A valid path should start with `$` and is followed by zero or
more segments like
+ `[123]`, `.name`, `['name']`, or `["name"]`.
targetType : str
the target data type to cast into, in a DDL-formatted string
@@ -20451,21 +20452,29 @@ def variant_get(v: "ColumnOrName", path: str,
targetType: str) -> Column:
Examples
--------
- >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
+ >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }''', 'path':
'$.a'} ])
>>> df.select(variant_get(parse_json(df.json), "$.a",
"int").alias("r")).collect()
[Row(r=1)]
>>> df.select(variant_get(parse_json(df.json), "$.b",
"int").alias("r")).collect()
[Row(r=None)]
+ >>> df.select(variant_get(parse_json(df.json), df.path,
"int").alias("r")).collect()
+ [Row(r=1)]
"""
from pyspark.sql.classic.column import _to_java_column
- return _invoke_function(
- "variant_get", _to_java_column(v), _enum_to_value(path),
_enum_to_value(targetType)
- )
+ assert isinstance(path, (Column, str))
+ if isinstance(path, str):
+ return _invoke_function(
+ "variant_get", _to_java_column(v), _enum_to_value(path),
_enum_to_value(targetType)
+ )
+ else:
+ return _invoke_function(
+ "variant_get", _to_java_column(v), _to_java_column(path),
_enum_to_value(targetType)
+ )
@_try_remote_functions
-def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
+def try_variant_get(v: "ColumnOrName", path: Union[Column, str], targetType:
str) -> Column:
"""
Extracts a sub-variant from `v` according to `path`, and then cast the
sub-variant to
`targetType`. Returns null if the path does not exist or the cast fails.
@@ -20476,9 +20485,10 @@ def try_variant_get(v: "ColumnOrName", path: str,
targetType: str) -> Column:
----------
v : :class:`~pyspark.sql.Column` or str
a variant column or column name
- path : str
- the extraction path. A valid path should start with `$` and is
followed by zero or more
- segments like `[123]`, `.name`, `['name']`, or `["name"]`.
+ path : :class:`~pyspark.sql.Column` or str
+ a column containing the extraction path strings or a string
representing the extraction
+ path. A valid path should start with `$` and is followed by zero or
more segments like
+ `[123]`, `.name`, `['name']`, or `["name"]`.
targetType : str
the target data type to cast into, in a DDL-formatted string
@@ -20489,19 +20499,26 @@ def try_variant_get(v: "ColumnOrName", path: str,
targetType: str) -> Column:
Examples
--------
- >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
+ >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }''', 'path':
'$.a'} ])
>>> df.select(try_variant_get(parse_json(df.json), "$.a",
"int").alias("r")).collect()
[Row(r=1)]
>>> df.select(try_variant_get(parse_json(df.json), "$.b",
"int").alias("r")).collect()
[Row(r=None)]
>>> df.select(try_variant_get(parse_json(df.json), "$.a",
"binary").alias("r")).collect()
[Row(r=None)]
+ >>> df.select(try_variant_get(parse_json(df.json), df.path,
"int").alias("r")).collect()
+ [Row(r=1)]
"""
from pyspark.sql.classic.column import _to_java_column
- return _invoke_function(
- "try_variant_get", _to_java_column(v), _enum_to_value(path),
_enum_to_value(targetType)
- )
+ if isinstance(path, str):
+ return _invoke_function(
+ "try_variant_get", _to_java_column(v), _enum_to_value(path),
_enum_to_value(targetType)
+ )
+ else:
+ return _invoke_function(
+ "try_variant_get", _to_java_column(v), _to_java_column(path),
_enum_to_value(targetType)
+ )
@_try_remote_functions
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index 39db72b235bf..b627bc793f05 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1496,7 +1496,9 @@ class FunctionsTestsMixin:
self.assertEqual("""{"b":[{"c":"str2"}]}""", actual["var_lit"])
def test_variant_expressions(self):
- df = self.spark.createDataFrame([Row(json="""{ "a" : 1 }"""),
Row(json="""{ "b" : 2 }""")])
+ df = self.spark.createDataFrame(
+ [Row(json="""{ "a" : 1 }""", path="$.a"), Row(json="""{ "b" : 2
}""", path="$.b")]
+ )
v = F.parse_json(df.json)
def check(resultDf, expected):
@@ -1510,6 +1512,10 @@ class FunctionsTestsMixin:
check(df.select(F.variant_get(v, "$.b", "int")), [None, 2])
check(df.select(F.variant_get(v, "$.a", "double")), [1.0, None])
+ # non-literal variant_get
+ check(df.select(F.variant_get(v, df.path, "int")), [1, 2])
+ check(df.select(F.try_variant_get(v, df.path, "binary")), [None, None])
+
with self.assertRaises(SparkRuntimeException) as ex:
df.select(F.variant_get(v, "$.a", "binary")).collect()
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
index 5670e513287e..ffa3a03e4224 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
@@ -7115,7 +7115,7 @@ object functions {
def is_variant_null(v: Column): Column = Column.fn("is_variant_null", v)
/**
- * Extracts a sub-variant from `v` according to `path`, and then cast the
sub-variant to
+ * Extracts a sub-variant from `v` according to `path` string, and then cast
the sub-variant to
* `targetType`. Returns null if the path does not exist. Throws an
exception if the cast fails.
*
* @param v
@@ -7132,7 +7132,25 @@ object functions {
Column.fn("variant_get", v, lit(path), lit(targetType))
/**
- * Extracts a sub-variant from `v` according to `path`, and then cast the
sub-variant to
+ * Extracts a sub-variant from `v` according to `path` column, and then cast
the sub-variant to
+ * `targetType`. Returns null if the path does not exist. Throws an
exception if the cast fails.
+ *
+ * @param v
+ * a variant column.
+ * @param path
+ * the column containing the extraction path strings. A valid path string
should start with
+ * `$` and is followed by zero or more segments like `[123]`, `.name`,
`['name']`, or
+ * `["name"]`.
+ * @param targetType
+ * the target data type to cast into, in a DDL-formatted string.
+ * @group variant_funcs
+ * @since 4.0.0
+ */
+ def variant_get(v: Column, path: Column, targetType: String): Column =
+ Column.fn("variant_get", v, path, lit(targetType))
+
+ /**
+ * Extracts a sub-variant from `v` according to `path` string, and then cast
the sub-variant to
* `targetType`. Returns null if the path does not exist or the cast fails..
*
* @param v
@@ -7148,6 +7166,24 @@ object functions {
def try_variant_get(v: Column, path: String, targetType: String): Column =
Column.fn("try_variant_get", v, lit(path), lit(targetType))
+ /**
+ * Extracts a sub-variant from `v` according to `path` column, and then cast
the sub-variant to
+ * `targetType`. Returns null if the path does not exist or the cast fails..
+ *
+ * @param v
+ * a variant column.
+ * @param path
+ * the column containing the extraction path strings. A valid path string
should start with
+ * `$` and is followed by zero or more segments like `[123]`, `.name`,
`['name']`, or
+ * `["name"]`.
+ * @param targetType
+ * the target data type to cast into, in a DDL-formatted string.
+ * @group variant_funcs
+ * @since 4.0.0
+ */
+ def try_variant_get(v: Column, path: Column, targetType: String): Column =
+ Column.fn("try_variant_get", v, lit(path), lit(targetType))
+
/**
* Returns schema in the SQL format of a variant.
*
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index f722329097bc..0a72e792a04f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -246,13 +246,6 @@ case class VariantGet(
val check = super.checkInputDataTypes()
if (check.isFailure) {
check
- } else if (!path.foldable) {
- DataTypeMismatch(
- errorSubClass = "NON_FOLDABLE_INPUT",
- messageParameters = Map(
- "inputName" -> toSQLId("path"),
- "inputType" -> toSQLType(path.dataType),
- "inputExpr" -> toSQLExpr(path)))
} else if (!VariantGet.checkDataType(targetType)) {
DataTypeMismatch(
errorSubClass = "CAST_WITHOUT_SUGGESTION",
@@ -265,10 +258,12 @@ case class VariantGet(
override lazy val dataType: DataType = targetType.asNullable
- @transient private lazy val parsedPath = {
- val pathValue = path.eval().toString
- VariantPathParser.parse(pathValue).getOrElse {
- throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+ @transient private lazy val parsedPath: Option[Array[VariantPathSegment]] = {
+ if (path.foldable) {
+ val pathValue = path.eval().toString
+ Some(VariantGet.getParsedPath(pathValue, prettyName))
+ } else {
+ None
}
}
@@ -287,23 +282,37 @@ case class VariantGet(
timeZoneId,
zoneId)
- protected override def nullSafeEval(input: Any, path: Any): Any = {
- VariantGet.variantGet(input.asInstanceOf[VariantVal], parsedPath,
dataType, castArgs)
+ protected override def nullSafeEval(input: Any, path: Any): Any = parsedPath
match {
+ case Some(pp) =>
+ VariantGet.variantGet(input.asInstanceOf[VariantVal], pp, dataType,
castArgs)
+ case _ =>
+ VariantGet.variantGet(input.asInstanceOf[VariantVal],
path.asInstanceOf[UTF8String], dataType,
+ castArgs, prettyName)
}
protected override def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
- val childCode = child.genCode(ctx)
val tmp = ctx.freshVariable("tmp", classOf[Object])
- val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
+ val childCode = child.genCode(ctx)
val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
+ val (pathCode, parsedPathArg) = if (parsedPath.isEmpty) {
+ val pathCode = path.genCode(ctx)
+ (pathCode, pathCode.value)
+ } else {
+ (
+ new ExprCode(EmptyBlock, FalseLiteral, TrueLiteral),
+ ctx.addReferenceObj("parsedPath", parsedPath.get)
+ )
+ }
val code = code"""
${childCode.code}
- boolean ${ev.isNull} = ${childCode.isNull};
+ ${pathCode.code}
+ boolean ${ev.isNull} = ${childCode.isNull} || ${pathCode.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
Object $tmp =
org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
- ${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg);
+ ${childCode.value}, $parsedPathArg, $dataTypeArg,
+ $castArgsArg${if (parsedPath.isEmpty) s""", "$prettyName"""" else
""});
if ($tmp == null) {
${ev.isNull} = true;
} else {
@@ -350,6 +359,15 @@ case object VariantGet {
case _ => false
}
+ /**
+ * Get parsed Array[VariantPathSegment] from string representing path
+ */
+ def getParsedPath(pathValue: String, prettyName: String):
Array[VariantPathSegment] = {
+ VariantPathParser.parse(pathValue).getOrElse {
+ throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+ }
+ }
+
/** The actual implementation of the `VariantGet` expression. */
def variantGet(
input: VariantVal,
@@ -368,6 +386,20 @@ case object VariantGet {
VariantGet.cast(v, dataType, castArgs)
}
+ /**
+ * Implementation of the `VariantGet` expression where the path is provided
as a UTF8String
+ */
+ def variantGet(
+ input: VariantVal,
+ path: UTF8String,
+ dataType: DataType,
+ castArgs: VariantCastArgs,
+ prettyName: String): Any = {
+ val pathValue = path.toString
+ val parsedPath = VariantGet.getParsedPath(pathValue, prettyName)
+ variantGet(input, parsedPath, dataType, castArgs)
+ }
+
/**
* A simple wrapper of the `cast` function that takes `Variant` rather than
`VariantVal`. The
* `Cast` expression uses it and makes the implementation simpler.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index 33ba4f772a13..e9cc23c6a5ba 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -95,9 +95,11 @@ object RequestedVariantField {
def fullVariant: RequestedVariantField =
RequestedVariantField(VariantMetadata("$", failOnError = true, "UTC"),
VariantType)
- def apply(v: VariantGet): RequestedVariantField =
+ def apply(v: VariantGet): RequestedVariantField = {
+ assert(v.path.foldable)
RequestedVariantField(
VariantMetadata(v.path.eval().toString, v.failOnError,
v.timeZoneId.get), v.dataType)
+ }
def apply(c: Cast): RequestedVariantField =
RequestedVariantField(
@@ -212,7 +214,7 @@ class VariantInRelation {
// fields, which also changes the struct type containing it, and it is
difficult to reconstruct
// the original struct value. This is not a big loss, because we need the
full variant anyway.
def collectRequestedFields(expr: Expression): Unit = expr match {
- case v@VariantGet(StructPathToVariant(fields), _, _, _, _) =>
+ case v@VariantGet(StructPathToVariant(fields), path, _, _, _) if
path.foldable =>
addField(fields, RequestedVariantField(v))
case c@Cast(StructPathToVariant(fields), _, _, _) => addField(fields,
RequestedVariantField(c))
case IsNotNull(StructPath(_, _)) | IsNull(StructPath(_, _)) =>
@@ -240,7 +242,7 @@ class VariantInRelation {
// Rewrite patterns should be consistent with visit patterns in
`collectRequestedFields`.
expr.transformDown {
- case g@VariantGet(v@StructPathToVariant(fields), _, _, _, _) =>
+ case g@VariantGet(v@StructPathToVariant(fields), path, _, _, _) if
path.foldable =>
// Rewrite the attribute in advance, rather than depending on the last
branch to rewrite it.
// Ww need to avoid the `v@StructPathToVariant(fields)` branch to
rewrite the child again.
GetStructField(rewriteAttribute(v), fields(RequestedVariantField(g)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index 09b29b668b13..b6fe4af28ab0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -29,6 +29,8 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode,
ExpressionEvalHelper, Literal}
import
org.apache.spark.sql.catalyst.expressions.variant.{VariantExpressionEvalUtils,
VariantGet}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils,
GenericArrayData}
+import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId
+import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -108,6 +110,92 @@ class VariantSuite extends QueryTest with
SharedSparkSession with ExpressionEval
checkAnswer(df.select(try_variant_get(v, "$.a", "binary")), rows(null,
null))
}
+ test("non-literal variant_get") {
+ def rows(results: Any*): Seq[Row] = results.map(Row(_))
+
+ Seq("CODEGEN_ONLY", "NO_CODEGEN").foreach { codegenMode =>
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
+ // The first three rows have valid paths while the final row has an
invalid path
+ val df = Seq(("""{"a" : 1}""", "$.a", 2), ("""{"b" : 2}""", "$", 1),
+ ("""{"c" : 3}""", null, 1), (null, null, 1), (null, "$.a", 1),
+ ("""{"d" : 3}""", "abc", 0)).toDF("json", "path", "valid")
+ val v = parse_json(col("json"))
+ val df1 = df.where($"valid" > 0).select(variant_get(v, col("path"),
"string"))
+ checkAnswer(df1, rows("1", """{"b":2}""", null, null, null))
+
assert(df1.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+ (codegenMode == "CODEGEN_ONLY"))
+ // Invalid path
+ val df2 = df.select(variant_get(v, col("path"), "string"))
+ checkError(
+ exception = intercept[SparkRuntimeException] { df2.collect() },
+ condition = "INVALID_VARIANT_GET_PATH",
+ parameters = Map("path" -> "abc", "functionName" ->
toSQLId("variant_get")))
+
assert(df2.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+ (codegenMode == "CODEGEN_ONLY"))
+ // Invalid cast
+ val df3 = df.where($"valid" > 1).select(variant_get(v, col("path"),
"binary"))
+ checkError(
+ exception = intercept[SparkRuntimeException] { df3.collect() },
+ condition = "INVALID_VARIANT_CAST",
+ parameters = Map("value" -> "1", "dataType" -> "\"BINARY\"")
+ )
+
assert(df3.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+ (codegenMode == "CODEGEN_ONLY"))
+
+ // try_variant_get
+ val df4 = df.where($"valid" > 0).select(try_variant_get(v,
col("path"), "string"))
+ checkAnswer(df4, rows("1", """{"b":2}""", null, null, null))
+
assert(df4.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+ (codegenMode == "CODEGEN_ONLY"))
+ // Invalid path
+ val df5 = df.select(try_variant_get(v, col("path"), "string"))
+ checkError(
+ exception = intercept[SparkRuntimeException] { df5.collect() },
+ condition = "INVALID_VARIANT_GET_PATH",
+ parameters = Map("path" -> "abc", "functionName" ->
toSQLId("try_variant_get")))
+
assert(df5.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+ (codegenMode == "CODEGEN_ONLY"))
+ // Invalid cast
+ val df6 = df.where($"valid" > 1).select(try_variant_get(v,
col("path"), "binary"))
+ checkAnswer(df6, rows(null))
+
assert(df6.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+ (codegenMode == "CODEGEN_ONLY"))
+
+ // SQL API
+ withTable("t") {
+ df.withColumn("v", parse_json(col("json"))).write.saveAsTable("t")
+ // variant_get
+ checkAnswer(sql("select variant_get(v, path, 'string') from t where
valid > 0"),
+ rows("1", """{"b":2}""", null, null, null))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("select variant_get(v, path, 'string') from t").collect()
+ },
+ condition = "INVALID_VARIANT_GET_PATH",
+ parameters = Map("path" -> "abc", "functionName" ->
toSQLId("variant_get")))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("select variant_get(v, path, 'binary') from t where valid >
1").collect()
+ },
+ condition = "INVALID_VARIANT_CAST",
+ parameters = Map("value" -> "1", "dataType" -> "\"BINARY\"")
+ )
+ // try_variant_get
+ checkAnswer(sql("select try_variant_get(v, path, 'string') from t
where valid > 0"),
+ rows("1", """{"b":2}""", null, null, null))
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ sql("select try_variant_get(v, path, 'string') from t").collect()
+ },
+ condition = "INVALID_VARIANT_GET_PATH",
+ parameters = Map("path" -> "abc", "functionName" ->
toSQLId("try_variant_get")))
+ checkAnswer(sql("select try_variant_get(v, path, 'binary') from t
where valid > 1"),
+ rows(null))
+ }
+ }
+ }
+ }
+
test("round trip tests") {
val rand = new Random(42)
val input = Seq.fill(50) {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
index 2a866dcd66f0..5515c4053bc1 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
@@ -59,7 +59,7 @@ class PushVariantIntoScanSuite extends SharedSparkSession {
testOnFormats { format =>
sql("create table T (v variant, vs struct<v1 variant, v2 variant, i int>,
" +
- "va array<variant>, vd variant default parse_json('1')) " +
+ "va array<variant>, vd variant default parse_json('1'), s string) " +
s"using $format")
sql("select variant_get(v, '$.a', 'int') as a, v, cast(v as struct<b
float>) as v from T")
@@ -162,6 +162,26 @@ class PushVariantIntoScanSuite extends SharedSparkSession {
assert(vd.dataType == VariantType)
case _ => fail()
}
+
+ // No push down if the path in variant_get is not a literal
+ sql("select variant_get(v, '$.a', 'int') as a, variant_get(v, s, 'int')
v2, v, " +
+ "cast(v as struct<b float>) as v from T")
+ .queryExecution.optimizedPlan match {
+ case Project(projectList, l: LogicalRelation) =>
+ val output = l.output
+ val v = output(0)
+ val s = output(4)
+ checkAlias(projectList(0), "a", GetStructField(v, 0))
+ checkAlias(projectList(1), "v2", VariantGet(GetStructField(v, 1), s,
+ targetType = IntegerType, failOnError = true, timeZoneId =
Some(localTimeZone)))
+ checkAlias(projectList(2), "v", GetStructField(v, 1))
+ checkAlias(projectList(3), "v", GetStructField(v, 2))
+ assert(v.dataType == StructType(Array(
+ field(0, IntegerType, "$.a"),
+ field(1, VariantType, "$", timeZone = "UTC"),
+ field(2, StructType(Array(StructField("b", FloatType))), "$"))))
+ case _ => fail()
+ }
}
test("No push down for JSON") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]