This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new caf2a33d380 [SPARK-43230][CONNECT] Simplify
`DataFrameNaFunctions.fillna`
caf2a33d380 is described below
commit caf2a33d380a967e6e528893145782b7e902eaa9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Apr 24 14:37:14 2023 +0800
[SPARK-43230][CONNECT] Simplify `DataFrameNaFunctions.fillna`
### What changes were proposed in this pull request?
add a helper function in `DataFrameNaFunctions`
### Why are the changes needed?
to Simplify `DataFrameNaFunctions.fillna`
### Does this PR introduce _any_ user-facing change?
no, dev-only
### How was this patch tested?
existing UTs
Closes #40898 from zhengruifeng/connect_simplify_fillna.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../CheckConnectJvmClientCompatibility.scala | 1 +
.../sql/connect/planner/SparkConnectPlanner.scala | 31 +++-------------------
.../apache/spark/sql/DataFrameNaFunctions.scala | 4 +++
3 files changed, 8 insertions(+), 28 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 7520588580a..65393780841 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -145,6 +145,7 @@ object CheckConnectJvmClientCompatibility {
// DataFrameNaFunctions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.this"),
+
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"),
// DataFrameStatFunctions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"),
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ba394396077..7bc67f8c398 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -323,34 +323,9 @@ class SparkConnectPlanner(val session: SparkSession) {
val cols = rel.getColsList.asScala.toArray
val values = rel.getValuesList.asScala.toArray
if (values.length == 1) {
- val value = values.head
- value.getLiteralTypeCase match {
- case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
- if (cols.nonEmpty) {
- dataset.na.fill(value = value.getBoolean, cols = cols).logicalPlan
- } else {
- dataset.na.fill(value = value.getBoolean).logicalPlan
- }
- case proto.Expression.Literal.LiteralTypeCase.LONG =>
- if (cols.nonEmpty) {
- dataset.na.fill(value = value.getLong, cols = cols).logicalPlan
- } else {
- dataset.na.fill(value = value.getLong).logicalPlan
- }
- case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
- if (cols.nonEmpty) {
- dataset.na.fill(value = value.getDouble, cols = cols).logicalPlan
- } else {
- dataset.na.fill(value = value.getDouble).logicalPlan
- }
- case proto.Expression.Literal.LiteralTypeCase.STRING =>
- if (cols.nonEmpty) {
- dataset.na.fill(value = value.getString, cols = cols).logicalPlan
- } else {
- dataset.na.fill(value = value.getString).logicalPlan
- }
- case other => throw InvalidPlanInput(s"Unsupported value type: $other")
- }
+ val value = LiteralValueProtoConverter.toCatalystValue(values.head)
+ val columns = if (cols.nonEmpty) Some(cols.toSeq) else None
+ dataset.na.fillValue(value, columns).logicalPlan
} else {
val valueMap = mutable.Map.empty[String, Any]
cols.zip(values).foreach { case (col, value) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index cfaf443db5b..91da789cd77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -503,6 +503,10 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
df.filter(Column(predicate))
}
+ private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame
= {
+ fillValue(value, cols.map(toAttributes).getOrElse(outputAttributes))
+ }
+
/**
* Returns a new `DataFrame` that replaces null or NaN values in the
specified
* columns. If a specified column is not a numeric, string or boolean column,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]