This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 190c50451333 [SPARK-50388][PYTHON][TESTS] Further centralize import 
checks
190c50451333 is described below

commit 190c50451333c5f6be349defcac1ad1983632935
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Nov 22 10:22:30 2024 +0800

    [SPARK-50388][PYTHON][TESTS] Further centralize import checks
    
    ### What changes were proposed in this pull request?
    Further centralized import checks:
    1, move `have_xxx` from `sqlutils.py/pandasutils.py/xxx` to `utils.py`;
    2, but still keep `have_pandas` and `have_pyarrow` in `sqlutils.py`, by 
importing them from `utils.py`, because there are too many usage places
    
    ### Why are the changes needed?
    simplify the import checks, e.g. `have_plotly` has been defined in multiple 
places
    
    ### Does this PR introduce _any_ user-facing change?
    no, test only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48926 from zhengruifeng/py_dep_2.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/pandas/tests/io/test_io.py          |  7 ++--
 .../tests/plot/test_frame_plot_matplotlib.py       |  8 ++---
 .../pandas/tests/plot/test_frame_plot_plotly.py    |  8 ++---
 .../pyspark/pandas/tests/plot/test_series_plot.py  |  2 +-
 .../tests/plot/test_series_plot_matplotlib.py      |  8 ++---
 .../pandas/tests/plot/test_series_plot_plotly.py   |  8 ++---
 .../pyspark/pandas/tests/series/test_conversion.py |  2 +-
 python/pyspark/sql/metrics.py                      |  6 ++--
 python/pyspark/sql/tests/connect/test_df_debug.py  |  7 ++--
 python/pyspark/sql/tests/plot/test_frame_plot.py   |  4 +--
 .../sql/tests/plot/test_frame_plot_plotly.py       |  4 +--
 python/pyspark/testing/pandasutils.py              | 23 ------------
 python/pyspark/testing/sqlutils.py                 | 41 ++++++----------------
 python/pyspark/testing/utils.py                    | 33 +++++++++++++++++
 14 files changed, 64 insertions(+), 97 deletions(-)

diff --git a/python/pyspark/pandas/tests/io/test_io.py 
b/python/pyspark/pandas/tests/io/test_io.py
index d4e61319f229..6fbdc366dd76 100644
--- a/python/pyspark/pandas/tests/io/test_io.py
+++ b/python/pyspark/pandas/tests/io/test_io.py
@@ -22,12 +22,9 @@ import numpy as np
 import pandas as pd
 
 from pyspark import pandas as ps
-from pyspark.testing.pandasutils import (
-    have_tabulate,
-    PandasOnSparkTestCase,
-    tabulate_requirement_message,
-)
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.testing.utils import have_tabulate, tabulate_requirement_message
 
 
 # This file contains test cases for 'Serialization / IO / Conversion'
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py 
b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
index 365d34b1f550..1d63cafe19b4 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
@@ -24,12 +24,8 @@ import numpy as np
 
 from pyspark import pandas as ps
 from pyspark.pandas.config import set_option, reset_option
-from pyspark.testing.pandasutils import (
-    have_matplotlib,
-    matplotlib_requirement_message,
-    PandasOnSparkTestCase,
-    TestUtils,
-)
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.testing.utils import have_matplotlib, 
matplotlib_requirement_message
 
 if have_matplotlib:
     import matplotlib
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py 
b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
index 8d197649aaeb..530893257333 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
@@ -23,12 +23,8 @@ import numpy as np
 
 from pyspark import pandas as ps
 from pyspark.pandas.config import set_option, reset_option
-from pyspark.testing.pandasutils import (
-    have_plotly,
-    plotly_requirement_message,
-    PandasOnSparkTestCase,
-    TestUtils,
-)
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.testing.utils import have_plotly, plotly_requirement_message
 from pyspark.pandas.utils import name_like_string
 
 if have_plotly:
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot.py 
b/python/pyspark/pandas/tests/plot/test_series_plot.py
index 6e0bdd232fc4..61d114f37b0e 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot.py
@@ -22,7 +22,7 @@ import numpy as np
 
 from pyspark import pandas as ps
 from pyspark.pandas.plot import PandasOnSparkPlotAccessor, BoxPlotBase
