This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d86208d043a [SPARK-43620][CONNECT][PS] Fix Pandas APIs depends on 
unsupported features
d86208d043a is described below

commit d86208d043a27a02d3c4ccf6a929c7e6b8ad0292
Author: Haejoon Lee <[email protected]>
AuthorDate: Wed Oct 4 12:22:34 2023 +0900

    [SPARK-43620][CONNECT][PS] Fix Pandas APIs depends on unsupported features
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to fix the Pandas APIs that have dependency on unsupported 
PySpark features.
    
    ### Why are the changes needed?
    
    To increate the API coverage for Pandas API on Spark with Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Pandas data type APIs such as `astype` and `factorize` are supported on 
Spark Connect.
    
    ### How was this patch tested?
    
    Enabling the existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43120 from itholic/SPARK-43620.
    
    Authored-by: Haejoon Lee <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/pandas/base.py                      | 25 +++++++--------
 python/pyspark/pandas/data_type_ops/base.py        | 26 ++++++++++++----
 .../pandas/data_type_ops/categorical_ops.py        |  7 ++---
 .../data_type_ops/test_parity_binary_ops.py        |  4 +--
 .../data_type_ops/test_parity_boolean_ops.py       |  4 ---
 .../data_type_ops/test_parity_categorical_ops.py   | 12 --------
 .../connect/data_type_ops/test_parity_date_ops.py  |  4 ---
 .../data_type_ops/test_parity_datetime_ops.py      |  4 ---
 .../connect/data_type_ops/test_parity_null_ops.py  |  4 +--
 .../connect/data_type_ops/test_parity_num_ops.py   |  4 ---
 .../data_type_ops/test_parity_string_ops.py        |  4 ---
 .../data_type_ops/test_parity_timedelta_ops.py     |  4 ---
 .../tests/connect/indexes/test_parity_base.py      |  4 ---
 .../tests/connect/indexes/test_parity_category.py  | 36 +---------------------
 .../tests/connect/series/test_parity_compute.py    |  4 ---
 .../tests/connect/test_parity_categorical.py       | 24 ---------------
 16 files changed, 39 insertions(+), 131 deletions(-)

diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py
index ed6e983fdc8..fa513e8b9b6 100644
--- a/python/pyspark/pandas/base.py
+++ b/python/pyspark/pandas/base.py
@@ -1704,16 +1704,10 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
             if len(categories) == 0:
                 scol = F.lit(None)
             else:
-                kvs = list(
-                    chain(
-                        *[
-                            (F.lit(code), F.lit(category))
-                            for code, category in enumerate(categories)
-                        ]
-                    )
-                )
-                map_scol = F.create_map(*kvs)
-                scol = map_scol[self.spark.column]
+                scol = F.lit(None)
+                for code, category in reversed(list(enumerate(categories))):
+                    scol = F.when(self.spark.column == F.lit(code), 
F.lit(category)).otherwise(scol)
+
             codes, uniques = self._with_new_scol(
                 scol.alias(self._internal.data_spark_column_names[0])
             ).factorize(use_na_sentinel=use_na_sentinel)
@@ -1761,9 +1755,16 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
         if len(kvs) == 0:  # uniques are all missing values
             new_scol = F.lit(na_sentinel_code)
         else:
-            map_scol = F.create_map(*kvs)
             null_scol = F.when(self.isnull().spark.column, 
F.lit(na_sentinel_code))
-            new_scol = null_scol.otherwise(map_scol[self.spark.column])
+            mapped_scol = None
+            for i in range(0, len(kvs), 2):
+                key = kvs[i]
+                value = kvs[i + 1]
+                if mapped_scol is None:
+                    mapped_scol = F.when(self.spark.column == key, value)
+                else:
+                    mapped_scol = mapped_scol.when(self.spark.column == key, 
value)
+            new_scol = null_scol.otherwise(mapped_scol)
 
         codes = 
self._with_new_scol(new_scol.alias(self._internal.data_spark_column_names[0]))
 
diff --git a/python/pyspark/pandas/data_type_ops/base.py 
b/python/pyspark/pandas/data_type_ops/base.py
index 5d497a55a5f..4f57aa65be7 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -17,7 +17,6 @@
 
 import numbers
 from abc import ABCMeta
-from itertools import chain
 from typing import Any, Optional, Union
 
 import numpy as np
