Repository: spark
Updated Branches:
  refs/heads/master 3a8b698e9 -> 6fc76e49c


Initialized the regVal for first iteration in SGD optimizer

Ported from https://github.com/apache/incubator-spark/pull/633

In runMiniBatchSGD, the regVal (for 1st iter) should be initialized
as sum of sqrt of weights if it's L2 update; for L1 update, the same logic is 
followed.

It maybe not be important here for SGD since the updater doesn't take the loss
as parameter to find the new weights. But it will give us the correct history 
of loss.
However, for LBFGS optimizer we implemented, the correct loss with regVal is 
crucial to
find the new weights.

Author: DB Tsai <[email protected]>

Closes #40 from dbtsai/dbtsai-smallRegValFix and squashes the following commits:

77d47da [DB Tsai] In runMiniBatchSGD, the regVal (for 1st iter) should be 
initialized as sum of sqrt of weights if it's L2 update; for L1 update, the 
same logic is followed.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6fc76e49
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6fc76e49
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6fc76e49

Branch: refs/heads/master
Commit: 6fc76e49c19310ec0d6cdf4754271ad09d652576
Parents: 3a8b698
Author: DB Tsai <[email protected]>
Authored: Sun Mar 2 00:31:59 2014 -0800
Committer: Reynold Xin <[email protected]>
Committed: Sun Mar 2 00:31:59 2014 -0800

----------------------------------------------------------------------
 .../mllib/optimization/GradientDescent.scala    |  8 +++-
 .../spark/mllib/optimization/Updater.scala      |  2 +
 .../optimization/GradientDescentSuite.scala     | 41 ++++++++++++++++++++
 3 files changed, 50 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6fc76e49/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 8e87b98..b967b22 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -149,7 +149,13 @@ object GradientDescent extends Logging {
 
     // Initialize weights as a column vector
     var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*)
-    var regVal = 0.0
+
+    /**
+     * For the first iteration, the regVal will be initialized as sum of sqrt 
of
+     * weights if it's L2 update; for L1 update; the same logic is followed.
+     */
+    var regVal = updater.compute(
+      weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2
 
     for (i <- 1 to numIterations) {
       // Sample a subset (fraction miniBatchFraction) of the total data

http://git-wip-us.apache.org/repos/asf/spark/blob/6fc76e49/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
index 889a03e..bf8f731 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
@@ -111,6 +111,8 @@ class SquaredL2Updater extends Updater {
     val step = gradient.mul(thisIterStepSize)
     // add up both updates from the gradient of the loss (= step) as well as
     // the gradient of the regularizer (= regParam * weightsOld)
+    // w' = w - thisIterStepSize * (gradient + regParam * w)
+    // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
     val newWeights = weightsOld.mul(1.0 - thisIterStepSize * 
regParam).sub(step)
     (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/6fc76e49/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index a453de6..631d0e2 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -104,4 +104,45 @@ class GradientDescentSuite extends FunSuite with 
LocalSparkContext with ShouldMa
     val lossDiff = loss.init.zip(loss.tail).map { case (lhs, rhs) => lhs - rhs 
}
     assert(lossDiff.count(_ > 0).toDouble / lossDiff.size > 0.8)
   }
+
+  test("Test the loss and gradient of first iteration with regularization.") {
+
+    val gradient = new LogisticGradient()
+    val updater = new SquaredL2Updater()
+
+    // Add a extra variable consisting of all 1.0's for the intercept.
+    val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42)
+    val data = testData.map { case LabeledPoint(label, features) =>
+      label -> Array(1.0, features: _*)
+    }
+
+    val dataRDD = sc.parallelize(data, 2).cache()
+
+    // Prepare non-zero weights
+    val initialWeightsWithIntercept = Array(1.0, 0.5)
+
+    val regParam0 = 0
+    val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD(
+      dataRDD, gradient, updater, 1, 1, regParam0, 1.0, 
initialWeightsWithIntercept)
+
+    val regParam1 = 1
+    val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD(
+      dataRDD, gradient, updater, 1, 1, regParam1, 1.0, 
initialWeightsWithIntercept)
+
+    def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
+      math.abs(x - y) / (math.abs(y) + 1e-15) < tol
+    }
+
+    assert(compareDouble(
+      loss1(0),
+      loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
+        math.pow(initialWeightsWithIntercept(1), 2)) / 2),
+      """For non-zero weights, the regVal should be \frac{1}{2}\sum_i 
w_i^2.""")
+
+    assert(
+      compareDouble(newWeights1(0) , newWeights0(0) - 
initialWeightsWithIntercept(0)) &&
+      compareDouble(newWeights1(1) , newWeights0(1) - 
initialWeightsWithIntercept(1)),
+      "The different between newWeights with/without regularization " +
+        "should be initialWeightsWithIntercept.")
+  }
 }

Reply via email to