Repository: spark
Updated Branches:
  refs/heads/master be317d4a9 -> bf5496dbd


http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 5755c00..7bf9225 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -19,200 +19,106 @@ package org.apache.spark.sql.execution.metric
 
 import java.text.NumberFormat
 
-import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, 
SparkContext}
+import org.apache.spark.{NewAccumulator, SparkContext}
 import org.apache.spark.scheduler.AccumulableInfo
 import org.apache.spark.util.Utils
 
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source 
compatibility.
- *
- * An implementation of SQLMetric should override `+=` and `add` to avoid 
boxing.
- */
-private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
-    name: String,
-    val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, 
param, Some(name)) {
 
-  // Provide special identifier as metadata so we can tell that this is a 
`SQLMetric` later
-  override def toInfo(update: Option[Any], value: Option[Any]): 
AccumulableInfo = {
-    new AccumulableInfo(id, Some(name), update, value, true, countFailedValues,
-      Some(SQLMetrics.ACCUM_IDENTIFIER))
-  }
-
-  def reset(): Unit = {
-    this.value = param.zero
-  }
-}
-
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source 
compatibility.
- */
-private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends 
AccumulableParam[R, T] {
-
-  /**
-   * A function that defines how we aggregate the final accumulator results 
among all tasks,
-   * and represent it in string for a SQL physical operator.
-   */
-  val stringValue: Seq[T] => String
-
-  def zero: R
-}
+class SQLMetric(val metricType: String, initValue: Long = 0L) extends 
NewAccumulator[Long, Long] {
+  // This is a workaround for SPARK-11013.
+  // We may use -1 as initial value of the accumulator, if the accumulator is 
valid, we will
+  // update it at the end of task and the value will be at least 0. Then we 
can filter out the -1
+  // values before calculate max, min, etc.
+  private[this] var _value = initValue
 
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source 
compatibility.
- */
-private[sql] trait SQLMetricValue[T] extends Serializable {
+  override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue)
 
-  def value: T
-
-  override def toString: String = value.toString
-}
-
-/**
- * A wrapper of Long to avoid boxing and unboxing when using Accumulator
- */
-private[sql] class LongSQLMetricValue(private var _value : Long) extends 
SQLMetricValue[Long] {
-
-  def add(incr: Long): LongSQLMetricValue = {
-    _value += incr
-    this
+  override def merge(other: NewAccumulator[Long, Long]): Unit = other match {
+    case o: SQLMetric => _value += o.localValue
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
   }
 
-  // Although there is a boxing here, it's fine because it's only called in 
SQLListener
-  override def value: Long = _value
-
-  // Needed for SQLListenerSuite
-  override def equals(other: Any): Boolean = other match {
-    case o: LongSQLMetricValue => value == o.value
-    case _ => false
-  }
+  override def isZero(): Boolean = _value == initValue
 
-  override def hashCode(): Int = _value.hashCode()
-}
+  override def add(v: Long): Unit = _value += v
 
-/**
- * A specialized long Accumulable to avoid boxing and unboxing when using 
Accumulator's
- * `+=` and `add`.
- */
-private[sql] class LongSQLMetric private[metric](name: String, param: 
LongSQLMetricParam)
-  extends SQLMetric[LongSQLMetricValue, Long](name, param) {
+  def +=(v: Long): Unit = _value += v
 
-  override def +=(term: Long): Unit = {
-    localValue.add(term)
-  }
+  override def localValue: Long = _value
 
-  override def add(term: Long): Unit = {
-    localValue.add(term)
+  // Provide special identifier as metadata so we can tell that this is a 
`SQLMetric` later
+  private[spark] override def toInfo(update: Option[Any], value: Option[Any]): 
AccumulableInfo = {
+    new AccumulableInfo(id, name, update, value, true, true, 
Some(SQLMetrics.ACCUM_IDENTIFIER))
   }
-}
-
-private class LongSQLMetricParam(val stringValue: Seq[Long] => String, 
initialValue: Long)
-  extends SQLMetricParam[LongSQLMetricValue, Long] {
-
-  override def addAccumulator(r: LongSQLMetricValue, t: Long): 
LongSQLMetricValue = r.add(t)
 
-  override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): 
LongSQLMetricValue =
-    r1.add(r2.value)
-
-  override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = 
zero
-
-  override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue)
+  def reset(): Unit = _value = initValue
 }
 
