This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new c3f3eca  #95 Improve Compile-Time Safety for Literals
c3f3eca is described below

commit c3f3eca84a02c803ed8a1f9ebb7c142ff7706111
Author: Martin Grund <[email protected]>
AuthorDate: Mon Dec 30 17:01:20 2024 +0100

    #95 Improve Compile-Time Safety for Literals
    
    ### What changes were proposed in this pull request?
    This patch improves the compile time safety for literals by providing 
specific interface types. This will make using specific types much easier to 
use.
    
    This patch adds a number of helper methods for dealing with literals for 
example `column.StringLit()` or `column.Int8Lit()`. In addition it fixes some 
longer standing bugs around dealing with nil values.
    
    ### Why are the changes needed?
    Usability, Safety, Stability
    
    ### Does this PR introduce _any_ user-facing change?
    Added new methods for using the type interfaces for conversion.
    
    ### How was this patch tested?
    Added testing.
    
    Closes #96 from grundprinzip/safety_95.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 cmd/spark-connect-example-spark-session/main.go |   4 +-
 dev/gen.py                                      |   6 +-
 internal/tests/integration/dataframe_test.go    |  10 +-
 internal/tests/integration/functions_test.go    |   2 +-
 internal/tests/integration/sql_test.go          |   2 +-
 spark/sql/column/expressions.go                 |  45 +--
 spark/sql/dataframe_test.go                     |   2 +-
 spark/sql/functions/buiitins.go                 |  43 ++-
 spark/sql/functions/generated.go                | 114 ++++---
 spark/sql/group.go                              |   8 +-
 spark/sql/sparksession.go                       |   4 +
 spark/sql/types/arrow.go                        |  90 +++++-
 spark/sql/types/builtin.go                      | 404 ++++++++++++++++++++++++
 spark/sql/types/builtin_test.go                 |  83 +++++
 14 files changed, 689 insertions(+), 128 deletions(-)

