MATH-1258: check for equal array lengths in distance functions Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/7934bfea Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/7934bfea Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/7934bfea
Branch: refs/heads/MATH_3_X Commit: 7934bfea106206d2840ba062eef105001601588a Parents: 9cb16d5 Author: Otmar Ertl <otmar.e...@gmail.com> Authored: Thu Aug 20 17:46:54 2015 +0200 Committer: Otmar Ertl <otmar.e...@gmail.com> Committed: Thu Aug 20 17:55:42 2015 +0200 ---------------------------------------------------------------------- .../math3/ml/distance/CanberraDistance.java | 6 +- .../math3/ml/distance/ChebyshevDistance.java | 4 +- .../math3/ml/distance/DistanceMeasure.java | 5 +- .../math3/ml/distance/EarthMoversDistance.java | 6 +- .../math3/ml/distance/EuclideanDistance.java | 4 +- .../math3/ml/distance/ManhattanDistance.java | 4 +- .../apache/commons/math3/util/MathArrays.java | 80 ++++++++++++++++---- 7 files changed, 89 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java index d997352..d467c3b 100644 --- a/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java +++ b/src/main/java/org/apache/commons/math3/ml/distance/CanberraDistance.java @@ -16,7 +16,9 @@ */ package org.apache.commons.math3.ml.distance; +import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; /** * Calculates the Canberra distance between two points. @@ -29,7 +31,9 @@ public class CanberraDistance implements DistanceMeasure { private static final long serialVersionUID = -6972277381587032228L; /** {@inheritDoc} */ - public double compute(double[] a, double[] b) { + public double compute(double[] a, double[] b) + throws DimensionMismatchException { + MathArrays.checkEqualLength(a, b); double sum = 0; for (int i = 0; i < a.length; i++) { final double num = FastMath.abs(a[i] - b[i]); http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java index 9eecb15..05dccb5 100644 --- a/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java +++ b/src/main/java/org/apache/commons/math3/ml/distance/ChebyshevDistance.java @@ -16,6 +16,7 @@ */ package org.apache.commons.math3.ml.distance; +import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.util.MathArrays; /** @@ -29,7 +30,8 @@ public class ChebyshevDistance implements DistanceMeasure { private static final long serialVersionUID = -4694868171115238296L; /** {@inheritDoc} */ - public double compute(double[] a, double[] b) { + public double compute(double[] a, double[] b) + throws DimensionMismatchException { return MathArrays.distanceInf(a, b); } http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java index 98bfc89..ff9c27f 100644 --- a/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java +++ b/src/main/java/org/apache/commons/math3/ml/distance/DistanceMeasure.java @@ -18,6 +18,8 @@ package org.apache.commons.math3.ml.distance; import java.io.Serializable; +import org.apache.commons.math3.exception.DimensionMismatchException; + /** * Interface for distance measures of n-dimensional vectors. * @@ -33,6 +35,7 @@ public interface DistanceMeasure extends Serializable { * @param a the first vector * @param b the second vector * @return the distance between the two vectors + * @throws DimensionMismatchException if the array lengths differ. */ - double compute(double[] a, double[] b); + double compute(double[] a, double[] b) throws DimensionMismatchException; } http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java index 13f2654..2518624 100644 --- a/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java +++ b/src/main/java/org/apache/commons/math3/ml/distance/EarthMoversDistance.java @@ -16,7 +16,9 @@ */ package org.apache.commons.math3.ml.distance; +import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; /** * Calculates the Earh Mover's distance (also known as Wasserstein metric) between two distributions. @@ -31,7 +33,9 @@ public class EarthMoversDistance implements DistanceMeasure { private static final long serialVersionUID = -5406732779747414922L; /** {@inheritDoc} */ - public double compute(double[] a, double[] b) { + public double compute(double[] a, double[] b) + throws DimensionMismatchException { + MathArrays.checkEqualLength(a, b); double lastDistance = 0; double totalDistance = 0; for (int i = 0; i < a.length; i++) { http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java index 5d8029e..187badc 100644 --- a/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java +++ b/src/main/java/org/apache/commons/math3/ml/distance/EuclideanDistance.java @@ -16,6 +16,7 @@ */ package org.apache.commons.math3.ml.distance; +import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.util.MathArrays; /** @@ -29,7 +30,8 @@ public class EuclideanDistance implements DistanceMeasure { private static final long serialVersionUID = 1717556319784040040L; /** {@inheritDoc} */ - public double compute(double[] a, double[] b) { + public double compute(double[] a, double[] b) + throws DimensionMismatchException { return MathArrays.distance(a, b); } http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java index 9e898c1..2eebe1b 100644 --- a/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java +++ b/src/main/java/org/apache/commons/math3/ml/distance/ManhattanDistance.java @@ -16,6 +16,7 @@ */ package org.apache.commons.math3.ml.distance; +import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.util.MathArrays; /** @@ -29,7 +30,8 @@ public class ManhattanDistance implements DistanceMeasure { private static final long serialVersionUID = -9108154600539125566L; /** {@inheritDoc} */ - public double compute(double[] a, double[] b) { + public double compute(double[] a, double[] b) + throws DimensionMismatchException { return MathArrays.distance1(a, b); } http://git-wip-us.apache.org/repos/asf/commons-math/blob/7934bfea/src/main/java/org/apache/commons/math3/util/MathArrays.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/util/MathArrays.java b/src/main/java/org/apache/commons/math3/util/MathArrays.java index 5bf7890..46a8716 100644 --- a/src/main/java/org/apache/commons/math3/util/MathArrays.java +++ b/src/main/java/org/apache/commons/math3/util/MathArrays.java @@ -194,8 +194,11 @@ public class MathArrays { * @param p1 the first point * @param p2 the second point * @return the L<sub>1</sub> distance between the two points + * @throws DimensionMismatchException if the array lengths differ. */ - public static double distance1(double[] p1, double[] p2) { + public static double distance1(double[] p1, double[] p2) + throws DimensionMismatchException { + checkEqualLength(p1, p2); double sum = 0; for (int i = 0; i < p1.length; i++) { sum += FastMath.abs(p1[i] - p2[i]); @@ -209,13 +212,16 @@ public class MathArrays { * @param p1 the first point * @param p2 the second point * @return the L<sub>1</sub> distance between the two points + * @throws DimensionMismatchException if the array lengths differ. */ - public static int distance1(int[] p1, int[] p2) { - int sum = 0; - for (int i = 0; i < p1.length; i++) { - sum += FastMath.abs(p1[i] - p2[i]); - } - return sum; + public static int distance1(int[] p1, int[] p2) + throws DimensionMismatchException { + checkEqualLength(p1, p2); + int sum = 0; + for (int i = 0; i < p1.length; i++) { + sum += FastMath.abs(p1[i] - p2[i]); + } + return sum; } /** @@ -224,8 +230,11 @@ public class MathArrays { * @param p1 the first point * @param p2 the second point * @return the L<sub>2</sub> distance between the two points + * @throws DimensionMismatchException if the array lengths differ. */ - public static double distance(double[] p1, double[] p2) { + public static double distance(double[] p1, double[] p2) + throws DimensionMismatchException { + checkEqualLength(p1, p2); double sum = 0; for (int i = 0; i < p1.length; i++) { final double dp = p1[i] - p2[i]; @@ -251,8 +260,11 @@ public class MathArrays { * @param p1 the first point * @param p2 the second point * @return the L<sub>2</sub> distance between the two points + * @throws DimensionMismatchException if the array lengths differ. */ - public static double distance(int[] p1, int[] p2) { + public static double distance(int[] p1, int[] p2) + throws DimensionMismatchException { + checkEqualLength(p1, p2); double sum = 0; for (int i = 0; i < p1.length; i++) { final double dp = p1[i] - p2[i]; @@ -267,8 +279,11 @@ public class MathArrays { * @param p1 the first point * @param p2 the second point * @return the L<sub>∞</sub> distance between the two points + * @throws DimensionMismatchException if the array lengths differ. */ - public static double distanceInf(double[] p1, double[] p2) { + public static double distanceInf(double[] p1, double[] p2) + throws DimensionMismatchException { + checkEqualLength(p1, p2); double max = 0; for (int i = 0; i < p1.length; i++) { max = FastMath.max(max, FastMath.abs(p1[i] - p2[i])); @@ -282,8 +297,11 @@ public class MathArrays { * @param p1 the first point * @param p2 the second point * @return the L<sub>∞</sub> distance between the two points + * @throws DimensionMismatchException if the array lengths differ. */ - public static int distanceInf(int[] p1, int[] p2) { + public static int distanceInf(int[] p1, int[] p2) + throws DimensionMismatchException { + checkEqualLength(p1, p2); int max = 0; for (int i = 0; i < p1.length; i++) { max = FastMath.max(max, FastMath.abs(p1[i] - p2[i])); @@ -399,6 +417,42 @@ public class MathArrays { checkEqualLength(a, b, true); } + + /** + * Check that both arrays have the same length. + * + * @param a Array. + * @param b Array. + * @param abort Whether to throw an exception if the check fails. + * @return {@code true} if the arrays have the same length. + * @throws DimensionMismatchException if the lengths differ and + * {@code abort} is {@code true}. + */ + public static boolean checkEqualLength(int[] a, + int[] b, + boolean abort) { + if (a.length == b.length) { + return true; + } else { + if (abort) { + throw new DimensionMismatchException(a.length, b.length); + } + return false; + } + } + + /** + * Check that both arrays have the same length. + * + * @param a Array. + * @param b Array. + * @throws DimensionMismatchException if the lengths differ. + */ + public static void checkEqualLength(int[] a, + int[] b) { + checkEqualLength(a, b, true); + } + /** * Check that the given array is sorted. * @@ -884,10 +938,8 @@ public class MathArrays { */ public static double linearCombination(final double[] a, final double[] b) throws DimensionMismatchException { + checkEqualLength(a, b); final int len = a.length; - if (len != b.length) { - throw new DimensionMismatchException(len, b.length); - } if (len == 1) { // Revert to scalar multiplication.