MATH-1267 Helper for finding the grid coordinates of a "Neuron" in a "NeuronSquareMesh2D".
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/f348d34f Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/f348d34f Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/f348d34f Branch: refs/heads/master Commit: f348d34fb1ef3c968e2981283618aa85db203901 Parents: cd55cbb Author: Gilles <er...@apache.org> Authored: Mon Sep 14 02:06:43 2015 +0200 Committer: Gilles <er...@apache.org> Committed: Mon Sep 14 02:06:43 2015 +0200 ---------------------------------------------------------------------- .../ml/neuralnet/twod/util/LocationFinder.java | 104 +++++++++++++++++++ .../neuralnet/twod/util/LocationFinderTest.java | 70 +++++++++++++ 2 files changed, 174 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/f348d34f/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinder.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinder.java b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinder.java new file mode 100644 index 0000000..7450c94 --- /dev/null +++ b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinder.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math4.ml.neuralnet.twod.util; + +import java.util.Map; +import java.util.HashMap; +import org.apache.commons.math4.ml.neuralnet.Neuron; +import org.apache.commons.math4.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.apache.commons.math4.exception.MathIllegalStateException; + +/** + * Helper class to find the grid coordinates of a neuron. + */ +public class LocationFinder { + /** Identifier to location mapping. */ + private final Map<Long, Location> locations = new HashMap<Long, Location>(); + + /** + * Container holding a (row, column) pair. + */ + public static class Location { + /** Row index. */ + private final int row; + /** Column index. */ + private final int column; + + /** + * @param row Row index. + * @param column Column index. + */ + public Location(int row, + int column) { + this.row = row; + this.column = column; + } + + /** + * @return the row index. + */ + public int getRow() { + return row; + } + + /** + * @return the column index. + */ + public int getColumn() { + return column; + } + } + + /** + * Builds a finder to retrieve the locations of neurons that + * belong to the given {@code map}. + * + * @param map Map. + * + * @throws MathIllegalStateException if the network contains non-unique + * identifiers. This indicates an inconsistent state due to a bug in + * the construction code of the underlying + * {@link org.apache.commons.math4.ml.neuralnet.Network network}. + */ + public LocationFinder(NeuronSquareMesh2D map) { + final int nR = map.getNumberOfRows(); + final int nC = map.getNumberOfColumns(); + + for (int r = 0; r < nR; r++) { + for (int c = 0; c < nC; c++) { + final Long id = map.getNeuron(r, c).getIdentifier(); + if (locations.get(id) != null) { + throw new MathIllegalStateException(); + } + locations.put(id, new Location(r, c)); + } + } + } + + /** + * Retrieves a neuron's grid coordinates. + * + * @param n Neuron. + * @return the (row, column) coordinates of {@code n}, or {@code null} + * if no such neuron belongs to the {@link #LocationFinder(NeuronSquareMesh2D) + * map used to build this instance}. + */ + public Location getLocation(Neuron n) { + return locations.get(n.getIdentifier()); + } +} http://git-wip-us.apache.org/repos/asf/commons-math/blob/f348d34f/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinderTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinderTest.java b/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinderTest.java new file mode 100644 index 0000000..6aaad65 --- /dev/null +++ b/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/util/LocationFinderTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math4.ml.neuralnet.twod.util; + +import org.apache.commons.math4.ml.neuralnet.Neuron; +import org.apache.commons.math4.ml.neuralnet.Network; +import org.apache.commons.math4.ml.neuralnet.FeatureInitializer; +import org.apache.commons.math4.ml.neuralnet.FeatureInitializerFactory; +import org.apache.commons.math4.ml.neuralnet.SquareNeighbourhood; +import org.apache.commons.math4.ml.neuralnet.twod.NeuronSquareMesh2D; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test for {@link LocationFinder}. + */ +public class LocationFinderTest { + final FeatureInitializer init = FeatureInitializerFactory.uniform(0, 2); + + /* + * Test assumes that the network is + * + * 0-----1 + * | | + * | | + * 2-----3 + */ + @Test + public void test2x2Network() { + final FeatureInitializer[] initArray = { init }; + final NeuronSquareMesh2D map = new NeuronSquareMesh2D(2, false, + 2, false, + SquareNeighbourhood.VON_NEUMANN, + initArray); + final LocationFinder finder = new LocationFinder(map); + final Network net = map.getNetwork(); + LocationFinder.Location loc; + + loc = finder.getLocation(net.getNeuron(0)); + Assert.assertEquals(0, loc.getRow()); + Assert.assertEquals(0, loc.getColumn()); + + loc = finder.getLocation(net.getNeuron(1)); + Assert.assertEquals(0, loc.getRow()); + Assert.assertEquals(1, loc.getColumn()); + + loc = finder.getLocation(net.getNeuron(2)); + Assert.assertEquals(1, loc.getRow()); + Assert.assertEquals(0, loc.getColumn()); + + loc = finder.getLocation(net.getNeuron(3)); + Assert.assertEquals(1, loc.getRow()); + Assert.assertEquals(1, loc.getColumn()); + } +}