diff --git a/cmd/spark-connect-example-spark-session/main.go 
b/cmd/spark-connect-example-spark-session/main.go
index 516d87b..8c6aa90 100644
--- a/cmd/spark-connect-example-spark-session/main.go
+++ b/cmd/spark-connect-example-spark-session/main.go
@@ -22,6 +22,8 @@ import (
        "fmt"
        "log"
 
+       "github.com/apache/spark-connect-go/v35/spark/sql/types"
+
        "github.com/apache/spark-connect-go/v35/spark/sql/functions"
 
        "github.com/apache/spark-connect-go/v35/spark/sql"
@@ -68,7 +70,7 @@ func main() {
        }
 
        df, _ = spark.Sql(ctx, "select * from range(100)")
-       df, err = df.Filter(ctx, functions.Col("id").Lt(functions.Lit(20)))
+       df, err = df.Filter(ctx, 
functions.Col("id").Lt(functions.Lit(types.Int64(20))))
        if err != nil {
                log.Fatalf("Failed: %s", err)
        }
diff --git a/dev/gen.py b/dev/gen.py
index 3b307e3..63ecd43 100644
--- a/dev/gen.py
+++ b/dev/gen.py
@@ -122,15 +122,15 @@ for fun in F.__dict__:
             args.append(p)
         elif param.annotation == str or typing.get_args(param.annotation) == 
(str, types.NoneType):
             res_params.append(f"{p} string")
-            conversions.append(f"lit_{p} := Lit({p})")
+            conversions.append(f"lit_{p} := StringLit({p})")
             args.append(f"lit_{p}")
         elif param.annotation == int or typing.get_args(param.annotation) == 
(int, types.NoneType):
             res_params.append(f"{p} int64")
-            conversions.append(f"lit_{p} := Lit({p})")
+            conversions.append(f"lit_{p} := Int64Lit({p})")
             args.append(f"lit_{p}")
         elif param.annotation == float or typing.get_args(param.annotation) == 
(float, types.NoneType):
             res_params.append(f"{p} float64")
-            conversions.append(f"lit_{p} := Lit({p})")
+            conversions.append(f"lit_{p} := Float64Lit({p})")
             args.append(f"lit_{p}")
         else:
             valid = False
diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index f2265f5..e00b375 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -39,7 +39,7 @@ func TestDataFrame_Select(t *testing.T) {
        assert.NoError(t, err)
        df, err := spark.Sql(ctx, "select * from range(100)")
        assert.NoError(t, err)
-       df, err = df.Select(ctx, functions.Lit("1"), functions.Lit("2"))
+       df, err = df.Select(ctx, functions.StringLit("1"), 
functions.StringLit("2"))
        assert.NoError(t, err)
 
        res, err := df.Collect(ctx)
@@ -210,7 +210,7 @@ func TestDataFrame_WithColumn(t *testing.T) {
        ctx, spark := connect()
        df, err := spark.Sql(ctx, "select * from range(10)")
        assert.NoError(t, err)
-       df, err = df.WithColumn(ctx, "newCol", functions.Lit(1))
+       df, err = df.WithColumn(ctx, "newCol", functions.IntLit(1))
        assert.NoError(t, err)
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
@@ -226,8 +226,8 @@ func TestDataFrame_WithColumns(t *testing.T) {
        ctx, spark := connect()
        df, err := spark.Sql(ctx, "select * from range(10)")
        assert.NoError(t, err)
-       df, err = df.WithColumns(ctx, column.WithAlias("newCol1", 
functions.Lit(1)),
-               column.WithAlias("newCol2", functions.Lit(2)))
+       df, err = df.WithColumns(ctx, column.WithAlias("newCol1", 
functions.IntLit(1)),
+               column.WithAlias("newCol2", functions.IntLit(2)))
        assert.NoError(t, err)
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
@@ -592,7 +592,7 @@ func TestDataFrame_Pivot(t *testing.T) {
        df, err := spark.CreateDataFrame(ctx, data, schema)
        assert.NoError(t, err)
        gd := df.GroupBy(functions.Col("year"))
-       gd, err = gd.Pivot(ctx, "course", []any{"Java", "dotNET"})
+       gd, err = gd.Pivot(ctx, "course", 
[]types.LiteralType{types.String("Java"), types.String("dotNET")})
        assert.NoError(t, err)
        df, err = gd.Sum(ctx, "earnings")
        assert.NoError(t, err)
diff --git a/internal/tests/integration/functions_test.go 
b/internal/tests/integration/functions_test.go
index 1f89923..33f3352 100644
--- a/internal/tests/integration/functions_test.go
+++ b/internal/tests/integration/functions_test.go
@@ -33,7 +33,7 @@ func TestIntegration_BuiltinFunctions(t *testing.T) {
        }
 
        df, _ := spark.Sql(ctx, "select '[2]' as a from range(10)")
-       df, _ = df.Filter(ctx, 
functions.JsonArrayLength(functions.Col("a")).Eq(functions.Lit(1)))
+       df, _ = df.Filter(ctx, 
functions.JsonArrayLength(functions.Col("a")).Eq(functions.IntLit(1)))
        res, err := df.Collect(ctx)
        assert.NoError(t, err)
        assert.Equal(t, 10, len(res))
diff --git a/internal/tests/integration/sql_test.go 
b/internal/tests/integration/sql_test.go
index ba2dccf..28c32d4 100644
--- a/internal/tests/integration/sql_test.go
+++ b/internal/tests/integration/sql_test.go
@@ -46,7 +46,7 @@ func TestIntegration_RunSQLCommand(t *testing.T) {
        assert.NoError(t, err)
        assert.Equal(t, 100, len(res))
 
-       df, err = df.Filter(ctx, column.OfDF(df, "id").Lt(functions.Lit(10)))
+       df, err = df.Filter(ctx, column.OfDF(df, "id").Lt(functions.IntLit(10)))
        assert.NoError(t, err)
        res, err = df.Collect(ctx)
        assert.NoErrorf(t, err, "Must be able to collect the rows.")
diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go
index fe50d16..a2ffa24 100644
--- a/spark/sql/column/expressions.go
+++ b/spark/sql/column/expressions.go
@@ -20,6 +20,8 @@ import (
        "fmt"
        "strings"
 
+       "github.com/apache/spark-connect-go/v35/spark/sql/types"
+
        "github.com/apache/spark-connect-go/v35/spark/sparkerrors"
 
        proto "github.com/apache/spark-connect-go/v35/internal/generated"
@@ -308,52 +310,17 @@ func (s *sqlExression) ToProto(context.Context) 
(*proto.Expression, error) {
 }
 
 type literalExpression struct {
-       value any
+       value types.LiteralType
 }
 
 func (l *literalExpression) DebugString() string {
        return fmt.Sprintf("%v", l.value)
 }
 
-func (l *literalExpression) ToProto(context.Context) (*proto.Expression, 
error) {
-       expr := newProtoExpression()
-       expr.ExprType = &proto.Expression_Literal_{
-               Literal: &proto.Expression_Literal{},
-       }
-       switch v := l.value.(type) {
-       case int8:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Byte{Byte: int32(v)}
-       case int16:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Short{Short: int32(v)}
-       case int32:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Integer{Integer: v}
-       case int64:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Long{Long: v}
-       case uint8:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Short{Short: int32(v)}
-       case uint16:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Integer{Integer: int32(v)}
-       case uint32:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Long{Long: int64(v)}
-       case float32:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Float{Float: v}
-       case float64:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Double{Double: v}
-       case string:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_String_{String_: v}
-       case bool:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Boolean{Boolean: v}
-       case []byte:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Binary{Binary: v}
-       case int:
-               expr.GetLiteral().LiteralType = 
&proto.Expression_Literal_Long{Long: int64(v)}
-       default:
-               return nil, sparkerrors.WithType(sparkerrors.InvalidPlanError,
-                       fmt.Errorf("unsupported literal type %T", v))
-       }
-       return expr, nil
+func (l *literalExpression) ToProto(ctx context.Context) (*proto.Expression, 
error) {
+       return l.value.ToProto(ctx)
 }
 
-func NewLiteral(value any) expression {
+func NewLiteral(value types.LiteralType) expression {
        return &literalExpression{value: value}
 }
diff --git a/spark/sql/dataframe_test.go b/spark/sql/dataframe_test.go
index 39145b0..afa6111 100644
--- a/spark/sql/dataframe_test.go
+++ b/spark/sql/dataframe_test.go
@@ -41,7 +41,7 @@ func TestDataFrameImpl_GroupBy(t *testing.T) {
 
        assert.Equal(t, gd.groupType, "groupby")
 
-       df, err := gd.Agg(ctx, functions.Count(functions.Lit(1)))
+       df, err := gd.Agg(ctx, functions.Count(functions.Int64Lit(1)))
        assert.Nil(t, err)
        impl := df.(*dataFrameImpl)
        assert.NotNil(t, impl)
diff --git a/spark/sql/functions/buiitins.go b/spark/sql/functions/buiitins.go
index a2a7bf8..6df737c 100644
--- a/spark/sql/functions/buiitins.go
+++ b/spark/sql/functions/buiitins.go
@@ -17,6 +17,7 @@ package functions
 
 import (
        "github.com/apache/spark-connect-go/v35/spark/sql/column"
+       "github.com/apache/spark-connect-go/v35/spark/sql/types"
 )
 
 func Expr(expr string) column.Column {
@@ -27,6 +28,46 @@ func Col(name string) column.Column {
        return column.NewColumn(column.NewColumnReference(name))
 }
 
-func Lit(value any) column.Column {
+func Lit(value types.LiteralType) column.Column {
        return column.NewColumn(column.NewLiteral(value))
 }
+
+func Int8Lit(value int8) column.Column {
+       return Lit(types.Int8(value))
+}
+
+func Int16Lit(value int16) column.Column {
+       return Lit(types.Int16(value))
+}
+
+func Int32Lit(value int32) column.Column {
+       return Lit(types.Int32(value))
+}
+
+func Int64Lit(value int64) column.Column {
+       return Lit(types.Int64(value))
+}
+
+func Float32Lit(value float32) column.Column {
+       return Lit(types.Float32(value))
+}
+
+func Float64Lit(value float64) column.Column {
+       return Lit(types.Float64(value))
+}
+
+func StringLit(value string) column.Column {
+       return Lit(types.String(value))
+}
+
+func BoolLit(value bool) column.Column {
+       return Lit(types.Boolean(value))
+}
+
+func BinaryLit(value []byte) column.Column {
+       return Lit(types.Binary(value))
+}
+
+func IntLit(value int) column.Column {
+       return Lit(types.Int(value))
+}
diff --git a/spark/sql/functions/generated.go b/spark/sql/functions/generated.go
index cb9421e..6ef668e 100644
--- a/spark/sql/functions/generated.go
+++ b/spark/sql/functions/generated.go
@@ -15,9 +15,7 @@
 
 package functions
 
-import (
-       "github.com/apache/spark-connect-go/v35/spark/sql/column"
-)
+import "github.com/apache/spark-connect-go/v35/spark/sql/column"
 
 // BitwiseNOT - Computes bitwise not.
 //
@@ -137,7 +135,7 @@ func Nanvl(col1 column.Column, col2 column.Column) 
column.Column {
 //
 // Rand is the Golang equivalent of rand: (seed: Optional[int] = None) -> 
pyspark.sql.connect.column.Column
 func Rand(seed int64) column.Column {
-       lit_seed := Lit(seed)
+       lit_seed := Int64Lit(seed)
        return column.NewColumn(column.NewUnresolvedFunctionWithColumns("rand", 
lit_seed))
 }
 
@@ -146,7 +144,7 @@ func Rand(seed int64) column.Column {
 //
 // Randn is the Golang equivalent of randn: (seed: Optional[int] = None) -> 
pyspark.sql.connect.column.Column
 func Randn(seed int64) column.Column {
-       lit_seed := Lit(seed)
+       lit_seed := Int64Lit(seed)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("randn", lit_seed))
 }
 
@@ -273,7 +271,7 @@ func Bin(col column.Column) column.Column {
 //
 // Bround is the Golang equivalent of bround: (col: 'ColumnOrName', scale: int 
= 0) -> pyspark.sql.connect.column.Column
 func Bround(col column.Column, scale int64) column.Column {
-       lit_scale := Lit(scale)
+       lit_scale := Int64Lit(scale)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("bround", col, 
lit_scale))
 }
 
@@ -302,8 +300,8 @@ func Ceiling(col column.Column) column.Column {
 //
 // Conv is the Golang equivalent of conv: (col: 'ColumnOrName', fromBase: int, 
toBase: int) -> pyspark.sql.connect.column.Column
 func Conv(col column.Column, fromBase int64, toBase int64) column.Column {
-       lit_fromBase := Lit(fromBase)
-       lit_toBase := Lit(toBase)
+       lit_fromBase := Int64Lit(fromBase)
+       lit_toBase := Int64Lit(toBase)
        return column.NewColumn(column.NewUnresolvedFunctionWithColumns("conv", 
col, lit_fromBase, lit_toBase))
 }
 
@@ -503,7 +501,7 @@ func Rint(col column.Column) column.Column {
 //
 // Round is the Golang equivalent of round: (col: 'ColumnOrName', scale: int = 
0) -> pyspark.sql.connect.column.Column
 func Round(col column.Column, scale int64) column.Column {
-       lit_scale := Lit(scale)
+       lit_scale := Int64Lit(scale)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("round", col, 
lit_scale))
 }
 
@@ -518,7 +516,7 @@ func Sec(col column.Column) column.Column {
 //
 // ShiftLeft is the Golang equivalent of shiftLeft: (col: 'ColumnOrName', 
numBits: int) -> pyspark.sql.connect.column.Column
 func ShiftLeft(col column.Column, numBits int64) column.Column {
-       lit_numBits := Lit(numBits)
+       lit_numBits := Int64Lit(numBits)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftLeft", col, 
lit_numBits))
 }
 
@@ -526,7 +524,7 @@ func ShiftLeft(col column.Column, numBits int64) 
column.Column {
 //
 // Shiftleft is the Golang equivalent of shiftleft: (col: 'ColumnOrName', 
numBits: int) -> pyspark.sql.connect.column.Column
 func Shiftleft(col column.Column, numBits int64) column.Column {
-       lit_numBits := Lit(numBits)
+       lit_numBits := Int64Lit(numBits)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftleft", col, 
lit_numBits))
 }
 
@@ -534,7 +532,7 @@ func Shiftleft(col column.Column, numBits int64) 
column.Column {
 //
 // ShiftRight is the Golang equivalent of shiftRight: (col: 'ColumnOrName', 
numBits: int) -> pyspark.sql.connect.column.Column
 func ShiftRight(col column.Column, numBits int64) column.Column {
-       lit_numBits := Lit(numBits)
+       lit_numBits := Int64Lit(numBits)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftRight", col, 
lit_numBits))
 }
 
@@ -542,7 +540,7 @@ func ShiftRight(col column.Column, numBits int64) 
column.Column {
 //
 // Shiftright is the Golang equivalent of shiftright: (col: 'ColumnOrName', 
numBits: int) -> pyspark.sql.connect.column.Column
 func Shiftright(col column.Column, numBits int64) column.Column {
-       lit_numBits := Lit(numBits)
+       lit_numBits := Int64Lit(numBits)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftright", col, 
lit_numBits))
 }
 
@@ -550,7 +548,7 @@ func Shiftright(col column.Column, numBits int64) 
column.Column {
 //
 // ShiftRightUnsigned is the Golang equivalent of shiftRightUnsigned: (col: 
'ColumnOrName', numBits: int) -> pyspark.sql.connect.column.Column
 func ShiftRightUnsigned(col column.Column, numBits int64) column.Column {
-       lit_numBits := Lit(numBits)
+       lit_numBits := Int64Lit(numBits)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftRightUnsigned", 
col, lit_numBits))
 }
 
@@ -558,7 +556,7 @@ func ShiftRightUnsigned(col column.Column, numBits int64) 
column.Column {
 //
 // Shiftrightunsigned is the Golang equivalent of shiftrightunsigned: (col: 
'ColumnOrName', numBits: int) -> pyspark.sql.connect.column.Column
 func Shiftrightunsigned(col column.Column, numBits int64) column.Column {
-       lit_numBits := Lit(numBits)
+       lit_numBits := Int64Lit(numBits)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("shiftrightunsigned", 
col, lit_numBits))
 }
 
