This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ce3c03  feat(parquet/pqarrow): read/write variant (#434)
7ce3c03 is described below

commit 7ce3c03fbc268cb67dee53676c48e9bd74be1684
Author: Matt Topol <[email protected]>
AuthorDate: Sun Jul 13 11:12:41 2025 -0400

    feat(parquet/pqarrow): read/write variant (#434)
    
    ### Rationale for this change
    resolves #310
    
    ### What changes are included in this PR?
    Updating the `pqarrow` package to support full round trip read/write of
    Variant values via `arrow/extensions/variant`
    
    ### Are these changes tested?
    Yes, unit tests are added for both shredded and unshredded variants.
    
    ### Are there any user-facing changes?
     just the new features.
---
 arrow/extensions/variant.go          |  82 ++++++++++++++++++++++++++-
 arrow/extensions/variant_test.go     |  81 +++++++++++++++++++++++++++
 parquet/file/record_reader.go        |   1 +
 parquet/pqarrow/encode_arrow_test.go | 104 +++++++++++++++++++++++++++++++++++
 parquet/pqarrow/file_reader.go       |  40 +++++++++++++-
 parquet/pqarrow/schema.go            |  44 +++++++++++----
 6 files changed, 339 insertions(+), 13 deletions(-)

diff --git a/arrow/extensions/variant.go b/arrow/extensions/variant.go
index fbef4a6..fe97f24 100644
--- a/arrow/extensions/variant.go
+++ b/arrow/extensions/variant.go
@@ -62,6 +62,83 @@ func NewDefaultVariantType() *VariantType {
        return vt
 }
 
+func createShreddedField(dt arrow.DataType) arrow.DataType {
+       switch t := dt.(type) {
+       case arrow.ListLikeType:
+               return arrow.ListOfNonNullable(arrow.StructOf(
+                       arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
+                       arrow.Field{Name: "typed_value", Type: 
createShreddedField(t.Elem()), Nullable: true},
+               ))
+       case *arrow.StructType:
+               fields := make([]arrow.Field, 0, t.NumFields())
+               for i := range t.NumFields() {
+                       f := t.Field(i)
+                       fields = append(fields, arrow.Field{
+                               Name: f.Name,
+                               Type: arrow.StructOf(arrow.Field{
+                                       Name:     "value",
+                                       Type:     arrow.BinaryTypes.Binary,
+                                       Nullable: true,
+                               }, arrow.Field{
+                                       Name:     "typed_value",
+                                       Type:     createShreddedField(f.Type),
+                                       Nullable: true,
+                               }),
+                               Nullable: false,
+                               Metadata: f.Metadata,
+                       })
+               }
+               return arrow.StructOf(fields...)
+       default:
+               return dt
+       }
+}
+
+// NewShreddedVariantType creates a new VariantType extension type using the 
provided
+// type to define a shredded schema by setting the `typed_value` field 
accordingly and
+// properly constructing the shredded fields for structs, lists and so on.
+//
+// For example:
+//
+//     NewShreddedVariantType(arrow.StructOf(
+//          arrow.Field{Name: "latitude", Type: arrow.PrimitiveTypes.Float64},
+//          arrow.Field{Name: "longitude", Type: 
arrow.PrimitiveTypes.Float32}))
+//
+// Will create a variant type with the following structure:
+//
+//     arrow.StructOf(
+//          arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, 
Nullable: false},
+//          arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, 
Nullable: true},
+//          arrow.Field{Name: "typed_value", Type: arrow.StructOf(
+//            arrow.Field{Name: "latitude", Type: arrow.StructOf(
+//              arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, 
Nullable: true},
+//              arrow.Field{Name: "typed_value", Type: 
arrow.PrimitiveTypes.Float64, Nullable: true}),
+//              Nullable: false},
+//          arrow.Field{Name: "longitude", Type: arrow.StructOf(
+//              arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, 
Nullable: true},
+//              arrow.Field{Name: "typed_value", Type: 
arrow.PrimitiveTypes.Float32, Nullable: true}),
+//              Nullable: false},
+//      ), Nullable: true})
+//
+// This is intended to be a convenient way to create a shredded variant type 
from a definition
+// of the fields to shred. If the provided data type is nil, it will create a 
default
+// variant type.
+func NewShreddedVariantType(dt arrow.DataType) *VariantType {
+       if dt == nil {
+               return NewDefaultVariantType()
+       }
+
+       vt, _ := NewVariantType(arrow.StructOf(
+               arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, 
Nullable: false},
+               arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, 
Nullable: true},
+               arrow.Field{
+                       Name:     "typed_value",
+                       Type:     createShreddedField(dt),
+                       Nullable: true,
+               }))
+       return vt
+}
+
 // NewVariantType creates a new variant type based on the provided storage 
