This is an automated email from the ASF dual-hosted git repository.
dongjoon-hyun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push:
new 10a3455 [SPARK-57309] Support `stat.sampleBy` for `DataFrame`
10a3455 is described below
commit 10a345563547936f2deb4692821c73a8fdbf0df2
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Sun Jun 7 18:47:58 2026 -0700
[SPARK-57309] Support `stat.sampleBy` for `DataFrame`
### What changes were proposed in this pull request?
This PR aims to support `sampleBy` for `DataFrame` by wiring the
`StatSampleBy`
Spark Connect relation through `DataFrameStatFunctions`, exposed via
`DataFrame.stat`
like PySpark/Scala.
```swift
public func sampleBy<T: Sendable & Hashable>(_ col: String, _ fractions:
[T: Double], _ seed: Int64) async -> DataFrame
public func sampleBy<T: Sendable & Hashable>(_ col: String, _ fractions:
[T: Double]) async -> DataFrame
```
`sampleBy` returns a stratified sample without replacement based on the
fraction
given for each stratum. A stratum that is not specified is treated as
having a
fraction of zero. The seed is optional; a random seed is used when it is
omitted.
### Why are the changes needed?
To improve API coverage by mirroring PySpark/Scala `DataFrameStatFunctions`.
### Does this PR introduce _any_ user-facing change?
Yes, this PR adds a new API, `DataFrame.stat.sampleBy`.
### How was this patch tested?
Pass the CIs with a new test case, `sampleBy`, in
`DataFrameStatFunctionsTests`.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (Claude Opus 4.8)
This patch had conflicts when merged, resolved by
Committer: Dongjoon Hyun <[email protected]>
Closes #411 from dongjoon-hyun/SPARK-57309.
Authored-by: Dongjoon Hyun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
Sources/SparkConnect/DataFrameStatFunctions.swift | 57 ++++++++++++++++++++++
Sources/SparkConnect/SparkConnectClient.swift | 18 +++++++
.../DataFrameStatFunctionsTests.swift | 15 ++++++
3 files changed, 90 insertions(+)
diff --git a/Sources/SparkConnect/DataFrameStatFunctions.swift
b/Sources/SparkConnect/DataFrameStatFunctions.swift
index 29a9ecc..cfdf356 100644
--- a/Sources/SparkConnect/DataFrameStatFunctions.swift
+++ b/Sources/SparkConnect/DataFrameStatFunctions.swift
@@ -100,6 +100,33 @@ public actor DataFrameStatFunctions: Sendable {
return quantilesPerColumn.map { ($0 as! [any Sendable]).map { $0 as!
Double } }
}
+ /// Returns a stratified sample without replacement based on the fraction
given on each stratum.
+ /// - Parameters:
+ /// - col: The name of the column that defines the strata.
+ /// - fractions: The sampling fraction for each stratum. If a stratum is
not specified, its
+ /// fraction is treated as zero. Each fraction must be in `[0, 1]`.
+ /// - seed: The random seed.
+ /// - Returns: A ``DataFrame`` representing the stratified sample.
+ public func sampleBy<T: Sendable & Hashable>(
+ _ col: String, _ fractions: [T: Double], _ seed: Int64
+ ) async -> DataFrame {
+ let fractionLiterals = fractions.map { (stratumLiteral($0.key), $0.value) }
+ return await transform { SparkConnectClient.getStatSampleBy($0, col,
fractionLiterals, seed) }
+ }
+
+ /// Returns a stratified sample without replacement based on the fraction
given on each stratum,
+ /// using a random seed.
+ /// - Parameters:
+ /// - col: The name of the column that defines the strata.
+ /// - fractions: The sampling fraction for each stratum. If a stratum is
not specified, its
+ /// fraction is treated as zero. Each fraction must be in `[0, 1]`.
+ /// - Returns: A ``DataFrame`` representing the stratified sample.
+ public func sampleBy<T: Sendable & Hashable>(
+ _ col: String, _ fractions: [T: Double]
+ ) async -> DataFrame {
+ return await sampleBy(col, fractions, Int64.random(in:
Int64.min...Int64.max))
+ }
+
// MARK: - Helpers
/// Builds a single-value ``DataFrame`` from this ``DataFrame``'s plan using
the given plan
@@ -109,6 +136,36 @@ public actor DataFrameStatFunctions: Sendable {
let result = DataFrame(spark: await df.spark, plan: f(plan.root))
return try await result.collect()[0].get(0) as! Double
}
+
+ /// Builds a new ``DataFrame`` from this ``DataFrame``'s plan using the
given plan builder.
+ private func transform(_ f: (Relation) -> Plan) async -> DataFrame {
+ let plan = await df.getPlan() as! Plan
+ return DataFrame(spark: await df.spark, plan: f(plan.root))
+ }
+
+ /// Converts a `sampleBy` stratum value to an ``ExpressionLiteral``.
+ private func stratumLiteral(_ value: Sendable) -> ExpressionLiteral {
+ var literal = ExpressionLiteral()
+ switch value {
+ case let value as Bool:
+ literal.boolean = value
+ case let value as Int:
+ literal.long = Int64(value)
+ case let value as Int32:
+ literal.integer = value
+ case let value as Int64:
+ literal.long = value
+ case let value as Float:
+ literal.float = value
+ case let value as Double:
+ literal.double = value
+ case let value as String:
+ literal.string = value
+ default:
+ literal.string = value as! String
+ }
+ return literal
+ }
}
extension DataFrame {
diff --git a/Sources/SparkConnect/SparkConnectClient.swift
b/Sources/SparkConnect/SparkConnectClient.swift
index a491c42..89f32eb 100644
--- a/Sources/SparkConnect/SparkConnectClient.swift
+++ b/Sources/SparkConnect/SparkConnectClient.swift
@@ -650,6 +650,24 @@ public actor SparkConnectClient {
return createPlan { $0.approxQuantile = approxQuantile }
}
+ static func getStatSampleBy(
+ _ child: Relation, _ col: String, _ fractions: [(ExpressionLiteral,
Double)], _ seed: Int64
+ ) -> Plan {
+ var sampleBy = Spark_Connect_StatSampleBy()
+ sampleBy.input = child
+ var colExpr = Spark_Connect_Expression()
+ colExpr.exprType = .unresolvedAttribute(col.toUnresolvedAttribute)
+ sampleBy.col = colExpr
+ sampleBy.fractions = fractions.map {
+ var fraction = Spark_Connect_StatSampleBy.Fraction()
+ fraction.stratum = $0.0
+ fraction.fraction = $0.1
+ return fraction
+ }
+ sampleBy.seed = seed
+ return createPlan { $0.sampleBy = sampleBy }
+ }
+
static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
var sort = Sort()
sort.input = child
diff --git a/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
b/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
index bda3b55..578a354 100644
--- a/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
+++ b/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
@@ -70,4 +70,19 @@ struct DataFrameStatFunctionsTests {
#expect(quantiles == [[1.0, 3.0, 5.0], [10.0, 30.0, 50.0]])
await spark.stop()
}
+
+ @Test
+ func sampleBy() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ // Strata 0, 1, 2 each have 33 rows.
+ let df = try await spark.sql("SELECT id % 3 AS key FROM range(0, 99)")
+ // A fraction of 1.0 keeps every row of a stratum; an unspecified stratum
(or 0.0) keeps none,
+ // so the result count is deterministic regardless of the seed.
+ #expect(try await df.stat.sampleBy("key", [0: 1.0, 1: 0.0], 0).count() ==
33)
+ // `Int64` strata are also supported.
+ #expect(try await df.stat.sampleBy("key", [Int64(0): 1.0, Int64(2): 1.0],
0).count() == 66)
+ // The seed is optional.
+ #expect(try await df.stat.sampleBy("key", [0: 1.0, 1: 1.0, 2:
1.0]).count() == 99)
+ await spark.stop()
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]