This is an automated email from the ASF dual-hosted git repository.
haejoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 48acded7f1b9 [SPARK-50718][PYTHON] Support `addArtifact(s)` for PySpark
48acded7f1b9 is described below
commit 48acded7f1b906fc7b29f26d4706c7495a67e256
Author: Haejoon Lee <[email protected]>
AuthorDate: Tue Jan 21 15:21:24 2025 +0900
[SPARK-50718][PYTHON] Support `addArtifact(s)` for PySpark
### What changes were proposed in this pull request?
This PR proposes to support `addArtifact(s)` for PySpark
### Why are the changes needed?
For feature parity with Spark Connect
### Does this PR introduce _any_ user-facing change?
No API changes, but adding new API `addArtifact(s)`
### How was this patch tested?
Added corresponding UTs with Spark Connect
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49572 from itholic/add_artifacts.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Haejoon Lee <[email protected]>
---
.../source/reference/pyspark.sql/spark_session.rst | 4 +-
python/pyspark/errors/error-conditions.json | 5 ++
python/pyspark/sql/session.py | 47 ++++++++---
python/pyspark/sql/tests/test_artifact.py | 94 ++++++++++++++++++++++
.../sql/tests/test_connect_compatibility.py | 2 -
python/pyspark/sql/tests/test_session.py | 1 -
6 files changed, 138 insertions(+), 15 deletions(-)
diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 0d6a1bc79b90..6c1142a23d58 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -44,6 +44,8 @@ See also :class:`SparkSession`.
.. autosummary::
:toctree: api/
+ SparkSession.addArtifact
+ SparkSession.addArtifacts
SparkSession.addTag
SparkSession.catalog
SparkSession.clearTags
@@ -84,8 +86,6 @@ Spark Connect Only
.. autosummary::
:toctree: api/
- SparkSession.addArtifact
- SparkSession.addArtifacts
SparkSession.clearProgressHandlers
SparkSession.client
SparkSession.copyFromLocalToFs
diff --git a/python/pyspark/errors/error-conditions.json
b/python/pyspark/errors/error-conditions.json
index b7c1ec23c3af..2a3bcd7240e7 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -266,6 +266,11 @@
"Argument `<arg_name>`(type: <arg_type>) should only contain a type in
[<allowed_types>], got <item_type>"
]
},
+ "DUPLICATED_ARTIFACT": {
+ "message": [
+ "Duplicate Artifact: <normalized_path>. Artifacts cannot be overwritten."
+ ]
+ },
"DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT": {
"message": [
"Duplicated field names in Arrow Struct are not allowed, got
<field_names>"
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index f5bb269c23d6..fb4c83868f91 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -17,6 +17,7 @@
import os
import sys
import warnings
+import filecmp
from collections.abc import Sized
from functools import reduce, cached_property
from threading import RLock
@@ -2077,7 +2078,6 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.client"},
)
- @remote_only
def addArtifacts(
self, *path: str, pyfile: bool = False, archive: bool = False, file:
bool = False
) -> None:
@@ -2086,6 +2086,9 @@ class SparkSession(SparkConversionMixin):
.. versionadded:: 3.5.0
+ .. versionchanged:: 4.0.0
+ Supports Spark Classic.
+
Parameters
----------
*path : tuple of str
@@ -2100,16 +2103,40 @@ class SparkSession(SparkConversionMixin):
file : bool
Add a file to be downloaded with this Spark job on every node.
The ``path`` passed can only be a local file for now.
-
- Notes
- -----
- This is an API dedicated to Spark Connect client only. With regular
Spark Session, it throws
- an exception.
"""
- raise PySparkRuntimeError(
- errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
- messageParameters={"feature": "SparkSession.addArtifact(s)"},
- )
+ from pyspark.core.files import SparkFiles
+
+ if sum([file, pyfile, archive]) > 1:
+ raise PySparkValueError(
+ errorClass="INVALID_MULTIPLE_ARGUMENT_CONDITIONS",
+ messageParameters={
+ "arg_names": "'pyfile', 'archive' and/or 'file'",
+ "condition": "True together",
+ },
+ )
+ for p in path:
+ normalized_path = os.path.abspath(p)
+ target_dir = os.path.join(
+ SparkFiles.getRootDirectory(),
os.path.basename(normalized_path)
+ )
+
+ # Check if the target path already exists
+ if os.path.exists(target_dir):
+ # Compare the contents of the files. If identical, skip adding
this file.
+ # If different, raise an exception.
+ if filecmp.cmp(normalized_path, target_dir, shallow=False):
+ continue
+ else:
+ raise PySparkRuntimeError(
+ errorClass="DUPLICATED_ARTIFACT",
+ messageParameters={"normalized_path": normalized_path},
+ )
+ if archive:
+ self._sc.addArchive(*path)
+ elif pyfile:
+ self._sc.addPyFile(*path)
+ elif file:
+ self._sc.addFile(*path) # type: ignore[arg-type]
addArtifact = addArtifacts
diff --git a/python/pyspark/sql/tests/test_artifact.py
b/python/pyspark/sql/tests/test_artifact.py
new file mode 100644
index 000000000000..791ee4f40ec9
--- /dev/null
+++ b/python/pyspark/sql/tests/test_artifact.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+import os
+import tempfile
+
+from pyspark.sql.tests.connect.client.test_artifact import ArtifactTestsMixin
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.errors import PySparkRuntimeError
+
+
+class ArtifactTests(ArtifactTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def root(cls):
+ from pyspark.core.files import SparkFiles
+
+ return SparkFiles.getRootDirectory()
+
+ def test_add_pyfile(self):
+ self.check_add_pyfile(self.spark)
+
+ # Test multi sessions. Should be able to add the same
+ # file from different session.
+ self.check_add_pyfile(self.spark.newSession())
+
+ def test_add_file(self):
+ self.check_add_file(self.spark)
+
+ # Test multi sessions. Should be able to add the same
+ # file from different session.
+ self.check_add_file(self.spark.newSession())
+
+ def test_add_archive(self):
+ self.check_add_archive(self.spark)
+
+ # Test multi sessions. Should be able to add the same
+ # file from different session.
+ self.check_add_file(self.spark.newSession())
+
+ def test_artifacts_cannot_be_overwritten(self):
+ with
tempfile.TemporaryDirectory(prefix="test_artifacts_cannot_be_overwritten") as d:
+ pyfile_path = os.path.join(d, "my_pyfile.py")
+ with open(pyfile_path, "w+") as f:
+ f.write("my_func = lambda: 10")
+
+ self.spark.addArtifacts(pyfile_path, pyfile=True)
+
+ # Writing the same file twice is fine, and should not throw.
+ self.spark.addArtifacts(pyfile_path, pyfile=True)
+
+ with open(pyfile_path, "w+") as f:
+ f.write("my_func = lambda: 11")
+
+ with self.assertRaises(PySparkRuntimeError) as pe:
+ self.spark.addArtifacts(pyfile_path, pyfile=True)
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="DUPLICATED_ARTIFACT",
+ messageParameters={"normalized_path": pyfile_path},
+ )
+
+ def test_add_zipped_package(self):
+ self.check_add_zipped_package(self.spark)
+
+ # Test multi sessions. Should be able to add the same
+ # file from different session.
+ self.check_add_file(self.spark.newSession())
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_artifact import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py
b/python/pyspark/sql/tests/test_connect_compatibility.py
index 4ac68292b402..37105ee04038 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -262,8 +262,6 @@ class ConnectCompatibilityTestsMixin:
expected_missing_connect_properties = {"sparkContext"}
expected_missing_classic_properties = {"is_stopped", "session_id"}
expected_missing_connect_methods = {
- "addArtifact",
- "addArtifacts",
"clearProgressHandlers",
"copyFromLocalToFs",
"newSession",
diff --git a/python/pyspark/sql/tests/test_session.py
b/python/pyspark/sql/tests/test_session.py
index c21247e3159c..4f3fdb3d3408 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -225,7 +225,6 @@ class SparkSessionTests3(unittest.TestCase,
PySparkErrorTestUtils):
with SparkSession.builder.master("local").getOrCreate() as session:
unsupported = [
(lambda: session.client, "client"),
- (session.addArtifacts, "addArtifact(s)"),
(lambda: session.copyFromLocalToFs("", ""),
"copyFromLocalToFs"),
]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]