This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new ae41f826fd GH-48705: [Ruby] Add support for reading dictionary array
(#48706)
ae41f826fd is described below
commit ae41f826fd862fa4288ed8a8d3093aad184bb894
Author: Sutou Kouhei <[email protected]>
AuthorDate: Sat Jan 3 09:59:26 2026 +0900
GH-48705: [Ruby] Add support for reading dictionary array (#48706)
### Rationale for this change
Dictionary array is a special data type. We need to process dictionary
batch message for this.
### What changes are included in this PR?
* Add `ArrowFormat::DictionaryType`
* Add `ArrowFormat::DictionaryArray`
* Add support for dictionary batch messages
### Are these changes tested?
Yes.
### Are there any user-facing changes?
Yes.
* GitHub Issue: #48705
Authored-by: Sutou Kouhei <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
ruby/red-arrow-format/lib/arrow-format/array.rb | 80 ++++++------
ruby/red-arrow-format/lib/arrow-format/field.rb | 4 +-
.../lib/arrow-format/file-reader.rb | 139 +++++++++++++++------
ruby/red-arrow-format/lib/arrow-format/readable.rb | 79 +++++++-----
.../lib/arrow-format/streaming-pull-reader.rb | 57 +++++++--
ruby/red-arrow-format/lib/arrow-format/type.rb | 56 +++++++++
ruby/red-arrow-format/test/test-reader.rb | 13 ++
7 files changed, 313 insertions(+), 115 deletions(-)
diff --git a/ruby/red-arrow-format/lib/arrow-format/array.rb
b/ruby/red-arrow-format/lib/arrow-format/array.rb
index 9d82aae16f..0c27e24bc6 100644
--- a/ruby/red-arrow-format/lib/arrow-format/array.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/array.rb
@@ -79,54 +79,34 @@ module ArrowFormat
super(type, size, validity_buffer)
@values_buffer = values_buffer
end
- end
- class Int8Array < IntArray
def to_a
- apply_validity(@values_buffer.values(:S8, 0, @size))
+ apply_validity(@values_buffer.values(@type.buffer_type, 0, @size))
end
end
+ class Int8Array < IntArray
+ end
+
class UInt8Array < IntArray
- def to_a
- apply_validity(@values_buffer.values(:U8, 0, @size))
- end
end
class Int16Array < IntArray
- def to_a
- apply_validity(@values_buffer.values(:s16, 0, @size))
- end
end
class UInt16Array < IntArray
- def to_a
- apply_validity(@values_buffer.values(:u16, 0, @size))
- end
end
class Int32Array < IntArray
- def to_a
- apply_validity(@values_buffer.values(:s32, 0, @size))
- end
end
class UInt32Array < IntArray
- def to_a
- apply_validity(@values_buffer.values(:u32, 0, @size))
- end
end
class Int64Array < IntArray
- def to_a
- apply_validity(@values_buffer.values(:s64, 0, @size))
- end
end
class UInt64Array < IntArray
- def to_a
- apply_validity(@values_buffer.values(:u64, 0, @size))
- end
end
class FloatingPointArray < Array
@@ -410,6 +390,27 @@ module ArrowFormat
end
end
+ class MapArray < VariableSizeListArray
+ def to_a
+ super.collect do |entries|
+ if entries.nil?
+ entries
+ else
+ hash = {}
+ entries.each do |key, value|
+ hash[key] = value
+ end
+ hash
+ end
+ end
+ end
+
+ private
+ def offset_type
+ :s32 # TODO: big endian support
+ end
+ end
+
class UnionArray < Array
def initialize(type, size, types_buffer, children)
super(type, size, nil)
@@ -449,24 +450,27 @@ module ArrowFormat
end
end
- class MapArray < VariableSizeListArray
+ class DictionaryArray < Array
+ def initialize(type, size, validity_buffer, indices_buffer, dictionary)
+ super(type, size, validity_buffer)
+ @indices_buffer = indices_buffer
+ @dictionary = dictionary
+ end
+
def to_a
- super.collect do |entries|
- if entries.nil?
- entries
+ values = []
+ @dictionary.each do |dictionary_chunk|
+ values.concat(dictionary_chunk.to_a)
+ end
+ buffer_type = @type.index_type.buffer_type
+ indices = apply_validity(@indices_buffer.values(buffer_type, 0, @size))
+ indices.collect do |index|
+ if index.nil?
+ nil
else
- hash = {}
- entries.each do |key, value|
- hash[key] = value
- end
- hash
+ values[index]
end
end
end
-
- private
- def offset_type
- :s32 # TODO: big endian support
- end
end
end
diff --git a/ruby/red-arrow-format/lib/arrow-format/field.rb
b/ruby/red-arrow-format/lib/arrow-format/field.rb
index ac531750f7..090113cfe6 100644
--- a/ruby/red-arrow-format/lib/arrow-format/field.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/field.rb
@@ -18,10 +18,12 @@ module ArrowFormat
class Field
attr_reader :name
attr_reader :type
- def initialize(name, type, nullable)
+ attr_reader :dictionary_id
+ def initialize(name, type, nullable, dictionary_id)
@name = name
@type = type
@nullable = nullable
+ @dictionary_id = dictionary_id
end
def nullable?
diff --git a/ruby/red-arrow-format/lib/arrow-format/file-reader.rb
b/ruby/red-arrow-format/lib/arrow-format/file-reader.rb
index bf50bfd1cd..545638ca90 100644
--- a/ruby/red-arrow-format/lib/arrow-format/file-reader.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/file-reader.rb
@@ -49,17 +49,65 @@ module ArrowFormat
validate
@footer = read_footer
- @record_batches = @footer.record_batches
+ @record_batch_blocks = @footer.record_batches
@schema = read_schema(@footer.schema)
+ @dictionaries = read_dictionaries
end
def n_record_batches
- @record_batches.size
+ @record_batch_blocks.size
end
def read(i)
- block = @record_batches[i]
+ fb_message, body = read_block(@record_batch_blocks[i])
+ fb_header = fb_message.header
+ unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch)
+ raise FileReadError.new(@buffer,
+ "Not a record batch message: #{i}: " +
+ fb_header.class.name)
+ end
+ read_record_batch(fb_header, @schema, body)
+ end
+
+ def each
+ return to_enum(__method__) {n_record_batches} unless block_given?
+
+ @record_batch_blocks.size.times do |i|
+ yield(read(i))
+ end
+ end
+
+ private
+ def validate
+ minimum_size = STREAMING_FORMAT_START_OFFSET +
+ FOOTER_SIZE_SIZE +
+ END_MARKER_SIZE
+ if @buffer.size < minimum_size
+ raise FileReadError.new(@buffer,
+ "Input must be larger than or equal to " +
+ "#{minimum_size}: #{@buffer.size}")
+ end
+
+ start_marker = @buffer.slice(0, START_MARKER_SIZE)
+ if start_marker != MAGIC_BUFFER
+ raise FileReadError.new(@buffer, "No start marker")
+ end
+ end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE,
+ END_MARKER_SIZE)
+ if end_marker != MAGIC_BUFFER
+ raise FileReadError.new(@buffer, "No end marker")
+ end
+ end
+
+ def read_footer
+ footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE
+ footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset)
+ footer_data = @buffer.slice(footer_size_offset - footer_size,
+ footer_size)
+ Org::Apache::Arrow::Flatbuf::Footer.new(footer_data)
+ end
+ def read_block(block)
offset = block.offset
# If we can report property error information, we can use
@@ -101,54 +149,65 @@ module ArrowFormat
metadata = @buffer.slice(offset, metadata_length)
fb_message = Org::Apache::Arrow::Flatbuf::Message.new(metadata)
- fb_header = fb_message.header
- unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch)
- raise FileReadError.new(@buffer,
- "Not a record batch message: #{i}: " +
- fb_header.class.name)
- end
offset += metadata_length
body = @buffer.slice(offset, block.body_length)
- read_record_batch(fb_header, @schema, body)
- end
- def each
- return to_enum(__method__) {n_record_batches} unless block_given?
-
- @record_batches.size.times do |i|
- yield(read(i))
- end
+ [fb_message, body]
end
- private
- def validate
- minimum_size = STREAMING_FORMAT_START_OFFSET +
- FOOTER_SIZE_SIZE +
- END_MARKER_SIZE
- if @buffer.size < minimum_size
- raise FileReadError.new(@buffer,
- "Input must be larger than or equal to " +
- "#{minimum_size}: #{@buffer.size}")
- end
+ def read_dictionaries
+ dictionary_blocks = @footer.dictionaries
+ return nil if dictionary_blocks.nil?
- start_marker = @buffer.slice(0, START_MARKER_SIZE)
- if start_marker != MAGIC_BUFFER
- raise FileReadError.new(@buffer, "No start marker")
+ dictionary_fields = {}
+ @schema.fields.each do |field|
+ next unless field.type.is_a?(DictionaryType)
+ dictionary_fields[field.dictionary_id] = field
end
- end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE,
- END_MARKER_SIZE)
- if end_marker != MAGIC_BUFFER
- raise FileReadError.new(@buffer, "No end marker")
+
+ dictionaries = {}
+ dictionary_blocks.each do |block|
+ fb_message, body = read_block(block)
+ fb_header = fb_message.header
+ unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::DictionaryBatch)
+ raise FileReadError.new(@buffer,
+ "Not a dictionary batch message: " +
+ fb_header.inspect)
+ end
+
+ id = fb_header.id
+ if fb_header.delta?
+ unless dictionaries.key?(id)
+ raise FileReadError.new(@buffer,
+ "A delta dictionary batch message " +
+ "must exist after a non delta " +
+ "dictionary batch message: " +
+ fb_header.inspect)
+ end
+ else
+ if dictionaries.key?(id)
+ raise FileReadError.new(@buffer,
+ "Multiple non delta dictionary batch " +
+ "messages for the same ID is invalid: " +
+ fb_header.inspect)
+ end
+ end
+
+ value_type = dictionary_fields[id].type.value_type
+ schema = Schema.new([Field.new("dummy", value_type, true, nil)])
+ record_batch = read_record_batch(fb_header.data, schema, body)
+ if fb_header.delta?
+ dictionaries[id] << record_batch.columns[0]
+ else
+ dictionaries[id] = [record_batch.columns[0]]
+ end
end
+ dictionaries
end
- def read_footer
- footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE
- footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset)
- footer_data = @buffer.slice(footer_size_offset - footer_size,
- footer_size)
- Org::Apache::Arrow::Flatbuf::Footer.new(footer_data)
+ def find_dictionary(id)
+ @dictionaries[id]
end
end
end
diff --git a/ruby/red-arrow-format/lib/arrow-format/readable.rb
b/ruby/red-arrow-format/lib/arrow-format/readable.rb
index 7aa2effde2..ad6be653e0 100644
--- a/ruby/red-arrow-format/lib/arrow-format/readable.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/readable.rb
@@ -26,6 +26,8 @@ require_relative "org/apache/arrow/flatbuf/bool"
require_relative "org/apache/arrow/flatbuf/date"
require_relative "org/apache/arrow/flatbuf/date_unit"
require_relative "org/apache/arrow/flatbuf/decimal"
+require_relative "org/apache/arrow/flatbuf/dictionary_encoding"
+require_relative "org/apache/arrow/flatbuf/dictionary_batch"
require_relative "org/apache/arrow/flatbuf/duration"
require_relative "org/apache/arrow/flatbuf/fixed_size_binary"
require_relative "org/apache/arrow/flatbuf/floating_point"
@@ -40,11 +42,12 @@ require_relative "org/apache/arrow/flatbuf/map"
require_relative "org/apache/arrow/flatbuf/message"
require_relative "org/apache/arrow/flatbuf/null"
require_relative "org/apache/arrow/flatbuf/precision"
+require_relative "org/apache/arrow/flatbuf/record_batch"
require_relative "org/apache/arrow/flatbuf/schema"
require_relative "org/apache/arrow/flatbuf/struct_"
require_relative "org/apache/arrow/flatbuf/time"
-require_relative "org/apache/arrow/flatbuf/timestamp"
require_relative "org/apache/arrow/flatbuf/time_unit"
+require_relative "org/apache/arrow/flatbuf/timestamp"
require_relative "org/apache/arrow/flatbuf/union"
require_relative "org/apache/arrow/flatbuf/union_mode"
require_relative "org/apache/arrow/flatbuf/utf8"
@@ -67,32 +70,7 @@ module ArrowFormat
when Org::Apache::Arrow::Flatbuf::Bool
type = BooleanType.singleton
when Org::Apache::Arrow::Flatbuf::Int
- case fb_type.bit_width
- when 8
- if fb_type.signed?
- type = Int8Type.singleton
- else
- type = UInt8Type.singleton
- end
- when 16
- if fb_type.signed?
- type = Int16Type.singleton
- else
- type = UInt16Type.singleton
- end
- when 32
- if fb_type.signed?
- type = Int32Type.singleton
- else
- type = UInt32Type.singleton
- end
- when 64
- if fb_type.signed?
- type = Int64Type.singleton
- else
- type = UInt64Type.singleton
- end
- end
+ type = read_type_int(fb_type)
when Org::Apache::Arrow::Flatbuf::FloatingPoint
case fb_type.precision
when Org::Apache::Arrow::Flatbuf::Precision::SINGLE
@@ -175,14 +153,52 @@ module ArrowFormat
type = Decimal256Type.new(fb_type.precision, fb_type.scale)
end
end
- Field.new(fb_field.name, type, fb_field.nullable?)
+
+ dictionary = fb_field.dictionary
+ if dictionary
+ dictionary_id = dictionary.id
+ index_type = read_type_int(dictionary.index_type)
+ type = DictionaryType.new(index_type, type, dictionary.ordered?)
+ else
+ dictionary_id = nil
+ end
+ Field.new(fb_field.name, type, fb_field.nullable?, dictionary_id)
+ end
+
+ def read_type_int(fb_type)
+ case fb_type.bit_width
+ when 8
+ if fb_type.signed?
+ Int8Type.singleton
+ else
+ UInt8Type.singleton
+ end
+ when 16
+ if fb_type.signed?
+ Int16Type.singleton
+ else
+ UInt16Type.singleton
+ end
+ when 32
+ if fb_type.signed?
+ Int32Type.singleton
+ else
+ UInt32Type.singleton
+ end
+ when 64
+ if fb_type.signed?
+ Int64Type.singleton
+ else
+ UInt64Type.singleton
+ end
+ end
end
def read_record_batch(fb_record_batch, schema, body)
n_rows = fb_record_batch.length
nodes = fb_record_batch.nodes
buffers = fb_record_batch.buffers
- columns = @schema.fields.collect do |field|
+ columns = schema.fields.collect do |field|
read_column(field, nodes, buffers, body)
end
RecordBatch.new(schema, n_rows, columns)
@@ -244,6 +260,11 @@ module ArrowFormat
read_column(child, nodes, buffers, body)
end
field.type.build_array(length, types, children)
+ when DictionaryType
+ indices_buffer = buffers.shift
+ indices = body.slice(indices_buffer.offset, indices_buffer.length)
+ dictionary = find_dictionary(field.dictionary_id)
+ field.type.build_array(length, validity, indices, dictionary)
end
end
end
diff --git a/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb
b/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb
index ae231fccbc..8682f3e826 100644
--- a/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb
@@ -151,6 +151,8 @@ module ArrowFormat
end
@state = :schema
@schema = nil
+ @dictionaries = nil
+ @dictionary_fields = nil
end
def next_required_size
@@ -170,8 +172,23 @@ module ArrowFormat
case @state
when :schema
process_schema_message(message, body)
- when :record_batch
- process_record_batch_message(message, body)
+ when :initial_dictionaries
+ header = message.header
+ unless header.is_a?(Org::Apache::Arrow::Flatbuf::DictionaryBatch)
+ raise ReadError.new("Not a dictionary batch message: " +
+ header.inspect)
+ end
+ process_dictionary_batch_message(message, body)
+ if @dictionaries.size == @dictionary_fields.size
+ @state = :data
+ end
+ when :data
+ case message.header
+ when Org::Apache::Arrow::Flatbuf::DictionaryBatch
+ process_dictionary_batch_message(message, body)
+ when Org::Apache::Arrow::Flatbuf::RecordBatch
+ process_record_batch_message(message, body)
+ end
end
end
@@ -183,17 +200,43 @@ module ArrowFormat
end
@schema = read_schema(header)
- # TODO: initial dictionaries support
- @state = :record_batch
+ @dictionaries = {}
+ @dictionary_fields = {}
+ @schema.fields.each do |field|
+ next unless field.type.is_a?(DictionaryType)
+ @dictionary_fields[field.dictionary_id] = field
+ end
+ if @dictionaries.size < @dictionary_fields.size
+ @state = :initial_dictionaries
+ else
+ @state = :data
+ end
end
- def process_record_batch_message(message, body)
+ def process_dictionary_batch_message(message, body)
header = message.header
- unless header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch)
- raise ReadError.new("Not a record batch message: " +
+ if @state == :initial_dictionaries and header.delta?
+ raise ReadError.new("An initial dictionary batch message must be " +
+ "a non delta dictionary batch message: " +
header.inspect)
end
+ field = @dictionary_fields[header.id]
+ value_type = field.type.value_type
+ schema = Schema.new([Field.new("dummy", value_type, true, nil)])
+ record_batch = read_record_batch(header.data, schema, body)
+ if header.delta?
+ @dictionaries[header.id] << record_batch.columns[0]
+ else
+ @dictionaries[header.id] = [record_batch.columns[0]]
+ end
+ end
+ def find_dictionary(id)
+ @dictionaries[id]
+ end
+
+ def process_record_batch_message(message, body)
+ header = message.header
@on_read.call(read_record_batch(header, @schema, body))
end
end
diff --git a/ruby/red-arrow-format/lib/arrow-format/type.rb
b/ruby/red-arrow-format/lib/arrow-format/type.rb
index 92a699509b..ebf4ce5fa9 100644
--- a/ruby/red-arrow-format/lib/arrow-format/type.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/type.rb
@@ -78,6 +78,10 @@ module ArrowFormat
"Int8"
end
+ def buffer_type
+ :S8
+ end
+
def build_array(size, validity_buffer, values_buffer)
Int8Array.new(self, size, validity_buffer, values_buffer)
end
@@ -98,6 +102,10 @@ module ArrowFormat
"UInt8"
end
+ def buffer_type
+ :U8
+ end
+
def build_array(size, validity_buffer, values_buffer)
UInt8Array.new(self, size, validity_buffer, values_buffer)
end
@@ -118,6 +126,10 @@ module ArrowFormat
"Int16"
end
+ def buffer_type
+ :s16
+ end
+
def build_array(size, validity_buffer, values_buffer)
Int16Array.new(self, size, validity_buffer, values_buffer)
end
@@ -138,6 +150,10 @@ module ArrowFormat
"UInt16"
end
+ def buffer_type
+ :u16
+ end
+
def build_array(size, validity_buffer, values_buffer)
UInt16Array.new(self, size, validity_buffer, values_buffer)
end
@@ -158,6 +174,10 @@ module ArrowFormat
"Int32"
end
+ def buffer_type
+ :s32
+ end
+
def build_array(size, validity_buffer, values_buffer)
Int32Array.new(self, size, validity_buffer, values_buffer)
end
@@ -178,6 +198,10 @@ module ArrowFormat
"UInt32"
end
+ def buffer_type
+ :u32
+ end
+
def build_array(size, validity_buffer, values_buffer)
UInt32Array.new(self, size, validity_buffer, values_buffer)
end
@@ -198,6 +222,10 @@ module ArrowFormat
"Int64"
end
+ def buffer_type
+ :s64
+ end
+
def build_array(size, validity_buffer, values_buffer)
Int64Array.new(self, size, validity_buffer, values_buffer)
end
@@ -218,6 +246,10 @@ module ArrowFormat
"UInt64"
end
+ def buffer_type
+ :u64
+ end
+
def build_array(size, validity_buffer, values_buffer)
UInt64Array.new(self, size, validity_buffer, values_buffer)
end
@@ -645,4 +677,28 @@ module ArrowFormat
SparseUnionArray.new(self, size, types_buffer, children)
end
end
+
+ class DictionaryType < Type
+ attr_reader :index_type
+ attr_reader :value_type
+ attr_reader :ordered
+ def initialize(index_type, value_type, ordered)
+ super()
+ @index_type = index_type
+ @value_type = value_type
+ @ordered = ordered
+ end
+
+ def name
+ "Dictionary"
+ end
+
+ def build_array(size, validity_buffer, indices_buffer, dictionary)
+ DictionaryArray.new(self,
+ size,
+ validity_buffer,
+ indices_buffer,
+ dictionary)
+ end
+ end
end
diff --git a/ruby/red-arrow-format/test/test-reader.rb
b/ruby/red-arrow-format/test/test-reader.rb
index 8164d20623..0e59d855ce 100644
--- a/ruby/red-arrow-format/test/test-reader.rb
+++ b/ruby/red-arrow-format/test/test-reader.rb
@@ -866,6 +866,19 @@ module ReaderTests
read)
end
end
+
+ sub_test_case("Dictionary") do
+ def build_array
+ values = ["a", "b", "c", nil, "a"]
+ string_array = Arrow::StringArray.new(values)
+ string_array.dictionary_encode
+ end
+
+ def test_read
+ assert_equal([{"value" => ["a", "b", "c", nil, "a"]}],
+ read)
+ end
+ end
end
end
end