Repository: spark
Updated Branches:
  refs/heads/master f87092181 -> 2a5c93079


[SPARK-13962][ML] spark.ml Evaluators should support other numeric types for 
label

## What changes were proposed in this pull request?

Made BinaryClassificationEvaluator, MulticlassClassificationEvaluator and 
RegressionEvaluator accept all numeric types for label

## How was this patch tested?

Unit tests

Author: BenFradet <[email protected]>

Closes #12500 from BenFradet/SPARK-13962.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a5c9307
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a5c9307
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a5c9307

Branch: refs/heads/master
Commit: 2a5c930790b4b92674e74f093380d89a9a625552
Parents: f870921
Author: BenFradet <[email protected]>
Authored: Tue Apr 26 08:55:50 2016 +0200
Committer: Nick Pentreath <[email protected]>
Committed: Tue Apr 26 08:55:50 2016 +0200

----------------------------------------------------------------------
 .../BinaryClassificationEvaluator.scala         | 12 ++--
 .../MulticlassClassificationEvaluator.scala     | 11 ++--
 .../ml/evaluation/RegressionEvaluator.scala     | 19 ++----
 .../BinaryClassificationEvaluatorSuite.scala    |  7 +-
 ...MulticlassClassificationEvaluatorSuite.scala |  6 +-
 .../evaluation/RegressionEvaluatorSuite.scala   |  6 +-
 .../apache/spark/ml/tree/impl/TreeTests.scala   |  9 ++-
 .../apache/spark/ml/util/MLTestingUtils.scala   | 69 +++++++++++++-------
 8 files changed, 88 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2a5c9307/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index bde8c27..0cbc391 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, 
DefaultParamsWritable, I
 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.DoubleType
 
 /**
@@ -73,13 +74,14 @@ class BinaryClassificationEvaluator @Since("1.4.0") 
(@Since("1.4.0") override va
   override def evaluate(dataset: Dataset[_]): Double = {
     val schema = dataset.schema
     SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, 
new VectorUDT))
-    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+    SchemaUtils.checkNumericType(schema, $(labelCol))
 
     // TODO: When dataset metadata has been implemented, check 
rawPredictionCol vector length = 2.
-    val scoreAndLabels = dataset.select($(rawPredictionCol), 
$(labelCol)).rdd.map {
-      case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), 
label)
-      case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
-    }
+    val scoreAndLabels =
+      dataset.select(col($(rawPredictionCol)), 
col($(labelCol)).cast(DoubleType)).rdd.map {
+        case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), 
label)
+        case Row(rawPrediction: Double, label: Double) => (rawPrediction, 
label)
+      }
     val metrics = new BinaryClassificationMetrics(scoreAndLabels)
     val metric = $(metricName) match {
       case "areaUnderROC" => metrics.areaUnderROC()

http://git-wip-us.apache.org/repos/asf/spark/blob/2a5c9307/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 3acfc22..3d89843 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, 
HasPredictionCol}
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable, SchemaUtils}
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.DoubleType
 
 /**
@@ -72,12 +73,12 @@ class MulticlassClassificationEvaluator @Since("1.5.0") 
(@Since("1.5.0") overrid
   override def evaluate(dataset: Dataset[_]): Double = {
     val schema = dataset.schema
     SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
-    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+    SchemaUtils.checkNumericType(schema, $(labelCol))
 
-    val predictionAndLabels = dataset.select($(predictionCol), 
$(labelCol)).rdd.map {
-      case Row(prediction: Double, label: Double) =>
-        (prediction, label)
-    }
+    val predictionAndLabels =
+      dataset.select(col($(predictionCol)), 
col($(labelCol)).cast(DoubleType)).rdd.map {
+        case Row(prediction: Double, label: Double) => (prediction, label)
+      }
     val metrics = new MulticlassMetrics(predictionAndLabels)
     val metric = $(metricName) match {
       case "f1" => metrics.weightedFMeasure

http://git-wip-us.apache.org/repos/asf/spark/blob/2a5c9307/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 988f6e9..031cd0d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable, SchemaUtils}
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.functions._
@@ -74,22 +74,13 @@ final class RegressionEvaluator @Since("1.4.0") 
(@Since("1.4.0") override val ui
   @Since("2.0.0")
   override def evaluate(dataset: Dataset[_]): Double = {
     val schema = dataset.schema
-    val predictionColName = $(predictionCol)
-    val predictionType = schema($(predictionCol)).dataType
-    require(predictionType == FloatType || predictionType == DoubleType,
-      s"Prediction column $predictionColName must be of type float or double, 
" +
-        s" but not $predictionType")
-    val labelColName = $(labelCol)
-    val labelType = schema($(labelCol)).dataType
-    require(labelType == FloatType || labelType == DoubleType,
-      s"Label column $labelColName must be of type float or double, but not 
$labelType")
+    SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, 
FloatType))
+    SchemaUtils.checkNumericType(schema, $(labelCol))
 
     val predictionAndLabels = dataset
       .select(col($(predictionCol)).cast(DoubleType), 
col($(labelCol)).cast(DoubleType))
-      .rdd.
-      map { case Row(prediction: Double, label: Double) =>
-        (prediction, label)
-      }
+      .rdd
+      .map { case Row(prediction: Double, label: Double) => (prediction, 
label) }
     val metrics = new RegressionMetrics(predictionAndLabels)
     val metric = $(metricName) match {
       case "rmse" => metrics.rootMeanSquaredError

http://git-wip-us.apache.org/repos/asf/spark/blob/2a5c9307/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index 2734995..ff345221 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 
@@ -68,4 +68,9 @@ class BinaryClassificationEvaluatorSuite
       "equal to one of the following types: [DoubleType, ")
     assert(thrown.getMessage.replace("\n", "") contains "but was actually of 
type StringType.")
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val evaluator = new 
BinaryClassificationEvaluator().setRawPredictionCol("prediction")
+    MLTestingUtils.checkNumericTypes(evaluator, sqlContext)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2a5c9307/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
index 7ee6597..87e511a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 
 class MulticlassClassificationEvaluatorSuite
@@ -36,4 +36,8 @@ class MulticlassClassificationEvaluatorSuite
       .setMetricName("recall")
     testDefaultReadWrite(evaluator)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, 
sqlContext)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2a5c9307/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 954d3be..c7b9483 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 
@@ -83,4 +83,8 @@ class RegressionEvaluatorSuite
       .setMetricName("r2")
     testDefaultReadWrite(evaluator)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2a5c9307/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index b650a9f..e3f0989 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -79,16 +79,21 @@ private[ml] object TreeTests extends SparkFunSuite {
    *              This must be non-empty.
    * @param numClasses  Number of classes label can take. If 0, mark as 
continuous.
    * @param labelColName  Name of the label column on which to set the 
metadata.
+   * @param featuresColName  Name of the features column
    * @return DataFrame with metadata
    */
