This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git


The following commit(s) were added to refs/heads/main by this push:
     new cee0edf  [SPARK-51968] Support `(cache|uncache|refresh)Table`, 
`refreshByPath`, `isCached`, `clearCache` in `Catalog`
cee0edf is described below

commit cee0edf42e4075ac816a5bb46e90bba6de5a5d9e
Author: Dongjoon Hyun <dongj...@apache.org>
AuthorDate: Wed Apr 30 11:38:22 2025 -0700

    [SPARK-51968] Support `(cache|uncache|refresh)Table`, `refreshByPath`, 
`isCached`, `clearCache` in `Catalog`
    
    ### What changes were proposed in this pull request?
    
    This PR aims to support the following APIs of `Catalog`.
    - `cacheTable`
    - `uncacheTable`
    - `refreshTable`
    - `refreshByPath`
    - `isCached`
    - `clearCache`
    
    ### Why are the changes needed?
    
    For feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #96 from dongjoon-hyun/SPARK-51968.
    
    Authored-by: Dongjoon Hyun <dongj...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 Sources/SparkConnect/Catalog.swift         |  86 ++++++++++++++++++++++++
 Sources/SparkConnect/SparkFileUtils.swift  |   1 +
 Tests/SparkConnectTests/CatalogTests.swift | 104 +++++++++++++++++++++++++++++
 3 files changed, 191 insertions(+)

diff --git a/Sources/SparkConnect/Catalog.swift 
b/Sources/SparkConnect/Catalog.swift
index c47fac4..98f9a1b 100644
--- a/Sources/SparkConnect/Catalog.swift
+++ b/Sources/SparkConnect/Catalog.swift
@@ -199,4 +199,90 @@ public actor Catalog: Sendable {
   public func databaseExists(_ dbName: String) async throws -> Bool {
     return try await self.listDatabases(pattern: dbName).count > 0
   }
+
+  /// Caches the specified table in-memory.
+  /// - Parameters:
+  ///   - tableName: A qualified or unqualified name that designates a 
table/view.
+  ///   If no database identifier is provided, it refers to a temporary view 
or a table/view in the current database.
+  ///   - storageLevel: storage level to cache table.
+  public func cacheTable(_ tableName: String, _ storageLevel: StorageLevel? = 
nil) async throws {
+    let df = getDataFrame({
+      var cacheTable = Spark_Connect_CacheTable()
+      cacheTable.tableName = tableName
+      if let storageLevel {
+        cacheTable.storageLevel = storageLevel.toSparkConnectStorageLevel
+      }
+      var catalog = Spark_Connect_Catalog()
+      catalog.cacheTable = cacheTable
+      return catalog
+    })
+    try await df.count()
+  }
+
+  /// Returns true if the table is currently cached in-memory.
+  /// - Parameter tableName: A qualified or unqualified name that designates a 
table/view.
+  /// If no database identifier is provided, it refers to a temporary view or 
a table/view in the current database.
+  public func isCached(_ tableName: String) async throws -> Bool {
+    let df = getDataFrame({
+      var isCached = Spark_Connect_IsCached()
+      isCached.tableName = tableName
+      var catalog = Spark_Connect_Catalog()
+      catalog.isCached = isCached
+      return catalog
+    })
+    return "true" == (try await df.collect().first!.get(0) as! String)
+  }
+
+  /// Invalidates and refreshes all the cached data and metadata of the given 
table.
+  /// - Parameter tableName: A qualified or unqualified name that designates a 
table/view.
+  /// If no database identifier is provided, it refers to a temporary view or 
a table/view in the current database.
+  public func refreshTable(_ tableName: String) async throws {
+    let df = getDataFrame({
+      var refreshTable = Spark_Connect_RefreshTable()
+      refreshTable.tableName = tableName
+      var catalog = Spark_Connect_Catalog()
+      catalog.refreshTable = refreshTable
+      return catalog
+    })
+    try await df.count()
+  }
+
+  /// Invalidates and refreshes all the cached data (and the associated 
metadata) for any ``DataFrame``
+  /// that contains the given data source path. Path matching is by checking 
for sub-directories,
+  /// i.e. "/" would invalidate everything that is cached and "/test/parent" 
would invalidate
+  /// everything that is a subdirectory of "/test/parent".
+  public func refreshByPath(_ path: String) async throws {
+    let df = getDataFrame({
+      var refreshByPath = Spark_Connect_RefreshByPath()
+      refreshByPath.path = path
+      var catalog = Spark_Connect_Catalog()
+      catalog.refreshByPath = refreshByPath
+      return catalog
+    })
+    try await df.count()
+  }
+
+  /// Removes the specified table from the in-memory cache.
+  /// - Parameter tableName: A qualified or unqualified name that designates a 
table/view.
+  /// If no database identifier is provided, it refers to a temporary view or 
a table/view in the current database.
+  public func uncacheTable(_ tableName: String) async throws {
+    let df = getDataFrame({
+      var uncacheTable = Spark_Connect_UncacheTable()
+      uncacheTable.tableName = tableName
+      var catalog = Spark_Connect_Catalog()
+      catalog.uncacheTable = uncacheTable
+      return catalog
+    })
+    try await df.count()
+  }
+
+  /// Removes all cached tables from the in-memory cache.
+  public func clearCache() async throws {
+    let df = getDataFrame({
+      var catalog = Spark_Connect_Catalog()
+      catalog.clearCache_p = Spark_Connect_ClearCache()
+      return catalog
+    })
+    try await df.count()
+  }
 }
