This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 17fac569b4e [SPARK-43660][CONNECT][PS] Enable `resample` with Spark
Connect
17fac569b4e is described below
commit 17fac569b4e4b569d41f761db07d7bf112801e0c
Author: itholic <[email protected]>
AuthorDate: Fri Jul 7 10:24:40 2023 +0800
[SPARK-43660][CONNECT][PS] Enable `resample` with Spark Connect
### What changes were proposed in this pull request?
This PR proposes to enable `resample` on Spark Connect.
### Why are the changes needed?
To increase pandas API coverage on Spark Connect
### Does this PR introduce _any_ user-facing change?
`resample` is available on Spark Connect.
### How was this patch tested?
Uncommented skipping tests, the existing CI should pass.
Closes #41877 from itholic/SPARK-43660.
Authored-by: itholic <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 5 +++
python/pyspark/pandas/resample.py | 46 ++++++++++++++++------
python/pyspark/pandas/spark/functions.py | 16 ++++++++
.../pandas/tests/connect/test_parity_resample.py | 8 +---
4 files changed, 56 insertions(+), 19 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d3090e8b09b..5fd5f7d4c77 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1798,6 +1798,11 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
val children = fun.getArgumentsList.asScala.map(transformExpression)
Some(NullIndex(children(0)))
+ case "timestampdiff" if fun.getArgumentsCount == 3 =>
+ val children = fun.getArgumentsList.asScala.map(transformExpression)
+ val unit = extractString(children(0), "unit")
+ Some(TimestampDiff(unit, children(1), children(2)))
+
// ML-specific functions
case "vector_to_array" if fun.getArgumentsCount == 2 =>
val expr = transformExpression(fun.getArguments(0))
diff --git a/python/pyspark/pandas/resample.py
b/python/pyspark/pandas/resample.py
index 1bd4c075342..c6c6019c07e 100644
--- a/python/pyspark/pandas/resample.py
+++ b/python/pyspark/pandas/resample.py
@@ -26,6 +26,7 @@ from typing import (
Generic,
List,
Optional,
+ Union,
)
import numpy as np
@@ -65,6 +66,8 @@ from pyspark.pandas.utils import (
scol_for,
verify_temp_column_name,
)
+from pyspark.sql.utils import is_remote
+from pyspark.pandas.spark.functions import timestampdiff
class Resampler(Generic[FrameLike], metaclass=ABCMeta):
@@ -131,8 +134,27 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
def _agg_columns_scols(self) -> List[Column]:
return [s.spark.column for s in self._agg_columns]
+ def get_make_interval( # type: ignore[return]
+ self, unit: str, col: Union[Column, int, float]
+ ) -> Column:
+ if is_remote():
+ from pyspark.sql.connect.functions import lit, make_interval
+
+ col = col if not isinstance(col, (int, float)) else lit(col) #
type: ignore[assignment]
+ if unit == "MONTH":
+ return make_interval(months=col) # type: ignore
+ if unit == "HOUR":
+ return make_interval(hours=col) # type: ignore
+ if unit == "MINUTE":
+ return make_interval(mins=col) # type: ignore
+ if unit == "SECOND":
+ return make_interval(secs=col) # type: ignore
+ else:
+ sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
+ col = col._jc if isinstance(col, Column) else F.lit(col)._jc
+ return sql_utils.makeInterval(unit, col)
+
def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
- sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
origin_scol = F.lit(origin)
(rule_code, n) = (self._offset.rule_code, getattr(self._offset, "n"))
left_closed, right_closed = (self._closed == "left", self._closed ==
"right")
@@ -191,18 +213,18 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
truncated_ts_scol = F.date_trunc("MONTH", ts_scol)
edge_label = truncated_ts_scol
if left_closed and right_labeled:
- edge_label += sql_utils.makeInterval("MONTH", F.lit(n)._jc)
+ edge_label += self.get_make_interval("MONTH", n)
elif right_closed and left_labeled:
- edge_label -= sql_utils.makeInterval("MONTH", F.lit(n)._jc)
+ edge_label -= self.get_make_interval("MONTH", n)
if left_labeled:
non_edge_label = F.when(
mod == 0,
- truncated_ts_scol - sql_utils.makeInterval("MONTH",
F.lit(n)._jc),
- ).otherwise(truncated_ts_scol -
sql_utils.makeInterval("MONTH", mod._jc))
+ truncated_ts_scol - self.get_make_interval("MONTH", n),
+ ).otherwise(truncated_ts_scol -
self.get_make_interval("MONTH", mod))
else:
non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
- truncated_ts_scol - sql_utils.makeInterval("MONTH", (mod -
n)._jc)
+ truncated_ts_scol - self.get_make_interval("MONTH", mod -
n)
)
return F.to_timestamp(
@@ -257,7 +279,7 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
unit_str = unit_mapping[rule_code]
truncated_ts_scol = F.date_trunc(unit_str, ts_scol)
- diff = sql_utils.timestampDiff(unit_str, origin_scol._jc,
truncated_ts_scol._jc)
+ diff = timestampdiff(unit_str, origin_scol, truncated_ts_scol)
mod = F.lit(0) if n == 1 else (diff % F.lit(n))
if rule_code == "H":
@@ -271,19 +293,19 @@ class Resampler(Generic[FrameLike], metaclass=ABCMeta):
edge_label = truncated_ts_scol
if left_closed and right_labeled:
- edge_label += sql_utils.makeInterval(unit_str, F.lit(n)._jc)
+ edge_label += self.get_make_interval(unit_str, n)
elif right_closed and left_labeled:
- edge_label -= sql_utils.makeInterval(unit_str, F.lit(n)._jc)
+ edge_label -= self.get_make_interval(unit_str, n)
if left_labeled:
non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
- truncated_ts_scol - sql_utils.makeInterval(unit_str,
mod._jc)
+ truncated_ts_scol - self.get_make_interval(unit_str, mod)
)
else:
non_edge_label = F.when(
mod == 0,
- truncated_ts_scol + sql_utils.makeInterval(unit_str,
F.lit(n)._jc),
- ).otherwise(truncated_ts_scol -
sql_utils.makeInterval(unit_str, (mod - n)._jc))
+ truncated_ts_scol + self.get_make_interval(unit_str, n),
+ ).otherwise(truncated_ts_scol -
self.get_make_interval(unit_str, mod - n))
return F.when(edge_cond, edge_label).otherwise(non_edge_label)
diff --git a/python/pyspark/pandas/spark/functions.py
b/python/pyspark/pandas/spark/functions.py
index 44650fd4d20..d6f6c6fdeeb 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -185,3 +185,19 @@ def null_index(col: Column) -> Column:
else:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
+
+
+def timestampdiff(unit: str, start: Column, end: Column) -> Column:
+ if is_remote():
+ from pyspark.sql.connect.functions import
_invoke_function_over_columns, lit
+
+ return _invoke_function_over_columns( # type: ignore[return-value]
+ "timestampdiff",
+ lit(unit),
+ start, # type: ignore[arg-type]
+ end, # type: ignore[arg-type]
+ )
+
+ else:
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc,
end._jc))
diff --git a/python/pyspark/pandas/tests/connect/test_parity_resample.py
b/python/pyspark/pandas/tests/connect/test_parity_resample.py
index ab3e7f4410b..e5957cc9b4a 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_resample.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_resample.py
@@ -24,13 +24,7 @@ from pyspark.testing.pandasutils import
PandasOnSparkTestUtils, TestUtils
class ResampleTestsParityMixin(
ResampleTestsMixin, PandasOnSparkTestUtils, TestUtils,
ReusedConnectTestCase
):
- @unittest.skip("TODO(SPARK-43660): Enable `resample` with Spark Connect.")
- def test_dataframe_resample(self):
- super().test_dataframe_resample()
-
- @unittest.skip("TODO(SPARK-43660): Enable `resample` with Spark Connect.")
- def test_series_resample(self):
- super().test_series_resample()
+ pass
if __name__ == "__main__":
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]