This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new ae41f826fd GH-48705: [Ruby] Add support for reading dictionary array 
(#48706)
ae41f826fd is described below

commit ae41f826fd862fa4288ed8a8d3093aad184bb894
Author: Sutou Kouhei <[email protected]>
AuthorDate: Sat Jan 3 09:59:26 2026 +0900

    GH-48705: [Ruby] Add support for reading dictionary array (#48706)
    
    ### Rationale for this change
    
    Dictionary array is a special data type. We need to process dictionary 
batch message for this.
    
    ### What changes are included in this PR?
    
    * Add `ArrowFormat::DictionaryType`
    * Add `ArrowFormat::DictionaryArray`
    * Add support for dictionary batch messages
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    Yes.
    * GitHub Issue: #48705
    
    Authored-by: Sutou Kouhei <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 ruby/red-arrow-format/lib/arrow-format/array.rb    |  80 ++++++------
 ruby/red-arrow-format/lib/arrow-format/field.rb    |   4 +-
 .../lib/arrow-format/file-reader.rb                | 139 +++++++++++++++------
 ruby/red-arrow-format/lib/arrow-format/readable.rb |  79 +++++++-----
 .../lib/arrow-format/streaming-pull-reader.rb      |  57 +++++++--
 ruby/red-arrow-format/lib/arrow-format/type.rb     |  56 +++++++++
 ruby/red-arrow-format/test/test-reader.rb          |  13 ++
 7 files changed, 313 insertions(+), 115 deletions(-)

diff --git a/ruby/red-arrow-format/lib/arrow-format/array.rb 
b/ruby/red-arrow-format/lib/arrow-format/array.rb
index 9d82aae16f..0c27e24bc6 100644
--- a/ruby/red-arrow-format/lib/arrow-format/array.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/array.rb
@@ -79,54 +79,34 @@ module ArrowFormat
       super(type, size, validity_buffer)
       @values_buffer = values_buffer
     end
-  end
 
-  class Int8Array < IntArray
     def to_a
-      apply_validity(@values_buffer.values(:S8, 0, @size))
+      apply_validity(@values_buffer.values(@type.buffer_type, 0, @size))
     end
   end
 
+  class Int8Array < IntArray
+  end
+
   class UInt8Array < IntArray
-    def to_a
-      apply_validity(@values_buffer.values(:U8, 0, @size))
-    end
   end
 
   class Int16Array < IntArray
-    def to_a
-      apply_validity(@values_buffer.values(:s16, 0, @size))
-    end
   end
 
   class UInt16Array < IntArray
-    def to_a
-      apply_validity(@values_buffer.values(:u16, 0, @size))
-    end
   end
 
   class Int32Array < IntArray
-    def to_a
-      apply_validity(@values_buffer.values(:s32, 0, @size))
-    end
   end
 
   class UInt32Array < IntArray
-    def to_a
-      apply_validity(@values_buffer.values(:u32, 0, @size))
-    end
   end
 
   class Int64Array < IntArray
-    def to_a
-      apply_validity(@values_buffer.values(:s64, 0, @size))
-    end
   end
 
   class UInt64Array < IntArray
-    def to_a
-      apply_validity(@values_buffer.values(:u64, 0, @size))
-    end
   end
 
   class FloatingPointArray < Array
@@ -410,6 +390,27 @@ module ArrowFormat
     end
   end
 
+  class MapArray < VariableSizeListArray
+    def to_a
+      super.collect do |entries|
+        if entries.nil?
+          entries
+        else
+          hash = {}
+          entries.each do |key, value|
+            hash[key] = value
+          end
+          hash
+        end
+      end
+    end
+
+    private
+    def offset_type
+      :s32 # TODO: big endian support
+    end
+  end
+
   class UnionArray < Array
     def initialize(type, size, types_buffer, children)
       super(type, size, nil)
@@ -449,24 +450,27 @@ module ArrowFormat
     end
   end
 
-  class MapArray < VariableSizeListArray
+  class DictionaryArray < Array
+    def initialize(type, size, validity_buffer, indices_buffer, dictionary)
+      super(type, size, validity_buffer)
+      @indices_buffer = indices_buffer
+      @dictionary = dictionary
+    end
+
     def to_a
-      super.collect do |entries|
-        if entries.nil?
-          entries
+      values = []
+      @dictionary.each do |dictionary_chunk|
+        values.concat(dictionary_chunk.to_a)
+      end
+      buffer_type = @type.index_type.buffer_type
+      indices = apply_validity(@indices_buffer.values(buffer_type, 0, @size))
+      indices.collect do |index|
+        if index.nil?
+          nil
         else
-          hash = {}
-          entries.each do |key, value|
-            hash[key] = value
-          end
-          hash
+          values[index]
         end
       end
     end
-
-    private
-    def offset_type
-      :s32 # TODO: big endian support
-    end
   end
 end
diff --git a/ruby/red-arrow-format/lib/arrow-format/field.rb 
b/ruby/red-arrow-format/lib/arrow-format/field.rb
index ac531750f7..090113cfe6 100644
--- a/ruby/red-arrow-format/lib/arrow-format/field.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/field.rb
@@ -18,10 +18,12 @@ module ArrowFormat
   class Field
     attr_reader :name
     attr_reader :type
-    def initialize(name, type, nullable)
+    attr_reader :dictionary_id
+    def initialize(name, type, nullable, dictionary_id)
       @name = name
       @type = type
       @nullable = nullable
+      @dictionary_id = dictionary_id
     end
 
     def nullable?
diff --git a/ruby/red-arrow-format/lib/arrow-format/file-reader.rb 
b/ruby/red-arrow-format/lib/arrow-format/file-reader.rb
index bf50bfd1cd..545638ca90 100644
--- a/ruby/red-arrow-format/lib/arrow-format/file-reader.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/file-reader.rb
@@ -49,17 +49,65 @@ module ArrowFormat
 
       validate
       @footer = read_footer
-      @record_batches = @footer.record_batches
+      @record_batch_blocks = @footer.record_batches
       @schema = read_schema(@footer.schema)
+      @dictionaries = read_dictionaries
     end
 
     def n_record_batches
-      @record_batches.size
+      @record_batch_blocks.size
     end
 
     def read(i)
-      block = @record_batches[i]
+      fb_message, body = read_block(@record_batch_blocks[i])
+      fb_header = fb_message.header
+      unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch)
+        raise FileReadError.new(@buffer,
+                                "Not a record batch message: #{i}: " +
+                                fb_header.class.name)
+      end
+      read_record_batch(fb_header, @schema, body)
+    end
+
+    def each
+      return to_enum(__method__) {n_record_batches} unless block_given?
+
+      @record_batch_blocks.size.times do |i|
+        yield(read(i))
+      end
+    end
+
+    private
+    def validate
+      minimum_size = STREAMING_FORMAT_START_OFFSET +
+                     FOOTER_SIZE_SIZE +
+                     END_MARKER_SIZE
+      if @buffer.size < minimum_size
+        raise FileReadError.new(@buffer,
+                                "Input must be larger than or equal to " +
+                                "#{minimum_size}: #{@buffer.size}")
+      end
+
+      start_marker = @buffer.slice(0, START_MARKER_SIZE)
+      if start_marker != MAGIC_BUFFER
+        raise FileReadError.new(@buffer, "No start marker")
+      end
+      end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE,
+                                 END_MARKER_SIZE)
+      if end_marker != MAGIC_BUFFER
+        raise FileReadError.new(@buffer, "No end marker")
+      end
+    end
+
+    def read_footer
+      footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE
+      footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset)
+      footer_data = @buffer.slice(footer_size_offset - footer_size,
+                                  footer_size)
+      Org::Apache::Arrow::Flatbuf::Footer.new(footer_data)
+    end
 
