Repository: spark
Updated Branches:
  refs/heads/master 6b88825a2 -> dc0c4490a


http://git-wip-us.apache.org/repos/asf/spark/blob/dc0c4490/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
new file mode 100644
index 0000000..d679085
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, 
PredictorParams}
+
+/**
+ * :: DeveloperApi ::
+ * Params for regression.
+ * Currently empty, but may add functionality later.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once 
it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait RegressorParams extends PredictorParams
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Single-label regression
+ *
+ * @tparam FeaturesType  Type of input features.  E.g., 
[[org.apache.spark.mllib.linalg.Vector]]
+ * @tparam Learner  Concrete Estimator type
+ * @tparam M  Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once 
it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Regressor[
+    FeaturesType,
+    Learner <: Regressor[FeaturesType, Learner, M],
+    M <: RegressionModel[FeaturesType, M]]
+  extends Predictor[FeaturesType, Learner, M]
+  with RegressorParams {
+
+  // TODO: defaultEvaluator (follow-up PR)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by a [[Regressor]].
+ *
+ * @tparam FeaturesType  Type of input features.  E.g., 
[[org.apache.spark.mllib.linalg.Vector]]
+ * @tparam M  Concrete Model type.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once 
it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class RegressionModel[FeaturesType, M <: 
RegressionModel[FeaturesType, M]]
+  extends PredictionModel[FeaturesType, M] with RegressorParams {
+
+  /**
+   * :: DeveloperApi ::
+   *
+   * Predict real-valued label for the given features.
+   * This internal method is used to implement [[transform()]] and output 
[[predictionCol]].
+   */
+  @DeveloperApi
+  protected def predict(features: FeaturesType): Double
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dc0c4490/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 77785bd..480bbfb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -26,6 +26,7 @@ import scala.collection.JavaConverters._
 import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
 
 import org.apache.spark.SparkException
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.util.NumericParser
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
@@ -110,9 +111,14 @@ sealed trait Vector extends Serializable {
 }
 
 /**
+ * :: DeveloperApi ::
+ *
  * User-defined type for [[Vector]] which allows easy interaction with SQL
  * via [[org.apache.spark.sql.DataFrame]].
+ *
+ * NOTE: This is currently private[spark] but will be made public later once 
it is stabilized.
  */