-  def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): 
DataFrame = {
+  def setMetadata(
+      data: DataFrame,
+      numClasses: Int,
+      labelColName: String,
+      featuresColName: String): DataFrame = {
     val labelAttribute = if (numClasses == 0) {
       NumericAttribute.defaultAttr.withName(labelColName)
     } else {
       
NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
     }
     val labelMetadata = labelAttribute.toMetadata()
-    data.select(data("features"), data(labelColName).as(labelColName, 
labelMetadata))
+    data.select(data(featuresColName), data(labelColName).as(labelColName, 
labelMetadata))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/2a5c9307/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 8108460..d9e6fd5a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.util
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.mllib.linalg.Vectors
@@ -47,12 +48,30 @@ object MLTestingUtils extends SparkFunSuite {
     val actuals = dfs.keys.filter(_ != DoubleType).map(t => 
estimator.fit(dfs(t)))
     actuals.foreach(actual => check(expected, actual))
 
-    val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext)
+    val dfWithStringLabels = sqlContext.createDataFrame(Seq(
+      ("0", Vectors.dense(0, 2, 3), 0.0)
+    )).toDF("label", "features", "censor")
     val thrown = intercept[IllegalArgumentException] {
       estimator.fit(dfWithStringLabels)
     }
-    assert(thrown.getMessage contains
-      "Column label must be of type NumericType but was actually of type 
StringType")
+    assert(thrown.getMessage.contains(
+      "Column label must be of type NumericType but was actually of type 
StringType"))
+  }
+
+  def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): 
Unit = {
+    val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", 
"prediction")
+    val expected = evaluator.evaluate(dfs(DoubleType))
+    val actuals = dfs.keys.filter(_ != DoubleType).map(t => 
evaluator.evaluate(dfs(t)))
+    actuals.foreach(actual => assert(expected === actual))
+
+    val dfWithStringLabels = sqlContext.createDataFrame(Seq(
+      ("0", 0d)
+    )).toDF("label", "prediction")
+    val thrown = intercept[IllegalArgumentException] {
+      evaluator.evaluate(dfWithStringLabels)
+    }
+    assert(thrown.getMessage.contains(
+      "Column label must be of type NumericType but was actually of type 
StringType"))
   }
 
   def genClassifDFWithNumericLabelCol(
@@ -69,9 +88,10 @@ object MLTestingUtils extends SparkFunSuite {
 
     val types =
       Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, 
DecimalType(10, 0))
-    types.map(t => t -> df.select(col(labelColName).cast(t), 
col(featuresColName)))
-      .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) }
-      .toMap
+    types.map { t =>
+        val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
+        t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName)
+      }.toMap
   }
 
   def genRegressionDFWithNumericLabelCol(
@@ -89,24 +109,29 @@ object MLTestingUtils extends SparkFunSuite {
 
     val types =
       Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, 
DecimalType(10, 0))
-    types
-      .map(t => t -> df.select(col(labelColName).cast(t), 
col(featuresColName)))
-      .map { case (t, d) =>
-        t -> TreeTests.setMetadata(d, 0, 
labelColName).withColumn(censorColName, lit(0.0))
-      }
-      .toMap
+    types.map { t =>
+        val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
+        t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName)
+          .withColumn(censorColName, lit(0.0))
+      }.toMap
   }
 
-  def generateDFWithStringLabelCol(
+  def genEvaluatorDFWithNumericLabelCol(
       sqlContext: SQLContext,
       labelColName: String = "label",
-      featuresColName: String = "features",
-      censorColName: String = "censor"): DataFrame =
-    sqlContext.createDataFrame(Seq(
-      ("0", Vectors.dense(0, 2, 3), 0.0),
-      ("1", Vectors.dense(0, 3, 1), 1.0),
-      ("0", Vectors.dense(0, 2, 2), 0.0),
-      ("1", Vectors.dense(0, 3, 9), 1.0),
-      ("0", Vectors.dense(0, 2, 6), 0.0)
-    )).toDF(labelColName, featuresColName, censorColName)
+      predictionColName: String = "prediction"): Map[NumericType, DataFrame] = 
{
+    val df = sqlContext.createDataFrame(Seq(
+      (0, 0d),
+      (1, 1d),
+      (2, 2d),
+      (3, 3d),
+      (4, 4d)
+    )).toDF(labelColName, predictionColName)
+
+    val types =
+      Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, 
DecimalType(10, 0))
+    types
+      .map(t => t -> df.select(col(labelColName).cast(t), 
col(predictionColName)))
+      .toMap
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to