Repository: spark
Updated Branches:
  refs/heads/master e1dc85373 -> bdff299f9


[SPARK-14900][ML] spark.ml classification metrics should include accuracy

## What changes were proposed in this pull request?

(Please fill in changes proposed in this fix)
Add accuracy to MulticlassMetrics class and add corresponding code in 
MulticlassClassificationEvaluator.

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, 
manual tests)
Scala Unit tests in ml.evaluation

Author: [email protected] <[email protected]>

Closes #12882 from wangmiao1981/accuracy.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bdff299f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bdff299f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bdff299f

Branch: refs/heads/master
Commit: bdff299f9e51b06b809fe505bda466009e759831
Parents: e1dc853
Author: [email protected] <[email protected]>
Authored: Fri May 13 08:29:37 2016 +0100
Committer: Sean Owen <[email protected]>
Committed: Fri May 13 08:29:37 2016 +0100

----------------------------------------------------------------------
 .../MulticlassClassificationEvaluator.scala        | 15 +++++----------
 .../spark/mllib/evaluation/MulticlassMetrics.scala | 17 ++++++++++++++---
 .../mllib/evaluation/MulticlassMetricsSuite.scala  |  9 +++++----
 3 files changed, 24 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bdff299f/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 3d89843..8408516 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -40,15 +40,15 @@ class MulticlassClassificationEvaluator @Since("1.5.0") 
(@Since("1.5.0") overrid
 
   /**
    * param for metric name in evaluation (supports `"f1"` (default), 
`"precision"`, `"recall"`,
-   * `"weightedPrecision"`, `"weightedRecall"`)
+   * `"weightedPrecision"`, `"weightedRecall"`, `"accuracy"`)
    * @group param
    */
   @Since("1.5.0")
   val metricName: Param[String] = {
     val allowedParams = ParamValidators.inArray(Array("f1", "precision",
-      "recall", "weightedPrecision", "weightedRecall"))
+      "recall", "weightedPrecision", "weightedRecall", "accuracy"))
     new Param(this, "metricName", "metric name in evaluation " +
-      "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams)
+      "(f1|precision|recall|weightedPrecision|weightedRecall|accuracy)", 
allowedParams)
   }
 
   /** @group getParam */
@@ -86,18 +86,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") 
(@Since("1.5.0") overrid
       case "recall" => metrics.recall
       case "weightedPrecision" => metrics.weightedPrecision
       case "weightedRecall" => metrics.weightedRecall
+      case "accuracy" => metrics.accuracy
     }
     metric
   }
 
   @Since("1.5.0")
-  override def isLargerBetter: Boolean = $(metricName) match {
-    case "f1" => true
-    case "precision" => true
-    case "recall" => true
-    case "weightedPrecision" => true
-    case "weightedRecall" => true
-  }
+  override def isLargerBetter: Boolean = true
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): MulticlassClassificationEvaluator = 
defaultCopy(extra)

http://git-wip-us.apache.org/repos/asf/spark/blob/bdff299f/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index 5dde2bd..719695a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -139,7 +139,8 @@ class MulticlassMetrics @Since("1.1.0") 
(predictionAndLabels: RDD[(Double, Doubl
    * Returns precision
    */
   @Since("1.1.0")
-  lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount
+  @deprecated("Use accuracy.", "2.0.0")
+  lazy val precision: Double = accuracy
 
   /**
    * Returns recall
@@ -148,14 +149,24 @@ class MulticlassMetrics @Since("1.1.0") 
(predictionAndLabels: RDD[(Double, Doubl
    * of all false negatives)
    */
   @Since("1.1.0")
-  lazy val recall: Double = precision
+  @deprecated("Use accuracy.", "2.0.0")
+  lazy val recall: Double = accuracy
 
   /**
    * Returns f-measure
    * (equals to precision and recall because precision equals recall)
    */
   @Since("1.1.0")
-  lazy val fMeasure: Double = precision
+  @deprecated("Use accuracy.", "2.0.0")
+  lazy val fMeasure: Double = accuracy
+
+  /**
+   * Returns accuracy
+   * (equals to the total number of correctly classified instances
+   * out of the total number of instances.)
+   */
+  @Since("2.0.0")
+  lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount
 
   /**
    * Returns weighted true positive rate

http://git-wip-us.apache.org/repos/asf/spark/blob/bdff299f/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index d55bc8c..f316c67 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -69,11 +69,12 @@ class MulticlassMetricsSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
     assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)
 
-    assert(math.abs(metrics.recall -
+    assert(math.abs(metrics.accuracy -
       (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
-    assert(math.abs(metrics.recall - metrics.precision) < delta)
-    assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
-    assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
+    assert(math.abs(metrics.accuracy - metrics.precision) < delta)
+    assert(math.abs(metrics.accuracy - metrics.recall) < delta)
+    assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta)
+    assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta)
     assert(math.abs(metrics.weightedFalsePositiveRate -
       ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < 
delta)
     assert(math.abs(metrics.weightedPrecision -


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to