+    def read_block(block)
       offset = block.offset
 
       # If we can report property error information, we can use
@@ -101,54 +149,65 @@ module ArrowFormat
 
       metadata = @buffer.slice(offset, metadata_length)
       fb_message = Org::Apache::Arrow::Flatbuf::Message.new(metadata)
-      fb_header = fb_message.header
-      unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch)
-        raise FileReadError.new(@buffer,
-                                "Not a record batch message: #{i}: " +
-                                fb_header.class.name)
-      end
       offset += metadata_length
 
       body = @buffer.slice(offset, block.body_length)
-      read_record_batch(fb_header, @schema, body)
-    end
 
-    def each
-      return to_enum(__method__) {n_record_batches} unless block_given?
-
-      @record_batches.size.times do |i|
-        yield(read(i))
-      end
+      [fb_message, body]
     end
 
-    private
-    def validate
-      minimum_size = STREAMING_FORMAT_START_OFFSET +
-                     FOOTER_SIZE_SIZE +
-                     END_MARKER_SIZE
-      if @buffer.size < minimum_size
-        raise FileReadError.new(@buffer,
-                                "Input must be larger than or equal to " +
-                                "#{minimum_size}: #{@buffer.size}")
-      end
+    def read_dictionaries
+      dictionary_blocks = @footer.dictionaries
+      return nil if dictionary_blocks.nil?
 
