Repository: commons-math
Updated Branches:
  refs/heads/master 2fd6c8fa1 -> 6c4e1d719


MATH-1278

Deep copy of "Neuron", "Network" and "NeuronSquareMesh2D".


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/6c4e1d71
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/6c4e1d71
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/6c4e1d71

Branch: refs/heads/master
Commit: 6c4e1d719fec98f04f0d80d2ff79dbfc2861bfaf
Parents: 2fd6c8f
Author: Gilles <er...@apache.org>
Authored: Sun Sep 20 22:02:21 2015 +0200
Committer: Gilles <er...@apache.org>
Committed: Sun Sep 20 22:02:21 2015 +0200

----------------------------------------------------------------------
 src/changes/changes.xml                         |  4 ++
 .../commons/math4/ml/neuralnet/Network.java     | 24 ++++++++++
 .../commons/math4/ml/neuralnet/Neuron.java      | 16 +++++++
 .../ml/neuralnet/twod/NeuronSquareMesh2D.java   | 48 ++++++++++++++++++++
 .../commons/math4/ml/neuralnet/NetworkTest.java | 41 +++++++++++++++++
 .../commons/math4/ml/neuralnet/NeuronTest.java  | 26 +++++++++++
 6 files changed, 159 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 90f3ee3..1b51e95 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -54,6 +54,10 @@ If the output is not quite correct, check for invisible 
trailing spaces!
     </release>
 
     <release version="4.0" date="XXXX-XX-XX" description="">
+      <action dev="erans" type="add" issue="MATH-1278"> <!-- backported to 3.6 
-->
+        Deep copy of "Network" (package "o.a.c.m.ml.neuralnet") to allow 
evaluation of
+        of intermediate states during training.
+      </action>
       <action dev="oertl" type="update" issue="MATH-1276"> <!-- backported to 
3.6 -->
         Improved performance of sampling and inverse cumulative probability 
calculation
         for geometric distributions.

http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java 
b/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java
index 5520063..a223929 100644
--- a/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java
+++ b/src/main/java/org/apache/commons/math4/ml/neuralnet/Network.java
@@ -28,6 +28,7 @@ import java.util.Collection;
 import java.util.Iterator;
 import java.util.Comparator;
 import java.util.Collections;
+import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
@@ -137,6 +138,29 @@ public class Network
     }
 
     /**
+     * Performs a deep copy of this instance.
+     * Upon return, the copied and original instances will be independent:
+     * Updating one will not affect the other.
+     *
+     * @return a new instance with the same state as this instance.
+     */
+    public synchronized Network copy() {
+        final Network copy = new Network(nextId.get(),
+                                         featureSize);
+
+
+        for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
+            copy.neuronMap.put(e.getKey(), e.getValue().copy());
+        }
+
+        for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
+            copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue()));
+        }
+
+        return copy;
+    }
+
+    /**
      * {@inheritDoc}
      */
     @Override

