This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6ee22158f2a [SPARK-41829][CONNECT][PYTHON] Add the missing ordering
parameter in `Sort` and `sortWithinPartitions`
6ee22158f2a is described below
commit 6ee22158f2a1891d39c4274fb6fe96d6fbb6c1fc
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Jan 5 15:26:17 2023 +0800
[SPARK-41829][CONNECT][PYTHON] Add the missing ordering parameter in `Sort`
and `sortWithinPartitions`
### What changes were proposed in this pull request?
Add the missing ordering parameter in `Sort` and `sortWithinPartitions`
### Why are the changes needed?
API coverage
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
enabled doctests
Closes #39398 from zhengruifeng/connect_fix_41829.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 59 ++++++++++++++++++++++++++++-----
python/pyspark/sql/connect/plan.py | 21 +++++-------
2 files changed, 59 insertions(+), 21 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index a22c2cc6421..13a421ca72a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -22,6 +22,7 @@ from typing import (
Optional,
Tuple,
Union,
+ Sequence,
TYPE_CHECKING,
overload,
Callable,
@@ -44,7 +45,13 @@ from pyspark.sql.connect.group import GroupedData
from pyspark.sql.connect.readwriter import DataFrameWriter
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedRegex
-from pyspark.sql.connect.functions import _invoke_function, col, lit, expr as
sql_expression
+from pyspark.sql.connect.functions import (
+ _to_col,
+ _invoke_function,
+ col,
+ lit,
+ expr as sql_expression,
+)
from pyspark.sql.dataframe import (
DataFrame as PySparkDataFrame,
DataFrameNaFunctions as PySparkDataFrameNaFunctions,
@@ -342,18 +349,56 @@ class DataFrame:
tail.__doc__ = PySparkDataFrame.tail.__doc__
- def sort(self, *cols: "ColumnOrName") -> "DataFrame":
+ def _sort_cols(
+ self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]],
kwargs: Dict[str, Any]
+ ) -> List[Column]:
+ """Return a JVM Seq of Columns that describes the sort order"""
+ if cols is None:
+ raise ValueError("should sort by at least one column")
+
+ _cols: List[Column] = []
+ if len(cols) == 1 and isinstance(cols[0], list):
+ _cols = [_to_col(c) for c in cols[0]]
+ else:
+ _cols = [_to_col(cast("ColumnOrName", c)) for c in cols]
+
+ ascending = kwargs.get("ascending", True)
+ if isinstance(ascending, (bool, int)):
+ if not ascending:
+ _cols = [c.desc() for c in _cols]
+ elif isinstance(ascending, list):
+ _cols = [c if asc else c.desc() for asc, c in zip(ascending,
_cols)]
+ else:
+ raise TypeError("ascending can only be boolean or list, but got
%s" % type(ascending))
+
+ return _cols
+
+ def sort(
+ self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs:
Any
+ ) -> "DataFrame":
return DataFrame.withPlan(
- plan.Sort(self._plan, columns=list(cols), is_global=True),
session=self._session
+ plan.Sort(
+ self._plan,
+ columns=self._sort_cols(cols, kwargs),
+ is_global=True,
+ ),
+ session=self._session,
)
sort.__doc__ = PySparkDataFrame.sort.__doc__
orderBy = sort
- def sortWithinPartitions(self, *cols: "ColumnOrName") -> "DataFrame":
+ def sortWithinPartitions(
+ self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs:
Any
+ ) -> "DataFrame":
return DataFrame.withPlan(
- plan.Sort(self._plan, columns=list(cols), is_global=False),
session=self._session
+ plan.Sort(
+ self._plan,
+ columns=self._sort_cols(cols, kwargs),
+ is_global=False,
+ ),
+ session=self._session,
)
sortWithinPartitions.__doc__ =
PySparkDataFrame.sortWithinPartitions.__doc__
@@ -1440,10 +1485,6 @@ def _test() -> None:
# TODO(SPARK-41827): groupBy requires all cols be Column or str
del pyspark.sql.connect.dataframe.DataFrame.groupBy.__doc__
- # TODO(SPARK-41829): Add Dataframe sort ordering
- del pyspark.sql.connect.dataframe.DataFrame.sort.__doc__
- del
pyspark.sql.connect.dataframe.DataFrame.sortWithinPartitions.__doc__
-
# TODO(SPARK-41830): fix sample parameters
del pyspark.sql.connect.dataframe.DataFrame.sample.__doc__
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index a6d1ad4068b..1973755be27 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -496,27 +496,24 @@ class Sort(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
- columns: List["ColumnOrName"],
+ columns: List[Column],
is_global: bool,
) -> None:
super().__init__(child)
+
+ assert all(isinstance(c, Column) for c in columns)
+ assert isinstance(is_global, bool)
+
self.columns = columns
self.is_global = is_global
def _convert_col(
- self, col: "ColumnOrName", session: "SparkConnectClient"
+ self, col: Column, session: "SparkConnectClient"
) -> proto.Expression.SortOrder:
- sort: Optional[SortOrder] = None
- if isinstance(col, Column):
- if isinstance(col._expr, SortOrder):
- sort = col._expr
- else:
- sort = SortOrder(col._expr)
+ if isinstance(col._expr, SortOrder):
+ return col._expr.to_plan(session).sort_order
else:
- sort = SortOrder(ColumnReference(name=col))
- assert sort is not None
-
- return sort.to_plan(session).sort_order
+ return SortOrder(col._expr).to_plan(session).sort_order
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]