This is an automated email from the ASF dual-hosted git repository.

haejoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 48acded7f1b9 [SPARK-50718][PYTHON] Support `addArtifact(s)` for PySpark
48acded7f1b9 is described below

commit 48acded7f1b906fc7b29f26d4706c7495a67e256
Author: Haejoon Lee <[email protected]>
AuthorDate: Tue Jan 21 15:21:24 2025 +0900

    [SPARK-50718][PYTHON] Support `addArtifact(s)` for PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support `addArtifact(s)` for PySpark
    
    ### Why are the changes needed?
    
    For feature parity with Spark Connect
    
    ### Does this PR introduce _any_ user-facing change?
    
    No API changes, but adding new API `addArtifact(s)`
    
    ### How was this patch tested?
    
    Added corresponding UTs with Spark Connect
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #49572 from itholic/add_artifacts.
    
    Authored-by: Haejoon Lee <[email protected]>
    Signed-off-by: Haejoon Lee <[email protected]>
---
 .../source/reference/pyspark.sql/spark_session.rst |  4 +-
 python/pyspark/errors/error-conditions.json        |  5 ++
 python/pyspark/sql/session.py                      | 47 ++++++++---
 python/pyspark/sql/tests/test_artifact.py          | 94 ++++++++++++++++++++++
 .../sql/tests/test_connect_compatibility.py        |  2 -
 python/pyspark/sql/tests/test_session.py           |  1 -
 6 files changed, 138 insertions(+), 15 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst 
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 0d6a1bc79b90..6c1142a23d58 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -44,6 +44,8 @@ See also :class:`SparkSession`.
 .. autosummary::
     :toctree: api/
 
+    SparkSession.addArtifact
+    SparkSession.addArtifacts
     SparkSession.addTag
     SparkSession.catalog
     SparkSession.clearTags
@@ -84,8 +86,6 @@ Spark Connect Only
 .. autosummary::
     :toctree: api/
 
-    SparkSession.addArtifact
-    SparkSession.addArtifacts
     SparkSession.clearProgressHandlers
     SparkSession.client
     SparkSession.copyFromLocalToFs
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index b7c1ec23c3af..2a3bcd7240e7 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -266,6 +266,11 @@
       "Argument `<arg_name>`(type: <arg_type>) should only contain a type in 
[<allowed_types>], got <item_type>"
     ]
   },
