This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 348b1bcff75c [SPARK-47769][SQL] Add schema_of_variant_agg expression
348b1bcff75c is described below

commit 348b1bcff75cd6eb951c5792cfe8a65cbe8aba73
Author: Chenhao Li <chenhao...@databricks.com>
AuthorDate: Tue Apr 16 13:45:37 2024 +0800

    [SPARK-47769][SQL] Add schema_of_variant_agg expression
    
    ### What changes were proposed in this pull request?
    
    This PR adds a new `schema_of_variant_agg` expression. It returns the 
merged schema in the SQL format of a variant column. Compared to 
`schema_of_variant`, which is a scalar expression and returns one schema for 
one row, the `schema_of_variant_agg` expression merges the schema of all rows.
    
    Usage examples:
    
    ```
    > SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('1'), ('2'), 
('3') AS tab(j);
     BIGINT
    > SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('{"a": 1}'), 
('{"b": true}'), ('{"c": 1.23}') AS tab(j);
     STRUCT<a: BIGINT, b: BOOLEAN, c: DECIMAL(3,2)>
    ```
    
    ### Why are the changes needed?
    
    This expression can help the user explore the content of variant values.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.  A new SQL expression is added.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45934 from chenhao-db/schema_of_variant_agg.
    
    Authored-by: Chenhao Li <chenhao...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/FunctionRegistry.scala   |  1 +
 .../expressions/variant/variantExpressions.scala   | 66 ++++++++++++++++++++++
 .../sql-functions/sql-expression-schema.md         |  1 +
 .../apache/spark/sql/VariantEndToEndSuite.scala    | 42 ++++++++++++++
 4 files changed, 110 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9447ea63b51f..c56d04b570e5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -825,6 +825,7 @@ object FunctionRegistry {
     expressionBuilder("variant_get", VariantGetExpressionBuilder),
     expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder),
     expression[SchemaOfVariant]("schema_of_variant"),
+    expression[SchemaOfVariantAgg]("schema_of_variant_agg"),
 
     // cast
     expression[Cast]("cast"),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index 8b09bf5f7de0..cab75e1996ab 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -27,11 +27,13 @@ import 
org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, 
TypedImperativeAggregate}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.json.JsonInferSchema
 import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, 
