New private method for factoring out some common code.
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/1c4ebd5b Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/1c4ebd5b Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/1c4ebd5b Branch: refs/heads/MATH_3_X Commit: 1c4ebd5bde1e1e597803d53be0ecd52fd1de5d28 Parents: 895f50e Author: Gilles <er...@apache.org> Authored: Sun Jul 19 00:07:03 2015 +0200 Committer: Gilles <er...@apache.org> Committed: Sun Jul 19 20:42:19 2015 +0200 ---------------------------------------------------------------------- .../ml/neuralnet/sofm/KohonenUpdateAction.java | 42 +++++++++++++------- 1 file changed, 28 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/1c4ebd5b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java index 91587bf..e6f2305 100644 --- a/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/sofm/KohonenUpdateAction.java @@ -20,13 +20,14 @@ package org.apache.commons.math3.ml.neuralnet.sofm; import java.util.Collection; import java.util.HashSet; import java.util.concurrent.atomic.AtomicLong; -import org.apache.commons.math3.ml.neuralnet.Network; + +import org.apache.commons.math3.analysis.function.Gaussian; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.neuralnet.MapUtils; +import org.apache.commons.math3.ml.neuralnet.Network; import org.apache.commons.math3.ml.neuralnet.Neuron; import org.apache.commons.math3.ml.neuralnet.UpdateAction; -import org.apache.commons.math3.ml.distance.DistanceMeasure; -import org.apache.commons.math3.linear.ArrayRealVector; -import org.apache.commons.math3.analysis.function.Gaussian; /** * Update formula for <a href="http://en.wikipedia.org/wiki/Kohonen"> @@ -91,6 +92,7 @@ public class KohonenUpdateAction implements UpdateAction { /** * {@inheritDoc} */ + @Override public void update(Network net, double[] features) { final long numCalls = numberOfCalls.incrementAndGet(); @@ -144,6 +146,26 @@ public class KohonenUpdateAction implements UpdateAction { } /** + * Tries to update a neuron. + * + * @param n Neuron to be updated. + * @param features Training data. + * @param learningRate Learning factor. + * @return {@code true} if the update succeeded, {@code true} if a + * concurrent update has been detected. + */ + private boolean attemptNeuronUpdate(Neuron n, + double[] features, + double learningRate) { + final double[] expect = n.getFeatures(); + final double[] update = computeFeatures(expect, + features, + learningRate); + + return n.compareAndSetFeatures(expect, update); + } + + /** * Atomically updates the given neuron. * * @param n Neuron to be updated. @@ -154,11 +176,7 @@ public class KohonenUpdateAction implements UpdateAction { double[] features, double learningRate) { while (true) { - final double[] expect = n.getFeatures(); - final double[] update = computeFeatures(expect, - features, - learningRate); - if (n.compareAndSetFeatures(expect, update)) { + if (attemptNeuronUpdate(n, features, learningRate)) { break; } } @@ -179,11 +197,7 @@ public class KohonenUpdateAction implements UpdateAction { while (true) { final Neuron best = MapUtils.findBest(features, net, distance); - final double[] expect = best.getFeatures(); - final double[] update = computeFeatures(expect, - features, - learningRate); - if (best.compareAndSetFeatures(expect, update)) { + if (attemptNeuronUpdate(best, features, learningRate)) { return best; }