type.
 //
 // The rules for a variant storage type are:
@@ -1480,8 +1557,9 @@ type shreddedObjBuilder struct {
 }
 
 func (b *shreddedObjBuilder) AppendMissing() {
-       b.structBldr.Append(true)
+       b.structBldr.AppendValues([]bool{false})
        for _, fieldBldr := range b.fieldBuilders {
+               fieldBldr.structBldr.Append(true)
                fieldBldr.valueBldr.AppendNull()
                fieldBldr.typedBldr.AppendMissing()
        }
@@ -1489,7 +1567,7 @@ func (b *shreddedObjBuilder) AppendMissing() {
 
 func (b *shreddedObjBuilder) tryTyped(v variant.Value) (residual []byte) {
        if v.Type() != variant.Object {
-               b.structBldr.AppendNull()
+               b.AppendMissing()
                return v.Bytes()
        }
 
diff --git a/arrow/extensions/variant_test.go b/arrow/extensions/variant_test.go
index 9a1c05f..6e539ee 100644
--- a/arrow/extensions/variant_test.go
+++ b/arrow/extensions/variant_test.go
@@ -1574,3 +1574,84 @@ func TestVariantBuilderUnmarshalJSON(t *testing.T) {
                assert.Equal(t, int8(5), innerVal2.Value())
        })
 }
+
+func TestNewSimpleShreddedVariantType(t *testing.T) {
+       assert.True(t, arrow.TypeEqual(extensions.NewDefaultVariantType(),
+               extensions.NewShreddedVariantType(nil)))
+
+       vt := extensions.NewShreddedVariantType(arrow.PrimitiveTypes.Float32)
+       s := arrow.StructOf(
+               arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary},
+               arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, 
Nullable: true},
+               arrow.Field{Name: "typed_value", Type: 
arrow.PrimitiveTypes.Float32, Nullable: true})
+
+       assert.Truef(t, arrow.TypeEqual(vt.Storage, s), "expected %s, got %s", 
s, vt.Storage)
+}
+
+func TestNewShreddedVariantType(t *testing.T) {
+       vt := extensions.NewShreddedVariantType(arrow.StructOf(arrow.Field{
+               Name: "event_type",
+               Type: arrow.BinaryTypes.String,
+       }, arrow.Field{
+               Name: "event_ts",
+               Type: arrow.FixedWidthTypes.Timestamp_us,
+       }))
+
+       assert.NotNil(t, vt)
+       s := arrow.StructOf(
+               arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary},
+               arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, 
Nullable: true},
+               arrow.Field{Name: "typed_value", Type: arrow.StructOf(
+                       arrow.Field{Name: "event_type", Type: arrow.StructOf(
+                               arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
+                               arrow.Field{Name: "typed_value", Type: 
arrow.BinaryTypes.String, Nullable: true},
+                       )},
+                       arrow.Field{Name: "event_ts", Type: arrow.StructOf(
+                               arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
+                               arrow.Field{Name: "typed_value", Type: 
arrow.FixedWidthTypes.Timestamp_us, Nullable: true},
+                       )},
+               ), Nullable: true})
+
+       assert.Truef(t, arrow.TypeEqual(vt.Storage, s), "expected %s, got %s", 
s, vt.Storage)
+}
+
+func TestShreddedVariantNested(t *testing.T) {
+       vt := extensions.NewShreddedVariantType(arrow.StructOf(
+               arrow.Field{Name: "strval", Type: arrow.BinaryTypes.String},
+               arrow.Field{Name: "bool", Type: arrow.FixedWidthTypes.Boolean},
+               arrow.Field{Name: "location", Type: arrow.ListOf(arrow.StructOf(
+                       arrow.Field{Name: "latitude", Type: 
arrow.PrimitiveTypes.Float64},
+                       arrow.Field{Name: "longitude", Type: 
arrow.PrimitiveTypes.Float32},
+               ))}))
+
+       s := arrow.StructOf(
+               arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary},
+               arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, 
Nullable: true},
+               arrow.Field{Name: "typed_value", Type: arrow.StructOf(
+                       arrow.Field{Name: "strval", Type: arrow.StructOf(
+                               arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
+                               arrow.Field{Name: "typed_value", Type: 
arrow.BinaryTypes.String, Nullable: true},
+                       )},
+                       arrow.Field{Name: "bool", Type: arrow.StructOf(
+                               arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
+                               arrow.Field{Name: "typed_value", Type: 
arrow.FixedWidthTypes.Boolean, Nullable: true},
+                       )},
+                       arrow.Field{Name: "location", Type: arrow.StructOf(
+                               arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
+                               arrow.Field{Name: "typed_value", Type: 
arrow.ListOfNonNullable(arrow.StructOf(
+                                       arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
+                                       arrow.Field{Name: "typed_value", Type: 
arrow.StructOf(
+                                               arrow.Field{Name: "latitude", 
Type: arrow.StructOf(
+                                                       arrow.Field{Name: 
"value", Type: arrow.BinaryTypes.Binary, Nullable: true},
+                                                       arrow.Field{Name: 
"typed_value", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
+                                               )},
+                                               arrow.Field{Name: "longitude", 
Type: arrow.StructOf(
+                                                       arrow.Field{Name: 
"value", Type: arrow.BinaryTypes.Binary, Nullable: true},
+                                                       arrow.Field{Name: 
"typed_value", Type: arrow.PrimitiveTypes.Float32, Nullable: true},
+                                               )},
+                                       ), Nullable: true},
+                               )), Nullable: true})},
+               ), Nullable: true})
+
+       assert.Truef(t, arrow.TypeEqual(vt.Storage, s), "expected %s, got %s", 
s, vt.Storage)
+}
diff --git a/parquet/file/record_reader.go b/parquet/file/record_reader.go
index 81ec0af..a21e066 100644
--- a/parquet/file/record_reader.go
+++ b/parquet/file/record_reader.go
@@ -555,6 +555,7 @@ func (rr *recordReader) ReadRecordData(numRecords int64) 
(int64, error) {
                // no repetition levels, skip delimiting logic. each level
                // represents null or not null entry
                recordsRead = utils.Min(rr.levelsWritten-rr.levelsPos, 
numRecords)
+               valuesToRead = recordsRead
                // this is advanced by delimitRecords which we skipped
                rr.levelsPos += recordsRead
        } else {
diff --git a/parquet/pqarrow/encode_arrow_test.go 
b/parquet/pqarrow/encode_arrow_test.go
index 0a2edab..61bc263 100644
--- a/parquet/pqarrow/encode_arrow_test.go
+++ b/parquet/pqarrow/encode_arrow_test.go
@@ -2314,3 +2314,107 @@ func TestEmptyListDeltaBinaryPacked(t *testing.T) {
        assert.True(t, schema.Equal(tbl.Schema()))
        assert.EqualValues(t, 1, tbl.NumRows())
 }
+
+func TestReadWriteNonShreddedVariant(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       bldr := extensions.NewVariantBuilder(mem, 
extensions.NewDefaultVariantType())
+       defer bldr.Release()
+
+       jsonData := `[
+                       42,
+                       "text",
+                       [1, 2, 3],
+                       {"name": "Alice"},
+                       [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 
2"}],
+                       {"items": [1, "two", true], "metadata": {"created": 
"2025-01-01"}},
+                       null
+               ]`
+
+       err := bldr.UnmarshalJSON([]byte(jsonData))
+       require.NoError(t, err)
+
+       arr := bldr.NewArray()
+       defer arr.Release()
+
+       rec := array.NewRecord(arrow.NewSchema([]arrow.Field{
+               {Name: "variant", Type: arr.DataType(), Nullable: true},
+       }, nil), []arrow.Array{arr}, -1)
+
+       var buf bytes.Buffer
+       wr, err := pqarrow.NewFileWriter(rec.Schema(), &buf, nil,
+               pqarrow.DefaultWriterProps())
+       require.NoError(t, err)
+
+       require.NoError(t, wr.Write(rec))
+       rec.Release()
+       wr.Close()
+
+       rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()))
+       require.NoError(t, err)
+       reader, err := pqarrow.NewFileReader(rdr, 
pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+       require.NoError(t, err)
+       defer rdr.Close()
+
+       tbl, err := reader.ReadTable(context.Background())
+       require.NoError(t, err)
+       defer tbl.Release()
+
+       assert.True(t, array.Equal(arr, tbl.Column(0).Data().Chunk(0)))
+}
+
+func TestReadWriteShreddedVariant(t *testing.T) {
+       vt := extensions.NewShreddedVariantType(arrow.StructOf(
+               arrow.Field{Name: "event_type", Type: arrow.BinaryTypes.String},
+               arrow.Field{Name: "event_ts", Type: 
arrow.FixedWidthTypes.Timestamp_us}))
+
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       bldr := vt.NewBuilder(mem)
+       defer bldr.Release()
+
+       jsonData := `[
+                       {"event_type": "noop", "event_ts": "1970-01-21 
00:29:54.114937Z"},
+                       42,
+                       {"event_type": "text", "event_ts": "1970-01-21 
00:29:54.954163Z"},
+                       {"event_type": "list", "event_ts": "1970-01-21 
00:29:54.240241Z"},
+                       "text",
+                       {"event_type": "object", "event_ts": "1970-01-21 
00:29:54.146402Z"},                    
+                       null
+               ]`
+
+       err := bldr.UnmarshalJSON([]byte(jsonData))
+       require.NoError(t, err)
+
+       arr := bldr.NewArray()
+       defer arr.Release()
+
+       rec := array.NewRecord(arrow.NewSchema([]arrow.Field{
+               {Name: "variant", Type: arr.DataType(), Nullable: true},
+       }, nil), []arrow.Array{arr}, -1)
+
+       var buf bytes.Buffer
+       wr, err := pqarrow.NewFileWriter(rec.Schema(), &buf,
+               
parquet.NewWriterProperties(parquet.WithDictionaryDefault(false)),
+               pqarrow.DefaultWriterProps())
+       require.NoError(t, err)
+
+       require.NoError(t, wr.Write(rec))
+       rec.Release()
+       wr.Close()
+
+       rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()))
+       require.NoError(t, err)
+       reader, err := pqarrow.NewFileReader(rdr, 
pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+       require.NoError(t, err)
+       defer rdr.Close()
+
+       tbl, err := reader.ReadTable(context.Background())
+       require.NoError(t, err)
+       defer tbl.Release()
+
+       assert.Truef(t, array.Equal(arr, tbl.Column(0).Data().Chunk(0)),
+               "expected: %s\ngot: %s", arr, tbl.Column(0).Data().Chunk(0))
+}
diff --git a/parquet/pqarrow/file_reader.go b/parquet/pqarrow/file_reader.go
index b064107..8fb114c 100644
--- a/parquet/pqarrow/file_reader.go
+++ b/parquet/pqarrow/file_reader.go
@@ -111,6 +111,37 @@ func (fr *FileReader) Schema() (*arrow.Schema, error) {
        return FromParquet(fr.rdr.MetaData().Schema, &fr.Props, 
fr.rdr.MetaData().KeyValueMetadata())
 }
 
