Repository: spark
Updated Branches:
  refs/heads/master ecfb3e73f -> ec03866a7


[SPARK-11343][ML] Allow float and double prediction/label columns in 
RegressionEvaluator

mengxr, felixcheung

This pull request just relaxes the type of the prediction/label columns to be 
float and double. Internally, these columns are casted to double. The other 
evaluators might need to be changed also.

Author: Dominik Dahlem <[email protected]>

Closes #9296 from 
dahlem/ddahlem_regression_evaluator_double_predictions_27102015.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ec03866a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ec03866a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ec03866a

Branch: refs/heads/master
Commit: ec03866a7ef2d0826520755d47c8c9480148a76c
Parents: ecfb3e7
Author: Dominik Dahlem <[email protected]>
Authored: Mon Nov 2 16:11:42 2015 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Mon Nov 2 16:11:42 2015 -0800

----------------------------------------------------------------------
 .../spark/ml/evaluation/RegressionEvaluator.scala       | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ec03866a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 3fd34d8..ba012f4 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -23,7 +23,8 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, 
HasPredictionCol}
 import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, FloatType}
 
 /**
  * :: Experimental ::
@@ -72,10 +73,13 @@ final class RegressionEvaluator @Since("1.4.0") 
(@Since("1.4.0") override val ui
   @Since("1.4.0")
   override def evaluate(dataset: DataFrame): Double = {
     val schema = dataset.schema
-    SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
-    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+    val predictionType = schema($(predictionCol)).dataType
+    require(predictionType == FloatType || predictionType == DoubleType)
+    val labelType = schema($(labelCol)).dataType
+    require(labelType == FloatType || labelType == DoubleType)
 
-    val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
+    val predictionAndLabels = dataset
+      .select(col($(predictionCol)).cast(DoubleType), 
col($(labelCol)).cast(DoubleType))
       .map { case Row(prediction: Double, label: Double) =>
         (prediction, label)
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to