This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 512099bfb77c [SPARK-54760][SQL] DelegatingCatalogExtension as session 
catalog supports both V1 and V2 functions
512099bfb77c is described below

commit 512099bfb77cbd98f5ea26fc37d06b694c11dee1
Author: Cheng Pan <[email protected]>
AuthorDate: Tue Dec 23 15:56:14 2025 +0800

    [SPARK-54760][SQL] DelegatingCatalogExtension as session catalog supports 
both V1 and V2 functions
    
    ### What changes were proposed in this pull request?
    
    This PR fixes a bug that occurs when the user uses a custom 
`DelegatingCatalogExtension` as the session catalog, Spark can not load the v2 
function properly provided by the catalog. A typical use case is Iceberg's 
`SparkSessionCatalog`
    ```
    $ spark-sql \
      --conf 
spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog \
      ...
    ```
    
    ```
    spark-sql (default)> SELECT spark_catalog.system.iceberg_version();
    [ROUTINE_NOT_FOUND] The routine `system`.`iceberg_version` cannot be found. 
Verify the spelling and correctness of the schema and catalog.
    If you did not qualify the name with a schema and catalog, verify the 
current_schema() output, or qualify the name with the correct schema and 
catalog.
    To tolerate the error on drop use DROP ... IF EXISTS. SQLSTATE: 42883; line 
1 pos 7
    ```
    
    ### Why are the changes needed?
    
    Fix bug.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it fixes a bug.
    
    ### How was this patch tested?
    
    Add new UT. Also manually tested with Iceberg.
    ```
    spark-sql (default)> SELECT spark_catalog.system.iceberg_version();
    1.10.0
    Time taken: 1.715 seconds, Fetched 1 row(s)
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53531 from pan3793/SPARK-54760.
    
    Authored-by: Cheng Pan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/analysis/FunctionResolution.scala |  23 +--
 .../identifier-clause-legacy.sql.out               |   8 +-
 .../analyzer-results/identifier-clause.sql.out     |   8 +-
 .../results/identifier-clause-legacy.sql.out       |   8 +-
 .../sql-tests/results/identifier-clause.sql.out    |   8 +-
 .../DataSourceV2DataFrameSessionCatalogSuite.scala |   8 +-
 .../sql/connector/DataSourceV2FunctionSuite.scala  | 182 ++++++++++-----------
 .../DataSourceV2SQLSessionCatalogSuite.scala       |   9 +-
 .../connector/SupportsCatalogOptionsSuite.scala    |   9 +-
 .../sql/connector/TestV2SessionCatalogBase.scala   |  60 ++++++-
 10 files changed, 197 insertions(+), 126 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
index 800126e0030e..8d6e2931a73b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala
@@ -26,17 +26,16 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.connector.catalog.{
   CatalogManager,
-  CatalogV2Util,
-  FunctionCatalog,
-  Identifier,
   LookupCatalog
 }
 import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 import org.apache.spark.sql.connector.catalog.functions.{
   AggregateFunction => V2AggregateFunction,
-  ScalarFunction
+  ScalarFunction,
+  UnboundFunction
 }
 import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
+import org.apache.spark.sql.internal.connector.V1Function
 import org.apache.spark.sql.types._
 
 class FunctionResolution(
@@ -52,10 +51,14 @@ class FunctionResolution(
       resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse {
         val CatalogAndIdentifier(catalog, ident) =
           relationResolution.expandIdentifier(u.nameParts)
-        if (CatalogV2Util.isSessionCatalog(catalog)) {
-          resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
-        } else {
-          resolveV2Function(catalog.asFunctionCatalog, ident, u.arguments, u)
+        catalog.asFunctionCatalog.loadFunction(ident) match {
+          case V1Function(_) =>
+            // this triggers the second time v1 function resolution but should 
be cheap
+            // (no RPC to external catalog), since the metadata has been 
already cached
+            // in FunctionRegistry during the above `catalog.loadFunction` 
call.
+            resolveV1Function(ident.asFunctionIdentifier, u.arguments, u)
+          case unboundV2Func =>
+            resolveV2Function(unboundV2Func, u.arguments, u)
         }
       }
     }
@@ -272,11 +275,9 @@ class FunctionResolution(
   }
 
   private def resolveV2Function(
-      catalog: FunctionCatalog,
-      ident: Identifier,
+      unbound: UnboundFunction,
       arguments: Seq[Expression],
       u: UnresolvedFunction): Expression = {
-    val unbound = catalog.loadFunction(ident)
     val inputType = StructType(arguments.zipWithIndex.map {
       case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
     })
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
index 94fff8f58697..95639c72a0ad 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
 -- !query analysis
 org.apache.spark.sql.AnalysisException
 {
-  "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
-  "sqlState" : "42601",
+  "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+  "sqlState" : "42K05",
   "messageParameters" : {
-    "identifier" : "`a`.`b`.`c`.`d`",
-    "limit" : "2"
+    "namespace" : "`a`.`b`.`c`",
+    "sessionCatalog" : "spark_catalog"
   },
   "queryContext" : [ {
     "objectType" : "",
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
index e6a406072c48..e3150b199658 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
@@ -972,11 +972,11 @@ VALUES(IDENTIFIER('a.b.c.d')())
 -- !query analysis
 org.apache.spark.sql.AnalysisException
 {
-  "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
-  "sqlState" : "42601",
+  "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+  "sqlState" : "42K05",
   "messageParameters" : {
-    "identifier" : "`a`.`b`.`c`.`d`",
-    "limit" : "2"
+    "namespace" : "`a`.`b`.`c`",
+    "sessionCatalog" : "spark_catalog"
   },
   "queryContext" : [ {
     "objectType" : "",
diff --git 
a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
index 6a99be057010..13a4b43fd058 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
@@ -1112,11 +1112,11 @@ struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
 {
-  "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
-  "sqlState" : "42601",
+  "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+  "sqlState" : "42K05",
   "messageParameters" : {
-    "identifier" : "`a`.`b`.`c`.`d`",
-    "limit" : "2"
+    "namespace" : "`a`.`b`.`c`",
+    "sessionCatalog" : "spark_catalog"
   },
   "queryContext" : [ {
     "objectType" : "",
diff --git 
a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out 
b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
index 0c0473791201..beeb3b13fe1e 100644
--- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
@@ -1112,11 +1112,11 @@ struct<>
 -- !query output
 org.apache.spark.sql.AnalysisException
 {
-  "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
-  "sqlState" : "42601",
+  "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+  "sqlState" : "42K05",
   "messageParameters" : {
-    "identifier" : "`a`.`b`.`c`.`d`",
-    "limit" : "2"
+    "namespace" : "`a`.`b`.`c`",
+    "sessionCatalog" : "spark_catalog"
   },
   "queryContext" : [ {
     "objectType" : "",
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
index 8959b285b028..bc6ceeb24593 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
@@ -168,6 +168,10 @@ private [connector] trait SessionCatalogTest[T <: Table, 
Catalog <: TestV2Sessio
     spark.sessionState.catalogManager.catalog(name)
   }
 
+  protected def sessionCatalog: Catalog = {
+    catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog]
+  }
+
   protected val v2Format: String = 
classOf[FakeV2ProviderWithCustomSchema].getName
 
   protected val catalogClassName: String = 
classOf[InMemoryTableSessionCatalog].getName
@@ -178,7 +182,9 @@ private [connector] trait SessionCatalogTest[T <: Table, 
Catalog <: TestV2Sessio
 
   override def afterEach(): Unit = {
     super.afterEach()
-    catalog(SESSION_CATALOG_NAME).asInstanceOf[Catalog].clearTables()
+    sessionCatalog.checkUsage()
+    sessionCatalog.clearTables()
+    sessionCatalog.clearFunctions()
     spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index c6f2da686fe9..366528e46ff2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -702,127 +702,127 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     comparePlans(df1.queryExecution.optimizedPlan, 
df2.queryExecution.optimizedPlan)
     checkAnswer(df1, Row(3) :: Nil)
   }
+}
 
-  private case object StrLenDefault extends ScalarFunction[Int] {
-    override def inputTypes(): Array[DataType] = Array(StringType)
-    override def resultType(): DataType = IntegerType
-    override def name(): String = "strlen_default"
+case object StrLenDefault extends ScalarFunction[Int] {
+  override def inputTypes(): Array[DataType] = Array(StringType)
+  override def resultType(): DataType = IntegerType
+  override def name(): String = "strlen_default"
 
-    override def produceResult(input: InternalRow): Int = {
-      val s = input.getString(0)
-      s.length
-    }
+  override def produceResult(input: InternalRow): Int = {
+    val s = input.getString(0)
+    s.length
   }
+}
 
-  case object StrLenMagic extends ScalarFunction[Int] {
-    override def inputTypes(): Array[DataType] = Array(StringType)
-    override def resultType(): DataType = IntegerType
-    override def name(): String = "strlen_magic"
+case object StrLenMagic extends ScalarFunction[Int] {
+  override def inputTypes(): Array[DataType] = Array(StringType)
+  override def resultType(): DataType = IntegerType
+  override def name(): String = "strlen_magic"
 
-    def invoke(input: UTF8String): Int = {
-      input.toString.length
-    }
+  def invoke(input: UTF8String): Int = {
+    input.toString.length
   }
+}
 
-  case object StrLenBadMagic extends ScalarFunction[Int] {
-    override def inputTypes(): Array[DataType] = Array(StringType)
-    override def resultType(): DataType = IntegerType
-    override def name(): String = "strlen_bad_magic"
+case object StrLenBadMagic extends ScalarFunction[Int] {
+  override def inputTypes(): Array[DataType] = Array(StringType)
+  override def resultType(): DataType = IntegerType
+  override def name(): String = "strlen_bad_magic"
 
-    def invoke(input: String): Int = {
-      input.length
-    }
+  def invoke(input: String): Int = {
+    input.length
   }
+}
 
-  case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
-    override def inputTypes(): Array[DataType] = Array(StringType)
-    override def resultType(): DataType = IntegerType
-    override def name(): String = "strlen_bad_magic"
-
-    def invoke(input: String): Int = {
-      input.length
-    }
+case object StrLenBadMagicWithDefault extends ScalarFunction[Int] {
+  override def inputTypes(): Array[DataType] = Array(StringType)
+  override def resultType(): DataType = IntegerType
+  override def name(): String = "strlen_bad_magic"
 
-    override def produceResult(input: InternalRow): Int = {
-      val s = input.getString(0)
-      s.length
-    }
+  def invoke(input: String): Int = {
+    input.length
   }
 
-  private case object StrLenNoImpl extends ScalarFunction[Int] {
-    override def inputTypes(): Array[DataType] = Array(StringType)
-    override def resultType(): DataType = IntegerType
-    override def name(): String = "strlen_noimpl"
+  override def produceResult(input: InternalRow): Int = {
+    val s = input.getString(0)
+    s.length
   }
+}
 
-  // input type doesn't match arguments accepted by `UnboundFunction.bind`
-  private case object StrLenBadInputTypes extends ScalarFunction[Int] {
-    override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
-    override def resultType(): DataType = IntegerType
-    override def name(): String = "strlen_bad_input_types"
-  }
+case object StrLenNoImpl extends ScalarFunction[Int] {
+  override def inputTypes(): Array[DataType] = Array(StringType)
+  override def resultType(): DataType = IntegerType
+  override def name(): String = "strlen_noimpl"
+}
 
-  private case object BadBoundFunction extends BoundFunction {
-    override def inputTypes(): Array[DataType] = Array(StringType)
-    override def resultType(): DataType = IntegerType
-    override def name(): String = "bad_bound_func"
-  }
+// input type doesn't match arguments accepted by `UnboundFunction.bind`
+case object StrLenBadInputTypes extends ScalarFunction[Int] {
+  override def inputTypes(): Array[DataType] = Array(StringType, IntegerType)
+  override def resultType(): DataType = IntegerType
+  override def name(): String = "strlen_bad_input_types"
+}
 
-  object UnboundDecimalAverage extends UnboundFunction {
-    override def name(): String = "decimal_avg"
+case object BadBoundFunction extends BoundFunction {
+  override def inputTypes(): Array[DataType] = Array(StringType)
+  override def resultType(): DataType = IntegerType
+  override def name(): String = "bad_bound_func"
+}
 
-    override def bind(inputType: StructType): BoundFunction = {
-      if (inputType.fields.length > 1) {
-        throw new UnsupportedOperationException("Too many arguments")
-      }
+object UnboundDecimalAverage extends UnboundFunction {
+  override def name(): String = "decimal_avg"
 
-      // put interval type here for testing purpose
-      inputType.fields(0).dataType match {
-        case _: NumericType | _: DayTimeIntervalType => DecimalAverage
-        case dataType =>
-          throw new UnsupportedOperationException(s"Unsupported input type: 
$dataType")
-      }
+  override def bind(inputType: StructType): BoundFunction = {
+    if (inputType.fields.length > 1) {
+      throw new UnsupportedOperationException("Too many arguments")
     }
 
-    override def description(): String =
-      "decimal_avg: produces an average using decimal division"
+    // put interval type here for testing purpose
+    inputType.fields(0).dataType match {
+      case _: NumericType | _: DayTimeIntervalType => DecimalAverage
+      case dataType =>
+        throw new UnsupportedOperationException(s"Unsupported input type: 
$dataType")
+    }
   }
 
-  object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
-    override def name(): String = "decimal_avg"
-    override def inputTypes(): Array[DataType] = 
Array(DecimalType.SYSTEM_DEFAULT)
-    override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
+  override def description(): String =
+    "decimal_avg: produces an average using decimal division"
+}
 
-    override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
+object DecimalAverage extends AggregateFunction[(Decimal, Int), Decimal] {
+  override def name(): String = "decimal_avg"
+  override def inputTypes(): Array[DataType] = 
Array(DecimalType.SYSTEM_DEFAULT)
+  override def resultType(): DataType = DecimalType.SYSTEM_DEFAULT
 
-    override def update(state: (Decimal, Int), input: InternalRow): (Decimal, 
Int) = {
-      if (input.isNullAt(0)) {
-        state
-      } else {
-        val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
-          DecimalType.SYSTEM_DEFAULT.scale)
-        state match {
-          case (_, d) if d == 0 =>
-            (l, 1)
-          case (total, count) =>
-            (total + l, count + 1)
-        }
-      }
-    }
+  override def newAggregationState(): (Decimal, Int) = (Decimal.ZERO, 0)
 
-    override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)): 
(Decimal, Int) = {
-      (leftState._1 + rightState._1, leftState._2 + rightState._2)
+  override def update(state: (Decimal, Int), input: InternalRow): (Decimal, 
Int) = {
+    if (input.isNullAt(0)) {
+      state
+    } else {
+      val l = input.getDecimal(0, DecimalType.SYSTEM_DEFAULT.precision,
+        DecimalType.SYSTEM_DEFAULT.scale)
+      state match {
+        case (_, d) if d == 0 =>
+          (l, 1)
+        case (total, count) =>
+          (total + l, count + 1)
+      }
     }
+  }
 
-    override def produceResult(state: (Decimal, Int)): Decimal = state._1 / 
Decimal(state._2)
+  override def merge(leftState: (Decimal, Int), rightState: (Decimal, Int)): 
(Decimal, Int) = {
+    (leftState._1 + rightState._1, leftState._2 + rightState._2)
   }
 
-  object NoImplAverage extends UnboundFunction {
-    override def name(): String = "no_impl_avg"
-    override def description(): String = name()
+  override def produceResult(state: (Decimal, Int)): Decimal = state._1 / 
Decimal(state._2)
+}
+
+object NoImplAverage extends UnboundFunction {
+  override def name(): String = "no_impl_avg"
+  override def description(): String = name()
 
-    override def bind(inputType: StructType): BoundFunction = {
-      throw SparkUnsupportedOperationException()
-    }
+  override def bind(inputType: StructType): BoundFunction = {
+    throw SparkUnsupportedOperationException()
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
index 7463eb34d17f..dcc49b252fdb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.connector
 
-import org.apache.spark.sql.{DataFrame, SaveMode}
+import org.apache.spark.sql.{DataFrame, Row, SaveMode}
 import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, 
Table, TableCatalog}
 
 class DataSourceV2SQLSessionCatalogSuite
@@ -79,4 +79,11 @@ class DataSourceV2SQLSessionCatalogSuite
       assert(getTableMetadata("default.t").columns().map(_.name()) === 
Seq("c2", "c1"))
     }
   }
+
+  test("SPARK-54760: DelegatingCatalogExtension supports both V1 and V2 
functions") {
+    sessionCatalog.createFunction(Identifier.of(Array("ns"), "strlen"), 
StrLen(StrLenDefault))
+    checkAnswer(
+      sql("SELECT char_length('Hello') as v1, ns.strlen('Spark') as v2"),
+      Row(5, 5))
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
index 6b5bd982ee5a..ef4128c29722 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala
@@ -52,6 +52,10 @@ class SupportsCatalogOptionsSuite extends QueryTest with 
SharedSparkSession with
     spark.sessionState.catalogManager.catalog(name).asInstanceOf[TableCatalog]
   }
 
+  protected def sessionCatalog: InMemoryTableSessionCatalog = {
+    catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog]
+  }
+
   private implicit def stringToIdentifier(value: String): Identifier = {
     Identifier.of(Array.empty, value)
   }
@@ -65,7 +69,8 @@ class SupportsCatalogOptionsSuite extends QueryTest with 
SharedSparkSession with
 
   override def afterEach(): Unit = {
     super.afterEach()
-    
Try(catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog].clearTables())
+    Try(sessionCatalog.checkUsage())
+    Try(sessionCatalog.clearTables())
     catalog(catalogName).listTables(Array.empty).foreach(
       catalog(catalogName).dropTable(_))
     spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
@@ -146,7 +151,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with 
SharedSparkSession with
     val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", 
"t1")
     dfw.save()
 
-    val table = 
catalog(SESSION_CATALOG_NAME).loadTable(Identifier.of(Array("default"), "t1"))
+    val table = sessionCatalog.loadTable(Identifier.of(Array("default"), "t1"))
     assert(table.partitioning().isEmpty, "Partitioning should be empty")
     assert(table.columns() sameElements
       Array(Column.create("id", LongType)), "Schema did not match")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
index 2254abef3fcb..6a82dca9cafc 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
@@ -21,21 +21,32 @@ import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.jdk.CollectionConverters._
+import scala.util.{Failure, Success, Try}
 
+import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException
 import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, 
DelegatingCatalogExtension, Identifier, Table, TableCatalog}
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
 import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.types.StructType
 
 /**
  * A V2SessionCatalog implementation that can be extended to generate 
arbitrary `Table` definitions
  * for testing DDL as well as write operations (through df.write.saveAsTable, 
df.write.insertInto
- * and SQL).
+ * and SQL), also supports v2 function operations.
  */
 private[connector] trait TestV2SessionCatalogBase[T <: Table] extends 
DelegatingCatalogExtension {
 
   protected val tables: java.util.Map[Identifier, T] = new 
ConcurrentHashMap[Identifier, T]()
+  protected val functions: java.util.Map[Identifier, UnboundFunction] =
+    new ConcurrentHashMap[Identifier, UnboundFunction]()
 
   private val tableCreated: AtomicBoolean = new AtomicBoolean(false)
+  private val funcCreated: AtomicBoolean = new AtomicBoolean(false)
+
+  def checkUsage(): Unit = {
+    assert(tableCreated.get || funcCreated.get,
+      "Either tables or functions are not created, maybe didn't use the 
session catalog code path?")
+  }
 
   private def addTable(ident: Identifier, table: T): Unit = {
     tableCreated.set(true)
@@ -96,13 +107,54 @@ private[connector] trait TestV2SessionCatalogBase[T <: 
Table] extends Delegating
   }
 
   def clearTables(): Unit = {
-    assert(
-      tableCreated.get,
-      "Tables are not created, maybe didn't use the session catalog code 
path?")
     tables.keySet().asScala.foreach(super.dropTable)
     tables.clear()
     tableCreated.set(false)
   }
+
+  override def listFunctions(namespace: Array[String]): Array[Identifier] = {
+    (Try(listFunctions0(namespace)), Try(super.listFunctions(namespace))) 
match {
+      case (Success(v2), Success(v1)) => v2 ++ v1
+      case (Success(v2), Failure(_)) => v2
+      case (Failure(_), Success(v1)) => v1
+      case (Failure(_), Failure(_)) =>
+        throw new NoSuchNamespaceException(namespace)
+    }
+  }
+
+  private def listFunctions0(namespace: Array[String]): Array[Identifier] = {
+    if (namespace.isEmpty || namespaceExists(namespace)) {
+      
functions.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
+    } else {
+      throw new NoSuchNamespaceException(namespace)
+    }
+  }
+
+  override def loadFunction(ident: Identifier): UnboundFunction = {
+    Option(functions.get(ident)) match {
+      case Some(func) => func
+      case _ =>
+        super.loadFunction(ident)
+    }
+  }
+
+  override def functionExists(ident: Identifier): Boolean = {
+    functions.containsKey(ident) || super.functionExists(ident)
+  }
+
+  def createFunction(ident: Identifier, fn: UnboundFunction): UnboundFunction 
= {
+    funcCreated.set(true)
+    functions.put(ident, fn)
+  }
+
+  def dropFunction(ident: Identifier): Unit = {
+    functions.remove(ident)
+  }
+
+  def clearFunctions(): Unit = {
+    functions.clear()
+    funcCreated.set(false)
+  }
 }
 
 object TestV2SessionCatalogBase {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to