+type extensionReader struct {
+       colReaderImpl
+
+       fieldWithExt arrow.Field
+}
+
+func (er *extensionReader) Field() *arrow.Field {
+       return &er.fieldWithExt
+}
+
+func (er *extensionReader) BuildArray(boundedLen int64) (*arrow.Chunked, 
error) {
+       if er.colReaderImpl == nil {
+               return nil, errors.New("extension reader has no underlying 
column reader implementation")
+       }
+
+       chkd, err := er.colReaderImpl.BuildArray(boundedLen)
+       if err != nil {
+               return nil, err
+       }
+       defer chkd.Release()
+
+       extType := er.fieldWithExt.Type.(arrow.ExtensionType)
+
+       newChunks := make([]arrow.Array, len(chkd.Chunks()))
+       for i, c := range chkd.Chunks() {
+               newChunks[i] = array.NewExtensionArrayWithStorage(extType, c)
+       }
+
+       return arrow.NewChunked(extType, newChunks), nil
+}
+
 type colReaderImpl interface {
        LoadBatch(nrecs int64) error
        BuildArray(boundedLen int64) (*arrow.Chunked, error)
@@ -517,7 +548,14 @@ func (fr *FileReader) getReader(ctx context.Context, field 
*SchemaField, arrowFi
 
        switch arrowField.Type.ID() {
        case arrow.EXTENSION:
-               return nil, xerrors.New("extension type not implemented")
+               storageField := arrowField
+               storageField.Type = 
arrowField.Type.(arrow.ExtensionType).StorageType()
+               storageReader, err := fr.getReader(ctx, field, storageField)
+               if err != nil {
+                       return nil, err
+               }
+
+               return &ColumnReader{&extensionReader{colReaderImpl: 
storageReader, fieldWithExt: arrowField}}, nil
        case arrow.STRUCT:
 
                childReaders := make([]*ColumnReader, len(field.Children))
diff --git a/parquet/pqarrow/schema.go b/parquet/pqarrow/schema.go
index 17603b9..2c0e70b 100644
--- a/parquet/pqarrow/schema.go
+++ b/parquet/pqarrow/schema.go
@@ -18,7 +18,6 @@ package pqarrow
 
 import (
        "encoding/base64"
-       "errors"
        "fmt"
        "math"
        "strconv"
@@ -243,25 +242,25 @@ func repFromNullable(isnullable bool) parquet.Repetition {
 }
 
 func variantToNode(t *extensions.VariantType, field arrow.Field, props 
*parquet.WriterProperties, arrProps ArrowWriterProperties) (schema.Node, error) 
{
-       metadataNode, err := fieldToNode("metadata", t.Metadata(), props, 
arrProps)
+       fields := make(schema.FieldList, 2, 3)
+       var err error
+
+       fields[0], err = fieldToNode("metadata", t.Metadata(), props, arrProps)
        if err != nil {
                return nil, err
        }
 
-       valueNode, err := fieldToNode("value", t.Value(), props, arrProps)
+       fields[1], err = fieldToNode("value", t.Value(), props, arrProps)
        if err != nil {
                return nil, err
        }
 
-       fields := schema.FieldList{metadataNode, valueNode}
-
-       typedField := t.TypedValue()
-       if typedField.Type != nil {
-               typedNode, err := fieldToNode("typed_value", typedField, props, 
arrProps)
+       if typed := t.TypedValue(); typed.Type != nil {
+               typedValue, err := fieldToNode("typed_value", typed, props, 
arrProps)
                if err != nil {
                        return nil, err
                }
-               fields = append(fields, typedNode)
+               fields = append(fields, typedValue)
        }
 
        return schema.NewGroupNodeLogical(field.Name, 
repFromNullable(field.Nullable),
@@ -868,9 +867,34 @@ func variantToSchemaField(n *schema.GroupNode, 
currentLevels file.LevelInfo, ctx
        switch n.NumFields() {
        case 2, 3:
        default:
-               return errors.New("VARIANT group must have exactly 2 or 3 
children")
+               return fmt.Errorf("VARIANT group must have exactly 2 or 3 
children, not %d", n.NumFields())
        }
 
+       if n.RepetitionType() == parquet.Repetitions.Repeated {
+               // list of variants
+               out.Children = make([]SchemaField, 1)
+               repeatedAncestorDef := currentLevels.IncrementRepeated()
+               if err := groupToStructField(n, currentLevels, ctx, 
&out.Children[0]); err != nil {
+                       return err
+               }
+
+               storageType := out.Children[0].Field.Type
+               elemType, err := extensions.NewVariantType(storageType)
+               if err != nil {
+                       return err
+               }
+
+               out.Children[0].Field.Type = elemType
+               out.Field = &arrow.Field{Name: n.Name(), Type: 
arrow.ListOfField(*out.Children[0].Field), Nullable: true,
+                       Metadata: createFieldMeta(int(n.FieldID()))}
+               ctx.LinkParent(&out.Children[0], out)
+               out.LevelInfo = currentLevels
+               out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDef
+               return nil
+       }
+
+       currentLevels.Increment(n)
+
        var err error
        if err = groupToStructField(n, currentLevels, ctx, out); err != nil {
                return err

Reply via email to