@@ -130,12 +129,27 @@ def _as_categorical_type(
         if len(categories) == 0:
             scol = F.lit(-1)
         else:
-            kvs = chain(
-                *[(F.lit(category), F.lit(code)) for code, category in 
enumerate(categories)]
-            )
-            map_scol = F.create_map(*kvs)
+            scol = F.lit(-1)
+            if isinstance(
+                
index_ops._internal.spark_type_for(index_ops._internal.column_labels[0]), 
BinaryType
+            ):
+                from pyspark.sql.functions import base64
+
+                stringified_column = base64(index_ops.spark.column)
+                for code, category in enumerate(categories):
+                    # Convert each category to base64 before comparison
+                    base64_category = F.base64(F.lit(category))
+                    scol = F.when(stringified_column == base64_category, 
F.lit(code)).otherwise(
+                        scol
+                    )
+            else:
+                stringified_column = F.format_string("%s", 
index_ops.spark.column)
+
+                for code, category in enumerate(categories):
+                    scol = F.when(stringified_column == F.lit(category), 
F.lit(code)).otherwise(
+                        scol
+                    )
 
-            scol = F.coalesce(map_scol[index_ops.spark.column], F.lit(-1))
         return index_ops._with_new_scol(
             scol.cast(spark_type),
             field=index_ops._internal.data_fields[0].copy(
diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py 
b/python/pyspark/pandas/data_type_ops/categorical_ops.py
index 824666b5819..bbaded42be9 100644
--- a/python/pyspark/pandas/data_type_ops/categorical_ops.py
+++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 
-from itertools import chain
 from typing import cast, Any, Union
 
 import pandas as pd
@@ -135,7 +134,7 @@ def _to_cat(index_ops: IndexOpsLike) -> IndexOpsLike:
     if len(categories) == 0:
         scol = F.lit(None)
     else:
-        kvs = chain(*[(F.lit(code), F.lit(category)) for code, category in 
enumerate(categories)])
-        map_scol = F.create_map(*kvs)
-        scol = map_scol[index_ops.spark.column]
+        scol = F.lit(None)
+        for code, category in reversed(list(enumerate(categories))):
+            scol = F.when(index_ops.spark.column == F.lit(code), 
F.lit(category)).otherwise(scol)
     return index_ops._with_new_scol(scol)
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py
index 663c0007389..29b13868e03 100644
--- 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py
+++ 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_binary_ops.py
@@ -25,9 +25,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
 class BinaryOpsParityTests(
     BinaryOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase, 
ReusedConnectTestCase
 ):
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
+    pass
 
 
 if __name__ == "__main__":
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py
index 52d517967eb..9ad2aa0ad17 100644
--- 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py
+++ 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_boolean_ops.py
@@ -30,10 +30,6 @@ class BooleanOpsParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.data_type_ops.test_parity_boolean_ops 
import *  # noqa: F401
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
index b680e5b3d79..1b4dabdb045 100644
--- 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
+++ 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
@@ -30,18 +30,6 @@ class CategoricalOpsParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_eq(self):
-        super().test_eq()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_ne(self):
-        super().test_ne()
-
 
 if __name__ == "__main__":
     from 
pyspark.pandas.tests.connect.data_type_ops.test_parity_categorical_ops import *
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py
index e7b1c7de70d..baa3180baaa 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_date_ops.py
@@ -30,10 +30,6 @@ class DateOpsParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.data_type_ops.test_parity_date_ops 
import *  # noqa: F401
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py
index 6d081b10aba..2641e3a32dc 100644
--- 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py
+++ 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_datetime_ops.py
@@ -30,10 +30,6 @@ class DatetimeOpsParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.data_type_ops.test_parity_datetime_ops 
import *  # noqa: F401
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
index 63b53c02fd7..5df4c791c98 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py
@@ -25,9 +25,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
 class NullOpsParityTests(
     NullOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase, 
ReusedConnectTestCase
 ):
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
+    pass
 
 
 if __name__ == "__main__":
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
index 04aa24c4045..56eba708c94 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py
@@ -30,10 +30,6 @@ class NumOpsParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.data_type_ops.test_parity_num_ops import 
*  # noqa: F401
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
index ecbf94a6bde..f507756a7a4 100644
--- 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
+++ 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
@@ -30,10 +30,6 @@ class StringOpsParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.data_type_ops.test_parity_string_ops 
import *  # noqa: F401
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py
 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py
index 058dd2bfd3f..edd29fa1ed2 100644
--- 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py
+++ 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_timedelta_ops.py
@@ -30,10 +30,6 @@ class TimedeltaOpsParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.data_type_ops.test_parity_timedelta_ops 
import *  # noqa: F401
diff --git a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py 
b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
index 3cf4dc9b3d2..8f1f2d2221c 100644
--- a/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
+++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_base.py
@@ -29,10 +29,6 @@ class IndexesParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_factorize(self):
-        super().test_factorize()
-
     @unittest.skip("TODO(SPARK-43704): Enable 
IndexesParityTests.test_to_series.")
     def test_to_series(self):
         super().test_to_series()
diff --git 
a/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py 
b/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py
index d99d013306f..aed7df26202 100644
--- a/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py
+++ b/python/pyspark/pandas/tests/connect/indexes/test_parity_category.py
@@ -24,41 +24,7 @@ from pyspark.testing.pandasutils import 
PandasOnSparkTestUtils, TestUtils
 class CategoricalIndexParityTests(
     CategoricalIndexTestsMixin, PandasOnSparkTestUtils, TestUtils, 
ReusedConnectTestCase
 ):
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_append(self):
-        super().test_append()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_factorize(self):
-        super().test_factorize()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_intersection(self):
-        super().test_intersection()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_remove_categories(self):
-        super().test_remove_categories()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_remove_unused_categories(self):
-        super().test_remove_unused_categories()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_reorder_categories(self):
-        super().test_reorder_categories()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_set_categories(self):
-        super().test_set_categories()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_union(self):
-        super().test_union()
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py 
b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
index 31916f12b4e..8876fcb1398 100644
--- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
+++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
@@ -24,10 +24,6 @@ from pyspark.testing.pandasutils import 
PandasOnSparkTestUtils
 class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
     pass
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_factorize(self):
-        super().test_factorize()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.series.test_parity_compute import *  # 
noqa: F401
diff --git a/python/pyspark/pandas/tests/connect/test_parity_categorical.py 
b/python/pyspark/pandas/tests/connect/test_parity_categorical.py
index 210cfce8ddb..ca880aef572 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_categorical.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_categorical.py
@@ -29,30 +29,6 @@ class CategoricalParityTests(
     def psdf(self):
         return ps.from_pandas(self.pdf)
 
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_astype(self):
-        super().test_astype()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_factorize(self):
-        super().test_factorize()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_remove_categories(self):
-        super().test_remove_categories()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_remove_unused_categories(self):
-        super().test_remove_unused_categories()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_reorder_categories(self):
-        super().test_reorder_categories()
-
-    @unittest.skip("TODO(SPARK-43620): Support `Column` for 
SparkConnectColumn.__getitem__.")
-    def test_set_categories(self):
-        super().test_set_categories()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.test_parity_categorical import *  # 
noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to