MATH-1370 Cache some data for improved performance.
Thanks to Ryan Gaffney. Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/0cc22651 Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/0cc22651 Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/0cc22651 Branch: refs/heads/feature-MATH-1370 Commit: 0cc22651bbd5db9fefac5f3114bb881aadca0d0c Parents: 69ed91c Author: Gilles <gil...@harfang.homelinux.org> Authored: Tue May 31 11:01:08 2016 +0200 Committer: Gilles <gil...@harfang.homelinux.org> Committed: Tue May 31 11:01:08 2016 +0200 ---------------------------------------------------------------------- .../distribution/EnumeratedDistribution.java | 43 ++++++++------ .../EnumeratedDistributionTest.java | 61 ++++++++++++++++++++ 2 files changed, 87 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/0cc22651/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java index 88ce037..3fa6279 100644 --- a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java @@ -20,7 +20,9 @@ import java.io.Serializable; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.commons.math4.exception.MathArithmeticException; import org.apache.commons.math4.exception.NotANumberException; @@ -59,11 +61,16 @@ public class EnumeratedDistribution<T> implements Serializable { private final List<T> singletons; /** * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, - * probability[i] is the probability that a random variable following this distribution takes + * probabilities[i] is the probability that a random variable following this distribution takes * the value singletons[i]. */ private final double[] probabilities; /** + * Probabilities of aggregated distinct random variables, cached to speed up probability lookup. + */ + private final Map<T, Double> massPoints; + + /** * Cumulative probabilities, cached to speed up sampling. */ private final double[] cumulativeProbabilities; @@ -92,7 +99,7 @@ public class EnumeratedDistribution<T> implements Serializable { singletons.add(sample.getKey()); final double p = sample.getValue(); if (p < 0) { - throw new NotPositiveException(sample.getValue()); + throw new NotPositiveException(p); } if (Double.isInfinite(p)) { throw new NotFiniteNumberException(p); @@ -105,11 +112,23 @@ public class EnumeratedDistribution<T> implements Serializable { probabilities = MathArrays.normalizeArray(probs, 1.0); + massPoints = new HashMap<T, Double>(); cumulativeProbabilities = new double[probabilities.length]; double sum = 0; for (int i = 0; i < probabilities.length; i++) { - sum += probabilities[i]; + double probability = probabilities[i]; + + sum += probability; cumulativeProbabilities[i] = sum; + + T randomVariable = singletons.get(i); + final double existingProbability; + if (massPoints.containsKey(randomVariable)) { + existingProbability = massPoints.get(randomVariable); + } else { + existingProbability = 0.0; + } + massPoints.put(randomVariable, existingProbability + probability); } } @@ -120,22 +139,14 @@ public class EnumeratedDistribution<T> implements Serializable { * distribution.</p> * * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)}, - * or both are null, then {@code probability(x1) = probability(x2)}.</p> + * or both are null, then {@code probability(x1) == probability(x2)}.</p> * * @param x the point at which the PMF is evaluated * @return the value of the probability mass function at {@code x} */ double probability(final T x) { - double probability = 0; - - for (int i = 0; i < probabilities.length; i++) { - if ((x == null && singletons.get(i) == null) || - (x != null && x.equals(singletons.get(i)))) { - probability += probabilities[i]; - } - } - - return probability; + final Double p = massPoints.get(x); + return p == null ? 0 : p.doubleValue(); } /** @@ -195,9 +206,7 @@ public class EnumeratedDistribution<T> implements Serializable { index = -index - 1; } - if (index >= 0 && - index < probabilities.length && - randomValue < cumulativeProbabilities[index]) { + if (randomValue < cumulativeProbabilities[index]) { return singletons.get(index); } http://git-wip-us.apache.org/repos/asf/commons-math/blob/0cc22651/src/test/java/org/apache/commons/math4/distribution/EnumeratedDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/EnumeratedDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/EnumeratedDistributionTest.java new file mode 100644 index 0000000..c1756b0 --- /dev/null +++ b/src/test/java/org/apache/commons/math4/distribution/EnumeratedDistributionTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math4.distribution; + +import static org.junit.Assert.assertEquals; + +import org.apache.commons.math4.util.Pair; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + + +/** + * Test class for {@link EnumeratedDistribution}. + */ +public class EnumeratedDistributionTest { + @Test + public void testProbability() { + final String[] values = {"car", "bike", null}; + final List<Pair<String, Double>> pmf = Arrays.asList( + new Pair<String, Double>(values[0], 0.1), + new Pair<String, Double>(values[1], 0.3), + new Pair<String, Double>(values[1], 0.2), + new Pair<String, Double>(values[2], 0.2), + new Pair<String, Double>(values[2], 0.2) + ); + final EnumeratedDistribution<String> distribution = new EnumeratedDistribution<String>(pmf); + assertEquals(0.1, distribution.probability(values[0]), 0); + assertEquals(0.5, distribution.probability(values[1]), 0); + assertEquals(0.4, distribution.probability(values[2]), 0); + } + + @Test + public void testGetPmf() { + final String s = "bike"; + final List<Pair<String, Double>> pmf = Arrays.asList( + new Pair<String, Double>(s, 0.1), + new Pair<String, Double>(s, 0.3), + new Pair<String, Double>(null, 0.2), + new Pair<String, Double>(s, 0.2), + new Pair<String, Double>(null, 0.2) + ); + final EnumeratedDistribution<String> distribution = new EnumeratedDistribution<String>(pmf); + assertEquals(pmf, distribution.getPmf()); + } +}