Repository: commons-math Updated Branches: refs/heads/master 97d32b14e -> 97accb47d
[MATH-1152] Improved performance of EnumeratedDistribution#sample(). Thanks to Andras Sereny. Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/97accb47 Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/97accb47 Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/97accb47 Branch: refs/heads/master Commit: 97accb47de63ee5063eda23641c6017e29ab81d7 Parents: 97d32b1 Author: Thomas Neidhart <thomas.neidh...@gmail.com> Authored: Tue Sep 30 21:16:07 2014 +0200 Committer: Thomas Neidhart <thomas.neidh...@gmail.com> Committed: Tue Sep 30 21:16:07 2014 +0200 ---------------------------------------------------------------------- src/changes/changes.xml | 4 +++ .../distribution/EnumeratedDistribution.java | 31 +++++++++++++++----- 2 files changed, 28 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/97accb47/src/changes/changes.xml ---------------------------------------------------------------------- diff --git a/src/changes/changes.xml b/src/changes/changes.xml index c64fb58..fdb2bd4 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -73,6 +73,10 @@ Users are encouraged to upgrade to this version as this release not 2. A few methods in the FastMath class are in fact slower that their counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901). "> + <action dev="tn" type="fix" issue="MATH-1152" due-to="Andras Sereny"> + Improved performance of "EnumeratedDistribution#sample()" by caching + the cumulative probabilities and using binary rather than a linear search. + </action> <action dev="tn" type="fix" issue="MATH-1148" due-to="Guillaume Marceau"> "MonotoneChain" did not take the tolerance factor into account when sorting the input points. In case of collinear points this could result http://git-wip-us.apache.org/repos/asf/commons-math/blob/97accb47/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java b/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java index 5117c2a..e95098c 100644 --- a/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java +++ b/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java @@ -19,6 +19,7 @@ package org.apache.commons.math3.distribution; import java.io.Serializable; import java.lang.reflect.Array; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.commons.math3.exception.MathArithmeticException; @@ -64,6 +65,7 @@ public class EnumeratedDistribution<T> implements Serializable { * List of random variable values. */ private final List<T> singletons; + /** * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, * probability[i] is the probability that a random variable following this distribution takes @@ -72,6 +74,11 @@ public class EnumeratedDistribution<T> implements Serializable { private final double[] probabilities; /** + * Cumulative probabilities, cached to speed up sampling. + */ + private final double[] cumulativeProbabilities; + + /** * Create an enumerated distribution using the given probability mass function * enumeration. * @@ -123,6 +130,13 @@ public class EnumeratedDistribution<T> implements Serializable { } probabilities = MathArrays.normalizeArray(probs, 1.0); + + cumulativeProbabilities = new double[probabilities.length]; + double sum = 0; + for (int i = 0; i < probabilities.length; i++) { + sum += probabilities[i]; + cumulativeProbabilities[i] = sum; + } } /** @@ -186,18 +200,21 @@ public class EnumeratedDistribution<T> implements Serializable { */ public T sample() { final double randomValue = random.nextDouble(); - double sum = 0; - for (int i = 0; i < probabilities.length; i++) { - sum += probabilities[i]; - if (randomValue < sum) { - return singletons.get(i); + int index = Arrays.binarySearch(cumulativeProbabilities, randomValue); + if (index < 0) { + index = -index-1; + } + + if (index >= 0 && index < probabilities.length) { + if (randomValue < cumulativeProbabilities[index]) { + return singletons.get(index); } } /* This should never happen, but it ensures we will return a correct - * object in case the loop above has some floating point inequality - * problem on the final iteration. */ + * object in case there is some floating point inequality problem + * wrt the cumulative probabilities. */ return singletons.get(singletons.size() - 1); }