This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push:
new c11e872 [SPARK-54892] Support `list` data type
c11e872 is described below
commit c11e872ee5ba75422b49836c89ff945541e5506c
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Sun Jan 4 18:49:20 2026 +0900
[SPARK-54892] Support `list` data type
### What changes were proposed in this pull request?
This PR aims to support `list` data type by syncing with Apache Arrow Swift
project.
### Why are the changes needed?
We still cannot use Apache Arrow Swift directly due to the mismatch like
https://github.com/apache/arrow/pull/46628. Until we achieve the goal, we need
to sync with the upstream to bring newly added but unreleased features.
- https://github.com/apache/arrow-swift/issues/16
- https://github.com/apache/arrow-swift/pull/39
### Does this PR introduce _any_ user-facing change?
No,
### How was this patch tested?
Pass the CIs.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #272 from dongjoon-hyun/SPARK-54892.
Authored-by: Dongjoon Hyun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
Sources/SparkConnect/ArrowArray.swift | 145 +++++++++++++++++++++-----
Sources/SparkConnect/ArrowArrayBuilder.swift | 59 ++++++++++-
Sources/SparkConnect/ArrowBufferBuilder.swift | 85 ++++++++++++++-
Sources/SparkConnect/ArrowReader.swift | 40 +++++++
Sources/SparkConnect/ArrowReaderHelper.swift | 26 +++--
Sources/SparkConnect/ArrowType.swift | 32 +++++-
Sources/SparkConnect/ArrowWriter.swift | 31 +++---
Sources/SparkConnect/ProtoUtil.swift | 9 +-
8 files changed, 372 insertions(+), 55 deletions(-)
diff --git a/Sources/SparkConnect/ArrowArray.swift
b/Sources/SparkConnect/ArrowArray.swift
index 61a4e49..b95a896 100644
--- a/Sources/SparkConnect/ArrowArray.swift
+++ b/Sources/SparkConnect/ArrowArray.swift
@@ -120,7 +120,9 @@ public class ArrowArrayHolderImpl: ArrowArrayHolder {
case .binary:
return try ArrowArrayHolderImpl(BinaryArray(with))
case .strct:
- return try ArrowArrayHolderImpl(StructArray(with))
+ return try ArrowArrayHolderImpl(NestedArray(with))
+ case .list:
+ return try ArrowArrayHolderImpl(NestedArray(with))
default:
throw ArrowError.invalid("Array not found for type: \(arrowType)")
}
@@ -395,16 +397,40 @@ public class BinaryArray: ArrowArray<Data> {
}
/// @nodoc
-public class StructArray: ArrowArray<[Any?]> {
- public private(set) var arrowFields: [ArrowArrayHolder]?
+public class NestedArray: ArrowArray<[Any?]> {
+ private var children: [ArrowArrayHolder]?
+
public required init(_ arrowData: ArrowData) throws {
try super.init(arrowData)
- var fields = [ArrowArrayHolder]()
- for child in arrowData.children {
- fields.append(try ArrowArrayHolderImpl.loadArray(child.type, with:
child))
- }
- self.arrowFields = fields
+ switch arrowData.type.id {
+ case .list:
+ guard arrowData.children.count == 1 else {
+ throw ArrowError.invalid("List array must have exactly one child")
+ }
+
+ guard let listType = arrowData.type as? ArrowTypeList else {
+ throw ArrowError.invalid("Expected ArrowTypeList for list type ID")
+ }
+
+ self.children = [
+ try ArrowArrayHolderImpl.loadArray(
+ listType.elementField.type,
+ with: arrowData.children[0]
+ )
+ ]
+
+ case .strct:
+ var fields = [ArrowArrayHolder]()
+ for child in arrowData.children {
+ fields.append(try ArrowArrayHolderImpl.loadArray(child.type, with:
child))
+ }
+ self.children = fields
+
+ default:
+ throw ArrowError.invalid(
+ "NestedArray only supports list and struct types, got:
\(arrowData.type.id)")
+ }
}
public override subscript(_ index: UInt) -> [Any?]? {
@@ -412,36 +438,105 @@ public class StructArray: ArrowArray<[Any?]> {
return nil
}
- if let fields = arrowFields {
+ guard let children = self.children else {
+ return nil
+ }
+
+ switch arrowData.type.id {
+ case .list:
+ guard let values = children.first else { return nil }
+
+ let offsets = self.arrowData.buffers[1]
+ let offsetIndex = Int(index) * MemoryLayout<Int32>.stride
+
+ let startOffset = offsets.rawPointer.advanced(by: offsetIndex).load(as:
Int32.self)
+ let endOffset = offsets.rawPointer.advanced(by: offsetIndex +
MemoryLayout<Int32>.stride)
+ .load(as: Int32.self)
+
+ var items = [Any?]()
+ for i in startOffset..<endOffset {
+ items.append(values.array.asAny(UInt(i)))
+ }
+
+ return items
+
+ case .strct:
var result = [Any?]()
- for field in fields {
+ for field in children {
result.append(field.array.asAny(index))
}
-
return result
- }
- return nil
+ default:
+ return nil
+ }
}
public override func asString(_ index: UInt) -> String {
- if self.arrowData.isNull(index) {
- return ""
- }
+ switch arrowData.type.id {
+ case .list:
+ if self.arrowData.isNull(index) {
+ return "null"
+ }
+
+ guard let list = self[index] else {
+ return "null"
+ }
- var output = "{"
- if let fields = arrowFields {
- for fieldIndex in 0..<fields.count {
- let asStr = fields[fieldIndex].array as? AsString
- if fieldIndex == 0 {
- output.append("\(asStr!.asString(index))")
+ var output = "["
+ for (i, item) in list.enumerated() {
+ if i > 0 {
+ output.append(",")
+ }
+
+ if item == nil {
+ output.append("null")
+ } else if let asStringItem = item as? AsString {
+ output.append(asStringItem.asString(0))
} else {
- output.append(",\(asStr!.asString(index))")
+ output.append("\(item!)")
}
}
+ output.append("]")
+ return output
+
+ case .strct:
+ if self.arrowData.isNull(index) {
+ return ""
+ }
+
+ var output = "{"
+ if let children = self.children {
+ for fieldIndex in 0..<children.count {
+ let asStr = children[fieldIndex].array as? AsString
+ if fieldIndex == 0 {
+ output.append("\(asStr!.asString(index))")
+ } else {
+ output.append(",\(asStr!.asString(index))")
+ }
+ }
+ }
+ output += "}"
+ return output
+
+ default:
+ return ""
}
+ }
+
+ public var isListArray: Bool {
+ return arrowData.type.id == .list
+ }
+
+ public var isStructArray: Bool {
+ return arrowData.type.id == .strct
+ }
+
+ public var fields: [ArrowArrayHolder]? {
+ return arrowData.type.id == .strct ? children : nil
+ }
- output += "}"
- return output
+ public var values: ArrowArrayHolder? {
+ return arrowData.type.id == .list ? children?.first : nil
}
}
diff --git a/Sources/SparkConnect/ArrowArrayBuilder.swift
b/Sources/SparkConnect/ArrowArrayBuilder.swift
index 2b977c1..4c75b90 100644
--- a/Sources/SparkConnect/ArrowArrayBuilder.swift
+++ b/Sources/SparkConnect/ArrowArrayBuilder.swift
@@ -135,13 +135,13 @@ public class TimestampArrayBuilder:
ArrowArrayBuilder<FixedBufferBuilder<Int64>,
}
}
-public class StructArrayBuilder: ArrowArrayBuilder<StructBufferBuilder,
StructArray> {
+public class StructArrayBuilder: ArrowArrayBuilder<StructBufferBuilder,
NestedArray> {
let builders: [any ArrowArrayHolderBuilder]
let fields: [ArrowField]
public init(_ fields: [ArrowField], builders: [any ArrowArrayHolderBuilder])
throws {
self.fields = fields
self.builders = builders
- try super.init(ArrowNestedType(ArrowType.ArrowStruct, fields: fields))
+ try super.init(ArrowTypeStruct(ArrowType.ArrowStruct, fields: fields))
self.bufferBuilder.initializeTypeInfo(fields)
}
@@ -153,7 +153,7 @@ public class StructArrayBuilder:
ArrowArrayBuilder<StructBufferBuilder, StructAr
}
self.builders = builders
- try super.init(ArrowNestedType(ArrowType.ArrowStruct, fields: fields))
+ try super.init(ArrowTypeStruct(ArrowType.ArrowStruct, fields: fields))
}
public override func append(_ values: [Any?]?) {
@@ -169,7 +169,7 @@ public class StructArrayBuilder:
ArrowArrayBuilder<StructBufferBuilder, StructAr
}
}
- public override func finish() throws -> StructArray {
+ public override func finish() throws -> NestedArray {
let buffers = self.bufferBuilder.finish()
var childData = [ArrowData]()
for builder in self.builders {
@@ -180,11 +180,42 @@ public class StructArrayBuilder:
ArrowArrayBuilder<StructBufferBuilder, StructAr
self.type, buffers: buffers,
children: childData, nullCount: self.nullCount,
length: self.length)
- let structArray = try StructArray(arrowData)
+ let structArray = try NestedArray(arrowData)
return structArray
}
}
+public class ListArrayBuilder: ArrowArrayBuilder<ListBufferBuilder,
NestedArray> {
+ let valueBuilder: any ArrowArrayHolderBuilder
+
+ public override init(_ arrowType: ArrowType) throws {
+ guard let listType = arrowType as? ArrowTypeList else {
+ throw ArrowError.invalid("Expected ArrowTypeList")
+ }
+ let arrowField = listType.elementField
+ self.valueBuilder = try ArrowArrayBuilders.loadBuilder(arrowType:
arrowField.type)
+ try super.init(arrowType)
+ }
+
+ public override func append(_ values: [Any?]?) {
+ self.bufferBuilder.append(values)
+ if let vals = values {
+ for val in vals {
+ self.valueBuilder.appendAny(val)
+ }
+ }
+ }
+
+ public override func finish() throws -> NestedArray {
+ let buffers = self.bufferBuilder.finish()
+ let childData = try valueBuilder.toHolder().array.arrowData
+ let arrowData = try ArrowData(
+ self.type, buffers: buffers, children: [childData], nullCount:
self.nullCount,
+ length: self.length)
+ return try NestedArray(arrowData)
+ }
+}
+
public class ArrowArrayBuilders {
public static func loadBuilder( // swiftlint:disable:this
cyclomatic_complexity
_ builderType: Any.Type
@@ -304,6 +335,16 @@ public class ArrowArrayBuilders {
throw ArrowError.invalid("Expected arrow type for \(arrowType.id) not
found")
}
return try TimestampArrayBuilder(timestampType.unit)
+ case .strct:
+ guard let structType = arrowType as? ArrowTypeStruct else {
+ throw ArrowError.invalid("Expected ArrowStructType for
\(arrowType.id)")
+ }
+ return try StructArrayBuilder(structType.fields)
+ case .list:
+ guard let listType = arrowType as? ArrowTypeList else {
+ throw ArrowError.invalid("Expected ArrowTypeList for \(arrowType.id)")
+ }
+ return try ListArrayBuilder(listType)
default:
throw ArrowError.unknownType("Builder not found for arrow type:
\(arrowType.id)")
}
@@ -378,4 +419,12 @@ public class ArrowArrayBuilders {
) throws -> Decimal128ArrayBuilder {
return try Decimal128ArrayBuilder(precision: precision, scale: scale)
}
+
+ public static func loadStructArrayBuilder(_ fields: [ArrowField]) throws ->
StructArrayBuilder {
+ return try StructArrayBuilder(fields)
+ }
+
+ public static func loadListArrayBuilder(_ listType: ArrowTypeList) throws ->
ListArrayBuilder {
+ return try ListArrayBuilder(listType)
+ }
}
diff --git a/Sources/SparkConnect/ArrowBufferBuilder.swift
b/Sources/SparkConnect/ArrowBufferBuilder.swift
index b20e964..620c23a 100644
--- a/Sources/SparkConnect/ArrowBufferBuilder.swift
+++ b/Sources/SparkConnect/ArrowBufferBuilder.swift
@@ -343,20 +343,20 @@ public class Date64BufferBuilder:
AbstractWrapperBufferBuilder<Date, Int64> {
public final class StructBufferBuilder: BaseBufferBuilder, ArrowBufferBuilder {
public typealias ItemType = [Any?]
- var info: ArrowNestedType?
+ var info: ArrowTypeStruct?
public init() throws {
let nulls = ArrowBuffer.createBuffer(0, size:
UInt(MemoryLayout<UInt8>.stride))
super.init(nulls)
}
public func initializeTypeInfo(_ fields: [ArrowField]) {
- info = ArrowNestedType(ArrowType.ArrowStruct, fields: fields)
+ info = ArrowTypeStruct(ArrowType.ArrowStruct, fields: fields)
}
public func append(_ newValue: [Any?]?) {
let index = UInt(self.length)
self.length += 1
- if length > self.nulls.length {
+ if self.length > self.nulls.length {
self.resize(length)
}
@@ -385,3 +385,82 @@ public final class StructBufferBuilder: BaseBufferBuilder,
ArrowBufferBuilder {
return [nulls]
}
}
+
+public class ListBufferBuilder: BaseBufferBuilder, ArrowBufferBuilder {
+ public typealias ItemType = [Any?]
+ var offsets: ArrowBuffer
+
+ public required init() throws {
+ self.offsets = ArrowBuffer.createBuffer(1, size:
UInt(MemoryLayout<Int32>.stride))
+ let nulls = ArrowBuffer.createBuffer(0, size:
UInt(MemoryLayout<UInt8>.stride))
+ super.init(nulls)
+ self.offsets.rawPointer.storeBytes(of: Int32(0), as: Int32.self)
+ }
+
+ public func append(_ count: Int) {
+ let index = UInt(self.length)
+ self.length += 1
+
+ if length >= self.offsets.length {
+ self.resize(length + 1)
+ }
+
+ let offsetIndex = Int(index) * MemoryLayout<Int32>.stride
+ let currentOffset = self.offsets.rawPointer.advanced(by:
offsetIndex).load(as: Int32.self)
+
+ BitUtility.setBit(index + self.offset, buffer: self.nulls)
+ let newOffset = currentOffset + Int32(count)
+ self.offsets.rawPointer.advanced(by: offsetIndex +
MemoryLayout<Int32>.stride).storeBytes(
+ of: newOffset, as: Int32.self)
+ }
+
+ public func append(_ newValue: [Any?]?) {
+ let index = UInt(self.length)
+ self.length += 1
+
+ if self.length >= self.offsets.length {
+ self.resize(self.length + 1)
+ }
+
+ let offsetIndex = Int(index) * MemoryLayout<Int32>.stride
+ let currentOffset = self.offsets.rawPointer.advanced(by:
offsetIndex).load(as: Int32.self)
+
+ if let vals = newValue {
+ BitUtility.setBit(index + self.offset, buffer: self.nulls)
+ let newOffset = currentOffset + Int32(vals.count)
+ self.offsets.rawPointer.advanced(by: offsetIndex +
MemoryLayout<Int32>.stride).storeBytes(
+ of: newOffset, as: Int32.self)
+ } else {
+ self.nullCount += 1
+ BitUtility.clearBit(index + self.offset, buffer: self.nulls)
+ self.offsets.rawPointer.advanced(by: offsetIndex +
MemoryLayout<Int32>.stride).storeBytes(
+ of: currentOffset, as: Int32.self)
+ }
+ }
+
+ public override func isNull(_ index: UInt) -> Bool {
+ return !BitUtility.isSet(index + self.offset, buffer: self.nulls)
+ }
+
+ public func resize(_ length: UInt) {
+ if length > self.offsets.length {
+ let resizeLength = resizeLength(self.offsets)
+ var offsets = ArrowBuffer.createBuffer(resizeLength, size:
UInt(MemoryLayout<Int32>.size))
+ var nulls = ArrowBuffer.createBuffer(
+ resizeLength / 8 + 1, size: UInt(MemoryLayout<UInt8>.size))
+ ArrowBuffer.copyCurrent(self.offsets, to: &offsets, len:
self.offsets.capacity)
+ ArrowBuffer.copyCurrent(self.nulls, to: &nulls, len: self.nulls.capacity)
+ self.offsets = offsets
+ self.nulls = nulls
+ }
+ }
+
+ public func finish() -> [ArrowBuffer] {
+ let length = self.length
+ var nulls = ArrowBuffer.createBuffer(length / 8 + 1, size:
UInt(MemoryLayout<UInt8>.size))
+ var offsets = ArrowBuffer.createBuffer(length + 1, size:
UInt(MemoryLayout<Int32>.size))
+ ArrowBuffer.copyCurrent(self.nulls, to: &nulls, len: nulls.capacity)
+ ArrowBuffer.copyCurrent(self.offsets, to: &offsets, len: offsets.capacity)
+ return [nulls, offsets]
+ }
+}
diff --git a/Sources/SparkConnect/ArrowReader.swift
b/Sources/SparkConnect/ArrowReader.swift
index f0699c2..c19a9e7 100644
--- a/Sources/SparkConnect/ArrowReader.swift
+++ b/Sources/SparkConnect/ArrowReader.swift
@@ -126,6 +126,44 @@ public class ArrowReader { // swiftlint:disable:this
type_body_length
rbLength: UInt(loadInfo.batchData.recordBatch.length))
}
+ private func loadListData(_ loadInfo: DataLoadInfo, field:
org_apache_arrow_flatbuf_Field)
+ -> Result<ArrowArrayHolder, ArrowError>
+ {
+ guard let node = loadInfo.batchData.nextNode() else {
+ return .failure(.invalid("Node not found"))
+ }
+
+ guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
+ return .failure(.invalid("Null buffer not found"))
+ }
+
+ guard let offsetBuffer = loadInfo.batchData.nextBuffer() else {
+ return .failure(.invalid("Offset buffer not found"))
+ }
+
+ let nullLength = UInt(ceil(Double(node.length) / 8))
+ let arrowNullBuffer = makeBuffer(
+ nullBuffer, fileData: loadInfo.fileData, length: nullLength,
+ messageOffset: loadInfo.messageOffset)
+ let arrowOffsetBuffer = makeBuffer(
+ offsetBuffer, fileData: loadInfo.fileData, length: UInt(node.length + 1),
+ messageOffset: loadInfo.messageOffset)
+
+ guard field.childrenCount == 1, let childField = field.children(at: 0)
else {
+ return .failure(.invalid("List must have exactly one child"))
+ }
+
+ switch loadField(loadInfo, field: childField) {
+ case .success(let childHolder):
+ return makeArrayHolder(
+ field, buffers: [arrowNullBuffer, arrowOffsetBuffer], nullCount:
UInt(node.nullCount),
+ children: [childHolder.array.arrowData],
+ rbLength: UInt(loadInfo.batchData.recordBatch.length))
+ case .failure(let error):
+ return .failure(error)
+ }
+ }
+
private func loadPrimitiveData(
_ loadInfo: DataLoadInfo,
field: org_apache_arrow_flatbuf_Field
@@ -204,6 +242,8 @@ public class ArrowReader { // swiftlint:disable:this
type_body_length
switch field.typeType {
case .struct_:
return loadStructData(loadInfo, field: field)
+ case .list:
+ return loadListData(loadInfo, field: field)
default:
if isFixedPrimitive(field.typeType) {
return loadPrimitiveData(loadInfo, field: field)
diff --git a/Sources/SparkConnect/ArrowReaderHelper.swift
b/Sources/SparkConnect/ArrowReaderHelper.swift
index bb0100a..ea397db 100644
--- a/Sources/SparkConnect/ArrowReaderHelper.swift
+++ b/Sources/SparkConnect/ArrowReaderHelper.swift
@@ -160,7 +160,7 @@ private func makeFixedHolder<T>(
}
}
-func makeStructHolder(
+func makeNestedHolder(
_ field: ArrowField,
buffers: [ArrowBuffer],
nullCount: UInt,
@@ -170,9 +170,12 @@ func makeStructHolder(
do {
let arrowData = try ArrowData(
field.type,
- buffers: buffers, children: children,
- nullCount: nullCount, length: rbLength)
- return .success(ArrowArrayHolderImpl(try StructArray(arrowData)))
+ buffers: buffers,
+ children: children,
+ nullCount: nullCount,
+ length: rbLength
+ )
+ return .success(ArrowArrayHolderImpl(try NestedArray(arrowData)))
} catch let error as ArrowError {
return .failure(error)
} catch {
@@ -236,7 +239,10 @@ func makeArrayHolder( // swiftlint:disable:this
cyclomatic_complexity
case .timestamp:
return makeTimestampHolder(field, buffers: buffers, nullCount: nullCount)
case .strct:
- return makeStructHolder(
+ return makeNestedHolder(
+ field, buffers: buffers, nullCount: nullCount, children: children!,
rbLength: rbLength)
+ case .list:
+ return makeNestedHolder(
field, buffers: buffers, nullCount: nullCount, children: children!,
rbLength: rbLength)
default:
return .failure(.unknownType("Type \(typeId) currently not supported"))
@@ -345,7 +351,15 @@ func findArrowType( // swiftlint:disable:this
cyclomatic_complexity function_bo
ArrowField(childField.name ?? "", type: childType, isNullable:
childField.nullable))
}
- return ArrowNestedType(ArrowType.ArrowStruct, fields: fields)
+ return ArrowTypeStruct(ArrowType.ArrowStruct, fields: fields)
+ case .list:
+ guard field.childrenCount == 1, let childField = field.children(at: 0)
else {
+ return ArrowType(ArrowType.ArrowUnknown)
+ }
+ let childType = findArrowType(childField)
+ let childFieldName = childField.name ?? "item"
+ return ArrowTypeList(
+ ArrowField(childFieldName, type: childType, isNullable:
childField.nullable))
default:
return ArrowType(ArrowType.ArrowUnknown)
}
diff --git a/Sources/SparkConnect/ArrowType.swift
b/Sources/SparkConnect/ArrowType.swift
index d7b773e..6ad39a7 100644
--- a/Sources/SparkConnect/ArrowType.swift
+++ b/Sources/SparkConnect/ArrowType.swift
@@ -190,7 +190,7 @@ public class ArrowTypeTimestamp: ArrowType {
}
/// @nodoc
-public class ArrowNestedType: ArrowType {
+public class ArrowTypeStruct: ArrowType {
let fields: [ArrowField]
public init(_ info: ArrowType.Info, fields: [ArrowField]) {
self.fields = fields
@@ -198,6 +198,19 @@ public class ArrowNestedType: ArrowType {
}
}
+public class ArrowTypeList: ArrowType {
+ public let elementField: ArrowField
+
+ public init(_ elementField: ArrowField) {
+ self.elementField = elementField
+ super.init(ArrowType.ArrowList)
+ }
+
+ public convenience init(_ elementType: ArrowType, nullable: Bool = true) {
+ self.init(ArrowField("item", type: elementType, isNullable: nullable))
+ }
+}
+
/// @nodoc
public class ArrowType {
public private(set) var info: ArrowType.Info
@@ -222,6 +235,7 @@ public class ArrowType {
public static let ArrowTime64 = Info.timeInfo(ArrowTypeId.time64)
public static let ArrowTimestamp = Info.timeInfo(ArrowTypeId.timestamp)
public static let ArrowStruct = Info.complexInfo(ArrowTypeId.strct)
+ public static let ArrowList = Info.complexInfo(ArrowTypeId.list)
public init(_ info: ArrowType.Info) {
self.info = info
@@ -355,7 +369,7 @@ public class ArrowType {
return MemoryLayout<Int8>.stride
case .string:
return MemoryLayout<Int8>.stride
- case .strct:
+ case .strct, .list:
return 0
default:
fatalError("Stride requested for unknown type: \(self)")
@@ -412,6 +426,20 @@ public class ArrowType {
return "z"
case ArrowTypeId.string:
return "u"
+ case ArrowTypeId.strct:
+ if let structType = self as? ArrowTypeStruct {
+ var format = "+s"
+ for field in structType.fields {
+ format += try field.type.cDataFormatId
+ }
+ return format
+ }
+ throw ArrowError.invalid("Invalid struct type")
+ case ArrowTypeId.list:
+ if let listType = self as? ArrowTypeList {
+ return "+l" + (try listType.elementField.type.cDataFormatId)
+ }
+ throw ArrowError.invalid("Invalid list type")
default:
throw ArrowError.notImplemented
}
diff --git a/Sources/SparkConnect/ArrowWriter.swift
b/Sources/SparkConnect/ArrowWriter.swift
index 9d44f9e..c7522a1 100644
--- a/Sources/SparkConnect/ArrowWriter.swift
+++ b/Sources/SparkConnect/ArrowWriter.swift
@@ -76,7 +76,7 @@ public class ArrowWriter { // swiftlint:disable:this
type_body_length
Offset, ArrowError
> {
var fieldsOffset: Offset?
- if let nestedField = field.type as? ArrowNestedType {
+ if let nestedField = field.type as? ArrowTypeStruct {
var offsets = [Offset]()
for field in nestedField.fields {
switch writeField(&fbb, field: field) {
@@ -180,10 +180,11 @@ public class ArrowWriter { // swiftlint:disable:this
type_body_length
length: Int64(column.length),
nullCount: Int64(column.nullCount))
offsets.append(fbb.create(struct: fieldNode))
- if let nestedType = column.type as? ArrowNestedType {
- let structArray = column.array as? StructArray
- writeFieldNodes(
- nestedType.fields, columns: structArray!.arrowFields!, offsets:
&offsets, fbb: &fbb)
+ if let nestedType = column.type as? ArrowTypeStruct {
+ let nestedArray = column.array as? NestedArray
+ if let nestedFields = nestedArray?.fields {
+ writeFieldNodes(nestedType.fields, columns: nestedFields, offsets:
&offsets, fbb: &fbb)
+ }
}
}
}
@@ -204,11 +205,13 @@ public class ArrowWriter { // swiftlint:disable:this
type_body_length
offset: Int64(bufferOffset), length: Int64(bufferDataSize))
buffers.append(buffer)
bufferOffset += bufferDataSize
- if let nestedType = column.type as? ArrowNestedType {
- let structArray = column.array as? StructArray
- writeBufferInfo(
- nestedType.fields, columns: structArray!.arrowFields!,
- bufferOffset: &bufferOffset, buffers: &buffers, fbb: &fbb)
+ if let nestedType = column.type as? ArrowTypeStruct {
+ let nestedArray = column.array as? NestedArray
+ if let nestedFields = nestedArray?.fields {
+ writeBufferInfo(
+ nestedType.fields, columns: nestedFields,
+ bufferOffset: &bufferOffset, buffers: &buffers, fbb: &fbb)
+ }
}
}
}
@@ -267,13 +270,15 @@ public class ArrowWriter { // swiftlint:disable:this
type_body_length
for var bufferData in colBufferData {
addPadForAlignment(&bufferData)
writer.append(bufferData)
- if let nestedType = column.type as? ArrowNestedType {
- guard let structArray = column.array as? StructArray else {
+ if let nestedType = column.type as? ArrowTypeStruct {
+ guard let nestedArray = column.array as? NestedArray,
+ let nestedFields = nestedArray.fields
+ else {
return .failure(.invalid("Struct type array expected for nested
type"))
}
switch writeRecordBatchData(
- &writer, fields: nestedType.fields, columns:
structArray.arrowFields!)
+ &writer, fields: nestedType.fields, columns: nestedFields)
{
case .success:
continue
diff --git a/Sources/SparkConnect/ProtoUtil.swift
b/Sources/SparkConnect/ProtoUtil.swift
index f890d07..6024315 100644
--- a/Sources/SparkConnect/ProtoUtil.swift
+++ b/Sources/SparkConnect/ProtoUtil.swift
@@ -96,7 +96,14 @@ func fromProto( // swiftlint:disable:this
cyclomatic_complexity function_body_l
children.append(fromProto(field: childField))
}
- arrowType = ArrowNestedType(ArrowType.ArrowStruct, fields: children)
+ arrowType = ArrowTypeStruct(ArrowType.ArrowStruct, fields: children)
+ case .list:
+ guard field.childrenCount == 1, let childField = field.children(at: 0)
else {
+ arrowType = ArrowType(ArrowType.ArrowUnknown)
+ break
+ }
+ let childArrowField = fromProto(field: childField)
+ arrowType = ArrowTypeList(childArrowField)
default:
arrowType = ArrowType(ArrowType.ArrowUnknown)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]