-from pyspark.testing.pandasutils import have_plotly, plotly_requirement_message
+from pyspark.testing.utils import have_plotly, plotly_requirement_message
 
 
 class SeriesPlotTestsMixin:
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py 
b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
index c98c1aeea04e..0fdcbc9d748e 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
@@ -24,12 +24,8 @@ import pandas as pd
 
 from pyspark import pandas as ps
 from pyspark.pandas.config import set_option, reset_option
-from pyspark.testing.pandasutils import (
-    have_matplotlib,
-    matplotlib_requirement_message,
-    PandasOnSparkTestCase,
-    TestUtils,
-)
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.testing.utils import have_matplotlib, 
matplotlib_requirement_message
 
 if have_matplotlib:
     import matplotlib
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py 
b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
index 1aa175f9308a..8123af26dbf4 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
@@ -24,12 +24,8 @@ import numpy as np
 from pyspark import pandas as ps
 from pyspark.pandas.config import set_option, reset_option
 from pyspark.pandas.utils import name_like_string
-from pyspark.testing.pandasutils import (
-    have_plotly,
-    plotly_requirement_message,
-    PandasOnSparkTestCase,
-    TestUtils,
-)
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.testing.utils import have_plotly, plotly_requirement_message
 
 if have_plotly:
     from plotly import express
diff --git a/python/pyspark/pandas/tests/series/test_conversion.py 
b/python/pyspark/pandas/tests/series/test_conversion.py
index 71ae858631d4..7711d05abd76 100644
--- a/python/pyspark/pandas/tests/series/test_conversion.py
+++ b/python/pyspark/pandas/tests/series/test_conversion.py
@@ -21,7 +21,7 @@ import pandas as pd
 from pyspark import pandas as ps
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
-from pyspark.testing.pandasutils import have_tabulate, 
tabulate_requirement_message
+from pyspark.testing.utils import have_tabulate, tabulate_requirement_message
 
 
 class SeriesConversionMixin:
diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py
index 0f4142e91b25..4ab9b041e313 100644
--- a/python/pyspark/sql/metrics.py
+++ b/python/pyspark/sql/metrics.py
@@ -21,10 +21,10 @@ from typing import Optional, List, Tuple, Dict, Any, Union, 
TYPE_CHECKING, Seque
 from pyspark.errors import PySparkValueError
 
 if TYPE_CHECKING:
-    from pyspark.testing.connectutils import have_graphviz
-
-    if have_graphviz:
+    try:
         import graphviz  # type: ignore
+    except ImportError:
+        pass
 
 
 class ObservedMetrics(abc.ABC):
diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py 
b/python/pyspark/sql/tests/connect/test_df_debug.py
index 8a4ec68fda84..40b6a072e912 100644
--- a/python/pyspark/sql/tests/connect/test_df_debug.py
+++ b/python/pyspark/sql/tests/connect/test_df_debug.py
@@ -17,12 +17,9 @@
 
 import unittest
 
-from pyspark.testing.connectutils import (
-    should_test_connect,
-    have_graphviz,
-    graphviz_requirement_message,
-)
 from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
+from pyspark.testing.connectutils import should_test_connect
+from pyspark.testing.utils import have_graphviz, graphviz_requirement_message
 
 if should_test_connect:
     from pyspark.sql.connect.dataframe import DataFrame
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py 
b/python/pyspark/sql/tests/plot/test_frame_plot.py
index 3221a408d153..c37aef5f7c94 100644
--- a/python/pyspark/sql/tests/plot/test_frame_plot.py
+++ b/python/pyspark/sql/tests/plot/test_frame_plot.py
@@ -18,8 +18,8 @@
 import unittest
 from pyspark.errors import PySparkValueError
 from pyspark.sql import Row