-private object LongSQLMetricParam
-  extends LongSQLMetricParam(x => NumberFormat.getInstance().format(x.sum), 0L)
-
-private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam(
-  (values: Seq[Long]) => {
-    // This is a workaround for SPARK-11013.
-    // We use -1 as initial value of the accumulator, if the accumulator is 
valid, we will update
-    // it at the end of task and the value will be at least 0.
-    val validValues = values.filter(_ >= 0)
-    val Seq(sum, min, med, max) = {
-      val metric = if (validValues.length == 0) {
-        Seq.fill(4)(0L)
-      } else {
-        val sorted = validValues.sorted
-        Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), 
sorted(validValues.length - 1))
-      }
-      metric.map(Utils.bytesToString)
-    }
-    s"\n$sum ($min, $med, $max)"
-  }, -1L)
-
-private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam(
-  (values: Seq[Long]) => {
-    // This is a workaround for SPARK-11013.
-    // We use -1 as initial value of the accumulator, if the accumulator is 
valid, we will update
-    // it at the end of task and the value will be at least 0.
-    val validValues = values.filter(_ >= 0)
-    val Seq(sum, min, med, max) = {
-      val metric = if (validValues.length == 0) {
-        Seq.fill(4)(0L)
-      } else {
-        val sorted = validValues.sorted
-        Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), 
sorted(validValues.length - 1))
-      }
-      metric.map(Utils.msDurationToString)
-    }
-    s"\n$sum ($min, $med, $max)"
-  }, -1L)
 
 private[sql] object SQLMetrics {
-
   // Identifier for distinguishing SQL metrics from other accumulators
   private[sql] val ACCUM_IDENTIFIER = "sql"
 
-  private def createLongMetric(
-      sc: SparkContext,
-      name: String,
-      param: LongSQLMetricParam): LongSQLMetric = {
-    val acc = new LongSQLMetric(name, param)
-    // This is an internal accumulator so we need to register it explicitly.
-    Accumulators.register(acc)
-    sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
-    acc
-  }
+  private[sql] val SUM_METRIC = "sum"
+  private[sql] val SIZE_METRIC = "size"
+  private[sql] val TIMING_METRIC = "timing"
 
-  def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
-    createLongMetric(sc, name, LongSQLMetricParam)
+  def createMetric(sc: SparkContext, name: String): SQLMetric = {
+    val acc = new SQLMetric(SUM_METRIC)
+    acc.register(sc, name = Some(name), countFailedValues = true)
+    acc
   }
 
   /**
    * Create a metric to report the size information (including total, min, 
med, max) like data size,
    * spill size, etc.
    */
-  def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = {
+  def createSizeMetric(sc: SparkContext, name: String): SQLMetric = {
     // The final result of this metric in physical operator UI may looks like:
     // data size total (min, med, max):
     // 100GB (100MB, 1GB, 10GB)
-    createLongMetric(sc, s"$name total (min, med, max)", 
StatisticsBytesSQLMetricParam)
+    val acc = new SQLMetric(SIZE_METRIC, -1)
+    acc.register(sc, name = Some(s"$name total (min, med, max)"), 
countFailedValues = true)
+    acc
   }
 
-  def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = {
+  def createTimingMetric(sc: SparkContext, name: String): SQLMetric = {
     // The final result of this metric in physical operator UI may looks like:
     // duration(min, med, max):
     // 5s (800ms, 1s, 2s)
-    createLongMetric(sc, s"$name total (min, med, max)", 
StatisticsTimingSQLMetricParam)
-  }
-
-  def getMetricParam(metricParamName: String): 
SQLMetricParam[SQLMetricValue[Any], Any] = {
-    val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam)
-    val bytesSQLMetricParam = 
Utils.getFormattedClassName(StatisticsBytesSQLMetricParam)
-    val timingsSQLMetricParam = 
Utils.getFormattedClassName(StatisticsTimingSQLMetricParam)
-    val metricParam = metricParamName match {
-      case `longSQLMetricParam` => LongSQLMetricParam
-      case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam
-      case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam
-    }
-    metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]
+    val acc = new SQLMetric(TIMING_METRIC, -1)
+    acc.register(sc, name = Some(s"$name total (min, med, max)"), 
countFailedValues = true)
+    acc
   }
 
   /**
-   * A metric that its value will be ignored. Use this one when we need a 
metric parameter but don't
-   * care about the value.
+   * A function that defines how we aggregate the final accumulator results 
among all tasks,
+   * and represent it in string for a SQL physical operator.
    */
-  val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam)
+  def stringValue(metricsType: String, values: Seq[Long]): String = {
+    if (metricsType == SUM_METRIC) {
+      NumberFormat.getInstance().format(values.sum)
+    } else {
+      val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
+        Utils.bytesToString
+      } else if (metricsType == TIMING_METRIC) {
+        Utils.msDurationToString
+      } else {
+        throw new IllegalStateException("unexpected metrics type: " + 
metricsType)
+      }
+
+      val validValues = values.filter(_ >= 0)
+      val Seq(sum, min, med, max) = {
+        val metric = if (validValues.length == 0) {
+          Seq.fill(4)(0L)
+        } else {
+          val sorted = validValues.sorted
+          Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), 
sorted(validValues.length - 1))
+        }
+        metric.map(strFormat)
+      }
+      s"\n$sum ($min, $med, $max)"
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 5ae9e91..9118593 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -164,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends 
SparkListener with Loggi
         taskEnd.taskInfo.taskId,
         taskEnd.stageId,
         taskEnd.stageAttemptId,
