Repository: spark
Updated Branches:
  refs/heads/branch-2.0 fab77dadf -> fffcec90b


[SPARK-17463][CORE] Make CollectionAccumulator and SetAccumulator's value can 
be read thread-safely

## What changes were proposed in this pull request?

Make CollectionAccumulator and SetAccumulator's value can be read thread-safely 
to fix the ConcurrentModificationException reported in 
[JIRA](https://issues.apache.org/jira/browse/SPARK-17463).

## How was this patch tested?

Existing tests.

Author: Shixiong Zhu <[email protected]>

Closes #15063 from zsxwing/SPARK-17463.

(cherry picked from commit e33bfaed3b160fbc617c878067af17477a0044f5)
Signed-off-by: Josh Rosen <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fffcec90
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fffcec90
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fffcec90

Branch: refs/heads/branch-2.0
Commit: fffcec90b65047c3031c2b96679401f8fbef6337
Parents: fab77da
Author: Shixiong Zhu <[email protected]>
Authored: Wed Sep 14 13:33:51 2016 -0700
Committer: Josh Rosen <[email protected]>
Committed: Wed Sep 14 13:34:27 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/executor/TaskMetrics.scala | 41 +++++++++++++-------
 .../org/apache/spark/util/AccumulatorV2.scala   |  7 +++-
 .../org/apache/spark/util/JsonProtocol.scala    | 11 +++---
 .../apache/spark/util/JsonProtocolSuite.scala   |  3 +-
 .../spark/sql/execution/debug/package.scala     | 24 +++++++-----
 5 files changed, 54 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fffcec90/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index dd149a9..52a3499 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.executor
 
+import java.util.{ArrayList, Collections}
+
+import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
 
 import org.apache.spark._
@@ -99,7 +102,11 @@ class TaskMetrics private[spark] () extends Serializable {
   /**
    * Storage statuses of any blocks that have been updated as a result of this 
task.
    */
-  def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = 
_updatedBlockStatuses.value
+  def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = {
+    // This is called on driver. All accumulator updates have a fixed value. 
So it's safe to use
+    // `asScala` which accesses the internal values using `java.util.Iterator`.
+    _updatedBlockStatuses.value.asScala
+  }
 
   // Setters and increment-ers
   private[spark] def setExecutorDeserializeTime(v: Long): Unit =
@@ -114,8 +121,10 @@ class TaskMetrics private[spark] () extends Serializable {
   private[spark] def incPeakExecutionMemory(v: Long): Unit = 
_peakExecutionMemory.add(v)
   private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit =
     _updatedBlockStatuses.add(v)
-  private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): 
Unit =
+  private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, 
BlockStatus)]): Unit =
     _updatedBlockStatuses.setValue(v)