VARIANT_GET}
+import org.apache.spark.sql.catalyst.trees.UnaryLike
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, 
QueryExecutionErrors}
@@ -615,3 +617,67 @@ object SchemaOfVariant {
   def mergeSchema(t1: DataType, t2: DataType): DataType =
     JsonInferSchema.compatibleType(t1, t2, VariantType)
 }
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(v) - Returns the merged schema in the SQL format of a 
variant column.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j);
+       BIGINT
+      > SELECT _FUNC_(parse_json(j)) FROM VALUES ('{"a": 1}'), ('{"b": 
true}'), ('{"c": 1.23}') AS tab(j);
+       STRUCT<a: BIGINT, b: BOOLEAN, c: DECIMAL(3,2)>
+  """,
+  since = "4.0.0",
+  group = "variant_funcs")
+// scalastyle:on line.size.limit
+case class SchemaOfVariantAgg(
+    child: Expression,
+    override val mutableAggBufferOffset: Int,
+    override val inputAggBufferOffset: Int)
+    extends TypedImperativeAggregate[DataType]
+    with ExpectsInputTypes
+    with QueryErrorsBase
+    with UnaryLike[Expression] {
+  def this(child: Expression) = this(child, 0, 0)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType)
+
+  override def dataType: DataType = StringType
+
+  override def nullable: Boolean = false
+
+  override def createAggregationBuffer(): DataType = NullType
+
+  override def update(buffer: DataType, input: InternalRow): DataType = {
+    val inputVariant = child.eval(input).asInstanceOf[VariantVal]
+    if (inputVariant != null) {
+      val v = new Variant(inputVariant.getValue, inputVariant.getMetadata)
+      SchemaOfVariant.mergeSchema(buffer, SchemaOfVariant.schemaOf(v))
+    } else {
+      buffer
+    }
+  }
+
+  override def merge(buffer: DataType, input: DataType): DataType =
+    SchemaOfVariant.mergeSchema(buffer, input)
+
+  override def eval(buffer: DataType): Any = UTF8String.fromString(buffer.sql)
+
+  override def serialize(buffer: DataType): Array[Byte] = 
buffer.json.getBytes("UTF-8")
+
+  override def deserialize(storageFormat: Array[Byte]): DataType =
+    DataType.fromJson(new String(storageFormat, "UTF-8"))
+
+  override def prettyName: String = "schema_of_variant_agg"
+
+  override def withNewMutableAggBufferOffset(
+      newMutableAggBufferOffset: Int): ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression =
+    copy(child = newChild)
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 05491034e6c7..8b70c88332df 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -438,6 +438,7 @@
 | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance 
| SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | 
struct<variance(col):double> |
 | org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | 
SELECT parse_json('{"a":1,"b":0.8}') | 
struct<parse_json({"a":1,"b":0.8}):variant> |
 | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | 
schema_of_variant | SELECT schema_of_variant(parse_json('null')) | 
struct<schema_of_variant(parse_json(null)):string> |
+| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariantAgg | 
schema_of_variant_agg | SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES 
('1'), ('2'), ('3') AS tab(j) | 
struct<schema_of_variant_agg(parse_json(j)):string> |
 | 
org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder
 | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 
'int') | struct<try_variant_get(parse_json({"a": 1}), $.a):int> |
 | 
org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | 
variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | 
struct<variant_get(parse_json({"a": 1}), $.a):int> |
 | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | 
SELECT xpath_boolean('<a><b>1</b></a>','a/b') | 
struct<xpath_boolean(<a><b>1</b></a>, a/b):boolean> |
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
index d8b1dca21ca6..58528b918673 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
@@ -112,4 +112,46 @@ class VariantEndToEndSuite extends QueryTest with 
SharedSparkSession {
       "ARRAY<STRUCT<a: DOUBLE, b: BOOLEAN>>"
     )
   }
+
+  test("schema_of_variant_agg") {
+    // Literal input.
+    checkAnswer(
+      sql("""SELECT schema_of_variant_agg(parse_json('{"a": [1, 2, 3]}'))"""),
+      Seq(Row("STRUCT<a: ARRAY<BIGINT>>")))
+
+    // Non-grouping aggregation.
+    def checkNonGrouping(input: Seq[String], expected: String): Unit = {
+      
checkAnswer(input.toDF("json").selectExpr("schema_of_variant_agg(parse_json(json))"),
+        Seq(Row(expected)))
+    }
+
+    checkNonGrouping(Seq("""{"a": [1, 2, 3]}"""), "STRUCT<a: ARRAY<BIGINT>>")
+    checkNonGrouping((0 to 100).map(i => s"""{"a": [$i]}"""), "STRUCT<a: 
ARRAY<BIGINT>>")
+    checkNonGrouping(Seq("""[{"a": 1}, {"b": 2}]"""), "ARRAY<STRUCT<a: BIGINT, 
b: BIGINT>>")
+    checkNonGrouping(Seq("""{"a": [1, 2, 3]}""", """{"a": "banana"}"""), 
"STRUCT<a: VARIANT>")
+    checkNonGrouping(Seq("""{"a": "banana"}""", """{"b": "apple"}"""),
+      "STRUCT<a: STRING, b: STRING>")
+    checkNonGrouping(Seq("""{"a": "data"}""", null), "STRUCT<a: STRING>")
+    checkNonGrouping(Seq(null, null), "VOID")
+    checkNonGrouping(Seq("""{"a": null}""", """{"a": null}"""), "STRUCT<a: 
VOID>")
+    checkNonGrouping(Seq(
+      """{"hi":[]}""",
+      """{"hi":[{},{}]}""",
+      """{"hi":[{"it's":[{"me":[{"a": 1}]}]}]}"""),
+      "STRUCT<hi: ARRAY<STRUCT<`it's`: ARRAY<STRUCT<me: ARRAY<STRUCT<a: 
BIGINT>>>>>>>")
+
+    // Grouping aggregation.
+    withView("v") {
+      (0 to 100).map { id =>
+        val json = if (id % 4 == 0) s"""{"a": [$id]}""" else s"""{"a": 
["$id"]}"""
+        (id, json)
+      }.toDF("id", "json").createTempView("v")
+      checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v 
group by id % 2"),
+        Seq(Row("STRUCT<a: ARRAY<STRING>>"), Row("STRUCT<a: ARRAY<VARIANT>>")))
+      checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v 
group by id % 3"),
+        Seq.fill(3)(Row("STRUCT<a: ARRAY<VARIANT>>")))
+      checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v 
group by id % 4"),
+        Seq.fill(3)(Row("STRUCT<a: ARRAY<STRING>>")) ++ Seq(Row("STRUCT<a: 
ARRAY<BIGINT>>")))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to