This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 512099bfb77c [SPARK-54760][SQL] DelegatingCatalogExtension as session
catalog supports both V1 and V2 functions
512099bfb77c is described below
commit 512099bfb77cbd98f5ea26fc37d06b694c11dee1
Author: Cheng Pan <[email protected]>
AuthorDate: Tue Dec 23 15:56:14 2025 +0800
[SPARK-54760][SQL] DelegatingCatalogExtension as session catalog supports
both V1 and V2 functions
### What changes were proposed in this pull request?
This PR fixes a bug that occurs when the user uses a custom
`DelegatingCatalogExtension` as the session catalog, Spark can not load the v2
function properly provided by the catalog. A typical use case is Iceberg's
`SparkSessionCatalog`
```
$ spark-sql \
--conf
spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog \
...
```
```
spark-sql (default)> SELECT spark_catalog.system.iceberg_version();
[ROUTINE_NOT_FOUND] The routine `system`.`iceberg_version` cannot be found.
Verify the spelling and correctness of the schema and catalog.
If you did not qualify the name with a schema and catalog, verify the
current_schema() output, or qualify the name with the correct schema and
catalog.
To tolerate the error on drop use DROP ... IF EXISTS. SQLSTATE: 42883; line
1 pos 7
```
### Why are the changes needed?
Fix bug.
### Does this PR introduce _any_ user-facing change?
Yes, it fixes a bug.
### How was this patch tested?
Add new UT. Also manually tested with Iceberg.
```
spark-sql (default)> SELECT spark_catalog.system.iceberg_version();
1.10.0
Time taken: 1.715 seconds, Fetched 1 row(s)
```
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53531 from pan3793/SPARK-54760.
Authored-by: Cheng Pan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/FunctionResolution.scala | 23 +--
.../identifier-clause-legacy.sql.out | 8 +-
.../analyzer-results/identifier-clause.sql.out | 8 +-
.../results/identifier-clause-legacy.sql.out | 8 +-
.../sql-tests/results/identifier-clause.sql.out | 8 +-
.../DataSourceV2DataFrameSessionCatalogSuite.scala | 8 +-
.../sql/connector/DataSourceV2FunctionSuite.scala | 182 ++++++++++-----------
.../DataSourceV2SQLSessionCatalogSuite.scala | 9 +-
.../connector/SupportsCatalogOptionsSuite.scala | 9 +-
.../sql/connector/TestV2SessionCatalogBase.scala | 60 ++++++-
10 files changed, 197 insertions(+), 126 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
index 800126e0030e..8d6e2931a73b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
@@ -26,17 +26,16 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{
CatalogManager,
- CatalogV2Util,
- FunctionCatalog,
- Identifier,
LookupCatalog
}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.functions.{
AggregateFunction => V2AggregateFunction,
- ScalarFunction
+ ScalarFunction,
+ UnboundFunction
}
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
+import org.apache.spark.sql.internal.connector.V1Function
import org.apache.spark.sql.types._
class FunctionResolution(
@@ -52,10 +51,14 @@ class FunctionResolution(
resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse {
val CatalogAndIdentifier(catalog, ident) =
relationResolution.expandIdentifier(u.nameParts)
- if (CatalogV2Util.isSessionCatalog(catalog)) {
- resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
- } else {
- resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u)
+ catalog.asFunctionCatalog.loadFunction(ident) match {
+ case V1Function(_) =>
+ // this triggers the second time v1 function resolution but should
be cheap
+ // (no RPC to external catalog), since the metadata has been
already cached
+ // in FunctionRegistry during the above `catalog.loadFunction`
call.
+ resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
+ case unboundV2Func =>
+ resolveV2Function(unboundV2Func, u.arguments, u)
}
}
}
@@ -272,11 +275,9 @@ class FunctionResolution(
}
private def resolveV2Function(
- catalog: FunctionCatalog,
- ident: Identifier,
+ unbound: UnboundFunction,
arguments: Seq[Expression],
u: UnresolvedFunction): Expression = {
- val unbound = catalog.loadFunction(ident)
val inputType = StructType(arguments.zipWithIndex.map {
case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
})
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
index 94fff8f58697..95639c72a0ad 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
-- !query analysis
org.apache.spark.sql.AnalysisException
{
- "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
- "sqlState" : "42601",
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
"messageParameters" : {
- "identifier" : "`a`.`b`.`c`.`d`",
- "limit" : "2"
+ "namespace" : "`a`.`b`.`c`",
+ "sessionCatalog" : "spark_catalog"
},
"queryContext" : [ {
"objectType" : "",
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
index e6a406072c48..e3150b199658 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
-- !query analysis
org.apache.spark.sql.AnalysisException
{
- "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
- "sqlState" : "42601",
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
"messageParameters" : {
- "identifier" : "`a`.`b`.`c`.`d`",
- "limit" : "2"
+ "namespace" : "`a`.`b`.`c`",
+ "sessionCatalog" : "spark_catalog"
},
"queryContext" : [ {
"objectType" : "",
diff --git
a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
index 6a99be057010..13a4b43fd058 100644
---
a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
+++
b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
@@ -1112,11 +1112,11 @@ struct<>
-- !query output
org.apache.spark.sql.AnalysisException
{
- "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
- "sqlState" : "42601",
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
"messageParameters" : {
- "identifier" : "`a`.`b`.`c`.`d`",
- "limit" : "2"
+ "namespace" : "`a`.`b`.`c`",
+ "sessionCatalog" : "spark_catalog"
},
"queryContext" : [ {
"objectType" : "",
diff --git
a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
index 0c0473791201..beeb3b13fe1e 100644
--- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
@@ -1112,11 +1112,11 @@ struct<>
-- !query output
org.apache.spark.sql.AnalysisException
{
- "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
- "sqlState" : "42601",
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
"messageParameters" : {
- "identifier" : "`a`.`b`.`c`.`d`",
- "limit" : "2"
+ "namespace" : "`a`.`b`.`c`",
+ "sessionCatalog" : "spark_catalog"
},
"queryContext" : [ {
"objectType" : "",
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
index 8959b285b028..bc6ceeb24593 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
@@ -168,6 +168,10 @@ private [connector] trait SessionCatalogTest[T <: Table,
Catalog <: TestV2Sessio
spark.sessionState.catalogManager.catalog(name)
}
+ protected def sessionCatalog: Catalog = {
+ catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog]
+ }
+
protected val v2Format: String =
classOf[FakeV2ProviderWithCustomSchema].getName
protected val catalogClassName: String =
classOf[InMemoryTableSessionCatalog].getName
@@ -178,7 +182,9 @@ private [connector] trait SessionCatalogTest[T <: Table,
Catalog <: TestV2Sessio
override def afterEach(): Unit = {
super.afterEach()
- catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog].clearTables()
+ sessionCatalog.checkUsage()
+ sessionCatalog.clearTables()
+ sessionCatalog.clearFunctions()
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index c6f2da686fe9..366528e46ff2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -702,127 +702,127 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
comparePlans(df1.queryExecution.optimizedPlan,
df2.queryExecution.optimizedPlan)
checkAnswer(df1, Row(3) :: Nil)
}
+}
- private case object StrLenDefault extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_default"
+case object StrLenDefault extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_default"
- override def produceResult(input: InternalRow): Int = {
- val s = input.getString(0)
- s.length
- }
+ override def produceResult(input: InternalRow): Int = {
+ val s = input.getString(0)
+ s.length
}
+}
- case object StrLenMagic extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_magic"
+case object StrLenMagic extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_magic"
- def invoke(input: UTF8String): Int = {
- input.toString.length
- }
+ def invoke(input: UTF8String): Int = {
+ input.toString.length
}
+}
- case object StrLenBadMagic extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_bad_magic"
+case object StrLenBadMagic extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_bad_magic"
- def invoke(input: String): Int = {
- input.length
- }
+ def invoke(input: String): Int = {
+ input.length
}
+}
- case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_bad_magic"
-
- def invoke(input: String): Int = {
- input.length
- }
+case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_bad_magic"
- override def produceResult(input: InternalRow): Int = {
- val s = input.getString(0)
- s.length
- }
+ def invoke(input: String): Int = {
+ input.length
}
- private case object StrLenNoImpl extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_noimpl"
+ override def produceResult(input: InternalRow): Int = {
+ val s = input.getString(0)
+ s.length
}
+}
- // input type doesn't match arguments accepted by `UnboundFunction.bind`
- private case object StrLenBadInputTypes extends ScalarFunction[Int] {
- override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "strlen_bad_input_types"
- }
+case object StrLenNoImpl extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_noimpl"
+}
- private case object BadBoundFunction extends BoundFunction {
- override def inputTypes(): Array[DataType] = Array(StringType)
- override def resultType(): DataType = IntegerType
- override def name(): String = "bad_bound_func"
- }
+// input type doesn't match arguments accepted by `UnboundFunction.bind`
+case object StrLenBadInputTypes extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "strlen_bad_input_types"
+}
- object UnboundDecimalAverage extends UnboundFunction {
- override def name(): String = "decimal_avg"
+case object BadBoundFunction extends BoundFunction {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "bad_bound_func"
+}
- override def bind(inputType: StructType): BoundFunction = {
- if (inputType.fields.length > 1) {
- throw new UnsupportedOperationException("Too many arguments")
- }
+object UnboundDecimalAverage extends UnboundFunction {
+ override def name(): String = "decimal_avg"
- // put interval type here for testing purpose
- inputType.fields(0).dataType match {
- case _: NumericType | _: DayTimeIntervalType => DecimalAverage
- case dataType =>
- throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
- }
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length > 1) {
+ throw new UnsupportedOperationException("Too many arguments")
}
- override def description(): String =
- "decimal_avg: produces an average using decimal division"
+ // put interval type here for testing purpose
+ inputType.fields(0).dataType match {
+ case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+ case dataType =>
+ throw new UnsupportedOperationException(s"Unsupported input type:
$dataType")
+ }
}
- object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
- override def name(): String = "decimal_avg"
- override def inputTypes(): Array[DataType] =
Array(DecimalType.SYSTEM_DEFAULT)
- override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
+ override def description(): String =
+ "decimal_avg: produces an average using decimal division"
+}
- override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
+object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+ override def name(): String = "decimal_avg"
+ override def inputTypes(): Array[DataType] =
Array(DecimalType.SYSTEM_DEFAULT)
+ override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
- override def update(state: (Decimal, Int), input: InternalRow): (Decimal,
Int) = {
- if (input.isNullAt(0)) {
- state
- } else {
- val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
- DecimalType.SYSTEM_DEFAULT.scale)
- state match {
- case (_, d) if d == 0 =>
- (l, 1)
- case (total, count) =>
- (total + l, count + 1)
- }
- }
- }
+ override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
- override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)):
(Decimal, Int) = {
- (leftState._1 + rightState._1, leftState._2 + rightState._2)
+ override def update(state: (Decimal, Int), input: InternalRow): (Decimal,
Int) = {
+ if (input.isNullAt(0)) {
+ state
+ } else {
+ val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
+ DecimalType.SYSTEM_DEFAULT.scale)
+ state match {
+ case (_, d) if d == 0 =>
+ (l, 1)
+ case (total, count) =>
+ (total + l, count + 1)
+ }
}
+ }
- override def produceResult(state: (Decimal, Int)): Decimal = state._1 /
Decimal(state._2)
+ override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)):
(Decimal, Int) = {
+ (leftState._1 + rightState._1, leftState._2 + rightState._2)
}
- object NoImplAverage extends UnboundFunction {
- override def name(): String = "no_impl_avg"
- override def description(): String = name()
+ override def produceResult(state: (Decimal, Int)): Decimal = state._1 /
Decimal(state._2)
+}
+
+object NoImplAverage extends UnboundFunction {
+ override def name(): String = "no_impl_avg"
+ override def description(): String = name()
- override def bind(inputType: StructType): BoundFunction = {
- throw SparkUnsupportedOperationException()
- }
+ override def bind(inputType: StructType): BoundFunction = {
+ throw SparkUnsupportedOperationException()
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
index 7463eb34d17f..dcc49b252fdb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connector
-import org.apache.spark.sql.{DataFrame, SaveMode}
+import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable,
Table, TableCatalog}
class DataSourceV2SQLSessionCatalogSuite
@@ -79,4 +79,11 @@ class DataSourceV2SQLSessionCatalogSuite
assert(getTableMetadata("default.t").columns().map(_.name()) ===
Seq("c2", "c1"))
}
}
+
+ test("SPARK-54760: DelegatingCatalogExtension supports both V1 and V2
functions") {
+ sessionCatalog.createFunction(Identifier.of(Array("ns"), "strlen"),
StrLen(StrLenDefault))
+ checkAnswer(
+ sql("SELECT char_length('Hello') as v1, ns.strlen('Spark') as v2"),
+ Row(5, 5))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
index 6b5bd982ee5a..ef4128c29722 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
@@ -52,6 +52,10 @@ class SupportsCatalogOptionsSuite extends QueryTest with
SharedSparkSession with
spark.sessionState.catalogManager.catalog(name).asInstanceOf[TableCatalog]
}
+ protected def sessionCatalog: InMemoryTableSessionCatalog = {
+ catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog]
+ }
+
private implicit def stringToIdentifier(value: String): Identifier = {
Identifier.of(Array.empty, value)
}
@@ -65,7 +69,8 @@ class SupportsCatalogOptionsSuite extends QueryTest with
SharedSparkSession with
override def afterEach(): Unit = {
super.afterEach()
-
Try(catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog].clearTables())
+ Try(sessionCatalog.checkUsage())
+ Try(sessionCatalog.clearTables())
catalog(catalogName).listTables(Array.empty).foreach(
catalog(catalogName).dropTable(_))
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
@@ -146,7 +151,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with
SharedSparkSession with
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name",
"t1")
dfw.save()
- val table =
catalog(SESSION_CATALOG_NAME).loadTable(Identifier.of(Array("default"), "t1"))
+ val table = sessionCatalog.loadTable(Identifier.of(Array("default"), "t1"))
assert(table.partitioning().isEmpty, "Partitioning should be empty")
assert(table.columns() sameElements
Array(Column.create("id", LongType)), "Schema did not match")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
index 2254abef3fcb..6a82dca9cafc 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
@@ -21,21 +21,32 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
import scala.jdk.CollectionConverters._
+import scala.util.{Failure, Success, Try}
+import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column,
DelegatingCatalogExtension, Identifier, Table, TableCatalog}
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType
/**
* A V2SessionCatalog implementation that can be extended to generate
arbitrary `Table` definitions
* for testing DDL as well as write operations (through df.write.saveAsTable,
df.write.insertInto
- * and SQL).
+ * and SQL), also supports v2 function operations.
*/
private[connector] trait TestV2SessionCatalogBase[T <: Table] extends
DelegatingCatalogExtension {
protected val tables: java.util.Map[Identifier, T] = new
ConcurrentHashMap[Identifier, T]()
+ protected val functions: java.util.Map[Identifier, UnboundFunction] =
+ new ConcurrentHashMap[Identifier, UnboundFunction]()
private val tableCreated: AtomicBoolean = new AtomicBoolean(false)
+ private val funcCreated: AtomicBoolean = new AtomicBoolean(false)
+
+ def checkUsage(): Unit = {
+ assert(tableCreated.get || funcCreated.get,
+ "Either tables or functions are not created, maybe didn't use the
session catalog code path?")
+ }
private def addTable(ident: Identifier, table: T): Unit = {
tableCreated.set(true)
@@ -96,13 +107,54 @@ private[connector] trait TestV2SessionCatalogBase[T <:
Table] extends Delegating
}
def clearTables(): Unit = {
- assert(
- tableCreated.get,
- "Tables are not created, maybe didn't use the session catalog code
path?")
tables.keySet().asScala.foreach(super.dropTable)
tables.clear()
tableCreated.set(false)
}
+
+ override def listFunctions(namespace: Array[String]): Array[Identifier] = {
+ (Try(listFunctions0(namespace)), Try(super.listFunctions(namespace)))
match {
+ case (Success(v2), Success(v1)) => v2 ++ v1
+ case (Success(v2), Failure(_)) => v2
+ case (Failure(_), Success(v1)) => v1
+ case (Failure(_), Failure(_)) =>
+ throw new NoSuchNamespaceException(namespace)
+ }
+ }
+
+ private def listFunctions0(namespace: Array[String]): Array[Identifier] = {
+ if (namespace.isEmpty || namespaceExists(namespace)) {
+
functions.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
+ } else {
+ throw new NoSuchNamespaceException(namespace)
+ }
+ }
+
+ override def loadFunction(ident: Identifier): UnboundFunction = {
+ Option(functions.get(ident)) match {
+ case Some(func) => func
+ case _ =>
+ super.loadFunction(ident)
+ }
+ }
+
+ override def functionExists(ident: Identifier): Boolean = {
+ functions.containsKey(ident) || super.functionExists(ident)
+ }
+
+ def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction
= {
+ funcCreated.set(true)
+ functions.put(ident, fn)
+ }
+
+ def dropFunction(ident: Identifier): Unit = {
+ functions.remove(ident)
+ }
+
+ def clearFunctions(): Unit = {
+ functions.clear()
+ funcCreated.set(false)
+ }
}
object TestV2SessionCatalogBase {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]