+  private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): 
Unit =
+    _updatedBlockStatuses.setValue(v.asJava)
 
   /**
    * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] 
or from persisted
@@ -268,7 +277,7 @@ private[spark] object TaskMetrics extends Logging {
       val name = info.name.get
       val value = info.update.get
       if (name == UPDATED_BLOCK_STATUSES) {
-        tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, 
BlockStatus)]])
+        tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, 
BlockStatus)]])
       } else {
         tm.nameToAccums.get(name).foreach(
           _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long])
@@ -299,8 +308,8 @@ private[spark] object TaskMetrics extends Logging {
 
 
 private[spark] class BlockStatusesAccumulator
-  extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] {
-  private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
+  extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, 
BlockStatus)]] {
+  private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, 
BlockStatus)]())
 
   override def isZero(): Boolean = _seq.isEmpty
 
@@ -308,25 +317,27 @@ private[spark] class BlockStatusesAccumulator
 
   override def copy(): BlockStatusesAccumulator = {
     val newAcc = new BlockStatusesAccumulator
-    newAcc._seq = _seq.clone()
+    newAcc._seq.addAll(_seq)
     newAcc
   }
 
   override def reset(): Unit = _seq.clear()
 
-  override def add(v: (BlockId, BlockStatus)): Unit = _seq += v
+  override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v)
 
-  override def merge(other: AccumulatorV2[(BlockId, BlockStatus), 
Seq[(BlockId, BlockStatus)]])
-  : Unit = other match {
-    case o: BlockStatusesAccumulator => _seq ++= o.value
-    case _ => throw new UnsupportedOperationException(
-      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+  override def merge(
+    other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, 
BlockStatus)]]): Unit = {
+    other match {
+      case o: BlockStatusesAccumulator => _seq.addAll(o.value)
+      case _ => throw new UnsupportedOperationException(
+        s"Cannot merge ${this.getClass.getName} with 
${other.getClass.getName}")
+    }
   }
 
-  override def value: Seq[(BlockId, BlockStatus)] = _seq
+  override def value: java.util.List[(BlockId, BlockStatus)] = _seq
 
-  def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = {
+  def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = {
     _seq.clear()
-    _seq ++= newValue
+    _seq.addAll(newValue)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fffcec90/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala 
b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index d130a37..470d912 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util
 
 import java.{lang => jl}
 import java.io.ObjectInputStream
-import java.util.ArrayList
+import java.util.{ArrayList, Collections}
 import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicLong
 
@@ -38,6 +38,9 @@ private[spark] case class AccumulatorMetadata(
 /**
  * The base class for accumulators, that can accumulate inputs of type `IN`, 
and produce output of
  * type `OUT`.
+ *
+ * `OUT` should be a type that can be read atomically (e.g., Int, Long), or 
thread-safely
+ * (e.g., synchronized collections) because it will be read from other threads.
  */
 abstract class AccumulatorV2[IN, OUT] extends Serializable {
   private[spark] var metadata: AccumulatorMetadata = _
@@ -433,7 +436,7 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, 
jl.Double] {
  * @since 2.0.0
  */
 class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
-  private val _list: java.util.List[T] = new ArrayList[T]
+  private val _list: java.util.List[T] = Collections.synchronizedList(new 
ArrayList[T]())
 
   override def isZero: Boolean = _list.isEmpty
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fffcec90/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala 
b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 18547d4..148635f 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -309,11 +309,12 @@ private[spark] object JsonProtocol {
         case v: Int => JInt(v)
         case v: Long => JInt(v)
         // We only have 3 kind of internal accumulator types, so if it's not 
int or long, it must be
-        // the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]`
+        // the blocks accumulator, whose type is `java.util.List[(BlockId, 
BlockStatus)]`
         case v =>
-          JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case 
(id, status) =>
-            ("Block ID" -> id.toString) ~
-            ("Status" -> blockStatusToJson(status))
+          JArray(v.asInstanceOf[java.util.List[(BlockId, 
BlockStatus)]].asScala.toList.map {
+            case (id, status) =>
+              ("Block ID" -> id.toString) ~
+              ("Status" -> blockStatusToJson(status))
           })
       }
     } else {
@@ -740,7 +741,7 @@ private[spark] object JsonProtocol {
             val id = BlockId((blockJson \ "Block ID").extract[String])
             val status = blockStatusFromJson(blockJson \ "Status")
             (id, status)
-          }
+          }.asJava
         case _ => throw new IllegalArgumentException(s"unexpected json value 
$value for " +
           "accumulator " + name.get)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/fffcec90/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala 
b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 7a1ee03..cda3457 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.util
 
 import java.util.Properties
 
+import scala.collection.JavaConverters._
 import scala.collection.Map
 
 import org.json4s.jackson.JsonMethods._
@@ -415,7 +416,7 @@ class JsonProtocolSuite extends SparkFunSuite {
     })
     testAccumValue(Some(RESULT_SIZE), 3L, JInt(3))
     testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2))
-    testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson)
+    testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks.asJava, blocksJson)
     // For anything else, we just cast the value to a string
     testAccumValue(Some("anything"), blocks, JString(blocks.toString))
     testAccumValue(Some("anything"), 123, JString("123"))

http://git-wip-us.apache.org/repos/asf/spark/blob/fffcec90/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 082f97a..d321f4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution
 
-import scala.collection.mutable.HashSet
+import java.util.Collections
+
+import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
@@ -107,18 +109,20 @@ package object debug {
   case class DebugExec(child: SparkPlan) extends UnaryExecNode with 
CodegenSupport {
     def output: Seq[Attribute] = child.output
 
-    class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] {
-      private val _set = new HashSet[T]()
+    class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] {
+      private val _set = Collections.synchronizedSet(new 
java.util.HashSet[T]())
       override def isZero: Boolean = _set.isEmpty
-      override def copy(): AccumulatorV2[T, HashSet[T]] = {
+      override def copy(): AccumulatorV2[T, java.util.Set[T]] = {
         val newAcc = new SetAccumulator[T]()
-        newAcc._set ++= _set
+        newAcc._set.addAll(_set)
         newAcc
       }
       override def reset(): Unit = _set.clear()
-      override def add(v: T): Unit = _set += v
-      override def merge(other: AccumulatorV2[T, HashSet[T]]): Unit = _set ++= 
other.value
-      override def value: HashSet[T] = _set
+      override def add(v: T): Unit = _set.add(v)
+      override def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = {
+        _set.addAll(other.value)
+      }
+      override def value: java.util.Set[T] = _set
     }
 
     /**
@@ -138,7 +142,9 @@ package object debug {
       debugPrint(s"== ${child.simpleString} ==")
       debugPrint(s"Tuples output: ${tupleCount.value}")
       child.output.zip(columnStats).foreach { case (attr, metric) =>
-        val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
+        // This is called on driver. All accumulator updates have a fixed 
value. So it's safe to use
+        // `asScala` which accesses the internal values using 
`java.util.Iterator`.
+        val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", 
",", "}")
         debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to