This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e56e84  Added implementations of FillNa
1e56e84 is described below

commit 1e56e844b96c6e19e91ebbca88b0c95531062325
Author: Alex Ott <[email protected]>
AuthorDate: Tue Dec 31 13:09:36 2024 +0100

    Added implementations of FillNa
    
    ### What changes were proposed in this pull request?
    
    This PR  adds `FillNa` and `FillNaWithValues` functions to `DataFrame`
    
    ### Why are the changes needed?
    
    Added missing functions
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Closes #94 from alexott/fillna-implementation.
    
    Authored-by: Alex Ott <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go | 53 ++++++++++++++++++++++++++++
 spark/sql/dataframe.go                       | 46 ++++++++++++++++++++++++
 2 files changed, 99 insertions(+)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index e00b375..232657b 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -31,6 +31,7 @@ import (
 
        "github.com/apache/spark-connect-go/v35/spark/sql"
        "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
 )
 
 func TestDataFrame_Select(t *testing.T) {
@@ -802,3 +803,55 @@ func TestDataFrame_Unpivot(t *testing.T) {
        assert.NoError(t, err)
        assert.Equal(t, int64(4), cnt)
 }
+
+func TestDataFrame_FillNa(t *testing.T) {
+       ctx, spark := connect()
+       file, err := os.CreateTemp("", "fillna")
+       defer os.Remove(file.Name())
+       assert.NoError(t, err)
+       defer file.Close()
+       _, err = file.WriteString(`{"id":1,"int":null, "int2": 1}
+{"id":null,"int":12, "int2": null}
+`)
+       assert.NoError(t, err)
+
+       df, err := spark.Read().Format("json").
+               Option("inferSchema", "true").
+               Load(file.Name())
+       assert.NoError(t, err)
+
+       // all columns
+       filled, err := df.FillNa(ctx, types.Int64(10))
+       assert.NoError(t, err)
+       sorted, err := filled.Sort(ctx, functions.Col("id").Asc())
+       assert.NoError(t, err)
+       res, err := sorted.Collect(ctx)
+       assert.NoError(t, err)
+       require.Equal(t, 2, len(res))
+       assert.Equal(t, []any{int64(1), int64(10), int64(1)}, res[0].Values())
+       assert.Equal(t, []any{int64(10), int64(12), int64(10)}, res[1].Values())
+
+       // specific columns
+       filled, err = df.FillNa(ctx, types.Int64(10), "int", "int2")
+       assert.NoError(t, err)
+       sorted, err = filled.Sort(ctx, functions.Col("id").Asc())
+       assert.NoError(t, err)
+       res, err = sorted.Collect(ctx)
+       assert.NoError(t, err)
+       require.Equal(t, 2, len(res))
+       assert.Equal(t, []any{nil, int64(12), int64(10)}, res[0].Values())
+       assert.Equal(t, []any{int64(1), int64(10), int64(1)}, res[1].Values())
+
+       // specific columns with map
+       filled, err = df.FillNaWithValues(ctx, 
map[string]types.PrimitiveTypeLiteral{
+               "int": types.Int64(10), "int2": types.Int64(20),
+       })
+       assert.NoError(t, err)
+       sorted, err = filled.Sort(ctx, functions.Col("id").Asc())
+       assert.NoError(t, err)
+       res, err = sorted.Collect(ctx)
+       assert.NoError(t, err)
+       require.Equal(t, 2, len(res))
+       assert.Equal(t, []any{nil, int64(12), int64(20)}, res[0].Values())
+       assert.Equal(t, []any{int64(1), int64(10), int64(1)}, res[1].Values())
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 65f9484..5feb6a3 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -108,6 +108,10 @@ type DataFrame interface {
        ExceptAll(ctx context.Context, other DataFrame) DataFrame
        // Explain returns the string explain plan for the current DataFrame 
according to the explainMode.
        Explain(ctx context.Context, explainMode utils.ExplainMode) (string, 
error)
+       // FillNa replaces null values with specified value.
+       FillNa(ctx context.Context, value types.PrimitiveTypeLiteral, columns 
...string) (DataFrame, error)
+       // FillNaWithValues replaces null values in specified columns (key of 
the map) with values.
+       FillNaWithValues(ctx context.Context, values 
map[string]types.PrimitiveTypeLiteral) (DataFrame, error)
        // Filter filters the data frame by a column condition.
        Filter(ctx context.Context, condition column.Convertible) (DataFrame, 
error)
        // FilterByString filters the data frame by a string condition.
@@ -1426,3 +1430,45 @@ func (df *dataFrameImpl) Unpivot(ctx context.Context,
        }
        return NewDataFrame(df.session, rel), nil
 }
+
+func makeDataframeWithFillNaRelation(df *dataFrameImpl, values 
[]*proto.Expression_Literal, columns []string) DataFrame {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_FillNa{
+                       FillNa: &proto.NAFill{
+                               Input:  df.relation,
+                               Cols:   columns,
+                               Values: values,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) FillNa(ctx context.Context, value 
types.PrimitiveTypeLiteral, columns ...string) (DataFrame, error) {
+       valueLiteral, err := value.ToProto(ctx)
+       if err != nil {
+               return nil, err
+       }
+       return makeDataframeWithFillNaRelation(df, []*proto.Expression_Literal{
+               valueLiteral.GetLiteral(),
+       }, columns), nil
+}
+
+func (df *dataFrameImpl) FillNaWithValues(ctx context.Context,
+       values map[string]types.PrimitiveTypeLiteral,
+) (DataFrame, error) {
+       valueLiterals := make([]*proto.Expression_Literal, 0, len(values))
+       columns := make([]string, 0, len(values))
+       for k, v := range values {
+               valueLiteral, err := v.ToProto(ctx)
+               if err != nil {
+                       return nil, err
+               }
+               valueLiterals = append(valueLiterals, valueLiteral.GetLiteral())
+               columns = append(columns, k)
+       }
+       return makeDataframeWithFillNaRelation(df, valueLiterals, columns), nil
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to