This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git


The following commit(s) were added to refs/heads/main by this push:
     new 49bb86c  [SPARK-51971] Improve `DataFrame.collect` to return the 
original values
49bb86c is described below

commit 49bb86c70148fab5e9c62e740441f3d3c02dce86
Author: Dongjoon Hyun <dongj...@apache.org>
AuthorDate: Wed Apr 30 22:22:24 2025 -0700

    [SPARK-51971] Improve `DataFrame.collect` to return the original values
    
    ### What changes were proposed in this pull request?
    
    This PR aims to improve `DataFrame.collect` to return the original values.
    
    Note that this PR provides simple value types first. More types like 
`Decimal` will be added later.
    
    ### Why are the changes needed?
    
    The initial implementation has a limitation to return rows of `String` 
values.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, because there is no released versions yet.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #99 from dongjoon-hyun/SPARK-51971.
    
    Authored-by: Dongjoon Hyun <dongj...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 Sources/SparkConnect/Catalog.swift                 |  14 +-
 Sources/SparkConnect/DataFrame.swift               |  30 +++-
 Sources/SparkConnect/Row.swift                     |  16 +--
 Tests/SparkConnectTests/DataFrameTests.swift       | 151 +++++++++++----------
 .../Resources/queries/binary.sql.answer            |   2 +-
 Tests/SparkConnectTests/SQLTests.swift             |   2 +-
 Tests/SparkConnectTests/SparkSessionTests.swift    |   2 +-
 7 files changed, 117 insertions(+), 100 deletions(-)

diff --git a/Sources/SparkConnect/Catalog.swift 
b/Sources/SparkConnect/Catalog.swift
index 04bc29c..51f056d 100644
--- a/Sources/SparkConnect/Catalog.swift
+++ b/Sources/SparkConnect/Catalog.swift
@@ -252,7 +252,7 @@ public actor Catalog: Sendable {
       catalog.tableExists = tableExists
       return catalog
     })
-    return "true" == (try await df.collect().first!.get(0) as! String)
+    return try await df.collect()[0].getAsBool(0)
   }
 
   /// Check if the table or view with the specified name exists. This can 
either be a temporary
@@ -270,7 +270,7 @@ public actor Catalog: Sendable {
       catalog.tableExists = tableExists
       return catalog
     })
-    return "true" == (try await df.collect().first!.get(0) as! String)
+    return try await df.collect()[0].getAsBool(0)
   }
 
   /// Check if the function with the specified name exists. This can either be 
a temporary function
@@ -287,7 +287,7 @@ public actor Catalog: Sendable {
       catalog.functionExists = functionExists
       return catalog
     })
-    return "true" == (try await df.collect().first!.get(0) as! String)
+    return try await df.collect()[0].getAsBool(0)
   }
 
   /// Check if the function with the specified name exists in the specified 
database under the Hive
@@ -305,7 +305,7 @@ public actor Catalog: Sendable {
       catalog.functionExists = functionExists
       return catalog
     })
-    return "true" == (try await df.collect().first!.get(0) as! String)
+    return try await df.collect()[0].getAsBool(0)
   }
 
   /// Caches the specified table in-memory.
@@ -338,7 +338,7 @@ public actor Catalog: Sendable {
       catalog.isCached = isCached
       return catalog
     })
-    return "true" == (try await df.collect().first!.get(0) as! String)
+    return try await df.collect()[0].getAsBool(0)
   }
 
   /// Invalidates and refreshes all the cached data and metadata of the given 
table.
@@ -407,7 +407,7 @@ public actor Catalog: Sendable {
       catalog.dropTempView = dropTempView
       return catalog
     })
-    return "true" == (try await df.collect().first!.get(0) as! String)
+    return try await df.collect().first!.getAsBool(0)
   }
 
   /// Drops the global temporary view with the given view name in the catalog. 
If the view has been
@@ -423,6 +423,6 @@ public actor Catalog: Sendable {
       catalog.dropGlobalTempView = dropGlobalTempView
       return catalog
     })
