This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new be18718 [SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs
be18718 is described below
commit be18718380fc501fe2a780debd089a1df91c1699
Author: schintap <[email protected]>
AuthorDate: Mon May 25 10:29:08 2020 +0900
[SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs
### What changes were proposed in this pull request?
UnionRDD of PairRDDs causing a bug. The fix is to check for instance type
before proceeding
### Why are the changes needed?
Changes are needed to avoid users running into issues with union rdd
operation with any other type other than JavaRDD.
### Does this PR introduce _any_ user-facing change?
Yes
Before:
SparkSession available as 'spark'.
>>> rdd1 = sc.parallelize([1,2,3,4,5])
>>> rdd2 = sc.parallelize([6,7,8,9,10])
>>> pairRDD1 = rdd1.zip(rdd2)
>>> unionRDD1 = sc.union([pairRDD1, pairRDD1])
Traceback (most recent call last): File "<stdin>", line 1, in <module> File
"/home/gs/spark/latest/python/pyspark/context.py", line 870,
in union jrdds[i] = rdds[i]._jrdd
File
"/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py",
line 238, in setitem File
"/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py",
line 221,
in __set_item File
"/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line
332, in get_return_value py4j.protocol.Py4JError: An error occurred while
calling None.None. Trace: py4j.Py4JException: Cannot convert
org.apache.spark.api.java.JavaPairRDD to org.apache.spark.api.java.JavaRDD at
py4j.commands.ArrayCommand.convertArgument(ArrayCommand.java:166) at
py4j.commands.ArrayCommand.setArray(ArrayCommand.java:144) at
py4j.commands.ArrayCommand.execute(ArrayCommand. [...]
After:
>>> rdd2 = sc.parallelize([6,7,8,9,10])
>>> pairRDD1 = rdd1.zip(rdd2)
>>> unionRDD1 = sc.union([pairRDD1, pairRDD1])
>>> unionRDD1.collect()
[(1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (1, 6), (2, 7), (3, 8), (4, 9),
(5, 10)]
### How was this patch tested?
Tested with the reproduced piece of code above manually
Closes #28603 from redsanket/SPARK-31788.
Authored-by: schintap <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
(cherry picked from commit a61911c50c391e61038cf01611629d2186d17a76)
Signed-off-by: HyukjinKwon <[email protected]>
---
python/pyspark/context.py | 12 ++++++++++--
python/pyspark/tests/test_rdd.py | 9 +++++++++
2 files changed, 19 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index d5f1506..3199aa7 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -25,6 +25,7 @@ from threading import RLock
from tempfile import NamedTemporaryFile
from py4j.protocol import Py4JError
+from py4j.java_gateway import is_instance_of
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@@ -864,10 +865,17 @@ class SparkContext(object):
first_jrdd_deserializer = rdds[0]._jrdd_deserializer
if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
rdds = [x._reserialize() for x in rdds]
+ gw = SparkContext._gateway
cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
- jrdds = SparkContext._gateway.new_array(cls, len(rdds))
+ is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls)
+ jrdds = gw.new_array(cls, len(rdds))
for i in range(0, len(rdds)):
- jrdds[i] = rdds[i]._jrdd
+ if is_jrdd:
+ jrdds[i] = rdds[i]._jrdd
+ else:
+ # zip could return JavaPairRDD hence we ensure `_jrdd`
+ # to be `JavaRDD` by wrapping it in a `map`
+ jrdds[i] = rdds[i].map(lambda x: x)._jrdd
return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
def broadcast(self, value):
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index e2d910c..0f1ee5b 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -166,6 +166,15 @@ class RDDTests(ReusedPySparkTestCase):
set([(x, (x, x)) for x in 'abc'])
)
+ def test_union_pair_rdd(self):
+ # Regression test for SPARK-31788
+ rdd = self.sc.parallelize([1, 2])
+ pair_rdd = rdd.zip(rdd)
+ self.assertEqual(
+ self.sc.union([pair_rdd, pair_rdd]).collect(),
+ [((1, 1), (2, 2)), ((1, 1), (2, 2))]
+ )
+
def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]