Repository: spark
Updated Branches:
  refs/heads/master 24d07e45d -> d1cf32010


[SPARK-14886][MLLIB] RankingMetrics.ndcgAt throw 
java.lang.ArrayIndexOutOfBoundsException

## What changes were proposed in this pull request?

Handle case where number of predictions is less than label set, k in nDCG 
computation

## How was this patch tested?

New unit test; existing tests

Author: Sean Owen <[email protected]>

Closes #12756 from srowen/SPARK-14886.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d1cf3201
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d1cf3201
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d1cf3201

Branch: refs/heads/master
Commit: d1cf320105504f908ee01f33044d0a6b29c3c03f
Parents: 24d07e4
Author: Sean Owen <[email protected]>
Authored: Fri Apr 29 09:21:27 2016 +0200
Committer: Nick Pentreath <[email protected]>
Committed: Fri Apr 29 09:21:27 2016 +0200

----------------------------------------------------------------------
 .../spark/mllib/evaluation/RankingMetrics.scala |  2 +-
 .../mllib/evaluation/RankingMetricsSuite.scala  | 26 ++++++++++++++++----
 2 files changed, 22 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d1cf3201/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index c45742c..4ed4a05 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -140,7 +140,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: 
RDD[(Array[T], Array[T])]
         var i = 0
         while (i < n) {
           val gain = 1.0 / math.log(i + 2)
-          if (labSet.contains(pred(i))) {
+          if (i < pred.length && labSet.contains(pred(i))) {
             dcg += gain
           }
           if (i < labSetSize) {

http://git-wip-us.apache.org/repos/asf/spark/blob/d1cf3201/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
index 77ec49d..8e9d910 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -22,14 +22,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
 class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
-  test("Ranking metrics: map, ndcg") {
+
+  test("Ranking metrics: MAP, NDCG") {
     val predictionAndLabels = sc.parallelize(
       Seq(
-        (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
-        (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
-        (Array[Int](1, 2, 3, 4, 5), Array[Int]())
+        (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)),
+        (Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3)),
+        (Array(1, 2, 3, 4, 5), Array[Int]())
       ), 2)
-    val eps: Double = 1E-5
+    val eps = 1.0E-5
 
     val metrics = new RankingMetrics(predictionAndLabels)
     val map = metrics.meanAveragePrecision
@@ -48,6 +49,21 @@ class RankingMetricsSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
     assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
     assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
+  }
+
+  test("MAP, NDCG with few predictions (SPARK-14886)") {
+    val predictionAndLabels = sc.parallelize(
+      Seq(
+        (Array(1, 6, 2), Array(1, 2, 3, 4, 5)),
+        (Array[Int](), Array(1, 2, 3))
+      ), 2)
+    val eps = 1.0E-5
 
+    val metrics = new RankingMetrics(predictionAndLabels)
+    assert(metrics.precisionAt(1) ~== 0.5 absTol eps)
+    assert(metrics.precisionAt(2) ~== 0.25 absTol eps)
+    assert(metrics.ndcgAt(1) ~== 0.5 absTol eps)
+    assert(metrics.ndcgAt(2) ~== 0.30657 absTol eps)
   }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to