@@ -684,7 +682,7 @@ func Unhex(col column.Column) column.Column {
 //
 // ApproxCountDistinct is the Golang equivalent of approx_count_distinct: 
(col: 'ColumnOrName', rsd: Optional[float] = None) -> 
pyspark.sql.connect.column.Column
 func ApproxCountDistinct(col column.Column, rsd float64) column.Column {
-       lit_rsd := Lit(rsd)
+       lit_rsd := Float64Lit(rsd)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("approx_count_distinct",
 col, lit_rsd))
 }
 
@@ -1122,7 +1120,7 @@ func HistogramNumeric(col column.Column, nBins 
column.Column) column.Column {
 //
 // Ntile is the Golang equivalent of ntile: (n: int) -> 
pyspark.sql.connect.column.Column
 func Ntile(n int64) column.Column {
-       lit_n := Lit(n)
+       lit_n := Int64Lit(n)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("ntile", lit_n))
 }
 
@@ -1207,8 +1205,8 @@ func ArrayCompact(col column.Column) column.Column {
 //
 // ArrayJoin is the Golang equivalent of array_join: (col: 'ColumnOrName', 
delimiter: str, null_replacement: Optional[str] = None) -> 
pyspark.sql.connect.column.Column
 func ArrayJoin(col column.Column, delimiter string, null_replacement string) 