-      start_marker = @buffer.slice(0, START_MARKER_SIZE)
-      if start_marker != MAGIC_BUFFER
-        raise FileReadError.new(@buffer, "No start marker")
+      dictionary_fields = {}
+      @schema.fields.each do |field|
+        next unless field.type.is_a?(DictionaryType)
+        dictionary_fields[field.dictionary_id] = field
       end
-      end_marker = @buffer.slice(@buffer.size - END_MARKER_SIZE,
-                                 END_MARKER_SIZE)
-      if end_marker != MAGIC_BUFFER
-        raise FileReadError.new(@buffer, "No end marker")
+
+      dictionaries = {}
+      dictionary_blocks.each do |block|
+        fb_message, body = read_block(block)
+        fb_header = fb_message.header
+        unless fb_header.is_a?(Org::Apache::Arrow::Flatbuf::DictionaryBatch)
+          raise FileReadError.new(@buffer,
+                                  "Not a dictionary batch message: " +
+                                  fb_header.inspect)
+        end
+
+        id = fb_header.id
+        if fb_header.delta?
+          unless dictionaries.key?(id)
+            raise FileReadError.new(@buffer,
+                                    "A delta dictionary batch message " +
+                                    "must exist after a non delta " +
+                                    "dictionary batch message: " +
+                                    fb_header.inspect)
+          end
+        else
+          if dictionaries.key?(id)
+            raise FileReadError.new(@buffer,
+                                    "Multiple non delta dictionary batch " +
+                                    "messages for the same ID is invalid: " +
+                                    fb_header.inspect)
+          end
+        end
+
+        value_type = dictionary_fields[id].type.value_type
+        schema = Schema.new([Field.new("dummy", value_type, true, nil)])
+        record_batch = read_record_batch(fb_header.data, schema, body)
+        if fb_header.delta?
+          dictionaries[id] << record_batch.columns[0]
+        else
+          dictionaries[id] = [record_batch.columns[0]]
+        end
       end
+      dictionaries
     end
 
-    def read_footer
-      footer_size_offset = @buffer.size - END_MARKER_SIZE - FOOTER_SIZE_SIZE
-      footer_size = @buffer.get_value(FOOTER_SIZE_FORMAT, footer_size_offset)
-      footer_data = @buffer.slice(footer_size_offset - footer_size,
-                                  footer_size)
-      Org::Apache::Arrow::Flatbuf::Footer.new(footer_data)
+    def find_dictionary(id)
+      @dictionaries[id]
     end
   end
 end
