This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new c3f3eca #95 Improve Compile-Time Safety for Literals
c3f3eca is described below
commit c3f3eca84a02c803ed8a1f9ebb7c142ff7706111
Author: Martin Grund <[email protected]>
AuthorDate: Mon Dec 30 17:01:20 2024 +0100
#95 Improve Compile-Time Safety for Literals
### What changes were proposed in this pull request?
This patch improves the compile time safety for literals by providing
specific interface types. This will make using specific types much easier to
use.
This patch adds a number of helper methods for dealing with literals for
example `column.StringLit()` or `column.Int8Lit()`. In addition it fixes some
longer standing bugs around dealing with nil values.
### Why are the changes needed?
Usability, Safety, Stability
### Does this PR introduce _any_ user-facing change?
Added new methods for using the type interfaces for conversion.
### How was this patch tested?
Added testing.
Closes #96 from grundprinzip/safety_95.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
cmd/spark-connect-example-spark-session/main.go | 4 +-
dev/gen.py | 6 +-
internal/tests/integration/dataframe_test.go | 10 +-
internal/tests/integration/functions_test.go | 2 +-
internal/tests/integration/sql_test.go | 2 +-
spark/sql/column/expressions.go | 45 +--
spark/sql/dataframe_test.go | 2 +-
spark/sql/functions/buiitins.go | 43 ++-
spark/sql/functions/generated.go | 114 ++++---
spark/sql/group.go | 8 +-
spark/sql/sparksession.go | 4 +
spark/sql/types/arrow.go | 90 +++++-
spark/sql/types/builtin.go | 404 ++++++++++++++++++++++++
spark/sql/types/builtin_test.go | 83 +++++
14 files changed, 689 insertions(+), 128 deletions(-)
diff --git a/cmd/spark-connect-example-spark-session/main.go
b/cmd/spark-connect-example-spark-session/main.go
index 516d87b..8c6aa90 100644
--- a/cmd/spark-connect-example-spark-session/main.go
+++ b/cmd/spark-connect-example-spark-session/main.go
@@ -22,6 +22,8 @@ import (
"fmt"
"log"
+ "github.com/apache/spark-connect-go/v35/spark/sql/types"
+
"github.com/apache/spark-connect-go/v35/spark/sql/functions"
"github.com/apache/spark-connect-go/v35/spark/sql"
@@ -68,7 +70,7 @@ func main() {
}
df, _ = spark.Sql(ctx, "select * from range(100)")
- df, err = df.Filter(ctx, functions.Col("id").Lt(functions.Lit(20)))
+ df, err = df.Filter(ctx,
functions.Col("id").Lt(functions.Lit(types.Int64(20))))
if err != nil {
log.Fatalf("Failed: %s", err)
}
diff --git a/dev/gen.py b/dev/gen.py
index 3b307e3..63ecd43 100644
--- a/dev/gen.py
+++ b/dev/gen.py
@@ -122,15 +122,15 @@ for fun in F.__dict__:
args.append(p)
elif param.annotation == str or typing.get_args(param.annotation) ==
(str, types.NoneType):
res_params.append(f"{p} string")
- conversions.append(f"lit_{p} := Lit({p})")
+ conversions.append(f"lit_{p} := StringLit({p})")
args.append(f"lit_{p}")
elif param.annotation == int or typing.get_args(param.annotation) ==
(int, types.NoneType):
res_params.append(f"{p} int64")
- conversions.append(f"lit_{p} := Lit({p})")
+ conversions.append(f"lit_{p} := Int64Lit({p})")
args.append(f"lit_{p}")
elif param.annotation == float or typing.get_args(param.annotation) ==
(float, types.NoneType):
res_params.append(f"{p} float64")
- conversions.append(f"lit_{p} := Lit({p})")
+ conversions.append(f"lit_{p} := Float64Lit({p})")
args.append(f"lit_{p}")
else:
valid = False
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index f2265f5..e00b375 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -39,7 +39,7 @@ func TestDataFrame_Select(t *testing.T) {
assert.NoError(t, err)
df, err := spark.Sql(ctx, "select * from range(100)")
assert.NoError(t, err)
- df, err = df.Select(ctx, functions.Lit("1"), functions.Lit("2"))
+ df, err = df.Select(ctx, functions.StringLit("1"),
functions.StringLit("2"))
assert.NoError(t, err)
res, err := df.Collect(ctx)
@@ -210,7 +210,7 @@ func TestDataFrame_WithColumn(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select * from range(10)")
assert.NoError(t, err)
- df, err = df.WithColumn(ctx, "newCol", functions.Lit(1))
+ df, err = df.WithColumn(ctx, "newCol", functions.IntLit(1))
assert.NoError(t, err)
res, err := df.Collect(ctx)
assert.NoError(t, err)
@@ -226,8 +226,8 @@ func TestDataFrame_WithColumns(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select * from range(10)")
assert.NoError(t, err)
- df, err = df.WithColumns(ctx, column.WithAlias("newCol1",
functions.Lit(1)),
- column.WithAlias("newCol2", functions.Lit(2)))
+ df, err = df.WithColumns(ctx, column.WithAlias("newCol1",
functions.IntLit(1)),
+ column.WithAlias("newCol2", functions.IntLit(2)))
assert.NoError(t, err)
res, err := df.Collect(ctx)
assert.NoError(t, err)
@@ -592,7 +592,7 @@ func TestDataFrame_Pivot(t *testing.T) {
df, err := spark.CreateDataFrame(ctx, data, schema)
assert.NoError(t, err)
gd := df.GroupBy(functions.Col("year"))
- gd, err = gd.Pivot(ctx, "course", []any{"Java", "dotNET"})
+ gd, err = gd.Pivot(ctx, "course",
[]types.LiteralType{types.String("Java"), types.String("dotNET")})
assert.NoError(t, err)
df, err = gd.Sum(ctx, "earnings")
assert.NoError(t, err)
diff --git a/internal/tests/integration/functions_test.go
b/internal/tests/integration/functions_test.go
index 1f89923..33f3352 100644
--- a/internal/tests/integration/functions_test.go
+++ b/internal/tests/integration/functions_test.go
@@ -33,7 +33,7 @@ func TestIntegration_BuiltinFunctions(t *testing.T) {
}
df, _ := spark.Sql(ctx, "select '[2]' as a from range(10)")
- df, _ = df.Filter(ctx,
functions.JsonArrayLength(functions.Col("a")).Eq(functions.Lit(1)))
+ df, _ = df.Filter(ctx,
functions.JsonArrayLength(functions.Col("a")).Eq(functions.IntLit(1)))
res, err := df.Collect(ctx)
assert.NoError(t, err)
assert.Equal(t, 10, len(res))
diff --git a/internal/tests/integration/sql_test.go
b/internal/tests/integration/sql_test.go
index ba2dccf..28c32d4 100644
--- a/internal/tests/integration/sql_test.go
+++ b/internal/tests/integration/sql_test.go
@@ -46,7 +46,7 @@ func TestIntegration_RunSQLCommand(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 100, len(res))
- df, err = df.Filter(ctx, column.OfDF(df, "id").Lt(functions.Lit(10)))
+ df, err = df.Filter(ctx, column.OfDF(df, "id").Lt(functions.IntLit(10)))
assert.NoError(t, err)
res, err = df.Collect(ctx)
assert.NoErrorf(t, err, "Must be able to collect the rows.")
diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go
index fe50d16..a2ffa24 100644
--- a/spark/sql/column/expressions.go
+++ b/spark/sql/column/expressions.go
@@ -20,6 +20,8 @@ import (
"fmt"
"strings"
+ "github.com/apache/spark-connect-go/v35/spark/sql/types"
+
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
proto "github.com/apache/spark-connect-go/v35/internal/generated"
@@ -308,52 +310,17 @@ func (s *sqlExression) ToProto(context.Context)
(*proto.Expression, error) {
}
type literalExpression struct {
- value any
+ value types.LiteralType
}
func (l *literalExpression) DebugString() string {
return fmt.Sprintf("%v", l.value)
}
-func (l *literalExpression) ToProto(context.Context) (*proto.Expression,
error) {
- expr := newProtoExpression()
- expr.ExprType = &proto.Expression_Literal_{
- Literal: &proto.Expression_Literal{},
- }
- switch v := l.value.(type) {
- case int8:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Byte{Byte: int32(v)}
- case int16:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Short{Short: int32(v)}
- case int32:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Integer{Integer: v}
- case int64:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Long{Long: v}
- case uint8:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Short{Short: int32(v)}
- case uint16:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Integer{Integer: int32(v)}
- case uint32:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Long{Long: int64(v)}
- case float32:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Float{Float: v}
- case float64:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Double{Double: v}
- case string:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_String_{String_: v}
- case bool:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Boolean{Boolean: v}
- case []byte:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Binary{Binary: v}
- case int:
- expr.GetLiteral().LiteralType =
&proto.Expression_Literal_Long{Long: int64(v)}
- default:
- return nil, sparkerrors.WithType(sparkerrors.InvalidPlanError,
- fmt.Errorf("unsupported literal type %T", v))
- }
- return expr, nil
+func (l *literalExpression) ToProto(ctx context.Context) (*proto.Expression,
error) {
+ return l.value.ToProto(ctx)
}
-func NewLiteral(value any) expression {
+func NewLiteral(value types.LiteralType) expression {
return &literalExpression{value: value}
}
diff --git a/spark/sql/dataframe_test.go b/spark/sql/dataframe_test.go
index 39145b0..afa6111 100644
--- a/spark/sql/dataframe_test.go
+++ b/spark/sql/dataframe_test.go
@@ -41,7 +41,7 @@ func TestDataFrameImpl_GroupBy(t *testing.T) {
assert.Equal(t, gd.groupType, "groupby")
- df, err := gd.Agg(ctx, functions.Count(functions.Lit(1)))
+ df, err := gd.Agg(ctx, functions.Count(functions.Int64Lit(1)))
assert.Nil(t, err)
impl := df.(*dataFrameImpl)
assert.NotNil(t, impl)
diff --git a/spark/sql/functions/buiitins.go b/spark/sql/functions/buiitins.go
index a2a7bf8..6df737c 100644
--- a/spark/sql/functions/buiitins.go
+++ b/spark/sql/functions/buiitins.go
@@ -17,6 +17,7 @@ package functions
import (
"github.com/apache/spark-connect-go/v35/spark/sql/column"
+ "github.com/apache/spark-connect-go/v35/spark/sql/types"
)
func Expr(expr string) column.Column {
@@ -27,6 +28,46 @@ func Col(name string) column.Column {
return column.NewColumn(column.NewColumnReference(name))
}
-func Lit(value any) column.Column {
+func Lit(value types.LiteralType) column.Column {
return column.NewColumn(column.NewLiteral(value))
}
+
+func Int8Lit(value int8) column.Column {
+ return Lit(types.Int8(value))
+}
+
+func Int16Lit(value int16) column.Column {
+ return Lit(types.Int16(value))
+}
+
+func Int32Lit(value int32) column.Column {
+ return Lit(types.Int32(value))
+}
+
+func Int64Lit(value int64) column.Column {
+ return Lit(types.Int64(value))
+}
+
+func Float32Lit(value float32) column.Column {
+ return Lit(types.Float32(value))
+}
+
+func Float64Lit(value float64) column.Column {
+ return Lit(types.Float64(value))
+}
+
+func StringLit(value string) column.Column {
+ return Lit(types.String(value))
+}
+
+func BoolLit(value bool) column.Column {
+ return Lit(types.Boolean(value))
+}
+
+func BinaryLit(value []byte) column.Column {
+ return Lit(types.Binary(value))
+}
+
+func IntLit(value int) column.Column {
+ return Lit(types.Int(value))
+}
diff --git a/spark/sql/functions/generated.go b/spark/sql/functions/generated.go
index cb9421e..6ef668e 100644
--- a/spark/sql/functions/generated.go
+++ b/spark/sql/functions/generated.go
@@ -15,9 +15,7 @@
package functions
-import (
- "github.com/apache/spark-connect-go/v35/spark/sql/column"
-)
+import "github.com/apache/spark-connect-go/v35/spark/sql/column"
// BitwiseNOT - Computes bitwise not.
//
@@ -137,7 +135,7 @@ func Nanvl(col1 column.Column, col2 column.Column)
column.Column {
//
// Rand is the Golang equivalent of rand: (seed: Optional[int] = None) ->
pyspark.sql.connect.column.Column
func Rand(seed int64) column.Column {
- lit_seed := Lit(seed)
+ lit_seed := Int64Lit(seed)
return column.NewColumn(column.NewUnresolvedFunctionWithColumns("rand",
lit_seed))
}
@@ -146,7 +144,7 @@ func Rand(seed int64) column.Column {
//
// Randn is the Golang equivalent of randn: (seed: Optional[int] = None) ->
pyspark.sql.connect.column.Column
func Randn(seed int64) column.Column {
- lit_seed := Lit(seed)
+ lit_seed := Int64Lit(seed)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("randn", lit_seed))
}
@@ -273,7 +271,7 @@ func Bin(col column.Column) column.Column {
//
// Bround is the Golang equivalent of bround: (col: 'ColumnOrName', scale: int
= 0) -> pyspark.sql.connect.column.Column
func Bround(col column.Column, scale int64) column.Column {
- lit_scale := Lit(scale)
+ lit_scale := Int64Lit(scale)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("bround", col,
lit_scale))
}
@@ -302,8 +300,8 @@ func Ceiling(col column.Column) column.Column {
//
// Conv is the Golang equivalent of conv: (col: 'ColumnOrName', fromBase: int,
toBase: int) -> pyspark.sql.connect.column.Column
func Conv(col column.Column, fromBase int64, toBase int64) column.Column {
- lit_fromBase := Lit(fromBase)
- lit_toBase := Lit(toBase)
+ lit_fromBase := Int64Lit(fromBase)
+ lit_toBase := Int64Lit(toBase)
return column.NewColumn(column.NewUnresolvedFunctionWithColumns("conv",
col, lit_fromBase, lit_toBase))
}
@@ -503,7 +501,7 @@ func Rint(col column.Column) column.Column {
//
// Round is the Golang equivalent of round: (col: 'ColumnOrName', scale: int =
0) -> pyspark.sql.connect.column.Column
func Round(col column.Column, scale int64) column.Column {
- lit_scale := Lit(scale)
+ lit_scale := Int64Lit(scale)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("round", col,
lit_scale))
}
@@ -518,7 +516,7 @@ func Sec(col column.Column) column.Column {
//
// ShiftLeft is the Golang equivalent of shiftLeft: (col: 'ColumnOrName',
numBits: int) -> pyspark.sql.connect.column.Column
func ShiftLeft(col column.Column, numBits int64) column.Column {
- lit_numBits := Lit(numBits)
+ lit_numBits := Int64Lit(numBits)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftLeft", col,
lit_numBits))
}
@@ -526,7 +524,7 @@ func ShiftLeft(col column.Column, numBits int64)
column.Column {
//
// Shiftleft is the Golang equivalent of shiftleft: (col: 'ColumnOrName',
numBits: int) -> pyspark.sql.connect.column.Column
func Shiftleft(col column.Column, numBits int64) column.Column {
- lit_numBits := Lit(numBits)
+ lit_numBits := Int64Lit(numBits)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftleft", col,
lit_numBits))
}
@@ -534,7 +532,7 @@ func Shiftleft(col column.Column, numBits int64)
column.Column {
//
// ShiftRight is the Golang equivalent of shiftRight: (col: 'ColumnOrName',
numBits: int) -> pyspark.sql.connect.column.Column
func ShiftRight(col column.Column, numBits int64) column.Column {
- lit_numBits := Lit(numBits)
+ lit_numBits := Int64Lit(numBits)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftRight", col,
lit_numBits))
}
@@ -542,7 +540,7 @@ func ShiftRight(col column.Column, numBits int64)
column.Column {
//
// Shiftright is the Golang equivalent of shiftright: (col: 'ColumnOrName',
numBits: int) -> pyspark.sql.connect.column.Column
func Shiftright(col column.Column, numBits int64) column.Column {
- lit_numBits := Lit(numBits)
+ lit_numBits := Int64Lit(numBits)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftright", col,
lit_numBits))
}
@@ -550,7 +548,7 @@ func Shiftright(col column.Column, numBits int64)
column.Column {
//
// ShiftRightUnsigned is the Golang equivalent of shiftRightUnsigned: (col:
'ColumnOrName', numBits: int) -> pyspark.sql.connect.column.Column
func ShiftRightUnsigned(col column.Column, numBits int64) column.Column {
- lit_numBits := Lit(numBits)
+ lit_numBits := Int64Lit(numBits)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftRightUnsigned",
col, lit_numBits))
}
@@ -558,7 +556,7 @@ func ShiftRightUnsigned(col column.Column, numBits int64)
column.Column {
//
// Shiftrightunsigned is the Golang equivalent of shiftrightunsigned: (col:
'ColumnOrName', numBits: int) -> pyspark.sql.connect.column.Column
func Shiftrightunsigned(col column.Column, numBits int64) column.Column {
- lit_numBits := Lit(numBits)
+ lit_numBits := Int64Lit(numBits)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftrightunsigned",
col, lit_numBits))
}
@@ -684,7 +682,7 @@ func Unhex(col column.Column) column.Column {
//
// ApproxCountDistinct is the Golang equivalent of approx_count_distinct:
(col: 'ColumnOrName', rsd: Optional[float] = None) ->
pyspark.sql.connect.column.Column
func ApproxCountDistinct(col column.Column, rsd float64) column.Column {
- lit_rsd := Lit(rsd)
+ lit_rsd := Float64Lit(rsd)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("approx_count_distinct",
col, lit_rsd))
}
@@ -1122,7 +1120,7 @@ func HistogramNumeric(col column.Column, nBins
column.Column) column.Column {
//
// Ntile is the Golang equivalent of ntile: (n: int) ->
pyspark.sql.connect.column.Column
func Ntile(n int64) column.Column {
- lit_n := Lit(n)
+ lit_n := Int64Lit(n)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("ntile", lit_n))
}
@@ -1207,8 +1205,8 @@ func ArrayCompact(col column.Column) column.Column {
//
// ArrayJoin is the Golang equivalent of array_join: (col: 'ColumnOrName',
delimiter: str, null_replacement: Optional[str] = None) ->
pyspark.sql.connect.column.Column
func ArrayJoin(col column.Column, delimiter string, null_replacement string)
column.Column {
- lit_delimiter := Lit(delimiter)
- lit_null_replacement := Lit(null_replacement)
+ lit_delimiter := StringLit(delimiter)
+ lit_null_replacement := StringLit(null_replacement)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("array_join", col,
lit_delimiter, lit_null_replacement))
}
@@ -1366,7 +1364,7 @@ func Get(col column.Column, index column.Column)
column.Column {
//
// GetJsonObject is the Golang equivalent of get_json_object: (col:
'ColumnOrName', path: str) -> pyspark.sql.connect.column.Column
func GetJsonObject(col column.Column, path string) column.Column {
- lit_path := Lit(path)
+ lit_path := StringLit(path)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("get_json_object",
col, lit_path))
}
@@ -1406,7 +1404,7 @@ func InlineOuter(col column.Column) column.Column {
//
// JsonTuple is the Golang equivalent of json_tuple: (col: 'ColumnOrName',
*fields: str) -> pyspark.sql.connect.column.Column
func JsonTuple(col column.Column, fields string) column.Column {
- lit_fields := Lit(fields)
+ lit_fields := StringLit(fields)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("json_tuple", col,
lit_fields))
}
@@ -1619,7 +1617,7 @@ func Trim(col column.Column) column.Column {
//
// ConcatWs is the Golang equivalent of concat_ws: (sep: str, *cols:
'ColumnOrName') -> pyspark.sql.connect.column.Column
func ConcatWs(sep string, cols ...column.Column) column.Column {
- lit_sep := Lit(sep)
+ lit_sep := StringLit(sep)
vals := make([]column.Column, 0)
vals = append(vals, lit_sep)
vals = append(vals, cols...)
@@ -1631,7 +1629,7 @@ func ConcatWs(sep string, cols ...column.Column)
column.Column {
//
// Decode is the Golang equivalent of decode: (col: 'ColumnOrName', charset:
str) -> pyspark.sql.connect.column.Column
func Decode(col column.Column, charset string) column.Column {
- lit_charset := Lit(charset)
+ lit_charset := StringLit(charset)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("decode", col,
lit_charset))
}
@@ -1640,7 +1638,7 @@ func Decode(col column.Column, charset string)
column.Column {
//
// Encode is the Golang equivalent of encode: (col: 'ColumnOrName', charset:
str) -> pyspark.sql.connect.column.Column
func Encode(col column.Column, charset string) column.Column {
- lit_charset := Lit(charset)
+ lit_charset := StringLit(charset)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("encode", col,
lit_charset))
}
@@ -1649,7 +1647,7 @@ func Encode(col column.Column, charset string)
column.Column {
//
// FormatNumber is the Golang equivalent of format_number: (col:
'ColumnOrName', d: int) -> pyspark.sql.connect.column.Column
func FormatNumber(col column.Column, d int64) column.Column {
- lit_d := Lit(d)
+ lit_d := Int64Lit(d)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("format_number", col,
lit_d))
}
@@ -1657,7 +1655,7 @@ func FormatNumber(col column.Column, d int64)
column.Column {
//
// FormatString is the Golang equivalent of format_string: (format: str,
*cols: 'ColumnOrName') -> pyspark.sql.connect.column.Column
func FormatString(format string, cols ...column.Column) column.Column {
- lit_format := Lit(format)
+ lit_format := StringLit(format)
vals := make([]column.Column, 0)
vals = append(vals, lit_format)
vals = append(vals, cols...)
@@ -1669,7 +1667,7 @@ func FormatString(format string, cols ...column.Column)
column.Column {
//
// Instr is the Golang equivalent of instr: (str: 'ColumnOrName', substr: str)
-> pyspark.sql.connect.column.Column
func Instr(str column.Column, substr string) column.Column {
- lit_substr := Lit(substr)
+ lit_substr := StringLit(substr)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("instr", str,
lit_substr))
}
@@ -1695,8 +1693,8 @@ func Sentences(string column.Column, language
column.Column, country column.Colu
//
// Substring is the Golang equivalent of substring: (str: 'ColumnOrName', pos:
int, len: int) -> pyspark.sql.connect.column.Column
func Substring(str column.Column, pos int64, len int64) column.Column {
- lit_pos := Lit(pos)
- lit_len := Lit(len)
+ lit_pos := Int64Lit(pos)
+ lit_len := Int64Lit(len)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("substring", str,
lit_pos, lit_len))
}
@@ -1707,8 +1705,8 @@ func Substring(str column.Column, pos int64, len int64)
column.Column {
//
// SubstringIndex is the Golang equivalent of substring_index: (str:
'ColumnOrName', delim: str, count: int) -> pyspark.sql.connect.column.Column
func SubstringIndex(str column.Column, delim string, count int64)
column.Column {
- lit_delim := Lit(delim)
- lit_count := Lit(count)
+ lit_delim := StringLit(delim)
+ lit_count := Int64Lit(count)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("substring_index",
str, lit_delim, lit_count))
}
@@ -1716,7 +1714,7 @@ func SubstringIndex(str column.Column, delim string,
count int64) column.Column
//
// Levenshtein is the Golang equivalent of levenshtein: (left: 'ColumnOrName',
right: 'ColumnOrName', threshold: Optional[int] = None) ->
pyspark.sql.connect.column.Column
func Levenshtein(left column.Column, right column.Column, threshold int64)
column.Column {
- lit_threshold := Lit(threshold)
+ lit_threshold := Int64Lit(threshold)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("levenshtein", left,
right, lit_threshold))
}
@@ -1724,8 +1722,8 @@ func Levenshtein(left column.Column, right column.Column,
threshold int64) colum
//
// Locate is the Golang equivalent of locate: (substr: str, str:
'ColumnOrName', pos: int = 1) -> pyspark.sql.connect.column.Column
func Locate(substr string, str column.Column, pos int64) column.Column {
- lit_substr := Lit(substr)
- lit_pos := Lit(pos)
+ lit_substr := StringLit(substr)
+ lit_pos := Int64Lit(pos)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("locate", lit_substr,
str, lit_pos))
}
@@ -1733,8 +1731,8 @@ func Locate(substr string, str column.Column, pos int64)
column.Column {
//
// Lpad is the Golang equivalent of lpad: (col: 'ColumnOrName', len: int, pad:
str) -> pyspark.sql.connect.column.Column
func Lpad(col column.Column, len int64, pad string) column.Column {
- lit_len := Lit(len)
- lit_pad := Lit(pad)
+ lit_len := Int64Lit(len)
+ lit_pad := StringLit(pad)
return column.NewColumn(column.NewUnresolvedFunctionWithColumns("lpad",
col, lit_len, lit_pad))
}
@@ -1742,8 +1740,8 @@ func Lpad(col column.Column, len int64, pad string)
column.Column {
//
// Rpad is the Golang equivalent of rpad: (col: 'ColumnOrName', len: int, pad:
str) -> pyspark.sql.connect.column.Column
func Rpad(col column.Column, len int64, pad string) column.Column {
- lit_len := Lit(len)
- lit_pad := Lit(pad)
+ lit_len := Int64Lit(len)
+ lit_pad := StringLit(pad)
return column.NewColumn(column.NewUnresolvedFunctionWithColumns("rpad",
col, lit_len, lit_pad))
}
@@ -1751,7 +1749,7 @@ func Rpad(col column.Column, len int64, pad string)
column.Column {
//
// Repeat is the Golang equivalent of repeat: (col: 'ColumnOrName', n: int) ->
pyspark.sql.connect.column.Column
func Repeat(col column.Column, n int64) column.Column {
- lit_n := Lit(n)
+ lit_n := Int64Lit(n)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("repeat", col, lit_n))
}
@@ -1759,8 +1757,8 @@ func Repeat(col column.Column, n int64) column.Column {
//
// Split is the Golang equivalent of split: (str: 'ColumnOrName', pattern:
str, limit: int = -1) -> pyspark.sql.connect.column.Column
func Split(str column.Column, pattern string, limit int64) column.Column {
- lit_pattern := Lit(pattern)
- lit_limit := Lit(limit)
+ lit_pattern := StringLit(pattern)
+ lit_limit := Int64Lit(limit)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("split", str,
lit_pattern, lit_limit))
}
@@ -1798,8 +1796,8 @@ func RegexpCount(str column.Column, regexp column.Column)
column.Column {
//
// RegexpExtract is the Golang equivalent of regexp_extract: (str:
'ColumnOrName', pattern: str, idx: int) -> pyspark.sql.connect.column.Column
func RegexpExtract(str column.Column, pattern string, idx int64) column.Column
{
- lit_pattern := Lit(pattern)
- lit_idx := Lit(idx)
+ lit_pattern := StringLit(pattern)
+ lit_idx := Int64Lit(idx)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("regexp_extract", str,
lit_pattern, lit_idx))
}
@@ -1861,8 +1859,8 @@ func BitLength(col column.Column) column.Column {
//
// Translate is the Golang equivalent of translate: (srcCol: 'ColumnOrName',
matching: str, replace: str) -> pyspark.sql.connect.column.Column
func Translate(srcCol column.Column, matching string, replace string)
column.Column {
- lit_matching := Lit(matching)
- lit_replace := Lit(replace)
+ lit_matching := StringLit(matching)
+ lit_replace := StringLit(replace)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("translate", srcCol,
lit_matching, lit_replace))
}
@@ -2214,7 +2212,7 @@ func Localtimestamp() column.Column {
//
// DateFormat is the Golang equivalent of date_format: (date: 'ColumnOrName',
format: str) -> pyspark.sql.connect.column.Column
func DateFormat(date column.Column, format string) column.Column {
- lit_format := Lit(format)
+ lit_format := StringLit(format)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("date_format", date,
lit_format))
}
@@ -2391,7 +2389,7 @@ func AddMonths(start column.Column, months column.Column)
column.Column {
//
// ToDate is the Golang equivalent of to_date: (col: 'ColumnOrName', format:
Optional[str] = None) -> pyspark.sql.connect.column.Column
func ToDate(col column.Column, format string) column.Column {
- lit_format := Lit(format)
+ lit_format := StringLit(format)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("to_date", col,
lit_format))
}
@@ -2432,7 +2430,7 @@ func UnixSeconds(col column.Column) column.Column {
//
// ToTimestamp is the Golang equivalent of to_timestamp: (col: 'ColumnOrName',
format: Optional[str] = None) -> pyspark.sql.connect.column.Column
func ToTimestamp(col column.Column, format string) column.Column {
- lit_format := Lit(format)
+ lit_format := StringLit(format)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("to_timestamp", col,
lit_format))
}
@@ -2518,7 +2516,7 @@ func XpathString(xml column.Column, path column.Column)
column.Column {
//
// Trunc is the Golang equivalent of trunc: (date: 'ColumnOrName', format:
str) -> pyspark.sql.connect.column.Column
func Trunc(date column.Column, format string) column.Column {
- lit_format := Lit(format)
+ lit_format := StringLit(format)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("trunc", date,
lit_format))
}
@@ -2526,7 +2524,7 @@ func Trunc(date column.Column, format string)
column.Column {
//
// DateTrunc is the Golang equivalent of date_trunc: (format: str, timestamp:
'ColumnOrName') -> pyspark.sql.connect.column.Column
func DateTrunc(format string, timestamp column.Column) column.Column {
- lit_format := Lit(format)
+ lit_format := StringLit(format)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("date_trunc",
lit_format, timestamp))
}
@@ -2535,7 +2533,7 @@ func DateTrunc(format string, timestamp column.Column)
column.Column {
//
// NextDay is the Golang equivalent of next_day: (date: 'ColumnOrName',
dayOfWeek: str) -> pyspark.sql.connect.column.Column
func NextDay(date column.Column, dayOfWeek string) column.Column {
- lit_dayOfWeek := Lit(dayOfWeek)
+ lit_dayOfWeek := StringLit(dayOfWeek)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("next_day", date,
lit_dayOfWeek))
}
@@ -2552,7 +2550,7 @@ func LastDay(date column.Column) column.Column {
//
// FromUnixtime is the Golang equivalent of from_unixtime: (timestamp:
'ColumnOrName', format: str = 'yyyy-MM-dd HH:mm:ss') ->
pyspark.sql.connect.column.Column
func FromUnixtime(timestamp column.Column, format string) column.Column {
- lit_format := Lit(format)
+ lit_format := StringLit(format)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("from_unixtime",
timestamp, lit_format))
}
@@ -2564,7 +2562,7 @@ func FromUnixtime(timestamp column.Column, format string)
column.Column {
//
// UnixTimestamp is the Golang equivalent of unix_timestamp: (timestamp:
Optional[ForwardRef('ColumnOrName')] = None, format: str = 'yyyy-MM-dd
HH:mm:ss') -> pyspark.sql.connect.column.Column
func UnixTimestamp(timestamp column.Column, format string) column.Column {
- lit_format := Lit(format)
+ lit_format := StringLit(format)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("unix_timestamp",
timestamp, lit_format))
}
@@ -2646,9 +2644,9 @@ func TimestampMicros(col column.Column) column.Column {
//
// Window is the Golang equivalent of window: (timeColumn: 'ColumnOrName',
windowDuration: str, slideDuration: Optional[str] = None, startTime:
Optional[str] = None) -> pyspark.sql.connect.column.Column
func Window(timeColumn column.Column, windowDuration string, slideDuration
string, startTime string) column.Column {
- lit_windowDuration := Lit(windowDuration)
- lit_slideDuration := Lit(slideDuration)
- lit_startTime := Lit(startTime)
+ lit_windowDuration := StringLit(windowDuration)
+ lit_slideDuration := StringLit(slideDuration)
+ lit_startTime := StringLit(startTime)
return
column.NewColumn(column.NewUnresolvedFunctionWithColumns("window", timeColumn,
lit_windowDuration, lit_slideDuration, lit_startTime))
}
@@ -2873,7 +2871,7 @@ func Sha1(col column.Column) column.Column {
//
// Sha2 is the Golang equivalent of sha2: (col: 'ColumnOrName', numBits: int)
-> pyspark.sql.connect.column.Column
func Sha2(col column.Column, numBits int64) column.Column {
- lit_numBits := Lit(numBits)
+ lit_numBits := Int64Lit(numBits)
return column.NewColumn(column.NewUnresolvedFunctionWithColumns("sha2",
col, lit_numBits))
}
@@ -3085,7 +3083,7 @@ func BitmapOrAgg(col column.Column) column.Column {
//
// CallFunction is the Golang equivalent of call_function: (funcName: str,
*cols: 'ColumnOrName') -> pyspark.sql.connect.column.Column
func CallFunction(funcName string, cols ...column.Column) column.Column {
- lit_funcName := Lit(funcName)
+ lit_funcName := StringLit(funcName)
vals := make([]column.Column, 0)
vals = append(vals, lit_funcName)
vals = append(vals, cols...)
diff --git a/spark/sql/group.go b/spark/sql/group.go
index 3f67e18..975dc50 100644
--- a/spark/sql/group.go
+++ b/spark/sql/group.go
@@ -19,6 +19,8 @@ package sql
import (
"context"
+ "github.com/apache/spark-connect-go/v35/spark/sql/types"
+
proto "github.com/apache/spark-connect-go/v35/internal/generated"
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
"github.com/apache/spark-connect-go/v35/spark/sql/column"
@@ -29,7 +31,7 @@ type GroupedData struct {
df *dataFrameImpl
groupType string
groupingCols []column.Convertible
- pivotValues []any
+ pivotValues []types.LiteralType
pivotCol column.Convertible
}
@@ -171,7 +173,7 @@ func (gd *GroupedData) Sum(ctx context.Context, cols
...string) (DataFrame, erro
// Count Computes the count value for each group.
func (gd *GroupedData) Count(ctx context.Context) (DataFrame, error) {
- return gd.Agg(ctx, functions.Count(functions.Lit(1)).Alias("count"))
+ return gd.Agg(ctx,
functions.Count(functions.Lit(types.Int64(1))).Alias("count"))
}
// Mean Computes the average value for each numeric column for each group.
@@ -179,7 +181,7 @@ func (gd *GroupedData) Mean(ctx context.Context, cols
...string) (DataFrame, err
return gd.Avg(ctx, cols...)
}
-func (gd *GroupedData) Pivot(ctx context.Context, pivotCol string, pivotValues
[]any) (*GroupedData, error) {
+func (gd *GroupedData) Pivot(ctx context.Context, pivotCol string, pivotValues
[]types.LiteralType) (*GroupedData, error) {
if gd.groupType != "groupby" {
if gd.groupType == "pivot" {
return nil,
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivot cannot be applied
on pivot")
diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go
index 58436be..4bf94cc 100644
--- a/spark/sql/sparksession.go
+++ b/spark/sql/sparksession.go
@@ -230,6 +230,10 @@ func (s *sparkSessionImpl) CreateDataFrame(ctx
context.Context, data [][]any, sc
// Iterate over all fields and add the values:
for _, row := range data {
for i, field := range schema.Fields {
+ if row[i] == nil {
+ rb.Field(i).AppendNull()
+ continue
+ }
switch field.DataType {
case types.BOOLEAN:
rb.Field(i).(*array.BooleanBuilder).Append(row[i].(bool))
diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go
index 360bed9..2921881 100644
--- a/spark/sql/types/arrow.go
+++ b/spark/sql/types/arrow.go
@@ -75,77 +75,137 @@ func readArrayData(t arrow.Type, data arrow.ArrayData)
([]any, error) {
case arrow.BOOL:
data := array.NewBooleanData(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.INT8:
data := array.NewInt8Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.INT16:
data := array.NewInt16Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.INT32:
data := array.NewInt32Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.INT64:
data := array.NewInt64Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.FLOAT16:
data := array.NewFloat16Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.FLOAT32:
data := array.NewFloat32Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.FLOAT64:
data := array.NewFloat64Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.DECIMAL | arrow.DECIMAL128:
data := array.NewDecimal128Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.DECIMAL256:
data := array.NewDecimal256Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.STRING:
data := array.NewStringData(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.BINARY:
data := array.NewBinaryData(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.TIMESTAMP:
data := array.NewTimestampData(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.DATE64:
data := array.NewDate64Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.DATE32:
data := array.NewDate32Data(data)
for i := 0; i < data.Len(); i++ {
- buf = append(buf, data.Value(i))
+ if data.IsNull(i) {
+ buf = append(buf, nil)
+ } else {
+ buf = append(buf, data.Value(i))
+ }
}
case arrow.LIST:
data := array.NewListData(data)
diff --git a/spark/sql/types/builtin.go b/spark/sql/types/builtin.go
new file mode 100644
index 0000000..15f50d1
--- /dev/null
+++ b/spark/sql/types/builtin.go
@@ -0,0 +1,404 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package types
+
+import (
+ "context"
+
+ proto "github.com/apache/spark-connect-go/v35/internal/generated"
+)
+
+type LiteralType interface {
+ ToProto(ctx context.Context) (*proto.Expression, error)
+}
+
+type NumericLiteral interface {
+ LiteralType
+ // marker method for compile time safety.
+ isNumericLiteral()
+}
+
+type PrimitiveTypeLiteral interface {
+ LiteralType
+ isPrimitiveTypeLiteral()
+}
+
+type Int8 int8
+
+func (t Int8) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_Byte{Byte: int32(t)},
+ },
+ },
+ }, nil
+}
+
+func (t Int8) isNumericLiteral() {}
+
+func (t Int8) isPrimitiveTypeLiteral() {}
+
+type Int16 int16
+
+func (t Int16) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_Short{Short: int32(t)},
+ },
+ },
+ }, nil
+}
+
+func (t Int16) isNumericLiteral() {}
+
+func (t Int16) isPrimitiveTypeLiteral() {}
+
+type Int32 int32
+
+func (t Int32) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_Integer{Integer: int32(t)},
+ },
+ },
+ }, nil
+}
+
+func (t Int32) isNumericLiteral() {}
+
+func (t Int32) isPrimitiveTypeLiteral() {}
+
+type Int64 int64
+
+func (t Int64) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_Long{Long: int64(t)},
+ },
+ },
+ }, nil
+}
+
+func (t Int64) isNumericLiteral() {}
+
+func (t Int64) isPrimitiveTypeLiteral() {}
+
+type Int int
+
+func (t Int) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return Int64(t).ToProto(ctx)
+}
+
+func (t Int) isNumericLiteral() {}
+
+func (t Int) isPrimitiveTypeLiteral() {}
+
+type Float32 float32
+
+func (t Float32) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_Float{Float: float32(t)},
+ },
+ },
+ }, nil
+}
+
+func (t Float32) isNumericLiteral() {}
+
+func (t Float32) isPrimitiveTypeLiteral() {}
+
+type Float64 float64
+
+func (t Float64) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_Double{Double: float64(t)},
+ },
+ },
+ }, nil
+}
+
+func (t Float64) isNumericLiteral() {}
+
+func (t Float64) isPrimitiveTypeLiteral() {}
+
+type String string
+
+func (t String) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_String_{String_: string(t)},
+ },
+ },
+ }, nil
+}
+
+func (t String) isPrimitiveTypeLiteral() {}
+
+type Boolean bool
+
+func (t Boolean) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_Boolean{Boolean: bool(t)},
+ },
+ },
+ }, nil
+}
+
+func (t Boolean) isPrimitiveTypeLiteral() {}
+
+type Binary []byte
+
+func (t Binary) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType:
&proto.Expression_Literal_Binary{Binary: t},
+ },
+ },
+ }, nil
+}
+
+type Int8NilType struct{}
+
+var Int8Nil = Int8NilType{}
+
+func (t Int8NilType) isNumericLiteral() {}
+
+func (t Int8NilType) isPrimitiveTypeLiteral() {}
+
+func (t Int8NilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_Byte_{
+ Byte:
&proto.DataType_Byte{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+type Int16NilType struct{}
+
+var Int16Nil = Int16NilType{}
+
+func (t Int16NilType) isNumericLiteral() {}
+
+func (t Int16NilType) isPrimitiveTypeLiteral() {}
+
+func (t Int16NilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_Short_{
+ Short:
&proto.DataType_Short{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+type Int32NilType struct{}
+
+var Int32Nil = Int32NilType{}
+
+func (t Int32NilType) isNumericLiteral() {}
+
+func (t Int32NilType) isPrimitiveTypeLiteral() {}
+
+func (t Int32NilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_Integer_{
+ Integer:
&proto.DataType_Integer{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+type Int64NilType struct{}
+
+var Int64Nil = Int64NilType{}
+
+func (t Int64NilType) isNumericLiteral() {}
+
+func (t Int64NilType) isPrimitiveTypeLiteral() {}
+
+func (t Int64NilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_Long_{
+ Long:
&proto.DataType_Long{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+type IntNilType struct{}
+
+var IntNil = IntNilType{}
+
+func (t IntNilType) isNumericLiteral() {}
+
+func (t IntNilType) isPrimitiveTypeLiteral() {}
+
+func (t IntNilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+ return Int64NilType{}.ToProto(ctx)
+}
+
+type Float32NilType struct{}
+
+var Float32Nil = Float32NilType{}
+
+func (t Float32NilType) isNumericLiteral() {}
+
+func (t Float32NilType) isPrimitiveTypeLiteral() {}
+
+func (t Float32NilType) ToProto(ctx context.Context) (*proto.Expression,
error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_Float_{
+ Float:
&proto.DataType_Float{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+type Float64NilType struct{}
+
+var Float64Nil = Float64NilType{}
+
+func (t Float64NilType) isNumericLiteral() {}
+
+func (t Float64NilType) isPrimitiveTypeLiteral() {}
+
+func (t Float64NilType) ToProto(ctx context.Context) (*proto.Expression,
error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_Double_{
+ Double:
&proto.DataType_Double{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+type StringNilType struct{}
+
+var StringNil = StringNilType{}
+
+func (t StringNilType) isPrimitiveTypeLiteral() {}
+
+func (t StringNilType) ToProto(ctx context.Context) (*proto.Expression, error)
{
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_String_{
+ String_:
&proto.DataType_String{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+type BooleanNilType struct{}
+
+var BooleanNil = BooleanNilType{}
+
+func (t BooleanNilType) isPrimitiveTypeLiteral() {}
+
+func (t BooleanNilType) ToProto(ctx context.Context) (*proto.Expression,
error) {
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_Boolean_{
+ Boolean:
&proto.DataType_Boolean{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
+
+type BinaryNilType struct{}
+
+var BinaryNil = BinaryNilType{}
+
+func (t BinaryNilType) ToProto(ctx context.Context) (*proto.Expression, error)
{
+ return &proto.Expression{
+ ExprType: &proto.Expression_Literal_{
+ Literal: &proto.Expression_Literal{
+ LiteralType: &proto.Expression_Literal_Null{
+ Null: &proto.DataType{
+ Kind: &proto.DataType_Binary_{
+ Binary:
&proto.DataType_Binary{},
+ },
+ },
+ },
+ },
+ },
+ }, nil
+}
diff --git a/spark/sql/types/builtin_test.go b/spark/sql/types/builtin_test.go
new file mode 100644
index 0000000..0a326ac
--- /dev/null
+++ b/spark/sql/types/builtin_test.go
@@ -0,0 +1,83 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package types
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestBuiltinTypes(t *testing.T) {
+ p, err := Int8(1).ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetByte(), int32(1))
+
+ p, err = Int16(1).ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetShort(), int32(1))
+
+ p, err = Int32(1).ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetInteger(), int32(1))
+
+ p, err = Int64(1).ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetLong(), int64(1))
+
+ p, err = Float32(1.0).ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetFloat(), float32(1.0))
+
+ p, err = Float64(1.0).ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetDouble(), float64(1.0))
+
+ p, err = String("1").ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetString_(), "1")
+
+ p, err = Boolean(true).ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetBoolean(), true)
+
+ p, err = Binary([]byte{1}).ToProto(context.TODO())
+ assert.NoError(t, err)
+ assert.Equal(t, p.GetLiteral().GetBinary(), []byte{1})
+}
+
+func testMe(n NumericLiteral) bool {
+ return true
+}
+
+func testPrimitive(p PrimitiveTypeLiteral) bool {
+ return true
+}
+
+func TestNumericTypes(t *testing.T) {
+ assert.True(t, testMe(Int8(1)))
+ assert.True(t, testMe(Int16(1)))
+ assert.True(t, testMe(Int32(1)))
+ assert.True(t, testMe(Int64(1)))
+ assert.True(t, testMe(Float32(1.0)))
+ assert.True(t, testMe(Float64(1.0)))
+
+ assert.True(t, testPrimitive(String("a")))
+ assert.True(t, testPrimitive(Boolean(true)))
+ assert.True(t, testPrimitive(Int16(1)))
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]