column.Column {
-       lit_delimiter := Lit(delimiter)
-       lit_null_replacement := Lit(null_replacement)
+       lit_delimiter := StringLit(delimiter)
+       lit_null_replacement := StringLit(null_replacement)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("array_join", col, 
lit_delimiter, lit_null_replacement))
 }
 
@@ -1366,7 +1364,7 @@ func Get(col column.Column, index column.Column) 
column.Column {
 //
 // GetJsonObject is the Golang equivalent of get_json_object: (col: 
'ColumnOrName', path: str) -> pyspark.sql.connect.column.Column
 func GetJsonObject(col column.Column, path string) column.Column {
-       lit_path := Lit(path)
+       lit_path := StringLit(path)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("get_json_object", 
col, lit_path))
 }
 
@@ -1406,7 +1404,7 @@ func InlineOuter(col column.Column) column.Column {
 //
 // JsonTuple is the Golang equivalent of json_tuple: (col: 'ColumnOrName', 
*fields: str) -> pyspark.sql.connect.column.Column
 func JsonTuple(col column.Column, fields string) column.Column {
-       lit_fields := Lit(fields)
+       lit_fields := StringLit(fields)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("json_tuple", col, 
lit_fields))
 }
 
@@ -1619,7 +1617,7 @@ func Trim(col column.Column) column.Column {
 //
 // ConcatWs is the Golang equivalent of concat_ws: (sep: str, *cols: 
'ColumnOrName') -> pyspark.sql.connect.column.Column
 func ConcatWs(sep string, cols ...column.Column) column.Column {
-       lit_sep := Lit(sep)
+       lit_sep := StringLit(sep)
        vals := make([]column.Column, 0)
        vals = append(vals, lit_sep)
        vals = append(vals, cols...)
@@ -1631,7 +1629,7 @@ func ConcatWs(sep string, cols ...column.Column) 
column.Column {
 //
 // Decode is the Golang equivalent of decode: (col: 'ColumnOrName', charset: 
str) -> pyspark.sql.connect.column.Column
 func Decode(col column.Column, charset string) column.Column {
-       lit_charset := Lit(charset)
+       lit_charset := StringLit(charset)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("decode", col, 
lit_charset))
 }
 
@@ -1640,7 +1638,7 @@ func Decode(col column.Column, charset string) 
column.Column {
 //
 // Encode is the Golang equivalent of encode: (col: 'ColumnOrName', charset: 
str) -> pyspark.sql.connect.column.Column
 func Encode(col column.Column, charset string) column.Column {
-       lit_charset := Lit(charset)
+       lit_charset := StringLit(charset)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("encode", col, 
lit_charset))
 }
 
@@ -1649,7 +1647,7 @@ func Encode(col column.Column, charset string) 
column.Column {
 //
 // FormatNumber is the Golang equivalent of format_number: (col: 
'ColumnOrName', d: int) -> pyspark.sql.connect.column.Column
 func FormatNumber(col column.Column, d int64) column.Column {
-       lit_d := Lit(d)
+       lit_d := Int64Lit(d)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("format_number", col, 
lit_d))
 }
 
@@ -1657,7 +1655,7 @@ func FormatNumber(col column.Column, d int64) 
column.Column {
 //
 // FormatString is the Golang equivalent of format_string: (format: str, 
*cols: 'ColumnOrName') -> pyspark.sql.connect.column.Column
 func FormatString(format string, cols ...column.Column) column.Column {
-       lit_format := Lit(format)
+       lit_format := StringLit(format)
        vals := make([]column.Column, 0)
        vals = append(vals, lit_format)
        vals = append(vals, cols...)
@@ -1669,7 +1667,7 @@ func FormatString(format string, cols ...column.Column) 
column.Column {
 //
 // Instr is the Golang equivalent of instr: (str: 'ColumnOrName', substr: str) 
-> pyspark.sql.connect.column.Column
 func Instr(str column.Column, substr string) column.Column {
-       lit_substr := Lit(substr)
+       lit_substr := StringLit(substr)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("instr", str, 
lit_substr))
 }
 
@@ -1695,8 +1693,8 @@ func Sentences(string column.Column, language 
column.Column, country column.Colu
 //
 // Substring is the Golang equivalent of substring: (str: 'ColumnOrName', pos: 
int, len: int) -> pyspark.sql.connect.column.Column
 func Substring(str column.Column, pos int64, len int64) column.Column {
-       lit_pos := Lit(pos)
-       lit_len := Lit(len)
+       lit_pos := Int64Lit(pos)
+       lit_len := Int64Lit(len)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("substring", str, 
lit_pos, lit_len))
 }
 