diff --git a/ruby/red-arrow-format/lib/arrow-format/readable.rb 
b/ruby/red-arrow-format/lib/arrow-format/readable.rb
index 7aa2effde2..ad6be653e0 100644
--- a/ruby/red-arrow-format/lib/arrow-format/readable.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/readable.rb
@@ -26,6 +26,8 @@ require_relative "org/apache/arrow/flatbuf/bool"
 require_relative "org/apache/arrow/flatbuf/date"
 require_relative "org/apache/arrow/flatbuf/date_unit"
 require_relative "org/apache/arrow/flatbuf/decimal"
+require_relative "org/apache/arrow/flatbuf/dictionary_encoding"
+require_relative "org/apache/arrow/flatbuf/dictionary_batch"
 require_relative "org/apache/arrow/flatbuf/duration"
 require_relative "org/apache/arrow/flatbuf/fixed_size_binary"
 require_relative "org/apache/arrow/flatbuf/floating_point"
@@ -40,11 +42,12 @@ require_relative "org/apache/arrow/flatbuf/map"
 require_relative "org/apache/arrow/flatbuf/message"
 require_relative "org/apache/arrow/flatbuf/null"
 require_relative "org/apache/arrow/flatbuf/precision"
+require_relative "org/apache/arrow/flatbuf/record_batch"
 require_relative "org/apache/arrow/flatbuf/schema"
 require_relative "org/apache/arrow/flatbuf/struct_"
 require_relative "org/apache/arrow/flatbuf/time"
-require_relative "org/apache/arrow/flatbuf/timestamp"
 require_relative "org/apache/arrow/flatbuf/time_unit"
+require_relative "org/apache/arrow/flatbuf/timestamp"
 require_relative "org/apache/arrow/flatbuf/union"
 require_relative "org/apache/arrow/flatbuf/union_mode"
 require_relative "org/apache/arrow/flatbuf/utf8"
@@ -67,32 +70,7 @@ module ArrowFormat
       when Org::Apache::Arrow::Flatbuf::Bool
         type = BooleanType.singleton
       when Org::Apache::Arrow::Flatbuf::Int
-        case fb_type.bit_width
-        when 8
-          if fb_type.signed?
-            type = Int8Type.singleton
-          else
-            type = UInt8Type.singleton
-          end
-        when 16
-          if fb_type.signed?
-            type = Int16Type.singleton
-          else
-            type = UInt16Type.singleton
-          end
-        when 32
-          if fb_type.signed?
-            type = Int32Type.singleton
-          else
-            type = UInt32Type.singleton
-          end
-        when 64
-          if fb_type.signed?
-            type = Int64Type.singleton
-          else
-            type = UInt64Type.singleton
-          end
-        end
+        type = read_type_int(fb_type)
       when Org::Apache::Arrow::Flatbuf::FloatingPoint
         case fb_type.precision
         when Org::Apache::Arrow::Flatbuf::Precision::SINGLE
@@ -175,14 +153,52 @@ module ArrowFormat
           type = Decimal256Type.new(fb_type.precision, fb_type.scale)
         end
       end
-      Field.new(fb_field.name, type, fb_field.nullable?)
+
+      dictionary = fb_field.dictionary
+      if dictionary
+        dictionary_id = dictionary.id
+        index_type = read_type_int(dictionary.index_type)
+        type = DictionaryType.new(index_type, type, dictionary.ordered?)
+      else
+        dictionary_id = nil
+      end
+      Field.new(fb_field.name, type, fb_field.nullable?, dictionary_id)
+    end
+
+    def read_type_int(fb_type)
+      case fb_type.bit_width
+      when 8
+        if fb_type.signed?
+          Int8Type.singleton
+        else
+          UInt8Type.singleton
+        end
+      when 16
+        if fb_type.signed?
+          Int16Type.singleton
+        else
+          UInt16Type.singleton
+        end
+      when 32
+        if fb_type.signed?
+          Int32Type.singleton
+        else
+          UInt32Type.singleton
+        end
+      when 64
+        if fb_type.signed?
+          Int64Type.singleton
+        else
+          UInt64Type.singleton
+        end
+      end
     end
 
     def read_record_batch(fb_record_batch, schema, body)
       n_rows = fb_record_batch.length
       nodes = fb_record_batch.nodes
       buffers = fb_record_batch.buffers