+@DeveloperApi
 private[spark] class VectorUDT extends UserDefinedType[Vector] {
 
   override def sqlType: StructType = {
@@ -169,6 +175,13 @@ private[spark] class VectorUDT extends 
UserDefinedType[Vector] {
   override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
 
   override def userClass: Class[Vector] = classOf[Vector]
+
+  override def equals(o: Any): Boolean = {
+    o match {
+      case v: VectorUDT => true
+      case _ => false
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/dc0c4490/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 56a9dbd..50995ff 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -65,7 +65,7 @@ public class JavaPipelineSuite {
       .setStages(new PipelineStage[] {scaler, lr});
     PipelineModel model = pipeline.fit(dataset);
     model.transform(dataset).registerTempTable("prediction");
-    DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM 
prediction");
+    DataFrame predictions = jsql.sql("SELECT label, probability, prediction 
FROM prediction");
     predictions.collectAsList();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc0c4490/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index f4ba23c..2628402 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -18,17 +18,22 @@
 package org.apache.spark.ml.classification;
 
 import java.io.Serializable;
+import java.lang.Math;
 import java.util.List;
 
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
+import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.DataFrame;
 import org.apache.spark.sql.SQLContext;
-import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
+import org.apache.spark.sql.Row;
+
 
 public class JavaLogisticRegressionSuite implements Serializable {
 
@@ -36,12 +41,17 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
   private transient SQLContext jsql;
   private transient DataFrame dataset;
 
+  private transient JavaRDD<LabeledPoint> datasetRDD;
+  private double eps = 1e-5;
+
   @Before
   public void setUp() {
     jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
     jsql = new SQLContext(jsc);
     List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
-    dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+    datasetRDD = jsc.parallelize(points, 2);
+    dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+    dataset.registerTempTable("dataset");
   }
 
   @After
@@ -51,29 +61,88 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
   }
 
   @Test
-  public void logisticRegression() {
+  public void logisticRegressionDefaultParams() {
     LogisticRegression lr = new LogisticRegression();
+    assert(lr.getLabelCol().equals("label"));
     LogisticRegressionModel model = lr.fit(dataset);
     model.transform(dataset).registerTempTable("prediction");
-    DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM 
prediction");
+    DataFrame predictions = jsql.sql("SELECT label, probability, prediction 
FROM prediction");
     predictions.collectAsList();
+    // Check defaults
+    assert(model.getThreshold() == 0.5);
+    assert(model.getFeaturesCol().equals("features"));
+    assert(model.getPredictionCol().equals("prediction"));
+    assert(model.getProbabilityCol().equals("probability"));
   }
 
   @Test
   public void logisticRegressionWithSetters() {
+    // Set params, train, and check as many params as we can.
     LogisticRegression lr = new LogisticRegression()
       .setMaxIter(10)
-      .setRegParam(1.0);
+      .setRegParam(1.0)
+      .setThreshold(0.6)
+      .setProbabilityCol("myProbability");
     LogisticRegressionModel model = lr.fit(dataset);
-    model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
-      .registerTempTable("prediction");
-    DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM 
prediction");
-    predictions.collectAsList();
+    assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+    assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+    assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
+    assert(model.getThreshold() == 0.6);
+
+    // Modify model params, and check that the params worked.
+    model.setThreshold(1.0);
+    model.transform(dataset).registerTempTable("predAllZero");
+    DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM 
predAllZero");
+    for (Row r: predAllZero.collectAsList()) {
+      assert(r.getDouble(0) == 0.0);
+    }
+    // Call transform with params, and check that the params worked.
+    model.transform(dataset, model.threshold().w(0.0), 
model.probabilityCol().w("myProb"))
+      .registerTempTable("predNotAllZero");
+    DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM 
predNotAllZero");
+    boolean foundNonZero = false;
+    for (Row r: predNotAllZero.collectAsList()) {
+      if (r.getDouble(0) != 0.0) foundNonZero = true;
+    }
+    assert(foundNonZero);
+
+    // Call fit() with new params, and check as many params as we can.
+    LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), 
lr.regParam().w(0.1),
+        lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+    assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+    assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
+    assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
+    assert(model2.getThreshold() == 0.4);
+    assert(model2.getProbabilityCol().equals("theProb"));
   }
 
+  @SuppressWarnings("unchecked")
   @Test
-  public void logisticRegressionFitWithVarargs() {
+  public void logisticRegressionPredictorClassifierMethods() {
     LogisticRegression lr = new LogisticRegression();
-    lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
+    LogisticRegressionModel model = lr.fit(dataset);
+    assert(model.numClasses() == 2);
+
+    model.transform(dataset).registerTempTable("transformed");
+    DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM 
transformed");
+    for (Row row: trans1.collect()) {
+      Vector raw = (Vector)row.get(0);
+      Vector prob = (Vector)row.get(1);
+      assert(raw.size() == 2);
+      assert(prob.size() == 2);
+      double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
+      assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
+      assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
+    }
+
+    DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM 
transformed");
+    for (Row row: trans2.collect()) {
+      double pred = row.getDouble(0);
+      Vector prob = (Vector)row.get(1);
+      double probOfPred = prob.apply((int)pred);
+      for (int i = 0; i < prob.size(); ++i) {
+        assert(probOfPred >= prob.apply(i));
+      }
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc0c4490/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
new file mode 100644
index 0000000..5bd616e
--- /dev/null
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+    .generateLogisticInputAsList;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+
+public class JavaLinearRegressionSuite implements Serializable {
+
+  private transient JavaSparkContext jsc;
+  private transient SQLContext jsql;
+  private transient DataFrame dataset;
+  private transient JavaRDD<LabeledPoint> datasetRDD;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+    jsql = new SQLContext(jsc);
+    List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+    datasetRDD = jsc.parallelize(points, 2);
+    dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+    dataset.registerTempTable("dataset");
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void linearRegressionDefaultParams() {
+    LinearRegression lr = new LinearRegression();
+    assert(lr.getLabelCol().equals("label"));
+    LinearRegressionModel model = lr.fit(dataset);
+    model.transform(dataset).registerTempTable("prediction");
+    DataFrame predictions = jsql.sql("SELECT label, prediction FROM 
prediction");
+    predictions.collect();
+    // Check defaults
+    assert(model.getFeaturesCol().equals("features"));
+    assert(model.getPredictionCol().equals("prediction"));
+  }
+
+  @Test
+  public void linearRegressionWithSetters() {
+    // Set params, train, and check as many params as we can.
+    LinearRegression lr = new LinearRegression()
+        .setMaxIter(10)
+        .setRegParam(1.0);
+    LinearRegressionModel model = lr.fit(dataset);
+    assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+    assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+
+    // Call fit() with new params, and check as many params as we can.
+    LinearRegressionModel model2 =
+        lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), 
lr.predictionCol().w("thePred"));
+    assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+    assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
+    assert(model2.getPredictionCol().equals("thePred"));
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dc0c4490/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 33e40dc..b3d1bfc 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -20,44 +20,108 @@ package org.apache.spark.ml.classification
 import org.scalatest.FunSuite
 
 import 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
 
 class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
 
   @transient var sqlContext: SQLContext = _
   @transient var dataset: DataFrame = _
+  private val eps: Double = 1e-5
 
   override def beforeAll(): Unit = {
     super.beforeAll()
     sqlContext = new SQLContext(sc)
     dataset = sqlContext.createDataFrame(
-      sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+      sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 
42), 2))
   }
 
-  test("logistic regression") {
+  test("logistic regression: default params") {
     val lr = new LogisticRegression
+    assert(lr.getLabelCol == "label")
+    assert(lr.getFeaturesCol == "features")
+    assert(lr.getPredictionCol == "prediction")
+    assert(lr.getRawPredictionCol == "rawPrediction")
+    assert(lr.getProbabilityCol == "probability")
     val model = lr.fit(dataset)
     model.transform(dataset)
-      .select("label", "prediction")
+      .select("label", "probability", "prediction", "rawPrediction")
       .collect()
+    assert(model.getThreshold === 0.5)
+    assert(model.getFeaturesCol == "features")
+    assert(model.getPredictionCol == "prediction")
+    assert(model.getRawPredictionCol == "rawPrediction")
+    assert(model.getProbabilityCol == "probability")
   }
 
   test("logistic regression with setters") {
+    // Set params, train, and check as many params as we can.
     val lr = new LogisticRegression()
       .setMaxIter(10)
       .setRegParam(1.0)
+      .setThreshold(0.6)
+      .setProbabilityCol("myProbability")
     val model = lr.fit(dataset)
-    model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
-      .select("label", "score", "prediction")
+    assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
+    assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
+    assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
+    assert(model.getThreshold === 0.6)
+
+    // Modify model params, and check that the params worked.
+    model.setThreshold(1.0)
+    val predAllZero = model.transform(dataset)
+      .select("prediction", "myProbability")
       .collect()
+      .map { case Row(pred: Double, prob: Vector) => pred }
+    assert(predAllZero.forall(_ === 0),
+      s"With threshold=1.0, expected predictions to be all 0, but only" +
+      s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
+    // Call transform with params, and check that the params worked.
+    val predNotAllZero =
+      model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> 
"myProb")
+        .select("prediction", "myProb")
+        .collect()
+        .map { case Row(pred: Double, prob: Vector) => pred }
+    assert(predNotAllZero.exists(_ !== 0.0))
+
+    // Call fit() with new params, and check as many params as we can.
+    val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, 
lr.threshold -> 0.4,
+      lr.probabilityCol -> "theProb")
+    assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+    assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+    assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
+    assert(model2.getThreshold === 0.4)
+    assert(model2.getProbabilityCol == "theProb")
   }
 
-  test("logistic regression fit and transform with varargs") {
+  test("logistic regression: Predictor, Classifier methods") {
+    val sqlContext = this.sqlContext
     val lr = new LogisticRegression
-    val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
-    model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> 
"probability")
-      .select("label", "probability", "prediction")
-      .collect()
+
+    val model = lr.fit(dataset)
+    assert(model.numClasses === 2)
+
+    val threshold = model.getThreshold
+    val results = model.transform(dataset)
+
+    // Compare rawPrediction with probability
+    results.select("rawPrediction", "probability").collect().map {
+      case Row(raw: Vector, prob: Vector) =>
+        assert(raw.size === 2)
+        assert(prob.size === 2)
+        val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
+        assert(prob(1) ~== probFromRaw1 relTol eps)
+        assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
+    }
+
+    // Compare prediction with probability
+    results.select("prediction", "probability").collect().map {
+      case Row(pred: Double, prob: Vector) =>
+        val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
+        assert(pred == predFromProb)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc0c4490/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
new file mode 100644
index 0000000..bbb44c3
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+
+  @transient var sqlContext: SQLContext = _
+  @transient var dataset: DataFrame = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+    dataset = sqlContext.createDataFrame(
+      sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 
42), 2))
+  }
+
+  test("linear regression: default params") {
+    val lr = new LinearRegression
+    assert(lr.getLabelCol == "label")
+    val model = lr.fit(dataset)
+    model.transform(dataset)
+      .select("label", "prediction")
+      .collect()
+    // Check defaults
+    assert(model.getFeaturesCol == "features")
+    assert(model.getPredictionCol == "prediction")
+  }
+
+  test("linear regression with setters") {
+    // Set params, train, and check as many as we can.
+    val lr = new LinearRegression()
+      .setMaxIter(10)
+      .setRegParam(1.0)
+    val model = lr.fit(dataset)
+    assert(model.fittingParamMap.get(lr.maxIter).get === 10)
+    assert(model.fittingParamMap.get(lr.regParam).get === 1.0)
+
+    // Call fit() with new params, and check as many as we can.
+    val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, 
lr.predictionCol -> "thePred")
+    assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+    assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+    assert(model2.getPredictionCol == "thePred")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/dc0c4490/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b17532c..4065a56 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,7 @@ object MimaExcludes {
         case v if v.startsWith("1.3") =>
           Seq(
             MimaBuild.excludeSparkPackage("deploy"),
+            MimaBuild.excludeSparkPackage("ml"),
             // These are needed if checking against the sbt build, since they 
are part of
             // the maven-generated artifacts in the 1.2 build.
             MimaBuild.excludeSparkPackage("unused"),
@@ -142,6 +143,11 @@ object MimaExcludes {
               "org.apache.spark.graphx.Graph.getCheckpointFiles"),
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.graphx.Graph.isCheckpointed")
+          ) ++ Seq(
+            // SPARK-4789 Standardize ML Prediction APIs
+            
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"),
+            
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
+            
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType")
           )
 
         case v if v.startsWith("1.2") =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to