This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new caf2a33d380 [SPARK-43230][CONNECT] Simplify 
`DataFrameNaFunctions.fillna`
caf2a33d380 is described below

commit caf2a33d380a967e6e528893145782b7e902eaa9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Apr 24 14:37:14 2023 +0800

    [SPARK-43230][CONNECT] Simplify `DataFrameNaFunctions.fillna`
    
    ### What changes were proposed in this pull request?
    add a helper function in `DataFrameNaFunctions`
    
    ### Why are the changes needed?
    to Simplify `DataFrameNaFunctions.fillna`
    
    ### Does this PR introduce _any_ user-facing change?
    no, dev-only
    
    ### How was this patch tested?
    existing UTs
    
    Closes #40898 from zhengruifeng/connect_simplify_fillna.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../CheckConnectJvmClientCompatibility.scala       |  1 +
 .../sql/connect/planner/SparkConnectPlanner.scala  | 31 +++-------------------
 .../apache/spark/sql/DataFrameNaFunctions.scala    |  4 +++
 3 files changed, 8 insertions(+), 28 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 7520588580a..65393780841 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -145,6 +145,7 @@ object CheckConnectJvmClientCompatibility {
 
       // DataFrameNaFunctions
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.this"),
+      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"),
 
       // DataFrameStatFunctions
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ba394396077..7bc67f8c398 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -323,34 +323,9 @@ class SparkConnectPlanner(val session: SparkSession) {
     val cols = rel.getColsList.asScala.toArray
     val values = rel.getValuesList.asScala.toArray
     if (values.length == 1) {
-      val value = values.head
-      value.getLiteralTypeCase match {
-        case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
-          if (cols.nonEmpty) {
-            dataset.na.fill(value = value.getBoolean, cols = cols).logicalPlan
-          } else {
-            dataset.na.fill(value = value.getBoolean).logicalPlan
-          }
-        case proto.Expression.Literal.LiteralTypeCase.LONG =>
-          if (cols.nonEmpty) {
-            dataset.na.fill(value = value.getLong, cols = cols).logicalPlan
-          } else {
-            dataset.na.fill(value = value.getLong).logicalPlan
-          }
-        case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
-          if (cols.nonEmpty) {
-            dataset.na.fill(value = value.getDouble, cols = cols).logicalPlan
-          } else {
-            dataset.na.fill(value = value.getDouble).logicalPlan
-          }
-        case proto.Expression.Literal.LiteralTypeCase.STRING =>
-          if (cols.nonEmpty) {
-            dataset.na.fill(value = value.getString, cols = cols).logicalPlan
-          } else {
-            dataset.na.fill(value = value.getString).logicalPlan
-          }
-        case other => throw InvalidPlanInput(s"Unsupported value type: $other")
-      }
+      val value = LiteralValueProtoConverter.toCatalystValue(values.head)
+      val columns = if (cols.nonEmpty) Some(cols.toSeq) else None
+      dataset.na.fillValue(value, columns).logicalPlan
     } else {
       val valueMap = mutable.Map.empty[String, Any]
       cols.zip(values).foreach { case (col, value) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index cfaf443db5b..91da789cd77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -503,6 +503,10 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
     df.filter(Column(predicate))
   }
 
+  private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame 
= {
+    fillValue(value, cols.map(toAttributes).getOrElse(outputAttributes))
+  }
+
   /**
    * Returns a new `DataFrame` that replaces null or NaN values in the 
specified
    * columns. If a specified column is not a numeric, string or boolean column,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to