-      columns = @schema.fields.collect do |field|
+      columns = schema.fields.collect do |field|
         read_column(field, nodes, buffers, body)
       end
       RecordBatch.new(schema, n_rows, columns)
@@ -244,6 +260,11 @@ module ArrowFormat
           read_column(child, nodes, buffers, body)
         end
         field.type.build_array(length, types, children)
+      when DictionaryType
+        indices_buffer = buffers.shift
+        indices = body.slice(indices_buffer.offset, indices_buffer.length)
+        dictionary = find_dictionary(field.dictionary_id)
+        field.type.build_array(length, validity, indices, dictionary)
       end
     end
   end
diff --git a/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb 
b/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb
index ae231fccbc..8682f3e826 100644
--- a/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/streaming-pull-reader.rb
@@ -151,6 +151,8 @@ module ArrowFormat
       end
       @state = :schema
       @schema = nil
+      @dictionaries = nil
+      @dictionary_fields = nil
     end
 
     def next_required_size
@@ -170,8 +172,23 @@ module ArrowFormat
       case @state
       when :schema
         process_schema_message(message, body)
-      when :record_batch
-        process_record_batch_message(message, body)
+      when :initial_dictionaries
+        header = message.header
+        unless header.is_a?(Org::Apache::Arrow::Flatbuf::DictionaryBatch)
+          raise ReadError.new("Not a dictionary batch message: " +
+                              header.inspect)
+        end
+        process_dictionary_batch_message(message, body)
+        if @dictionaries.size == @dictionary_fields.size
+          @state = :data
+        end
+      when :data
+        case message.header
+        when Org::Apache::Arrow::Flatbuf::DictionaryBatch
+          process_dictionary_batch_message(message, body)
+        when Org::Apache::Arrow::Flatbuf::RecordBatch
+          process_record_batch_message(message, body)
+        end
       end
     end
 
@@ -183,17 +200,43 @@ module ArrowFormat
       end
 
       @schema = read_schema(header)
-      # TODO: initial dictionaries support
-      @state = :record_batch
+      @dictionaries = {}
+      @dictionary_fields = {}
+      @schema.fields.each do |field|
+        next unless field.type.is_a?(DictionaryType)
+        @dictionary_fields[field.dictionary_id] = field
+      end
+      if @dictionaries.size < @dictionary_fields.size
+        @state = :initial_dictionaries
+      else
+        @state = :data
+      end
     end
 
-    def process_record_batch_message(message, body)
+    def process_dictionary_batch_message(message, body)
       header = message.header
