MATH-1267

Helper for finding the grid coordinates of a "Neuron" in a "NeuronSquareMesh2D".


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/f348d34f
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/f348d34f
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/f348d34f

Branch: refs/heads/master
Commit: f348d34fb1ef3c968e2981283618aa85db203901
Parents: cd55cbb
Author: Gilles <er...@apache.org>
Authored: Mon Sep 14 02:06:43 2015 +0200
Committer: Gilles <er...@apache.org>
Committed: Mon Sep 14 02:06:43 2015 +0200

----------------------------------------------------------------------
 .../ml/neuralnet/twod/util/LocationFinder.java  | 104 +++++++++++++++++++
 .../neuralnet/twod/util/LocationFinderTest.java |  70 +++++++++++++
 2 files changed, 174 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/f348d34f/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinder.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinder.java
 
b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinder.java
new file mode 100644
index 0000000..7450c94
--- /dev/null
+++ 
b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinder.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math4.ml.neuralnet.twod.util;
+
+import java.util.Map;
+import java.util.HashMap;
+import org.apache.commons.math4.ml.neuralnet.Neuron;
+import org.apache.commons.math4.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.apache.commons.math4.exception.MathIllegalStateException;
+
+/**
+ * Helper class to find the grid coordinates of a neuron.
+ */
+public class LocationFinder {
+    /** Identifier to location mapping. */
+    private final Map<Long, Location> locations = new HashMap<Long, 
Location>();
+
+    /**
+     * Container holding a (row, column) pair.
+     */
+    public static class Location {
+        /** Row index. */
+        private final int row;
+        /** Column index. */
+        private final int column;
+
+        /**
+         * @param row Row index.
+         * @param column Column index.
+         */
+        public Location(int row,
+                        int column) {
+            this.row = row;
+            this.column = column;
+        }
+
+        /**
+         * @return the row index.
+         */
+        public int getRow() {
+            return row;
+        }
+
+        /**
+         * @return the column index.
+         */
+        public int getColumn() {
+            return column;
+        }
+    }
+
+    /**
+     * Builds a finder to retrieve the locations of neurons that
+     * belong to the given {@code map}.
+     *
+     * @param map Map.
+     *
+     * @throws MathIllegalStateException if the network contains non-unique
+     * identifiers.  This indicates an inconsistent state due to a bug in
+     * the construction code of the underlying
+     * {@link org.apache.commons.math4.ml.neuralnet.Network network}.
+     */
+    public LocationFinder(NeuronSquareMesh2D map) {
+        final int nR = map.getNumberOfRows();
+        final int nC = map.getNumberOfColumns();
+
+        for (int r = 0; r < nR; r++) {
+            for (int c = 0; c < nC; c++) {
+                final Long id = map.getNeuron(r, c).getIdentifier();
+                if (locations.get(id) != null) {
+                    throw new MathIllegalStateException();
+                }
+                locations.put(id, new Location(r, c));
+            }
+        }
+    }
+
+    /**
+     * Retrieves a neuron's grid coordinates.
+     *
+     * @param n Neuron.
+     * @return the (row, column) coordinates of {@code n}, or {@code null}
+     * if no such neuron belongs to the {@link 
#LocationFinder(NeuronSquareMesh2D)
+     * map used to build this instance}.
+     */
+    public Location getLocation(Neuron n) {
+        return locations.get(n.getIdentifier());
+    }
+}

http://git-wip-us.apache.org/repos/asf/commons-math/blob/f348d34f/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinderTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinderTest.java
 
b/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinderTest.java
new file mode 100644
index 0000000..6aaad65
--- /dev/null
+++ 
b/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinderTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math4.ml.neuralnet.twod.util;
+
+import org.apache.commons.math4.ml.neuralnet.Neuron;
+import org.apache.commons.math4.ml.neuralnet.Network;
+import org.apache.commons.math4.ml.neuralnet.FeatureInitializer;
+import org.apache.commons.math4.ml.neuralnet.FeatureInitializerFactory;
+import org.apache.commons.math4.ml.neuralnet.SquareNeighbourhood;
+import org.apache.commons.math4.ml.neuralnet.twod.NeuronSquareMesh2D;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test for {@link LocationFinder}.
+ */
+public class LocationFinderTest {
+    final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2);
+
+    /*
+     * Test assumes that the network is
+     *
+     *  0-----1
+     *  |     |
+     *  |     |
+     *  2-----3
+     */
+    @Test
+    public void test2x2Network() {
+        final FeatureInitializer[] initArray = { init };
+        final NeuronSquareMesh2D map = new NeuronSquareMesh2D(2, false,
+                                                              2, false,
+                                                              
SquareNeighbourhood.VON_NEUMANN,
+                                                              initArray);
+        final LocationFinder finder = new LocationFinder(map);
+        final Network net = map.getNetwork();
+        LocationFinder.Location loc;
+
+        loc = finder.getLocation(net.getNeuron(0));
+        Assert.assertEquals(0, loc.getRow());
+        Assert.assertEquals(0, loc.getColumn());
+
+        loc = finder.getLocation(net.getNeuron(1));
+        Assert.assertEquals(0, loc.getRow());
+        Assert.assertEquals(1, loc.getColumn());
+
+        loc = finder.getLocation(net.getNeuron(2));
+        Assert.assertEquals(1, loc.getRow());
+        Assert.assertEquals(0, loc.getColumn());
+
+        loc = finder.getLocation(net.getNeuron(3));
+        Assert.assertEquals(1, loc.getRow());
+        Assert.assertEquals(1, loc.getColumn());
+    }
+}

Reply via email to