@@ -1707,8 +1705,8 @@ func Substring(str column.Column, pos int64, len int64) 
column.Column {
 //
 // SubstringIndex is the Golang equivalent of substring_index: (str: 
'ColumnOrName', delim: str, count: int) -> pyspark.sql.connect.column.Column
 func SubstringIndex(str column.Column, delim string, count int64) 
column.Column {
-       lit_delim := Lit(delim)
-       lit_count := Lit(count)
+       lit_delim := StringLit(delim)
+       lit_count := Int64Lit(count)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("substring_index", 
str, lit_delim, lit_count))
 }
 
@@ -1716,7 +1714,7 @@ func SubstringIndex(str column.Column, delim string, 
count int64) column.Column
 //
 // Levenshtein is the Golang equivalent of levenshtein: (left: 'ColumnOrName', 
right: 'ColumnOrName', threshold: Optional[int] = None) -> 
pyspark.sql.connect.column.Column
 func Levenshtein(left column.Column, right column.Column, threshold int64) 
column.Column {
-       lit_threshold := Lit(threshold)
+       lit_threshold := Int64Lit(threshold)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("levenshtein", left, 
right, lit_threshold))
 }
 
@@ -1724,8 +1722,8 @@ func Levenshtein(left column.Column, right column.Column, 
threshold int64) colum
 //
 // Locate is the Golang equivalent of locate: (substr: str, str: 
'ColumnOrName', pos: int = 1) -> pyspark.sql.connect.column.Column
 func Locate(substr string, str column.Column, pos int64) column.Column {
-       lit_substr := Lit(substr)
-       lit_pos := Lit(pos)
+       lit_substr := StringLit(substr)
+       lit_pos := Int64Lit(pos)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("locate", lit_substr, 
str, lit_pos))
 }
 
@@ -1733,8 +1731,8 @@ func Locate(substr string, str column.Column, pos int64) 
column.Column {
 //
 // Lpad is the Golang equivalent of lpad: (col: 'ColumnOrName', len: int, pad: 
str) -> pyspark.sql.connect.column.Column
 func Lpad(col column.Column, len int64, pad string) column.Column {
-       lit_len := Lit(len)
-       lit_pad := Lit(pad)
+       lit_len := Int64Lit(len)
+       lit_pad := StringLit(pad)
        return column.NewColumn(column.NewUnresolvedFunctionWithColumns("lpad", 
col, lit_len, lit_pad))
 }
 
@@ -1742,8 +1740,8 @@ func Lpad(col column.Column, len int64, pad string) 
column.Column {
 //
 // Rpad is the Golang equivalent of rpad: (col: 'ColumnOrName', len: int, pad: 
str) -> pyspark.sql.connect.column.Column
 func Rpad(col column.Column, len int64, pad string) column.Column {
-       lit_len := Lit(len)
-       lit_pad := Lit(pad)
+       lit_len := Int64Lit(len)
+       lit_pad := StringLit(pad)
        return column.NewColumn(column.NewUnresolvedFunctionWithColumns("rpad", 
col, lit_len, lit_pad))
 }
 
@@ -1751,7 +1749,7 @@ func Rpad(col column.Column, len int64, pad string) 
column.Column {
 //
 // Repeat is the Golang equivalent of repeat: (col: 'ColumnOrName', n: int) -> 
pyspark.sql.connect.column.Column
 func Repeat(col column.Column, n int64) column.Column {
-       lit_n := Lit(n)
+       lit_n := Int64Lit(n)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("repeat", col, lit_n))
 }
 
@@ -1759,8 +1757,8 @@ func Repeat(col column.Column, n int64) column.Column {
 //
 // Split is the Golang equivalent of split: (str: 'ColumnOrName', pattern: 
str, limit: int = -1) -> pyspark.sql.connect.column.Column
 func Split(str column.Column, pattern string, limit int64) column.Column {
-       lit_pattern := Lit(pattern)
-       lit_limit := Lit(limit)
+       lit_pattern := StringLit(pattern)
+       lit_limit := Int64Lit(limit)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("split", str, 
lit_pattern, lit_limit))
 }
 
@@ -1798,8 +1796,8 @@ func RegexpCount(str column.Column, regexp column.Column) 
column.Column {
 //
 // RegexpExtract is the Golang equivalent of regexp_extract: (str: 
'ColumnOrName', pattern: str, idx: int) -> pyspark.sql.connect.column.Column
 func RegexpExtract(str column.Column, pattern string, idx int64) column.Column 
{
-       lit_pattern := Lit(pattern)
-       lit_idx := Lit(idx)
+       lit_pattern := StringLit(pattern)
+       lit_idx := Int64Lit(idx)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("regexp_extract", str, 
lit_pattern, lit_idx))
 }
 
@@ -1861,8 +1859,8 @@ func BitLength(col column.Column) column.Column {
 //
 // Translate is the Golang equivalent of translate: (srcCol: 'ColumnOrName', 
matching: str, replace: str) -> pyspark.sql.connect.column.Column
 func Translate(srcCol column.Column, matching string, replace string) 
column.Column {
-       lit_matching := Lit(matching)
-       lit_replace := Lit(replace)
+       lit_matching := StringLit(matching)
+       lit_replace := StringLit(replace)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("translate", srcCol, 
lit_matching, lit_replace))
 }
 
@@ -2214,7 +2212,7 @@ func Localtimestamp() column.Column {
 //
 // DateFormat is the Golang equivalent of date_format: (date: 'ColumnOrName', 
format: str) -> pyspark.sql.connect.column.Column
 func DateFormat(date column.Column, format string) column.Column {
-       lit_format := Lit(format)
+       lit_format := StringLit(format)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("date_format", date, 
lit_format))
 }
 