-      unless header.is_a?(Org::Apache::Arrow::Flatbuf::RecordBatch)
-        raise ReadError.new("Not a record batch message: " +
+      if @state == :initial_dictionaries and header.delta?
+        raise ReadError.new("An initial dictionary batch message must be " +
+                            "a non delta dictionary batch message: " +
                             header.inspect)
       end
+      field = @dictionary_fields[header.id]
+      value_type = field.type.value_type
+      schema = Schema.new([Field.new("dummy", value_type, true, nil)])
+      record_batch = read_record_batch(header.data, schema, body)
+      if header.delta?
+        @dictionaries[header.id] << record_batch.columns[0]
+      else
+        @dictionaries[header.id] = [record_batch.columns[0]]
+      end
+    end
 
+    def find_dictionary(id)
+      @dictionaries[id]
+    end
+
+    def process_record_batch_message(message, body)
+      header = message.header
       @on_read.call(read_record_batch(header, @schema, body))
     end
   end
diff --git a/ruby/red-arrow-format/lib/arrow-format/type.rb 
b/ruby/red-arrow-format/lib/arrow-format/type.rb
index 92a699509b..ebf4ce5fa9 100644
--- a/ruby/red-arrow-format/lib/arrow-format/type.rb
+++ b/ruby/red-arrow-format/lib/arrow-format/type.rb
@@ -78,6 +78,10 @@ module ArrowFormat
       "Int8"
     end
 
+    def buffer_type
+      :S8
+    end
+
     def build_array(size, validity_buffer, values_buffer)
       Int8Array.new(self, size, validity_buffer, values_buffer)
     end
@@ -98,6 +102,10 @@ module ArrowFormat
       "UInt8"
     end
 
+    def buffer_type
+      :U8
+    end
+
     def build_array(size, validity_buffer, values_buffer)
       UInt8Array.new(self, size, validity_buffer, values_buffer)
     end
@@ -118,6 +126,10 @@ module ArrowFormat
       "Int16"
     end
 
+    def buffer_type
+      :s16
+    end
+
     def build_array(size, validity_buffer, values_buffer)
       Int16Array.new(self, size, validity_buffer, values_buffer)
     end
@@ -138,6 +150,10 @@ module ArrowFormat
       "UInt16"
     end
 
+    def buffer_type
+      :u16
+    end
+
     def build_array(size, validity_buffer, values_buffer)
       UInt16Array.new(self, size, validity_buffer, values_buffer)
     end
@@ -158,6 +174,10 @@ module ArrowFormat
       "Int32"
     end
 
+    def buffer_type
+      :s32
+    end
+
     def build_array(size, validity_buffer, values_buffer)
       Int32Array.new(self, size, validity_buffer, values_buffer)
     end
@@ -178,6 +198,10 @@ module ArrowFormat
       "UInt32"
     end
 
+    def buffer_type
+      :u32
+    end
+
     def build_array(size, validity_buffer, values_buffer)
       UInt32Array.new(self, size, validity_buffer, values_buffer)
     end
@@ -198,6 +222,10 @@ module ArrowFormat
       "Int64"
     end
 
+    def buffer_type
+      :s64
+    end
+
     def build_array(size, validity_buffer, values_buffer)
       Int64Array.new(self, size, validity_buffer, values_buffer)
     end
@@ -218,6 +246,10 @@ module ArrowFormat
       "UInt64"
     end
 
+    def buffer_type
+      :u64
+    end
+
     def build_array(size, validity_buffer, values_buffer)
       UInt64Array.new(self, size, validity_buffer, values_buffer)
     end
@@ -645,4 +677,28 @@ module ArrowFormat
       SparseUnionArray.new(self, size, types_buffer, children)
     end
   end
+
+  class DictionaryType < Type
+    attr_reader :index_type
+    attr_reader :value_type
+    attr_reader :ordered
+    def initialize(index_type, value_type, ordered)
+      super()
+      @index_type = index_type
+      @value_type = value_type
+      @ordered = ordered
+    end
+
+    def name
+      "Dictionary"
+    end
+
+    def build_array(size, validity_buffer, indices_buffer, dictionary)
+      DictionaryArray.new(self,
+                          size,
+                          validity_buffer,
+                          indices_buffer,
+                          dictionary)
+    end
+  end
 end
diff --git a/ruby/red-arrow-format/test/test-reader.rb 
b/ruby/red-arrow-format/test/test-reader.rb
index 8164d20623..0e59d855ce 100644
--- a/ruby/red-arrow-format/test/test-reader.rb
+++ b/ruby/red-arrow-format/test/test-reader.rb
@@ -866,6 +866,19 @@ module ReaderTests
                          read)
           end
         end
+
+        sub_test_case("Dictionary") do
+          def build_array
+            values = ["a", "b", "c", nil, "a"]
+            string_array = Arrow::StringArray.new(values)
+            string_array.dictionary_encode
+          end
+
+          def test_read
+            assert_equal([{"value" => ["a", "b", "c", nil, "a"]}],
+                         read)
+          end
+        end
       end
     end
   end

Reply via email to