MATH-1278 Deep copy of "Neuron", "Network" and "NeuronSquareMesh2D".
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/f13693fd Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/f13693fd Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/f13693fd Branch: refs/heads/MATH_3_X Commit: f13693fdc3cd00c6acbcd51672076e38a333778c Parents: 78b9d81 Author: Gilles <er...@apache.org> Authored: Sun Sep 20 21:45:31 2015 +0200 Committer: Gilles <er...@apache.org> Committed: Sun Sep 20 21:45:31 2015 +0200 ---------------------------------------------------------------------- .../commons/math3/ml/neuralnet/Network.java | 24 ++++++++++ .../commons/math3/ml/neuralnet/Neuron.java | 16 +++++++ .../ml/neuralnet/twod/NeuronSquareMesh2D.java | 48 ++++++++++++++++++++ .../commons/math3/ml/neuralnet/NetworkTest.java | 41 +++++++++++++++++ .../commons/math3/ml/neuralnet/NeuronTest.java | 26 +++++++++++ 5 files changed, 155 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java index 6c4b8e9..70d8bb2 100644 --- a/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java @@ -28,6 +28,7 @@ import java.util.Collection; import java.util.Iterator; import java.util.Comparator; import java.util.Collections; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import org.apache.commons.math3.exception.DimensionMismatchException; @@ -135,6 +136,29 @@ public class Network } /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + */ + public synchronized Network copy() { + final Network copy = new Network(nextId.get(), + featureSize); + + + for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) { + copy.neuronMap.put(e.getKey(), e.getValue().copy()); + } + + for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) { + copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue())); + } + + return copy; + } + + /** * {@inheritDoc} */ public Iterator<Neuron> iterator() { http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java index 3fd0c0a..300fa50 100644 --- a/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java @@ -67,6 +67,22 @@ public class Neuron implements Serializable { } /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + */ + public synchronized Neuron copy() { + final Neuron copy = new Neuron(getIdentifier(), + getFeatures()); + copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get()); + copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get()); + + return copy; + } + + /** * Gets the neuron's identifier. * * @return the identifier. http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java index 2f4dd2d..d1c692e 100644 --- a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java @@ -197,6 +197,54 @@ public class NeuronSquareMesh2D createLinks(); } + /** + * Constructor with restricted access, solely used for making a + * {@link #copy() deep copy}. + * + * @param wrapRowDim Whether to wrap the first dimension (i.e the first + * and last neurons will be linked together). + * @param wrapColDim Whether to wrap the second dimension (i.e the first + * and last neurons will be linked together). + * @param neighbourhoodType Neighbourhood type. + * @param net Underlying network. + * @param idGrid Neuron identifiers. + */ + private NeuronSquareMesh2D(boolean wrapRowDim, + boolean wrapColDim, + SquareNeighbourhood neighbourhoodType, + Network net, + long[][] idGrid) { + numberOfRows = idGrid.length; + numberOfColumns = idGrid[0].length; + wrapRows = wrapRowDim; + wrapColumns = wrapColDim; + neighbourhood = neighbourhoodType; + network = net; + identifiers = idGrid; + } + + /** + * Performs a deep copy of this instance. + * Upon return, the copied and original instances will be independent: + * Updating one will not affect the other. + * + * @return a new instance with the same state as this instance. + */ + public synchronized NeuronSquareMesh2D copy() { + final long[][] idGrid = new long[numberOfRows][numberOfColumns]; + for (int r = 0; r < numberOfRows; r++) { + for (int c = 0; c < numberOfColumns; c++) { + idGrid[r][c] = identifiers[r][c]; + } + } + + return new NeuronSquareMesh2D(wrapRows, + wrapColumns, + neighbourhood, + network.copy(), + idGrid); + } + /** {@inheritDoc} */ public Iterator<Neuron> iterator() { return network.iterator(); http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java index 7f2bec9..aa83196 100644 --- a/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java +++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java @@ -127,6 +127,47 @@ public class NetworkTest { Assert.assertFalse(isUnspecifiedOrder); } + /* + * Test assumes that the network is + * + * 0-----1 + * | | + * | | + * 2-----3 + */ + @Test + public void testCopy() { + final FeatureInitializer[] initArray = { init }; + final Network net = new NeuronSquareMesh2D(2, false, + 2, false, + SquareNeighbourhood.VON_NEUMANN, + initArray).getNetwork(); + + final Network copy = net.copy(); + + final Neuron netNeuron0 = net.getNeuron(0); + final Neuron copyNeuron0 = copy.getNeuron(0); + final Neuron netNeuron1 = net.getNeuron(1); + final Neuron copyNeuron1 = copy.getNeuron(1); + Collection<Neuron> netNeighbours; + Collection<Neuron> copyNeighbours; + + // Check that both networks have the same connections. + netNeighbours = net.getNeighbours(netNeuron0); + copyNeighbours = copy.getNeighbours(copyNeuron0); + Assert.assertTrue(netNeighbours.contains(netNeuron1)); + Assert.assertTrue(copyNeighbours.contains(copyNeuron1)); + + // Delete neuron 1 from original. + net.deleteNeuron(netNeuron1); + + // Check that the networks now differ. + netNeighbours = net.getNeighbours(netNeuron0); + copyNeighbours = copy.getNeighbours(copyNeuron0); + Assert.assertFalse(netNeighbours.contains(netNeuron1)); + Assert.assertTrue(copyNeighbours.contains(copyNeuron1)); + } + @Test public void testSerialize() throws IOException, http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java index b03f07d..376d91c 100644 --- a/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java +++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java @@ -86,6 +86,32 @@ public class NeuronTest { } @Test + public void testCopy() { + final Neuron n = new Neuron(1, new double[] { 9.87 }); + + // Update original. + double[] update = new double[] { n.getFeatures()[0] + 2.34 }; + n.compareAndSetFeatures(n.getFeatures(), update); + + // Create a copy. + final Neuron copy = n.copy(); + + // Check that original and copy have the same value. + Assert.assertTrue(n.getFeatures()[0] == copy.getFeatures()[0]); + Assert.assertEquals(n.getNumberOfAttemptedUpdates(), + copy.getNumberOfAttemptedUpdates()); + + // Update original. + update = new double[] { 1.23 * n.getFeatures()[0] }; + n.compareAndSetFeatures(n.getFeatures(), update); + + // Check that original and copy differ. + Assert.assertFalse(n.getFeatures()[0] == copy.getFeatures()[0]); + Assert.assertNotEquals(n.getNumberOfSuccessfulUpdates(), + copy.getNumberOfSuccessfulUpdates()); + } + + @Test public void testSerialize() throws IOException, ClassNotFoundException {