This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 17fac569b4e [SPARK-43660][CONNECT][PS] Enable `resample` with Spark 
Connect
17fac569b4e is described below

commit 17fac569b4e4b569d41f761db07d7bf112801e0c
Author: itholic <[email protected]>
AuthorDate: Fri Jul 7 10:24:40 2023 +0800

    [SPARK-43660][CONNECT][PS] Enable `resample` with Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to enable `resample` on Spark Connect.
    
    ### Why are the changes needed?
    
    To increase pandas API coverage on Spark Connect
    
    ### Does this PR introduce _any_ user-facing change?
    
    `resample` is available on Spark Connect.
    
    ### How was this patch tested?
    
    Uncommented skipping tests, the existing CI should pass.
    
    Closes #41877 from itholic/SPARK-43660.
    
    Authored-by: itholic <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  5 +++
 python/pyspark/pandas/resample.py                  | 46 ++++++++++++++++------
 python/pyspark/pandas/spark/functions.py           | 16 ++++++++
 .../pandas/tests/connect/test_parity_resample.py   |  8 +---
 4 files changed, 56 insertions(+), 19 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d3090e8b09b..5fd5f7d4c77 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1798,6 +1798,11 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
         val children = fun.getArgumentsList.asScala.map(transformExpression)
         Some(NullIndex(children(0)))
 
+      case "timestampdiff" if fun.getArgumentsCount == 3 =>
+        val children = fun.getArgumentsList.asScala.map(transformExpression)
+        val unit = extractString(children(0), "unit")
+        Some(TimestampDiff(unit, children(1), children(2)))
+
       // ML-specific functions
       case "vector_to_array" if fun.getArgumentsCount == 2 =>
         val expr = transformExpression(fun.getArguments(0))
diff --git a/python/pyspark/pandas/resample.py 
b/python/pyspark/pandas/resample.py
index 1bd4c075342..c6c6019c07e 100644
--- a/python/pyspark/pandas/resample.py
+++ b/python/pyspark/pandas/resample.py
@@ -26,6 +26,7 @@ from typing import (
     Generic,
     List,
     Optional,
+    Union,
 )
 
 import numpy as np
@@ -65,6 +66,8 @@ from pyspark.pandas.utils import (
     scol_for,
     verify_temp_column_name,
 )
+from pyspark.sql.utils import is_remote
+from pyspark.pandas.spark.functions import timestampdiff
 
 
 class Resampler(Generic[FrameLike], metaclass=ABCMeta):
@@ -131,8 +134,27 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
     def _agg_columns_scols(self) -> List[Column]:
         return [s.spark.column for s in self._agg_columns]
 
+    def get_make_interval(  # type: ignore[return]
+        self, unit: str, col: Union[Column, int, float]
+    ) -> Column:
+        if is_remote():
+            from pyspark.sql.connect.functions import lit, make_interval
+
+            col = col if not isinstance(col, (int, float)) else lit(col)  # 
type: ignore[assignment]
+            if unit == "MONTH":
+                return make_interval(months=col)  # type: ignore
+            if unit == "HOUR":
+                return make_interval(hours=col)  # type: ignore
+            if unit == "MINUTE":
+                return make_interval(mins=col)  # type: ignore
+            if unit == "SECOND":
+                return make_interval(secs=col)  # type: ignore
+        else:
+            sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
+            col = col._jc if isinstance(col, Column) else F.lit(col)._jc
+            return sql_utils.makeInterval(unit, col)
+
     def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
-        sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
         origin_scol = F.lit(origin)
         (rule_code, n) = (self._offset.rule_code, getattr(self._offset, "n"))
         left_closed, right_closed = (self._closed == "left", self._closed == 
"right")
@@ -191,18 +213,18 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
             truncated_ts_scol = F.date_trunc("MONTH", ts_scol)
             edge_label = truncated_ts_scol
             if left_closed and right_labeled:
-                edge_label += sql_utils.makeInterval("MONTH", F.lit(n)._jc)
+                edge_label += self.get_make_interval("MONTH", n)
             elif right_closed and left_labeled:
