Repository: commons-math Updated Branches: refs/heads/MATH_3_X f189a4c5a -> 69b5c8214
MATH-1264 Sort units according to distance from a given vector. Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/69b5c821 Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/69b5c821 Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/69b5c821 Branch: refs/heads/MATH_3_X Commit: 69b5c82140d95c2d584b98c14d15f1066d55f187 Parents: f189a4c Author: Gilles <er...@apache.org> Authored: Tue Sep 1 14:25:26 2015 +0200 Committer: Gilles <er...@apache.org> Committed: Tue Sep 1 14:25:26 2015 +0200 ---------------------------------------------------------------------- src/changes/changes.xml | 4 + .../commons/math3/ml/neuralnet/MapUtils.java | 78 ++++++++++++++++++++ .../math3/ml/neuralnet/MapUtilsTest.java | 18 +++++ 3 files changed, 100 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/69b5c821/src/changes/changes.xml ---------------------------------------------------------------------- diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 29ce88f..88594b5 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,10 @@ If the output is not quite correct, check for invisible trailing spaces! </properties> <body> <release version="3.6" date="XXXX-XX-XX" description=""> + <action dev="erans" type="add" issue="MATH-1264"> + "MapUtils" (package "o.a.c.m.ml.neuralnet"): Method to sort units according to distance + from a given vector. + </action> <action dev="erans" type="add" issue="MATH-1263"> Accessor (class "o.a.c.m.ml.neuralnet.twod.NeuronSquareMesh2D"). </action> http://git-wip-us.apache.org/repos/asf/commons-math/blob/69b5c821/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java index 9e67982..e7cf598 100644 --- a/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java +++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/MapUtils.java @@ -19,6 +19,9 @@ package org.apache.commons.math3.ml.neuralnet; import java.util.HashMap; import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.ArrayList; import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.neuralnet.twod.NeuronSquareMesh2D; import org.apache.commons.math3.exception.NoDataException; @@ -104,6 +107,46 @@ public class MapUtils { } /** + * Creates a list of neurons sorted in increased order of the distance + * to the given {@code features}. + * + * @param features Data. + * @param neurons List of neurons to scan. If it is empty, an empty array + * will be returned. + * @param distance Distance function. + * @return the neurons, sorted in increasing order of distance in data + * space. + * @throws org.apache.commons.math4.exception.DimensionMismatchException + * if the size of the input is not compatible with the neurons features + * size. + * + * @see #findBest(double[],Iterable,DistanceMeasure) + * @see #findBestAndSecondBest(double[],Iterable,DistanceMeasure) + * + * @since 3.6 + */ + public static Neuron[] sort(double[] features, + Iterable<Neuron> neurons, + DistanceMeasure distance) { + final List<PairNeuronDouble> list = new ArrayList<PairNeuronDouble>(); + + for (final Neuron n : neurons) { + final double d = distance.compute(n.getFeatures(), features); + list.add(new PairNeuronDouble(n, d)); + } + + Collections.sort(list); + + final int len = list.size(); + final Neuron[] sorted = new Neuron[len]; + + for (int i = 0; i < len; i++) { + sorted[i] = list.get(i).getNeuron(); + } + return sorted; + } + + /** * Computes the <a href="http://en.wikipedia.org/wiki/U-Matrix"> * U-matrix</a> of a two-dimensional map. * @@ -244,4 +287,39 @@ public class MapUtils { return ((double) notAdjacentCount) / count; } + + /** + * Helper data structure holding a (Neuron, double) pair. + */ + private static class PairNeuronDouble implements Comparable<PairNeuronDouble> { + /** Key */ + private final Neuron neuron; + /** Value */ + private final double value; + + /** + * @param neuron Neuron. + * @param value Value. + */ + public PairNeuronDouble(Neuron neuron, + double value) { + this.neuron = neuron; + this.value = value; + } + + /** @return the neuron. */ + public Neuron getNeuron() { + return neuron; + } + + /** @return the value. */ + public double getValue() { + return value; + } + + /** {@inheritDoc} */ + public int compareTo(PairNeuronDouble other) { + return Double.compare(this.value, other.value); + } + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/69b5c821/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java index 72bf09c..b6216c1 100644 --- a/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java +++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/MapUtilsTest.java @@ -88,4 +88,22 @@ public class MapUtilsTest { Assert.assertEquals(3, allBest.size()); } + + @Test + public void testSort() { + final Set<Neuron> list = new HashSet<Neuron>(); + + for (int i = 0; i < 4; i++) { + list.add(new Neuron(i, new double[] { i - 0.5 })); + } + + final Neuron[] sorted = MapUtils.sort(new double[] { 3.4 }, + list, + new EuclideanDistance()); + + final long[] expected = new long[] { 3, 2, 1, 0 }; + for (int i = 0; i < list.size(); i++) { + Assert.assertEquals(expected[i], sorted[i].getIdentifier()); + } + } }