-        taskEnd.taskMetrics.accumulatorUpdates(),
+        taskEnd.taskMetrics.accumulators().map(a => 
a.toInfo(Some(a.localValue), None)),
         finishTask = true)
     }
   }
@@ -296,7 +296,7 @@ private[sql] class SQLListener(conf: SparkConf) extends 
SparkListener with Loggi
           }
         }.filter { case (id, _) => 
executionUIData.accumulatorMetrics.contains(id) }
         mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
-          executionUIData.accumulatorMetrics(accumulatorId).metricParam)
+          executionUIData.accumulatorMetrics(accumulatorId).metricType)
       case None =>
         // This execution has been dropped
         Map.empty
@@ -305,11 +305,11 @@ private[sql] class SQLListener(conf: SparkConf) extends 
SparkListener with Loggi
 
   private def mergeAccumulatorUpdates(
       accumulatorUpdates: Seq[(Long, Any)],
-      paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, 
String] = {
+      metricTypeFunc: Long => String): Map[Long, String] = {
     accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
-      val param = paramFunc(accumulatorId)
-      (accumulatorId,
-        
param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value)))
+      val metricType = metricTypeFunc(accumulatorId)
+      accumulatorId ->
+        SQLMetrics.stringValue(metricType, values.map(_._2.asInstanceOf[Long]))
     }
   }
 
@@ -337,7 +337,7 @@ private[spark] class SQLHistoryListener(conf: SparkConf, 
sparkUI: SparkUI)
         // Filter out accumulators that are not SQL metrics
         // For now we assume all SQL metrics are Long's that have been JSON 