-                edge_label -= sql_utils.makeInterval("MONTH", F.lit(n)._jc)
+                edge_label -= self.get_make_interval("MONTH", n)
 
             if left_labeled:
                 non_edge_label = F.when(
                     mod == 0,
-                    truncated_ts_scol - sql_utils.makeInterval("MONTH", 
F.lit(n)._jc),
-                ).otherwise(truncated_ts_scol - 
sql_utils.makeInterval("MONTH", mod._jc))
+                    truncated_ts_scol - self.get_make_interval("MONTH", n),
+                ).otherwise(truncated_ts_scol - 
self.get_make_interval("MONTH", mod))
             else:
                 non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
-                    truncated_ts_scol - sql_utils.makeInterval("MONTH", (mod - 
n)._jc)
+                    truncated_ts_scol - self.get_make_interval("MONTH", mod - 
n)
                 )
 
             return F.to_timestamp(
@@ -257,7 +279,7 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
             unit_str = unit_mapping[rule_code]
 
             truncated_ts_scol = F.date_trunc(unit_str, ts_scol)
-            diff = sql_utils.timestampDiff(unit_str, origin_scol._jc, 
truncated_ts_scol._jc)
+            diff = timestampdiff(unit_str, origin_scol, truncated_ts_scol)
             mod = F.lit(0) if n == 1 else (diff % F.lit(n))
 
             if rule_code == "H":
@@ -271,19 +293,19 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
 
             edge_label = truncated_ts_scol
             if left_closed and right_labeled:
-                edge_label += sql_utils.makeInterval(unit_str, F.lit(n)._jc)
+                edge_label += self.get_make_interval(unit_str, n)
             elif right_closed and left_labeled:
-                edge_label -= sql_utils.makeInterval(unit_str, F.lit(n)._jc)
+                edge_label -= self.get_make_interval(unit_str, n)
 
             if left_labeled:
                 non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
-                    truncated_ts_scol - sql_utils.makeInterval(unit_str, 
mod._jc)
+                    truncated_ts_scol - self.get_make_interval(unit_str, mod)
                 )
             else:
                 non_edge_label = F.when(
                     mod == 0,
-                    truncated_ts_scol + sql_utils.makeInterval(unit_str, 
F.lit(n)._jc),
-                ).otherwise(truncated_ts_scol - 
sql_utils.makeInterval(unit_str, (mod - n)._jc))
+                    truncated_ts_scol + self.get_make_interval(unit_str, n),
+                ).otherwise(truncated_ts_scol - 
self.get_make_interval(unit_str, mod - n))
 
             return F.when(edge_cond, edge_label).otherwise(non_edge_label)
 
diff --git a/python/pyspark/pandas/spark/functions.py 
b/python/pyspark/pandas/spark/functions.py
index 44650fd4d20..d6f6c6fdeeb 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -185,3 +185,19 @@ def null_index(col: Column) -> Column:
     else:
         sc = SparkContext._active_spark_context
         return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
+
+
+def timestampdiff(unit: str, start: Column, end: Column) -> Column:
+    if is_remote():
+        from pyspark.sql.connect.functions import 
_invoke_function_over_columns, lit
+
+        return _invoke_function_over_columns(  # type: ignore[return-value]
+            "timestampdiff",
+            lit(unit),
+            start,  # type: ignore[arg-type]
+            end,  # type: ignore[arg-type]
+        )
+
+    else:
+        sc = SparkContext._active_spark_context
+        return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc, 
end._jc))
diff --git a/python/pyspark/pandas/tests/connect/test_parity_resample.py 
b/python/pyspark/pandas/tests/connect/test_parity_resample.py
index ab3e7f4410b..e5957cc9b4a 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_resample.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_resample.py
@@ -24,13 +24,7 @@ from pyspark.testing.pandasutils import 
PandasOnSparkTestUtils, TestUtils
 class ResampleTestsParityMixin(
     ResampleTestsMixin, PandasOnSparkTestUtils, TestUtils, 
ReusedConnectTestCase
 ):
-    @unittest.skip("TODO(SPARK-43660): Enable `resample` with Spark Connect.")
-    def test_dataframe_resample(self):
-        super().test_dataframe_resample()
-
-    @unittest.skip("TODO(SPARK-43660): Enable `resample` with Spark Connect.")
-    def test_series_resample(self):
-        super().test_series_resample()
+    pass
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to