This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new aa568354725c [SPARK-47811][PYTHON][CONNECT][TESTS] Run ML tests for
pyspark-connect package
aa568354725c is described below
commit aa568354725ce44fc0261973b97597ab0986edb1
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Apr 12 09:02:47 2024 +0900
[SPARK-47811][PYTHON][CONNECT][TESTS] Run ML tests for pyspark-connect
package
### What changes were proposed in this pull request?
This PR proposes to extends `pyspark-connect` scheduled job to run ML tests
as well.
### Why are the changes needed?
In order to make sure pure Python library works with ML.
### Does this PR introduce _any_ user-facing change?
No, test-only.
### How was this patch tested?
Tested in my fork:
https://github.com/HyukjinKwon/spark/actions/runs/8643632135/job/23697401430
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45941 from HyukjinKwon/test-ps-ci.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.github/workflows/build_python_connect.yml | 3 +-
python/packaging/connect/setup.py | 1 +
python/pyspark/ml/connect/classification.py | 1 -
python/pyspark/ml/param/__init__.py | 7 +-
.../tests/connect/test_connect_classification.py | 10 +-
.../ml/tests/connect/test_connect_evaluation.py | 5 +-
.../ml/tests/connect/test_connect_feature.py | 5 +-
.../ml/tests/connect/test_connect_function.py | 2 +
.../ml/tests/connect/test_connect_pipeline.py | 11 +-
.../ml/tests/connect/test_connect_summarizer.py | 5 +-
.../ml/tests/connect/test_connect_tuning.py | 9 +-
.../connect/test_legacy_mode_classification.py | 8 +-
.../tests/connect/test_legacy_mode_evaluation.py | 9 +-
.../ml/tests/connect/test_legacy_mode_feature.py | 6 +-
.../ml/tests/connect/test_legacy_mode_pipeline.py | 6 +-
.../tests/connect/test_legacy_mode_summarizer.py | 6 +-
.../ml/tests/connect/test_legacy_mode_tuning.py | 9 +-
.../tests/connect/test_parity_torch_data_loader.py | 28 ++-
.../tests/connect/test_parity_torch_distributor.py | 232 +++++++++++----------
19 files changed, 218 insertions(+), 145 deletions(-)
diff --git a/.github/workflows/build_python_connect.yml
b/.github/workflows/build_python_connect.yml
index ec7103e5dbeb..8deee026131e 100644
--- a/.github/workflows/build_python_connect.yml
+++ b/.github/workflows/build_python_connect.yml
@@ -72,6 +72,7 @@ jobs:
python packaging/connect/setup.py sdist
cd dist
pip install pyspark-connect-*.tar.gz
+ pip install scikit-learn torch torchvision torcheval
- name: Run tests
env:
SPARK_CONNECT_TESTING_REMOTE: sc://localhost
@@ -82,7 +83,7 @@ jobs:
# Remove Py4J and PySpark zipped library to make sure there is no
JVM connection
rm python/lib/*
rm -r python/pyspark
- ./python/run-tests --parallelism=1 --python-executables=python3
--modules pyspark-connect
+ ./python/run-tests --parallelism=1 --python-executables=python3
--modules pyspark-connect,pyspark-ml-connect
- name: Upload test results to report
if: always()
uses: actions/upload-artifact@v4
diff --git a/python/packaging/connect/setup.py
b/python/packaging/connect/setup.py
index 3514e5cdc422..419ed36b4236 100755
--- a/python/packaging/connect/setup.py
+++ b/python/packaging/connect/setup.py
@@ -77,6 +77,7 @@ if "SPARK_TESTING" in os.environ:
"pyspark.sql.tests.connect.shell",
"pyspark.sql.tests.pandas",
"pyspark.sql.tests.streaming",
+ "pyspark.ml.tests.connect",
]
try:
diff --git a/python/pyspark/ml/connect/classification.py
b/python/pyspark/ml/connect/classification.py
index 8d8c6227eac3..fc7b5cda88a2 100644
--- a/python/pyspark/ml/connect/classification.py
+++ b/python/pyspark/ml/connect/classification.py
@@ -320,7 +320,6 @@ class LogisticRegressionModel(
def _get_transform_fn(self) -> Callable[["pd.Series"], Any]:
import torch
-
import torch.nn as torch_nn
model_state_dict = self.torch_model.state_dict()
diff --git a/python/pyspark/ml/param/__init__.py
b/python/pyspark/ml/param/__init__.py
index 345b7f7a5964..f32ead2a580c 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -30,8 +30,8 @@ from typing import (
)
import numpy as np
-from py4j.java_gateway import JavaObject
+from pyspark.util import is_remote_only
from pyspark.ml.linalg import DenseVector, Vector, Matrix
from pyspark.ml.util import Identifiable
@@ -516,9 +516,12 @@ class Params(Identifiable, metaclass=ABCMeta):
"""
Sets default params.
"""
+ if not is_remote_only():
+ from py4j.java_gateway import JavaObject
+
for param, value in kwargs.items():
p = getattr(self, param)
- if value is not None and not isinstance(value, JavaObject):
+ if value is not None and (is_remote_only() or not
isinstance(value, JavaObject)):
try:
value = p.typeConverter(value)
except TypeError as e:
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py
b/python/pyspark/ml/tests/connect/test_connect_classification.py
index ebc1745874d9..8083090523a0 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -17,7 +17,9 @@
#
import unittest
+import os
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
@@ -33,13 +35,15 @@ if should_test_connect:
@unittest.skipIf(
- not should_test_connect or not have_torch,
- connect_requirement_message or torch_requirement_message,
+ not should_test_connect or not have_torch or is_remote_only(),
+ connect_requirement_message
+ or torch_requirement_message
+ or "Requires PySpark core library in Spark Connect server",
)
class ClassificationTestsOnConnect(ClassificationTestsMixin,
unittest.TestCase):
def setUp(self) -> None:
self.spark = (
- SparkSession.builder.remote("local[2]")
+
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE",
"local[2]"))
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal",
"true")
.getOrCreate()
)
diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
index 7f3b6bd0198c..359a77bbcb20 100644
--- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import os
import unittest
from pyspark.sql import SparkSession
@@ -36,7 +37,9 @@ if should_test_connect:
)
class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
- self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
+ self.spark = SparkSession.builder.remote(
+ os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+ ).getOrCreate()
def tearDown(self) -> None:
self.spark.stop()
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py
b/python/pyspark/ml/tests/connect/test_connect_feature.py
index 04b1744c4995..c786ce2f87d0 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import os
import unittest
from pyspark.sql import SparkSession
@@ -38,7 +39,9 @@ if should_test_connect:
)
class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase):
def setUp(self) -> None:
- self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
+ self.spark = SparkSession.builder.remote(
+ os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+ ).getOrCreate()
def tearDown(self) -> None:
self.spark.stop()
diff --git a/python/pyspark/ml/tests/connect/test_connect_function.py
b/python/pyspark/ml/tests/connect/test_connect_function.py
index b38d415e2bb2..f50376110660 100644
--- a/python/pyspark/ml/tests/connect/test_connect_function.py
+++ b/python/pyspark/ml/tests/connect/test_connect_function.py
@@ -17,6 +17,7 @@
import os
import unittest
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession as PySparkSession
from pyspark.sql.dataframe import DataFrame as SDF
from pyspark.ml import functions as SF
@@ -32,6 +33,7 @@ if should_test_connect:
from pyspark.ml.connect import functions as CF
[email protected](is_remote_only(), "Requires JVM access")
class SparkConnectMLFunctionTests(ReusedConnectTestCase,
PandasOnSparkTestUtils, SQLTestUtils):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index 45d19f2bcdde..4105f593f170 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -15,9 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import os
import unittest
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
@@ -34,13 +35,15 @@ except ImportError:
@unittest.skipIf(
- not should_test_connect or not have_torch,
- connect_requirement_message or torch_requirement_message,
+ not should_test_connect or not have_torch or is_remote_only(),
+ connect_requirement_message
+ or torch_requirement_message
+ or "Requires PySpark core library in Spark Connect server",
)
class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
- SparkSession.builder.remote("local[2]")
+
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE",
"local[2]"))
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal",
"true")
.getOrCreate()
)
diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py
b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
index 866a3468388d..1cfd2ed229e5 100644
--- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
@@ -16,6 +16,7 @@
#
import unittest
+import os
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
@@ -27,7 +28,9 @@ if should_test_connect:
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase):
def setUp(self) -> None:
- self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
+ self.spark = SparkSession.builder.remote(
+ os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+ ).getOrCreate()
def tearDown(self) -> None:
self.spark.stop()
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py
b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index 7b10d91da064..d5fcb93099b6 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -17,7 +17,9 @@
#
import unittest
+import os
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
@@ -25,11 +27,14 @@ if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_tuning import
CrossValidatorTestsMixin
[email protected](not should_test_connect, connect_requirement_message)
[email protected](
+ not should_test_connect or is_remote_only(),
+ connect_requirement_message or "Requires PySpark core library in Spark
Connect server",
+)
class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin,
unittest.TestCase):
def setUp(self) -> None:
self.spark = (
- SparkSession.builder.remote("local[2]")
+
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE",
"local[2]"))
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal",
"true")
.getOrCreate()
)
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
index db9a29804808..dc2642a42d66 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
@@ -21,14 +21,17 @@ import unittest
import numpy as np
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
have_torch = True
+torch_requirement_message = None
try:
import torch # noqa: F401
except ImportError:
have_torch = False
+ torch_requirement_message = "No torch found"
if should_test_connect:
from pyspark.ml.connect.classification import (
@@ -228,7 +231,10 @@ class ClassificationTestsMixin:
@unittest.skipIf(
- not should_test_connect or not have_torch, connect_requirement_message or
"No torch found"
+ not should_test_connect or not have_torch or is_remote_only(),
+ connect_requirement_message
+ or torch_requirement_message
+ or "pyspark-connect cannot test classic Spark",
)
class ClassificationTests(ClassificationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
index ae01031ff462..11c1f9aeee51 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
@@ -20,14 +20,17 @@ import tempfile
import numpy as np
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
have_torcheval = True
+torcheval_requirement_message = None
try:
import torcheval # noqa: F401
except ImportError:
have_torcheval = False
+ torcheval_requirement_message = "torcheval is required"
if should_test_connect:
from pyspark.ml.connect.evaluation import (
@@ -177,8 +180,10 @@ class EvaluationTestsMixin:
@unittest.skipIf(
- not should_test_connect or not have_torcheval,
- connect_requirement_message or "torcheval is required",
+ not should_test_connect or not have_torcheval or is_remote_only(),
+ connect_requirement_message
+ or torcheval_requirement_message
+ or "pyspark-connect cannot test classic Spark",
)
class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
index 9565b3a09a5b..4915d4706b87 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
@@ -23,6 +23,7 @@ import unittest
import numpy as np
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
@@ -194,7 +195,10 @@ class FeatureTestsMixin:
assembler2.transform(pandas_df)["out"].tolist()
[email protected](not should_test_connect, connect_requirement_message)
[email protected](
+ not should_test_connect or is_remote_only(),
+ connect_requirement_message or "pyspark-connect cannot test classic Spark",
+)
class FeatureTests(FeatureTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
index 104aff17e0b2..692144148af0 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
@@ -21,6 +21,7 @@ import unittest
import numpy as np
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
@@ -167,7 +168,10 @@ class PipelineTestsMixin:
assert lorv2.getOrDefault(lorv2.maxIter) == 200
[email protected](not should_test_connect, connect_requirement_message)
[email protected](
+ not should_test_connect or is_remote_only(),
+ connect_requirement_message or "pyspark-connect cannot test classic Spark",
+)
class PipelineTests(PipelineTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
index 7f09eb9f0742..253632a74c97 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
@@ -20,6 +20,7 @@ import unittest
import numpy as np
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
@@ -62,7 +63,10 @@ class SummarizerTestsMixin:
assert_dict_allclose(result_local, expected_result)
[email protected](not should_test_connect, connect_requirement_message)
[email protected](
+ not should_test_connect or is_remote_only(),
+ connect_requirement_message or "pyspark-connect cannot test classic Spark",
+)
class SummarizerTests(SummarizerTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
index 7f26788c137f..14f52d75e6d6 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -22,6 +22,7 @@ import sys
import numpy as np
+from pyspark.util import is_remote_only
from pyspark.ml.param import Param, Params
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.sql import SparkSession
@@ -29,10 +30,13 @@ from pyspark.sql.functions import rand
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
have_sklearn = True
+sklearn_requirement_message = None
try:
from sklearn.datasets import load_breast_cancer # noqa: F401
except ImportError:
have_sklearn = False
+ sklearn_requirement_message = "No sklearn found"
+
if should_test_connect:
import pandas as pd
@@ -279,7 +283,10 @@ class CrossValidatorTestsMixin:
@unittest.skipIf(
- not should_test_connect or not have_sklearn, connect_requirement_message
or "No sklearn found"
+ not should_test_connect or not have_sklearn or is_remote_only(),
+ connect_requirement_message
+ or sklearn_requirement_message
+ or "pyspark-connect cannot test classic Spark",
)
class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
index 1984efdc6c6e..462fe3822141 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
@@ -17,24 +17,30 @@
import unittest
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
-from pyspark.ml.torch.tests.test_data_loader import
TorchDistributorDataLoaderUnitTests
+torch_requirement_message = None
have_torch = True
try:
import torch # noqa: F401
except ImportError:
have_torch = False
-
-
[email protected](not have_torch, "torch is required")
-class
TorchDistributorBaselineUnitTestsOnConnect(TorchDistributorDataLoaderUnitTests):
- def setUp(self) -> None:
- self.spark = (
- SparkSession.builder.remote("local[1]")
- .config("spark.default.parallelism", "1")
- .getOrCreate()
- )
+ torch_requirement_message = "torch is required"
+
+if not is_remote_only():
+ from pyspark.ml.torch.tests.test_data_loader import
TorchDistributorDataLoaderUnitTests
+
+ @unittest.skipIf(
+ not have_torch or is_remote_only(), torch_requirement_message or
"Requires JVM access"
+ )
+ class
TorchDistributorBaselineUnitTestsOnConnect(TorchDistributorDataLoaderUnitTests):
+ def setUp(self) -> None:
+ self.spark = (
+ SparkSession.builder.remote("local[1]")
+ .config("spark.default.parallelism", "1")
+ .getOrCreate()
+ )
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
index 70aa80ba6d11..e40303ae9ce2 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
@@ -19,124 +19,134 @@ import os
import shutil
import unittest
+torch_requirement_message = None
have_torch = True
try:
import torch # noqa: F401
except ImportError:
have_torch = False
+ torch_requirement_message = "torch is required"
+from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
-from pyspark.ml.torch.tests.test_distributor import (
- TorchDistributorBaselineUnitTestsMixin,
- TorchDistributorLocalUnitTestsMixin,
- TorchDistributorDistributedUnitTestsMixin,
- TorchWrapperUnitTestsMixin,
- set_up_test_dirs,
- get_local_mode_conf,
- get_distributed_mode_conf,
-)
-
-
[email protected](not have_torch, "torch is required")
-class TorchDistributorBaselineUnitTestsOnConnect(
- TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
-):
- @classmethod
- def setUpClass(cls):
- cls.spark = SparkSession.builder.remote("local[4]").getOrCreate()
-
- @classmethod
- def tearDownClass(cls):
- cls.spark.stop()
-
-
[email protected](not have_torch, "torch is required")
-class TorchDistributorLocalUnitTestsOnConnect(
- TorchDistributorLocalUnitTestsMixin, unittest.TestCase
-):
- @classmethod
- def setUpClass(cls):
- (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) =
set_up_test_dirs()
- builder = SparkSession.builder.appName(cls.__name__)
- for k, v in get_local_mode_conf().items():
- builder = builder.config(k, v)
- builder = builder.config(
- "spark.driver.resource.gpu.discoveryScript",
cls.gpu_discovery_script_file_name
- )
- cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate()
-
- @classmethod
- def tearDownClass(cls):
- shutil.rmtree(cls.mnist_dir_path)
- os.unlink(cls.gpu_discovery_script_file_name)
- cls.spark.stop()
-
- def _get_inputs_for_test_local_training_succeeds(self):
- return [
- ("0,1,2", 1, True, "0,1,2"),
- ("0,1,2", 3, True, "0,1,2"),
- ("0,1,2", 2, False, "0,1,2"),
- (None, 3, False, "NONE"),
- ]
-
-
[email protected](not have_torch, "torch is required")
-class TorchDistributorLocalUnitTestsIIOnConnect(
- TorchDistributorLocalUnitTestsMixin, unittest.TestCase
-):
- @classmethod
- def setUpClass(cls):
- (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) =
set_up_test_dirs()
- builder = SparkSession.builder.appName(cls.__name__)
- for k, v in get_local_mode_conf().items():
- builder = builder.config(k, v)
-
- builder = builder.config(
- "spark.driver.resource.gpu.discoveryScript",
cls.gpu_discovery_script_file_name
- )
- cls.spark = builder.remote("local[4]").getOrCreate()
-
- @classmethod
- def tearDownClass(cls):
- shutil.rmtree(cls.mnist_dir_path)
- os.unlink(cls.gpu_discovery_script_file_name)
- cls.spark.stop()
-
- def _get_inputs_for_test_local_training_succeeds(self):
- return [
- ("0,1,2", 1, True, "0,1,2"),
- ("0,1,2", 3, True, "0,1,2"),
- ("0,1,2", 2, False, "0,1,2"),
- (None, 3, False, "NONE"),
- ]
-
-
[email protected](not have_torch, "torch is required")
-class TorchDistributorDistributedUnitTestsOnConnect(
- TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
-):
- @classmethod
- def setUpClass(cls):
- (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) =
set_up_test_dirs()
- builder = SparkSession.builder.appName(cls.__name__)
- for k, v in get_distributed_mode_conf().items():
- builder = builder.config(k, v)
-
- builder = builder.config(
- "spark.worker.resource.gpu.discoveryScript",
cls.gpu_discovery_script_file_name
- )
- cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate()
-
- @classmethod
- def tearDownClass(cls):
- shutil.rmtree(cls.mnist_dir_path)
- os.unlink(cls.gpu_discovery_script_file_name)
- cls.spark.stop()
-
-
[email protected](not have_torch, "torch is required")
-class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin,
unittest.TestCase):
- pass
+
+if not is_remote_only():
+ from pyspark.ml.torch.tests.test_distributor import (
+ TorchDistributorBaselineUnitTestsMixin,
+ TorchDistributorLocalUnitTestsMixin,
+ TorchDistributorDistributedUnitTestsMixin,
+ TorchWrapperUnitTestsMixin,
+ set_up_test_dirs,
+ get_local_mode_conf,
+ get_distributed_mode_conf,
+ )
+
+ @unittest.skipIf(
+ not have_torch or is_remote_only(), torch_requirement_message or
"Requires JVM access"
+ )
+ class TorchDistributorBaselineUnitTestsOnConnect(
+ TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+ ):
+ @classmethod
+ def setUpClass(cls):
+ cls.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.spark.stop()
+
+ @unittest.skipIf(
+ not have_torch or is_remote_only(), torch_requirement_message or
"Requires JVM access"
+ )
+ class TorchDistributorLocalUnitTestsOnConnect(
+ TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+ ):
+ @classmethod
+ def setUpClass(cls):
+ (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) =
set_up_test_dirs()
+ builder = SparkSession.builder.appName(cls.__name__)
+ for k, v in get_local_mode_conf().items():
+ builder = builder.config(k, v)
+ builder = builder.config(
+ "spark.driver.resource.gpu.discoveryScript",
cls.gpu_discovery_script_file_name
+ )
+ cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate()
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.mnist_dir_path)
+ os.unlink(cls.gpu_discovery_script_file_name)
+ cls.spark.stop()
+
+ def _get_inputs_for_test_local_training_succeeds(self):
+ return [
+ ("0,1,2", 1, True, "0,1,2"),
+ ("0,1,2", 3, True, "0,1,2"),
+ ("0,1,2", 2, False, "0,1,2"),
+ (None, 3, False, "NONE"),
+ ]
+
+ @unittest.skipIf(
+ not have_torch or is_remote_only(), torch_requirement_message or
"Requires JVM access"
+ )
+ class TorchDistributorLocalUnitTestsIIOnConnect(
+ TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+ ):
+ @classmethod
+ def setUpClass(cls):
+ (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) =
set_up_test_dirs()
+ builder = SparkSession.builder.appName(cls.__name__)
+ for k, v in get_local_mode_conf().items():
+ builder = builder.config(k, v)
+
+ builder = builder.config(
+ "spark.driver.resource.gpu.discoveryScript",
cls.gpu_discovery_script_file_name
+ )
+ cls.spark = builder.remote("local[4]").getOrCreate()
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.mnist_dir_path)
+ os.unlink(cls.gpu_discovery_script_file_name)
+ cls.spark.stop()
+
+ def _get_inputs_for_test_local_training_succeeds(self):
+ return [
+ ("0,1,2", 1, True, "0,1,2"),
+ ("0,1,2", 3, True, "0,1,2"),
+ ("0,1,2", 2, False, "0,1,2"),
+ (None, 3, False, "NONE"),
+ ]
+
+ @unittest.skipIf(
+ not have_torch or is_remote_only(), torch_requirement_message or
"Requires JVM access"
+ )
+ class TorchDistributorDistributedUnitTestsOnConnect(
+ TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+ ):
+ @classmethod
+ def setUpClass(cls):
+ (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) =
set_up_test_dirs()
+ builder = SparkSession.builder.appName(cls.__name__)
+ for k, v in get_distributed_mode_conf().items():
+ builder = builder.config(k, v)
+
+ builder = builder.config(
+ "spark.worker.resource.gpu.discoveryScript",
cls.gpu_discovery_script_file_name
+ )
+ cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate()
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.mnist_dir_path)
+ os.unlink(cls.gpu_discovery_script_file_name)
+ cls.spark.stop()
+
+ @unittest.skipIf(
+ not have_torch or is_remote_only(), torch_requirement_message or
"Requires JVM access"
+ )
+ class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin,
unittest.TestCase):
+ pass
if __name__ == "__main__":
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]