Repository: commons-math Updated Branches: refs/heads/master 2fd6c8fa1 -> 6c4e1d719
MATH-1278 Deep copy of "Neuron", "Network" and "NeuronSquareMesh2D". Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/6c4e1d71 Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/6c4e1d71 Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/6c4e1d71 Branch: refs/heads/master Commit: 6c4e1d719fec98f04f0d80d2ff79dbfc2861bfaf Parents: 2fd6c8f Author: Gilles <er...@apache.org> Authored: Sun Sep 20 22:02:21 2015 +0200 Committer: Gilles <er...@apache.org> Committed: Sun Sep 20 22:02:21 2015 +0200 ---------------------------------------------------------------------- src/changes/changes.xml | 4 ++ .../commons/math4/ml/neuralnet/Network.java | 24 ++++++++++ .../commons/math4/ml/neuralnet/Neuron.java | 16 +++++++ .../ml/neuralnet/twod/NeuronSquareMesh2D.java | 48 ++++++++++++++++++++ .../commons/math4/ml/neuralnet/NetworkTest.java | 41 +++++++++++++++++ .../commons/math4/ml/neuralnet/NeuronTest.java | 26 +++++++++++ 6 files changed, 159 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/changes/changes.xml ---------------------------------------------------------------------- diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 90f3ee3..1b51e95 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -54,6 +54,10 @@ If the output is not quite correct, check for invisible trailing spaces! </release> <release version="4.0" date="XXXX-XX-XX" description=""> + <action dev="erans" type="add" issue="MATH-1278"> <!-- backported to 3.6 --> + Deep copy of "Network" (package "o.a.c.m.ml.neuralnet") to allow evaluation of + of intermediate states during training. + </action> <action dev="oertl" type="update" issue="MATH-1276"> <!-- backported to 3.6 --> Improved performance of sampling and inverse cumulative probability calculation for geometric distributions. http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java b/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java index 5520063..a223929 100644 --- a/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java +++ b/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java @@ -28,6 +28,7 @@ import java.util.Collection; import java.util.Iterator; import java.util.Comparator; import java.util.Collections; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; @@ -137,6 +138,29 @@ public class Network } /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + */ + public synchronized Network copy() { + final Network copy = new Network(nextId.get(), + featureSize); + + + for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) { + copy.neuronMap.put(e.getKey(), e.getValue().copy()); + } + + for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) { + copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue())); + } + + return copy; + } + + /** * {@inheritDoc} */ @Override http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java b/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java index eee8151..b45bfa9 100644 --- a/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java +++ b/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java @@ -67,6 +67,22 @@ public class Neuron implements Serializable { } /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + */ + public synchronized Neuron copy() { + final Neuron copy = new Neuron(getIdentifier(), + getFeatures()); + copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get()); + copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get()); + + return copy; + } + + /** * Gets the neuron's identifier. * * @return the identifier. http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java index 2ad150b..56ac1e2 100644 --- a/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java +++ b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java @@ -198,6 +198,54 @@ public class NeuronSquareMesh2D createLinks(); } + /** + * Constructor with restricted access, solely used for making a + * {@link #copy() deep copy}. + * + * @param wrapRowDim Whether to wrap the first dimension (i.e the first + * and last neurons will be linked together). + * @param wrapColDim Whether to wrap the second dimension (i.e the first + * and last neurons will be linked together). + * @param neighbourhoodType Neighbourhood type. + * @param net Underlying network. + * @param idGrid Neuron identifiers. + */ + private NeuronSquareMesh2D(boolean wrapRowDim, + boolean wrapColDim, + SquareNeighbourhood neighbourhoodType, + Network net, + long[][] idGrid) { + numberOfRows = idGrid.length; + numberOfColumns = idGrid[0].length; + wrapRows = wrapRowDim; + wrapColumns = wrapColDim; + neighbourhood = neighbourhoodType; + network = net; + identifiers = idGrid; + } + + /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + */ + public synchronized NeuronSquareMesh2D copy() { + final long[][] idGrid = new long[numberOfRows][numberOfColumns]; + for (int r = 0; r < numberOfRows; r++) { + for (int c = 0; c < numberOfColumns; c++) { + idGrid[r][c] = identifiers[r][c]; + } + } + + return new NeuronSquareMesh2D(wrapRows, + wrapColumns, + neighbourhood, + network.copy(), + idGrid); + } + /** {@inheritDoc} */ @Override public Iterator<Neuron> iterator() { http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java b/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java index 4f81eb6..7163e0d 100644 --- a/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java +++ b/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java @@ -132,6 +132,47 @@ public class NetworkTest { Assert.assertFalse(isUnspecifiedOrder); } + /* + * Test assumes that the network is + * + * 0-----1 + * | | + * | | + * 2-----3 + */ + @Test + public void testCopy() { + final FeatureInitializer[] initArray = { init }; + final Network net = new NeuronSquareMesh2D(2, false, + 2, false, + SquareNeighbourhood.VON_NEUMANN, + initArray).getNetwork(); + + final Network copy = net.copy(); + + final Neuron netNeuron0 = net.getNeuron(0); + final Neuron copyNeuron0 = copy.getNeuron(0); + final Neuron netNeuron1 = net.getNeuron(1); + final Neuron copyNeuron1 = copy.getNeuron(1); + Collection<Neuron> netNeighbours; + Collection<Neuron> copyNeighbours; + + // Check that both networks have the same connections. + netNeighbours = net.getNeighbours(netNeuron0); + copyNeighbours = copy.getNeighbours(copyNeuron0); + Assert.assertTrue(netNeighbours.contains(netNeuron1)); + Assert.assertTrue(copyNeighbours.contains(copyNeuron1)); + + // Delete neuron 1 from original. + net.deleteNeuron(netNeuron1); + + // Check that the networks now differ. + netNeighbours = net.getNeighbours(netNeuron0); + copyNeighbours = copy.getNeighbours(copyNeuron0); + Assert.assertFalse(netNeighbours.contains(netNeuron1)); + Assert.assertTrue(copyNeighbours.contains(copyNeuron1)); + } + @Test public void testSerialize() throws IOException, http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java b/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java index 92ec4bf..64f7013 100644 --- a/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java +++ b/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java @@ -88,6 +88,32 @@ public class NeuronTest { } @Test + public void testCopy() { + final Neuron n = new Neuron(1, new double[] { 9.87 }); + + // Update original. + double[] update = new double[] { n.getFeatures()[0] + 2.34 }; + n.compareAndSetFeatures(n.getFeatures(), update); + + // Create a copy. + final Neuron copy = n.copy(); + + // Check that original and copy have the same value. + Assert.assertTrue(n.getFeatures()[0] == copy.getFeatures()[0]); + Assert.assertEquals(n.getNumberOfAttemptedUpdates(), + copy.getNumberOfAttemptedUpdates()); + + // Update original. + update = new double[] { 1.23 * n.getFeatures()[0] }; + n.compareAndSetFeatures(n.getFeatures(), update); + + // Check that original and copy differ. + Assert.assertFalse(n.getFeatures()[0] == copy.getFeatures()[0]); + Assert.assertNotEquals(n.getNumberOfSuccessfulUpdates(), + copy.getNumberOfSuccessfulUpdates()); + } + + @Test public void testSerialize() throws IOException, ClassNotFoundException {