This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ee22158f2a [SPARK-41829][CONNECT][PYTHON] Add the missing ordering 
parameter in `Sort` and `sortWithinPartitions`
6ee22158f2a is described below

commit 6ee22158f2a1891d39c4274fb6fe96d6fbb6c1fc
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Jan 5 15:26:17 2023 +0800

    [SPARK-41829][CONNECT][PYTHON] Add the missing ordering parameter in `Sort` 
and `sortWithinPartitions`
    
    ### What changes were proposed in this pull request?
    Add the missing ordering parameter in `Sort` and `sortWithinPartitions`
    
    ### Why are the changes needed?
    API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    enabled doctests
    
    Closes #39398 from zhengruifeng/connect_fix_41829.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py | 59 ++++++++++++++++++++++++++++-----
 python/pyspark/sql/connect/plan.py      | 21 +++++-------
 2 files changed, 59 insertions(+), 21 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index a22c2cc6421..13a421ca72a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -22,6 +22,7 @@ from typing import (
     Optional,
     Tuple,
     Union,
+    Sequence,
     TYPE_CHECKING,
     overload,
     Callable,
@@ -44,7 +45,13 @@ from pyspark.sql.connect.group import GroupedData
 from pyspark.sql.connect.readwriter import DataFrameWriter
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import UnresolvedRegex
-from pyspark.sql.connect.functions import _invoke_function, col, lit, expr as 
sql_expression
+from pyspark.sql.connect.functions import (
+    _to_col,
+    _invoke_function,
+    col,
+    lit,
+    expr as sql_expression,
+)
 from pyspark.sql.dataframe import (
     DataFrame as PySparkDataFrame,
     DataFrameNaFunctions as PySparkDataFrameNaFunctions,
@@ -342,18 +349,56 @@ class DataFrame:
 
     tail.__doc__ = PySparkDataFrame.tail.__doc__
 
-    def sort(self, *cols: "ColumnOrName") -> "DataFrame":
+    def _sort_cols(
+        self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], 
kwargs: Dict[str, Any]
+    ) -> List[Column]:
+        """Return a JVM Seq of Columns that describes the sort order"""
+        if cols is None:
+            raise ValueError("should sort by at least one column")
+
+        _cols: List[Column] = []
+        if len(cols) == 1 and isinstance(cols[0], list):
+            _cols = [_to_col(c) for c in cols[0]]
+        else:
+            _cols = [_to_col(cast("ColumnOrName", c)) for c in cols]
+
+        ascending = kwargs.get("ascending", True)
+        if isinstance(ascending, (bool, int)):
+            if not ascending:
+                _cols = [c.desc() for c in _cols]
+        elif isinstance(ascending, list):
+            _cols = [c if asc else c.desc() for asc, c in zip(ascending, 
_cols)]
+        else:
+            raise TypeError("ascending can only be boolean or list, but got 
%s" % type(ascending))
+
+        return _cols
+
+    def sort(
+        self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
+    ) -> "DataFrame":
         return DataFrame.withPlan(
-            plan.Sort(self._plan, columns=list(cols), is_global=True), 
session=self._session
+            plan.Sort(
+                self._plan,
+                columns=self._sort_cols(cols, kwargs),
+                is_global=True,
+            ),
+            session=self._session,
         )
 
     sort.__doc__ = PySparkDataFrame.sort.__doc__
 
     orderBy = sort
 
-    def sortWithinPartitions(self, *cols: "ColumnOrName") -> "DataFrame":
+    def sortWithinPartitions(
+        self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
+    ) -> "DataFrame":
         return DataFrame.withPlan(
-            plan.Sort(self._plan, columns=list(cols), is_global=False), 
session=self._session
+            plan.Sort(
+                self._plan,
+                columns=self._sort_cols(cols, kwargs),
+                is_global=False,
+            ),
+            session=self._session,
         )
 
     sortWithinPartitions.__doc__ = 
PySparkDataFrame.sortWithinPartitions.__doc__
@@ -1440,10 +1485,6 @@ def _test() -> None:
         # TODO(SPARK-41827): groupBy requires all cols be Column or str
         del pyspark.sql.connect.dataframe.DataFrame.groupBy.__doc__
 
-        # TODO(SPARK-41829): Add Dataframe sort ordering
-        del pyspark.sql.connect.dataframe.DataFrame.sort.__doc__
-        del 
pyspark.sql.connect.dataframe.DataFrame.sortWithinPartitions.__doc__
-
         # TODO(SPARK-41830): fix sample parameters
         del pyspark.sql.connect.dataframe.DataFrame.sample.__doc__
 
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index a6d1ad4068b..1973755be27 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -496,27 +496,24 @@ class Sort(LogicalPlan):
     def __init__(
         self,
         child: Optional["LogicalPlan"],
-        columns: List["ColumnOrName"],
+        columns: List[Column],
         is_global: bool,
     ) -> None:
         super().__init__(child)
+
+        assert all(isinstance(c, Column) for c in columns)
+        assert isinstance(is_global, bool)
+
         self.columns = columns
         self.is_global = is_global
 
     def _convert_col(
-        self, col: "ColumnOrName", session: "SparkConnectClient"
+        self, col: Column, session: "SparkConnectClient"
     ) -> proto.Expression.SortOrder:
-        sort: Optional[SortOrder] = None
-        if isinstance(col, Column):
-            if isinstance(col._expr, SortOrder):
-                sort = col._expr
-            else:
-                sort = SortOrder(col._expr)
+        if isinstance(col._expr, SortOrder):
+            return col._expr.to_plan(session).sort_order
         else:
-            sort = SortOrder(ColumnReference(name=col))
-        assert sort is not None
-
-        return sort.to_plan(session).sort_order
+            return SortOrder(col._expr).to_plan(session).sort_order
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to