-    return "true" == (try await df.collect().first!.get(0) as! String)
+    return try await df.collect()[0].getAsBool(0)
   }
 }
diff --git a/Sources/SparkConnect/DataFrame.swift 
b/Sources/SparkConnect/DataFrame.swift
index d3eb909..5531917 100644
--- a/Sources/SparkConnect/DataFrame.swift
+++ b/Sources/SparkConnect/DataFrame.swift
@@ -208,14 +208,34 @@ public actor DataFrame: Sendable {
       for i in 0..<batch.length {
         var values: [Sendable?] = []
         for column in batch.columns {
-          let str = column.array as! AsString
           if column.data.isNull(i) {
             values.append(nil)
-          } else if column.data.type.info == ArrowType.ArrowBinary {
-            let binary = str.asString(i).utf8.map { String(format: "%02x", $0) 
}.joined(separator: " ")
-            values.append("[\(binary)]")
           } else {
-            values.append(str.asString(i))
+            let array = column.array
+            switch column.data.type.info {
+            case .primitiveInfo(.boolean):
+              values.append(array.asAny(i) as? Bool)
+            case .primitiveInfo(.int8):
+              values.append(array.asAny(i) as? Int8)
+            case .primitiveInfo(.int16):
+              values.append(array.asAny(i) as? Int16)
+            case .primitiveInfo(.int32):
+              values.append(array.asAny(i) as? Int32)
+            case .primitiveInfo(.int64):
+              values.append(array.asAny(i) as! Int64)
+            case .primitiveInfo(.float):
+              values.append(array.asAny(i) as? Float)
+            case .primitiveInfo(.double):
+              values.append(array.asAny(i) as? Double)
+            case .primitiveInfo(.date32):
+              values.append(array.asAny(i) as! Date)
+            case ArrowType.ArrowBinary:
+              values.append((array as! AsString).asString(i).utf8)
+            case .complexInfo(.strct):
+              values.append((array as! AsString).asString(i))
+            default:
+              values.append(array.asAny(i) as? String)
+            }
           }
         }
         result.append(Row(valueArray: values))
diff --git a/Sources/SparkConnect/Row.swift b/Sources/SparkConnect/Row.swift
index 0caf505..67cfcfd 100644
--- a/Sources/SparkConnect/Row.swift
+++ b/Sources/SparkConnect/Row.swift
@@ -50,6 +50,10 @@ public struct Row: Sendable, Equatable {
     return values[i]
   }
 
+  public func getAsBool(_ i: Int) throws -> Bool {
+    return try get(i) as! Bool
+  }
+
   public static func == (lhs: Row, rhs: Row) -> Bool {
     if lhs.values.count != rhs.values.count {
       return false
@@ -59,16 +63,8 @@ public struct Row: Sendable, Equatable {
         return true
       } else if let a = x as? Bool, let b = y as? Bool {
         return a == b
-      } else if let a = x as? Int, let b = y as? Int {
-        return a == b
-      } else if let a = x as? Int8, let b = y as? Int8 {
-        return a == b
-      } else if let a = x as? Int16, let b = y as? Int16 {
-        return a == b
-      } else if let a = x as? Int32, let b = y as? Int32 {
-        return a == b
-      } else if let a = x as? Int64, let b = y as? Int64 {
-        return a == b
+      } else if let a = x as? any FixedWidthInteger, let b = y as? any 
FixedWidthInteger {
+        return Int64(a) == Int64(b)
       } else if let a = x as? Float, let b = y as? Float {
         return a == b
       } else if let a = x as? Double, let b = y as? Double {
diff --git a/Tests/SparkConnectTests/DataFrameTests.swift 
b/Tests/SparkConnectTests/DataFrameTests.swift
index bf320e5..13bc7c4 100644
--- a/Tests/SparkConnectTests/DataFrameTests.swift
+++ b/Tests/SparkConnectTests/DataFrameTests.swift
@@ -318,7 +318,7 @@ struct DataFrameTests {
   @Test
   func sort() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
-    let expected = Array((1...10).map{ Row(String($0)) })
+    let expected = Array((1...10).map{ Row($0) })
     #expect(try await spark.range(10, 0, -1).sort("id").collect() == expected)
     await spark.stop()
   }
@@ -326,7 +326,7 @@ struct DataFrameTests {
   @Test
   func orderBy() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
-    let expected = Array((1...10).map{ Row(String($0)) })
+    let expected = Array((1...10).map{ Row($0) })
     #expect(try await spark.range(10, 0, -1).orderBy("id").collect() == 
expected)
     await spark.stop()
   }
@@ -354,7 +354,7 @@ struct DataFrameTests {
     #expect(
       try await spark.sql(
         "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 
'def')"
-      ).collect() == [Row("1", "true", "abc"), Row(nil, nil, nil), Row("3", 
"false", "def")])
+      ).collect() == [Row(1, true, "abc"), Row(nil, nil, nil), Row(3, false, 
"def")])
     await spark.stop()
   }
 