diff --git a/Sources/SparkConnect/SparkFileUtils.swift 
b/Sources/SparkConnect/SparkFileUtils.swift
index c91ee0c..1c7fa44 100644
--- a/Sources/SparkConnect/SparkFileUtils.swift
+++ b/Sources/SparkConnect/SparkFileUtils.swift
@@ -64,6 +64,7 @@ public enum SparkFileUtils {
   /// Create a directory given the abstract pathname
   /// - Parameter url: An URL location.
   /// - Returns: Return true if the directory is successfully created; 
otherwise, return false.
+  @discardableResult
   static func createDirectory(at url: URL) -> Bool {
     let fileManager = FileManager.default
     do {
diff --git a/Tests/SparkConnectTests/CatalogTests.swift 
b/Tests/SparkConnectTests/CatalogTests.swift
index 44562d5..afb0ea7 100644
--- a/Tests/SparkConnectTests/CatalogTests.swift
+++ b/Tests/SparkConnectTests/CatalogTests.swift
@@ -111,4 +111,108 @@ struct CatalogTests {
     await spark.stop()
   }
 #endif
+
+  @Test
+  func cacheTable() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")
+    try await SQLHelper.withTable(spark, tableName)({
+      try await spark.range(1).write.saveAsTable(tableName)
+      try await spark.catalog.cacheTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName))
+      try await spark.catalog.cacheTable(tableName, StorageLevel.MEMORY_ONLY)
+    })
+
+    try await #require(throws: Error.self) {
+      try await spark.catalog.cacheTable("not_exist_table")
+    }
+    await spark.stop()
+  }
+
+  @Test
+  func isCached() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")
+    try await SQLHelper.withTable(spark, tableName)({
+      try await spark.range(1).write.saveAsTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName) == false)
+      try await spark.catalog.cacheTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName))
+    })
+
+    try await #require(throws: Error.self) {
+      try await spark.catalog.isCached("not_exist_table")
+    }
+    await spark.stop()
+  }
+
+  @Test
+  func refreshTable() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")
+    try await SQLHelper.withTable(spark, tableName)({
+      try await spark.range(1).write.saveAsTable(tableName)
+      try await spark.catalog.refreshTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName) == false)
+
+      try await spark.catalog.cacheTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName))
+      try await spark.catalog.refreshTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName))
+    })
+
+    try await #require(throws: Error.self) {
+      try await spark.catalog.refreshTable("not_exist_table")
+    }
+    await spark.stop()
+  }
+
+  @Test
+  func refreshByPath() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")
+    try await SQLHelper.withTable(spark, tableName)({
+      try await spark.range(1).write.saveAsTable(tableName)
+      try await spark.catalog.refreshByPath("/")
+      #expect(try await spark.catalog.isCached(tableName) == false)
+
+      try await spark.catalog.cacheTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName))
+      try await spark.catalog.refreshByPath("/")
+      #expect(try await spark.catalog.isCached(tableName))
+    })
+    await spark.stop()
+  }
+
+  @Test
+  func uncacheTable() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")
+    try await SQLHelper.withTable(spark, tableName)({
+      try await spark.range(1).write.saveAsTable(tableName)
+      try await spark.catalog.cacheTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName))
+      try await spark.catalog.uncacheTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName) == false)
+    })
+
+    try await #require(throws: Error.self) {
+      try await spark.catalog.uncacheTable("not_exist_table")
+    }
+    await spark.stop()
+  }
+
+  @Test
+  func clearCache() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")
+    try await SQLHelper.withTable(spark, tableName)({
+      try await spark.range(1).write.saveAsTable(tableName)
+      try await spark.catalog.cacheTable(tableName)
+      #expect(try await spark.catalog.isCached(tableName))
+      try await spark.catalog.clearCache()
+      #expect(try await spark.catalog.isCached(tableName) == false)
+    })
+    await spark.stop()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to