+  "DUPLICATED_ARTIFACT": {
+    "message": [
+      "Duplicate Artifact: <normalized_path>. Artifacts cannot be overwritten."
+    ]
+  },
   "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT": {
     "message": [
       "Duplicated field names in Arrow Struct are not allowed, got 
<field_names>"
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index f5bb269c23d6..fb4c83868f91 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -17,6 +17,7 @@
 import os
 import sys
 import warnings
+import filecmp
 from collections.abc import Sized
 from functools import reduce, cached_property
 from threading import RLock
@@ -2077,7 +2078,6 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.client"},
         )
 
-    @remote_only
     def addArtifacts(
         self, *path: str, pyfile: bool = False, archive: bool = False, file: 
bool = False
     ) -> None:
@@ -2086,6 +2086,9 @@ class SparkSession(SparkConversionMixin):
 
         .. versionadded:: 3.5.0
 
+        .. versionchanged:: 4.0.0
+            Supports Spark Classic.
+
         Parameters
         ----------
         *path : tuple of str
@@ -2100,16 +2103,40 @@ class SparkSession(SparkConversionMixin):
         file : bool
             Add a file to be downloaded with this Spark job on every node.
             The ``path`` passed can only be a local file for now.
-
-        Notes
-        -----
-        This is an API dedicated to Spark Connect client only. With regular 
Spark Session, it throws
-        an exception.
         """
-        raise PySparkRuntimeError(
-            errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
-            messageParameters={"feature": "SparkSession.addArtifact(s)"},
-        )
+        from pyspark.core.files import SparkFiles
+
+        if sum([file, pyfile, archive]) > 1:
+            raise PySparkValueError(
+                errorClass="INVALID_MULTIPLE_ARGUMENT_CONDITIONS",
+                messageParameters={
+                    "arg_names": "'pyfile', 'archive' and/or 'file'",
+                    "condition": "True together",
+                },
+            )
+        for p in path:
+            normalized_path = os.path.abspath(p)
+            target_dir = os.path.join(
+                SparkFiles.getRootDirectory(), 
os.path.basename(normalized_path)
+            )
+
+            # Check if the target path already exists
+            if os.path.exists(target_dir):
+                # Compare the contents of the files. If identical, skip adding 
this file.
+                # If different, raise an exception.
+                if filecmp.cmp(normalized_path, target_dir, shallow=False):
+                    continue
+                else:
+                    raise PySparkRuntimeError(
+                        errorClass="DUPLICATED_ARTIFACT",
+                        messageParameters={"normalized_path": normalized_path},
+                    )
+        if archive:
+            self._sc.addArchive(*path)
+        elif pyfile:
+            self._sc.addPyFile(*path)
+        elif file:
+            self._sc.addFile(*path)  # type: ignore[arg-type]
 
     addArtifact = addArtifacts
 
diff --git a/python/pyspark/sql/tests/test_artifact.py 
b/python/pyspark/sql/tests/test_artifact.py
new file mode 100644
index 000000000000..791ee4f40ec9
--- /dev/null
+++ b/python/pyspark/sql/tests/test_artifact.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+import os
+import tempfile
+
+from pyspark.sql.tests.connect.client.test_artifact import ArtifactTestsMixin
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.errors import PySparkRuntimeError
+
+
+class ArtifactTests(ArtifactTestsMixin, ReusedSQLTestCase):
+    @classmethod
+    def root(cls):
+        from pyspark.core.files import SparkFiles
+
+        return SparkFiles.getRootDirectory()
+
+    def test_add_pyfile(self):
+        self.check_add_pyfile(self.spark)
+
+        # Test multi sessions. Should be able to add the same
+        # file from different session.
+        self.check_add_pyfile(self.spark.newSession())
+
+    def test_add_file(self):
+        self.check_add_file(self.spark)
+
+        # Test multi sessions. Should be able to add the same
+        # file from different session.
+        self.check_add_file(self.spark.newSession())
+
+    def test_add_archive(self):
+        self.check_add_archive(self.spark)
+
+        # Test multi sessions. Should be able to add the same
+        # file from different session.
+        self.check_add_file(self.spark.newSession())
+
+    def test_artifacts_cannot_be_overwritten(self):
+        with 
tempfile.TemporaryDirectory(prefix="test_artifacts_cannot_be_overwritten") as d:
+            pyfile_path = os.path.join(d, "my_pyfile.py")
+            with open(pyfile_path, "w+") as f:
+                f.write("my_func = lambda: 10")
+
+            self.spark.addArtifacts(pyfile_path, pyfile=True)
+
+            # Writing the same file twice is fine, and should not throw.
+            self.spark.addArtifacts(pyfile_path, pyfile=True)
+
+            with open(pyfile_path, "w+") as f:
+                f.write("my_func = lambda: 11")
+
+            with self.assertRaises(PySparkRuntimeError) as pe:
+                self.spark.addArtifacts(pyfile_path, pyfile=True)
+
+            self.check_error(
+                exception=pe.exception,
+                errorClass="DUPLICATED_ARTIFACT",
+                messageParameters={"normalized_path": pyfile_path},
+            )
+
+    def test_add_zipped_package(self):
+        self.check_add_zipped_package(self.spark)
+
+        # Test multi sessions. Should be able to add the same
+        # file from different session.
+        self.check_add_file(self.spark.newSession())
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_artifact import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py 
b/python/pyspark/sql/tests/test_connect_compatibility.py
index 4ac68292b402..37105ee04038 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -262,8 +262,6 @@ class ConnectCompatibilityTestsMixin:
         expected_missing_connect_properties = {"sparkContext"}
         expected_missing_classic_properties = {"is_stopped", "session_id"}
         expected_missing_connect_methods = {
-            "addArtifact",
-            "addArtifacts",
             "clearProgressHandlers",
             "copyFromLocalToFs",
             "newSession",
diff --git a/python/pyspark/sql/tests/test_session.py 
b/python/pyspark/sql/tests/test_session.py
index c21247e3159c..4f3fdb3d3408 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -225,7 +225,6 @@ class SparkSessionTests3(unittest.TestCase, 
PySparkErrorTestUtils):
         with SparkSession.builder.master("local").getOrCreate() as session:
             unsupported = [
                 (lambda: session.client, "client"),
-                (session.addArtifacts, "addArtifact(s)"),
                 (lambda: session.copyFromLocalToFs("", ""), 
"copyFromLocalToFs"),
             ]
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to