This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new a9704994 ci(parquet/pqarrow): integration tests for reading shredded 
variants (#455)
a9704994 is described below

commit a97049945a08d36a494de637cffb100a02f4bc5a
Author: Matt Topol <[email protected]>
AuthorDate: Wed Aug 27 11:04:26 2025 -0400

    ci(parquet/pqarrow): integration tests for reading shredded variants (#455)
    
    ### Rationale for this change
    Testing out the variant implementation here against Parquet java using
    the test cases generated in
    https://github.com/apache/parquet-testing/pull/90/files. Overall, it
    confirms that our implementation is generally compatible for reading
    parquet files written by parquet-java with some caveats.
    
    ### What changes are included in this PR?
    New testing suite in `parquet/pqarrow/variant_test.go` which uses the
    test cases defined in parquet-testing and attempts to read the parquet
    files and compares the resulting variants against the expected ones.
    
    Some issues were found that I believe are issues with Parquet-java and
    the test cases rather than issues with the Go implementation, as such
    discussion is needed for the following:
    
    * The parquet test files are missing the Logical Variant Type
    annotation. Currently I've worked around that for testing purposes, but
    not in a way that can be merged or that is sustainable. As such the
    files need to be re-generated with the Variant Logical Type annotation
    before these tests can be enabled.
    * Several test cases test variations on situations where the `value`
    column is missing. Based on my reading of the
    
[spec](https://github.com/apache/parquet-format/blob/master/VariantShredding.md)
    this seems to be an invalid scenario. The specific case is that the spec
    states the `typed_value` field may be omitted when not shredding
    elements as a specific type, but says nothing about allowing omission of
    the `value` field. Currently, the Go implementation will error if this
    field is missing as per my reading of the spec, meaning those test cases
    fail.
    * Test case 43 `testPartiallyShreddedObjectMissingFieldConflict` seems
    to have a conflict between what is expected and what in the spec. The
    `b` field exists within the `value` field, while also being a shredded
    field, the test appears to assume the data in the `value` field would be
    ignored, but the
    
[spec](https://github.com/apache/parquet-format/blob/master/VariantShredding.md#objects)
    says that `value` **must never** contain fields represented by the
    shredded fields. This needs clarification on the desired behavior and
    result.
    * Test case 84, `testShreddedObjectWithOptionalFieldStructs` tests the
    schenario where the shredded fields of an object are listed as
    `optional` in the schema, but the spec states that they *must* be
    `required`. Thus, the Go implementation errors on this test as the spec
    says this is an error. Clarification is needed on if this is a valid
    test case.
    * Test case 38 `testShreddedObjectMissingTypedValue` tests the case
    where the `typed_value` field is missing, this is allowed by the spec
    except that the spec states that in this scenario the `value` field
    **must** be `required`. The test case uses `optional` in this scenario
    causing the Go implementation to fail. Clarification is needed here.
    * Test case 125, `testPartiallyShreddedObjectFieldConflict` again tests
    the case of a field existing in both the `value` and the shredded column
    which the spec states is invalid and will lead to inconsistent results.
    Thus it is not valid to have this test case assert a specific result
    according to the spec unless the spec is amended to state that the
    shredded field takes precedence in this case.
    * One thing that makes the tests a bit difficult is that when we
    un-shred back into variants, the current variant code in some libraries
    will automatically downcast to the smallest viable precision
    (downcasting an int32 into an int16 for example if it fits). This is
    worked around by testing the *values* rather than the types, but is
    worth mentioning. Particularly in the case of decimal values
    * A couple error test cases verify that particular types are **not**
    supported such as UINT32 or Fixed Len Byte Array(4), nothing in the spec
    however says that an implementation couldn't just upcast a uint32 ->
    int64 or treat a fixed len byte array shredded column as a binary
    column. So is it meaningful to explicitly error on those cases rather
    than allow them since they are trivially convertable to valid variant
    types?
---
 arrow-testing                    |   2 +-
 arrow/extensions/variant.go      | 157 +++++++++++++------
 arrow/extensions/variant_test.go |  24 +--
 parquet-testing                  |   2 +-
 parquet/pqarrow/encode_arrow.go  |   2 +-
 parquet/pqarrow/schema.go        |  20 ++-
 parquet/pqarrow/schema_test.go   |   4 +-
 parquet/pqarrow/variant_test.go  | 326 +++++++++++++++++++++++++++++++++++++++
 parquet/schema/logical_types.go  |   3 +-
 parquet/variant/builder.go       |  16 +-
 parquet/variant/builder_test.go  |   4 +-
 parquet/variant/variant.go       |   5 +-
 parquet/variant/variant_test.go  |   4 +-
 13 files changed, 474 insertions(+), 95 deletions(-)

diff --git a/arrow-testing b/arrow-testing
index d2a13712..6a7b02fa 160000
--- a/arrow-testing
+++ b/arrow-testing
@@ -1 +1 @@
-Subproject commit d2a13712303498963395318a4eb42872e66aead7
+Subproject commit 6a7b02fac93d8addbcdbb213264e58bfdc3068e4
diff --git a/arrow/extensions/variant.go b/arrow/extensions/variant.go
index fe97f247..659f571c 100644
--- a/arrow/extensions/variant.go
+++ b/arrow/extensions/variant.go
@@ -18,6 +18,7 @@ package extensions
 
 import (
        "bytes"
+       "errors"
        "fmt"
        "math"
        "reflect"
@@ -171,21 +172,23 @@ func NewVariantType(storage arrow.DataType) 
(*VariantType, error) {
                return nil, fmt.Errorf("%w: missing non-nullable field 
'metadata' in variant storage type %s", arrow.ErrInvalid, storage)
        }
 
-       if valueFieldIdx, ok = s.FieldIdx("value"); !ok {
-               return nil, fmt.Errorf("%w: missing non-nullable field 'value' 
in variant storage type %s", arrow.ErrInvalid, storage)
+       var valueOk, typedValueOk bool
+       valueFieldIdx, valueOk = s.FieldIdx("value")
+       typedValueFieldIdx, typedValueOk = s.FieldIdx("typed_value")
+
+       if !valueOk && !typedValueOk {
+               return nil, fmt.Errorf("%w: there must be at least one of 
'value' or 'typed_value' fields in variant storage type %s", arrow.ErrInvalid, 
storage)
        }
 
-       if s.NumFields() > 3 {
-               return nil, fmt.Errorf("%w: too many fields in variant storage 
type %s, expected 2 or 3", arrow.ErrInvalid, storage)
+       if s.NumFields() == 3 && (!valueOk || !typedValueOk) {
+               return nil, fmt.Errorf("%w: has 3 fields, but missing one of 
'value' or 'typed_value' fields, %s", arrow.ErrInvalid, storage)
        }
 
-       if s.NumFields() == 3 {
-               if typedValueFieldIdx, ok = s.FieldIdx("typed_value"); !ok {
-                       return nil, fmt.Errorf("%w: has 3 fields, but missing 
'typed_value' field, %s", arrow.ErrInvalid, storage)
-               }
+       if s.NumFields() > 3 {
+               return nil, fmt.Errorf("%w: too many fields in variant storage 
type %s, expected 2 or 3", arrow.ErrInvalid, storage)
        }
 
-       mdField, valField := s.Field(metadataFieldIdx), s.Field(valueFieldIdx)
+       mdField := s.Field(metadataFieldIdx)
        if mdField.Nullable {
                return nil, fmt.Errorf("%w: metadata field must be non-nullable 
binary type, got %s", arrow.ErrInvalid, mdField.Type)
        }
@@ -196,11 +199,14 @@ func NewVariantType(storage arrow.DataType) 
(*VariantType, error) {
                }
        }
 
-       if !isBinary(valField.Type) || (valField.Nullable && typedValueFieldIdx 
== -1) {
-               return nil, fmt.Errorf("%w: value field must be non-nullable 
binary type, got %s", arrow.ErrInvalid, valField.Type)
+       if valueOk {
+               valField := s.Field(valueFieldIdx)
+               if !isBinary(valField.Type) {
+                       return nil, fmt.Errorf("%w: value field must be binary 
type, got %s", arrow.ErrInvalid, valField.Type)
+               }
        }
 
-       if typedValueFieldIdx == -1 {
+       if !typedValueOk {
                return &VariantType{
                        ExtensionBase:      arrow.ExtensionBase{Storage: 
storage},
                        metadataFieldIdx:   metadataFieldIdx,
@@ -209,17 +215,17 @@ func NewVariantType(storage arrow.DataType) 
(*VariantType, error) {
                }, nil
        }
 
-       valueField := s.Field(valueFieldIdx)
-       if !valueField.Nullable {
-               return nil, fmt.Errorf("%w: value field must be nullable if 
typed_value is present, got %s", arrow.ErrInvalid, valueField.Type)
-       }
-
        typedValueField := s.Field(typedValueFieldIdx)
        if !typedValueField.Nullable {
                return nil, fmt.Errorf("%w: typed_value field must be nullable, 
got %s", arrow.ErrInvalid, typedValueField.Type)
        }
 
-       if nt, ok := typedValueField.Type.(arrow.NestedType); ok {
+       dt := typedValueField.Type
+       if dt.ID() == arrow.EXTENSION {
+               dt = dt.(arrow.ExtensionType).StorageType()
+       }
+
+       if nt, ok := dt.(arrow.NestedType); ok {
                if !validNestedType(nt) {
                        return nil, fmt.Errorf("%w: typed_value field must be a 
valid nested type, got %s", arrow.ErrInvalid, typedValueField.Type)
                }
@@ -242,6 +248,9 @@ func (v *VariantType) Metadata() arrow.Field {
 }
 
 func (v *VariantType) Value() arrow.Field {
+       if v.valueFieldIdx == -1 {
+               return arrow.Field{}
+       }
        return v.StorageType().(*arrow.StructType).Field(v.valueFieldIdx)
 }
 
@@ -286,7 +295,7 @@ func validStruct(s *arrow.StructType) bool {
        switch s.NumFields() {
        case 1:
                f := s.Field(0)
-               return f.Name == "value" && !f.Nullable && isBinary(f.Type)
+               return (f.Name == "value" && isBinary(f.Type)) || f.Name == 
"typed_value"
        case 2:
                valField, ok := s.FieldByName("value")
                if !ok || !valField.Nullable || !isBinary(valField.Type) {
@@ -365,8 +374,6 @@ func (v *VariantArray) initReader() {
                vt := v.ExtensionType().(*VariantType)
                st := v.Storage().(*array.Struct)
                metaField := st.Field(vt.metadataFieldIdx)
-               valueField := st.Field(vt.valueFieldIdx)
-
                metadata, ok := metaField.(arrow.TypedArray[[]byte])
                if !ok {
                        // we already validated that if the metadata field 
isn't a binary
@@ -374,24 +381,30 @@ func (v *VariantArray) initReader() {
                        metadata, _ = 
array.NewDictWrapper[[]byte](metaField.(*array.Dictionary))
                }
 
-               if vt.typedValueFieldIdx == -1 {
+               var value arrow.TypedArray[[]byte]
+               if vt.valueFieldIdx != -1 {
+                       valueField := st.Field(vt.valueFieldIdx)
+                       value = valueField.(arrow.TypedArray[[]byte])
+               }
+
+               var ivreader typedValReader
+               var err error
+               if vt.typedValueFieldIdx != -1 {
+                       ivreader, err = 
getReader(st.Field(vt.typedValueFieldIdx))
+                       if err != nil {
+                               v.rdrErr = err
+                               return
+                       }
+                       v.rdr = &shreddedVariantReader{
+                               metadata:   metadata,
+                               value:      value,
+                               typedValue: ivreader,
+                       }
+               } else {
                        v.rdr = &basicVariantReader{
                                metadata: metadata,
-                               value:    valueField.(arrow.TypedArray[[]byte]),
+                               value:    value,
                        }
-                       return
-               }
-
-               ivreader, err := getReader(st.Field(vt.typedValueFieldIdx))
-               if err != nil {
-                       v.rdrErr = err
-                       return
-               }
-
-               v.rdr = &shreddedVariantReader{
-                       metadata:   metadata,
-                       value:      valueField.(arrow.TypedArray[[]byte]),
-                       typedValue: ivreader,
                }
        })
 }
@@ -419,6 +432,9 @@ func (v *VariantArray) Metadata() arrow.TypedArray[[]byte] {
 // value of null).
 func (v *VariantArray) UntypedValues() arrow.TypedArray[[]byte] {
        vt := v.ExtensionType().(*VariantType)
+       if vt.valueFieldIdx == -1 {
+               return nil
+       }
        return 
v.Storage().(*array.Struct).Field(vt.valueFieldIdx).(arrow.TypedArray[[]byte])
 }
 
@@ -451,7 +467,6 @@ func (v *VariantArray) IsNull(i int) bool {
        }
 
        vt := v.ExtensionType().(*VariantType)
-       valArr := v.Storage().(*array.Struct).Field(vt.valueFieldIdx)
        if vt.typedValueFieldIdx != -1 {
                typedArr := 
v.Storage().(*array.Struct).Field(vt.typedValueFieldIdx)
                if !typedArr.IsNull(i) {
@@ -459,6 +474,7 @@ func (v *VariantArray) IsNull(i int) bool {
                }
        }
 
+       valArr := v.Storage().(*array.Struct).Field(vt.valueFieldIdx)
        b := valArr.(arrow.TypedArray[[]byte]).Value(i)
        return len(b) == 1 && b[0] == 0 // variant null
 }
@@ -747,9 +763,20 @@ func getReader(typedArr arrow.Array) (typedValReader, 
error) {
                        childType := child.DataType().(*arrow.StructType)
 
                        valueIdx, _ := childType.FieldIdx("value")
-                       valueArr := 
child.Field(valueIdx).(arrow.TypedArray[[]byte])
+                       var valueArr arrow.TypedArray[[]byte]
+                       if valueIdx != -1 {
+                               valueArr = 
child.Field(valueIdx).(arrow.TypedArray[[]byte])
+                       }
+
+                       typedValueIdx, exists := 
childType.FieldIdx("typed_value")
+                       if !exists {
+                               fieldReaders[fieldList[i].Name] = 
fieldReaderPair{
+                                       values:   valueArr,
+                                       typedVal: nil,
+                               }
+                               continue
+                       }
 
-                       typedValueIdx, _ := childType.FieldIdx("typed_value")
                        typedRdr, err := getReader(child.Field(typedValueIdx))
                        if err != nil {
                                return nil, fmt.Errorf("error getting typed 
value reader for field %s: %w", fieldList[i].Name, err)
@@ -768,13 +795,22 @@ func getReader(typedArr arrow.Array) (typedValReader, 
error) {
        case array.ListLike:
                listValues := arr.ListValues().(*array.Struct)
                elemType := listValues.DataType().(*arrow.StructType)
+
+               var valueArr arrow.TypedArray[[]byte]
+               var typedRdr typedValReader
+
                valueIdx, _ := elemType.FieldIdx("value")
-               valueArr := 
listValues.Field(valueIdx).(arrow.TypedArray[[]byte])
+               if valueIdx != -1 {
+                       valueArr = 
listValues.Field(valueIdx).(arrow.TypedArray[[]byte])
+               }
 
                typedValueIdx, _ := elemType.FieldIdx("typed_value")
-               typedRdr, err := getReader(listValues.Field(typedValueIdx))
-               if err != nil {
-                       return nil, fmt.Errorf("error getting typed value 
reader: %w", err)
+               if typedValueIdx != -1 {
+                       var err error
+                       typedRdr, err = 
getReader(listValues.Field(typedValueIdx))
+                       if err != nil {
+                               return nil, fmt.Errorf("error getting typed 
value reader: %w", err)
+                       }
                }
 
                return &typedListReader{
@@ -796,6 +832,7 @@ func constructVariant(b *variant.Builder, meta 
variant.Metadata, value []byte, t
        switch v := typedVal.(type) {
        case nil:
                if len(value) == 0 {
+                       b.AppendNull()
                        return nil
                }
 
@@ -846,6 +883,9 @@ func constructVariant(b *variant.Builder, meta 
variant.Metadata, value []byte, t
 
                return b.FinishArray(arrstart, elems)
        case []byte:
+               if len(value) > 0 {
+                       return errors.New("invalid variant, conflicting value 
and typed_value")
+               }
                return b.UnsafeAppendEncoded(v)
        default:
                return fmt.Errorf("%w: unsupported typed value type %T for 
variant", arrow.ErrInvalid, v)
@@ -876,14 +916,24 @@ func (v *typedObjReader) Value(meta variant.Metadata, i 
int) (any, error) {
                return nil, nil
        }
 
+       var err error
        result := make(map[string]typedPair)
        for name, rdr := range v.fieldRdrs {
-               typedValue, err := rdr.typedVal.Value(meta, i)
-               if err != nil {
-                       return nil, fmt.Errorf("error reading typed value for 
field %s at index %d: %w", name, i, err)
+               var typedValue any
+               if rdr.typedVal != nil {
+                       typedValue, err = rdr.typedVal.Value(meta, i)
+                       if err != nil {
+                               return nil, fmt.Errorf("error reading typed 
value for field %s at index %d: %w", name, i, err)
+                       }
                }
+
+               var val []byte
+               if rdr.values != nil {
+                       val = rdr.values.Value(i)
+               }
+
                result[name] = typedPair{
-                       Value:      rdr.values.Value(i),
+                       Value:      val,
                        TypedValue: typedValue,
                }
        }
@@ -913,7 +963,11 @@ func (v *typedListReader) Value(meta variant.Metadata, i 
int) (any, error) {
 
        result := make([]typedPair, 0, end-start)
        for j := start; j < end; j++ {
-               val := v.valueArr.Value(int(j))
+               var val []byte
+               if v.valueArr != nil {
+                       val = v.valueArr.Value(int(j))
+               }
+
                typedValue, err := v.typedVal.Value(meta, int(j))
                if err != nil {
                        return nil, fmt.Errorf("error reading typed value at 
index %d: %w", j, err)
@@ -956,12 +1010,17 @@ func (v *shreddedVariantReader) Value(i int) 
(variant.Value, error) {
        }
 
        b := variant.NewBuilderFromMeta(meta)
+       b.SetAllowDuplicates(true)
        typed, err := v.typedValue.Value(meta, i)
        if err != nil {
                return variant.NullValue, fmt.Errorf("error reading typed value 
at index %d: %w", i, err)
        }
 
-       if err := constructVariant(b, meta, v.value.Value(i), typed); err != 
nil {
+       var value []byte
+       if v.value != nil {
+               value = v.value.Value(i)
+       }
+       if err := constructVariant(b, meta, value, typed); err != nil {
                return variant.NullValue, fmt.Errorf("error constructing 
variant at index %d: %w", i, err)
        }
        return b.Build()
diff --git a/arrow/extensions/variant_test.go b/arrow/extensions/variant_test.go
index 6e539ee5..925d0621 100644
--- a/arrow/extensions/variant_test.go
+++ b/arrow/extensions/variant_test.go
@@ -61,21 +61,18 @@ func TestVariantExtensionType(t *testing.T) {
                expectedErr string
        }{
                {arrow.StructOf(arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.Binary}),
-                       "missing non-nullable field 'value'"},
+                       "there must be at least one of 'value' or 'typed_value' 
fields in variant storage type"},
                {arrow.StructOf(arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary}), "missing non-nullable field 'metadata'"},
                {arrow.StructOf(arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.Binary},
                        arrow.Field{Name: "value", Type: 
arrow.PrimitiveTypes.Int32}),
-                       "value field must be non-nullable binary type, got 
int32"},
+                       "value field must be binary type, got int32"},
                {arrow.StructOf(arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.Binary},
                        arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary},
                        arrow.Field{Name: "extra", Type: 
arrow.BinaryTypes.Binary}),
-                       "has 3 fields, but missing 'typed_value' field"},
+                       "has 3 fields, but missing one of 'value' or 
'typed_value' field"},
                {arrow.StructOf(arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
                        arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: false}),
                        "metadata field must be non-nullable binary type"},
-               {arrow.StructOf(arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.Binary, Nullable: false},
-                       arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true}),
-                       "value field must be non-nullable binary type"},
                {arrow.FixedWidthTypes.Boolean, "bad storage type bool for 
variant type"},
                {arrow.StructOf(
                        arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.Binary, Nullable: false},
@@ -86,16 +83,6 @@ func TestVariantExtensionType(t *testing.T) {
                        arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.String, Nullable: false},
                        arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: false}),
                        "metadata field must be non-nullable binary type, got 
utf8"},
-               {arrow.StructOf(
-                       arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.Binary, Nullable: false},
-                       arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: false},
-                       arrow.Field{Name: "typed_value", Type: 
arrow.BinaryTypes.String, Nullable: true}),
-                       "value field must be nullable if typed_value is 
present"},
-               {arrow.StructOf(
-                       arrow.Field{Name: "metadata", Type: 
arrow.BinaryTypes.Binary, Nullable: false},
-                       arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
-                       arrow.Field{Name: "typed_value", Type: 
arrow.BinaryTypes.String, Nullable: false}),
-                       "typed_value field must be nullable"},
        }
 
        for _, tt := range tests {
@@ -126,11 +113,6 @@ func TestVariantExtensionBadNestedTypes(t *testing.T) {
                        ), Nullable: false})},
                {"empty struct elem", arrow.StructOf(
                        arrow.Field{Name: "foobar", Type: arrow.StructOf(), 
Nullable: false})},
-               {"nullable value struct elem",
-                       arrow.StructOf(
-                               arrow.Field{Name: "foobar", Type: 
arrow.StructOf(
-                                       arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
-                               ), Nullable: false})},
                {"non-nullable two elem struct", arrow.StructOf(
                        arrow.Field{Name: "foobar", Type: arrow.StructOf(
                                arrow.Field{Name: "value", Type: 
arrow.BinaryTypes.Binary, Nullable: true},
diff --git a/parquet-testing b/parquet-testing
index 2dc8bf14..a3d96a65 160000
--- a/parquet-testing
+++ b/parquet-testing
@@ -1 +1 @@
-Subproject commit 2dc8bf140ed6e28652fc347211c7d661714c7f95
+Subproject commit a3d96a65e11e2bbca7d22a894e8313ede90a33a3
diff --git a/parquet/pqarrow/encode_arrow.go b/parquet/pqarrow/encode_arrow.go
index cdaba241..5724e9f8 100644
--- a/parquet/pqarrow/encode_arrow.go
+++ b/parquet/pqarrow/encode_arrow.go
@@ -333,7 +333,7 @@ func writeDenseArrow(ctx *arrowWriteContext, cw 
file.ColumnChunkWriter, leafArr
                        case arrow.DECIMAL128:
                                for idx, val := range 
leafArr.(*array.Decimal128).Values() {
                                        debug.Assert(val.HighBits() == 0 || 
val.HighBits() == -1, "casting Decimal128 greater than the value range; high 
bits must be 0 or -1")
-                                       debug.Assert(val.LowBits() <= 
math.MaxUint32, "casting Decimal128 to int32 when value > MaxUint32")
+                                       debug.Assert(int64(val.LowBits()) <= 
math.MaxUint32, "casting Decimal128 to int32 when value > MaxUint32")
                                        data[idx] = int32(val.LowBits())
                                }
                        case arrow.DECIMAL256:
diff --git a/parquet/pqarrow/schema.go b/parquet/pqarrow/schema.go
index 2c0e70b5..7c56e333 100644
--- a/parquet/pqarrow/schema.go
+++ b/parquet/pqarrow/schema.go
@@ -242,7 +242,7 @@ func repFromNullable(isnullable bool) parquet.Repetition {
 }
 
 func variantToNode(t *extensions.VariantType, field arrow.Field, props 
*parquet.WriterProperties, arrProps ArrowWriterProperties) (schema.Node, error) 
{
-       fields := make(schema.FieldList, 2, 3)
+       fields := make(schema.FieldList, 1, 3)
        var err error
 
        fields[0], err = fieldToNode("metadata", t.Metadata(), props, arrProps)
@@ -250,9 +250,12 @@ func variantToNode(t *extensions.VariantType, field 
arrow.Field, props *parquet.
                return nil, err
        }
 
-       fields[1], err = fieldToNode("value", t.Value(), props, arrProps)
-       if err != nil {
-               return nil, err
+       if value := t.Value(); value.Type != nil {
+               valueField, err := fieldToNode("value", value, props, arrProps)
+               if err != nil {
+                       return nil, err
+               }
+               fields = append(fields, valueField)
        }
 
        if typed := t.TypedValue(); typed.Type != nil {
@@ -594,8 +597,9 @@ func getParquetType(typ arrow.DataType, props 
*parquet.WriterProperties, arrprop
                precision := int(dectype.GetPrecision())
                scale := int(dectype.GetScale())
 
+               logicalType := schema.NewDecimalLogicalType(int32(precision), 
int32(scale))
                if !props.StoreDecimalAsInteger() || precision > 18 {
-                       return parquet.Types.FixedLenByteArray, 
schema.NewDecimalLogicalType(int32(precision), int32(scale)), 
int(DecimalSize(int32(precision))), nil
+                       return parquet.Types.FixedLenByteArray, logicalType, 
int(DecimalSize(int32(precision))), nil
                }
 
                pqType := parquet.Types.Int32
@@ -603,7 +607,7 @@ func getParquetType(typ arrow.DataType, props 
*parquet.WriterProperties, arrprop
                        pqType = parquet.Types.Int64
                }
 
-               return pqType, schema.NoLogicalType{}, -1, nil
+               return pqType, logicalType, -1, nil
        case arrow.DATE32:
                return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil
        case arrow.DATE64:
@@ -612,14 +616,14 @@ func getParquetType(typ arrow.DataType, props 
*parquet.WriterProperties, arrprop
                pqType, logicalType, err := 
getTimestampMeta(typ.(*arrow.TimestampType), props, arrprops)
                return pqType, logicalType, -1, err
        case arrow.TIME32:
-               return parquet.Types.Int32, schema.NewTimeLogicalType(true, 
schema.TimeUnitMillis), -1, nil
+               return parquet.Types.Int32, schema.NewTimeLogicalType(false, 
schema.TimeUnitMillis), -1, nil
        case arrow.TIME64:
                pqTimeUnit := schema.TimeUnitMicros
                if typ.(*arrow.Time64Type).Unit == arrow.Nanosecond {
                        pqTimeUnit = schema.TimeUnitNanos
                }
 
-               return parquet.Types.Int64, schema.NewTimeLogicalType(true, 
pqTimeUnit), -1, nil
+               return parquet.Types.Int64, schema.NewTimeLogicalType(false, 
pqTimeUnit), -1, nil
        case arrow.FLOAT16:
                return parquet.Types.FixedLenByteArray, 
schema.Float16LogicalType{}, arrow.Float16SizeBytes, nil
        case arrow.EXTENSION:
diff --git a/parquet/pqarrow/schema_test.go b/parquet/pqarrow/schema_test.go
index 6f3da880..6f5d14c7 100644
--- a/parquet/pqarrow/schema_test.go
+++ b/parquet/pqarrow/schema_test.go
@@ -184,11 +184,11 @@ func TestConvertArrowFlatPrimitives(t *testing.T) {
        arrowFields = append(arrowFields, arrow.Field{Name: "date64", Type: 
arrow.FixedWidthTypes.Date64, Nullable: false})
 
        parquetFields = append(parquetFields, 
schema.Must(schema.NewPrimitiveNodeLogical("time32", 
parquet.Repetitions.Required,
-               schema.NewTimeLogicalType(true, schema.TimeUnitMillis), 
parquet.Types.Int32, 0, -1)))
+               schema.NewTimeLogicalType(false, schema.TimeUnitMillis), 
parquet.Types.Int32, 0, -1)))
        arrowFields = append(arrowFields, arrow.Field{Name: "time32", Type: 
arrow.FixedWidthTypes.Time32ms, Nullable: false})
 
        parquetFields = append(parquetFields, 
schema.Must(schema.NewPrimitiveNodeLogical("time64", 
parquet.Repetitions.Required,
-               schema.NewTimeLogicalType(true, schema.TimeUnitMicros), 
parquet.Types.Int64, 0, -1)))
+               schema.NewTimeLogicalType(false, schema.TimeUnitMicros), 
parquet.Types.Int64, 0, -1)))
        arrowFields = append(arrowFields, arrow.Field{Name: "time64", Type: 
arrow.FixedWidthTypes.Time64us, Nullable: false})
 
        parquetFields = append(parquetFields, 
schema.NewInt96Node("timestamp96", parquet.Repetitions.Required, -1))
diff --git a/parquet/pqarrow/variant_test.go b/parquet/pqarrow/variant_test.go
new file mode 100644
index 00000000..81fa246b
--- /dev/null
+++ b/parquet/pqarrow/variant_test.go
@@ -0,0 +1,326 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow_test
+
+import (
+       "context"
+       "fmt"
+       "io"
+       "iter"
+       "os"
+       "path/filepath"
+       "slices"
+       "strings"
+       "testing"
+       "unsafe"
+
+       "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/endian"
+       "github.com/apache/arrow-go/v18/arrow/extensions"
+       "github.com/apache/arrow-go/v18/arrow/memory"
+       "github.com/apache/arrow-go/v18/internal/json"
+       "github.com/apache/arrow-go/v18/parquet"
+       "github.com/apache/arrow-go/v18/parquet/pqarrow"
+       "github.com/apache/arrow-go/v18/parquet/variant"
+       "github.com/stretchr/testify/suite"
+)
+
+type ShreddedVariantTestSuite struct {
+       suite.Suite
+
+       generate bool
+
+       dirPrefix string
+       outDir    string
+       cases     []Case
+
+       errorCases    []Case
+       singleVariant []Case
+       multiVariant  []Case
+}
+
+func (s *ShreddedVariantTestSuite) SetupSuite() {
+       dir := os.Getenv("PARQUET_TEST_DATA")
+       if dir == "" {
+               s.T().Skip("PARQUET_TEST_DATA environment variable not set")
+       }
+
+       s.dirPrefix = filepath.Join(dir, "..", "shredded_variant")
+       s.outDir = filepath.Join(dir, "..", "go_variant")
+       if s.generate {
+               s.Require().NoError(os.MkdirAll(s.outDir, 0o755), "Failed to 
create output directory: %s", s.outDir)
+       }
+
+       cases, err := os.Open(filepath.Join(s.dirPrefix, "cases.json"))
+       s.Require().NoError(err, "Failed to open cases.json")
+       defer cases.Close()
+
+       s.Require().NoError(json.NewDecoder(cases).Decode(&s.cases))
+
+       s.errorCases = slices.DeleteFunc(slices.Clone(s.cases), func(c Case) 
bool {
+               return c.ErrorMessage == ""
+       })
+
+       s.singleVariant = slices.DeleteFunc(slices.Clone(s.cases), func(c Case) 
bool {
+               return c.ErrorMessage != "" || c.VariantFile == "" || 
len(c.VariantFiles) > 0
+       })
+
+       s.multiVariant = slices.DeleteFunc(slices.Clone(s.cases), func(c Case) 
bool {
+               return c.ErrorMessage != "" || c.VariantFile != "" || 
len(c.VariantFiles) == 0
+       })
+
+       if s.generate {
+               cases.Seek(0, io.SeekStart)
+               outCases, err := os.Create(filepath.Join(s.outDir, 
"cases.json"))
+               s.Require().NoError(err, "Failed to create cases.json")
+               defer outCases.Close()
+
+               io.Copy(outCases, cases)
+               outCases.Sync()
+       }
+}
+
+type Case struct {
+       Number       int       `json:"case_number"`
+       Title        string    `json:"test"`
+       Notes        string    `json:"notes,omitempty"`
+       ParquetFile  string    `json:"parquet_file"`
+       VariantFile  string    `json:"variant_file,omitempty"`
+       VariantFiles []*string `json:"variant_files,omitempty"`
+       VariantData  string    `json:"variant,omitempty"`
+       Variants     string    `json:"variants,omitempty"`
+       ErrorMessage string    `json:"error_message,omitempty"`
+}
+
+func readUnsigned(b []byte) (result uint32) {
+       v := (*[4]byte)(unsafe.Pointer(&result))
+       copy(v[:], b)
+       return endian.FromLE(result)
+}
+
+func (s *ShreddedVariantTestSuite) readVariant(filename string) variant.Value {
+       data, err := os.ReadFile(filename)
+       s.Require().NoError(err, "Failed to read variant file: %s", filename)
+
+       hdr := data[0]
+       offsetSize := int(1 + ((hdr & 0b11000000) >> 6))
+       dictSize := int(readUnsigned(data[1 : 1+offsetSize]))
+       offsetListOffset := 1 + offsetSize
+       dataOffset := offsetListOffset + ((1 + dictSize) * offsetSize)
+
+       idx := offsetListOffset + (offsetSize * dictSize)
+       endOffset := dataOffset + int(readUnsigned(data[idx:idx+offsetSize]))
+       val, err := variant.New(data[:endOffset], data[endOffset:])
+       s.Require().NoError(err, "Failed to create variant from data: %s", 
filename)
+       return val
+}
+
+func (s *ShreddedVariantTestSuite) readParquet(filename string) arrow.Table {
+       file, err := os.Open(filepath.Join(s.dirPrefix, filename))
+       s.Require().NoError(err, "Failed to open Parquet file: %s", filename)
+       defer file.Close()
+
+       tbl, err := pqarrow.ReadTable(context.Background(), file, nil, 
pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+       s.Require().NoError(err, "Failed to read Parquet file: %s", filename)
+       return tbl
+}
+
+func (s *ShreddedVariantTestSuite) writeVariantFile(filename string, val 
variant.Value) {
+       out, err := os.Create(filepath.Join(s.outDir, filename))
+       s.Require().NoError(err)
+       defer out.Close()
+
+       _, err = out.Write(val.Metadata().Bytes())
+       s.Require().NoError(err)
+       _, err = out.Write(val.Bytes())
+       s.Require().NoError(err)
+}
+
+func (s *ShreddedVariantTestSuite) writeParquetFile(filename string, tbl 
arrow.Table) {
+       out, err := os.Create(filepath.Join(s.outDir, filename))
+       s.Require().NoError(err)
+       defer out.Close()
+
+       s.Require().NoError(pqarrow.WriteTable(tbl, out, max(1, tbl.NumRows()), 
parquet.NewWriterProperties(
+               parquet.WithDictionaryDefault(false), parquet.WithStats(false),
+               parquet.WithStoreDecimalAsInteger(true),
+       ), pqarrow.DefaultWriterProps()))
+}
+
+func zip[T, U any](a iter.Seq[T], b iter.Seq[U]) iter.Seq2[T, U] {
+       return func(yield func(T, U) bool) {
+               nexta, stopa := iter.Pull(a)
+               nextb, stopb := iter.Pull(b)
+               defer stopa()
+               defer stopb()
+
+               for {
+                       a, ok := nexta()
+                       if !ok {
+                               return
+                       }
+                       b, ok := nextb()
+                       if !ok {
+                               return
+                       }
+                       if !yield(a, b) {
+                               return
+                       }
+               }
+       }
+}
+
+func (s *ShreddedVariantTestSuite) assertVariantEqual(expected, actual 
variant.Value) {
+       switch expected.BasicType() {
+       case variant.BasicObject:
+               exp := expected.Value().(variant.ObjectValue)
+               act := actual.Value().(variant.ObjectValue)
+
+               s.Equal(exp.NumElements(), act.NumElements(), "Expected %d 
elements in object, got %d", exp.NumElements(), act.NumElements())
+               for i := range exp.NumElements() {
+                       expectedField, err := exp.FieldAt(i)
+                       s.Require().NoError(err, "Failed to get expected field 
at index %d", i)
+                       actualField, err := act.FieldAt(i)
+                       s.Require().NoError(err, "Failed to get actual field at 
index %d", i)
+
+                       s.Equal(expectedField.Key, actualField.Key, "Expected 
field key %s, got %s", expectedField.Key, actualField.Key)
+                       s.assertVariantEqual(expectedField.Value, 
actualField.Value)
+               }
+       case variant.BasicArray:
+               exp := expected.Value().(variant.ArrayValue)
+               act := actual.Value().(variant.ArrayValue)
+
+               s.Equal(exp.Len(), act.Len(), "Expected array length %d, got 
%d", exp.Len(), act.Len())
+               for e, a := range zip(exp.Values(), act.Values()) {
+                       s.assertVariantEqual(e, a)
+               }
+       default:
+               switch expected.Type() {
+               case variant.Decimal4, variant.Decimal8, variant.Decimal16:
+                       e, err := json.Marshal(expected.Value())
+                       s.Require().NoError(err, "Failed to marshal expected 
value")
+                       a, err := json.Marshal(actual.Value())
+                       s.Require().NoError(err, "Failed to marshal actual 
value")
+                       s.JSONEq(string(e), string(a), "Expected variant value 
%s, got %s", e, a)
+               default:
+                       s.EqualValues(expected.Value(), actual.Value(), 
"Expected variant value %v, got %v", expected.Value(), actual.Value())
+               }
+       }
+}
+
+func (s *ShreddedVariantTestSuite) TestSingleVariantCases() {
+       for _, c := range s.singleVariant {
+               s.Run(c.Title, func() {
+                       s.Run(fmt.Sprint(c.Number), func() {
+                               if strings.Contains(c.ParquetFile, "-INVALID") {
+                                       s.T().Skip(c.Notes)
+                               }
+
+                               expected := 
s.readVariant(filepath.Join(s.dirPrefix, c.VariantFile))
+                               if s.generate {
+                                       s.writeVariantFile(c.VariantFile, 
expected)
+                               }
+
+                               tbl := s.readParquet(c.ParquetFile)
+                               defer tbl.Release()
+
+                               if s.generate {
+                                       s.writeParquetFile(c.ParquetFile, tbl)
+                               }
+
+                               col := tbl.Column(1).Data().Chunk(0)
+                               s.Require().IsType(&extensions.VariantArray{}, 
col)
+
+                               variantArray := col.(*extensions.VariantArray)
+                               s.Require().Equal(1, variantArray.Len(), 
"Expected single variant value")
+
+                               val, err := variantArray.Value(0)
+                               s.Require().NoError(err, "Failed to get variant 
value from array")
+                               s.assertVariantEqual(expected, val)
+                       })
+               })
+       }
+}
+
+func (s *ShreddedVariantTestSuite) TestMultiVariantCases() {
+       for _, c := range s.multiVariant {
+               s.Run(c.Title, func() {
+                       s.Run(fmt.Sprint(c.Number), func() {
+                               tbl := s.readParquet(c.ParquetFile)
+                               defer tbl.Release()
+
+                               if s.generate {
+                                       s.writeParquetFile(c.ParquetFile, tbl)
+                               }
+
+                               s.Require().EqualValues(len(c.VariantFiles), 
tbl.NumRows(), "Expected number of rows to match number of variant files")
+                               col := tbl.Column(1).Data().Chunk(0)
+                               s.Require().IsType(&extensions.VariantArray{}, 
col)
+
+                               variantArray := col.(*extensions.VariantArray)
+                               for i, variantFile := range c.VariantFiles {
+                                       if variantFile == nil {
+                                               s.True(variantArray.IsNull(i), 
"Expected null value at index %d", i)
+                                               continue
+                                       }
+
+                                       expected := 
s.readVariant(filepath.Join(s.dirPrefix, *variantFile))
+                                       if s.generate {
+                                               
s.writeVariantFile(*variantFile, expected)
+                                       }
+
+                                       actual, err := variantArray.Value(i)
+                                       s.Require().NoError(err, "Failed to get 
variant value at index %d", i)
+                                       s.assertVariantEqual(expected, actual)
+                               }
+                       })
+               })
+       }
+}
+
+func (s *ShreddedVariantTestSuite) TestErrorCases() {
+       for _, c := range s.errorCases {
+               s.Run(c.Title, func() {
+                       s.Run(fmt.Sprint(c.Number), func() {
+                               switch c.Number {
+                               case 127:
+                                       s.T().Skip("Skipping case 127: test 
says uint32 should error, we just upcast to int64")
+                               case 137:
+                                       s.T().Skip("Skipping case 137: test 
says flba(4) should error, we just treat it as a binary variant")
+                               }
+
+                               tbl := s.readParquet(c.ParquetFile)
+                               defer tbl.Release()
+
+                               if s.generate {
+                                       s.writeParquetFile(c.ParquetFile, tbl)
+                               }
+
+                               col := tbl.Column(1).Data().Chunk(0)
+                               s.Require().IsType(&extensions.VariantArray{}, 
col)
+
+                               variantArray := col.(*extensions.VariantArray)
+                               _, err := variantArray.Value(0)
+                               s.Error(err, "Expected error for case %d: %s", 
c.Number, c.ErrorMessage)
+                       })
+               })
+       }
+}
+
+func TestShreddedVariantExamples(t *testing.T) {
+       suite.Run(t, &ShreddedVariantTestSuite{generate: false})
+}
diff --git a/parquet/schema/logical_types.go b/parquet/schema/logical_types.go
index 0c0ce559..e7f1c29f 100644
--- a/parquet/schema/logical_types.go
+++ b/parquet/schema/logical_types.go
@@ -24,6 +24,7 @@ import (
        "github.com/apache/arrow-go/v18/parquet"
        "github.com/apache/arrow-go/v18/parquet/internal/debug"
        format "github.com/apache/arrow-go/v18/parquet/internal/gen-go/parquet"
+       "github.com/apache/thrift/lib/go/thrift"
 )
 
 // DecimalMetadata is a struct for managing scale and precision information 
between
@@ -1139,7 +1140,7 @@ func (VariantLogicalType) IsCompatible(ct ConvertedType, 
_ DecimalMetadata) bool
 func (VariantLogicalType) IsApplicable(parquet.Type, int32) bool { return 
false }
 
 func (VariantLogicalType) toThrift() *format.LogicalType {
-       return &format.LogicalType{VARIANT: format.NewVariantType()}
+       return &format.LogicalType{VARIANT: 
&format.VariantType{SpecificationVersion: thrift.Int8Ptr(1)}}
 }
 
 func (VariantLogicalType) Equals(rhs LogicalType) bool {
diff --git a/parquet/variant/builder.go b/parquet/variant/builder.go
index 194814c6..68fc178d 100644
--- a/parquet/variant/builder.go
+++ b/parquet/variant/builder.go
@@ -887,7 +887,7 @@ func (b *Builder) Build() (Value, error) {
 type variantPrimitiveType interface {
        constraints.Integer | constraints.Float | string | []byte |
                arrow.Date32 | arrow.Time64 | arrow.Timestamp | bool |
-               uuid.UUID | DecimalValue[decimal.Decimal32] |
+               uuid.UUID | DecimalValue[decimal.Decimal32] | time.Time |
                DecimalValue[decimal.Decimal64] | 
DecimalValue[decimal.Decimal128]
 }
 
@@ -895,17 +895,25 @@ type variantPrimitiveType interface {
 // variant value. At the moment this is just delegating to the 
[Builder.Append] method,
 // but in the future it will be optimized to avoid the extra overhead and 
reduce allocations.
 func Encode[T variantPrimitiveType](v T, opt ...AppendOpt) ([]byte, error) {
+       out, err := Of(v, opt...)
+       if err != nil {
+               return nil, fmt.Errorf("failed to encode variant value: %w", 
err)
+       }
+       return out.value, nil
+}
+
+func Of[T variantPrimitiveType](v T, opt ...AppendOpt) (Value, error) {
        var b Builder
        if err := b.Append(v, opt...); err != nil {
-               return nil, fmt.Errorf("failed to append value: %w", err)
+               return Value{}, fmt.Errorf("failed to append value: %w", err)
        }
 
        val, err := b.Build()
        if err != nil {
-               return nil, fmt.Errorf("failed to build variant value: %w", err)
+               return Value{}, fmt.Errorf("failed to build variant value: %w", 
err)
        }
 
-       return val.value, nil
+       return val, nil
 }
 
 func ParseJSON(data string, allowDuplicateKeys bool) (Value, error) {
diff --git a/parquet/variant/builder_test.go b/parquet/variant/builder_test.go
index 09fa80eb..982fa4e9 100644
--- a/parquet/variant/builder_test.go
+++ b/parquet/variant/builder_test.go
@@ -57,9 +57,7 @@ func TestBuildPrimitive(t *testing.T) {
                {"primitive_int8", func(b *variant.Builder) error { return 
b.AppendInt(42) }},
                {"primitive_int16", func(b *variant.Builder) error { return 
b.AppendInt(1234) }},
                {"primitive_int32", func(b *variant.Builder) error { return 
b.AppendInt(123456) }},
-               // FIXME: https://github.com/apache/parquet-testing/issues/82
-               // primitive_int64 is an int32 value, but the metadata is int64
-               {"primitive_int64", func(b *variant.Builder) error { return 
b.AppendInt(12345678) }},
+               {"primitive_int64", func(b *variant.Builder) error { return 
b.AppendInt(1234567890123456789) }},
                {"primitive_float", func(b *variant.Builder) error { return 
b.AppendFloat32(1234568000) }},
                {"primitive_double", func(b *variant.Builder) error { return 
b.AppendFloat64(1234567890.1234) }},
                {"primitive_string", func(b *variant.Builder) error {
diff --git a/parquet/variant/variant.go b/parquet/variant/variant.go
index 800b7eb2..254bc3c3 100644
--- a/parquet/variant/variant.go
+++ b/parquet/variant/variant.go
@@ -650,7 +650,10 @@ func (v Value) Value() any {
                }
        case BasicShortString:
                sz := int(v.value[0] >> 2)
-               return unsafe.String(&v.value[1], sz)
+               if sz > 0 {
+                       return unsafe.String(&v.value[1], sz)
+               }
+               return ""
        case BasicObject:
                valueHdr := (v.value[0] >> basicTypeBits)
                fieldOffsetSz := (valueHdr & 0b11) + 1
diff --git a/parquet/variant/variant_test.go b/parquet/variant/variant_test.go
index 2ef4da38..c623f646 100644
--- a/parquet/variant/variant_test.go
+++ b/parquet/variant/variant_test.go
@@ -152,9 +152,7 @@ func TestPrimitiveVariants(t *testing.T) {
                {"primitive_int8", int8(42), variant.Int8, "42"},
                {"primitive_int16", int16(1234), variant.Int16, "1234"},
                {"primitive_int32", int32(123456), variant.Int32, "123456"},
-               // FIXME: https://github.com/apache/parquet-testing/issues/82
-               // primitive_int64 is an int32 value, but the metadata is int64
-               {"primitive_int64", int32(12345678), variant.Int32, "12345678"},
+               {"primitive_int64", int64(1234567890123456789), variant.Int64, 
"1234567890123456789"},
                {"primitive_float", float32(1234567940.0), variant.Float, 
"1234568000"},
                {"primitive_double", float64(1234567890.1234), variant.Double, 
"1234567890.1234"},
                {"primitive_string",


Reply via email to