This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 001da5d003c [SPARK-43671][PS][FOLLOWUP] Refine `CategoricalOps`
functions
001da5d003c is described below
commit 001da5d003caef3cda9978d35967ade55837e0bc
Author: itholic <[email protected]>
AuthorDate: Sun May 28 08:44:16 2023 +0800
[SPARK-43671][PS][FOLLOWUP] Refine `CategoricalOps` functions
### What changes were proposed in this pull request?
This PR follow-up for SPARK-43671, to refine functions to use
`pyspark_column_op` util for clean-up the code.
### Why are the changes needed?
To avoid `is_remote` in too many places for future maintenance.
### Does this PR introduce _any_ user-facing change?
No, it's code cleanup
### How was this patch tested?
The existing CI should pass
Closes #41326 from itholic/categorical_followup.
Authored-by: itholic <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../pandas/data_type_ops/categorical_ops.py | 69 +++++-----------------
1 file changed, 14 insertions(+), 55 deletions(-)
diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py
b/python/pyspark/pandas/data_type_ops/categorical_ops.py
index 9f14a4b1ee7..66e181a6079 100644
--- a/python/pyspark/pandas/data_type_ops/categorical_ops.py
+++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py
@@ -16,19 +16,18 @@
#
from itertools import chain
-from typing import cast, Any, Callable, Union
+from typing import cast, Any, Union
import pandas as pd
import numpy as np
from pandas.api.types import is_list_like, CategoricalDtype # type:
ignore[attr-defined]
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
-from pyspark.pandas.base import column_op, IndexOpsMixin
+from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.data_type_ops.base import _sanitize_list_like, DataTypeOps
from pyspark.pandas.typedef import pandas_on_spark_type
from pyspark.sql import functions as F
-from pyspark.sql.column import Column as PySparkColumn
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import pyspark_column_op
class CategoricalOps(DataTypeOps):
@@ -66,73 +65,33 @@ class CategoricalOps(DataTypeOps):
def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
-
- Column = ConnectColumn
- else:
- Column = PySparkColumn # type: ignore[assignment]
- return _compare(
- left, right, Column.__eq__, is_equality_comparison=True # type:
ignore[arg-type]
- )
+ return _compare(left, right, "__eq__", is_equality_comparison=True)
def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
-
- Column = ConnectColumn
- else:
- Column = PySparkColumn # type: ignore[assignment]
- return _compare(
- left, right, Column.__ne__, is_equality_comparison=True # type:
ignore[arg-type]
- )
+ return _compare(left, right, "__ne__", is_equality_comparison=True)
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
-
- Column = ConnectColumn
- else:
- Column = PySparkColumn # type: ignore[assignment]
- return _compare(left, right, Column.__lt__) # type: ignore[arg-type]
+ return _compare(left, right, "__lt__")
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
-
- Column = ConnectColumn
- else:
- Column = PySparkColumn # type: ignore[assignment]
- return _compare(left, right, Column.__le__) # type: ignore[arg-type]
+ return _compare(left, right, "__le__")
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
-
- Column = ConnectColumn
- else:
- Column = PySparkColumn # type: ignore[assignment]
- return _compare(left, right, Column.__gt__) # type: ignore[arg-type]
+ return _compare(left, right, "__gt__")
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
-
- Column = ConnectColumn
- else:
- Column = PySparkColumn # type: ignore[assignment]
- return _compare(left, right, Column.__ge__) # type: ignore[arg-type]
+ return _compare(left, right, "__ge__")
def _compare(
left: IndexOpsLike,
right: Any,
- f: Callable[..., PySparkColumn],
+ func_name: str,
*,
is_equality_comparison: bool = False,
) -> SeriesOrIndex:
@@ -143,7 +102,7 @@ def _compare(
----------
left: A Categorical operand
right: The other operand to compare with
- f : The Spark Column function to apply
+ func_name: The Spark Column function name to apply
is_equality_comparison: True if it is equality comparison, ie. == or !=.
False by default.
Returns
@@ -158,15 +117,15 @@ def _compare(
if hash(left.dtype) != hash(right.dtype):
raise TypeError("Categoricals can only be compared if 'categories'
are the same.")
if cast(CategoricalDtype, left.dtype).ordered:
- return column_op(f)(left, right)
+ return pyspark_column_op(func_name)(left, right)
else:
- return column_op(f)(_to_cat(left), _to_cat(right))
+ return pyspark_column_op(func_name)(_to_cat(left), _to_cat(right))
elif not is_list_like(right):
categories = cast(CategoricalDtype, left.dtype).categories
if right not in categories:
raise TypeError("Cannot compare a Categorical with a scalar, which
is not a category.")
right_code = categories.get_loc(right)
- return column_op(f)(left, right_code)
+ return pyspark_column_op(func_name)(left, right_code)
else:
raise TypeError("Cannot compare a Categorical with the given type.")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]