This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push:
new 49bb86c [SPARK-51971] Improve `DataFrame.collect` to return the
original values
49bb86c is described below
commit 49bb86c70148fab5e9c62e740441f3d3c02dce86
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Wed Apr 30 22:22:24 2025 -0700
[SPARK-51971] Improve `DataFrame.collect` to return the original values
### What changes were proposed in this pull request?
This PR aims to improve `DataFrame.collect` to return the original values.
Note that this PR provides simple value types first. More types like
`Decimal` will be added later.
### Why are the changes needed?
The initial implementation has a limitation to return rows of `String`
values.
### Does this PR introduce _any_ user-facing change?
No, because there is no released versions yet.
### How was this patch tested?
Pass the CIs.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #99 from dongjoon-hyun/SPARK-51971.
Authored-by: Dongjoon Hyun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
Sources/SparkConnect/Catalog.swift | 14 +-
Sources/SparkConnect/DataFrame.swift | 30 +++-
Sources/SparkConnect/Row.swift | 16 +--
Tests/SparkConnectTests/DataFrameTests.swift | 151 +++++++++++----------
.../Resources/queries/binary.sql.answer | 2 +-
Tests/SparkConnectTests/SQLTests.swift | 2 +-
Tests/SparkConnectTests/SparkSessionTests.swift | 2 +-
7 files changed, 117 insertions(+), 100 deletions(-)
diff --git a/Sources/SparkConnect/Catalog.swift
b/Sources/SparkConnect/Catalog.swift
index 04bc29c..51f056d 100644
--- a/Sources/SparkConnect/Catalog.swift
+++ b/Sources/SparkConnect/Catalog.swift
@@ -252,7 +252,7 @@ public actor Catalog: Sendable {
catalog.tableExists = tableExists
return catalog
})
- return "true" == (try await df.collect().first!.get(0) as! String)
+ return try await df.collect()[0].getAsBool(0)
}
/// Check if the table or view with the specified name exists. This can
either be a temporary
@@ -270,7 +270,7 @@ public actor Catalog: Sendable {
catalog.tableExists = tableExists
return catalog
})
- return "true" == (try await df.collect().first!.get(0) as! String)
+ return try await df.collect()[0].getAsBool(0)
}
/// Check if the function with the specified name exists. This can either be
a temporary function
@@ -287,7 +287,7 @@ public actor Catalog: Sendable {
catalog.functionExists = functionExists
return catalog
})
- return "true" == (try await df.collect().first!.get(0) as! String)
+ return try await df.collect()[0].getAsBool(0)
}
/// Check if the function with the specified name exists in the specified
database under the Hive
@@ -305,7 +305,7 @@ public actor Catalog: Sendable {
catalog.functionExists = functionExists
return catalog
})
- return "true" == (try await df.collect().first!.get(0) as! String)
+ return try await df.collect()[0].getAsBool(0)
}
/// Caches the specified table in-memory.
@@ -338,7 +338,7 @@ public actor Catalog: Sendable {
catalog.isCached = isCached
return catalog
})
- return "true" == (try await df.collect().first!.get(0) as! String)
+ return try await df.collect()[0].getAsBool(0)
}
/// Invalidates and refreshes all the cached data and metadata of the given
table.
@@ -407,7 +407,7 @@ public actor Catalog: Sendable {
catalog.dropTempView = dropTempView
return catalog
})
- return "true" == (try await df.collect().first!.get(0) as! String)
+ return try await df.collect().first!.getAsBool(0)
}
/// Drops the global temporary view with the given view name in the catalog.
If the view has been
@@ -423,6 +423,6 @@ public actor Catalog: Sendable {
catalog.dropGlobalTempView = dropGlobalTempView
return catalog
})
- return "true" == (try await df.collect().first!.get(0) as! String)
+ return try await df.collect()[0].getAsBool(0)
}
}
diff --git a/Sources/SparkConnect/DataFrame.swift
b/Sources/SparkConnect/DataFrame.swift
index d3eb909..5531917 100644
--- a/Sources/SparkConnect/DataFrame.swift
+++ b/Sources/SparkConnect/DataFrame.swift
@@ -208,14 +208,34 @@ public actor DataFrame: Sendable {
for i in 0..<batch.length {
var values: [Sendable?] = []
for column in batch.columns {
- let str = column.array as! AsString
if column.data.isNull(i) {
values.append(nil)
- } else if column.data.type.info == ArrowType.ArrowBinary {
- let binary = str.asString(i).utf8.map { String(format: "%02x", $0)
}.joined(separator: " ")
- values.append("[\(binary)]")
} else {
- values.append(str.asString(i))
+ let array = column.array
+ switch column.data.type.info {
+ case .primitiveInfo(.boolean):
+ values.append(array.asAny(i) as? Bool)
+ case .primitiveInfo(.int8):
+ values.append(array.asAny(i) as? Int8)
+ case .primitiveInfo(.int16):
+ values.append(array.asAny(i) as? Int16)
+ case .primitiveInfo(.int32):
+ values.append(array.asAny(i) as? Int32)
+ case .primitiveInfo(.int64):
+ values.append(array.asAny(i) as! Int64)
+ case .primitiveInfo(.float):
+ values.append(array.asAny(i) as? Float)
+ case .primitiveInfo(.double):
+ values.append(array.asAny(i) as? Double)
+ case .primitiveInfo(.date32):
+ values.append(array.asAny(i) as! Date)
+ case ArrowType.ArrowBinary:
+ values.append((array as! AsString).asString(i).utf8)
+ case .complexInfo(.strct):
+ values.append((array as! AsString).asString(i))
+ default:
+ values.append(array.asAny(i) as? String)
+ }
}
}
result.append(Row(valueArray: values))
diff --git a/Sources/SparkConnect/Row.swift b/Sources/SparkConnect/Row.swift
index 0caf505..67cfcfd 100644
--- a/Sources/SparkConnect/Row.swift
+++ b/Sources/SparkConnect/Row.swift
@@ -50,6 +50,10 @@ public struct Row: Sendable, Equatable {
return values[i]
}
+ public func getAsBool(_ i: Int) throws -> Bool {
+ return try get(i) as! Bool
+ }
+
public static func == (lhs: Row, rhs: Row) -> Bool {
if lhs.values.count != rhs.values.count {
return false
@@ -59,16 +63,8 @@ public struct Row: Sendable, Equatable {
return true
} else if let a = x as? Bool, let b = y as? Bool {
return a == b
- } else if let a = x as? Int, let b = y as? Int {
- return a == b
- } else if let a = x as? Int8, let b = y as? Int8 {
- return a == b
- } else if let a = x as? Int16, let b = y as? Int16 {
- return a == b
- } else if let a = x as? Int32, let b = y as? Int32 {
- return a == b
- } else if let a = x as? Int64, let b = y as? Int64 {
- return a == b
+ } else if let a = x as? any FixedWidthInteger, let b = y as? any
FixedWidthInteger {
+ return Int64(a) == Int64(b)
} else if let a = x as? Float, let b = y as? Float {
return a == b
} else if let a = x as? Double, let b = y as? Double {
diff --git a/Tests/SparkConnectTests/DataFrameTests.swift
b/Tests/SparkConnectTests/DataFrameTests.swift
index bf320e5..13bc7c4 100644
--- a/Tests/SparkConnectTests/DataFrameTests.swift
+++ b/Tests/SparkConnectTests/DataFrameTests.swift
@@ -318,7 +318,7 @@ struct DataFrameTests {
@Test
func sort() async throws {
let spark = try await SparkSession.builder.getOrCreate()
- let expected = Array((1...10).map{ Row(String($0)) })
+ let expected = Array((1...10).map{ Row($0) })
#expect(try await spark.range(10, 0, -1).sort("id").collect() == expected)
await spark.stop()
}
@@ -326,7 +326,7 @@ struct DataFrameTests {
@Test
func orderBy() async throws {
let spark = try await SparkSession.builder.getOrCreate()
- let expected = Array((1...10).map{ Row(String($0)) })
+ let expected = Array((1...10).map{ Row($0) })
#expect(try await spark.range(10, 0, -1).orderBy("id").collect() ==
expected)
await spark.stop()
}
@@ -354,7 +354,7 @@ struct DataFrameTests {
#expect(
try await spark.sql(
"SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false,
'def')"
- ).collect() == [Row("1", "true", "abc"), Row(nil, nil, nil), Row("3",
"false", "def")])
+ ).collect() == [Row(1, true, "abc"), Row(nil, nil, nil), Row(3, false,
"def")])
await spark.stop()
}
@@ -371,10 +371,11 @@ struct DataFrameTests {
func head() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(0).head().isEmpty)
- #expect(try await spark.range(2).sort("id").head() == [Row("0")])
- #expect(try await spark.range(2).sort("id").head(1) == [Row("0")])
- #expect(try await spark.range(2).sort("id").head(2) == [Row("0"),
Row("1")])
- #expect(try await spark.range(2).sort("id").head(3) == [Row("0"),
Row("1")])
+ print(try await spark.range(2).sort("id").head())
+ #expect(try await spark.range(2).sort("id").head() == [Row(0)])
+ #expect(try await spark.range(2).sort("id").head(1) == [Row(0)])
+ #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)])
+ #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)])
await spark.stop()
}
@@ -382,9 +383,9 @@ struct DataFrameTests {
func tail() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(0).tail(1).isEmpty)
- #expect(try await spark.range(2).sort("id").tail(1) == [Row("1")])
- #expect(try await spark.range(2).sort("id").tail(2) == [Row("0"),
Row("1")])
- #expect(try await spark.range(2).sort("id").tail(3) == [Row("0"),
Row("1")])
+ #expect(try await spark.range(2).sort("id").tail(1) == [Row(1)])
+ #expect(try await spark.range(2).sort("id").tail(2) == [Row(0), Row(1)])
+ #expect(try await spark.range(2).sort("id").tail(3) == [Row(0), Row(1)])
await spark.stop()
}
@@ -464,35 +465,35 @@ struct DataFrameTests {
@Test
func join() async throws {
let spark = try await SparkSession.builder.getOrCreate()
- let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2')
AS T(a, b)")
- let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3')
AS S(c, b)")
+ let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS
T(a, b)")
+ let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS
S(c, b)")
let expectedCross = [
- Row("a", "1", "c", "2"),
- Row("a", "1", "d", "3"),
- Row("b", "2", "c", "2"),
- Row("b", "2", "d", "3"),
+ Row("a", 1, "c", 2),
+ Row("a", 1, "d", 3),
+ Row("b", 2, "c", 2),
+ Row("b", 2, "d", 3),
]
#expect(try await df1.join(df2).collect() == expectedCross)
#expect(try await df1.crossJoin(df2).collect() == expectedCross)
- #expect(try await df1.join(df2, "b").collect() == [Row("2", "b", "c")])
- #expect(try await df1.join(df2, ["b"]).collect() == [Row("2", "b", "c")])
+ #expect(try await df1.join(df2, "b").collect() == [Row(2, "b", "c")])
+ #expect(try await df1.join(df2, ["b"]).collect() == [Row(2, "b", "c")])
- #expect(try await df1.join(df2, "b", "left").collect() == [Row("1", "a",
nil), Row("2", "b", "c")])
- #expect(try await df1.join(df2, "b", "right").collect() == [Row("2", "b",
"c"), Row("3", nil, "d")])
- #expect(try await df1.join(df2, "b", "semi").collect() == [Row("2", "b")])
- #expect(try await df1.join(df2, "b", "anti").collect() == [Row("1", "a")])
+ #expect(try await df1.join(df2, "b", "left").collect() == [Row(1, "a",
nil), Row(2, "b", "c")])
+ #expect(try await df1.join(df2, "b", "right").collect() == [Row(2, "b",
"c"), Row(3, nil, "d")])
+ #expect(try await df1.join(df2, "b", "semi").collect() == [Row(2, "b")])
+ #expect(try await df1.join(df2, "b", "anti").collect() == [Row(1, "a")])
let expectedOuter = [
- Row("1", "a", nil),
- Row("2", "b", "c"),
- Row("3", nil, "d"),
+ Row(1, "a", nil),
+ Row(2, "b", "c"),
+ Row(3, nil, "d"),
]
#expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter)
#expect(try await df1.join(df2, "b", "full").collect() == expectedOuter)
#expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter)
- let expected = [Row("b", "2", "c", "2")]
+ let expected = [Row("b", 2, "c", 2)]
#expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() ==
expected)
#expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType:
"inner").collect() == expected)
await spark.stop()
@@ -502,18 +503,18 @@ struct DataFrameTests {
func lateralJoin() async throws {
let spark = try await SparkSession.builder.getOrCreate()
if await spark.version.starts(with: "4.") {
- let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b',
'2') AS T(a, b)")
- let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d',
'3') AS S(c, b)")
+ let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2)
AS T(a, b)")
+ let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3)
AS S(c, b)")
let expectedCross = [
- Row("a", "1", "c", "2"),
- Row("a", "1", "d", "3"),
- Row("b", "2", "c", "2"),
- Row("b", "2", "d", "3"),
+ Row("a", 1, "c", 2),
+ Row("a", 1, "d", 3),
+ Row("b", 2, "c", 2),
+ Row("b", 2, "d", 3),
]
#expect(try await df1.lateralJoin(df2).collect() == expectedCross)
#expect(try await df1.lateralJoin(df2, joinType: "inner").collect() ==
expectedCross)
- let expected = [Row("b", "2", "c", "2")]
+ let expected = [Row("b", 2, "c", 2)]
#expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect()
== expected)
#expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType:
"inner").collect() == expected)
}
@@ -525,8 +526,8 @@ struct DataFrameTests {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 3)
#expect(try await df.except(spark.range(1, 5)).collect() == [])
- #expect(try await df.except(spark.range(2, 5)).collect() == [Row("1")])
- #expect(try await df.except(spark.range(3, 5)).collect() == [Row("1"),
Row("2")])
+ #expect(try await df.except(spark.range(2, 5)).collect() == [Row(1)])
+ #expect(try await df.except(spark.range(3, 5)).collect() == [Row(1),
Row(2)])
#expect(try await spark.sql("SELECT * FROM VALUES 1,
1").except(df).count() == 0)
await spark.stop()
}
@@ -536,8 +537,8 @@ struct DataFrameTests {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 3)
#expect(try await df.exceptAll(spark.range(1, 5)).collect() == [])
- #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row("1")])
- #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row("1"),
Row("2")])
+ #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row(1)])
+ #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row(1),
Row(2)])
#expect(try await spark.sql("SELECT * FROM VALUES 1,
1").exceptAll(df).count() == 1)
await spark.stop()
}
@@ -546,8 +547,8 @@ struct DataFrameTests {
func intersect() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 3)
- #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row("1"),
Row("2")])
- #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row("2")])
+ #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row(1),
Row(2)])
+ #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row(2)])
#expect(try await df.intersect(spark.range(3, 5)).collect() == [])
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df2.intersect(df2).count() == 1)
@@ -558,8 +559,8 @@ struct DataFrameTests {
func intersectAll() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 3)
- #expect(try await df.intersectAll(spark.range(1, 5)).collect() ==
[Row("1"), Row("2")])
- #expect(try await df.intersectAll(spark.range(2, 5)).collect() ==
[Row("2")])
+ #expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row(1),
Row(2)])
+ #expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row(2)])
#expect(try await df.intersectAll(spark.range(3, 5)).collect() == [])
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df2.intersectAll(df2).count() == 2)
@@ -570,8 +571,8 @@ struct DataFrameTests {
func union() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 2)
- #expect(try await df.union(spark.range(1, 3)).collect() == [Row("1"),
Row("1"), Row("2")])
- #expect(try await df.union(spark.range(2, 3)).collect() == [Row("1"),
Row("2")])
+ #expect(try await df.union(spark.range(1, 3)).collect() == [Row(1),
Row(1), Row(2)])
+ #expect(try await df.union(spark.range(2, 3)).collect() == [Row(1),
Row(2)])
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df2.union(df2).count() == 4)
await spark.stop()
@@ -581,8 +582,8 @@ struct DataFrameTests {
func unionAll() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 2)
- #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row("1"),
Row("1"), Row("2")])
- #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row("1"),
Row("2")])
+ #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row(1),
Row(1), Row(2)])
+ #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row(1),
Row(2)])
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df2.unionAll(df2).count() == 4)
await spark.stop()
@@ -593,8 +594,8 @@ struct DataFrameTests {
let spark = try await SparkSession.builder.getOrCreate()
let df1 = try await spark.sql("SELECT 1 a, 2 b")
let df2 = try await spark.sql("SELECT 4 b, 3 a")
- #expect(try await df1.unionByName(df2).collect() == [Row("1", "2"),
Row("3", "4")])
- #expect(try await df1.union(df2).collect() == [Row("1", "2"), Row("4",
"3")])
+ #expect(try await df1.unionByName(df2).collect() == [Row(1, 2), Row(3, 4)])
+ #expect(try await df1.union(df2).collect() == [Row(1, 2), Row(4, 3)])
let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df3.unionByName(df3).count() == 4)
await spark.stop()
@@ -650,7 +651,7 @@ struct DataFrameTests {
func groupBy() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let rows = try await spark.range(3).groupBy("id").agg("count(*)",
"sum(*)", "avg(*)").collect()
- #expect(rows == [Row("0", "1", "0", "0.0"), Row("1", "1", "1", "1.0"),
Row("2", "1", "2", "2.0")])
+ #expect(rows == [Row(0, 1, 0, 0.0), Row(1, 1, 1, 1.0), Row(2, 1, 2, 2.0)])
await spark.stop()
}
@@ -660,18 +661,18 @@ struct DataFrameTests {
let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model")
.agg("sum(quantity) sum").orderBy("city", "car_model").collect()
#expect(rows == [
- Row("Dublin", "Honda Accord", "10"),
- Row("Dublin", "Honda CRV", "3"),
- Row("Dublin", "Honda Civic", "20"),
- Row("Dublin", nil, "33"),
- Row("Fremont", "Honda Accord", "15"),
- Row("Fremont", "Honda CRV", "7"),
- Row("Fremont", "Honda Civic", "10"),
- Row("Fremont", nil, "32"),
- Row("San Jose", "Honda Accord", "8"),
- Row("San Jose", "Honda Civic", "5"),
- Row("San Jose", nil, "13"),
- Row(nil, nil, "78"),
+ Row("Dublin", "Honda Accord", 10),
+ Row("Dublin", "Honda CRV", 3),
+ Row("Dublin", "Honda Civic", 20),
+ Row("Dublin", nil, 33),
+ Row("Fremont", "Honda Accord", 15),
+ Row("Fremont", "Honda CRV", 7),
+ Row("Fremont", "Honda Civic", 10),
+ Row("Fremont", nil, 32),
+ Row("San Jose", "Honda Accord", 8),
+ Row("San Jose", "Honda Civic", 5),
+ Row("San Jose", nil, 13),
+ Row(nil, nil, 78),
])
await spark.stop()
}
@@ -682,21 +683,21 @@ struct DataFrameTests {
let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model")
.agg("sum(quantity) sum").orderBy("city", "car_model").collect()
#expect(rows == [
- Row("Dublin", "Honda Accord", "10"),
- Row("Dublin", "Honda CRV", "3"),
- Row("Dublin", "Honda Civic", "20"),
- Row("Dublin", nil, "33"),
- Row("Fremont", "Honda Accord", "15"),
- Row("Fremont", "Honda CRV", "7"),
- Row("Fremont", "Honda Civic", "10"),
- Row("Fremont", nil, "32"),
- Row("San Jose", "Honda Accord", "8"),
- Row("San Jose", "Honda Civic", "5"),
- Row("San Jose", nil, "13"),
- Row(nil, "Honda Accord", "33"),
- Row(nil, "Honda CRV", "10"),
- Row(nil, "Honda Civic", "35"),
- Row(nil, nil, "78"),
+ Row("Dublin", "Honda Accord", 10),
+ Row("Dublin", "Honda CRV", 3),
+ Row("Dublin", "Honda Civic", 20),
+ Row("Dublin", nil, 33),
+ Row("Fremont", "Honda Accord", 15),
+ Row("Fremont", "Honda CRV", 7),
+ Row("Fremont", "Honda Civic", 10),
+ Row("Fremont", nil, 32),
+ Row("San Jose", "Honda Accord", 8),
+ Row("San Jose", "Honda Civic", 5),
+ Row("San Jose", nil, 13),
+ Row(nil, "Honda Accord", 33),
+ Row(nil, "Honda CRV", 10),
+ Row(nil, "Honda Civic", 35),
+ Row(nil, nil, 78),
])
await spark.stop()
}
diff --git a/Tests/SparkConnectTests/Resources/queries/binary.sql.answer
b/Tests/SparkConnectTests/Resources/queries/binary.sql.answer
index 0d42bcd..52085ed 100644
--- a/Tests/SparkConnectTests/Resources/queries/binary.sql.answer
+++ b/Tests/SparkConnectTests/Resources/queries/binary.sql.answer
@@ -1 +1 @@
-[[61 62 63]]
+[abc]
diff --git a/Tests/SparkConnectTests/SQLTests.swift
b/Tests/SparkConnectTests/SQLTests.swift
index 498d3d2..f07c409 100644
--- a/Tests/SparkConnectTests/SQLTests.swift
+++ b/Tests/SparkConnectTests/SQLTests.swift
@@ -83,7 +83,7 @@ struct SQLTests {
for name in try! fm.contentsOfDirectory(atPath: path).sorted() {
guard name.hasSuffix(".sql") else { continue }
print(name)
- if queriesForSpark4Only.contains(name) {
+ if await !spark.version.starts(with: "4.") &&
queriesForSpark4Only.contains(name) {
print("Skip query \(name) due to the difference between Spark 3 and
4.")
continue
}
diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift
b/Tests/SparkConnectTests/SparkSessionTests.swift
index dd0c03a..2bc887e 100644
--- a/Tests/SparkConnectTests/SparkSessionTests.swift
+++ b/Tests/SparkConnectTests/SparkSessionTests.swift
@@ -92,7 +92,7 @@ struct SparkSessionTests {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.time(spark.range(1000).count) == 1000)
#if !os(Linux)
- #expect(try await spark.time(spark.range(1).collect) == [Row("0")])
+ #expect(try await spark.time(spark.range(1).collect) == [Row(0)])
try await spark.time(spark.range(10).show)
#endif
await spark.stop()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]