This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 661064a7a38 [SPARK-41148][CONNECT][PYTHON] Implement `DataFrame.dropna` and `DataFrame.na.drop` 661064a7a38 is described below commit 661064a7a3811da27da0d5b024764238d2a1fb3f Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Nov 29 13:27:33 2022 +0800 [SPARK-41148][CONNECT][PYTHON] Implement `DataFrame.dropna` and `DataFrame.na.drop` ### What changes were proposed in this pull request? Implement `DataFrame.dropna ` and `DataFrame.na.drop` ### Why are the changes needed? For API coverage ### Does this PR introduce _any_ user-facing change? yes, new method ### How was this patch tested? added UT Closes #38819 from zhengruifeng/connect_df_na_drop. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 25 ++++ .../org/apache/spark/sql/connect/dsl/package.scala | 33 +++++ .../sql/connect/planner/SparkConnectPlanner.scala | 18 +++ .../connect/planner/SparkConnectProtoSuite.scala | 19 +++ python/pyspark/sql/connect/dataframe.py | 81 +++++++++++ python/pyspark/sql/connect/plan.py | 39 +++++ python/pyspark/sql/connect/proto/relations_pb2.py | 158 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 77 ++++++++++ .../sql/tests/connect/test_connect_basic.py | 32 +++++ .../sql/tests/connect/test_connect_plan_only.py | 16 +++ 10 files changed, 426 insertions(+), 72 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index a676871c9e0..cbdf6311657 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -55,6 +55,7 @@ message Relation { // NA functions NAFill fill_na = 90; + NADrop drop_na = 91; // stat functions StatSummary summary = 100; @@ -440,6 +441,30 @@ message NAFill { repeated Expression.Literal values = 3; } + +// Drop rows containing null values. +// It will invoke 'Dataset.na.drop' (same as 'DataFrameNaFunctions.drop') to compute the results. +message NADrop { + // (Required) The input relation. + Relation input = 1; + + // (Optional) Optional list of column names to consider. + // + // When it is empty, all the columns in the input relation will be considered. + repeated string cols = 2; + + // (Optional) The minimum number of non-null and non-NaN values required to keep. + // + // When not set, it is equivalent to the number of considered columns, which means + // a row will be kept only if all columns are non-null. + // + // 'how' options ('all', 'any') can be easily converted to this field: + // - 'all' -> set 'min_non_nulls' 1; + // - 'any' -> keep 'min_non_nulls' unset; + optional int32 min_non_nulls = 3; +} + + // Rename columns on the input relation by the same length of names. message RenameColumnsBySameLengthNames { // (Required) The input relation of RenameColumnsBySameLengthNames. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 61d7abe9e15..dd1c7f0574b 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -278,6 +278,39 @@ package object dsl { .build()) .build() } + + def drop( + how: Option[String] = None, + minNonNulls: Option[Int] = None, + cols: Seq[String] = Seq.empty): Relation = { + require(!(how.nonEmpty && minNonNulls.nonEmpty)) + require(how.isEmpty || Seq("any", "all").contains(how.get)) + + val dropna = proto.NADrop + .newBuilder() + .setInput(logicalPlan) + + if (cols.nonEmpty) { + dropna.addAllCols(cols.asJava) + } + + var _minNonNulls = -1 + how match { + case Some("all") => _minNonNulls = 1 + case _ => + } + if (minNonNulls.nonEmpty) { + _minNonNulls = minNonNulls.get + } + if (_minNonNulls > 0) { + dropna.setMinNonNulls(_minNonNulls) + } + + Relation + .newBuilder() + .setDropNa(dropna.build()) + .build() + } } implicit class DslStatFunctions(val logicalPlan: Relation) { diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index b4eaa03df5d..b2b6e6ffc54 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -83,6 +83,7 @@ class SparkConnectPlanner(session: SparkSession) { transformSubqueryAlias(rel.getSubqueryAlias) case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition) case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa) + case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa) case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary) case proto.Relation.RelTypeCase.CROSSTAB => transformStatCrosstab(rel.getCrosstab) @@ -212,6 +213,23 @@ class SparkConnectPlanner(session: SparkSession) { } } + private def transformNADrop(rel: proto.NADrop): LogicalPlan = { + val dataset = Dataset.ofRows(session, transformRelation(rel.getInput)) + + val cols = rel.getColsList.asScala.toArray + + (cols.nonEmpty, rel.hasMinNonNulls) match { + case (true, true) => + dataset.na.drop(minNonNulls = rel.getMinNonNulls, cols = cols).logicalPlan + case (true, false) => + dataset.na.drop(cols = cols).logicalPlan + case (false, true) => + dataset.na.drop(minNonNulls = rel.getMinNonNulls).logicalPlan + case (false, false) => + dataset.na.drop().logicalPlan + } + } + private def transformStatSummary(rel: proto.StatSummary): LogicalPlan = { Dataset .ofRows(session, transformRelation(rel.getInput)) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 6c4297f5437..63bd3eccf17 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -342,6 +342,25 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { sparkTestRelation.na.fill(Map("id" -> 1L, "name" -> "xyz"))) } + test("SPARK-41148: Test drop na") { + comparePlans(connectTestRelation.na.drop(), sparkTestRelation.na.drop()) + comparePlans( + connectTestRelation.na.drop(cols = Seq("id")), + sparkTestRelation.na.drop(cols = Seq("id"))) + comparePlans( + connectTestRelation.na.drop(how = Some("all")), + sparkTestRelation.na.drop(how = "all")) + comparePlans( + connectTestRelation.na.drop(how = Some("all"), cols = Seq("id", "name")), + sparkTestRelation.na.drop(how = "all", cols = Seq("id", "name"))) + comparePlans( + connectTestRelation.na.drop(minNonNulls = Some(1)), + sparkTestRelation.na.drop(minNonNulls = 1)) + comparePlans( + connectTestRelation.na.drop(minNonNulls = Some(1), cols = Seq("id", "name")), + sparkTestRelation.na.drop(minNonNulls = 1, cols = Seq("id", "name"))) + } + test("Test summary") { comparePlans( connectTestRelation.summary("count", "mean", "stddev"), diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index f157835f4a4..725c7fc90da 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -873,6 +873,77 @@ class DataFrame(object): session=self._session, ) + def dropna( + self, + how: str = "any", + thresh: Optional[int] = None, + subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, + ) -> "DataFrame": + """Returns a new :class:`DataFrame` omitting rows with null values. + :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + how : str, optional + 'any' or 'all'. + If 'any', drop a row if it contains any nulls. + If 'all', drop a row only if all its values are null. + thresh: int, optional + default None + If specified, drop rows that have less than `thresh` non-null values. + This overwrites the `how` parameter. + subset : str, tuple or list, optional + optional list of column names to consider. + + Returns + ------- + :class:`DataFrame` + DataFrame with null only rows excluded. + """ + min_non_nulls: Optional[int] = None + + if how is not None: + if not isinstance(how, str): + raise TypeError(f"how should be a str, but got {type(how).__name__}") + if how == "all": + min_non_nulls = 1 + elif how == "any": + min_non_nulls = None + else: + raise ValueError("how ('" + how + "') should be 'any' or 'all'") + + if thresh is not None: + if not isinstance(thresh, int): + raise TypeError(f"thresh should be a int, but got {type(thresh).__name__}") + + # 'thresh' overwrites 'how' + min_non_nulls = thresh + + _cols: List[str] = [] + if subset is not None: + if isinstance(subset, str): + _cols = [subset] + elif isinstance(subset, (tuple, list)): + for c in subset: + if not isinstance(c, str): + raise TypeError( + f"cols should be a str, tuple[str] or list[str], " + f"but got {type(c).__name__}" + ) + _cols = list(subset) + else: + raise TypeError( + f"cols should be a str, tuple[str] or list[str], " + f"but got {type(subset).__name__}" + ) + + return DataFrame.withPlan( + plan.NADrop(child=self._plan, cols=_cols, min_non_nulls=min_non_nulls), + session=self._session, + ) + @property def stat(self) -> "DataFrameStatFunctions": """Returns a :class:`DataFrameStatFunctions` for statistic functions. @@ -1312,6 +1383,16 @@ class DataFrameNaFunctions: fill.__doc__ = DataFrame.fillna.__doc__ + def drop( + self, + how: str = "any", + thresh: Optional[int] = None, + subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, + ) -> DataFrame: + return self.df.dropna(how=how, thresh=thresh, subset=subset) + + drop.__doc__ = DataFrame.dropna.__doc__ + class DataFrameStatFunctions: """Functionality for statistic functions with :class:`DataFrame`. diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 805628cfe5b..0f611654ee5 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -995,6 +995,45 @@ class NAFill(LogicalPlan): """ +class NADrop(LogicalPlan): + def __init__( + self, + child: Optional["LogicalPlan"], + cols: Optional[List[str]], + min_non_nulls: Optional[int], + ) -> None: + super().__init__(child) + + self.cols = cols + self.min_non_nulls = min_non_nulls + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.drop_na.input.CopyFrom(self._child.plan(session)) + if self.cols is not None and len(self.cols) > 0: + plan.drop_na.cols.extend(self.cols) + if self.min_non_nulls is not None: + plan.drop_na.min_non_nulls = self.min_non_nulls + return plan + + def print(self, indent: int = 0) -> str: + i = " " * indent + return f"{i}" f"<NADrop cols='{self.cols}' " f"min_non_nulls='{self.min_non_nulls}'>" + + def _repr_html_(self) -> str: + return f""" + <ul> + <li> + <b>NADrop</b><br /> + Cols: {self.cols} <br /> + Min_non_nulls: {self.min_non_nulls} <br /> + {self._child_repr_()} + </li> + </ul> + """ + + class StatSummary(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> None: super().__init__(child) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 5ac2e8d7952..856fbe5b68b 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xfc\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xae\x0c\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...] ) @@ -66,6 +66,7 @@ _SHOWSTRING = DESCRIPTOR.message_types_by_name["ShowString"] _STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"] _STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"] _NAFILL = DESCRIPTOR.message_types_by_name["NAFill"] +_NADROP = DESCRIPTOR.message_types_by_name["NADrop"] _RENAMECOLUMNSBYSAMELENGTHNAMES = DESCRIPTOR.message_types_by_name["RenameColumnsBySameLengthNames"] _RENAMECOLUMNSBYNAMETONAMEMAP = DESCRIPTOR.message_types_by_name["RenameColumnsByNameToNameMap"] _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY = ( @@ -390,6 +391,17 @@ NAFill = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(NAFill) +NADrop = _reflection.GeneratedProtocolMessageType( + "NADrop", + (_message.Message,), + { + "DESCRIPTOR": _NADROP, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.NADrop) + }, +) +_sym_db.RegisterMessage(NADrop) + RenameColumnsBySameLengthNames = _reflection.GeneratedProtocolMessageType( "RenameColumnsBySameLengthNames", (_message.Message,), @@ -431,75 +443,77 @@ if _descriptor._USE_C_DESCRIPTORS == False: _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 1614 - _UNKNOWN._serialized_start = 1616 - _UNKNOWN._serialized_end = 1625 - _RELATIONCOMMON._serialized_start = 1627 - _RELATIONCOMMON._serialized_end = 1676 - _SQL._serialized_start = 1678 - _SQL._serialized_end = 1705 - _READ._serialized_start = 1708 - _READ._serialized_end = 2134 - _READ_NAMEDTABLE._serialized_start = 1850 - _READ_NAMEDTABLE._serialized_end = 1911 - _READ_DATASOURCE._serialized_start = 1914 - _READ_DATASOURCE._serialized_end = 2121 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2052 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2110 - _PROJECT._serialized_start = 2136 - _PROJECT._serialized_end = 2253 - _FILTER._serialized_start = 2255 - _FILTER._serialized_end = 2367 - _JOIN._serialized_start = 2370 - _JOIN._serialized_end = 2820 - _JOIN_JOINTYPE._serialized_start = 2633 - _JOIN_JOINTYPE._serialized_end = 2820 - _SETOPERATION._serialized_start = 2823 - _SETOPERATION._serialized_end = 3219 - _SETOPERATION_SETOPTYPE._serialized_start = 3082 - _SETOPERATION_SETOPTYPE._serialized_end = 3196 - _LIMIT._serialized_start = 3221 - _LIMIT._serialized_end = 3297 - _OFFSET._serialized_start = 3299 - _OFFSET._serialized_end = 3378 - _TAIL._serialized_start = 3380 - _TAIL._serialized_end = 3455 - _AGGREGATE._serialized_start = 3458 - _AGGREGATE._serialized_end = 3668 - _SORT._serialized_start = 3671 - _SORT._serialized_end = 4221 - _SORT_SORTFIELD._serialized_start = 3825 - _SORT_SORTFIELD._serialized_end = 4013 - _SORT_SORTDIRECTION._serialized_start = 4015 - _SORT_SORTDIRECTION._serialized_end = 4123 - _SORT_SORTNULLS._serialized_start = 4125 - _SORT_SORTNULLS._serialized_end = 4207 - _DROP._serialized_start = 4223 - _DROP._serialized_end = 4323 - _DEDUPLICATE._serialized_start = 4326 - _DEDUPLICATE._serialized_end = 4497 - _LOCALRELATION._serialized_start = 4499 - _LOCALRELATION._serialized_end = 4534 - _SAMPLE._serialized_start = 4537 - _SAMPLE._serialized_end = 4761 - _RANGE._serialized_start = 4764 - _RANGE._serialized_end = 4909 - _SUBQUERYALIAS._serialized_start = 4911 - _SUBQUERYALIAS._serialized_end = 5025 - _REPARTITION._serialized_start = 5028 - _REPARTITION._serialized_end = 5170 - _SHOWSTRING._serialized_start = 5173 - _SHOWSTRING._serialized_end = 5314 - _STATSUMMARY._serialized_start = 5316 - _STATSUMMARY._serialized_end = 5408 - _STATCROSSTAB._serialized_start = 5410 - _STATCROSSTAB._serialized_end = 5511 - _NAFILL._serialized_start = 5514 - _NAFILL._serialized_end = 5648 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5650 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5764 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5767 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6026 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5959 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6026 + _RELATION._serialized_end = 1664 + _UNKNOWN._serialized_start = 1666 + _UNKNOWN._serialized_end = 1675 + _RELATIONCOMMON._serialized_start = 1677 + _RELATIONCOMMON._serialized_end = 1726 + _SQL._serialized_start = 1728 + _SQL._serialized_end = 1755 + _READ._serialized_start = 1758 + _READ._serialized_end = 2184 + _READ_NAMEDTABLE._serialized_start = 1900 + _READ_NAMEDTABLE._serialized_end = 1961 + _READ_DATASOURCE._serialized_start = 1964 + _READ_DATASOURCE._serialized_end = 2171 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2102 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2160 + _PROJECT._serialized_start = 2186 + _PROJECT._serialized_end = 2303 + _FILTER._serialized_start = 2305 + _FILTER._serialized_end = 2417 + _JOIN._serialized_start = 2420 + _JOIN._serialized_end = 2870 + _JOIN_JOINTYPE._serialized_start = 2683 + _JOIN_JOINTYPE._serialized_end = 2870 + _SETOPERATION._serialized_start = 2873 + _SETOPERATION._serialized_end = 3269 + _SETOPERATION_SETOPTYPE._serialized_start = 3132 + _SETOPERATION_SETOPTYPE._serialized_end = 3246 + _LIMIT._serialized_start = 3271 + _LIMIT._serialized_end = 3347 + _OFFSET._serialized_start = 3349 + _OFFSET._serialized_end = 3428 + _TAIL._serialized_start = 3430 + _TAIL._serialized_end = 3505 + _AGGREGATE._serialized_start = 3508 + _AGGREGATE._serialized_end = 3718 + _SORT._serialized_start = 3721 + _SORT._serialized_end = 4271 + _SORT_SORTFIELD._serialized_start = 3875 + _SORT_SORTFIELD._serialized_end = 4063 + _SORT_SORTDIRECTION._serialized_start = 4065 + _SORT_SORTDIRECTION._serialized_end = 4173 + _SORT_SORTNULLS._serialized_start = 4175 + _SORT_SORTNULLS._serialized_end = 4257 + _DROP._serialized_start = 4273 + _DROP._serialized_end = 4373 + _DEDUPLICATE._serialized_start = 4376 + _DEDUPLICATE._serialized_end = 4547 + _LOCALRELATION._serialized_start = 4549 + _LOCALRELATION._serialized_end = 4584 + _SAMPLE._serialized_start = 4587 + _SAMPLE._serialized_end = 4811 + _RANGE._serialized_start = 4814 + _RANGE._serialized_end = 4959 + _SUBQUERYALIAS._serialized_start = 4961 + _SUBQUERYALIAS._serialized_end = 5075 + _REPARTITION._serialized_start = 5078 + _REPARTITION._serialized_end = 5220 + _SHOWSTRING._serialized_start = 5223 + _SHOWSTRING._serialized_end = 5364 + _STATSUMMARY._serialized_start = 5366 + _STATSUMMARY._serialized_end = 5458 + _STATCROSSTAB._serialized_start = 5460 + _STATCROSSTAB._serialized_end = 5561 + _NAFILL._serialized_start = 5564 + _NAFILL._serialized_end = 5698 + _NADROP._serialized_start = 5701 + _NADROP._serialized_end = 5835 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5837 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5951 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5954 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6213 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6146 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6213 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 6b05621744f..6832db56190 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -82,6 +82,7 @@ class Relation(google.protobuf.message.Message): DROP_FIELD_NUMBER: builtins.int TAIL_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int + DROP_NA_FIELD_NUMBER: builtins.int SUMMARY_FIELD_NUMBER: builtins.int CROSSTAB_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @@ -133,6 +134,8 @@ class Relation(google.protobuf.message.Message): def fill_na(self) -> global___NAFill: """NA functions""" @property + def drop_na(self) -> global___NADrop: ... + @property def summary(self) -> global___StatSummary: """stat functions""" @property @@ -165,6 +168,7 @@ class Relation(google.protobuf.message.Message): drop: global___Drop | None = ..., tail: global___Tail | None = ..., fill_na: global___NAFill | None = ..., + drop_na: global___NADrop | None = ..., summary: global___StatSummary | None = ..., crosstab: global___StatCrosstab | None = ..., unknown: global___Unknown | None = ..., @@ -182,6 +186,8 @@ class Relation(google.protobuf.message.Message): b"deduplicate", "drop", b"drop", + "drop_na", + b"drop_na", "fill_na", b"fill_na", "filter", @@ -241,6 +247,8 @@ class Relation(google.protobuf.message.Message): b"deduplicate", "drop", b"drop", + "drop_na", + b"drop_na", "fill_na", b"fill_na", "filter", @@ -312,6 +320,7 @@ class Relation(google.protobuf.message.Message): "drop", "tail", "fill_na", + "drop_na", "summary", "crosstab", "unknown", @@ -1547,6 +1556,74 @@ class NAFill(google.protobuf.message.Message): global___NAFill = NAFill +class NADrop(google.protobuf.message.Message): + """Drop rows containing null values. + It will invoke 'Dataset.na.drop' (same as 'DataFrameNaFunctions.drop') to compute the results. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + COLS_FIELD_NUMBER: builtins.int + MIN_NON_NULLS_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + @property + def cols( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Optional) Optional list of column names to consider. + + When it is empty, all the columns in the input relation will be considered. + """ + min_non_nulls: builtins.int + """(Optional) The minimum number of non-null and non-NaN values required to keep. + + When not set, it is equivalent to the number of considered columns, which means + a row will be kept only if all columns are non-null. + + 'how' options ('all', 'any') can be easily converted to this field: + - 'all' -> set 'min_non_nulls' 1; + - 'any' -> keep 'min_non_nulls' unset; + """ + def __init__( + self, + *, + input: global___Relation | None = ..., + cols: collections.abc.Iterable[builtins.str] | None = ..., + min_non_nulls: builtins.int | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_min_non_nulls", + b"_min_non_nulls", + "input", + b"input", + "min_non_nulls", + b"min_non_nulls", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_min_non_nulls", + b"_min_non_nulls", + "cols", + b"cols", + "input", + b"input", + "min_non_nulls", + b"min_non_nulls", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_min_non_nulls", b"_min_non_nulls"] + ) -> typing_extensions.Literal["min_non_nulls"] | None: ... + +global___NADrop = NADrop + class RenameColumnsBySameLengthNames(google.protobuf.message.Message): """Rename columns on the input relation by the same length of names.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 0fec48779ef..eb025fd5d04 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -498,6 +498,38 @@ class SparkConnectTests(SparkConnectSQLTestCase): self.spark.sql(query).na.fill({"a": True, "b": 2}).toPandas(), ) + def test_drop_na(self): + # SPARK-41148: Test drop na + query = """ + SELECT * FROM VALUES + (false, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0) + AS tab(a, b, c) + """ + # +-----+----+----+ + # | a| b| c| + # +-----+----+----+ + # |false| 1|null| + # |false|null| 2.0| + # | null| 3| 3.0| + # +-----+----+----+ + + self.assert_eq( + self.connect.sql(query).dropna().toPandas(), + self.spark.sql(query).dropna().toPandas(), + ) + self.assert_eq( + self.connect.sql(query).na.drop(how="all", thresh=1).toPandas(), + self.spark.sql(query).na.drop(how="all", thresh=1).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(), + self.spark.sql(query).dropna(thresh=1, subset=("a", "b")).toPandas(), + ) + self.assert_eq( + self.connect.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(), + self.spark.sql(query).na.drop(how="any", thresh=2, subset="a").toPandas(), + ) + def test_empty_dataset(self): # SPARK-41005: Test arrow based collection with empty dataset. self.assertTrue( diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 109702af7e8..367c514d0a8 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -102,6 +102,22 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): self.assertEqual(plan.root.fill_na.values[1].string, "abc") self.assertEqual(plan.root.fill_na.cols, ["col_a", "col_b"]) + def test_drop_na(self): + # SPARK-41148: Test drop na + df = self.connect.readTable(table_name=self.tbl_name) + + plan = df.dropna()._plan.to_proto(self.connect) + self.assertEqual(plan.root.drop_na.cols, []) + self.assertEqual(plan.root.drop_na.HasField("min_non_nulls"), False) + + plan = df.na.drop(thresh=2, subset=("col_a", "col_b"))._plan.to_proto(self.connect) + self.assertEqual(plan.root.drop_na.cols, ["col_a", "col_b"]) + self.assertEqual(plan.root.drop_na.min_non_nulls, 2) + + plan = df.dropna(how="all", subset="col_c")._plan.to_proto(self.connect) + self.assertEqual(plan.root.drop_na.cols, ["col_c"]) + self.assertEqual(plan.root.drop_na.min_non_nulls, 1) + def test_summary(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org