Repository: spark Updated Branches: refs/heads/branch-1.0 2d3080855 -> 45bf91025
[SPARK-2251] fix concurrency issues in random sampler (branch-1.0) The following code is very likely to throw an exception: ~~~ val rdd = sc.parallelize(0 until 111, 10).sample(false, 0.1) rdd.zip(rdd).count() ~~~ because the same random number generator is used in compute partitions. This fix doesn't change the type signature. @pwendell Author: Xiangrui Meng <[email protected]> Closes #1234 from mengxr/fix-sample-1.0 and squashes the following commits: 88795e2 [Xiangrui Meng] fix concurrency issues in random sampler Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/45bf9102 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/45bf9102 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/45bf9102 Branch: refs/heads/branch-1.0 Commit: 45bf91025a4492127b01b83a87b94b1dc90c4b2c Parents: 2d30808 Author: Xiangrui Meng <[email protected]> Authored: Thu Jun 26 13:32:50 2014 -0700 Committer: Patrick Wendell <[email protected]> Committed: Thu Jun 26 13:32:50 2014 -0700 ---------------------------------------------------------------------- .../org/apache/spark/util/random/RandomSampler.scala | 7 ++++--- .../apache/spark/rdd/PartitionwiseSampledRDDSuite.scala | 12 +++++++++++- 2 files changed, 15 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/45bf9102/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 4dc8ada..15b3e1a 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -72,9 +72,10 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) /** * Return a sampler with is the complement of the range specified of the current sampler. */ - def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) + def cloneComplement(): BernoulliSampler[T] = + new BernoulliSampler[T](lb, ub, !complement)(new XORShiftRandom) - override def clone = new BernoulliSampler[T](lb, ub, complement) + override def clone = new BernoulliSampler[T](lb, ub, complement)(new XORShiftRandom) } /** @@ -104,5 +105,5 @@ class PoissonSampler[T](mean: Double) } } - override def clone = new PoissonSampler[T](mean) + override def clone = new PoissonSampler[T](mean)(new Poisson(mean, new DRand)) } http://git-wip-us.apache.org/repos/asf/spark/blob/45bf9102/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 00c273d..53e8e0b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.rdd import org.scalatest.FunSuite import org.apache.spark.SharedSparkContext -import org.apache.spark.util.random.RandomSampler +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler} /** a sampler that outputs its seed */ class MockSampler extends RandomSampler[Long, Long] { @@ -46,5 +46,15 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L) assert(sample.distinct.count == 2, "Seeds must be different.") } + + test("concurrency") { + // SPARK-2251: zip with self computes each partition twice. + // We want to make sure there are no concurrency issues. + val rdd = sc.parallelize(0 until 111, 10) + for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) { + val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler) + sampled.zip(sampled).count() + } + } }