serialized as String's
         if (a.metadata == Some(SQLMetrics.ACCUM_IDENTIFIER)) {
-          val newValue = new 
LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L))
+          val newValue = a.update.map(_.toString.toLong).getOrElse(0L)
           Some(a.copy(update = Some(newValue)))
         } else {
           None
@@ -403,7 +403,7 @@ private[ui] class SQLExecutionUIData(
 private[ui] case class SQLPlanMetric(
     name: String,
     accumulatorId: Long,
-    metricParam: SQLMetricParam[SQLMetricValue[Any], Any])
+    metricType: String)
 
 /**
  * Store all accumulatorUpdates for all tasks in a Spark stage.

http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 1959f1e..8f5681b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -80,8 +80,7 @@ private[sql] object SparkPlanGraph {
     planInfo.nodeName match {
       case "WholeStageCodegen" =>
         val metrics = planInfo.metrics.map { metric =>
-          SQLPlanMetric(metric.name, metric.accumulatorId,
-            SQLMetrics.getMetricParam(metric.metricParam))
+          SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
         }
 
         val cluster = new SparkPlanGraphCluster(
@@ -106,8 +105,7 @@ private[sql] object SparkPlanGraph {
         edges += SparkPlanGraphEdge(node.id, parent.id)
       case name =>
         val metrics = planInfo.metrics.map { metric =>
-          SQLPlanMetric(metric.name, metric.accumulatorId,
-            SQLMetrics.getMetricParam(metric.metricParam))
+          SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
         }
         val node = new SparkPlanGraphNode(
           nodeIdGenerator.getAndIncrement(), planInfo.nodeName,

http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 4aea21e..0e6356b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -22,7 +22,7 @@ import scala.language.postfixOps
 
 import org.scalatest.concurrent.Eventually._
 
-import org.apache.spark.Accumulators
+import org.apache.spark.AccumulatorContext
 import org.apache.spark.sql.execution.RDDScanExec
 import org.apache.spark.sql.execution.columnar._
 import org.apache.spark.sql.execution.exchange.ShuffleExchange
@@ -333,11 +333,11 @@ class CachedTableSuite extends QueryTest with 
SQLTestUtils with SharedSQLContext
     sql("SELECT * FROM t1").count()
     sql("SELECT * FROM t2").count()
 
-    Accumulators.synchronized {
-      val accsSize = Accumulators.originals.size
+    AccumulatorContext.synchronized {
+      val accsSize = AccumulatorContext.originals.size
       sqlContext.uncacheTable("t1")
       sqlContext.uncacheTable("t2")
-      assert((accsSize - 2) == Accumulators.originals.size)
+      assert((accsSize - 2) == AccumulatorContext.originals.size)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 1859c6e..8de4d8b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -37,8 +37,8 @@ import org.apache.spark.util.{JsonProtocol, Utils}
 class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
   import testImplicits._
 
-  test("LongSQLMetric should not box Long") {
-    val l = SQLMetrics.createLongMetric(sparkContext, "long")
+  test("SQLMetric should not box Long") {
+    val l = SQLMetrics.createMetric(sparkContext, "long")
     val f = () => {
       l += 1L
       l.add(1L)
@@ -300,12 +300,12 @@ class SQLMetricsSuite extends SparkFunSuite with 
SharedSQLContext {
   }
 
   test("metrics can be loaded by history server") {
-    val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam)
+    val metric = SQLMetrics.createMetric(sparkContext, "zanzibar")
     metric += 10L
     val metricInfo = metric.toInfo(Some(metric.localValue), None)
     metricInfo.update match {
-      case Some(v: LongSQLMetricValue) => assert(v.value === 10L)
-      case Some(v) => fail(s"metric value was not a LongSQLMetricValue: 
${v.getClass.getName}")
+      case Some(v: Long) => assert(v === 10L)
+      case Some(v) => fail(s"metric value was not a Long: 
${v.getClass.getName}")
       case _ => fail("metric update is missing")
     }
     assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))

http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 09bd7f6..8572ed1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -21,18 +21,19 @@ import java.util.Properties
 
 import org.mockito.Mockito.{mock, when}
 
-import org.apache.spark.{SparkConf, SparkContext, SparkException, 
SparkFunSuite}
+import org.apache.spark._
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler._
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.util.quietly
 import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.ui.SparkUI
 
 class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
   import testImplicits._
+  import org.apache.spark.AccumulatorSuite.makeInfo
 
   private def createTestDataFrame: DataFrame = {
     Seq(
@@ -72,9 +73,11 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext {
 
   private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): 
TaskMetrics = {
     val metrics = mock(classOf[TaskMetrics])
-    when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { 
case (id, update) =>
-      new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)),
-        value = None, internal = true, countFailedValues = true)
+    when(metrics.accumulators()).thenReturn(accumulatorUpdates.map { case (id, 
update) =>
+      val acc = new LongAccumulator
+      acc.metadata = AccumulatorMetadata(id, Some(""), true)
+      acc.setValue(update)
+      acc
     }.toSeq)
     metrics
   }
@@ -130,16 +133,17 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext {
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", 
Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
-      (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+      (0L, 0, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+      (1L, 0, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 2))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", 
Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
-      (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 
2)).accumulatorUpdates())
+      (0L, 0, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+      (1L, 0, 0,
+        createTaskMetrics(accumulatorUpdates.mapValues(_ * 
2)).accumulators().map(makeInfo))
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 3))
@@ -149,8 +153,8 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext {
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", 
Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
-      (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+      (0L, 0, 1, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+      (1L, 0, 1, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 2))
@@ -189,8 +193,8 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext {
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", 
Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
-      (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+      (0L, 1, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+      (1L, 1, 0, 
createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), 
accumulatorUpdates.mapValues(_ * 7))
@@ -358,7 +362,7 @@ class SQLListenerSuite extends SparkFunSuite with 
SharedSQLContext {
     val stageSubmitted = SparkListenerStageSubmitted(stageInfo)
     // This task has both accumulators that are SQL metrics and accumulators 
that are not.
     // The listener should only track the ones that are actually SQL metrics.
-    val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella")
+    val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella")
     val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball")
     val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None)
     val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), 
None)

http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index eb25ea0..8a0578c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -96,7 +96,7 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
           case w: WholeStageCodegenExec => w.child.longMetric("numOutputRows")
           case other => other.longMetric("numOutputRows")
         }
-        metrics += metric.value.value
+        metrics += metric.value
       }
     }
     sqlContext.listenerManager.register(listener)
@@ -126,9 +126,9 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
       override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
 
       override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
-        metrics += qe.executedPlan.longMetric("dataSize").value.value
+        metrics += qe.executedPlan.longMetric("dataSize").value
         val bottomAgg = qe.executedPlan.children(0).children(0)
-        metrics += bottomAgg.longMetric("dataSize").value.value
+        metrics += bottomAgg.longMetric("dataSize").value
       }
     }
     sqlContext.listenerManager.register(listener)

http://git-wip-us.apache.org/repos/asf/spark/blob/bf5496db/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 007c338..b52b96a 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -55,7 +55,7 @@ case class HiveTableScanExec(
     "Partition pruning predicates only supported for partitioned tables.")
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of 
output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
 
   override def producedAttributes: AttributeSet = outputSet ++
     AttributeSet(partitionPruningPred.flatMap(_.references))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to