@@ -371,10 +371,11 @@ struct DataFrameTests {
   func head() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     #expect(try await spark.range(0).head().isEmpty)
-    #expect(try await spark.range(2).sort("id").head() == [Row("0")])
-    #expect(try await spark.range(2).sort("id").head(1) == [Row("0")])
-    #expect(try await spark.range(2).sort("id").head(2) == [Row("0"), 
Row("1")])
-    #expect(try await spark.range(2).sort("id").head(3) == [Row("0"), 
Row("1")])
+    print(try await spark.range(2).sort("id").head())
+    #expect(try await spark.range(2).sort("id").head() == [Row(0)])
+    #expect(try await spark.range(2).sort("id").head(1) == [Row(0)])
+    #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)])
+    #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)])
     await spark.stop()
   }
 
@@ -382,9 +383,9 @@ struct DataFrameTests {
   func tail() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     #expect(try await spark.range(0).tail(1).isEmpty)
-    #expect(try await spark.range(2).sort("id").tail(1) == [Row("1")])
-    #expect(try await spark.range(2).sort("id").tail(2) == [Row("0"), 
Row("1")])
-    #expect(try await spark.range(2).sort("id").tail(3) == [Row("0"), 
Row("1")])
+    #expect(try await spark.range(2).sort("id").tail(1) == [Row(1)])
+    #expect(try await spark.range(2).sort("id").tail(2) == [Row(0), Row(1)])
+    #expect(try await spark.range(2).sort("id").tail(3) == [Row(0), Row(1)])
     await spark.stop()
   }
 
@@ -464,35 +465,35 @@ struct DataFrameTests {
   @Test
   func join() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
-    let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2') 
AS T(a, b)")
-    let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3') 
AS S(c, b)")
+    let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS 
T(a, b)")
+    let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS 
S(c, b)")
     let expectedCross = [
-      Row("a", "1", "c", "2"),
-      Row("a", "1", "d", "3"),
-      Row("b", "2", "c", "2"),
-      Row("b", "2", "d", "3"),
+      Row("a", 1, "c", 2),
+      Row("a", 1, "d", 3),
+      Row("b", 2, "c", 2),
+      Row("b", 2, "d", 3),
     ]
     #expect(try await df1.join(df2).collect() == expectedCross)
     #expect(try await df1.crossJoin(df2).collect() == expectedCross)
 
-    #expect(try await df1.join(df2, "b").collect() == [Row("2", "b", "c")])
-    #expect(try await df1.join(df2, ["b"]).collect() == [Row("2", "b", "c")])
+    #expect(try await df1.join(df2, "b").collect() == [Row(2, "b", "c")])
+    #expect(try await df1.join(df2, ["b"]).collect() == [Row(2, "b", "c")])
 
-    #expect(try await df1.join(df2, "b", "left").collect() == [Row("1", "a", 
nil), Row("2", "b", "c")])
-    #expect(try await df1.join(df2, "b", "right").collect() == [Row("2", "b", 
"c"), Row("3", nil, "d")])
-    #expect(try await df1.join(df2, "b", "semi").collect() == [Row("2", "b")])
-    #expect(try await df1.join(df2, "b", "anti").collect() == [Row("1", "a")])
+    #expect(try await df1.join(df2, "b", "left").collect() == [Row(1, "a", 
nil), Row(2, "b", "c")])
+    #expect(try await df1.join(df2, "b", "right").collect() == [Row(2, "b", 
"c"), Row(3, nil, "d")])
+    #expect(try await df1.join(df2, "b", "semi").collect() == [Row(2, "b")])
+    #expect(try await df1.join(df2, "b", "anti").collect() == [Row(1, "a")])
 
     let expectedOuter = [
-      Row("1", "a", nil),
-      Row("2", "b", "c"),
-      Row("3", nil, "d"),
+      Row(1, "a", nil),
+      Row(2, "b", "c"),
+      Row(3, nil, "d"),
     ]
     #expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter)
     #expect(try await df1.join(df2, "b", "full").collect() == expectedOuter)
     #expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter)
 
