This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push:
new 056c0cf [SPARK-52522] Reapply `swift format`
056c0cf is described below
commit 056c0cf55c0267666c143c501f3b8740ca10b09d
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Tue Jun 17 20:20:22 2025 -0700
[SPARK-52522] Reapply `swift format`
### What changes were proposed in this pull request?
This PR aims to re-apply `swift format` to all source codes.
### Why are the changes needed?
To tidy up the source code.
### Does this PR introduce _any_ user-facing change?
No behavior change.
### How was this patch tested?
Pass the CIs.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #200 from dongjoon-hyun/SPARK-52522.
Authored-by: Dongjoon Hyun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
Sources/SparkConnect/ArrowArray.swift | 13 +-
Sources/SparkConnect/ArrowArrayBuilder.swift | 3 +-
Sources/SparkConnect/ArrowType.swift | 4 +-
Sources/SparkConnect/Catalog.swift | 24 +-
Sources/SparkConnect/DataFrame.swift | 112 ++-
Sources/SparkConnect/DataFrameReader.swift | 4 +-
Sources/SparkConnect/DataFrameWriter.swift | 4 +-
Sources/SparkConnect/Extension.swift | 47 +-
Sources/SparkConnect/SparkConnectClient.swift | 130 +--
Sources/SparkConnect/SparkSession.swift | 22 +-
Tests/SparkConnectTests/CatalogTests.swift | 487 ++++-----
.../SparkConnectTests/DataFrameInternalTests.swift | 112 +--
Tests/SparkConnectTests/DataFrameReaderTests.swift | 13 +-
Tests/SparkConnectTests/DataFrameTests.swift | 1038 ++++++++++----------
Tests/SparkConnectTests/DataFrameWriterTests.swift | 3 +-
Tests/SparkConnectTests/SQLTests.swift | 59 +-
.../SparkConnectClientTests.swift | 3 +-
Tests/SparkConnectTests/SparkSessionTests.swift | 135 +--
18 files changed, 1152 insertions(+), 1061 deletions(-)
diff --git a/Sources/SparkConnect/ArrowArray.swift
b/Sources/SparkConnect/ArrowArray.swift
index a767b6e..cc348ed 100644
--- a/Sources/SparkConnect/ArrowArray.swift
+++ b/Sources/SparkConnect/ArrowArray.swift
@@ -255,12 +255,13 @@ public class Decimal128Array: FixedArray<Decimal> {
if self.arrowData.isNull(index) {
return nil
}
- let scale: Int32 = switch self.arrowData.type.id {
- case .decimal128(_, let scale):
- scale
- default:
- 18
- }
+ let scale: Int32 =
+ switch self.arrowData.type.id {
+ case .decimal128(_, let scale):
+ scale
+ default:
+ 18
+ }
let byteOffset = self.arrowData.stride * Int(index)
let value = self.arrowData.buffers[1].rawPointer.advanced(by:
byteOffset).load(
as: UInt64.self)
diff --git a/Sources/SparkConnect/ArrowArrayBuilder.swift
b/Sources/SparkConnect/ArrowArrayBuilder.swift
index 20b3f27..da7074c 100644
--- a/Sources/SparkConnect/ArrowArrayBuilder.swift
+++ b/Sources/SparkConnect/ArrowArrayBuilder.swift
@@ -122,7 +122,8 @@ public class Time64ArrayBuilder:
ArrowArrayBuilder<FixedBufferBuilder<Time64>, T
}
}
-public class Decimal128ArrayBuilder:
ArrowArrayBuilder<FixedBufferBuilder<Decimal>, Decimal128Array> {
+public class Decimal128ArrayBuilder:
ArrowArrayBuilder<FixedBufferBuilder<Decimal>, Decimal128Array>
+{
fileprivate convenience init(precision: Int32, scale: Int32) throws {
try self.init(ArrowTypeDecimal128(precision: precision, scale: scale))
}
diff --git a/Sources/SparkConnect/ArrowType.swift
b/Sources/SparkConnect/ArrowType.swift
index 39555f3..a617b3a 100644
--- a/Sources/SparkConnect/ArrowType.swift
+++ b/Sources/SparkConnect/ArrowType.swift
@@ -294,7 +294,7 @@ public class ArrowType {
case .double:
return MemoryLayout<Double>.stride
case .decimal128:
- return 16 // Decimal 128 (= 16 * 8) bits
+ return 16 // Decimal 128 (= 16 * 8) bits
case .boolean:
return MemoryLayout<Bool>.stride
case .date32:
@@ -429,7 +429,7 @@ extension ArrowType.Info: Equatable {
case (.timeInfo(let lhsId), .timeInfo(let rhsId)):
return lhsId == rhsId
case (.complexInfo(let lhsId), .complexInfo(let rhsId)):
- return lhsId == rhsId
+ return lhsId == rhsId
default:
return false
}
diff --git a/Sources/SparkConnect/Catalog.swift
b/Sources/SparkConnect/Catalog.swift
index 4f3c917..c1b23d4 100644
--- a/Sources/SparkConnect/Catalog.swift
+++ b/Sources/SparkConnect/Catalog.swift
@@ -40,15 +40,13 @@ public struct SparkTable: Sendable, Equatable {
public var tableType: String
public var isTemporary: Bool
public var database: String? {
- get {
- guard let namespace else {
- return nil
- }
- if namespace.count == 1 {
- return namespace[0]
- } else {
- return nil
- }
+ guard let namespace else {
+ return nil
+ }
+ if namespace.count == 1 {
+ return namespace[0]
+ } else {
+ return nil
}
}
}
@@ -173,7 +171,9 @@ public actor Catalog: Sendable {
return catalog
})
return try await df.collect().map {
- try Database(name: $0[0] as! String, catalog: $0[1] as? String,
description: $0[2] as? String, locationUri: $0[3] as! String)
+ try Database(
+ name: $0[0] as! String, catalog: $0[1] as? String, description: $0[2]
as? String,
+ locationUri: $0[3] as! String)
}
}
@@ -189,7 +189,9 @@ public actor Catalog: Sendable {
return catalog
})
return try await df.collect().map {
- try Database(name: $0[0] as! String, catalog: $0[1] as? String,
description: $0[2] as? String, locationUri: $0[3] as! String)
+ try Database(
+ name: $0[0] as! String, catalog: $0[1] as? String, description: $0[2]
as? String,
+ locationUri: $0[3] as! String)
}.first!
}
diff --git a/Sources/SparkConnect/DataFrame.swift
b/Sources/SparkConnect/DataFrame.swift
index db720ba..2f590de 100644
--- a/Sources/SparkConnect/DataFrame.swift
+++ b/Sources/SparkConnect/DataFrame.swift
@@ -494,11 +494,12 @@ public actor DataFrame: Sendable {
/// - Parameter cols: Column names
/// - Returns: A ``DataFrame`` with subset of columns.
public func toDF(_ cols: String...) -> DataFrame {
- let df = if cols.isEmpty {
- DataFrame(spark: self.spark, plan: self.plan)
- } else {
- DataFrame(spark: self.spark, plan:
SparkConnectClient.getProject(self.plan.root, cols))
- }
+ let df =
+ if cols.isEmpty {
+ DataFrame(spark: self.spark, plan: self.plan)
+ } else {
+ DataFrame(spark: self.spark, plan:
SparkConnectClient.getProject(self.plan.root, cols))
+ }
return df
}
@@ -507,7 +508,8 @@ public actor DataFrame: Sendable {
/// - Returns: A ``DataFrame`` with the given schema.
public func to(_ schema: String) async throws -> DataFrame {
let dataType = try await sparkSession.client.ddlParse(schema)
- return DataFrame(spark: self.spark, plan:
SparkConnectClient.getToSchema(self.plan.root, dataType))
+ return DataFrame(
+ spark: self.spark, plan: SparkConnectClient.getToSchema(self.plan.root,
dataType))
}
/// Returns the content of the Dataset as a Dataset of JSON strings.
@@ -520,7 +522,8 @@ public actor DataFrame: Sendable {
/// - Parameter exprs: Expression strings
/// - Returns: A ``DataFrame`` with subset of columns.
public func selectExpr(_ exprs: String...) -> DataFrame {
- return DataFrame(spark: self.spark, plan:
SparkConnectClient.getProjectExprs(self.plan.root, exprs))
+ return DataFrame(
+ spark: self.spark, plan:
SparkConnectClient.getProjectExprs(self.plan.root, exprs))
}
/// Returns a new Dataset with a column dropped. This is a no-op if schema
doesn't contain column name.
@@ -564,7 +567,8 @@ public actor DataFrame: Sendable {
/// - Parameter statistics: Statistics names.
/// - Returns: A ``DataFrame`` containing specified statistics.
public func summary(_ statistics: String...) -> DataFrame {
- return DataFrame(spark: self.spark, plan:
SparkConnectClient.getSummary(self.plan.root, statistics))
+ return DataFrame(
+ spark: self.spark, plan: SparkConnectClient.getSummary(self.plan.root,
statistics))
}
/// Returns a new Dataset with a column renamed. This is a no-op if schema
doesn't contain existingName.
@@ -583,14 +587,16 @@ public actor DataFrame: Sendable {
/// - Returns: A ``DataFrame`` with the renamed columns.
public func withColumnRenamed(_ colNames: [String], _ newColNames: [String])
-> DataFrame {
let dic = Dictionary(uniqueKeysWithValues: zip(colNames, newColNames))
- return DataFrame(spark: self.spark, plan:
SparkConnectClient.getWithColumnRenamed(self.plan.root, dic))
+ return DataFrame(
+ spark: self.spark, plan:
SparkConnectClient.getWithColumnRenamed(self.plan.root, dic))
}
/// Returns a new Dataset with columns renamed. This is a no-op if schema
doesn't contain existingName.
/// - Parameter colsMap: A dictionary of existing column name and new column
name.
/// - Returns: A ``DataFrame`` with the renamed columns.
public func withColumnRenamed(_ colsMap: [String: String]) -> DataFrame {
- return DataFrame(spark: self.spark, plan:
SparkConnectClient.getWithColumnRenamed(self.plan.root, colsMap))
+ return DataFrame(
+ spark: self.spark, plan:
SparkConnectClient.getWithColumnRenamed(self.plan.root, colsMap))
}
/// Filters rows using the given condition.
@@ -611,7 +617,8 @@ public actor DataFrame: Sendable {
/// - Parameter conditionExpr: A SQL expression string for filtering
/// - Returns: A new DataFrame containing only rows that match the condition
public func filter(_ conditionExpr: String) -> DataFrame {
- return DataFrame(spark: self.spark, plan:
SparkConnectClient.getFilter(self.plan.root, conditionExpr))
+ return DataFrame(
+ spark: self.spark, plan: SparkConnectClient.getFilter(self.plan.root,
conditionExpr))
}
/// Filters rows using the given condition (alias for filter).
@@ -691,7 +698,9 @@ public actor DataFrame: Sendable {
/// - seed: Seed for sampling.
/// - Returns: A subset of the records.
public func sample(_ withReplacement: Bool, _ fraction: Double, _ seed:
Int64) -> DataFrame {
- return DataFrame(spark: self.spark, plan:
SparkConnectClient.getSample(self.plan.root, withReplacement, fraction, seed))
+ return DataFrame(
+ spark: self.spark,
+ plan: SparkConnectClient.getSample(self.plan.root, withReplacement,
fraction, seed))
}
/// Returns a new ``Dataset`` by sampling a fraction of rows, using a random
seed.
@@ -765,7 +774,7 @@ public actor DataFrame: Sendable {
/// - Parameter n: The number of rows.
/// - Returns: ``[Row]``
public func tail(_ n: Int32) async throws -> [Row] {
- let lastN = DataFrame(spark:spark, plan:
SparkConnectClient.getTail(self.plan.root, n))
+ let lastN = DataFrame(spark: spark, plan:
SparkConnectClient.getTail(self.plan.root, n))
return try await lastN.collect()
}
@@ -786,7 +795,8 @@ public actor DataFrame: Sendable {
public func isStreaming() async throws -> Bool {
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
- let response = try await
service.analyzePlan(spark.client.getIsStreaming(spark.sessionID, plan))
+ let response = try await service.analyzePlan(
+ spark.client.getIsStreaming(spark.sessionID, plan))
return response.isStreaming.isStreaming
}
}
@@ -850,8 +860,10 @@ public actor DataFrame: Sendable {
get async throws {
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping:
client)
- return try await service
- .analyzePlan(spark.client.getStorageLevel(spark.sessionID,
plan)).getStorageLevel.storageLevel.toStorageLevel
+ return
+ try await service
+ .analyzePlan(spark.client.getStorageLevel(spark.sessionID,
plan)).getStorageLevel
+ .storageLevel.toStorageLevel
}
}
}
@@ -878,7 +890,7 @@ public actor DataFrame: Sendable {
/// Prints the plans (logical and physical) to the console for debugging
purposes.
/// - Parameter extended: If `false`, prints only the physical plan.
public func explain(_ extended: Bool) async throws {
- if (extended) {
+ if extended {
try await explain("extended")
} else {
try await explain("simple")
@@ -891,7 +903,8 @@ public actor DataFrame: Sendable {
public func explain(_ mode: String) async throws {
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
- let response = try await
service.analyzePlan(spark.client.getExplain(spark.sessionID, plan, mode))
+ let response = try await service.analyzePlan(
+ spark.client.getExplain(spark.sessionID, plan, mode))
print(response.explain.explainString)
}
}
@@ -903,7 +916,8 @@ public actor DataFrame: Sendable {
public func inputFiles() async throws -> [String] {
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
- let response = try await
service.analyzePlan(spark.client.getInputFiles(spark.sessionID, plan))
+ let response = try await service.analyzePlan(
+ spark.client.getInputFiles(spark.sessionID, plan))
return response.inputFiles.files
}
}
@@ -918,7 +932,8 @@ public actor DataFrame: Sendable {
public func printSchema(_ level: Int32) async throws {
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
- let response = try await
service.analyzePlan(spark.client.getTreeString(spark.sessionID, plan, level))
+ let response = try await service.analyzePlan(
+ spark.client.getTreeString(spark.sessionID, plan, level))
print(response.treeString.treeString)
}
}
@@ -964,7 +979,9 @@ public actor DataFrame: Sendable {
/// - usingColumn: Column name that exists in both DataFrames
/// - joinType: Type of join (default: "inner")
/// - Returns: A new DataFrame with the join result
- public func join(_ right: DataFrame, _ usingColumn: String, _ joinType:
String = "inner") async -> DataFrame {
+ public func join(_ right: DataFrame, _ usingColumn: String, _ joinType:
String = "inner") async
+ -> DataFrame
+ {
await join(right, [usingColumn], joinType)
}
@@ -974,7 +991,9 @@ public actor DataFrame: Sendable {
/// - usingColumn: Names of the columns to join on. These columns must
exist on both sides.
/// - joinType: A join type name.
/// - Returns: A `DataFrame`.
- public func join(_ other: DataFrame, _ usingColumns: [String], _ joinType:
String = "inner") async -> DataFrame {
+ public func join(_ other: DataFrame, _ usingColumns: [String], _ joinType:
String = "inner") async
+ -> DataFrame
+ {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getJoin(
self.plan.root,
@@ -1112,7 +1131,8 @@ public actor DataFrame: Sendable {
/// - Returns: A `DataFrame`.
public func exceptAll(_ other: DataFrame) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
- let plan = SparkConnectClient.getSetOperation(self.plan.root, right,
SetOpType.except, isAll: true)
+ let plan = SparkConnectClient.getSetOperation(
+ self.plan.root, right, SetOpType.except, isAll: true)
return DataFrame(spark: self.spark, plan: plan)
}
@@ -1132,7 +1152,8 @@ public actor DataFrame: Sendable {
/// - Returns: A `DataFrame`.
public func intersectAll(_ other: DataFrame) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
- let plan = SparkConnectClient.getSetOperation(self.plan.root, right,
SetOpType.intersect, isAll: true)
+ let plan = SparkConnectClient.getSetOperation(
+ self.plan.root, right, SetOpType.intersect, isAll: true)
return DataFrame(spark: self.spark, plan: plan)
}
@@ -1144,7 +1165,8 @@ public actor DataFrame: Sendable {
/// - Returns: A `DataFrame`.
public func union(_ other: DataFrame) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
- let plan = SparkConnectClient.getSetOperation(self.plan.root, right,
SetOpType.union, isAll: true)
+ let plan = SparkConnectClient.getSetOperation(
+ self.plan.root, right, SetOpType.union, isAll: true)
return DataFrame(spark: self.spark, plan: plan)
}
@@ -1164,7 +1186,9 @@ public actor DataFrame: Sendable {
/// of this `DataFrame` will be added at the end in the schema of the union
result
/// - Parameter other: A `DataFrame` to union with.
/// - Returns: A `DataFrame`.
- public func unionByName(_ other: DataFrame, _ allowMissingColumns: Bool =
false) async -> DataFrame {
+ public func unionByName(_ other: DataFrame, _ allowMissingColumns: Bool =
false) async
+ -> DataFrame
+ {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getSetOperation(
self.plan.root,
@@ -1182,8 +1206,11 @@ public actor DataFrame: Sendable {
return DataFrame(spark: self.spark, plan: plan)
}
- private func buildRepartitionByExpression(numPartitions: Int32?,
partitionExprs: [String]) -> DataFrame {
- let plan = SparkConnectClient.getRepartitionByExpression(self.plan.root,
partitionExprs, numPartitions)
+ private func buildRepartitionByExpression(numPartitions: Int32?,
partitionExprs: [String])
+ -> DataFrame
+ {
+ let plan = SparkConnectClient.getRepartitionByExpression(
+ self.plan.root, partitionExprs, numPartitions)
return DataFrame(spark: self.spark, plan: plan)
}
@@ -1211,7 +1238,8 @@ public actor DataFrame: Sendable {
/// - partitionExprs: The partition expression strings.
/// - Returns: A `DataFrame`.
public func repartition(_ numPartitions: Int32, _ partitionExprs: String...)
-> DataFrame {
- return buildRepartitionByExpression(numPartitions: numPartitions,
partitionExprs: partitionExprs)
+ return buildRepartitionByExpression(
+ numPartitions: numPartitions, partitionExprs: partitionExprs)
}
/// Returns a new ``DataFrame`` partitioned by the given partitioning
expressions, using
@@ -1219,8 +1247,11 @@ public actor DataFrame: Sendable {
/// partitioned.
/// - Parameter partitionExprs: The partition expression strings.
/// - Returns: A `DataFrame`.
- public func repartitionByExpression(_ numPartitions: Int32?, _
partitionExprs: String...) -> DataFrame {
- return buildRepartitionByExpression(numPartitions: numPartitions,
partitionExprs: partitionExprs)
+ public func repartitionByExpression(_ numPartitions: Int32?, _
partitionExprs: String...)
+ -> DataFrame
+ {
+ return buildRepartitionByExpression(
+ numPartitions: numPartitions, partitionExprs: partitionExprs)
}
/// Returns a new ``DataFrame`` that has exactly `numPartitions` partitions,
when the fewer partitions
@@ -1322,7 +1353,8 @@ public actor DataFrame: Sendable {
_ variableColumnName: String,
_ valueColumnName: String
) -> DataFrame {
- let plan = SparkConnectClient.getUnpivot(self.plan.root, ids, values,
variableColumnName, valueColumnName)
+ let plan = SparkConnectClient.getUnpivot(
+ self.plan.root, ids, values, variableColumnName, valueColumnName)
return DataFrame(spark: self.spark, plan: plan)
}
@@ -1421,7 +1453,8 @@ public actor DataFrame: Sendable {
}
func createTempView(_ viewName: String, replace: Bool, global: Bool) async
throws {
- try await spark.client.createTempView(self.plan.root, viewName, replace:
replace, isGlobal: global)
+ try await spark.client.createTempView(
+ self.plan.root, viewName, replace: replace, isGlobal: global)
}
/// Eagerly checkpoint a ``DataFrame`` and return the new ``DataFrame``.
@@ -1439,7 +1472,8 @@ public actor DataFrame: Sendable {
_ reliableCheckpoint: Bool = true,
_ storageLevel: StorageLevel? = nil
) async throws -> DataFrame {
- let plan = try await spark.client.getCheckpoint(self.plan.root, eager,
reliableCheckpoint, storageLevel)
+ let plan = try await spark.client.getCheckpoint(
+ self.plan.root, eager, reliableCheckpoint, storageLevel)
return DataFrame(spark: self.spark, plan: plan)
}
@@ -1474,9 +1508,7 @@ public actor DataFrame: Sendable {
/// Returns a ``DataFrameWriter`` that can be used to write non-streaming
data.
public var write: DataFrameWriter {
- get {
- DataFrameWriter(df: self)
- }
+ DataFrameWriter(df: self)
}
/// Create a write configuration builder for v2 sources.
@@ -1485,7 +1517,7 @@ public actor DataFrame: Sendable {
public func writeTo(_ table: String) -> DataFrameWriterV2 {
return DataFrameWriterV2(table, self)
}
-
+
/// Merges a set of updates, insertions, and deletions based on a source
table into a target table.
/// - Parameters:
/// - table: A target table name.
@@ -1497,8 +1529,6 @@ public actor DataFrame: Sendable {
/// Returns a ``DataStreamWriter`` that can be used to write streaming data.
public var writeStream: DataStreamWriter {
- get {
- DataStreamWriter(df: self)
- }
+ DataStreamWriter(df: self)
}
}
diff --git a/Sources/SparkConnect/DataFrameReader.swift
b/Sources/SparkConnect/DataFrameReader.swift
index 274efdf..9c2076e 100644
--- a/Sources/SparkConnect/DataFrameReader.swift
+++ b/Sources/SparkConnect/DataFrameReader.swift
@@ -261,7 +261,9 @@ public actor DataFrameReader: Sendable {
/// - table: The JDBC table that should be read from or written into.
/// - properties: A string-string dictionary for connection properties.
/// - Returns: A `DataFrame`.
- public func jdbc(_ url: String, _ table: String, _ properties: [String:
String] = [:]) -> DataFrame {
+ public func jdbc(_ url: String, _ table: String, _ properties: [String:
String] = [:])
+ -> DataFrame
+ {
for (key, value) in properties {
self.extraOptions[key] = value
}
diff --git a/Sources/SparkConnect/DataFrameWriter.swift
b/Sources/SparkConnect/DataFrameWriter.swift
index 11a5fa8..38492ac 100644
--- a/Sources/SparkConnect/DataFrameWriter.swift
+++ b/Sources/SparkConnect/DataFrameWriter.swift
@@ -236,7 +236,9 @@ public actor DataFrameWriter: Sendable {
/// - url: The JDBC URL of the form `jdbc:subprotocol:subname` to connect
to.
/// - table: Name of the table in the external database.
/// - properties:JDBC database connection arguments, a list of arbitrary
string tag/value.
- public func jdbc(_ url: String, _ table: String, _ properties: [String:
String] = [:]) async throws {
+ public func jdbc(_ url: String, _ table: String, _ properties: [String:
String] = [:])
+ async throws
+ {
for (key, value) in properties {
self.extraOptions[key] = value
}
diff --git a/Sources/SparkConnect/Extension.swift
b/Sources/SparkConnect/Extension.swift
index f7d869e..4307a94 100644
--- a/Sources/SparkConnect/Extension.swift
+++ b/Sources/SparkConnect/Extension.swift
@@ -135,14 +135,15 @@ extension String {
}
var toExplainMode: ExplainMode {
- let mode = switch self {
- case "codegen": ExplainMode.codegen
- case "cost": ExplainMode.cost
- case "extended": ExplainMode.extended
- case "formatted": ExplainMode.formatted
- case "simple": ExplainMode.simple
- default: ExplainMode.simple
- }
+ let mode =
+ switch self {
+ case "codegen": ExplainMode.codegen
+ case "cost": ExplainMode.cost
+ case "extended": ExplainMode.extended
+ case "formatted": ExplainMode.formatted
+ case "simple": ExplainMode.simple
+ default: ExplainMode.simple
+ }
return mode
}
@@ -220,13 +221,14 @@ extension YearMonthInterval {
func toString() throws -> String {
let startFieldName = try fieldToString(self.startField)
let endFieldName = try fieldToString(self.endField)
- let interval = if startFieldName == endFieldName {
- "interval \(startFieldName)"
- } else if startFieldName < endFieldName {
- "interval \(startFieldName) to \(endFieldName)"
- } else {
- throw SparkConnectError.InvalidType
- }
+ let interval =
+ if startFieldName == endFieldName {
+ "interval \(startFieldName)"
+ } else if startFieldName < endFieldName {
+ "interval \(startFieldName) to \(endFieldName)"
+ } else {
+ throw SparkConnectError.InvalidType
+ }
return interval
}
}
@@ -246,13 +248,14 @@ extension DayTimeInterval {
func toString() throws -> String {
let startFieldName = try fieldToString(self.startField)
let endFieldName = try fieldToString(self.endField)
- let interval = if startFieldName == endFieldName {
- "interval \(startFieldName)"
- } else if startFieldName < endFieldName {
- "interval \(startFieldName) to \(endFieldName)"
- } else {
- throw SparkConnectError.InvalidType
- }
+ let interval =
+ if startFieldName == endFieldName {
+ "interval \(startFieldName)"
+ } else if startFieldName < endFieldName {
+ "interval \(startFieldName) to \(endFieldName)"
+ } else {
+ throw SparkConnectError.InvalidType
+ }
return interval
}
}
diff --git a/Sources/SparkConnect/SparkConnectClient.swift
b/Sources/SparkConnect/SparkConnectClient.swift
index c1c9bd1..208601f 100644
--- a/Sources/SparkConnect/SparkConnectClient.swift
+++ b/Sources/SparkConnect/SparkConnectClient.swift
@@ -44,11 +44,11 @@ public actor SparkConnectClient {
self.port = self.url.port ?? 15002
var token: String? = nil
let processInfo = ProcessInfo.processInfo
-#if os(macOS) || os(Linux)
- var userName = processInfo.environment["SPARK_USER"] ??
processInfo.userName
-#else
- var userName = processInfo.environment["SPARK_USER"] ?? ""
-#endif
+ #if os(macOS) || os(Linux)
+ var userName = processInfo.environment["SPARK_USER"] ??
processInfo.userName
+ #else
+ var userName = processInfo.environment["SPARK_USER"] ?? ""
+ #endif
for param in self.url.path.split(separator: ";").dropFirst().filter({
!$0.isEmpty }) {
let kv = param.split(separator: "=")
switch String(kv[0]).lowercased() {
@@ -109,9 +109,11 @@ public actor SparkConnectClient {
self.sessionID = sessionID
let service = SparkConnectService.Client(wrapping: client)
- let request = analyze(self.sessionID!, {
- return OneOf_Analyze.sparkVersion(AnalyzePlanRequest.SparkVersion())
- })
+ let request = analyze(
+ self.sessionID!,
+ {
+ return OneOf_Analyze.sparkVersion(AnalyzePlanRequest.SparkVersion())
+ })
let response = try await service.analyzePlan(request)
return response
}
@@ -193,7 +195,7 @@ public actor SparkConnectClient {
request.operation.opType = .unset(unset)
return request
}
-
+
/// Request the server to unset keys
/// - Parameter keys: An array of keys
/// - Returns: Always return true
@@ -263,11 +265,12 @@ public actor SparkConnectClient {
request.userContext = userContext
request.sessionID = self.sessionID!
let response = try await service.config(request)
- let result = if response.pairs[0].hasValue {
- response.pairs[0].value
- } else {
- value
- }
+ let result =
+ if response.pairs[0].hasValue {
+ response.pairs[0].value
+ } else {
+ value
+ }
return result
}
}
@@ -295,11 +298,12 @@ public actor SparkConnectClient {
request.userContext = userContext
request.sessionID = self.sessionID!
let response = try await service.config(request)
- let result: String? = if response.pairs[0].hasValue {
- response.pairs[0].value
- } else {
- nil
- }
+ let result: String? =
+ if response.pairs[0].hasValue {
+ response.pairs[0].value
+ } else {
+ nil
+ }
return result
}
}
@@ -414,11 +418,13 @@ public actor SparkConnectClient {
func getAnalyzePlanRequest(_ sessionID: String, _ plan: Plan) async
-> AnalyzePlanRequest
{
- return analyze(sessionID, {
- var schema = AnalyzePlanRequest.Schema()
- schema.plan = plan
- return OneOf_Analyze.schema(schema)
- })
+ return analyze(
+ sessionID,
+ {
+ var schema = AnalyzePlanRequest.Schema()
+ schema.plan = plan
+ return OneOf_Analyze.schema(schema)
+ })
}
private func analyze(_ sessionID: String, _ f: () -> OneOf_Analyze) ->
AnalyzePlanRequest {
@@ -456,8 +462,7 @@ public actor SparkConnectClient {
})
}
- func getStorageLevel(_ sessionID: String, _ plan: Plan) async ->
AnalyzePlanRequest
- {
+ func getStorageLevel(_ sessionID: String, _ plan: Plan) async ->
AnalyzePlanRequest {
return analyze(
sessionID,
{
@@ -467,8 +472,7 @@ public actor SparkConnectClient {
})
}
- func getExplain(_ sessionID: String, _ plan: Plan, _ mode: String) async ->
AnalyzePlanRequest
- {
+ func getExplain(_ sessionID: String, _ plan: Plan, _ mode: String) async ->
AnalyzePlanRequest {
return analyze(
sessionID,
{
@@ -479,8 +483,7 @@ public actor SparkConnectClient {
})
}
- func getInputFiles(_ sessionID: String, _ plan: Plan) async ->
AnalyzePlanRequest
- {
+ func getInputFiles(_ sessionID: String, _ plan: Plan) async ->
AnalyzePlanRequest {
return analyze(
sessionID,
{
@@ -670,7 +673,9 @@ public actor SparkConnectClient {
return plan
}
- static func getSample(_ child: Relation, _ withReplacement: Bool, _
fraction: Double, _ seed: Int64) -> Plan {
+ static func getSample(
+ _ child: Relation, _ withReplacement: Bool, _ fraction: Double, _ seed:
Int64
+ ) -> Plan {
var sample = Sample()
sample.input = child
sample.withReplacement = withReplacement
@@ -762,9 +767,10 @@ public actor SparkConnectClient {
addArtifactsRequest.clientType = self.clientType
addArtifactsRequest.batch = batch
let request = addArtifactsRequest
- _ = try await service.addArtifacts(request:
StreamingClientRequest<Spark_Connect_AddArtifactsRequest> { x in
- try await x.write(contentsOf: [request])
- })
+ _ = try await service.addArtifacts(
+ request: StreamingClientRequest<Spark_Connect_AddArtifactsRequest> { x
in
+ try await x.write(contentsOf: [request])
+ })
}
}
@@ -846,11 +852,13 @@ public actor SparkConnectClient {
func ddlParse(_ ddlString: String) async throws -> Spark_Connect_DataType {
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
- let request = analyze(self.sessionID!, {
- var ddlParse = AnalyzePlanRequest.DDLParse()
- ddlParse.ddlString = ddlString
- return OneOf_Analyze.ddlParse(ddlParse)
- })
+ let request = analyze(
+ self.sessionID!,
+ {
+ var ddlParse = AnalyzePlanRequest.DDLParse()
+ ddlParse.ddlString = ddlString
+ return OneOf_Analyze.ddlParse(ddlParse)
+ })
do {
let response = try await service.analyzePlan(request)
return response.ddlParse.parsed
@@ -871,11 +879,13 @@ public actor SparkConnectClient {
func jsonToDdl(_ jsonString: String) async throws -> String {
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
- let request = analyze(self.sessionID!, {
- var jsonToDDL = AnalyzePlanRequest.JsonToDDL()
- jsonToDDL.jsonString = jsonString
- return OneOf_Analyze.jsonToDdl(jsonToDDL)
- })
+ let request = analyze(
+ self.sessionID!,
+ {
+ var jsonToDDL = AnalyzePlanRequest.JsonToDDL()
+ jsonToDDL.jsonString = jsonString
+ return OneOf_Analyze.jsonToDdl(jsonToDDL)
+ })
let response = try await service.analyzePlan(request)
return response.jsonToDdl.ddlString
}
@@ -884,12 +894,14 @@ public actor SparkConnectClient {
func sameSemantics(_ plan: Plan, _ otherPlan: Plan) async throws -> Bool {
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
- let request = analyze(self.sessionID!, {
- var sameSemantics = AnalyzePlanRequest.SameSemantics()
- sameSemantics.targetPlan = plan
- sameSemantics.otherPlan = otherPlan
- return OneOf_Analyze.sameSemantics(sameSemantics)
- })
+ let request = analyze(
+ self.sessionID!,
+ {
+ var sameSemantics = AnalyzePlanRequest.SameSemantics()
+ sameSemantics.targetPlan = plan
+ sameSemantics.otherPlan = otherPlan
+ return OneOf_Analyze.sameSemantics(sameSemantics)
+ })
let response = try await service.analyzePlan(request)
return response.sameSemantics.result
}
@@ -898,11 +910,13 @@ public actor SparkConnectClient {
func semanticHash(_ plan: Plan) async throws -> Int32 {
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
- let request = analyze(self.sessionID!, {
- var semanticHash = AnalyzePlanRequest.SemanticHash()
- semanticHash.plan = plan
- return OneOf_Analyze.semanticHash(semanticHash)
- })
+ let request = analyze(
+ self.sessionID!,
+ {
+ var semanticHash = AnalyzePlanRequest.SemanticHash()
+ semanticHash.plan = plan
+ return OneOf_Analyze.semanticHash(semanticHash)
+ })
let response = try await service.analyzePlan(request)
return response.semanticHash.result
}
@@ -986,7 +1000,9 @@ public actor SparkConnectClient {
})
}
- static func getRepartition(_ child: Relation, _ numPartitions: Int32, _
shuffle: Bool = false) -> Plan {
+ static func getRepartition(_ child: Relation, _ numPartitions: Int32, _
shuffle: Bool = false)
+ -> Plan
+ {
var repartition = Repartition()
repartition.input = child
repartition.numPartitions = numPartitions
@@ -1064,7 +1080,7 @@ public actor SparkConnectClient {
literal.short = Int32(value)
case let value as Int32:
literal.integer = value
- case let value as Int64: // Hint parameter raises exceptions for Int64
+ case let value as Int64: // Hint parameter raises exceptions for Int64
literal.integer = Int32(value)
case let value as Int:
literal.integer = Int32(value)
diff --git a/Sources/SparkConnect/SparkSession.swift
b/Sources/SparkConnect/SparkSession.swift
index 7e7326c..5203cba 100644
--- a/Sources/SparkConnect/SparkSession.swift
+++ b/Sources/SparkConnect/SparkSession.swift
@@ -89,16 +89,14 @@ public actor SparkSession {
/// Interface through which the user may create, drop, alter or query
underlying databases, tables, functions etc.
public var catalog: Catalog {
- get {
- return Catalog(spark: self)
- }
+ return Catalog(spark: self)
}
/// Stop the current client.
public func stop() async {
await client.stop()
}
-
+
/// Returns a ``DataFrame`` with no rows or columns.
public var emptyDataFrame: DataFrame {
get async {
@@ -222,9 +220,7 @@ public actor SparkSession {
///
/// - Returns: A ``DataFrameReader`` instance configured for this session
public var read: DataFrameReader {
- get {
- DataFrameReader(sparkSession: self)
- }
+ DataFrameReader(sparkSession: self)
}
/// Returns a ``DataStreamReader`` that can be used to read streaming data
in as a ``DataFrame``.
@@ -239,9 +235,7 @@ public actor SparkSession {
///
/// - Returns: A ``DataFrameReader`` instance configured for this session
public var readStream: DataStreamReader {
- get {
- DataStreamReader(sparkSession: self)
- }
+ DataStreamReader(sparkSession: self)
}
/// Returns a ``DataFrame`` representing the specified table or view.
@@ -337,11 +331,11 @@ public actor SparkSession {
/// ```swift
/// // Add a tag for a specific operation
/// try await spark.addTag("etl_job_2024")
- ///
+ ///
/// // Perform operations that will be tagged
/// let df = try await spark.sql("SELECT * FROM source_table")
/// try await df.write.saveAsTable("processed_table")
- ///
+ ///
/// // Remove the tag when done
/// try await spark.removeTag("etl_job_2024")
/// ```
@@ -422,9 +416,7 @@ public actor SparkSession {
/// Returns a `StreamingQueryManager` that allows managing all the
`StreamingQuery`s active on
/// `this`.
public var streams: StreamingQueryManager {
- get {
- StreamingQueryManager(self)
- }
+ StreamingQueryManager(self)
}
/// This is defined as the return type of `SparkSession.sparkContext` method.
diff --git a/Tests/SparkConnectTests/CatalogTests.swift
b/Tests/SparkConnectTests/CatalogTests.swift
index 0888fdd..24ae1f6 100644
--- a/Tests/SparkConnectTests/CatalogTests.swift
+++ b/Tests/SparkConnectTests/CatalogTests.swift
@@ -25,288 +25,297 @@ import Testing
/// A test suite for `Catalog`
@Suite(.serialized)
struct CatalogTests {
-#if !os(Linux)
- @Test
- func currentCatalog() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.catalog.currentCatalog() == "spark_catalog")
- await spark.stop()
- }
+ #if !os(Linux)
+ @Test
+ func currentCatalog() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.catalog.currentCatalog() == "spark_catalog")
+ await spark.stop()
+ }
- @Test
- func setCurrentCatalog() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- try await spark.catalog.setCurrentCatalog("spark_catalog")
- if await spark.version >= "4.0.0" {
- try await #require(throws: SparkConnectError.CatalogNotFound) {
- try await spark.catalog.setCurrentCatalog("not_exist_catalog")
- }
- } else {
- try await #require(throws: Error.self) {
- try await spark.catalog.setCurrentCatalog("not_exist_catalog")
+ @Test
+ func setCurrentCatalog() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ try await spark.catalog.setCurrentCatalog("spark_catalog")
+ if await spark.version >= "4.0.0" {
+ try await #require(throws: SparkConnectError.CatalogNotFound) {
+ try await spark.catalog.setCurrentCatalog("not_exist_catalog")
+ }
+ } else {
+ try await #require(throws: Error.self) {
+ try await spark.catalog.setCurrentCatalog("not_exist_catalog")
+ }
}
+ await spark.stop()
}
- await spark.stop()
- }
-
- @Test
- func listCatalogs() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name:
"spark_catalog")])
- #expect(try await spark.catalog.listCatalogs(pattern: "*") ==
[CatalogMetadata(name: "spark_catalog")])
- #expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count
== 0)
- await spark.stop()
- }
- @Test
- func currentDatabase() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.catalog.currentDatabase() == "default")
- await spark.stop()
- }
+ @Test
+ func listCatalogs() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name:
"spark_catalog")])
+ #expect(
+ try await spark.catalog.listCatalogs(pattern: "*") == [
+ CatalogMetadata(name: "spark_catalog")
+ ])
+ #expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count
== 0)
+ await spark.stop()
+ }
- @Test
- func setCurrentDatabase() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- try await spark.catalog.setCurrentDatabase("default")
- try await #require(throws: SparkConnectError.SchemaNotFound) {
- try await spark.catalog.setCurrentDatabase("not_exist_database")
+ @Test
+ func currentDatabase() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.catalog.currentDatabase() == "default")
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func listDatabases() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let dbs = try await spark.catalog.listDatabases()
- #expect(dbs.count == 1)
- #expect(dbs[0].name == "default")
- #expect(dbs[0].catalog == "spark_catalog")
- #expect(dbs[0].description == "default database")
- #expect(dbs[0].locationUri.hasSuffix("spark-warehouse"))
- #expect(try await spark.catalog.listDatabases(pattern: "*") == dbs)
- #expect(try await spark.catalog.listDatabases(pattern: "non_exist").count
== 0)
- await spark.stop()
- }
+ @Test
+ func setCurrentDatabase() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ try await spark.catalog.setCurrentDatabase("default")
+ try await #require(throws: SparkConnectError.SchemaNotFound) {
+ try await spark.catalog.setCurrentDatabase("not_exist_database")
+ }
+ await spark.stop()
+ }
- @Test
- func getDatabase() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let db = try await spark.catalog.getDatabase("default")
- #expect(db.name == "default")
- #expect(db.catalog == "spark_catalog")
- #expect(db.description == "default database")
- #expect(db.locationUri.hasSuffix("spark-warehouse"))
- try await #require(throws: SparkConnectError.SchemaNotFound) {
- try await spark.catalog.getDatabase("not_exist_database")
+ @Test
+ func listDatabases() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let dbs = try await spark.catalog.listDatabases()
+ #expect(dbs.count == 1)
+ #expect(dbs[0].name == "default")
+ #expect(dbs[0].catalog == "spark_catalog")
+ #expect(dbs[0].description == "default database")
+ #expect(dbs[0].locationUri.hasSuffix("spark-warehouse"))
+ #expect(try await spark.catalog.listDatabases(pattern: "*") == dbs)
+ #expect(try await spark.catalog.listDatabases(pattern:
"non_exist").count == 0)
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func databaseExists() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.catalog.databaseExists("default"))
+ @Test
+ func getDatabase() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let db = try await spark.catalog.getDatabase("default")
+ #expect(db.name == "default")
+ #expect(db.catalog == "spark_catalog")
+ #expect(db.description == "default database")
+ #expect(db.locationUri.hasSuffix("spark-warehouse"))
+ try await #require(throws: SparkConnectError.SchemaNotFound) {
+ try await spark.catalog.getDatabase("not_exist_database")
+ }
+ await spark.stop()
+ }
- let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with:
"")
- #expect(try await spark.catalog.databaseExists(dbName) == false)
- try await SQLHelper.withDatabase(spark, dbName) ({
- try await spark.sql("CREATE DATABASE \(dbName)").count()
- #expect(try await spark.catalog.databaseExists(dbName))
- })
- #expect(try await spark.catalog.databaseExists(dbName) == false)
- await spark.stop()
- }
+ @Test
+ func databaseExists() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.catalog.databaseExists("default"))
+
+ let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
+ #expect(try await spark.catalog.databaseExists(dbName) == false)
+ try await SQLHelper.withDatabase(spark, dbName)({
+ try await spark.sql("CREATE DATABASE \(dbName)").count()
+ #expect(try await spark.catalog.databaseExists(dbName))
+ })
+ #expect(try await spark.catalog.databaseExists(dbName) == false)
+ await spark.stop()
+ }
- @Test
- func createTable() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withTable(spark, tableName)({
- try await spark.range(1).write.orc("/tmp/\(tableName)")
- #expect(try await spark.catalog.createTable(tableName,
"/tmp/\(tableName)", source: "orc").count() == 1)
- #expect(try await spark.catalog.tableExists(tableName))
- })
- await spark.stop()
- }
+ @Test
+ func createTable() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of:
"-", with: "")
+ try await SQLHelper.withTable(spark, tableName)({
+ try await spark.range(1).write.orc("/tmp/\(tableName)")
+ #expect(
+ try await spark.catalog.createTable(tableName, "/tmp/\(tableName)",
source: "orc").count()
+ == 1)
+ #expect(try await spark.catalog.tableExists(tableName))
+ })
+ await spark.stop()
+ }
- @Test
- func tableExists() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withTable(spark, tableName)({
- try await spark.range(1).write.parquet("/tmp/\(tableName)")
+ @Test
+ func tableExists() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of:
"-", with: "")
+ try await SQLHelper.withTable(spark, tableName)({
+ try await spark.range(1).write.parquet("/tmp/\(tableName)")
+ #expect(try await spark.catalog.tableExists(tableName) == false)
+ #expect(try await spark.catalog.createTable(tableName,
"/tmp/\(tableName)").count() == 1)
+ #expect(try await spark.catalog.tableExists(tableName))
+ #expect(try await spark.catalog.tableExists("default", tableName))
+ #expect(try await spark.catalog.tableExists("default2", tableName) ==
false)
+ })
#expect(try await spark.catalog.tableExists(tableName) == false)
- #expect(try await spark.catalog.createTable(tableName,
"/tmp/\(tableName)").count() == 1)
- #expect(try await spark.catalog.tableExists(tableName))
- #expect(try await spark.catalog.tableExists("default", tableName))
- #expect(try await spark.catalog.tableExists("default2", tableName) ==
false)
- })
- #expect(try await spark.catalog.tableExists(tableName) == false)
- try await #require(throws: SparkConnectError.ParseSyntaxError) {
- try await spark.catalog.tableExists("invalid table name")
+ try await #require(throws: SparkConnectError.ParseSyntaxError) {
+ try await spark.catalog.tableExists("invalid table name")
+ }
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func listColumns() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
+ @Test
+ func listColumns() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+
+ // Table
+ let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of:
"-", with: "")
+ let path = "/tmp/\(tableName)"
+ try await SQLHelper.withTable(spark, tableName)({
+ try await spark.range(2).write.orc(path)
+ let expected =
+ if await spark.version.starts(with: "4.") {
+ [Row("id", nil, "bigint", true, false, false, false)]
+ } else {
+ [Row("id", nil, "bigint", true, false, false)]
+ }
+ #expect(try await spark.catalog.createTable(tableName, path, source:
"orc").count() == 2)
+ #expect(try await spark.catalog.listColumns(tableName).collect() ==
expected)
+ #expect(try await
spark.catalog.listColumns("default.\(tableName)").collect() == expected)
+ })
+
+ // View
+ let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
+ try await SQLHelper.withTempView(spark, viewName)({
+ try await spark.range(1).createTempView(viewName)
+ let expected =
+ if await spark.version.starts(with: "4.") {
+ [Row("id", nil, "bigint", false, false, false, false)]
+ } else {
+ [Row("id", nil, "bigint", false, false, false)]
+ }
+ #expect(try await spark.catalog.listColumns(viewName).collect() ==
expected)
+ })
+
+ await spark.stop()
+ }
- // Table
- let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- let path = "/tmp/\(tableName)"
- try await SQLHelper.withTable(spark, tableName)({
- try await spark.range(2).write.orc(path)
- let expected = if await spark.version.starts(with: "4.") {
- [Row("id", nil, "bigint", true, false, false, false)]
- } else {
- [Row("id", nil, "bigint", true, false, false)]
- }
- #expect(try await spark.catalog.createTable(tableName, path, source:
"orc").count() == 2)
- #expect(try await spark.catalog.listColumns(tableName).collect() ==
expected)
- #expect(try await
spark.catalog.listColumns("default.\(tableName)").collect() == expected)
- })
+ @Test
+ func functionExists() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.catalog.functionExists("base64"))
+ #expect(try await spark.catalog.functionExists("non_exist_function") ==
false)
- // View
- let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withTempView(spark, viewName)({
- try await spark.range(1).createTempView(viewName)
- let expected = if await spark.version.starts(with: "4.") {
- [Row("id", nil, "bigint", false, false, false, false)]
- } else {
- [Row("id", nil, "bigint", false, false, false)]
+ try await #require(throws: SparkConnectError.ParseSyntaxError) {
+ try await spark.catalog.functionExists("invalid function name")
}
- #expect(try await spark.catalog.listColumns(viewName).collect() ==
expected)
- })
-
- await spark.stop()
- }
-
- @Test
- func functionExists() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.catalog.functionExists("base64"))
- #expect(try await spark.catalog.functionExists("non_exist_function") ==
false)
-
- try await #require(throws: SparkConnectError.ParseSyntaxError) {
- try await spark.catalog.functionExists("invalid function name")
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func createTempView() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withTempView(spark, viewName)({
- #expect(try await spark.catalog.tableExists(viewName) == false)
- try await spark.range(1).createTempView(viewName)
- #expect(try await spark.catalog.tableExists(viewName))
-
- try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) {
+ @Test
+ func createTempView() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
+ try await SQLHelper.withTempView(spark, viewName)({
+ #expect(try await spark.catalog.tableExists(viewName) == false)
try await spark.range(1).createTempView(viewName)
+ #expect(try await spark.catalog.tableExists(viewName))
+
+ try await #require(throws: SparkConnectError.TableOrViewAlreadyExists)
{
+ try await spark.range(1).createTempView(viewName)
+ }
+ })
+
+ try await #require(throws: SparkConnectError.InvalidViewName) {
+ try await spark.range(1).createTempView("invalid view name")
}
- })
- try await #require(throws: SparkConnectError.InvalidViewName) {
- try await spark.range(1).createTempView("invalid view name")
+ await spark.stop()
}
- await spark.stop()
- }
-
- @Test
- func createOrReplaceTempView() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withTempView(spark, viewName)({
- #expect(try await spark.catalog.tableExists(viewName) == false)
- try await spark.range(1).createOrReplaceTempView(viewName)
- #expect(try await spark.catalog.tableExists(viewName))
- try await spark.range(1).createOrReplaceTempView(viewName)
- })
+ @Test
+ func createOrReplaceTempView() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
+ try await SQLHelper.withTempView(spark, viewName)({
+ #expect(try await spark.catalog.tableExists(viewName) == false)
+ try await spark.range(1).createOrReplaceTempView(viewName)
+ #expect(try await spark.catalog.tableExists(viewName))
+ try await spark.range(1).createOrReplaceTempView(viewName)
+ })
+
+ try await #require(throws: SparkConnectError.InvalidViewName) {
+ try await spark.range(1).createOrReplaceTempView("invalid view name")
+ }
- try await #require(throws: SparkConnectError.InvalidViewName) {
- try await spark.range(1).createOrReplaceTempView("invalid view name")
+ await spark.stop()
}
- await spark.stop()
- }
+ @Test
+ func createGlobalTempView() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
+ try await SQLHelper.withGlobalTempView(spark, viewName)({
+ #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")
== false)
+ try await spark.range(1).createGlobalTempView(viewName)
+ #expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))
- @Test
- func createGlobalTempView() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withGlobalTempView(spark, viewName)({
+ try await #require(throws: SparkConnectError.TableOrViewAlreadyExists)
{
+ try await spark.range(1).createGlobalTempView(viewName)
+ }
+ })
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)")
== false)
- try await spark.range(1).createGlobalTempView(viewName)
- #expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))
- try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) {
- try await spark.range(1).createGlobalTempView(viewName)
+ try await #require(throws: SparkConnectError.InvalidViewName) {
+ try await spark.range(1).createGlobalTempView("invalid view name")
}
- })
- #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") ==
false)
- try await #require(throws: SparkConnectError.InvalidViewName) {
- try await spark.range(1).createGlobalTempView("invalid view name")
+ await spark.stop()
}
- await spark.stop()
- }
-
- @Test
- func createOrReplaceGlobalTempView() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withGlobalTempView(spark, viewName)({
+ @Test
+ func createOrReplaceGlobalTempView() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
+ try await SQLHelper.withGlobalTempView(spark, viewName)({
+ #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")
== false)
+ try await spark.range(1).createOrReplaceGlobalTempView(viewName)
+ #expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))
+ try await spark.range(1).createOrReplaceGlobalTempView(viewName)
+ })
#expect(try await spark.catalog.tableExists("global_temp.\(viewName)")
== false)
- try await spark.range(1).createOrReplaceGlobalTempView(viewName)
- #expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))
- try await spark.range(1).createOrReplaceGlobalTempView(viewName)
- })
- #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") ==
false)
-
- try await #require(throws: SparkConnectError.InvalidViewName) {
- try await spark.range(1).createOrReplaceGlobalTempView("invalid view
name")
- }
- await spark.stop()
- }
+ try await #require(throws: SparkConnectError.InvalidViewName) {
+ try await spark.range(1).createOrReplaceGlobalTempView("invalid view
name")
+ }
- @Test
- func dropTempView() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withTempView(spark, viewName)({ #expect(try await
spark.catalog.tableExists(viewName) == false)
- try await spark.range(1).createTempView(viewName)
- try await spark.catalog.dropTempView(viewName)
- #expect(try await spark.catalog.tableExists(viewName) == false)
- })
+ await spark.stop()
+ }
- #expect(try await spark.catalog.dropTempView("non_exist_view") == false)
- #expect(try await spark.catalog.dropTempView("invalid view name") == false)
- await spark.stop()
- }
+ @Test
+ func dropTempView() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
+ try await SQLHelper.withTempView(spark, viewName)({
+ #expect(try await spark.catalog.tableExists(viewName) == false)
+ try await spark.range(1).createTempView(viewName)
+ try await spark.catalog.dropTempView(viewName)
+ #expect(try await spark.catalog.tableExists(viewName) == false)
+ })
- @Test
- func dropGlobalTempView() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
- try await SQLHelper.withTempView(spark, viewName)({ #expect(try await
spark.catalog.tableExists(viewName) == false)
- try await spark.range(1).createGlobalTempView(viewName)
- #expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))
- try await spark.catalog.dropGlobalTempView(viewName)
- #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")
== false)
- })
+ #expect(try await spark.catalog.dropTempView("non_exist_view") == false)
+ #expect(try await spark.catalog.dropTempView("invalid view name") ==
false)
+ await spark.stop()
+ }
- #expect(try await spark.catalog.dropGlobalTempView("non_exist_view") ==
false)
- #expect(try await spark.catalog.dropGlobalTempView("invalid view name") ==
false)
- await spark.stop()
- }
-#endif
+ @Test
+ func dropGlobalTempView() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-",
with: "")
+ try await SQLHelper.withTempView(spark, viewName)({
+ #expect(try await spark.catalog.tableExists(viewName) == false)
+ try await spark.range(1).createGlobalTempView(viewName)
+ #expect(try await spark.catalog.tableExists("global_temp.\(viewName)"))
+ try await spark.catalog.dropGlobalTempView(viewName)
+ #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")
== false)
+ })
+
+ #expect(try await spark.catalog.dropGlobalTempView("non_exist_view") ==
false)
+ #expect(try await spark.catalog.dropGlobalTempView("invalid view name")
== false)
+ await spark.stop()
+ }
+ #endif
@Test
func cacheTable() async throws {
diff --git a/Tests/SparkConnectTests/DataFrameInternalTests.swift
b/Tests/SparkConnectTests/DataFrameInternalTests.swift
index 96e8fc2..6c843c3 100644
--- a/Tests/SparkConnectTests/DataFrameInternalTests.swift
+++ b/Tests/SparkConnectTests/DataFrameInternalTests.swift
@@ -25,63 +25,63 @@ import Testing
@Suite(.serialized)
struct DataFrameInternalTests {
-#if !os(Linux)
- @Test
- func showString() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let rows = try await spark.range(10).showString(2, 0, false).collect()
- #expect(rows.count == 1)
- #expect(rows[0].length == 1)
- #expect(
- try (rows[0].get(0) as! String).trimmingCharacters(in:
.whitespacesAndNewlines) == """
- +---+
- |id |
- +---+
- |0 |
- |1 |
- +---+
- only showing top 2 rows
- """)
- await spark.stop()
- }
+ #if !os(Linux)
+ @Test
+ func showString() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.range(10).showString(2, 0, false).collect()
+ #expect(rows.count == 1)
+ #expect(rows[0].length == 1)
+ #expect(
+ try (rows[0].get(0) as! String).trimmingCharacters(in:
.whitespacesAndNewlines) == """
+ +---+
+ |id |
+ +---+
+ |0 |
+ |1 |
+ +---+
+ only showing top 2 rows
+ """)
+ await spark.stop()
+ }
- @Test
- func showStringTruncate() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let rows = try await spark.sql("SELECT * FROM VALUES ('abc', 'def'),
('ghi', 'jkl')")
- .showString(2, 2, false).collect()
- #expect(rows.count == 1)
- #expect(rows[0].length == 1)
- print(try rows[0].get(0) as! String)
- #expect(
- try rows[0].get(0) as! String == """
- +----+----+
- |col1|col2|
- +----+----+
- | ab| de|
- | gh| jk|
- +----+----+
+ @Test
+ func showStringTruncate() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.sql("SELECT * FROM VALUES ('abc', 'def'),
('ghi', 'jkl')")
+ .showString(2, 2, false).collect()
+ #expect(rows.count == 1)
+ #expect(rows[0].length == 1)
+ print(try rows[0].get(0) as! String)
+ #expect(
+ try rows[0].get(0) as! String == """
+ +----+----+
+ |col1|col2|
+ +----+----+
+ | ab| de|
+ | gh| jk|
+ +----+----+
- """)
- await spark.stop()
- }
+ """)
+ await spark.stop()
+ }
- @Test
- func showStringVertical() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let rows = try await spark.range(10).showString(2, 0, true).collect()
- #expect(rows.count == 1)
- #expect(rows[0].length == 1)
- print(try rows[0].get(0) as! String)
- #expect(
- try (rows[0].get(0) as! String).trimmingCharacters(in:
.whitespacesAndNewlines) == """
- -RECORD 0--
- id | 0
- -RECORD 1--
- id | 1
- only showing top 2 rows
- """)
- await spark.stop()
- }
-#endif
+ @Test
+ func showStringVertical() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.range(10).showString(2, 0, true).collect()
+ #expect(rows.count == 1)
+ #expect(rows[0].length == 1)
+ print(try rows[0].get(0) as! String)
+ #expect(
+ try (rows[0].get(0) as! String).trimmingCharacters(in:
.whitespacesAndNewlines) == """
+ -RECORD 0--
+ id | 0
+ -RECORD 1--
+ id | 1
+ only showing top 2 rows
+ """)
+ await spark.stop()
+ }
+ #endif
}
diff --git a/Tests/SparkConnectTests/DataFrameReaderTests.swift
b/Tests/SparkConnectTests/DataFrameReaderTests.swift
index 0dfd04b..bcee038 100644
--- a/Tests/SparkConnectTests/DataFrameReaderTests.swift
+++ b/Tests/SparkConnectTests/DataFrameReaderTests.swift
@@ -18,9 +18,8 @@
//
import Foundation
-import Testing
-
import SparkConnect
+import Testing
/// A test suite for `DataFrameReader`
@Suite(.serialized)
@@ -95,8 +94,14 @@ struct DataFrameReaderTests {
let path = "../examples/src/main/resources/people.json"
#expect(try await spark.read.schema("age SHORT").json(path).dtypes.count
== 1)
#expect(try await spark.read.schema("age SHORT").json(path).dtypes[0] ==
("age", "smallint"))
- #expect(try await spark.read.schema("age SHORT, name
STRING").json(path).dtypes[0] == ("age", "smallint"))
- #expect(try await spark.read.schema("age SHORT, name
STRING").json(path).dtypes[1] == ("name", "string"))
+ #expect(
+ try await spark.read.schema("age SHORT, name
STRING").json(path).dtypes[0] == (
+ "age", "smallint"
+ ))
+ #expect(
+ try await spark.read.schema("age SHORT, name
STRING").json(path).dtypes[1] == (
+ "name", "string"
+ ))
await spark.stop()
}
diff --git a/Tests/SparkConnectTests/DataFrameTests.swift
b/Tests/SparkConnectTests/DataFrameTests.swift
index f5c6eeb..2edd5f8 100644
--- a/Tests/SparkConnectTests/DataFrameTests.swift
+++ b/Tests/SparkConnectTests/DataFrameTests.swift
@@ -18,9 +18,8 @@
//
import Foundation
-import Testing
-
import SparkConnect
+import Testing
/// A test suite for `DataFrame`
@Suite(.serialized)
@@ -70,19 +69,21 @@ struct DataFrameTests {
let spark = try await SparkSession.builder.getOrCreate()
let schema1 = try await spark.sql("SELECT 'a' as col1").schema
- let answer1 = if await spark.version.starts(with: "4.") {
-
#"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"#
- } else {
- #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}}]}}"#
- }
+ let answer1 =
+ if await spark.version.starts(with: "4.") {
+
#"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"#
+ } else {
+ #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}}]}}"#
+ }
#expect(schema1 == answer1)
let schema2 = try await spark.sql("SELECT 'a' as col1, 'b' as col2").schema
- let answer2 = if await spark.version.starts(with: "4.") {
-
#"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}},{"name":"col2","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"#
- } else {
-
#"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}},{"name":"col2","dataType":{"string":{}}}]}}"#
- }
+ let answer2 =
+ if await spark.version.starts(with: "4.") {
+
#"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}},{"name":"col2","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"#
+ } else {
+
#"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}},{"name":"col2","dataType":{"string":{}}}]}}"#
+ }
#expect(schema2 == answer2)
let emptySchema = try await spark.sql("DROP TABLE IF EXISTS
nonexistent").schema
@@ -208,14 +209,14 @@ struct DataFrameTests {
let schema1 = try await spark.range(1).to("shortID SHORT").schema
#expect(
schema1
- ==
#"{"struct":{"fields":[{"name":"shortID","dataType":{"short":{}},"nullable":true}]}}"#
+ ==
#"{"struct":{"fields":[{"name":"shortID","dataType":{"short":{}},"nullable":true}]}}"#
)
let schema2 = try await spark.sql("SELECT '1'").to("id INT").schema
print(schema2)
#expect(
schema2
- ==
#"{"struct":{"fields":[{"name":"id","dataType":{"integer":{}},"nullable":true}]}}"#
+ ==
#"{"struct":{"fields":[{"name":"id","dataType":{"integer":{}},"nullable":true}]}}"#
)
await spark.stop()
@@ -344,23 +345,23 @@ struct DataFrameTests {
await spark.stop()
}
-#if !os(Linux)
- @Test
- func sort() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let expected = Array((1...10).map{ Row($0) })
- #expect(try await spark.range(10, 0, -1).sort("id").collect() == expected)
- await spark.stop()
- }
+ #if !os(Linux)
+ @Test
+ func sort() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let expected = Array((1...10).map { Row($0) })
+ #expect(try await spark.range(10, 0, -1).sort("id").collect() ==
expected)
+ await spark.stop()
+ }
- @Test
- func orderBy() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let expected = Array((1...10).map{ Row($0) })
- #expect(try await spark.range(10, 0, -1).orderBy("id").collect() ==
expected)
- await spark.stop()
- }
-#endif
+ @Test
+ func orderBy() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let expected = Array((1...10).map { Row($0) })
+ #expect(try await spark.range(10, 0, -1).orderBy("id").collect() ==
expected)
+ await spark.stop()
+ }
+ #endif
@Test
func table() async throws {
@@ -376,204 +377,167 @@ struct DataFrameTests {
await spark.stop()
}
-#if !os(Linux)
- @Test
- func collect() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.range(0).collect().isEmpty)
- #expect(
- try await spark.sql(
- "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false,
'def')"
- ).collect() == [Row(1, true, "abc"), Row(nil, nil, nil), Row(3, false,
"def")])
- await spark.stop()
- }
-
- @Test
- func collectMultiple() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(1)
- #expect(try await df.collect().count == 1)
- #expect(try await df.collect().count == 1)
- await spark.stop()
- }
+ #if !os(Linux)
+ @Test
+ func collect() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.range(0).collect().isEmpty)
+ #expect(
+ try await spark.sql(
+ "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3,
false, 'def')"
+ ).collect() == [Row(1, true, "abc"), Row(nil, nil, nil), Row(3, false,
"def")])
+ await spark.stop()
+ }
- @Test
- func first() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.range(2).sort("id").first() == Row(0))
- #expect(try await spark.range(2).sort("id").head() == Row(0))
- await spark.stop()
- }
+ @Test
+ func collectMultiple() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(1)
+ #expect(try await df.collect().count == 1)
+ #expect(try await df.collect().count == 1)
+ await spark.stop()
+ }
- @Test
- func head() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.range(0).head(1).isEmpty)
- #expect(try await spark.range(2).sort("id").head() == Row(0))
- #expect(try await spark.range(2).sort("id").head(1) == [Row(0)])
- #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)])
- #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)])
- await spark.stop()
- }
+ @Test
+ func first() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.range(2).sort("id").first() == Row(0))
+ #expect(try await spark.range(2).sort("id").head() == Row(0))
+ await spark.stop()
+ }
- @Test
- func take() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.range(0).take(1).isEmpty)
- #expect(try await spark.range(2).sort("id").take(1) == [Row(0)])
- #expect(try await spark.range(2).sort("id").take(2) == [Row(0), Row(1)])
- #expect(try await spark.range(2).sort("id").take(3) == [Row(0), Row(1)])
- await spark.stop()
- }
+ @Test
+ func head() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.range(0).head(1).isEmpty)
+ #expect(try await spark.range(2).sort("id").head() == Row(0))
+ #expect(try await spark.range(2).sort("id").head(1) == [Row(0)])
+ #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)])
+ #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)])
+ await spark.stop()
+ }
- @Test
- func tail() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.range(0).tail(1).isEmpty)
- #expect(try await spark.range(2).sort("id").tail(1) == [Row(1)])
- #expect(try await spark.range(2).sort("id").tail(2) == [Row(0), Row(1)])
- #expect(try await spark.range(2).sort("id").tail(3) == [Row(0), Row(1)])
- await spark.stop()
- }
+ @Test
+ func take() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.range(0).take(1).isEmpty)
+ #expect(try await spark.range(2).sort("id").take(1) == [Row(0)])
+ #expect(try await spark.range(2).sort("id").take(2) == [Row(0), Row(1)])
+ #expect(try await spark.range(2).sort("id").take(3) == [Row(0), Row(1)])
+ await spark.stop()
+ }
- @Test
- func show() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- try await spark.sql("SHOW TABLES").show()
- try await spark.sql("SELECT * FROM VALUES (true, false)").show()
- try await spark.sql("SELECT * FROM VALUES (1, 2)").show()
- try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi',
'jkl')").show()
-
- // Check all signatures
- try await spark.range(1000).show()
- try await spark.range(1000).show(1)
- try await spark.range(1000).show(true)
- try await spark.range(1000).show(false)
- try await spark.range(1000).show(1, true)
- try await spark.range(1000).show(1, false)
- try await spark.range(1000).show(1, 20)
- try await spark.range(1000).show(1, 20, true)
- try await spark.range(1000).show(1, 20, false)
+ @Test
+ func tail() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.range(0).tail(1).isEmpty)
+ #expect(try await spark.range(2).sort("id").tail(1) == [Row(1)])
+ #expect(try await spark.range(2).sort("id").tail(2) == [Row(0), Row(1)])
+ #expect(try await spark.range(2).sort("id").tail(3) == [Row(0), Row(1)])
+ await spark.stop()
+ }
- await spark.stop()
- }
+ @Test
+ func show() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ try await spark.sql("SHOW TABLES").show()
+ try await spark.sql("SELECT * FROM VALUES (true, false)").show()
+ try await spark.sql("SELECT * FROM VALUES (1, 2)").show()
+ try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi',
'jkl')").show()
+
+ // Check all signatures
+ try await spark.range(1000).show()
+ try await spark.range(1000).show(1)
+ try await spark.range(1000).show(true)
+ try await spark.range(1000).show(false)
+ try await spark.range(1000).show(1, true)
+ try await spark.range(1000).show(1, false)
+ try await spark.range(1000).show(1, 20)
+ try await spark.range(1000).show(1, 20, true)
+ try await spark.range(1000).show(1, 20, false)
+
+ await spark.stop()
+ }
- @Test
- func showNull() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- try await spark.sql(
- "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false,
'def')"
- ).show()
- await spark.stop()
- }
+ @Test
+ func showNull() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ try await spark.sql(
+ "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false,
'def')"
+ ).show()
+ await spark.stop()
+ }
- @Test
- func showCommand() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- try await spark.sql("DROP TABLE IF EXISTS t").show()
- await spark.stop()
- }
+ @Test
+ func showCommand() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ try await spark.sql("DROP TABLE IF EXISTS t").show()
+ await spark.stop()
+ }
- @Test
- func cache() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.range(10).cache().count() == 10)
- await spark.stop()
- }
+ @Test
+ func cache() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.range(10).cache().count() == 10)
+ await spark.stop()
+ }
- @Test
- func checkpoint() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- if await spark.version >= "4.0.0" {
- // By default, reliable checkpoint location is required.
- try await #require(throws: Error.self) {
- try await spark.range(10).checkpoint()
+ @Test
+ func checkpoint() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ if await spark.version >= "4.0.0" {
+ // By default, reliable checkpoint location is required.
+ try await #require(throws: Error.self) {
+ try await spark.range(10).checkpoint()
+ }
+ // Checkpointing with unreliable checkpoint
+ let df = try await spark.range(10).checkpoint(true, false)
+ #expect(try await df.count() == 10)
}
- // Checkpointing with unreliable checkpoint
- let df = try await spark.range(10).checkpoint(true, false)
- #expect(try await df.count() == 10)
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func localCheckpoint() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- if await spark.version >= "4.0.0" {
- #expect(try await spark.range(10).localCheckpoint().count() == 10)
+ @Test
+ func localCheckpoint() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ if await spark.version >= "4.0.0" {
+ #expect(try await spark.range(10).localCheckpoint().count() == 10)
+ }
+ await spark.stop()
}
- await spark.stop()
- }
-
- @Test
- func persist() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(try await spark.range(20).persist().count() == 20)
- #expect(try await spark.range(21).persist(storageLevel:
StorageLevel.MEMORY_ONLY).count() == 21)
- await spark.stop()
- }
- @Test
- func persistInvalidStorageLevel() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- try await #require(throws: Error.self) {
- var invalidLevel = StorageLevel.DISK_ONLY
- invalidLevel.replication = 0
- try await spark.range(9999).persist(storageLevel: invalidLevel).count()
+ @Test
+ func persist() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(try await spark.range(20).persist().count() == 20)
+ #expect(
+ try await spark.range(21).persist(storageLevel:
StorageLevel.MEMORY_ONLY).count() == 21)
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func unpersist() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(30)
- #expect(try await df.persist().count() == 30)
- #expect(try await df.unpersist().count() == 30)
- await spark.stop()
- }
-
- @Test
- func join() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS
T(a, b)")
- let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS
S(c, b)")
- let expectedCross = [
- Row("a", 1, "c", 2),
- Row("a", 1, "d", 3),
- Row("b", 2, "c", 2),
- Row("b", 2, "d", 3),
- ]
- #expect(try await df1.join(df2).collect() == expectedCross)
- #expect(try await df1.crossJoin(df2).collect() == expectedCross)
-
- #expect(try await df1.join(df2, "b").collect() == [Row(2, "b", "c")])
- #expect(try await df1.join(df2, ["b"]).collect() == [Row(2, "b", "c")])
-
- #expect(try await df1.join(df2, "b", "left").collect() == [Row(1, "a",
nil), Row(2, "b", "c")])
- #expect(try await df1.join(df2, "b", "right").collect() == [Row(2, "b",
"c"), Row(3, nil, "d")])
- #expect(try await df1.join(df2, "b", "semi").collect() == [Row(2, "b")])
- #expect(try await df1.join(df2, "b", "anti").collect() == [Row(1, "a")])
-
- let expectedOuter = [
- Row(1, "a", nil),
- Row(2, "b", "c"),
- Row(3, nil, "d"),
- ]
- #expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter)
- #expect(try await df1.join(df2, "b", "full").collect() == expectedOuter)
- #expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter)
+ @Test
+ func persistInvalidStorageLevel() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ try await #require(throws: Error.self) {
+ var invalidLevel = StorageLevel.DISK_ONLY
+ invalidLevel.replication = 0
+ try await spark.range(9999).persist(storageLevel: invalidLevel).count()
+ }
+ await spark.stop()
+ }
- let expected = [Row("b", 2, "c", 2)]
- #expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() ==
expected)
- #expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType:
"inner").collect() == expected)
- await spark.stop()
- }
+ @Test
+ func unpersist() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(30)
+ #expect(try await df.persist().count() == 30)
+ #expect(try await df.unpersist().count() == 30)
+ await spark.stop()
+ }
- @Test
- func lateralJoin() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- if await spark.version.starts(with: "4.") {
+ @Test
+ func join() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2)
AS T(a, b)")
let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3)
AS S(c, b)")
let expectedCross = [
@@ -582,337 +546,393 @@ struct DataFrameTests {
Row("b", 2, "c", 2),
Row("b", 2, "d", 3),
]
- #expect(try await df1.lateralJoin(df2).collect() == expectedCross)
- #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() ==
expectedCross)
+ #expect(try await df1.join(df2).collect() == expectedCross)
+ #expect(try await df1.crossJoin(df2).collect() == expectedCross)
+
+ #expect(try await df1.join(df2, "b").collect() == [Row(2, "b", "c")])
+ #expect(try await df1.join(df2, ["b"]).collect() == [Row(2, "b", "c")])
+
+ #expect(
+ try await df1.join(df2, "b", "left").collect() == [Row(1, "a", nil),
Row(2, "b", "c")])
+ #expect(
+ try await df1.join(df2, "b", "right").collect() == [Row(2, "b", "c"),
Row(3, nil, "d")])
+ #expect(try await df1.join(df2, "b", "semi").collect() == [Row(2, "b")])
+ #expect(try await df1.join(df2, "b", "anti").collect() == [Row(1, "a")])
+
+ let expectedOuter = [
+ Row(1, "a", nil),
+ Row(2, "b", "c"),
+ Row(3, nil, "d"),
+ ]
+ #expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter)
+ #expect(try await df1.join(df2, "b", "full").collect() == expectedOuter)
+ #expect(try await df1.join(df2, ["b"], "full").collect() ==
expectedOuter)
let expected = [Row("b", 2, "c", 2)]
- #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect()
== expected)
- #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType:
"inner").collect() == expected)
+ #expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() ==
expected)
+ #expect(
+ try await df1.join(df2, joinExprs: "T.b = S.b", joinType:
"inner").collect() == expected)
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func except() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(1, 3)
- #expect(try await df.except(spark.range(1, 5)).collect() == [])
- #expect(try await df.except(spark.range(2, 5)).collect() == [Row(1)])
- #expect(try await df.except(spark.range(3, 5)).collect() == [Row(1),
Row(2)])
- #expect(try await spark.sql("SELECT * FROM VALUES 1,
1").except(df).count() == 0)
- await spark.stop()
- }
+ @Test
+ func lateralJoin() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ if await spark.version.starts(with: "4.") {
+ let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2)
AS T(a, b)")
+ let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3)
AS S(c, b)")
+ let expectedCross = [
+ Row("a", 1, "c", 2),
+ Row("a", 1, "d", 3),
+ Row("b", 2, "c", 2),
+ Row("b", 2, "d", 3),
+ ]
+ #expect(try await df1.lateralJoin(df2).collect() == expectedCross)
+ #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() ==
expectedCross)
+
+ let expected = [Row("b", 2, "c", 2)]
+ #expect(try await df1.lateralJoin(df2, joinExprs: "T.b =
S.b").collect() == expected)
+ #expect(
+ try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType:
"inner").collect()
+ == expected)
+ }
+ await spark.stop()
+ }
- @Test
- func exceptAll() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(1, 3)
- #expect(try await df.exceptAll(spark.range(1, 5)).collect() == [])
- #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row(1)])
- #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row(1),
Row(2)])
- #expect(try await spark.sql("SELECT * FROM VALUES 1,
1").exceptAll(df).count() == 1)
- await spark.stop()
- }
+ @Test
+ func except() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(1, 3)
+ #expect(try await df.except(spark.range(1, 5)).collect() == [])
+ #expect(try await df.except(spark.range(2, 5)).collect() == [Row(1)])
+ #expect(try await df.except(spark.range(3, 5)).collect() == [Row(1),
Row(2)])
+ #expect(try await spark.sql("SELECT * FROM VALUES 1,
1").except(df).count() == 0)
+ await spark.stop()
+ }
- @Test
- func intersect() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(1, 3)
- #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row(1),
Row(2)])
- #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row(2)])
- #expect(try await df.intersect(spark.range(3, 5)).collect() == [])
- let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
- #expect(try await df2.intersect(df2).count() == 1)
- await spark.stop()
- }
+ @Test
+ func exceptAll() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(1, 3)
+ #expect(try await df.exceptAll(spark.range(1, 5)).collect() == [])
+ #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row(1)])
+ #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row(1),
Row(2)])
+ #expect(try await spark.sql("SELECT * FROM VALUES 1,
1").exceptAll(df).count() == 1)
+ await spark.stop()
+ }
- @Test
- func intersectAll() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(1, 3)
- #expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row(1),
Row(2)])
- #expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row(2)])
- #expect(try await df.intersectAll(spark.range(3, 5)).collect() == [])
- let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
- #expect(try await df2.intersectAll(df2).count() == 2)
- await spark.stop()
- }
+ @Test
+ func intersect() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(1, 3)
+ #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row(1),
Row(2)])
+ #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row(2)])
+ #expect(try await df.intersect(spark.range(3, 5)).collect() == [])
+ let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
+ #expect(try await df2.intersect(df2).count() == 1)
+ await spark.stop()
+ }
- @Test
- func union() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(1, 2)
- #expect(try await df.union(spark.range(1, 3)).collect() == [Row(1),
Row(1), Row(2)])
- #expect(try await df.union(spark.range(2, 3)).collect() == [Row(1),
Row(2)])
- let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
- #expect(try await df2.union(df2).count() == 4)
- await spark.stop()
- }
+ @Test
+ func intersectAll() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(1, 3)
+ #expect(try await df.intersectAll(spark.range(1, 5)).collect() ==
[Row(1), Row(2)])
+ #expect(try await df.intersectAll(spark.range(2, 5)).collect() ==
[Row(2)])
+ #expect(try await df.intersectAll(spark.range(3, 5)).collect() == [])
+ let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
+ #expect(try await df2.intersectAll(df2).count() == 2)
+ await spark.stop()
+ }
- @Test
- func unionAll() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(1, 2)
- #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row(1),
Row(1), Row(2)])
- #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row(1),
Row(2)])
- let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
- #expect(try await df2.unionAll(df2).count() == 4)
- await spark.stop()
- }
+ @Test
+ func union() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(1, 2)
+ #expect(try await df.union(spark.range(1, 3)).collect() == [Row(1),
Row(1), Row(2)])
+ #expect(try await df.union(spark.range(2, 3)).collect() == [Row(1),
Row(2)])
+ let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
+ #expect(try await df2.union(df2).count() == 4)
+ await spark.stop()
+ }
- @Test
- func unionByName() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df1 = try await spark.sql("SELECT 1 a, 2 b")
- let df2 = try await spark.sql("SELECT 4 b, 3 a")
- #expect(try await df1.unionByName(df2).collect() == [Row(1, 2), Row(3, 4)])
- #expect(try await df1.union(df2).collect() == [Row(1, 2), Row(4, 3)])
- let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1")
- #expect(try await df3.unionByName(df3).count() == 4)
- await spark.stop()
- }
+ @Test
+ func unionAll() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(1, 2)
+ #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row(1),
Row(1), Row(2)])
+ #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row(1),
Row(2)])
+ let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
+ #expect(try await df2.unionAll(df2).count() == 4)
+ await spark.stop()
+ }
- @Test
- func repartition() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let tmpDir = "/tmp/" + UUID().uuidString
- let df = try await spark.range(2025)
- for n in [1, 3, 5] as [Int32] {
- try await df.repartition(n).write.mode("overwrite").orc(tmpDir)
- #expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
- }
- try await
spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir)
- #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
- await spark.stop()
- }
+ @Test
+ func unionByName() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df1 = try await spark.sql("SELECT 1 a, 2 b")
+ let df2 = try await spark.sql("SELECT 4 b, 3 a")
+ #expect(try await df1.unionByName(df2).collect() == [Row(1, 2), Row(3,
4)])
+ #expect(try await df1.union(df2).collect() == [Row(1, 2), Row(4, 3)])
+ let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1")
+ #expect(try await df3.unionByName(df3).count() == 4)
+ await spark.stop()
+ }
- @Test
- func repartitionByExpression() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let tmpDir = "/tmp/" + UUID().uuidString
- let df = try await spark.range(2025)
- for n in [1, 3, 5] as [Int32] {
- try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir)
- #expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
- try await df.repartitionByExpression(n,
"id").write.mode("overwrite").orc(tmpDir)
- #expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
- }
- try await spark.range(1).repartition(10,
"id").write.mode("overwrite").orc(tmpDir)
- #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
- try await
spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir)
- #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
- await spark.stop()
- }
+ @Test
+ func repartition() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let tmpDir = "/tmp/" + UUID().uuidString
+ let df = try await spark.range(2025)
+ for n in [1, 3, 5] as [Int32] {
+ try await df.repartition(n).write.mode("overwrite").orc(tmpDir)
+ #expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
+ }
+ try await
spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir)
+ #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
+ await spark.stop()
+ }
- @Test
- func coalesce() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let tmpDir = "/tmp/" + UUID().uuidString
- let df = try await spark.range(2025)
- for n in [1, 2, 3] as [Int32] {
- try await df.coalesce(n).write.mode("overwrite").orc(tmpDir)
- #expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
- }
- try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir)
- #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
- await spark.stop()
- }
+ @Test
+ func repartitionByExpression() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let tmpDir = "/tmp/" + UUID().uuidString
+ let df = try await spark.range(2025)
+ for n in [1, 3, 5] as [Int32] {
+ try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir)
+ #expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
+ try await df.repartitionByExpression(n,
"id").write.mode("overwrite").orc(tmpDir)
+ #expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
+ }
+ try await spark.range(1).repartition(10,
"id").write.mode("overwrite").orc(tmpDir)
+ #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
+ try await
spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir)
+ #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
+ await spark.stop()
+ }
- @Test
- func distinct() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3)
T(a)")
- #expect(try await df.distinct().count() == 3)
- await spark.stop()
- }
+ @Test
+ func coalesce() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let tmpDir = "/tmp/" + UUID().uuidString
+ let df = try await spark.range(2025)
+ for n in [1, 2, 3] as [Int32] {
+ try await df.coalesce(n).write.mode("overwrite").orc(tmpDir)
+ #expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
+ }
+ try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir)
+ #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
+ await spark.stop()
+ }
- @Test
- func dropDuplicates() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3)
T(a)")
- #expect(try await df.dropDuplicates().count() == 3)
- #expect(try await df.dropDuplicates("a").count() == 3)
- await spark.stop()
- }
+ @Test
+ func distinct() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1),
(3) T(a)")
+ #expect(try await df.distinct().count() == 3)
+ await spark.stop()
+ }
- @Test
- func dropDuplicatesWithinWatermark() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3)
T(a)")
- #expect(try await df.dropDuplicatesWithinWatermark().count() == 3)
- #expect(try await df.dropDuplicatesWithinWatermark("a").count() == 3)
- await spark.stop()
- }
+ @Test
+ func dropDuplicates() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1),
(3) T(a)")
+ #expect(try await df.dropDuplicates().count() == 3)
+ #expect(try await df.dropDuplicates("a").count() == 3)
+ await spark.stop()
+ }
- @Test
- func withWatermark() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark
- .sql("""
- SELECT * FROM VALUES
- (1, now()),
- (1, now() - INTERVAL 1 HOUR),
- (1, now() - INTERVAL 2 HOUR)
- T(data, eventTime)
- """)
- .withWatermark("eventTime", "1 minute") // This tests only API for now
- #expect(try await df.dropDuplicatesWithinWatermark("data").count() == 1)
- await spark.stop()
- }
+ @Test
+ func dropDuplicatesWithinWatermark() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1),
(3) T(a)")
+ #expect(try await df.dropDuplicatesWithinWatermark().count() == 3)
+ #expect(try await df.dropDuplicatesWithinWatermark("a").count() == 3)
+ await spark.stop()
+ }
- @Test
- func describe() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(10)
- let expected = [Row("10"), Row("4.5"), Row("3.0276503540974917"),
Row("0"), Row("9")]
- #expect(try await df.describe().select("id").collect() == expected)
- #expect(try await df.describe("id").select("id").collect() == expected)
- await spark.stop()
- }
+ @Test
+ func withWatermark() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df =
+ try await spark
+ .sql(
+ """
+ SELECT * FROM VALUES
+ (1, now()),
+ (1, now() - INTERVAL 1 HOUR),
+ (1, now() - INTERVAL 2 HOUR)
+ T(data, eventTime)
+ """
+ )
+ .withWatermark("eventTime", "1 minute") // This tests only API for now
+ #expect(try await df.dropDuplicatesWithinWatermark("data").count() == 1)
+ await spark.stop()
+ }
- @Test
- func summary() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let expected = [
- Row("10"), Row("4.5"), Row("3.0276503540974917"),
- Row("0"), Row("2"), Row("4"), Row("7"), Row("9")
- ]
- #expect(try await spark.range(10).summary().select("id").collect() ==
expected)
- #expect(try await spark.range(10).summary("min",
"max").select("id").collect() == [Row("0"), Row("9")])
- await spark.stop()
- }
+ @Test
+ func describe() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(10)
+ let expected = [Row("10"), Row("4.5"), Row("3.0276503540974917"),
Row("0"), Row("9")]
+ #expect(try await df.describe().select("id").collect() == expected)
+ #expect(try await df.describe("id").select("id").collect() == expected)
+ await spark.stop()
+ }
- @Test
- func groupBy() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let rows = try await spark.range(3).groupBy("id").agg("count(*)",
"sum(*)", "avg(*)").collect()
- #expect(rows == [Row(0, 1, 0, 0.0), Row(1, 1, 1, 1.0), Row(2, 1, 2, 2.0)])
- await spark.stop()
- }
+ @Test
+ func summary() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let expected = [
+ Row("10"), Row("4.5"), Row("3.0276503540974917"),
+ Row("0"), Row("2"), Row("4"), Row("7"), Row("9"),
+ ]
+ #expect(try await spark.range(10).summary().select("id").collect() ==
expected)
+ #expect(
+ try await spark.range(10).summary("min", "max").select("id").collect()
== [
+ Row("0"), Row("9"),
+ ])
+ await spark.stop()
+ }
- @Test
- func rollup() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model")
- .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
- #expect(rows == [
- Row("Dublin", "Honda Accord", 10),
- Row("Dublin", "Honda CRV", 3),
- Row("Dublin", "Honda Civic", 20),
- Row("Dublin", nil, 33),
- Row("Fremont", "Honda Accord", 15),
- Row("Fremont", "Honda CRV", 7),
- Row("Fremont", "Honda Civic", 10),
- Row("Fremont", nil, 32),
- Row("San Jose", "Honda Accord", 8),
- Row("San Jose", "Honda Civic", 5),
- Row("San Jose", nil, 13),
- Row(nil, nil, 78),
- ])
- await spark.stop()
- }
+ @Test
+ func groupBy() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.range(3).groupBy("id").agg("count(*)",
"sum(*)", "avg(*)")
+ .collect()
+ #expect(rows == [Row(0, 1, 0, 0.0), Row(1, 1, 1, 1.0), Row(2, 1, 2,
2.0)])
+ await spark.stop()
+ }
- @Test
- func cube() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model")
- .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
- #expect(rows == [
- Row("Dublin", "Honda Accord", 10),
- Row("Dublin", "Honda CRV", 3),
- Row("Dublin", "Honda Civic", 20),
- Row("Dublin", nil, 33),
- Row("Fremont", "Honda Accord", 15),
- Row("Fremont", "Honda CRV", 7),
- Row("Fremont", "Honda Civic", 10),
- Row("Fremont", nil, 32),
- Row("San Jose", "Honda Accord", 8),
- Row("San Jose", "Honda Civic", 5),
- Row("San Jose", nil, 13),
- Row(nil, "Honda Accord", 33),
- Row(nil, "Honda CRV", 10),
- Row(nil, "Honda Civic", 35),
- Row(nil, nil, 78),
- ])
- await spark.stop()
- }
+ @Test
+ func rollup() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model")
+ .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
+ #expect(
+ rows == [
+ Row("Dublin", "Honda Accord", 10),
+ Row("Dublin", "Honda CRV", 3),
+ Row("Dublin", "Honda Civic", 20),
+ Row("Dublin", nil, 33),
+ Row("Fremont", "Honda Accord", 15),
+ Row("Fremont", "Honda CRV", 7),
+ Row("Fremont", "Honda Civic", 10),
+ Row("Fremont", nil, 32),
+ Row("San Jose", "Honda Accord", 8),
+ Row("San Jose", "Honda Civic", 5),
+ Row("San Jose", nil, 13),
+ Row(nil, nil, 78),
+ ])
+ await spark.stop()
+ }
- @Test
- func toJSON() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.range(2).toJSON()
- #expect(try await df.columns == ["to_json(struct(id))"])
- #expect(try await df.collect() == [Row("{\"id\":0}"), Row("{\"id\":1}")])
+ @Test
+ func cube() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model")
+ .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
+ #expect(
+ rows == [
+ Row("Dublin", "Honda Accord", 10),
+ Row("Dublin", "Honda CRV", 3),
+ Row("Dublin", "Honda Civic", 20),
+ Row("Dublin", nil, 33),
+ Row("Fremont", "Honda Accord", 15),
+ Row("Fremont", "Honda CRV", 7),
+ Row("Fremont", "Honda Civic", 10),
+ Row("Fremont", nil, 32),
+ Row("San Jose", "Honda Accord", 8),
+ Row("San Jose", "Honda Civic", 5),
+ Row("San Jose", nil, 13),
+ Row(nil, "Honda Accord", 33),
+ Row(nil, "Honda CRV", 10),
+ Row(nil, "Honda Civic", 35),
+ Row(nil, nil, 78),
+ ])
+ await spark.stop()
+ }
- let expected = [Row("{\"a\":1,\"b\":2,\"c\":3}")]
- #expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() ==
expected)
- await spark.stop()
- }
+ @Test
+ func toJSON() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.range(2).toJSON()
+ #expect(try await df.columns == ["to_json(struct(id))"])
+ #expect(try await df.collect() == [Row("{\"id\":0}"), Row("{\"id\":1}")])
- @Test
- func unpivot() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.sql(
- """
- SELECT * FROM
- VALUES (1, 11, 12L),
- (2, 21, 22L)
- T(id, int, long)
- """)
- let expected = [
- Row(1, "int", 11),
- Row(1, "long", 12),
- Row(2, "int", 21),
- Row(2, "long", 22),
- ]
- #expect(try await df.unpivot(["id"], ["int", "long"], "variable",
"value").collect() == expected)
- #expect(try await df.melt(["id"], ["int", "long"], "variable",
"value").collect() == expected)
- await spark.stop()
- }
+ let expected = [Row("{\"a\":1,\"b\":2,\"c\":3}")]
+ #expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect()
== expected)
+ await spark.stop()
+ }
- @Test
- func transpose() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- if await spark.version.starts(with: "4.") {
- #expect(try await spark.range(1).transpose().columns == ["key", "0"])
- #expect(try await spark.range(1).transpose().count() == 0)
-
+ @Test
+ func unpivot() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.sql(
- """
- SELECT * FROM
- VALUES ('A', 1, 2),
- ('B', 3, 4)
- T(id, val1, val2)
- """)
+ """
+ SELECT * FROM
+ VALUES (1, 11, 12L),
+ (2, 21, 22L)
+ T(id, int, long)
+ """)
let expected = [
- Row("val1", 1, 3),
- Row("val2", 2, 4),
+ Row(1, "int", 11),
+ Row(1, "long", 12),
+ Row(2, "int", 21),
+ Row(2, "long", 22),
]
- #expect(try await df.transpose().collect() == expected)
- #expect(try await df.transpose("id").collect() == expected)
+ #expect(
+ try await df.unpivot(["id"], ["int", "long"], "variable",
"value").collect() == expected)
+ #expect(
+ try await df.melt(["id"], ["int", "long"], "variable",
"value").collect() == expected)
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func decimal() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let df = try await spark.sql(
- """
- SELECT * FROM VALUES
- (1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)),
- (2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL))
- """)
- #expect(try await df.dtypes.map { $0.1 } ==
- ["decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)"])
- let expected = [
- Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)),
- Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil)
- ]
- #expect(try await df.collect() == expected)
- await spark.stop()
- }
-#endif
+ @Test
+ func transpose() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ if await spark.version.starts(with: "4.") {
+ #expect(try await spark.range(1).transpose().columns == ["key", "0"])
+ #expect(try await spark.range(1).transpose().count() == 0)
+
+ let df = try await spark.sql(
+ """
+ SELECT * FROM
+ VALUES ('A', 1, 2),
+ ('B', 3, 4)
+ T(id, val1, val2)
+ """)
+ let expected = [
+ Row("val1", 1, 3),
+ Row("val2", 2, 4),
+ ]
+ #expect(try await df.transpose().collect() == expected)
+ #expect(try await df.transpose("id").collect() == expected)
+ }
+ await spark.stop()
+ }
+
+ @Test
+ func decimal() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.sql(
+ """
+ SELECT * FROM VALUES
+ (1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)),
+ (2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL))
+ """)
+ #expect(
+ try await df.dtypes.map { $0.1 } == [
+ "decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)",
+ ])
+ let expected = [
+ Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)),
+ Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil),
+ ]
+ #expect(try await df.collect() == expected)
+ await spark.stop()
+ }
+ #endif
@Test
func storageLevel() async throws {
diff --git a/Tests/SparkConnectTests/DataFrameWriterTests.swift
b/Tests/SparkConnectTests/DataFrameWriterTests.swift
index 5228667..7e91a30 100644
--- a/Tests/SparkConnectTests/DataFrameWriterTests.swift
+++ b/Tests/SparkConnectTests/DataFrameWriterTests.swift
@@ -18,9 +18,8 @@
//
import Foundation
-import Testing
-
import SparkConnect
+import Testing
/// A test suite for `DataFrameWriter`
@Suite(.serialized)
diff --git a/Tests/SparkConnectTests/SQLTests.swift
b/Tests/SparkConnectTests/SQLTests.swift
index 5c5efb2..808c27b 100644
--- a/Tests/SparkConnectTests/SQLTests.swift
+++ b/Tests/SparkConnectTests/SQLTests.swift
@@ -27,7 +27,8 @@ import Testing
struct SQLTests {
let fm = FileManager.default
let path = Bundle.module.path(forResource: "queries", ofType: "")!
- let regenerateGoldenFiles =
ProcessInfo.processInfo.environment["SPARK_GENERATE_GOLDEN_FILES"] == "1"
+ let regenerateGoldenFiles =
+ ProcessInfo.processInfo.environment["SPARK_GENERATE_GOLDEN_FILES"] == "1"
let regexID = /#\d+L?/
let regexPlanId = /plan_id=\d+/
@@ -90,35 +91,39 @@ struct SQLTests {
"variant.sql",
]
-#if !os(Linux)
- @Test
- func runAll() async throws {
- let spark = try await SparkSession.builder.getOrCreate()
- let MAX = Int32.max
- for name in try! fm.contentsOfDirectory(atPath: path).sorted() {
- guard name.hasSuffix(".sql") else { continue }
- print(name)
- if await !spark.version.starts(with: "4.") &&
queriesForSpark4Only.contains(name) {
- print("Skip query \(name) due to the difference between Spark 3 and
4.")
- continue
- }
+ #if !os(Linux)
+ @Test
+ func runAll() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let MAX = Int32.max
+ for name in try! fm.contentsOfDirectory(atPath: path).sorted() {
+ guard name.hasSuffix(".sql") else { continue }
+ print(name)
+ if await !spark.version.starts(with: "4.") &&
queriesForSpark4Only.contains(name) {
+ print("Skip query \(name) due to the difference between Spark 3 and
4.")
+ continue
+ }
- let sql = try String(contentsOf: URL(fileURLWithPath:
"\(path)/\(name)"), encoding: .utf8)
- let result = try await spark.sql(sql).showString(MAX, MAX,
false).collect()[0].get(0) as! String
- let answer = cleanUp(result.trimmingCharacters(in:
.whitespacesAndNewlines))
- if (regenerateGoldenFiles) {
- let path =
"\(FileManager.default.currentDirectoryPath)/Tests/SparkConnectTests/Resources/queries/\(name).answer"
- fm.createFile(atPath: path, contents: answer.data(using: .utf8)!,
attributes: nil)
- } else {
- let expected = cleanUp(try String(contentsOf: URL(fileURLWithPath:
"\(path)/\(name).answer"), encoding: .utf8))
+ let sql = try String(contentsOf: URL(fileURLWithPath:
"\(path)/\(name)"), encoding: .utf8)
+ let result =
+ try await spark.sql(sql).showString(MAX, MAX,
false).collect()[0].get(0) as! String
+ let answer = cleanUp(result.trimmingCharacters(in:
.whitespacesAndNewlines))
+ if regenerateGoldenFiles {
+ let path =
+
"\(FileManager.default.currentDirectoryPath)/Tests/SparkConnectTests/Resources/queries/\(name).answer"
+ fm.createFile(atPath: path, contents: answer.data(using: .utf8)!,
attributes: nil)
+ } else {
+ let expected = cleanUp(
+ try String(contentsOf: URL(fileURLWithPath:
"\(path)/\(name).answer"), encoding: .utf8)
+ )
.trimmingCharacters(in: .whitespacesAndNewlines)
- if (answer != expected) {
- print("Try to compare normalized result.")
- #expect(normalize(answer) == normalize(expected))
+ if answer != expected {
+ print("Try to compare normalized result.")
+ #expect(normalize(answer) == normalize(expected))
+ }
}
}
+ await spark.stop()
}
- await spark.stop()
- }
-#endif
+ #endif
}
diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift
b/Tests/SparkConnectTests/SparkConnectClientTests.swift
index e47eab6..cd57905 100644
--- a/Tests/SparkConnectTests/SparkConnectClientTests.swift
+++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift
@@ -35,7 +35,8 @@ struct SparkConnectClientTests {
@Test
func parameters() async throws {
- let client = SparkConnectClient(remote:
"sc://host1:123/;tOkeN=abcd;user_ID=test;USER_agent=myagent")
+ let client = SparkConnectClient(
+ remote: "sc://host1:123/;tOkeN=abcd;user_ID=test;USER_agent=myagent")
#expect(await client.token == "abcd")
#expect(await client.userContext.userID == "test")
#expect(await client.clientType == "myagent")
diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift
b/Tests/SparkConnectTests/SparkSessionTests.swift
index 1b4a658..326f37d 100644
--- a/Tests/SparkConnectTests/SparkSessionTests.swift
+++ b/Tests/SparkConnectTests/SparkSessionTests.swift
@@ -58,7 +58,8 @@ struct SparkSessionTests {
await SparkSession.builder.clear()
let spark1 = try await SparkSession.builder.getOrCreate()
let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ??
"sc://localhost"
- let spark2 = try await
SparkSession.builder.remote("\(remote)/;session_id=\(spark1.sessionID)").getOrCreate()
+ let spark2 = try await
SparkSession.builder.remote("\(remote)/;session_id=\(spark1.sessionID)")
+ .getOrCreate()
await spark2.stop()
#expect(spark1.sessionID == spark2.sessionID)
#expect(spark1 == spark2)
@@ -81,11 +82,11 @@ struct SparkSessionTests {
@Test func userContext() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
-#if os(macOS) || os(Linux)
- let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext
-#else
- let defaultUserContext = "".toUserContext
-#endif
+ #if os(macOS) || os(Linux)
+ let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext
+ #else
+ let defaultUserContext = "".toUserContext
+ #endif
#expect(await spark.client.userContext == defaultUserContext)
await spark.stop()
}
@@ -129,74 +130,76 @@ struct SparkSessionTests {
await spark.stop()
}
-#if !os(Linux)
- @Test
- func sql() async throws {
- await SparkSession.builder.clear()
- let spark = try await SparkSession.builder.getOrCreate()
- let expected = [Row(true, 1, "a")]
- if await spark.version.starts(with: "4.") {
- #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() ==
expected)
- #expect(try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y":
1, "z": "a"]).collect() == expected)
+ #if !os(Linux)
+ @Test
+ func sql() async throws {
+ await SparkSession.builder.clear()
+ let spark = try await SparkSession.builder.getOrCreate()
+ let expected = [Row(true, 1, "a")]
+ if await spark.version.starts(with: "4.") {
+ #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect()
== expected)
+ #expect(
+ try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1,
"z": "a"]).collect()
+ == expected)
+ }
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func addInvalidArtifact() async throws {
- await SparkSession.builder.clear()
- let spark = try await SparkSession.builder.getOrCreate()
- await #expect(throws: SparkConnectError.InvalidArgument) {
- try await spark.addArtifact("x.txt")
+ @Test
+ func addInvalidArtifact() async throws {
+ await SparkSession.builder.clear()
+ let spark = try await SparkSession.builder.getOrCreate()
+ await #expect(throws: SparkConnectError.InvalidArgument) {
+ try await spark.addArtifact("x.txt")
+ }
+ await spark.stop()
}
- await spark.stop()
- }
- @Test
- func addArtifact() async throws {
- let fm = FileManager()
- let path = "my.jar"
- let url = URL(fileURLWithPath: path)
+ @Test
+ func addArtifact() async throws {
+ let fm = FileManager()
+ let path = "my.jar"
+ let url = URL(fileURLWithPath: path)
- await SparkSession.builder.clear()
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8)))
- if await spark.version.starts(with: "4.") {
- try await spark.addArtifact(path)
- try await spark.addArtifact(url)
+ await SparkSession.builder.clear()
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8)))
+ if await spark.version.starts(with: "4.") {
+ try await spark.addArtifact(path)
+ try await spark.addArtifact(url)
+ }
+ try fm.removeItem(atPath: path)
+ await spark.stop()
}
- try fm.removeItem(atPath: path)
- await spark.stop()
- }
- @Test
- func addArtifacts() async throws {
- let fm = FileManager()
- let path = "my.jar"
- let url = URL(fileURLWithPath: path)
+ @Test
+ func addArtifacts() async throws {
+ let fm = FileManager()
+ let path = "my.jar"
+ let url = URL(fileURLWithPath: path)
- await SparkSession.builder.clear()
- let spark = try await SparkSession.builder.getOrCreate()
- #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8)))
- if await spark.version.starts(with: "4.") {
- try await spark.addArtifacts(url, url)
+ await SparkSession.builder.clear()
+ let spark = try await SparkSession.builder.getOrCreate()
+ #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8)))
+ if await spark.version.starts(with: "4.") {
+ try await spark.addArtifacts(url, url)
+ }
+ try fm.removeItem(atPath: path)
+ await spark.stop()
}
- try fm.removeItem(atPath: path)
- await spark.stop()
- }
- @Test
- func executeCommand() async throws {
- await SparkSession.builder.clear()
- let spark = try await SparkSession.builder.getOrCreate()
- if await spark.version.starts(with: "4.") {
- await #expect(throws: SparkConnectError.DataSourceNotFound) {
- try await spark.executeCommand("runner", "command", [:]).show()
+ @Test
+ func executeCommand() async throws {
+ await SparkSession.builder.clear()
+ let spark = try await SparkSession.builder.getOrCreate()
+ if await spark.version.starts(with: "4.") {
+ await #expect(throws: SparkConnectError.DataSourceNotFound) {
+ try await spark.executeCommand("runner", "command", [:]).show()
+ }
}
+ await spark.stop()
}
- await spark.stop()
- }
-#endif
+ #endif
@Test
func table() async throws {
@@ -215,10 +218,10 @@ struct SparkSessionTests {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.time(spark.range(1000).count) == 1000)
-#if !os(Linux)
- #expect(try await spark.time(spark.range(1).collect) == [Row(0)])
- try await spark.time(spark.range(10).show)
-#endif
+ #if !os(Linux)
+ #expect(try await spark.time(spark.range(1).collect) == [Row(0)])
+ try await spark.time(spark.range(10).show)
+ #endif
await spark.stop()
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]