@@ -2391,7 +2389,7 @@ func AddMonths(start column.Column, months column.Column) 
column.Column {
 //
 // ToDate is the Golang equivalent of to_date: (col: 'ColumnOrName', format: 
Optional[str] = None) -> pyspark.sql.connect.column.Column
 func ToDate(col column.Column, format string) column.Column {
-       lit_format := Lit(format)
+       lit_format := StringLit(format)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("to_date", col, 
lit_format))
 }
 
@@ -2432,7 +2430,7 @@ func UnixSeconds(col column.Column) column.Column {
 //
 // ToTimestamp is the Golang equivalent of to_timestamp: (col: 'ColumnOrName', 
format: Optional[str] = None) -> pyspark.sql.connect.column.Column
 func ToTimestamp(col column.Column, format string) column.Column {
-       lit_format := Lit(format)
+       lit_format := StringLit(format)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("to_timestamp", col, 
lit_format))
 }
 
@@ -2518,7 +2516,7 @@ func XpathString(xml column.Column, path column.Column) 
column.Column {
 //
 // Trunc is the Golang equivalent of trunc: (date: 'ColumnOrName', format: 
str) -> pyspark.sql.connect.column.Column
 func Trunc(date column.Column, format string) column.Column {
-       lit_format := Lit(format)
+       lit_format := StringLit(format)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("trunc", date, 
lit_format))
 }
 
@@ -2526,7 +2524,7 @@ func Trunc(date column.Column, format string) 
column.Column {
 //
 // DateTrunc is the Golang equivalent of date_trunc: (format: str, timestamp: 
'ColumnOrName') -> pyspark.sql.connect.column.Column
 func DateTrunc(format string, timestamp column.Column) column.Column {
-       lit_format := Lit(format)
+       lit_format := StringLit(format)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("date_trunc", 
lit_format, timestamp))
 }
 
@@ -2535,7 +2533,7 @@ func DateTrunc(format string, timestamp column.Column) 
column.Column {
 //
 // NextDay is the Golang equivalent of next_day: (date: 'ColumnOrName', 
dayOfWeek: str) -> pyspark.sql.connect.column.Column
 func NextDay(date column.Column, dayOfWeek string) column.Column {
-       lit_dayOfWeek := Lit(dayOfWeek)
+       lit_dayOfWeek := StringLit(dayOfWeek)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("next_day", date, 
lit_dayOfWeek))
 }
 
@@ -2552,7 +2550,7 @@ func LastDay(date column.Column) column.Column {
 //
 // FromUnixtime is the Golang equivalent of from_unixtime: (timestamp: 
'ColumnOrName', format: str = 'yyyy-MM-dd HH:mm:ss') -> 
pyspark.sql.connect.column.Column
 func FromUnixtime(timestamp column.Column, format string) column.Column {
-       lit_format := Lit(format)
+       lit_format := StringLit(format)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("from_unixtime", 
timestamp, lit_format))
 }
 
@@ -2564,7 +2562,7 @@ func FromUnixtime(timestamp column.Column, format string) 
column.Column {
 //
 // UnixTimestamp is the Golang equivalent of unix_timestamp: (timestamp: 
Optional[ForwardRef('ColumnOrName')] = None, format: str = 'yyyy-MM-dd 
HH:mm:ss') -> pyspark.sql.connect.column.Column
 func UnixTimestamp(timestamp column.Column, format string) column.Column {
-       lit_format := Lit(format)
+       lit_format := StringLit(format)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("unix_timestamp", 
timestamp, lit_format))
 }
 
@@ -2646,9 +2644,9 @@ func TimestampMicros(col column.Column) column.Column {
 //
 // Window is the Golang equivalent of window: (timeColumn: 'ColumnOrName', 
windowDuration: str, slideDuration: Optional[str] = None, startTime: 
Optional[str] = None) -> pyspark.sql.connect.column.Column
 func Window(timeColumn column.Column, windowDuration string, slideDuration 
string, startTime string) column.Column {
-       lit_windowDuration := Lit(windowDuration)
-       lit_slideDuration := Lit(slideDuration)
-       lit_startTime := Lit(startTime)
+       lit_windowDuration := StringLit(windowDuration)
+       lit_slideDuration := StringLit(slideDuration)
+       lit_startTime := StringLit(startTime)
        return 
column.NewColumn(column.NewUnresolvedFunctionWithColumns("window", timeColumn,
                lit_windowDuration, lit_slideDuration, lit_startTime))
 }