-    let expected = [Row("b", "2", "c", "2")]
+    let expected = [Row("b", 2, "c", 2)]
     #expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() == 
expected)
     #expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: 
"inner").collect() == expected)
     await spark.stop()
@@ -502,18 +503,18 @@ struct DataFrameTests {
   func lateralJoin() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     if await spark.version.starts(with: "4.") {
-      let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', 
'2') AS T(a, b)")
-      let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', 
'3') AS S(c, b)")
+      let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) 
AS T(a, b)")
+      let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) 
AS S(c, b)")
       let expectedCross = [
-        Row("a", "1", "c", "2"),
-        Row("a", "1", "d", "3"),
-        Row("b", "2", "c", "2"),
-        Row("b", "2", "d", "3"),
+        Row("a", 1, "c", 2),
+        Row("a", 1, "d", 3),
+        Row("b", 2, "c", 2),
+        Row("b", 2, "d", 3),
       ]
       #expect(try await df1.lateralJoin(df2).collect() == expectedCross)
       #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == 
expectedCross)
 
-      let expected = [Row("b", "2", "c", "2")]
+      let expected = [Row("b", 2, "c", 2)]
       #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() 
== expected)
       #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: 
"inner").collect() == expected)
     }
@@ -525,8 +526,8 @@ struct DataFrameTests {
     let spark = try await SparkSession.builder.getOrCreate()
     let df = try await spark.range(1, 3)
     #expect(try await df.except(spark.range(1, 5)).collect() == [])
-    #expect(try await df.except(spark.range(2, 5)).collect() == [Row("1")])
-    #expect(try await df.except(spark.range(3, 5)).collect() == [Row("1"), 
Row("2")])
+    #expect(try await df.except(spark.range(2, 5)).collect() == [Row(1)])
+    #expect(try await df.except(spark.range(3, 5)).collect() == [Row(1), 
Row(2)])
     #expect(try await spark.sql("SELECT * FROM VALUES 1, 
1").except(df).count() == 0)
     await spark.stop()
   }
@@ -536,8 +537,8 @@ struct DataFrameTests {
     let spark = try await SparkSession.builder.getOrCreate()
     let df = try await spark.range(1, 3)
     #expect(try await df.exceptAll(spark.range(1, 5)).collect() == [])
-    #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row("1")])
-    #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row("1"), 
Row("2")])
+    #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row(1)])
+    #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row(1), 
Row(2)])
     #expect(try await spark.sql("SELECT * FROM VALUES 1, 
1").exceptAll(df).count() == 1)
     await spark.stop()
   }
@@ -546,8 +547,8 @@ struct DataFrameTests {
   func intersect() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     let df = try await spark.range(1, 3)
-    #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row("1"), 
Row("2")])
-    #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row("2")])
+    #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row(1), 
Row(2)])
+    #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row(2)])
     #expect(try await df.intersect(spark.range(3, 5)).collect() == [])
     let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
     #expect(try await df2.intersect(df2).count() == 1)
