This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 55d6b51af7cd [SPARK-46340][PS][CONNECT][TESTS] Reorganize `EWMTests`
55d6b51af7cd is described below
commit 55d6b51af7cd9108752eea65e7eef13da01118e8
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Dec 10 08:53:19 2023 +0800
[SPARK-46340][PS][CONNECT][TESTS] Reorganize `EWMTests`
### What changes were proposed in this pull request?
Reorganize `EWMTests`
### Why are the changes needed?
break it into smaller files to be consistent with pandas tests (see
https://github.com/pandas-dev/pandas/tree/main/pandas/tests/window )
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44273 from zhengruifeng/ps_test_ewm.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 8 +-
.../{test_parity_ewm.py => window/__init__.py} | 21 --
.../test_parity_ewm_error.py} | 11 +-
.../test_parity_ewm_mean.py} | 11 +-
.../test_parity_groupby_ewm_mean.py} | 11 +-
.../test_parity_ewm.py => window/__init__.py} | 21 --
.../pyspark/pandas/tests/window/test_ewm_error.py | 97 ++++++++++
.../pyspark/pandas/tests/window/test_ewm_mean.py | 194 +++++++++++++++++++
.../test_groupby_ewm_mean.py} | 215 +--------------------
9 files changed, 328 insertions(+), 261 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index e67cfce0f5c0..ca35fdabc0c4 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -729,7 +729,9 @@ pyspark_pandas = Module(
"pyspark.pandas.tests.test_default_index",
"pyspark.pandas.tests.test_expanding",
"pyspark.pandas.tests.test_extension",
- "pyspark.pandas.tests.test_ewm",
+ "pyspark.pandas.tests.window.test_ewm_error",
+ "pyspark.pandas.tests.window.test_ewm_mean",
+ "pyspark.pandas.tests.window.test_groupby_ewm_mean",
"pyspark.pandas.tests.test_frame_spark",
"pyspark.pandas.tests.test_generic_functions",
"pyspark.pandas.tests.test_frame_interpolate",
@@ -1113,7 +1115,9 @@ pyspark_pandas_connect_part2 = Module(
"pyspark.pandas.tests.connect.test_parity_series_interpolate",
"pyspark.pandas.tests.connect.resample.test_parity_frame",
"pyspark.pandas.tests.connect.resample.test_parity_series",
- "pyspark.pandas.tests.connect.test_parity_ewm",
+ "pyspark.pandas.tests.connect.window.test_parity_ewm_error",
+ "pyspark.pandas.tests.connect.window.test_parity_ewm_mean",
+ "pyspark.pandas.tests.connect.window.test_parity_groupby_ewm_mean",
"pyspark.pandas.tests.connect.test_parity_rolling",
"pyspark.pandas.tests.connect.test_parity_expanding",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling",
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
b/python/pyspark/pandas/tests/connect/window/__init__.py
similarity index 53%
copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py
copy to python/pyspark/pandas/tests/connect/window/__init__.py
index 748728203337..cce3acad34a4 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/window/__init__.py
@@ -14,24 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
-
-
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase, TestUtils):
- pass
-
-
-if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401
-
- try:
- import xmlrunner # type: ignore[import]
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py
index 748728203337..7f6b0e8494cf 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py
@@ -16,17 +16,22 @@
#
import unittest
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
+from pyspark.pandas.tests.window.test_ewm_error import EWMErrorMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase, TestUtils):
+class EWMParityErrorTests(
+ EWMErrorMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+ TestUtils,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_ewm_error import * #
noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py
index 748728203337..8c7144799bce 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py
@@ -16,17 +16,22 @@
#
import unittest
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
+from pyspark.pandas.tests.window.test_ewm_mean import EWMMeanMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase, TestUtils):
+class EWMParityMeanTests(
+ EWMMeanMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+ TestUtils,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_ewm_mean import * #
noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py
similarity index 79%
copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py
copy to
python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py
index 748728203337..76254698b757 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py
@@ -16,17 +16,22 @@
#
import unittest
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
+from pyspark.pandas.tests.window.test_groupby_ewm_mean import
GroupByEWMMeanMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase, TestUtils):
+class EWMParityGroupByMeanTests(
+ GroupByEWMMeanMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+ TestUtils,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_groupby_ewm_mean
import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
b/python/pyspark/pandas/tests/window/__init__.py
similarity index 53%
rename from python/pyspark/pandas/tests/connect/test_parity_ewm.py
rename to python/pyspark/pandas/tests/window/__init__.py
index 748728203337..cce3acad34a4 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/window/__init__.py
@@ -14,24 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
-
-
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils,
ReusedConnectTestCase, TestUtils):
- pass
-
-
-if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401
-
- try:
- import xmlrunner # type: ignore[import]
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/window/test_ewm_error.py
b/python/pyspark/pandas/tests/window/test_ewm_error.py
new file mode 100644
index 000000000000..02018fb10617
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_ewm_error.py
@@ -0,0 +1,97 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.pandas.window import ExponentialMoving
+
+
+class EWMErrorMixin:
+ def test_ewm_error(self):
+ with self.assertRaisesRegex(
+ TypeError, "psdf_or_psser must be a series or dataframe; however,
got:.*int"
+ ):
+ ExponentialMoving(1, 2)
+
+ psdf = ps.range(10)
+
+ with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
+ psdf.ewm(min_periods=-1, alpha=0.5).mean()
+
+ with self.assertRaisesRegex(ValueError, "com must be >= 0"):
+ psdf.ewm(com=-0.1).mean()
+
+ with self.assertRaisesRegex(ValueError, "span must be >= 1"):
+ psdf.ewm(span=0.7).mean()
+
+ with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
+ psdf.ewm(halflife=0).mean()
+
+ with self.assertRaisesRegex(ValueError, "alpha must be in"):
+ psdf.ewm(alpha=1.7).mean()
+
+ with self.assertRaisesRegex(ValueError, "Must pass one of com, span,
halflife, or alpha"):
+ psdf.ewm().mean()
+
+ with self.assertRaisesRegex(
+ ValueError, "com, span, halflife, and alpha are mutually exclusive"
+ ):
+ psdf.ewm(com=0.5, alpha=0.7).mean()
+
+ with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
+ psdf.groupby(psdf.id).ewm(min_periods=-1, alpha=0.5).mean()
+
+ with self.assertRaisesRegex(ValueError, "com must be >= 0"):
+ psdf.groupby(psdf.id).ewm(com=-0.1).mean()
+
+ with self.assertRaisesRegex(ValueError, "span must be >= 1"):
+ psdf.groupby(psdf.id).ewm(span=0.7).mean()
+
+ with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
+ psdf.groupby(psdf.id).ewm(halflife=0).mean()
+
+ with self.assertRaisesRegex(ValueError, "alpha must be in"):
+ psdf.groupby(psdf.id).ewm(alpha=1.7).mean()
+
+ with self.assertRaisesRegex(ValueError, "Must pass one of com, span,
halflife, or alpha"):
+ psdf.groupby(psdf.id).ewm().mean()
+
+ with self.assertRaisesRegex(
+ ValueError, "com, span, halflife, and alpha are mutually exclusive"
+ ):
+ psdf.groupby(psdf.id).ewm(com=0.5, alpha=0.7).mean()
+
+
+class EWMErrorTests(
+ EWMErrorMixin,
+ PandasOnSparkTestCase,
+ TestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.window.test_ewm_error import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/window/test_ewm_mean.py
b/python/pyspark/pandas/tests/window/test_ewm_mean.py
new file mode 100644
index 000000000000..00750b867610
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_ewm_mean.py
@@ -0,0 +1,194 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import pandas as pd
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+
+
+class EWMMeanMixin:
+ def _test_ewm_func(self, f):
+ pser = pd.Series([1, 2, 3], index=np.random.rand(3), name="a")
+ psser = ps.from_pandas(pser)
+ self.assert_eq(getattr(psser.ewm(com=0.2), f)(),
getattr(pser.ewm(com=0.2), f)())
+ self.assert_eq(
+ getattr(psser.ewm(com=0.2), f)().sum(), getattr(pser.ewm(com=0.2),
f)().sum()
+ )
+ self.assert_eq(getattr(psser.ewm(span=1.7), f)(),
getattr(pser.ewm(span=1.7), f)())
+ self.assert_eq(
+ getattr(psser.ewm(span=1.7), f)().sum(),
getattr(pser.ewm(span=1.7), f)().sum()
+ )
+ self.assert_eq(getattr(psser.ewm(halflife=0.5), f)(),
getattr(pser.ewm(halflife=0.5), f)())
+ self.assert_eq(
+ getattr(psser.ewm(halflife=0.5), f)().sum(),
getattr(pser.ewm(halflife=0.5), f)().sum()
+ )
+ self.assert_eq(getattr(psser.ewm(alpha=0.7), f)(),
getattr(pser.ewm(alpha=0.7), f)())
+ self.assert_eq(
+ getattr(psser.ewm(alpha=0.7), f)().sum(),
getattr(pser.ewm(alpha=0.7), f)().sum()
+ )
+ self.assert_eq(
+ getattr(psser.ewm(alpha=0.7, min_periods=2), f)(),
+ getattr(pser.ewm(alpha=0.7, min_periods=2), f)(),
+ )
+ self.assert_eq(
+ getattr(psser.ewm(alpha=0.7, min_periods=2), f)().sum(),
+ getattr(pser.ewm(alpha=0.7, min_periods=2), f)().sum(),
+ )
+
+ pdf = pd.DataFrame(
+ {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]},
index=np.random.rand(4)
+ )
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(getattr(psdf.ewm(com=0.2), f)(),
getattr(pdf.ewm(com=0.2), f)())
+ self.assert_eq(getattr(psdf.ewm(com=0.2), f)().sum(),
getattr(pdf.ewm(com=0.2), f)().sum())
+ self.assert_eq(getattr(psdf.ewm(span=1.7), f)(),
getattr(pdf.ewm(span=1.7), f)())
+ self.assert_eq(
+ getattr(psdf.ewm(span=1.7), f)().sum(), getattr(pdf.ewm(span=1.7),
f)().sum()
+ )
+ self.assert_eq(getattr(psdf.ewm(halflife=0.5), f)(),
getattr(pdf.ewm(halflife=0.5), f)())
+ self.assert_eq(
+ getattr(psdf.ewm(halflife=0.5), f)().sum(),
getattr(pdf.ewm(halflife=0.5), f)().sum()
+ )
+ self.assert_eq(getattr(psdf.ewm(alpha=0.7), f)(),
getattr(pdf.ewm(alpha=0.7), f)())
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7), f)().sum(),
getattr(pdf.ewm(alpha=0.7), f)().sum()
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, min_periods=2), f)(),
+ getattr(pdf.ewm(alpha=0.7, min_periods=2), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
+ getattr(pdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
+ )
+
+ pdf = pd.DataFrame(
+ {
+ "s1": [None, 2, 3, 4],
+ "s2": [1, None, 3, 4],
+ "s3": [1, 3, 4, 5],
+ "s4": [1, 0, 3, 4],
+ "s5": [None, None, 1, None],
+ "s6": [None, None, None, None],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(
+ getattr(psdf.ewm(com=0.2, ignore_na=True), f)(),
+ getattr(pdf.ewm(com=0.2, ignore_na=True), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(com=0.2, ignore_na=True), f)().sum(),
+ getattr(pdf.ewm(com=0.2, ignore_na=True), f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(com=0.2, ignore_na=False), f)(),
+ getattr(pdf.ewm(com=0.2, ignore_na=False), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(com=0.2, ignore_na=False), f)().sum(),
+ getattr(pdf.ewm(com=0.2, ignore_na=False), f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(span=1.7, ignore_na=True), f)(),
+ getattr(pdf.ewm(span=1.7, ignore_na=True), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(span=1.7, ignore_na=True), f)().sum(),
+ getattr(pdf.ewm(span=1.7, ignore_na=True), f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(span=1.7, ignore_na=False), f)(),
+ getattr(pdf.ewm(span=1.7, ignore_na=False), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(span=1.7, ignore_na=False), f)().sum(),
+ getattr(pdf.ewm(span=1.7, ignore_na=False), f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)(),
+ getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
+ getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)(),
+ getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
+ getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)(),
+ getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
+ getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)(),
+ getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
+ getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
+ getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2),
f)().sum(),
+ getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2),
f)().sum(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
+ getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
+ )
+ self.assert_eq(
+ getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2),
f)().sum(),
+ getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2),
f)().sum(),
+ )
+
+ def test_ewm_mean(self):
+ self._test_ewm_func("mean")
+
+
+class EWMMeanTests(
+ EWMMeanMixin,
+ PandasOnSparkTestCase,
+ TestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.window.test_ewm_mean import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/test_ewm.py
b/python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py
similarity index 52%
rename from python/pyspark/pandas/tests/test_ewm.py
rename to python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py
index a8886a0af69c..fb29cb8ea04f 100644
--- a/python/pyspark/pandas/tests/test_ewm.py
+++ b/python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py
@@ -19,214 +19,9 @@ import pandas as pd
import pyspark.pandas as ps
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-from pyspark.pandas.window import ExponentialMoving
-class EWMTestsMixin:
- def test_ewm_error(self):
- with self.assertRaisesRegex(
- TypeError, "psdf_or_psser must be a series or dataframe; however,
got:.*int"
- ):
- ExponentialMoving(1, 2)
-
- psdf = ps.range(10)
-
- with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
- psdf.ewm(min_periods=-1, alpha=0.5).mean()
-
- with self.assertRaisesRegex(ValueError, "com must be >= 0"):
- psdf.ewm(com=-0.1).mean()
-
- with self.assertRaisesRegex(ValueError, "span must be >= 1"):
- psdf.ewm(span=0.7).mean()
-
- with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
- psdf.ewm(halflife=0).mean()
-
- with self.assertRaisesRegex(ValueError, "alpha must be in"):
- psdf.ewm(alpha=1.7).mean()
-
- with self.assertRaisesRegex(ValueError, "Must pass one of com, span,
halflife, or alpha"):
- psdf.ewm().mean()
-
- with self.assertRaisesRegex(
- ValueError, "com, span, halflife, and alpha are mutually exclusive"
- ):
- psdf.ewm(com=0.5, alpha=0.7).mean()
-
- with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
- psdf.groupby(psdf.id).ewm(min_periods=-1, alpha=0.5).mean()
-
- with self.assertRaisesRegex(ValueError, "com must be >= 0"):
- psdf.groupby(psdf.id).ewm(com=-0.1).mean()
-
- with self.assertRaisesRegex(ValueError, "span must be >= 1"):
- psdf.groupby(psdf.id).ewm(span=0.7).mean()
-
- with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
- psdf.groupby(psdf.id).ewm(halflife=0).mean()
-
- with self.assertRaisesRegex(ValueError, "alpha must be in"):
- psdf.groupby(psdf.id).ewm(alpha=1.7).mean()
-
- with self.assertRaisesRegex(ValueError, "Must pass one of com, span,
halflife, or alpha"):
- psdf.groupby(psdf.id).ewm().mean()
-
- with self.assertRaisesRegex(
- ValueError, "com, span, halflife, and alpha are mutually exclusive"
- ):
- psdf.groupby(psdf.id).ewm(com=0.5, alpha=0.7).mean()
-
- def _test_ewm_func(self, f):
- pser = pd.Series([1, 2, 3], index=np.random.rand(3), name="a")
- psser = ps.from_pandas(pser)
- self.assert_eq(getattr(psser.ewm(com=0.2), f)(),
getattr(pser.ewm(com=0.2), f)())
- self.assert_eq(
- getattr(psser.ewm(com=0.2), f)().sum(), getattr(pser.ewm(com=0.2),
f)().sum()
- )
- self.assert_eq(getattr(psser.ewm(span=1.7), f)(),
getattr(pser.ewm(span=1.7), f)())
- self.assert_eq(
- getattr(psser.ewm(span=1.7), f)().sum(),
getattr(pser.ewm(span=1.7), f)().sum()
- )
- self.assert_eq(getattr(psser.ewm(halflife=0.5), f)(),
getattr(pser.ewm(halflife=0.5), f)())
- self.assert_eq(
- getattr(psser.ewm(halflife=0.5), f)().sum(),
getattr(pser.ewm(halflife=0.5), f)().sum()
- )
- self.assert_eq(getattr(psser.ewm(alpha=0.7), f)(),
getattr(pser.ewm(alpha=0.7), f)())
- self.assert_eq(
- getattr(psser.ewm(alpha=0.7), f)().sum(),
getattr(pser.ewm(alpha=0.7), f)().sum()
- )
- self.assert_eq(
- getattr(psser.ewm(alpha=0.7, min_periods=2), f)(),
- getattr(pser.ewm(alpha=0.7, min_periods=2), f)(),
- )
- self.assert_eq(
- getattr(psser.ewm(alpha=0.7, min_periods=2), f)().sum(),
- getattr(pser.ewm(alpha=0.7, min_periods=2), f)().sum(),
- )
-
- pdf = pd.DataFrame(
- {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]},
index=np.random.rand(4)
- )
- psdf = ps.from_pandas(pdf)
- self.assert_eq(getattr(psdf.ewm(com=0.2), f)(),
getattr(pdf.ewm(com=0.2), f)())
- self.assert_eq(getattr(psdf.ewm(com=0.2), f)().sum(),
getattr(pdf.ewm(com=0.2), f)().sum())
- self.assert_eq(getattr(psdf.ewm(span=1.7), f)(),
getattr(pdf.ewm(span=1.7), f)())
- self.assert_eq(
- getattr(psdf.ewm(span=1.7), f)().sum(), getattr(pdf.ewm(span=1.7),
f)().sum()
- )
- self.assert_eq(getattr(psdf.ewm(halflife=0.5), f)(),
getattr(pdf.ewm(halflife=0.5), f)())
- self.assert_eq(
- getattr(psdf.ewm(halflife=0.5), f)().sum(),
getattr(pdf.ewm(halflife=0.5), f)().sum()
- )
- self.assert_eq(getattr(psdf.ewm(alpha=0.7), f)(),
getattr(pdf.ewm(alpha=0.7), f)())
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7), f)().sum(),
getattr(pdf.ewm(alpha=0.7), f)().sum()
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, min_periods=2), f)(),
- getattr(pdf.ewm(alpha=0.7, min_periods=2), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
- getattr(pdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
- )
-
- pdf = pd.DataFrame(
- {
- "s1": [None, 2, 3, 4],
- "s2": [1, None, 3, 4],
- "s3": [1, 3, 4, 5],
- "s4": [1, 0, 3, 4],
- "s5": [None, None, 1, None],
- "s6": [None, None, None, None],
- }
- )
- psdf = ps.from_pandas(pdf)
- self.assert_eq(
- getattr(psdf.ewm(com=0.2, ignore_na=True), f)(),
- getattr(pdf.ewm(com=0.2, ignore_na=True), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(com=0.2, ignore_na=True), f)().sum(),
- getattr(pdf.ewm(com=0.2, ignore_na=True), f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(com=0.2, ignore_na=False), f)(),
- getattr(pdf.ewm(com=0.2, ignore_na=False), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(com=0.2, ignore_na=False), f)().sum(),
- getattr(pdf.ewm(com=0.2, ignore_na=False), f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(span=1.7, ignore_na=True), f)(),
- getattr(pdf.ewm(span=1.7, ignore_na=True), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(span=1.7, ignore_na=True), f)().sum(),
- getattr(pdf.ewm(span=1.7, ignore_na=True), f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(span=1.7, ignore_na=False), f)(),
- getattr(pdf.ewm(span=1.7, ignore_na=False), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(span=1.7, ignore_na=False), f)().sum(),
- getattr(pdf.ewm(span=1.7, ignore_na=False), f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)(),
- getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
- getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)(),
- getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
- getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)(),
- getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
- getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)(),
- getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
- getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
- getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2),
f)().sum(),
- getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2),
f)().sum(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
- getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
- )
- self.assert_eq(
- getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2),
f)().sum(),
- getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2),
f)().sum(),
- )
-
- def test_ewm_mean(self):
- self._test_ewm_func("mean")
-
+class GroupByEWMMeanMixin:
def _test_groupby_ewm_func(self, f):
pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
psser = ps.from_pandas(pser)
@@ -417,13 +212,17 @@ class EWMTestsMixin:
self._test_groupby_ewm_func("mean")
-class EWMTests(EWMTestsMixin, PandasOnSparkTestCase, TestUtils):
+class GroupByEWMMeanTests(
+ GroupByEWMMeanMixin,
+ PandasOnSparkTestCase,
+ TestUtils,
+):
pass
if __name__ == "__main__":
import unittest
- from pyspark.pandas.tests.test_ewm import * # noqa: F401
+ from pyspark.pandas.tests.window.test_groupby_ewm_mean import * # noqa:
F401
try:
import xmlrunner
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]