This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 348b1bcff75c [SPARK-47769][SQL] Add schema_of_variant_agg expression 348b1bcff75c is described below commit 348b1bcff75cd6eb951c5792cfe8a65cbe8aba73 Author: Chenhao Li <chenhao...@databricks.com> AuthorDate: Tue Apr 16 13:45:37 2024 +0800 [SPARK-47769][SQL] Add schema_of_variant_agg expression ### What changes were proposed in this pull request? This PR adds a new `schema_of_variant_agg` expression. It returns the merged schema in the SQL format of a variant column. Compared to `schema_of_variant`, which is a scalar expression and returns one schema for one row, the `schema_of_variant_agg` expression merges the schema of all rows. Usage examples: ``` > SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j); BIGINT > SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('{"a": 1}'), ('{"b": true}'), ('{"c": 1.23}') AS tab(j); STRUCT<a: BIGINT, b: BOOLEAN, c: DECIMAL(3,2)> ``` ### Why are the changes needed? This expression can help the user explore the content of variant values. ### Does this PR introduce _any_ user-facing change? Yes. A new SQL expression is added. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45934 from chenhao-db/schema_of_variant_agg. Authored-by: Chenhao Li <chenhao...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/variant/variantExpressions.scala | 66 ++++++++++++++++++++++ .../sql-functions/sql-expression-schema.md | 1 + .../apache/spark/sql/VariantEndToEndSuite.scala | 42 ++++++++++++++ 4 files changed, 110 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9447ea63b51f..c56d04b570e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -825,6 +825,7 @@ object FunctionRegistry { expressionBuilder("variant_get", VariantGetExpressionBuilder), expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder), expression[SchemaOfVariant]("schema_of_variant"), + expression[SchemaOfVariantAgg]("schema_of_variant_agg"), // cast expression[Cast]("cast"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 8b09bf5f7de0..cab75e1996ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -27,11 +27,13 @@ import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.json.JsonInferSchema import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET} +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} @@ -615,3 +617,67 @@ object SchemaOfVariant { def mergeSchema(t1: DataType, t2: DataType): DataType = JsonInferSchema.compatibleType(t1, t2, VariantType) } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(v) - Returns the merged schema in the SQL format of a variant column.", + examples = """ + Examples: + > SELECT _FUNC_(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j); + BIGINT + > SELECT _FUNC_(parse_json(j)) FROM VALUES ('{"a": 1}'), ('{"b": true}'), ('{"c": 1.23}') AS tab(j); + STRUCT<a: BIGINT, b: BOOLEAN, c: DECIMAL(3,2)> + """, + since = "4.0.0", + group = "variant_funcs") +// scalastyle:on line.size.limit +case class SchemaOfVariantAgg( + child: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[DataType] + with ExpectsInputTypes + with QueryErrorsBase + with UnaryLike[Expression] { + def this(child: Expression) = this(child, 0, 0) + + override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) + + override def dataType: DataType = StringType + + override def nullable: Boolean = false + + override def createAggregationBuffer(): DataType = NullType + + override def update(buffer: DataType, input: InternalRow): DataType = { + val inputVariant = child.eval(input).asInstanceOf[VariantVal] + if (inputVariant != null) { + val v = new Variant(inputVariant.getValue, inputVariant.getMetadata) + SchemaOfVariant.mergeSchema(buffer, SchemaOfVariant.schemaOf(v)) + } else { + buffer + } + } + + override def merge(buffer: DataType, input: DataType): DataType = + SchemaOfVariant.mergeSchema(buffer, input) + + override def eval(buffer: DataType): Any = UTF8String.fromString(buffer.sql) + + override def serialize(buffer: DataType): Array[Byte] = buffer.json.getBytes("UTF-8") + + override def deserialize(storageFormat: Array[Byte]): DataType = + DataType.fromJson(new String(storageFormat, "UTF-8")) + + override def prettyName: String = "schema_of_variant_agg" + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 05491034e6c7..8b70c88332df 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -438,6 +438,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<variance(col):double> | | org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct<parse_json({"a":1,"b":0.8}):variant> | | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct<schema_of_variant(parse_json(null)):string> | +| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariantAgg | schema_of_variant_agg | SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j) | struct<schema_of_variant_agg(parse_json(j)):string> | | org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct<try_variant_get(parse_json({"a": 1}), $.a):int> | | org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct<variant_get(parse_json({"a": 1}), $.a):int> | | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('<a><b>1</b></a>','a/b') | struct<xpath_boolean(<a><b>1</b></a>, a/b):boolean> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index d8b1dca21ca6..58528b918673 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -112,4 +112,46 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { "ARRAY<STRUCT<a: DOUBLE, b: BOOLEAN>>" ) } + + test("schema_of_variant_agg") { + // Literal input. + checkAnswer( + sql("""SELECT schema_of_variant_agg(parse_json('{"a": [1, 2, 3]}'))"""), + Seq(Row("STRUCT<a: ARRAY<BIGINT>>"))) + + // Non-grouping aggregation. + def checkNonGrouping(input: Seq[String], expected: String): Unit = { + checkAnswer(input.toDF("json").selectExpr("schema_of_variant_agg(parse_json(json))"), + Seq(Row(expected))) + } + + checkNonGrouping(Seq("""{"a": [1, 2, 3]}"""), "STRUCT<a: ARRAY<BIGINT>>") + checkNonGrouping((0 to 100).map(i => s"""{"a": [$i]}"""), "STRUCT<a: ARRAY<BIGINT>>") + checkNonGrouping(Seq("""[{"a": 1}, {"b": 2}]"""), "ARRAY<STRUCT<a: BIGINT, b: BIGINT>>") + checkNonGrouping(Seq("""{"a": [1, 2, 3]}""", """{"a": "banana"}"""), "STRUCT<a: VARIANT>") + checkNonGrouping(Seq("""{"a": "banana"}""", """{"b": "apple"}"""), + "STRUCT<a: STRING, b: STRING>") + checkNonGrouping(Seq("""{"a": "data"}""", null), "STRUCT<a: STRING>") + checkNonGrouping(Seq(null, null), "VOID") + checkNonGrouping(Seq("""{"a": null}""", """{"a": null}"""), "STRUCT<a: VOID>") + checkNonGrouping(Seq( + """{"hi":[]}""", + """{"hi":[{},{}]}""", + """{"hi":[{"it's":[{"me":[{"a": 1}]}]}]}"""), + "STRUCT<hi: ARRAY<STRUCT<`it's`: ARRAY<STRUCT<me: ARRAY<STRUCT<a: BIGINT>>>>>>>") + + // Grouping aggregation. + withView("v") { + (0 to 100).map { id => + val json = if (id % 4 == 0) s"""{"a": [$id]}""" else s"""{"a": ["$id"]}""" + (id, json) + }.toDF("id", "json").createTempView("v") + checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 2"), + Seq(Row("STRUCT<a: ARRAY<STRING>>"), Row("STRUCT<a: ARRAY<VARIANT>>"))) + checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 3"), + Seq.fill(3)(Row("STRUCT<a: ARRAY<VARIANT>>"))) + checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 4"), + Seq.fill(3)(Row("STRUCT<a: ARRAY<STRING>>")) ++ Seq(Row("STRUCT<a: ARRAY<BIGINT>>"))) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org