-from pyspark.testing.sqlutils import (
-    ReusedSQLTestCase,
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import (
     have_plotly,
     plotly_requirement_message,
     have_pandas,
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py 
b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
index 84a9c2aa0170..fd264c348882 100644
--- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -19,8 +19,8 @@ import unittest
 from datetime import datetime
 
 from pyspark.errors import PySparkTypeError, PySparkValueError
-from pyspark.testing.sqlutils import (
-    ReusedSQLTestCase,
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import (
     have_plotly,
     plotly_requirement_message,
     have_pandas,
diff --git a/python/pyspark/testing/pandasutils.py 
b/python/pyspark/testing/pandasutils.py
index 10e8ce6f69af..09d3ffb09708 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -23,29 +23,6 @@ from contextlib import contextmanager
 import decimal
 from typing import Any, Union
 
-tabulate_requirement_message = None
-try:
-    from tabulate import tabulate
-except ImportError as e:
-    # If tabulate requirement is not satisfied, skip related tests.
-    tabulate_requirement_message = str(e)
-have_tabulate = tabulate_requirement_message is None
-
-matplotlib_requirement_message = None
-try:
-    import matplotlib
-except ImportError as e:
-    # If matplotlib requirement is not satisfied, skip related tests.
-    matplotlib_requirement_message = str(e)
-have_matplotlib = matplotlib_requirement_message is None
-
-plotly_requirement_message = None
-try:
-    import plotly
-except ImportError as e:
-    # If plotly requirement is not satisfied, skip related tests.
-    plotly_requirement_message = str(e)
-have_plotly = plotly_requirement_message is None
 
 try:
     from pyspark.sql.pandas.utils import require_minimum_pandas_version
diff --git a/python/pyspark/testing/sqlutils.py 
b/python/pyspark/testing/sqlutils.py
index c833abfb805d..e5464257422a 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -22,23 +22,17 @@ import shutil
 import tempfile
 from contextlib import contextmanager
 
-pandas_requirement_message = None
-try:
-    from pyspark.sql.pandas.utils import require_minimum_pandas_version
-
-    require_minimum_pandas_version()
-except ImportError as e:
-    # If Pandas version requirement is not satisfied, skip related tests.
-    pandas_requirement_message = str(e)
-
-pyarrow_requirement_message = None
-try:
-    from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
+from pyspark.sql import SparkSession
+from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
+from pyspark.testing.utils import (
+    ReusedPySparkTestCase,
+    PySparkErrorTestUtils,
+    have_pandas,
+    pandas_requirement_message,
+    have_pyarrow,
+    pyarrow_requirement_message,
+)
 
-    require_minimum_pyarrow_version()
-except ImportError as e:
-    # If Arrow version requirement is not satisfied, skip related tests.
-    pyarrow_requirement_message = str(e)
 
 test_not_compiled_message = None
 try:
@@ -48,21 +42,6 @@ try:
 except Exception as e:
     test_not_compiled_message = str(e)
 
-plotly_requirement_message = None
-try:
-    import plotly
-except ImportError as e:
-    plotly_requirement_message = str(e)
-have_plotly = plotly_requirement_message is None
-
-
-from pyspark.sql import SparkSession
-from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
-from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils
-
-
-have_pandas = pandas_requirement_message is None
-have_pyarrow = pyarrow_requirement_message is None
 test_compiled = test_not_compiled_message is None
 
 
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index ca16628fc56f..1dd15666382f 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -82,6 +82,39 @@ deepspeed_requirement_message = None if have_deepspeed else 
"No module named 'de
 have_plotly = have_package("plotly")
 plotly_requirement_message = None if have_plotly else "No module named 
'plotly'"
 
+have_matplotlib = have_package("matplotlib")
+matplotlib_requirement_message = None if have_matplotlib else "No module named 
'matplotlib'"
+
+have_tabulate = have_package("tabulate")
+tabulate_requirement_message = None if have_tabulate else "No module named 
'tabulate'"
+
+have_graphviz = have_package("graphviz")
+graphviz_requirement_message = None if have_graphviz else "No module named 
'graphviz'"
+
+
+pandas_requirement_message = None
+try:
+    from pyspark.sql.pandas.utils import require_minimum_pandas_version
+
+    require_minimum_pandas_version()
+except Exception as e:
+    # If Pandas version requirement is not satisfied, skip related tests.
+    pandas_requirement_message = str(e)
+
+have_pandas = pandas_requirement_message is None
+
+
+pyarrow_requirement_message = None
+try:
+    from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
+
+    require_minimum_pyarrow_version()
+except Exception as e:
+    # If Arrow version requirement is not satisfied, skip related tests.
+    pyarrow_requirement_message = str(e)
+
+have_pyarrow = pyarrow_requirement_message is None
+
 
 def read_int(b):
     return struct.unpack("!i", b)[0]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to