This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 30e8856dc feat(go/adbc/sqldriver): handle timestamp/time.Time values
for input (#3109)
30e8856dc is described below
commit 30e8856dc89a41d2e412a7ac5279053b753c6d24
Author: Matt Topol <[email protected]>
AuthorDate: Wed Jul 9 14:51:11 2025 -0400
feat(go/adbc/sqldriver): handle timestamp/time.Time values for input (#3109)
fixes #3103
Adding new cases to `arrFromVal` to allow for handling `time.Time` and
`arrow.Timestamp` and `arrow.Time32`/`arrow.Time64` types. This only
works when the parameter schema is provided by the FlightSQL server side
when utilizing a prepared statement or otherwise.
If we don't have the parameter schema, then it will error as usual.
---
go/adbc/sqldriver/driver.go | 95 +++++++++++++++++++++++++-----
go/adbc/sqldriver/driver_internals_test.go | 69 +++++++++++++++++++++-
2 files changed, 147 insertions(+), 17 deletions(-)
diff --git a/go/adbc/sqldriver/driver.go b/go/adbc/sqldriver/driver.go
index 9515fda6e..03b6eecb6 100644
--- a/go/adbc/sqldriver/driver.go
+++ b/go/adbc/sqldriver/driver.go
@@ -413,10 +413,9 @@ func (s *stmt) CheckNamedValue(val *driver.NamedValue)
error {
return nil
}
-func arrFromVal(val any) arrow.Array {
+func arrFromVal(val any, dt arrow.DataType) (arrow.Array, error) {
var (
buffers = make([]*memory.Buffer, 2)
- dt arrow.DataType
)
switch v := val.(type) {
case bool:
@@ -459,17 +458,65 @@ func arrFromVal(val any) arrow.Array {
dt = arrow.PrimitiveTypes.Date64
buffers[1] =
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
case []byte:
- dt = arrow.BinaryTypes.Binary
- buffers[1] =
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
+ if dt == nil || dt.ID() == arrow.BINARY {
+ dt = arrow.BinaryTypes.Binary
+ buffers[1] =
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
+ } else if dt.ID() == arrow.LARGE_BINARY {
+ dt = arrow.BinaryTypes.LargeBinary
+ buffers[1] =
memory.NewBufferBytes(arrow.Int64Traits.CastToBytes([]int64{0, int64(len(v))}))
+ }
buffers = append(buffers, memory.NewBufferBytes(v))
case string:
- dt = arrow.BinaryTypes.String
- buffers[1] =
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
-
+ if dt == nil || dt.ID() == arrow.STRING {
+ dt = arrow.BinaryTypes.String
+ buffers[1] =
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes([]int32{0, int32(len(v))}))
+ } else if dt.ID() == arrow.LARGE_STRING {
+ dt = arrow.BinaryTypes.LargeString
+ buffers[1] =
memory.NewBufferBytes(arrow.Int64Traits.CastToBytes([]int64{0, int64(len(v))}))
+ }
buf := unsafe.Slice(unsafe.StringData(v), len(v))
buffers = append(buffers, memory.NewBufferBytes(buf))
+ case arrow.Time32:
+ if dt == nil || dt.ID() != arrow.TIME32 {
+ return nil, errors.New("can only create array from
arrow.Time32 with a provided parameter schema")
+ }
+
+ buffers[1] =
memory.NewBufferBytes((*[4]byte)(unsafe.Pointer(&v))[:])
+ case arrow.Time64:
+ if dt == nil || dt.ID() != arrow.TIME64 {
+ return nil, errors.New("can only create array from
arrow.Time64 with a provided parameter schema")
+ }
+
+ buffers[1] =
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
+ case arrow.Timestamp:
+ if dt == nil || dt.ID() != arrow.TIMESTAMP {
+ return nil, errors.New("can only create array from
arrow.Timestamp with a provided parameter schema")
+ }
+
+ buffers[1] =
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&v))[:])
+ case time.Time:
+ if dt == nil {
+ return nil, errors.New("can only create array from
time.Time with a provided parameter schema")
+ }
+
+ switch dt.ID() {
+ case arrow.DATE32:
+ val := arrow.Date32FromTime(v)
+ buffers[1] =
memory.NewBufferBytes((*[4]byte)(unsafe.Pointer(&val))[:])
+ case arrow.DATE64:
+ val := arrow.Date64FromTime(v)
+ buffers[1] =
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&val))[:])
+ case arrow.TIMESTAMP:
+ val, err := arrow.TimestampFromTime(v,
dt.(*arrow.TimestampType).Unit)
+ if err != nil {
+ return nil, fmt.Errorf("could not convert
time.Time to arrow.Timestamp: %v", err)
+ }
+ buffers[1] =
memory.NewBufferBytes((*[8]byte)(unsafe.Pointer(&val))[:])
+ default:
+ return nil, fmt.Errorf("time.Time with type %s
unsupported", dt)
+ }
default:
- panic(fmt.Sprintf("unsupported type %T", val))
+ return nil, fmt.Errorf("unsupported type %T", val)
}
for _, b := range buffers {
if b != nil {
@@ -478,10 +525,10 @@ func arrFromVal(val any) arrow.Array {
}
data := array.NewData(dt, 1, buffers, nil, 0, 0)
defer data.Release()
- return array.MakeFromData(data)
+ return array.MakeFromData(data), nil
}
-func createBoundRecord(values []driver.NamedValue, schema *arrow.Schema)
arrow.Record {
+func createBoundRecord(values []driver.NamedValue, schema *arrow.Schema)
(arrow.Record, error) {
fields := make([]arrow.Field, len(values))
cols := make([]arrow.Array, len(values))
if schema == nil {
@@ -492,13 +539,16 @@ func createBoundRecord(values []driver.NamedValue, schema
*arrow.Schema) arrow.R
} else {
f.Name = v.Name
}
- arr := arrFromVal(v.Value)
+ arr, err := arrFromVal(v.Value, nil)
+ if err != nil {
+ return nil, err
+ }
defer arr.Release()
f.Type = arr.DataType()
cols[v.Ordinal-1] = arr
}
- return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1)
+ return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1),
nil
}
for _, v := range values {
@@ -514,17 +564,25 @@ func createBoundRecord(values []driver.NamedValue, schema
*arrow.Schema) arrow.R
f := &fields[idx]
f.Name = name
- arr := arrFromVal(v.Value)
+ arr, err := arrFromVal(v.Value, f.Type)
+ if err != nil {
+ return nil, err
+ }
defer arr.Release()
f.Type = arr.DataType()
cols[idx] = arr
}
- return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1)
+ return array.NewRecord(arrow.NewSchema(fields, nil), cols, 1), nil
}
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue)
(driver.Result, error) {
if len(args) > 0 {
- if err := s.stmt.Bind(ctx, createBoundRecord(args,
s.paramSchema)); err != nil {
+ rec, err := createBoundRecord(args, s.paramSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := s.stmt.Bind(ctx, rec); err != nil {
return nil, err
}
}
@@ -539,7 +597,12 @@ func (s *stmt) ExecContext(ctx context.Context, args
[]driver.NamedValue) (drive
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue)
(driver.Rows, error) {
if len(args) > 0 {
- if err := s.stmt.Bind(ctx, createBoundRecord(args,
s.paramSchema)); err != nil {
+ rec, err := createBoundRecord(args, s.paramSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := s.stmt.Bind(ctx, rec); err != nil {
return nil, err
}
}
diff --git a/go/adbc/sqldriver/driver_internals_test.go
b/go/adbc/sqldriver/driver_internals_test.go
index 539d55735..0c42e5ac2 100644
--- a/go/adbc/sqldriver/driver_internals_test.go
+++ b/go/adbc/sqldriver/driver_internals_test.go
@@ -146,6 +146,11 @@ var (
Name: "int",
Type: arrow.PrimitiveTypes.Int32,
}
+
+ tstampSec, _ = arrow.TimestampFromTime(testTime, arrow.Second)
+ tstampMilli, _ = arrow.TimestampFromTime(testTime, arrow.Millisecond)
+ tstampMicro, _ = arrow.TimestampFromTime(testTime, arrow.Microsecond)
+ tstampNano, _ = arrow.TimestampFromTime(testTime, arrow.Nanosecond)
)
func TestNextRowTypes(t *testing.T) {
@@ -328,6 +333,7 @@ func TestNextRowTypes(t *testing.T) {
func TestArrFromVal(t *testing.T) {
tests := []struct {
value any
+ inputDataType arrow.DataType
expectedDataType arrow.DataType
expectedStringValue string
}{
@@ -401,15 +407,76 @@ func TestArrFromVal(t *testing.T) {
expectedDataType: arrow.BinaryTypes.Binary,
expectedStringValue:
base64.StdEncoding.EncodeToString([]byte("my-string")),
},
+ {
+ value: []byte("my-string"),
+ inputDataType: arrow.BinaryTypes.LargeBinary,
+ expectedDataType: arrow.BinaryTypes.LargeBinary,
+ expectedStringValue:
base64.StdEncoding.EncodeToString([]byte("my-string")),
+ },
{
value: "my-string",
expectedDataType: arrow.BinaryTypes.String,
expectedStringValue: "my-string",
},
+ {
+ value: "my-string",
+ inputDataType: arrow.BinaryTypes.LargeString,
+ expectedDataType: arrow.BinaryTypes.LargeString,
+ expectedStringValue: "my-string",
+ },
+ {
+ value: tstampSec,
+ inputDataType: &arrow.TimestampType{Unit:
arrow.Second},
+ expectedDataType: &arrow.TimestampType{Unit:
arrow.Second},
+ expectedStringValue:
testTime.UTC().Truncate(time.Second).Format("2006-01-02 15:04:05Z"),
+ },
+ {
+ value: tstampMilli,
+ inputDataType: &arrow.TimestampType{Unit:
arrow.Millisecond},
+ expectedDataType: &arrow.TimestampType{Unit:
arrow.Millisecond},
+ expectedStringValue:
testTime.UTC().Truncate(time.Millisecond).Format("2006-01-02 15:04:05.000Z"),
+ },
+ {
+ value: tstampMicro,
+ inputDataType: &arrow.TimestampType{Unit:
arrow.Microsecond},
+ expectedDataType: &arrow.TimestampType{Unit:
arrow.Microsecond},
+ expectedStringValue:
testTime.UTC().Truncate(time.Microsecond).Format("2006-01-02 15:04:05.000000Z"),
+ },
+ {
+ value: tstampNano,
+ inputDataType: &arrow.TimestampType{Unit:
arrow.Nanosecond},
+ expectedDataType: &arrow.TimestampType{Unit:
arrow.Nanosecond},
+ expectedStringValue:
testTime.UTC().Truncate(time.Nanosecond).Format("2006-01-02
15:04:05.000000000Z"),
+ },
+ {
+ value: testTime,
+ inputDataType: &arrow.TimestampType{Unit:
arrow.Second},
+ expectedDataType: &arrow.TimestampType{Unit:
arrow.Second},
+ expectedStringValue:
testTime.UTC().Truncate(time.Second).Format("2006-01-02 15:04:05Z"),
+ },
+ {
+ value: testTime,
+ inputDataType: &arrow.TimestampType{Unit:
arrow.Millisecond},
+ expectedDataType: &arrow.TimestampType{Unit:
arrow.Millisecond},
+ expectedStringValue:
testTime.UTC().Truncate(time.Millisecond).Format("2006-01-02 15:04:05.000Z"),
+ },
+ {
+ value: testTime,
+ inputDataType: &arrow.TimestampType{Unit:
arrow.Microsecond},
+ expectedDataType: &arrow.TimestampType{Unit:
arrow.Microsecond},
+ expectedStringValue:
testTime.UTC().Truncate(time.Microsecond).Format("2006-01-02 15:04:05.000000Z"),
+ },
+ {
+ value: testTime,
+ inputDataType: &arrow.TimestampType{Unit:
arrow.Nanosecond},
+ expectedDataType: &arrow.TimestampType{Unit:
arrow.Nanosecond},
+ expectedStringValue:
testTime.UTC().Truncate(time.Nanosecond).Format("2006-01-02
15:04:05.000000000Z"),
+ },
}
for i, test := range tests {
t.Run(fmt.Sprintf("%d-%T", i, test.value), func(t *testing.T) {
- arr := arrFromVal(test.value)
+ arr, err := arrFromVal(test.value, test.inputDataType)
+ require.NoError(t, err)
assert.Equal(t, test.expectedDataType, arr.DataType())
require.Equal(t, 1, arr.Len())