This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new c03a875  Revert "[SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs"
c03a875 is described below

commit c03a87599c78da38fcc7d95d422bf0bf639c4171
Author: HyukjinKwon <[email protected]>
AuthorDate: Wed May 27 10:16:08 2020 +0900

    Revert "[SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs"
    
    This reverts commit be18718380fc501fe2a780debd089a1df91c1699.
---
 python/pyspark/context.py        | 12 ++----------
 python/pyspark/tests/test_rdd.py |  9 ---------
 2 files changed, 2 insertions(+), 19 deletions(-)

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 3199aa7..d5f1506 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -25,7 +25,6 @@ from threading import RLock
 from tempfile import NamedTemporaryFile
 
 from py4j.protocol import Py4JError
-from py4j.java_gateway import is_instance_of
 
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
@@ -865,17 +864,10 @@ class SparkContext(object):
         first_jrdd_deserializer = rdds[0]._jrdd_deserializer
         if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
             rdds = [x._reserialize() for x in rdds]
-        gw = SparkContext._gateway
         cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
-        is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls)
-        jrdds = gw.new_array(cls, len(rdds))
+        jrdds = SparkContext._gateway.new_array(cls, len(rdds))
         for i in range(0, len(rdds)):
-            if is_jrdd:
-                jrdds[i] = rdds[i]._jrdd
-            else:
-                # zip could return JavaPairRDD hence we ensure `_jrdd`
-                # to be `JavaRDD` by wrapping it in a `map`
-                jrdds[i] = rdds[i].map(lambda x: x)._jrdd
+            jrdds[i] = rdds[i]._jrdd
         return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
 
     def broadcast(self, value):
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index 0f1ee5b..e2d910c 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -166,15 +166,6 @@ class RDDTests(ReusedPySparkTestCase):
             set([(x, (x, x)) for x in 'abc'])
         )
 
-    def test_union_pair_rdd(self):
-        # Regression test for SPARK-31788
-        rdd = self.sc.parallelize([1, 2])
-        pair_rdd = rdd.zip(rdd)
-        self.assertEqual(
-            self.sc.union([pair_rdd, pair_rdd]).collect(),
-            [((1, 1), (2, 2)), ((1, 1), (2, 2))]
-        )
-
     def test_deleting_input_files(self):
         # Regression test for SPARK-1025
         tempFile = tempfile.NamedTemporaryFile(delete=False)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to