@@ -558,8 +559,8 @@ struct DataFrameTests {
   func intersectAll() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     let df = try await spark.range(1, 3)
-    #expect(try await df.intersectAll(spark.range(1, 5)).collect() == 
[Row("1"), Row("2")])
-    #expect(try await df.intersectAll(spark.range(2, 5)).collect() == 
[Row("2")])
+    #expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row(1), 
Row(2)])
+    #expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row(2)])
     #expect(try await df.intersectAll(spark.range(3, 5)).collect() == [])
     let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
     #expect(try await df2.intersectAll(df2).count() == 2)
@@ -570,8 +571,8 @@ struct DataFrameTests {
   func union() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     let df = try await spark.range(1, 2)
-    #expect(try await df.union(spark.range(1, 3)).collect() == [Row("1"), 
Row("1"), Row("2")])
-    #expect(try await df.union(spark.range(2, 3)).collect() == [Row("1"), 
Row("2")])
+    #expect(try await df.union(spark.range(1, 3)).collect() == [Row(1), 
Row(1), Row(2)])
+    #expect(try await df.union(spark.range(2, 3)).collect() == [Row(1), 
Row(2)])
     let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
     #expect(try await df2.union(df2).count() == 4)
     await spark.stop()
@@ -581,8 +582,8 @@ struct DataFrameTests {
   func unionAll() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     let df = try await spark.range(1, 2)
-    #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row("1"), 
Row("1"), Row("2")])
-    #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row("1"), 
Row("2")])
+    #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row(1), 
Row(1), Row(2)])
+    #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row(1), 
Row(2)])
     let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
     #expect(try await df2.unionAll(df2).count() == 4)
     await spark.stop()
@@ -593,8 +594,8 @@ struct DataFrameTests {
     let spark = try await SparkSession.builder.getOrCreate()
     let df1 = try await spark.sql("SELECT 1 a, 2 b")
     let df2 = try await spark.sql("SELECT 4 b, 3 a")
-    #expect(try await df1.unionByName(df2).collect() == [Row("1", "2"), 
Row("3", "4")])
-    #expect(try await df1.union(df2).collect() == [Row("1", "2"), Row("4", 
"3")])
+    #expect(try await df1.unionByName(df2).collect() == [Row(1, 2), Row(3, 4)])
+    #expect(try await df1.union(df2).collect() == [Row(1, 2), Row(4, 3)])
     let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1")
     #expect(try await df3.unionByName(df3).count() == 4)
     await spark.stop()
@@ -650,7 +651,7 @@ struct DataFrameTests {
   func groupBy() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     let rows = try await spark.range(3).groupBy("id").agg("count(*)", 
"sum(*)", "avg(*)").collect()
-    #expect(rows == [Row("0", "1", "0", "0.0"), Row("1", "1", "1", "1.0"), 
Row("2", "1", "2", "2.0")])
+    #expect(rows == [Row(0, 1, 0, 0.0), Row(1, 1, 1, 1.0), Row(2, 1, 2, 2.0)])
     await spark.stop()
   }
 
@@ -660,18 +661,18 @@ struct DataFrameTests {
     let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model")
       .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
     #expect(rows == [
-      Row("Dublin", "Honda Accord", "10"),
-      Row("Dublin", "Honda CRV", "3"),
-      Row("Dublin", "Honda Civic", "20"),
-      Row("Dublin", nil, "33"),
-      Row("Fremont", "Honda Accord", "15"),
-      Row("Fremont", "Honda CRV", "7"),
-      Row("Fremont", "Honda Civic", "10"),
-      Row("Fremont", nil, "32"),
-      Row("San Jose", "Honda Accord", "8"),
-      Row("San Jose", "Honda Civic", "5"),
-      Row("San Jose", nil, "13"),
-      Row(nil, nil, "78"),
+      Row("Dublin", "Honda Accord", 10),
+      Row("Dublin", "Honda CRV", 3),
+      Row("Dublin", "Honda Civic", 20),
+      Row("Dublin", nil, 33),
+      Row("Fremont", "Honda Accord", 15),
+      Row("Fremont", "Honda CRV", 7),
+      Row("Fremont", "Honda Civic", 10),
+      Row("Fremont", nil, 32),
+      Row("San Jose", "Honda Accord", 8),
+      Row("San Jose", "Honda Civic", 5),
+      Row("San Jose", nil, 13),
+      Row(nil, nil, 78),
     ])
     await spark.stop()
   }