http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java 
b/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java
index eee8151..b45bfa9 100644
--- a/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java
+++ b/src/main/java/org/apache/commons/math4/ml/neuralnet/Neuron.java
@@ -67,6 +67,22 @@ public class Neuron implements Serializable {
     }
 
     /**
+     * Performs a deep copy of this instance.
+     * Upon return, the copied and original instances will be independent:
+     * Updating one will not affect the other.
+     *
+     * @return a new instance with the same state as this instance.
+     */
+    public synchronized Neuron copy() {
+        final Neuron copy = new Neuron(getIdentifier(),
+                                       getFeatures());
+        copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get());
+        copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get());
+
+        return copy;
+    }
+
+    /**
      * Gets the neuron's identifier.
      *
      * @return the identifier.

http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
 
b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
index 2ad150b..56ac1e2 100644
--- 
a/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
+++ 
b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
@@ -198,6 +198,54 @@ public class NeuronSquareMesh2D
         createLinks();
     }
 
+    /**
+     * Constructor with restricted access, solely used for making a
+     * {@link #copy() deep copy}.
+     *
+     * @param wrapRowDim Whether to wrap the first dimension (i.e the first
+     * and last neurons will be linked together).
+     * @param wrapColDim Whether to wrap the second dimension (i.e the first
+     * and last neurons will be linked together).
+     * @param neighbourhoodType Neighbourhood type.
+     * @param net Underlying network.
+     * @param idGrid Neuron identifiers.
+     */
+    private NeuronSquareMesh2D(boolean wrapRowDim,
+                               boolean wrapColDim,
+                               SquareNeighbourhood neighbourhoodType,
+                               Network net,
+                               long[][] idGrid) {
+        numberOfRows = idGrid.length;
+        numberOfColumns = idGrid[0].length;
+        wrapRows = wrapRowDim;
+        wrapColumns = wrapColDim;
+        neighbourhood = neighbourhoodType;
+        network = net;
+        identifiers = idGrid;
+    }
+
+    /**
+     * Performs a deep copy of this instance.
+     * Upon return, the copied and original instances will be independent:
+     * Updating one will not affect the other.
+     *
+     * @return a new instance with the same state as this instance.
+     */
+    public synchronized NeuronSquareMesh2D copy() {
+        final long[][] idGrid = new long[numberOfRows][numberOfColumns];
+        for (int r = 0; r < numberOfRows; r++) {
+            for (int c = 0; c < numberOfColumns; c++) {
+                idGrid[r][c] = identifiers[r][c];
+            }
+        }
+
+        return new NeuronSquareMesh2D(wrapRows,
+                                      wrapColumns,
+                                      neighbourhood,
+                                      network.copy(),
+                                      idGrid);
+    }
+
     /** {@inheritDoc} */
     @Override
     public Iterator<Neuron> iterator() {

http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java 
b/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java
index 4f81eb6..7163e0d 100644
--- a/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java
+++ b/src/test/java/org/apache/commons/math4/ml/neuralnet/NetworkTest.java
@@ -132,6 +132,47 @@ public class NetworkTest {
         Assert.assertFalse(isUnspecifiedOrder);
     }
 
+    /*
+     * Test assumes that the network is
+     *
+     *  0-----1
+     *  |     |
+     *  |     |
+     *  2-----3
+     */
+    @Test
+    public void testCopy() {
+        final FeatureInitializer[] initArray = { init };
+        final Network net = new NeuronSquareMesh2D(2, false,
+                                                   2, false,
+                                                   
SquareNeighbourhood.VON_NEUMANN,
+                                                   initArray).getNetwork();
+
+        final Network copy = net.copy();
+        
+        final Neuron netNeuron0 = net.getNeuron(0);
+        final Neuron copyNeuron0 = copy.getNeuron(0);
+        final Neuron netNeuron1 = net.getNeuron(1);
+        final Neuron copyNeuron1 = copy.getNeuron(1);
+        Collection<Neuron> netNeighbours;
+        Collection<Neuron> copyNeighbours;
+        
+        // Check that both networks have the same connections.
+        netNeighbours = net.getNeighbours(netNeuron0);
+        copyNeighbours = copy.getNeighbours(copyNeuron0);
+        Assert.assertTrue(netNeighbours.contains(netNeuron1));
+        Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
+
+        // Delete neuron 1 from original.
+        net.deleteNeuron(netNeuron1);
+
+        // Check that the networks now differ.
+        netNeighbours = net.getNeighbours(netNeuron0);
+        copyNeighbours = copy.getNeighbours(copyNeuron0);
+        Assert.assertFalse(netNeighbours.contains(netNeuron1));
+        Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
+    }
+
     @Test
     public void testSerialize()
         throws IOException,

http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c4e1d71/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java 
b/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java
index 92ec4bf..64f7013 100644
--- a/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java
+++ b/src/test/java/org/apache/commons/math4/ml/neuralnet/NeuronTest.java
@@ -88,6 +88,32 @@ public class NeuronTest {
     }
 
     @Test
+    public void testCopy() {
+        final Neuron n = new Neuron(1, new double[] { 9.87 });
+
+        // Update original.
+        double[] update = new double[] { n.getFeatures()[0] + 2.34 };
+        n.compareAndSetFeatures(n.getFeatures(), update);
+
+        // Create a copy.
+        final Neuron copy = n.copy();
+
+        // Check that original and copy have the same value.
+        Assert.assertTrue(n.getFeatures()[0] == copy.getFeatures()[0]);
+        Assert.assertEquals(n.getNumberOfAttemptedUpdates(),
+                            copy.getNumberOfAttemptedUpdates());
+
+        // Update original.
+        update = new double[] { 1.23 * n.getFeatures()[0] };
+        n.compareAndSetFeatures(n.getFeatures(), update);
+
+        // Check that original and copy differ.
+        Assert.assertFalse(n.getFeatures()[0] == copy.getFeatures()[0]);
+        Assert.assertNotEquals(n.getNumberOfSuccessfulUpdates(),
+                               copy.getNumberOfSuccessfulUpdates());
+    }
+
+    @Test
     public void testSerialize()
         throws IOException,
                ClassNotFoundException {

Reply via email to