@@ -2873,7 +2871,7 @@ func Sha1(col column.Column) column.Column {
 //
 // Sha2 is the Golang equivalent of sha2: (col: 'ColumnOrName', numBits: int) 
-> pyspark.sql.connect.column.Column
 func Sha2(col column.Column, numBits int64) column.Column {
-       lit_numBits := Lit(numBits)
+       lit_numBits := Int64Lit(numBits)
        return column.NewColumn(column.NewUnresolvedFunctionWithColumns("sha2", 
col, lit_numBits))
 }
 
@@ -3085,7 +3083,7 @@ func BitmapOrAgg(col column.Column) column.Column {
 //
 // CallFunction is the Golang equivalent of call_function: (funcName: str, 
*cols: 'ColumnOrName') -> pyspark.sql.connect.column.Column
 func CallFunction(funcName string, cols ...column.Column) column.Column {
-       lit_funcName := Lit(funcName)
+       lit_funcName := StringLit(funcName)
        vals := make([]column.Column, 0)
        vals = append(vals, lit_funcName)
        vals = append(vals, cols...)
diff --git a/spark/sql/group.go b/spark/sql/group.go
index 3f67e18..975dc50 100644
--- a/spark/sql/group.go
+++ b/spark/sql/group.go
@@ -19,6 +19,8 @@ package sql
 import (
        "context"
 
+       "github.com/apache/spark-connect-go/v35/spark/sql/types"
+
        proto "github.com/apache/spark-connect-go/v35/internal/generated"
        "github.com/apache/spark-connect-go/v35/spark/sparkerrors"
        "github.com/apache/spark-connect-go/v35/spark/sql/column"
@@ -29,7 +31,7 @@ type GroupedData struct {
        df           *dataFrameImpl
        groupType    string
        groupingCols []column.Convertible
-       pivotValues  []any
+       pivotValues  []types.LiteralType
        pivotCol     column.Convertible
 }
 
@@ -171,7 +173,7 @@ func (gd *GroupedData) Sum(ctx context.Context, cols 
...string) (DataFrame, erro
 
 // Count Computes the count value for each group.
 func (gd *GroupedData) Count(ctx context.Context) (DataFrame, error) {
-       return gd.Agg(ctx, functions.Count(functions.Lit(1)).Alias("count"))
+       return gd.Agg(ctx, 
functions.Count(functions.Lit(types.Int64(1))).Alias("count"))
 }
 
 // Mean Computes the average value for each numeric column for each group.
@@ -179,7 +181,7 @@ func (gd *GroupedData) Mean(ctx context.Context, cols 
...string) (DataFrame, err
        return gd.Avg(ctx, cols...)
 }
 
-func (gd *GroupedData) Pivot(ctx context.Context, pivotCol string, pivotValues 
[]any) (*GroupedData, error) {
+func (gd *GroupedData) Pivot(ctx context.Context, pivotCol string, pivotValues 
[]types.LiteralType) (*GroupedData, error) {
        if gd.groupType != "groupby" {
                if gd.groupType == "pivot" {
                        return nil, 
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivot cannot be applied 
on pivot")
diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go
index 58436be..4bf94cc 100644
--- a/spark/sql/sparksession.go
+++ b/spark/sql/sparksession.go
@@ -230,6 +230,10 @@ func (s *sparkSessionImpl) CreateDataFrame(ctx 
context.Context, data [][]any, sc
        // Iterate over all fields and add the values:
        for _, row := range data {
                for i, field := range schema.Fields {
+                       if row[i] == nil {
+                               rb.Field(i).AppendNull()
+                               continue
+                       }
                        switch field.DataType {
                        case types.BOOLEAN:
                                
rb.Field(i).(*array.BooleanBuilder).Append(row[i].(bool))
diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go
index 360bed9..2921881 100644
--- a/spark/sql/types/arrow.go
+++ b/spark/sql/types/arrow.go
@@ -75,77 +75,137 @@ func readArrayData(t arrow.Type, data arrow.ArrayData) 
([]any, error) {
        case arrow.BOOL:
                data := array.NewBooleanData(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.INT8:
                data := array.NewInt8Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.INT16:
                data := array.NewInt16Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.INT32:
                data := array.NewInt32Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.INT64:
                data := array.NewInt64Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.FLOAT16:
                data := array.NewFloat16Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.FLOAT32:
                data := array.NewFloat32Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.FLOAT64:
                data := array.NewFloat64Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.DECIMAL | arrow.DECIMAL128:
                data := array.NewDecimal128Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.DECIMAL256:
                data := array.NewDecimal256Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.STRING:
                data := array.NewStringData(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.BINARY:
                data := array.NewBinaryData(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.TIMESTAMP:
                data := array.NewTimestampData(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.DATE64:
                data := array.NewDate64Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.DATE32:
                data := array.NewDate32Data(data)
                for i := 0; i < data.Len(); i++ {
-                       buf = append(buf, data.Value(i))
+                       if data.IsNull(i) {
+                               buf = append(buf, nil)
+                       } else {
+                               buf = append(buf, data.Value(i))
+                       }
                }
        case arrow.LIST:
                data := array.NewListData(data)
diff --git a/spark/sql/types/builtin.go b/spark/sql/types/builtin.go
new file mode 100644
index 0000000..15f50d1
--- /dev/null
+++ b/spark/sql/types/builtin.go
@@ -0,0 +1,404 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package types
+
+import (
+       "context"
+
+       proto "github.com/apache/spark-connect-go/v35/internal/generated"
+)
+
+type LiteralType interface {
+       ToProto(ctx context.Context) (*proto.Expression, error)
+}
+
+type NumericLiteral interface {
+       LiteralType
+       // marker method for compile time safety.
+       isNumericLiteral()
+}
+
+type PrimitiveTypeLiteral interface {
+       LiteralType
+       isPrimitiveTypeLiteral()
+}
+
+type Int8 int8
+
+func (t Int8) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_Byte{Byte: int32(t)},
+                       },
+               },
+       }, nil
+}
+
+func (t Int8) isNumericLiteral() {}
+
+func (t Int8) isPrimitiveTypeLiteral() {}
+
+type Int16 int16
+
+func (t Int16) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_Short{Short: int32(t)},
+                       },
+               },
+       }, nil
+}
+
+func (t Int16) isNumericLiteral() {}
+
+func (t Int16) isPrimitiveTypeLiteral() {}
+
+type Int32 int32
+
+func (t Int32) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_Integer{Integer: int32(t)},
+                       },
+               },
+       }, nil
+}
+
+func (t Int32) isNumericLiteral() {}
+
+func (t Int32) isPrimitiveTypeLiteral() {}
+
+type Int64 int64
+
+func (t Int64) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_Long{Long: int64(t)},
+                       },
+               },
+       }, nil
+}
+
+func (t Int64) isNumericLiteral() {}
+
+func (t Int64) isPrimitiveTypeLiteral() {}
+
+type Int int
+
+func (t Int) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return Int64(t).ToProto(ctx)
+}
+
+func (t Int) isNumericLiteral() {}
+
+func (t Int) isPrimitiveTypeLiteral() {}
+
+type Float32 float32
+
+func (t Float32) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_Float{Float: float32(t)},
+                       },
+               },
+       }, nil
+}
+
+func (t Float32) isNumericLiteral() {}
+
+func (t Float32) isPrimitiveTypeLiteral() {}
+
+type Float64 float64
+
+func (t Float64) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_Double{Double: float64(t)},
+                       },
+               },
+       }, nil
+}
+
+func (t Float64) isNumericLiteral() {}
+
+func (t Float64) isPrimitiveTypeLiteral() {}
+
+type String string
+
+func (t String) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_String_{String_: string(t)},
+                       },
+               },
+       }, nil
+}
+
+func (t String) isPrimitiveTypeLiteral() {}
+
+type Boolean bool
+
+func (t Boolean) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_Boolean{Boolean: bool(t)},
+                       },
+               },
+       }, nil
+}
+
+func (t Boolean) isPrimitiveTypeLiteral() {}
+
+type Binary []byte
+
+func (t Binary) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: 
&proto.Expression_Literal_Binary{Binary: t},
+                       },
+               },
+       }, nil
+}
+
+type Int8NilType struct{}
+
+var Int8Nil = Int8NilType{}
+
+func (t Int8NilType) isNumericLiteral() {}
+
+func (t Int8NilType) isPrimitiveTypeLiteral() {}
+
+func (t Int8NilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_Byte_{
+                                                       Byte: 
&proto.DataType_Byte{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
+
+type Int16NilType struct{}
+
+var Int16Nil = Int16NilType{}
+
+func (t Int16NilType) isNumericLiteral() {}
+
+func (t Int16NilType) isPrimitiveTypeLiteral() {}
+
+func (t Int16NilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_Short_{
+                                                       Short: 
&proto.DataType_Short{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
+
+type Int32NilType struct{}
+
+var Int32Nil = Int32NilType{}
+
+func (t Int32NilType) isNumericLiteral() {}
+
+func (t Int32NilType) isPrimitiveTypeLiteral() {}
+
+func (t Int32NilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_Integer_{
+                                                       Integer: 
&proto.DataType_Integer{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
+
+type Int64NilType struct{}
+
+var Int64Nil = Int64NilType{}
+
+func (t Int64NilType) isNumericLiteral() {}
+
+func (t Int64NilType) isPrimitiveTypeLiteral() {}
+
+func (t Int64NilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_Long_{
+                                                       Long: 
&proto.DataType_Long{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
+
+type IntNilType struct{}
+
+var IntNil = IntNilType{}
+
+func (t IntNilType) isNumericLiteral() {}
+
+func (t IntNilType) isPrimitiveTypeLiteral() {}
+
+func (t IntNilType) ToProto(ctx context.Context) (*proto.Expression, error) {
+       return Int64NilType{}.ToProto(ctx)
+}
+
+type Float32NilType struct{}
+
+var Float32Nil = Float32NilType{}
+
+func (t Float32NilType) isNumericLiteral() {}
+
+func (t Float32NilType) isPrimitiveTypeLiteral() {}
+
+func (t Float32NilType) ToProto(ctx context.Context) (*proto.Expression, 
error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_Float_{
+                                                       Float: 
&proto.DataType_Float{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
+
+type Float64NilType struct{}
+
+var Float64Nil = Float64NilType{}
+
+func (t Float64NilType) isNumericLiteral() {}
+
+func (t Float64NilType) isPrimitiveTypeLiteral() {}
+
+func (t Float64NilType) ToProto(ctx context.Context) (*proto.Expression, 
error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_Double_{
+                                                       Double: 
&proto.DataType_Double{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
+
+type StringNilType struct{}
+
+var StringNil = StringNilType{}
+
+func (t StringNilType) isPrimitiveTypeLiteral() {}
+
+func (t StringNilType) ToProto(ctx context.Context) (*proto.Expression, error) 
{
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_String_{
+                                                       String_: 
&proto.DataType_String{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
+
+type BooleanNilType struct{}
+
+var BooleanNil = BooleanNilType{}
+
+func (t BooleanNilType) isPrimitiveTypeLiteral() {}
+
+func (t BooleanNilType) ToProto(ctx context.Context) (*proto.Expression, 
error) {
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_Boolean_{
+                                                       Boolean: 
&proto.DataType_Boolean{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
+
+type BinaryNilType struct{}
+
+var BinaryNil = BinaryNilType{}
+
+func (t BinaryNilType) ToProto(ctx context.Context) (*proto.Expression, error) 
{
+       return &proto.Expression{
+               ExprType: &proto.Expression_Literal_{
+                       Literal: &proto.Expression_Literal{
+                               LiteralType: &proto.Expression_Literal_Null{
+                                       Null: &proto.DataType{
+                                               Kind: &proto.DataType_Binary_{
+                                                       Binary: 
&proto.DataType_Binary{},
+                                               },
+                                       },
+                               },
+                       },
+               },
+       }, nil
+}
diff --git a/spark/sql/types/builtin_test.go b/spark/sql/types/builtin_test.go
new file mode 100644
index 0000000..0a326ac
--- /dev/null
+++ b/spark/sql/types/builtin_test.go
@@ -0,0 +1,83 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package types
+
+import (
+       "context"
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestBuiltinTypes(t *testing.T) {
+       p, err := Int8(1).ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetByte(), int32(1))
+
+       p, err = Int16(1).ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetShort(), int32(1))
+
+       p, err = Int32(1).ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetInteger(), int32(1))
+
+       p, err = Int64(1).ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetLong(), int64(1))
+
+       p, err = Float32(1.0).ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetFloat(), float32(1.0))
+
+       p, err = Float64(1.0).ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetDouble(), float64(1.0))
+
+       p, err = String("1").ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetString_(), "1")
+
+       p, err = Boolean(true).ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetBoolean(), true)
+
+       p, err = Binary([]byte{1}).ToProto(context.TODO())
+       assert.NoError(t, err)
+       assert.Equal(t, p.GetLiteral().GetBinary(), []byte{1})
+}
+
+func testMe(n NumericLiteral) bool {
+       return true
+}
+
+func testPrimitive(p PrimitiveTypeLiteral) bool {
+       return true
+}
+
+func TestNumericTypes(t *testing.T) {
+       assert.True(t, testMe(Int8(1)))
+       assert.True(t, testMe(Int16(1)))
+       assert.True(t, testMe(Int32(1)))
+       assert.True(t, testMe(Int64(1)))
+       assert.True(t, testMe(Float32(1.0)))
+       assert.True(t, testMe(Float64(1.0)))
+
+       assert.True(t, testPrimitive(String("a")))
+       assert.True(t, testPrimitive(Boolean(true)))
+       assert.True(t, testPrimitive(Int16(1)))
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to