@@ -682,21 +683,21 @@ struct DataFrameTests {
     let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model")
       .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
     #expect(rows == [
-      Row("Dublin", "Honda Accord", "10"),
-      Row("Dublin", "Honda CRV", "3"),
-      Row("Dublin", "Honda Civic", "20"),
-      Row("Dublin", nil, "33"),
-      Row("Fremont", "Honda Accord", "15"),
-      Row("Fremont", "Honda CRV", "7"),
-      Row("Fremont", "Honda Civic", "10"),
-      Row("Fremont", nil, "32"),
-      Row("San Jose", "Honda Accord", "8"),
-      Row("San Jose", "Honda Civic", "5"),
-      Row("San Jose", nil, "13"),
-      Row(nil, "Honda Accord", "33"),
-      Row(nil, "Honda CRV", "10"),
-      Row(nil, "Honda Civic", "35"),
-      Row(nil, nil, "78"),
+      Row("Dublin", "Honda Accord", 10),
+      Row("Dublin", "Honda CRV", 3),
+      Row("Dublin", "Honda Civic", 20),
+      Row("Dublin", nil, 33),
+      Row("Fremont", "Honda Accord", 15),
+      Row("Fremont", "Honda CRV", 7),
+      Row("Fremont", "Honda Civic", 10),
+      Row("Fremont", nil, 32),
+      Row("San Jose", "Honda Accord", 8),
+      Row("San Jose", "Honda Civic", 5),
+      Row("San Jose", nil, 13),
+      Row(nil, "Honda Accord", 33),
+      Row(nil, "Honda CRV", 10),
+      Row(nil, "Honda Civic", 35),
+      Row(nil, nil, 78),
     ])
     await spark.stop()
   }
diff --git a/Tests/SparkConnectTests/Resources/queries/binary.sql.answer 
b/Tests/SparkConnectTests/Resources/queries/binary.sql.answer
index 0d42bcd..52085ed 100644
--- a/Tests/SparkConnectTests/Resources/queries/binary.sql.answer
+++ b/Tests/SparkConnectTests/Resources/queries/binary.sql.answer
@@ -1 +1 @@
-[[61 62 63]]
+[abc]
diff --git a/Tests/SparkConnectTests/SQLTests.swift 
b/Tests/SparkConnectTests/SQLTests.swift
index 498d3d2..f07c409 100644
--- a/Tests/SparkConnectTests/SQLTests.swift
+++ b/Tests/SparkConnectTests/SQLTests.swift
@@ -83,7 +83,7 @@ struct SQLTests {
     for name in try! fm.contentsOfDirectory(atPath: path).sorted() {
       guard name.hasSuffix(".sql") else { continue }
       print(name)
-      if queriesForSpark4Only.contains(name) {
+      if await !spark.version.starts(with: "4.") && 
queriesForSpark4Only.contains(name) {
         print("Skip query \(name) due to the difference between Spark 3 and 
4.")
         continue
       }
diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift 
b/Tests/SparkConnectTests/SparkSessionTests.swift
index dd0c03a..2bc887e 100644
--- a/Tests/SparkConnectTests/SparkSessionTests.swift
+++ b/Tests/SparkConnectTests/SparkSessionTests.swift
@@ -92,7 +92,7 @@ struct SparkSessionTests {
     let spark = try await SparkSession.builder.getOrCreate()
     #expect(try await spark.time(spark.range(1000).count) == 1000)
 #if !os(Linux)
-    #expect(try await spark.time(spark.range(1).collect) == [Row("0")])
+    #expect(try await spark.time(spark.range(1).collect) == [Row(0)])
     try await spark.time(spark.range(10).show)
 #endif
     await spark.stop()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to