This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 211f351968f98c923e5634e59127fe945780db09 Author: Alex Herbert <aherb...@apache.org> AuthorDate: Fri May 14 20:40:25 2021 +0100 RNG-138: CompositeSamplers to sample from a weighted combination of samplers --- .../commons/rng/sampling/CompositeSamplers.java | 1093 ++++++++++++++++++++ .../rng/sampling/CompositeSamplersTest.java | 1004 ++++++++++++++++++ src/main/resources/pmd/pmd-ruleset.xml | 23 +- 3 files changed, 2118 insertions(+), 2 deletions(-) diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CompositeSamplers.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CompositeSamplers.java new file mode 100644 index 0000000..a269338 --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CompositeSamplers.java @@ -0,0 +1,1093 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.rng.sampling; + +import java.util.List; +import java.util.ArrayList; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.ContinuousSampler; +import org.apache.commons.rng.sampling.distribution.DiscreteSampler; +import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler; +import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.LongSampler; +import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler; +import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.SharedStateLongSampler; + +/** + * Factory class to create a sampler that combines sampling from multiple samplers. + * + * <p>The composite sampler is constructed using a {@link Builder builder} for the type of samplers + * that will form the composite. Each sampler has a weight in the composition. + * Samples are returned using a 2 step algorithm: + * + * <ol> + * <li>Select a sampler based on its weighting + * <li>Return a sample from the selected sampler + * </ol> + * + * <p>The weights used for each sampler create a discrete probability distribution. This is + * sampled using a discrete probability distribution sampler. The builder provides methods + * to change the default implementation. + * + * <p>The following example will create a sampler to uniformly sample the border of a triangle + * using the line segment lengths as weights: + * + * <pre> + * UniformRandomProvider rng = RandomSource.KISS.create(); + * double[] a = {1.23, 4.56}; + * double[] b = {6.78, 9.01}; + * double[] c = {3.45, 2.34}; + * ObjectSampler<double[]> sampler = + * CompositeSamplers.<double[]>newObjectSamplerBuilder() + * .add(LineSampler.of(a, b, rng), Math.hypot(a[0] - b[0], a[1] - b[1])) + * .add(LineSampler.of(b, c, rng), Math.hypot(b[0] - c[0], b[1] - c[1])) + * .add(LineSampler.of(c, a, rng), Math.hypot(c[0] - a[0], c[1] - a[1])) + * .build(rng); + * </pre> + * + * @since 1.4 + */ +public final class CompositeSamplers { + /** + * A factory for creating a sampler of a user-defined + * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> + * discrete probability distribution</a>. + */ + public interface DiscreteProbabilitySamplerFactory { + /** + * Creates the sampler. + * + * @param rng Source of randomness. + * @param probabilities Discrete probability distribution. + * @return the sampler + */ + DiscreteSampler create(UniformRandomProvider rng, + double[] probabilities); + } + + /** + * The DiscreteProbabilitySampler class defines implementations that sample from a user-defined + * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> + * discrete probability distribution</a>. + * + * <p>All implementations support the {@link SharedStateDiscreteSampler} interface. + */ + public enum DiscreteProbabilitySampler implements DiscreteProbabilitySamplerFactory { + /** Sample using a guide table (see {@link GuideTableDiscreteSampler}). */ + GUIDE_TABLE { + @Override + public SharedStateDiscreteSampler create(UniformRandomProvider rng, double[] probabilities) { + return GuideTableDiscreteSampler.of(rng, probabilities); + } + }, + /** Sample using the alias method (see {@link AliasMethodDiscreteSampler}). */ + ALIAS_METHOD { + @Override + public SharedStateDiscreteSampler create(UniformRandomProvider rng, double[] probabilities) { + return AliasMethodDiscreteSampler.of(rng, probabilities); + } + }, + /** + * Sample using an optimised look-up table (see + * {@link org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler.Enumerated + * MarsagliaTsangWangDiscreteSampler.Enumerated}). + */ + LOOKUP_TABLE { + @Override + public SharedStateDiscreteSampler create(UniformRandomProvider rng, double[] probabilities) { + return MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities); + } + }; + } + + /** + * A class to implement the SharedStateDiscreteSampler interface for a discrete probability + * sampler given a factory and the probability distribution. Each new instance will recreate + * the distribution sampler using the factory. + */ + private static class SharedStateDiscreteProbabilitySampler implements SharedStateDiscreteSampler { + /** The sampler. */ + private final DiscreteSampler sampler; + /** The factory to create a new discrete sampler. */ + private final DiscreteProbabilitySamplerFactory factory; + /** The probabilities. */ + private final double[] probabilities; + + /** + * @param sampler Sampler of the discrete distribution. + * @param factory Factory to create a new discrete sampler. + * @param probabilities Probabilities of the discrete distribution. + * @throws NullPointerException if the {@code sampler} is null + */ + SharedStateDiscreteProbabilitySampler(DiscreteSampler sampler, + DiscreteProbabilitySamplerFactory factory, + double[] probabilities) { + this.sampler = requireNonNull(sampler, "discrete sampler"); + // Assume the factory and probabilities are not null + this.factory = factory; + this.probabilities = probabilities; + } + + @Override + public int sample() { + // Delegate + return sampler.sample(); + } + + @Override + public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + // The factory may destructively modify the probabilities + return new SharedStateDiscreteProbabilitySampler(factory.create(rng, probabilities.clone()), + factory, probabilities); + } + } + + /** + * Builds a composite sampler. + * + * <p>A composite sampler is a combination of multiple samplers + * that all return the same sample type. Each sampler has a weighting in the composition. + * Samples are returned using a 2 step algorithm: + * + * <ol> + * <li>Select a sampler based on its weighting + * <li>Return a sample from the selected sampler + * </ol> + * + * <p>Step 1 requires a discrete sampler constructed from a discrete probability distribution. + * The probability for each sampler is the sampler weight divided by the sum of the weights: + * <pre> + * p(i) = w(i) / sum(w) + * </pre> + * + * <p>The builder provides a method to set the factory used to generate the discrete sampler. + * + * @param <S> Type of sampler + */ + public interface Builder<S> { + /** + * Return the number of samplers in the composite. The size must be non-zero before + * the {@link #build(UniformRandomProvider) build} method can create a sampler. + * + * @return the size + */ + int size(); + + /** + * Adds the sampler to the composite. A sampler with a zero weight is ignored. + * + * @param sampler Sampler. + * @param weight Weight for the composition. + * @return a reference to this builder + * @throws IllegalArgumentException if {@code weight} is negative, infinite or {@code NaN}. + * @throws NullPointerException if {@code sampler} is null. + */ + Builder<S> add(S sampler, double weight); + + /** + * Sets the factory to use to generate the composite's discrete sampler from the sampler + * weights. + * + * <p>Note: If the factory is not explicitly set then a default will be used. + * + * @param factory Factory. + * @return a reference to this builder + * @throws NullPointerException if {@code factory} is null. + */ + Builder<S> setFactory(DiscreteProbabilitySamplerFactory factory); + + /** + * Builds the composite sampler. The {@code rng} is the source of randomness for selecting + * which sampler to use for each sample. + * + * <p>Note: When the sampler is created the builder is reset to an empty state. + * This prevents building multiple composite samplers with the same samplers and + * their identical underlying source of randomness. + * + * @param rng Generator of uniformly distributed random numbers. + * @return the sampler + * @throws IllegalStateException if no samplers have been added to create a composite. + * @see #size() + */ + S build(UniformRandomProvider rng); + } + + /** + * Builds a composite sampler. + * + * <p>A single builder can be used to create composites of different implementing classes + * which support different sampler interfaces. The type of sampler is generic. The individual + * samplers and their weights can be collected by the builder. The build method creates + * the discrete probability distribution from the weights. The final composite is created + * using a factory to create the class. + * + * @param <S> Type of sampler + */ + private static class SamplerBuilder<S> implements Builder<S> { + /** The specialisation of the sampler. */ + private final Specialisation specialisation; + /** The weighted samplers. */ + private final List<WeightedSampler<S>> weightedSamplers; + /** The factory to create the discrete probability sampler from the weights. */ + private DiscreteProbabilitySamplerFactory factory; + /** The factory to create the composite sampler. */ + private final SamplerFactory<S> compositeFactory; + + /** + * The specialisation of composite sampler to build. + * This is used to determine if specialised interfaces from the sampler + * type must be supported, e.g. {@link SharedStateSampler}. + */ + enum Specialisation { + /** Instance of {@link SharedStateSampler}. */ + SHARED_STATE_SAMPLER, + /** No specialisation. */ + NONE; + } + + /** + * A factory for creating composite samplers. + * + * <p>This interface is used to build concrete implementations + * of different sampler interfaces. + * + * @param <S> Type of sampler + */ + interface SamplerFactory<S> { + /** + * Creates a new composite sampler. + * + * <p>If the composite specialisation is a + * {@link Specialisation#SHARED_STATE_SAMPLER shared state sampler} + * the discrete sampler passed to this method will be an instance of + * {@link SharedStateDiscreteSampler}. + * + * @param discreteSampler Discrete sampler. + * @param samplers Samplers. + * @return the sampler + */ + S createSampler(DiscreteSampler discreteSampler, + List<S> samplers); + } + + /** + * Contains a weighted sampler. + * + * @param <S> Sampler type + */ + private static class WeightedSampler<S> { + /** The weight. */ + private final double weight; + /** The sampler. */ + private final S sampler; + + /** + * @param weight the weight + * @param sampler the sampler + * @throws IllegalArgumentException if {@code weight} is negative, infinite or {@code NaN}. + * @throws NullPointerException if {@code sampler} is null. + */ + WeightedSampler(double weight, S sampler) { + this.weight = requirePositiveFinite(weight, "weight"); + this.sampler = requireNonNull(sampler, "sampler"); + } + + /** + * Gets the weight. + * + * @return the weight + */ + double getWeight() { + return weight; + } + + /** + * Gets the sampler. + * + * @return the sampler + */ + S getSampler() { + return sampler; + } + + /** + * Checks that the specified value is positive finite and throws a customized + * {@link IllegalArgumentException} if it is not. + * + * @param value the value + * @param message detail message to be used in the event that a {@code + * IllegalArgumentException} is thrown + * @return {@code value} if positive finite + * @throws IllegalArgumentException if {@code weight} is negative, infinite or {@code NaN}. + */ + private static double requirePositiveFinite(double value, String message) { + // Must be positive finite + if (!(value >= 0 && value < Double.POSITIVE_INFINITY)) { + throw new IllegalArgumentException(message + " is not positive finite: " + value); + } + return value; + } + } + + /** + * @param specialisation Specialisation of the sampler. + * @param compositeFactory Factory to create the final composite sampler. + */ + SamplerBuilder(Specialisation specialisation, + SamplerFactory<S> compositeFactory) { + this.specialisation = specialisation; + this.compositeFactory = compositeFactory; + weightedSamplers = new ArrayList<WeightedSampler<S>>(); + factory = DiscreteProbabilitySampler.GUIDE_TABLE; + } + + @Override + public int size() { + return weightedSamplers.size(); + } + + @Override + public Builder<S> add(S sampler, double weight) { + // Ignore zero weights. The sampler and weight are validated by the WeightedSampler. + if (weight != 0) { + weightedSamplers.add(new WeightedSampler<S>(weight, sampler)); + } + return this; + } + + /** + * {@inheritDoc} + * + * <p>If the weights are uniform the factory is ignored and composite's discrete sampler + * is a {@link DiscreteUniformSampler uniform distribution sampler}. + */ + @Override + public Builder<S> setFactory(DiscreteProbabilitySamplerFactory samplerFactory) { + this.factory = requireNonNull(samplerFactory, "factory"); + return this; + } + + /** + * {@inheritDoc} + * + * <p>If only one sampler has been added to the builder then the sampler is returned + * and the builder is reset. + * + * @throws IllegalStateException if no samplers have been added to create a composite. + */ + @Override + public S build(UniformRandomProvider rng) { + final List<WeightedSampler<S>> list = this.weightedSamplers; + final int n = list.size(); + if (n == 0) { + throw new IllegalStateException("No samplers to build the composite"); + } + if (n == 1) { + // No composite + final S sampler = list.get(0).sampler; + reset(); + return sampler; + } + + // Extract the weights and samplers. + final double[] weights = new double[n]; + final ArrayList<S> samplers = new ArrayList<S>(n); + for (int i = 0; i < n; i++) { + final WeightedSampler<S> weightedItem = list.get(i); + weights[i] = weightedItem.getWeight(); + samplers.add(weightedItem.getSampler()); + } + + reset(); + + final DiscreteSampler discreteSampler = createDiscreteSampler(rng, weights); + + return compositeFactory.createSampler(discreteSampler, samplers); + } + + /** + * Reset the builder. + */ + private void reset() { + weightedSamplers.clear(); + } + + /** + * Creates the discrete sampler of the enumerated probability distribution. + * + * <p>If the specialisation is a {@link Specialisation#SHARED_STATE_SAMPLER shared state sampler} + * the discrete sampler will be an instance of {@link SharedStateDiscreteSampler}. + * + * @param rng Generator of uniformly distributed random numbers. + * @param weights Weight associated to each item. + * @return the sampler + */ + private DiscreteSampler createDiscreteSampler(UniformRandomProvider rng, + double[] weights) { + // Edge case. Detect uniform weights. + final int n = weights.length; + if (uniform(weights)) { + // Uniformly sample from the size. + // Note: Upper bound is inclusive. + return DiscreteUniformSampler.of(rng, 0, n - 1); + } + + // If possible normalise with a simple sum. + final double sum = sum(weights); + if (sum < Double.POSITIVE_INFINITY) { + // Do not use f = 1.0 / sum and multiplication by f. + // Use of divide handles a sub-normal sum. + for (int i = 0; i < n; i++) { + weights[i] /= sum; + } + } else { + // The sum is not finite. We know the weights are all positive finite. + // Compute the mean without overflow and divide by the mean and number of items. + final double mean = mean(weights); + for (int i = 0; i < n; i++) { + // Two step division avoids using the denominator (mean * n) + weights[i] = weights[i] / mean / n; + } + } + + // Create the sampler from the factory. + // Check if a SharedStateSampler is required. + // If a default factory then the result is a SharedStateDiscreteSampler, + // otherwise the sampler must be checked. + if (specialisation == Specialisation.SHARED_STATE_SAMPLER && + !(factory instanceof DiscreteProbabilitySampler)) { + // If the factory was user defined then clone the weights as they may be required + // to create a SharedStateDiscreteProbabilitySampler. + final DiscreteSampler sampler = factory.create(rng, weights.clone()); + return sampler instanceof SharedStateDiscreteSampler ? + sampler : + new SharedStateDiscreteProbabilitySampler(sampler, factory, weights); + } + + return factory.create(rng, weights); + } + + /** + * Check if all the values are the same. + * + * <p>Warning: This method assumes there are input values. If the length is zero an + * {@link ArrayIndexOutOfBoundsException} will be thrown. + * + * @param values the values + * @return true if all values are the same + */ + private static boolean uniform(double[] values) { + final double value = values[0]; + for (int i = 1; i < values.length; i++) { + if (value != values[i]) { + return false; + } + } + return true; + } + + /** + * Compute the sum of the values. + * + * @param values the values + * @return the sum + */ + private static double sum(double[] values) { + double sum = 0; + for (final double value : values) { + sum += value; + } + return sum; + } + + /** + * Compute the mean of the values. Uses a rolling algorithm to avoid overflow of a simple sum. + * This method can be used to compute the mean of observed counts for normalisation to a + * probability: + * + * <pre> + * double[] values = ...; + * int n = values.length; + * double mean = mean(values); + * for (int i = 0; i < n; i++) { + * // Two step division avoids using the denominator (mean * n) + * values[i] = values[i] / mean / n; + * } + * </pre> + * + * <p>Warning: This method assumes there are input values. If the length is zero an + * {@link ArrayIndexOutOfBoundsException} will be thrown. + * + * @param values the values + * @return the mean + */ + private static double mean(double[] values) { + double mean = values[0]; + int i = 1; + while (i < values.length) { + // Deviation from the mean + final double dev = values[i] - mean; + i++; + mean += dev / i; + } + return mean; + } + } + + /** + * A composite sampler. + * + * <p>The source sampler for each sampler is chosen based on a user-defined continuous + * probability distribution. + * + * @param <S> Type of sampler + */ + private static class CompositeSampler<S> { + /** Continuous sampler to choose the individual sampler to sample. */ + protected final DiscreteSampler discreteSampler; + /** Collection of samplers to be sampled from. */ + protected final List<S> samplers; + + /** + * @param discreteSampler Continuous sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeSampler(DiscreteSampler discreteSampler, + List<S> samplers) { + this.discreteSampler = discreteSampler; + this.samplers = samplers; + } + + /** + * Gets the next sampler to use to create a sample. + * + * @return the sampler + */ + S nextSampler() { + return samplers.get(discreteSampler.sample()); + } + } + + /** + * A factory for creating a composite ObjectSampler. + * + * @param <T> Type of sample + */ + private static class ObjectSamplerFactory<T> implements + SamplerBuilder.SamplerFactory<ObjectSampler<T>> { + /** The instance. */ + @SuppressWarnings("rawtypes") + private static final ObjectSamplerFactory INSTANCE = new ObjectSamplerFactory(); + + /** + * Get an instance. + * + * @param <T> Type of sample + * @return the factory + */ + static <T> ObjectSamplerFactory<T> instance() { + return INSTANCE; + } + + @Override + public ObjectSampler<T> createSampler(DiscreteSampler discreteSampler, + List<ObjectSampler<T>> samplers) { + return new CompositeObjectSampler<T>(discreteSampler, samplers); + } + + /** + * A composite object sampler. + * + * @param <T> Type of sample + */ + private static class CompositeObjectSampler<T> + extends CompositeSampler<ObjectSampler<T>> + implements ObjectSampler<T> { + /** + * @param discreteSampler Discrete sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeObjectSampler(DiscreteSampler discreteSampler, + List<ObjectSampler<T>> samplers) { + super(discreteSampler, samplers); + } + + @Override + public T sample() { + return nextSampler().sample(); + } + } + } + + /** + * A factory for creating a composite SharedStateObjectSampler. + * + * @param <T> Type of sample + */ + private static class SharedStateObjectSamplerFactory<T> implements + SamplerBuilder.SamplerFactory<SharedStateObjectSampler<T>> { + /** The instance. */ + @SuppressWarnings("rawtypes") + private static final SharedStateObjectSamplerFactory INSTANCE = new SharedStateObjectSamplerFactory(); + + /** + * Get an instance. + * + * @param <T> Type of sample + * @return the factory + */ + static <T> SharedStateObjectSamplerFactory<T> instance() { + return INSTANCE; + } + + @Override + public SharedStateObjectSampler<T> createSampler(DiscreteSampler discreteSampler, + List<SharedStateObjectSampler<T>> samplers) { + // The input discrete sampler is assumed to be a SharedStateDiscreteSampler + return new CompositeSharedStateObjectSampler<T>( + (SharedStateDiscreteSampler) discreteSampler, samplers); + } + + /** + * A composite object sampler with shared state support. + * + * <p>The source sampler for each sampler is chosen based on a user-defined + * discrete probability distribution. + * + * @param <T> Type of sample + */ + private static class CompositeSharedStateObjectSampler<T> + extends CompositeSampler<SharedStateObjectSampler<T>> + implements SharedStateObjectSampler<T> { + /** + * @param discreteSampler Discrete sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeSharedStateObjectSampler(SharedStateDiscreteSampler discreteSampler, + List<SharedStateObjectSampler<T>> samplers) { + super(discreteSampler, samplers); + } + + @Override + public T sample() { + return nextSampler().sample(); + } + + @Override + public CompositeSharedStateObjectSampler<T> withUniformRandomProvider(UniformRandomProvider rng) { + // Duplicate each sampler with the same source of randomness + return new CompositeSharedStateObjectSampler<T>( + ((SharedStateDiscreteSampler) this.discreteSampler).withUniformRandomProvider(rng), + copy(samplers, rng)); + } + } + } + + /** + * A factory for creating a composite DiscreteSampler. + */ + private static class DiscreteSamplerFactory implements + SamplerBuilder.SamplerFactory<DiscreteSampler> { + /** The instance. */ + static final DiscreteSamplerFactory INSTANCE = new DiscreteSamplerFactory(); + + @Override + public DiscreteSampler createSampler(DiscreteSampler discreteSampler, + List<DiscreteSampler> samplers) { + return new CompositeDiscreteSampler(discreteSampler, samplers); + } + + /** + * A composite discrete sampler. + */ + private static class CompositeDiscreteSampler + extends CompositeSampler<DiscreteSampler> + implements DiscreteSampler { + /** + * @param discreteSampler Discrete sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeDiscreteSampler(DiscreteSampler discreteSampler, + List<DiscreteSampler> samplers) { + super(discreteSampler, samplers); + } + + @Override + public int sample() { + return nextSampler().sample(); + } + } + } + + /** + * A factory for creating a composite SharedStateDiscreteSampler. + */ + private static class SharedStateDiscreteSamplerFactory implements + SamplerBuilder.SamplerFactory<SharedStateDiscreteSampler> { + /** The instance. */ + static final SharedStateDiscreteSamplerFactory INSTANCE = new SharedStateDiscreteSamplerFactory(); + + @Override + public SharedStateDiscreteSampler createSampler(DiscreteSampler discreteSampler, + List<SharedStateDiscreteSampler> samplers) { + // The input discrete sampler is assumed to be a SharedStateDiscreteSampler + return new CompositeSharedStateDiscreteSampler( + (SharedStateDiscreteSampler) discreteSampler, samplers); + } + + /** + * A composite discrete sampler with shared state support. + */ + private static class CompositeSharedStateDiscreteSampler + extends CompositeSampler<SharedStateDiscreteSampler> + implements SharedStateDiscreteSampler { + /** + * @param discreteSampler Discrete sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeSharedStateDiscreteSampler(SharedStateDiscreteSampler discreteSampler, + List<SharedStateDiscreteSampler> samplers) { + super(discreteSampler, samplers); + } + + @Override + public int sample() { + return nextSampler().sample(); + } + + @Override + public CompositeSharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + // Duplicate each sampler with the same source of randomness + return new CompositeSharedStateDiscreteSampler( + ((SharedStateDiscreteSampler) this.discreteSampler).withUniformRandomProvider(rng), + copy(samplers, rng)); + } + } + } + + /** + * A factory for creating a composite ContinuousSampler. + */ + private static class ContinuousSamplerFactory implements + SamplerBuilder.SamplerFactory<ContinuousSampler> { + /** The instance. */ + static final ContinuousSamplerFactory INSTANCE = new ContinuousSamplerFactory(); + + @Override + public ContinuousSampler createSampler(DiscreteSampler discreteSampler, + List<ContinuousSampler> samplers) { + return new CompositeContinuousSampler(discreteSampler, samplers); + } + + /** + * A composite continuous sampler. + */ + private static class CompositeContinuousSampler + extends CompositeSampler<ContinuousSampler> + implements ContinuousSampler { + /** + * @param discreteSampler Continuous sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeContinuousSampler(DiscreteSampler discreteSampler, + List<ContinuousSampler> samplers) { + super(discreteSampler, samplers); + } + + @Override + public double sample() { + return nextSampler().sample(); + } + } + } + + /** + * A factory for creating a composite SharedStateContinuousSampler. + */ + private static class SharedStateContinuousSamplerFactory implements + SamplerBuilder.SamplerFactory<SharedStateContinuousSampler> { + /** The instance. */ + static final SharedStateContinuousSamplerFactory INSTANCE = new SharedStateContinuousSamplerFactory(); + + @Override + public SharedStateContinuousSampler createSampler(DiscreteSampler discreteSampler, + List<SharedStateContinuousSampler> samplers) { + // The sampler is assumed to be a SharedStateContinuousSampler + return new CompositeSharedStateContinuousSampler( + (SharedStateDiscreteSampler) discreteSampler, samplers); + } + + /** + * A composite continuous sampler with shared state support. + */ + private static class CompositeSharedStateContinuousSampler + extends CompositeSampler<SharedStateContinuousSampler> + implements SharedStateContinuousSampler { + /** + * @param discreteSampler Continuous sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeSharedStateContinuousSampler(SharedStateDiscreteSampler discreteSampler, + List<SharedStateContinuousSampler> samplers) { + super(discreteSampler, samplers); + } + + @Override + public double sample() { + return nextSampler().sample(); + } + + @Override + public CompositeSharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { + // Duplicate each sampler with the same source of randomness + return new CompositeSharedStateContinuousSampler( + ((SharedStateDiscreteSampler) this.discreteSampler).withUniformRandomProvider(rng), + copy(samplers, rng)); + } + } + } + + /** + * A factory for creating a composite LongSampler. + */ + private static class LongSamplerFactory implements + SamplerBuilder.SamplerFactory<LongSampler> { + /** The instance. */ + static final LongSamplerFactory INSTANCE = new LongSamplerFactory(); + + @Override + public LongSampler createSampler(DiscreteSampler discreteSampler, + List<LongSampler> samplers) { + return new CompositeLongSampler(discreteSampler, samplers); + } + + /** + * A composite long sampler. + */ + private static class CompositeLongSampler + extends CompositeSampler<LongSampler> + implements LongSampler { + /** + * @param discreteSampler Long sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeLongSampler(DiscreteSampler discreteSampler, + List<LongSampler> samplers) { + super(discreteSampler, samplers); + } + + @Override + public long sample() { + return nextSampler().sample(); + } + } + } + + /** + * A factory for creating a composite SharedStateLongSampler. + */ + private static class SharedStateLongSamplerFactory implements + SamplerBuilder.SamplerFactory<SharedStateLongSampler> { + /** The instance. */ + static final SharedStateLongSamplerFactory INSTANCE = new SharedStateLongSamplerFactory(); + + @Override + public SharedStateLongSampler createSampler(DiscreteSampler discreteSampler, + List<SharedStateLongSampler> samplers) { + // The input discrete sampler is assumed to be a SharedStateLongSampler + return new CompositeSharedStateLongSampler( + (SharedStateDiscreteSampler) discreteSampler, samplers); + } + + /** + * A composite long sampler with shared state support. + */ + private static class CompositeSharedStateLongSampler + extends CompositeSampler<SharedStateLongSampler> + implements SharedStateLongSampler { + /** + * @param discreteSampler Long sampler to choose the individual sampler to sample. + * @param samplers Collection of samplers to be sampled from. + */ + CompositeSharedStateLongSampler(SharedStateDiscreteSampler discreteSampler, + List<SharedStateLongSampler> samplers) { + super(discreteSampler, samplers); + } + + @Override + public long sample() { + return nextSampler().sample(); + } + + @Override + public CompositeSharedStateLongSampler withUniformRandomProvider(UniformRandomProvider rng) { + // Duplicate each sampler with the same source of randomness + return new CompositeSharedStateLongSampler( + ((SharedStateDiscreteSampler) this.discreteSampler).withUniformRandomProvider(rng), + copy(samplers, rng)); + } + } + } + + /** No public instances. */ + private CompositeSamplers() {} + + /** + * Create a new builder for a composite {@link ObjectSampler}. + * + * <p>Note: If the compiler cannot infer the type parameter of the sampler it can be specified + * within the diamond operator {@code <T>} preceding the call to + * {@code newObjectSamplerBuilder()}, for example: + * + * <pre>{@code + * CompositeSamplers.<double[]>newObjectSamplerBuilder() + * }</pre> + * + * @param <T> Type of the sample. + * @return the builder + */ + public static <T> Builder<ObjectSampler<T>> newObjectSamplerBuilder() { + final SamplerBuilder.SamplerFactory<ObjectSampler<T>> factory = ObjectSamplerFactory.instance(); + return new SamplerBuilder<ObjectSampler<T>>( + SamplerBuilder.Specialisation.NONE, factory); + } + + /** + * Create a new builder for a composite {@link SharedStateObjectSampler}. + * + * <p>Note: If the compiler cannot infer the type parameter of the sampler it can be specified + * within the diamond operator {@code <T>} preceding the call to + * {@code newSharedStateObjectSamplerBuilder()}, for example: + * + * <pre>{@code + * CompositeSamplers.<double[]>newSharedStateObjectSamplerBuilder() + * }</pre> + * + * @param <T> Type of the sample. + * @return the builder + */ + public static <T> Builder<SharedStateObjectSampler<T>> newSharedStateObjectSamplerBuilder() { + final SamplerBuilder.SamplerFactory<SharedStateObjectSampler<T>> factory = + SharedStateObjectSamplerFactory.instance(); + return new SamplerBuilder<SharedStateObjectSampler<T>>( + SamplerBuilder.Specialisation.SHARED_STATE_SAMPLER, factory); + } + + /** + * Create a new builder for a composite {@link DiscreteSampler}. + * + * @return the builder + */ + public static Builder<DiscreteSampler> newDiscreteSamplerBuilder() { + return new SamplerBuilder<DiscreteSampler>( + SamplerBuilder.Specialisation.NONE, DiscreteSamplerFactory.INSTANCE); + } + + /** + * Create a new builder for a composite {@link SharedStateDiscreteSampler}. + * + * @return the builder + */ + public static Builder<SharedStateDiscreteSampler> newSharedStateDiscreteSamplerBuilder() { + return new SamplerBuilder<SharedStateDiscreteSampler>( + SamplerBuilder.Specialisation.SHARED_STATE_SAMPLER, SharedStateDiscreteSamplerFactory.INSTANCE); + } + + /** + * Create a new builder for a composite {@link ContinuousSampler}. + * + * @return the builder + */ + public static Builder<ContinuousSampler> newContinuousSamplerBuilder() { + return new SamplerBuilder<ContinuousSampler>( + SamplerBuilder.Specialisation.NONE, ContinuousSamplerFactory.INSTANCE); + } + + /** + * Create a new builder for a composite {@link SharedStateContinuousSampler}. + * + * @return the builder + */ + public static Builder<SharedStateContinuousSampler> newSharedStateContinuousSamplerBuilder() { + return new SamplerBuilder<SharedStateContinuousSampler>( + SamplerBuilder.Specialisation.SHARED_STATE_SAMPLER, SharedStateContinuousSamplerFactory.INSTANCE); + } + + /** + * Create a new builder for a composite {@link LongSampler}. + * + * @return the builder + */ + public static Builder<LongSampler> newLongSamplerBuilder() { + return new SamplerBuilder<LongSampler>( + SamplerBuilder.Specialisation.NONE, LongSamplerFactory.INSTANCE); + } + + /** + * Create a new builder for a composite {@link SharedStateLongSampler}. + * + * @return the builder + */ + public static Builder<SharedStateLongSampler> newSharedStateLongSamplerBuilder() { + return new SamplerBuilder<SharedStateLongSampler>( + SamplerBuilder.Specialisation.SHARED_STATE_SAMPLER, SharedStateLongSamplerFactory.INSTANCE); + } + + /** + * Checks that the specified object reference is not {@code null} and throws a + * customized {@link NullPointerException} if it is. + * + * <P>Note: This method is to be replaced with + * {@code java.util.Objects.requireNonNull} when the source requires Java 8. + * + * @param obj the object reference to check for nullity + * @param message detail message to be used in the event that a {@code + * NullPointerException} is thrown + * @param <T> the type of the reference + * @return {@code obj} if not {@code null} + * @throws NullPointerException if {@code obj} is {@code null} + */ + private static <T> T requireNonNull(T obj, String message) { + if (obj == null) { + throw new NullPointerException(message); + } + return obj; + } + + /** + * Create a copy instance of each sampler in the list of samplers using the given + * uniform random provider as the source of randomness. + * + * @param <T> the type of sampler + * @param samplers Source to copy. + * @param rng Generator of uniformly distributed random numbers. + * @return the copy + */ + private static <T extends SharedStateSampler<T>> List<T> copy(List<T> samplers, + UniformRandomProvider rng) { + final ArrayList<T> newSamplers = new ArrayList<T>(samplers.size()); + for (final T s : samplers) { + newSamplers.add(s.withUniformRandomProvider(rng)); + } + return newSamplers; + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CompositeSamplersTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CompositeSamplersTest.java new file mode 100644 index 0000000..2081590 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CompositeSamplersTest.java @@ -0,0 +1,1004 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.rng.sampling; + +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.commons.math3.stat.inference.ChiSquareTest; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.CompositeSamplers.Builder; +import org.apache.commons.rng.sampling.CompositeSamplers.DiscreteProbabilitySampler; +import org.apache.commons.rng.sampling.CompositeSamplers.DiscreteProbabilitySamplerFactory; +import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.ContinuousSampler; +import org.apache.commons.rng.sampling.distribution.DiscreteSampler; +import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.LongSampler; +import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler; +import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.SharedStateLongSampler; +import org.apache.commons.rng.simple.RandomSource; + +/** + * Test class for {@link CompositeSamplers}. + */ +public class CompositeSamplersTest { + /** + * Test the default implementations of the discrete probability sampler factory. + */ + @Test + public void testDiscreteProbabilitySampler() { + final UniformRandomProvider rng = RandomSource.MWC_256.create(78979L); + final double[] probabilities = {0.1, 0.2, 0.3, 0.4}; + final double mean = 0.2 + 2 * 0.3 + 3 * 0.4; + final int n = 1000000; + for (final DiscreteProbabilitySampler item : DiscreteProbabilitySampler.values()) { + final DiscreteSampler sampler = item.create(rng, probabilities.clone()); + long sum = 0; + for (int i = 0; i < n; i++) { + sum += sampler.sample(); + } + Assert.assertEquals(item.name(), mean, (double) sum / n, 1e-3); + } + } + + /** + * Test an empty builder cannot build a sampler. + */ + @Test(expected = IllegalStateException.class) + public void testEmptyBuilderThrows() { + final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create(0L); + final Builder<SharedStateObjectSampler<Integer>> builder = CompositeSamplers + .newSharedStateObjectSamplerBuilder(); + Assert.assertEquals(0, builder.size()); + builder.build(rng); + } + + /** + * Test adding null sampler to a builder. + */ + @Test(expected = NullPointerException.class) + public void testNullSharedStateObjectSamplerThrows() { + final Builder<SharedStateObjectSampler<Integer>> builder = CompositeSamplers + .newSharedStateObjectSamplerBuilder(); + builder.add(null, 1.0); + } + + /** + * Test invalid weights (zero, negative, NaN, infinte). + */ + @Test + public void testInvalidWeights() { + final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create(0L); + final Builder<SharedStateObjectSampler<Integer>> builder = CompositeSamplers + .newSharedStateObjectSamplerBuilder(); + final RangeSampler sampler = new RangeSampler(45, 63, rng); + // Zero weight is ignored + Assert.assertEquals(0, builder.size()); + builder.add(sampler, 0.0); + Assert.assertEquals(0, builder.size()); + + final double[] bad = {-1, Double.NaN, Double.POSITIVE_INFINITY}; + for (final double weight : bad) { + try { + builder.add(sampler, weight); + Assert.fail("Did not detect invalid weight: " + weight); + } catch (final IllegalArgumentException ex) { + // Expected + } + } + } + + /** + * Test a single sampler added to the builder is returned without a composite. + */ + @Test + public void testSingleSharedStateObjectSampler() { + final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create(0L); + final Builder<SharedStateObjectSampler<Integer>> builder = CompositeSamplers + .newSharedStateObjectSamplerBuilder(); + final RangeSampler sampler = new RangeSampler(45, 63, rng); + builder.add(sampler, 1.0); + Assert.assertEquals(1, builder.size()); + final SharedStateObjectSampler<Integer> composite = builder.build(rng); + Assert.assertSame(sampler, composite); + } + + /** + * Test sampling is uniform across several ObjectSampler samplers. + */ + @Test + public void testObjectSamplerSamples() { + final Builder<ObjectSampler<Integer>> builder = CompositeSamplers.newObjectSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.PCG_XSH_RR_32_OS.create(345); + final int n = 15; + final int min = -134; + final int max = 2097; + addObjectSamplers(builder, n, min, max, rng); + assertObjectSamplerSamples(builder.build(rng), min, max); + } + + /** + * Test sampling is uniform across several SharedStateObjectSampler samplers. + */ + @Test + public void testSharedStateObjectSamplerSamples() { + final Builder<SharedStateObjectSampler<Integer>> builder = CompositeSamplers + .newSharedStateObjectSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.PCG_XSH_RS_32_OS.create(299); + final int n = 11; + final int min = 42; + final int max = 678; + addObjectSamplers(builder, n, min, max, rng); + // Exercise the shared state interface + final UniformRandomProvider rng1 = RandomSource.XO_SHI_RO_256_PLUS.create(0x9a8c6f5e); + assertObjectSamplerSamples(builder.build(rng).withUniformRandomProvider(rng1), min, max); + } + + /** + * Test sampling is uniform across several SharedStateObjectSampler samplers + * using a custom factory that implements SharedStateDiscreteSampler. + */ + @Test + public void testSharedStateObjectSamplerSamplesWithCustomSharedStateDiscreteSamplerFactory() { + final Builder<SharedStateObjectSampler<Integer>> builder = CompositeSamplers + .newSharedStateObjectSamplerBuilder(); + final AtomicInteger factoryCount = new AtomicInteger(); + builder.setFactory(new DiscreteProbabilitySamplerFactory() { + @Override + public DiscreteSampler create(UniformRandomProvider rng, double[] probabilities) { + factoryCount.incrementAndGet(); + // Use an expanded table with a non-default alpha + return AliasMethodDiscreteSampler.of(rng, probabilities, 2); + } + }); + final UniformRandomProvider rng = RandomSource.XO_SHI_RO_128_PP.create(0xa6b7c9); + final int n = 7; + final int min = -610; + final int max = 745; + addObjectSamplers(builder, n, min, max, rng); + + // Exercise the shared state interface + final UniformRandomProvider rng1 = RandomSource.XO_SHI_RO_256_PLUS.create(0x1f2e3d); + assertObjectSamplerSamples(builder.build(rng).withUniformRandomProvider(rng1), min, max); + + Assert.assertEquals("Factory should not be used to create the shared state sampler", 1, factoryCount.get()); + } + + /** + * Test sampling is uniform across several SharedStateObjectSampler samplers + * using a custom factory that implements DiscreteSampler (so must be wrapped). + */ + @Test + public void testSharedStateObjectSamplerSamplesWithCustomDiscreteSamplerFactory() { + final Builder<SharedStateObjectSampler<Integer>> builder = CompositeSamplers + .newSharedStateObjectSamplerBuilder(); + final AtomicInteger factoryCount = new AtomicInteger(); + builder.setFactory(new DiscreteProbabilitySamplerFactory() { + @Override + public DiscreteSampler create(UniformRandomProvider rng, double[] probabilities) { + factoryCount.incrementAndGet(); + // Wrap so it is not a SharedStateSamplerInstance. + final DiscreteSampler sampler = GuideTableDiscreteSampler.of(rng, probabilities, 2); + // Destroy the probabilities to check that custom factories are not trusted. + Arrays.fill(probabilities, Double.NaN); + return new DiscreteSampler() { + @Override + public int sample() { + return sampler.sample(); + } + }; + } + }); + final UniformRandomProvider rng = RandomSource.XO_SHI_RO_128_PP.create(0x263478628L); + final int n = 14; + final int min = 56; + final int max = 2033; + addObjectSamplers(builder, n, min, max, rng); + + // Exercise the shared state interface. + // This tests the custom factory is used twice. + final UniformRandomProvider rng1 = RandomSource.XO_SHI_RO_256_PLUS.create(0x8c7b6a); + assertObjectSamplerSamples(builder.build(rng).withUniformRandomProvider(rng1), min, max); + + Assert.assertEquals("Factory should be used to create the shared state sampler", 2, factoryCount.get()); + } + + /** + * Test sampling is uniform across several ObjectSampler samplers with a uniform + * weighting. This tests an edge case where there is no requirement for a + * sampler from a discrete probability distribution as the distribution is + * uniform. + */ + @Test + public void testObjectSamplerSamplesWithUniformWeights() { + final Builder<ObjectSampler<Integer>> builder = CompositeSamplers.newObjectSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.JSF_64.create(678345); + final int max = 60; + final int interval = 10; + for (int min = 0; min < max; min += interval) { + builder.add(new RangeSampler(min, min + interval, rng), 1.0); + } + assertObjectSamplerSamples(builder.build(rng), 0, max); + } + + /** + * Test sampling is uniform across several ObjectSampler samplers with very + * large weights. This tests an edge case where the weights with sum to + * infinity. + */ + @Test + public void testObjectSamplerSamplesWithVeryLargeWeights() { + final Builder<ObjectSampler<Integer>> builder = CompositeSamplers.newObjectSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.SFC_64.create(267934293); + // Ratio 4:4:2:1 + // The weights will sum to infinity as they are more than 2^1024. + final double w4 = 0x1.0p1023; + final double w2 = 0x1.0p1022; + final double w1 = 0x1.0p1021; + Assert.assertEquals(Double.POSITIVE_INFINITY, w4 + w4 + w2 + w1, 0.0); + builder.add(new RangeSampler(0, 40, rng), w4); + builder.add(new RangeSampler(40, 80, rng), w4); + builder.add(new RangeSampler(80, 100, rng), w2); + builder.add(new RangeSampler(100, 110, rng), w1); + assertObjectSamplerSamples(builder.build(rng), 0, 110); + } + + /** + * Test sampling is uniform across several ObjectSampler samplers with very + * small weights. This tests an edge case where the weights divided by their sum + * are valid (due to accurate floating-point division) but cannot be multiplied + * by the reciprocal of the sum. + */ + @Test + public void testObjectSamplerSamplesWithSubNormalWeights() { + final Builder<ObjectSampler<Integer>> builder = CompositeSamplers.newObjectSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.MSWS.create(6786); + // Ratio 4:4:2:1 + // The weights are very small sub-normal numbers + final double w4 = Double.MIN_VALUE * 4; + final double w2 = Double.MIN_VALUE * 2; + final double w1 = Double.MIN_VALUE; + final double sum = w4 + w4 + w2 + w1; + // Cannot do a divide by multiplying by the reciprocal + Assert.assertEquals(Double.POSITIVE_INFINITY, 1.0 / sum, 0.0); + // A divide works so the sampler should work + Assert.assertEquals(4.0 / 11, w4 / sum, 0.0); + Assert.assertEquals(2.0 / 11, w2 / sum, 0.0); + Assert.assertEquals(1.0 / 11, w1 / sum, 0.0); + builder.add(new RangeSampler(0, 40, rng), w4); + builder.add(new RangeSampler(40, 80, rng), w4); + builder.add(new RangeSampler(80, 100, rng), w2); + builder.add(new RangeSampler(100, 110, rng), w1); + assertObjectSamplerSamples(builder.build(rng), 0, 110); + } + + /** + * Add samplers to the builder that sample from contiguous ranges between the + * minimum and maximum. Note: {@code max - min >= n} + * + * @param builder the builder + * @param n the number of samplers (must be {@code >= 2}) + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + * @param rng the source of randomness + */ + private static void addObjectSamplers(Builder<? super SharedStateObjectSampler<Integer>> builder, int n, int min, + int max, UniformRandomProvider rng) { + // Create the ranges using n-1 random ticks in the range (min, max), + // adding the limits and then sorting in ascending order. + // The samplers are then constructed: + // + // min-------A-----B----max + // Sampler 1 = [min, A) + // Sampler 2 = [A, B) + // Sampler 3 = [B, max) + + // Use a combination sampler to ensure the ticks are unique in the range. + // This will throw if the range is negative. + final int range = max - min - 1; + int[] ticks = new CombinationSampler(rng, range, n - 1).sample(); + // Shift the ticks into the range + for (int i = 0; i < ticks.length; i++) { + ticks[i] += min + 1; + } + // Add the min and max + ticks = Arrays.copyOf(ticks, n + 1); + ticks[n - 1] = min; + ticks[n] = max; + Arrays.sort(ticks); + + // Sample within the ranges between the ticks + final int before = builder.size(); + for (int i = 1; i < ticks.length; i++) { + final RangeSampler sampler = new RangeSampler(ticks[i - 1], ticks[i], rng); + // Weight using the range + builder.add(sampler, sampler.range); + } + + Assert.assertEquals("Failed to add the correct number of samplers", n, builder.size() - before); + } + + /** + * Assert sampling is uniform between the minimum and maximum. + * + * @param sampler the sampler + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + */ + private static void assertObjectSamplerSamples(ObjectSampler<Integer> sampler, int min, int max) { + final int n = 100000; + final long[] observed = new long[max - min]; + for (int i = 0; i < n; i++) { + observed[sampler.sample() - min]++; + } + + final double[] expected = new double[observed.length]; + Arrays.fill(expected, (double) n / expected.length); + final double p = new ChiSquareTest().chiSquareTest(expected, observed); + Assert.assertFalse("p-value too small: " + p, p < 0.001); + } + + /** + * Test sampling is uniform across several DiscreteSampler samplers. + */ + @Test + public void testDiscreteSamplerSamples() { + final Builder<DiscreteSampler> builder = CompositeSamplers.newDiscreteSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.PCG_XSH_RR_32_OS.create(345); + final int n = 15; + final int min = -134; + final int max = 2097; + addDiscreteSamplers(builder, n, min, max, rng); + assertDiscreteSamplerSamples(builder.build(rng), min, max); + } + + /** + * Test sampling is uniform across several SharedStateDiscreteSampler samplers. + */ + @Test + public void testSharedStateDiscreteSamplerSamples() { + final Builder<SharedStateDiscreteSampler> builder = CompositeSamplers.newSharedStateDiscreteSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.PCG_XSH_RS_32_OS.create(299); + final int n = 11; + final int min = 42; + final int max = 678; + addDiscreteSamplers(builder, n, min, max, rng); + assertDiscreteSamplerSamples(builder.build(rng), min, max); + } + + /** + * Add samplers to the builder that sample from contiguous ranges between the + * minimum and maximum. Note: {@code max - min >= n} + * + * @param builder the builder + * @param n the number of samplers (must be {@code >= 2}) + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + * @param rng the source of randomness + */ + private static void addDiscreteSamplers(Builder<? super SharedStateDiscreteSampler> builder, int n, int min, + int max, UniformRandomProvider rng) { + // Create the ranges using n-1 random ticks in the range (min, max), + // adding the limits and then sorting in ascending order. + // The samplers are then constructed: + // + // min-------A-----B----max + // Sampler 1 = [min, A) + // Sampler 2 = [A, B) + // Sampler 3 = [B, max) + + // Use a combination sampler to ensure the ticks are unique in the range. + // This will throw if the range is negative. + final int range = max - min - 1; + int[] ticks = new CombinationSampler(rng, range, n - 1).sample(); + // Shift the ticks into the range + for (int i = 0; i < ticks.length; i++) { + ticks[i] += min + 1; + } + // Add the min and max + ticks = Arrays.copyOf(ticks, n + 1); + ticks[n - 1] = min; + ticks[n] = max; + Arrays.sort(ticks); + + // Sample within the ranges between the ticks + final int before = builder.size(); + for (int i = 1; i < ticks.length; i++) { + final IntRangeSampler sampler = new IntRangeSampler(ticks[i - 1], ticks[i], rng); + // Weight using the range + builder.add(sampler, sampler.range); + } + + Assert.assertEquals("Failed to add the correct number of samplers", n, builder.size() - before); + } + + /** + * Assert sampling is uniform between the minimum and maximum. + * + * @param sampler the sampler + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + */ + private static void assertDiscreteSamplerSamples(DiscreteSampler sampler, int min, int max) { + final int n = 100000; + final long[] observed = new long[max - min]; + for (int i = 0; i < n; i++) { + observed[sampler.sample() - min]++; + } + + final double[] expected = new double[observed.length]; + Arrays.fill(expected, (double) n / expected.length); + final double p = new ChiSquareTest().chiSquareTest(expected, observed); + Assert.assertFalse("p-value too small: " + p, p < 0.001); + } + + /** + * Test sampling is uniform across several ContinuousSampler samplers. + */ + @Test + public void testContinuousSamplerSamples() { + final Builder<ContinuousSampler> builder = CompositeSamplers.newContinuousSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.XO_SHI_RO_256_PP.create(9283756); + final int n = 15; + final double min = 67.2; + final double max = 2033.8; + addContinuousSamplers(builder, n, min, max, rng); + assertContinuousSamplerSamples(builder.build(rng), min, max); + } + + /** + * Test sampling is uniform across several SharedStateContinuousSampler samplers. + */ + @Test + public void testSharedStateContinuousSamplerSamples() { + final Builder<SharedStateContinuousSampler> builder = CompositeSamplers + .newSharedStateContinuousSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.PCG_RXS_M_XS_64_OS.create(0x567567345L); + final int n = 11; + final double min = -15.7; + final double max = 123.4; + addContinuousSamplers(builder, n, min, max, rng); + assertContinuousSamplerSamples(builder.build(rng), min, max); + } + + /** + * Add samplers to the builder that sample from contiguous ranges between the + * minimum and maximum. Note: {@code max - min >= n} + * + * @param builder the builder + * @param n the number of samplers (must be {@code >= 2}) + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + * @param rng the source of randomness + */ + private static void addContinuousSamplers(Builder<? super SharedStateContinuousSampler> builder, int n, double min, + double max, UniformRandomProvider rng) { + // Create the ranges using n-1 random ticks in the range (min, max), + // adding the limits and then sorting in ascending order. + // The samplers are then constructed: + // + // min-------A-----B----max + // Sampler 1 = [min, A) + // Sampler 2 = [A, B) + // Sampler 3 = [B, max) + + // For double values it is extremely unlikely the same value will be generated. + // An assertion is performed to ensure we create the correct number of samplers. + DoubleRangeSampler sampler = new DoubleRangeSampler(min, max, rng); + final double[] ticks = new double[n + 1]; + ticks[0] = min; + ticks[1] = max; + // Shift the ticks into the range + for (int i = 2; i < ticks.length; i++) { + ticks[i] = sampler.sample(); + } + Arrays.sort(ticks); + + // Sample within the ranges between the ticks + final int before = builder.size(); + for (int i = 1; i < ticks.length; i++) { + sampler = new DoubleRangeSampler(ticks[i - 1], ticks[i], rng); + // Weight using the range + builder.add(sampler, sampler.range()); + } + + Assert.assertEquals("Failed to add the correct number of samplers", n, builder.size() - before); + } + + /** + * Assert sampling is uniform between the minimum and maximum. + * + * @param sampler the sampler + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + */ + private static void assertContinuousSamplerSamples(ContinuousSampler sampler, double min, double max) { + final int n = 100000; + final int bins = 200; + final long[] observed = new long[bins]; + final double scale = bins / (max - min); + for (int i = 0; i < n; i++) { + // scale the sample into a bin within the range: + // bin = bins * (x - min) / (max - min) + observed[(int) (scale * (sampler.sample() - min))]++; + } + + final double[] expected = new double[observed.length]; + Arrays.fill(expected, (double) n / expected.length); + final double p = new ChiSquareTest().chiSquareTest(expected, observed); + Assert.assertFalse("p-value too small: " + p, p < 0.001); + } + + /** + * Test sampling is uniform across several LongSampler samplers. + */ + @Test + public void testLongSamplerSamples() { + final Builder<LongSampler> builder = CompositeSamplers.newLongSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.KISS.create(67842321783L); + final int n = 15; + final long min = -134; + final long max = 1L << 54; + addLongSamplers(builder, n, min, max, rng); + assertLongSamplerSamples(builder.build(rng), min, max); + } + + /** + * Test sampling is uniform across several SharedStateLongSampler samplers. + */ + @Test + public void testSharedStateLongSamplerSamples() { + final Builder<SharedStateLongSampler> builder = CompositeSamplers.newSharedStateLongSamplerBuilder(); + final UniformRandomProvider rng = RandomSource.KISS.create(12369279382030L); + final int n = 11; + final long min = 42; + final long max = 1L << 53; + addLongSamplers(builder, n, min, max, rng); + assertLongSamplerSamples(builder.build(rng), min, max); + } + + /** + * Add samplers to the builder that sample from contiguous ranges between the + * minimum and maximum. Note: {@code max - min >= n} + * + * @param builder the builder + * @param n the number of samplers (must be {@code >= 2}) + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + * @param rng the source of randomness + */ + private static void addLongSamplers(Builder<? super SharedStateLongSampler> builder, int n, long min, + long max, UniformRandomProvider rng) { + // Create the ranges using n-1 random ticks in the range (min, max), + // adding the limits and then sorting in ascending order. + // The samplers are then constructed: + // + // min-------A-----B----max + // Sampler 1 = [min, A) + // Sampler 2 = [A, B) + // Sampler 3 = [B, max) + + // For long values it is extremely unlikely the same value will be generated. + // An assertion is performed to ensure we create the correct number of samplers. + LongRangeSampler sampler = new LongRangeSampler(min, max, rng); + final long[] ticks = new long[n + 1]; + ticks[0] = min; + ticks[1] = max; + // Shift the ticks into the range + for (int i = 2; i < ticks.length; i++) { + ticks[i] = sampler.sample(); + } + Arrays.sort(ticks); + + + // Sample within the ranges between the ticks + final int before = builder.size(); + for (int i = 1; i < ticks.length; i++) { + sampler = new LongRangeSampler(ticks[i - 1], ticks[i], rng); + // Weight using the range + builder.add(sampler, sampler.range); + } + + Assert.assertEquals("Failed to add the correct number of samplers", n, builder.size() - before); + } + + /** + * Assert sampling is uniform between the minimum and maximum. + * + * @param sampler the sampler + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + */ + private static void assertLongSamplerSamples(LongSampler sampler, long min, long max) { + final int n = 100000; + final int bins = 200; + final long[] observed = new long[bins]; + final long range = max - min; + for (int i = 0; i < n; i++) { + // scale the sample into a bin within the range: + observed[(int) (bins * (sampler.sample() - min) / range)]++; + } + + final double[] expected = new double[observed.length]; + Arrays.fill(expected, (double) n / expected.length); + final double p = new ChiSquareTest().chiSquareTest(expected, observed); + Assert.assertFalse("p-value too small: " + p, p < 0.001); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateObjectSampler. + */ + @Test + public void testSharedStateObjectSampler() { + testSharedStateObjectSampler(false); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateObjectSampler with a factory that does not support a shared state sampler. + */ + @Test + public void testSharedStateObjectSamplerWithCustomFactory() { + testSharedStateObjectSampler(true); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateObjectSampler. + * + * @param customFactory Set to true to use a custom discrete sampler factory that does not + * support a shared stated sampler. + */ + private static void testSharedStateObjectSampler(boolean customFactory) { + final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(0L); + final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(0L); + + final Builder<SharedStateObjectSampler<Integer>> builder = CompositeSamplers + .newSharedStateObjectSamplerBuilder(); + + if (customFactory) { + addFactoryWithNoSharedStateSupport(builder); + } + + // Sample within the ranges between the ticks + final int[] ticks = {6, 13, 42, 99}; + for (int i = 1; i < ticks.length; i++) { + final RangeSampler sampler = new RangeSampler(ticks[i - 1], ticks[i], rng1); + // Weight using the range + builder.add(sampler, sampler.range); + } + + final SharedStateObjectSampler<Integer> sampler1 = builder.build(rng1); + final SharedStateObjectSampler<Integer> sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(new RandomAssert.Sampler<Integer>() { + @Override + public Integer sample() { + return sampler1.sample(); + } + }, new RandomAssert.Sampler<Integer>() { + @Override + public Integer sample() { + return sampler2.sample(); + } + }); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateDiscreteSampler. + */ + @Test + public void testSharedStateDiscreteSampler() { + testSharedStateDiscreteSampler(false); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateDiscreteSampler with a factory that does not support a shared state sampler. + */ + @Test + public void testSharedStateDiscreteSamplerWithCustomFactory() { + testSharedStateDiscreteSampler(true); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateDiscreteSampler. + * + * @param customFactory Set to true to use a custom discrete sampler factory that does not + * support a shared stated sampler. + */ + private static void testSharedStateDiscreteSampler(boolean customFactory) { + final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(0L); + final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(0L); + + final Builder<SharedStateDiscreteSampler> builder = CompositeSamplers.newSharedStateDiscreteSamplerBuilder(); + + if (customFactory) { + addFactoryWithNoSharedStateSupport(builder); + } + + // Sample within the ranges between the ticks + final int[] ticks = {-3, 5, 14, 22}; + for (int i = 1; i < ticks.length; i++) { + final IntRangeSampler sampler = new IntRangeSampler(ticks[i - 1], ticks[i], rng1); + // Weight using the range + builder.add(sampler, sampler.range); + } + + final SharedStateDiscreteSampler sampler1 = builder.build(rng1); + final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateContinuousSampler. + */ + @Test + public void testSharedStateContinuousSampler() { + testSharedStateContinuousSampler(false); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateContinuousSampler with a factory that does not support a shared state sampler. + */ + @Test + public void testSharedStateContinuousSamplerWithCustomFactory() { + testSharedStateContinuousSampler(true); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateContinuousSampler. + * + * @param customFactory Set to true to use a custom discrete sampler factory that does not + * support a shared stated sampler. + */ + private static void testSharedStateContinuousSampler(boolean customFactory) { + final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(0L); + final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(0L); + + final Builder<SharedStateContinuousSampler> builder = CompositeSamplers + .newSharedStateContinuousSamplerBuilder(); + + if (customFactory) { + addFactoryWithNoSharedStateSupport(builder); + } + + // Sample within the ranges between the ticks + final double[] ticks = {7.89, 13.99, 21.7, 35.6, 45.5}; + for (int i = 1; i < ticks.length; i++) { + final DoubleRangeSampler sampler = new DoubleRangeSampler(ticks[i - 1], ticks[i], rng1); + // Weight using the range + builder.add(sampler, sampler.range()); + } + + final SharedStateContinuousSampler sampler1 = builder.build(rng1); + final SharedStateContinuousSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** + * Adds a DiscreteSamplerFactory to the builder that creates samplers that do not share state. + * + * @param builder the builder + */ + private static void addFactoryWithNoSharedStateSupport(Builder<?> builder) { + builder.setFactory(new DiscreteProbabilitySamplerFactory() { + @Override + public DiscreteSampler create(UniformRandomProvider rng, double[] probabilities) { + // Wrap so it is not a SharedStateSamplerInstance. + final DiscreteSampler sampler = GuideTableDiscreteSampler.of(rng, probabilities, 2); + // Destroy the probabilities to check that custom factories are not trusted. + Arrays.fill(probabilities, Double.NaN); + return new DiscreteSampler() { + @Override + public int sample() { + return sampler.sample(); + } + }; + } + }); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateLongSampler. + */ + @Test + public void testSharedStateLongSampler() { + testSharedStateLongSampler(false); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateLongSampler with a factory that does not support a shared state sampler. + */ + @Test + public void testSharedStateLongSamplerWithCustomFactory() { + testSharedStateLongSampler(true); + } + + /** + * Test the SharedStateSampler implementation for the composite + * SharedStateLongSampler. + * + * @param customFactory Set to true to use a custom discrete sampler factory that does not + * support a shared stated sampler. + */ + private static void testSharedStateLongSampler(boolean customFactory) { + final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(0L); + final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(0L); + + final Builder<SharedStateLongSampler> builder = CompositeSamplers.newSharedStateLongSamplerBuilder(); + + if (customFactory) { + addFactoryWithNoSharedStateSupport(builder); + } + + // Sample within the ranges between the ticks + final long[] ticks = {-32634628368L, 516234712, 1472839427384234L, 72364572187368423L}; + for (int i = 1; i < ticks.length; i++) { + final LongRangeSampler sampler = new LongRangeSampler(ticks[i - 1], ticks[i], rng1); + // Weight using the range + builder.add(sampler, sampler.range); + } + + final SharedStateLongSampler sampler1 = builder.build(rng1); + final SharedStateLongSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** + * Sample an object {@code Integer} from a range. + */ + private static class RangeSampler implements SharedStateObjectSampler<Integer> { + private final int min; + private final int range; + private final UniformRandomProvider rng; + + /** + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + * @param rng the source of randomness + */ + RangeSampler(int min, int max, UniformRandomProvider rng) { + this.min = min; + this.range = max - min; + this.rng = rng; + } + + @Override + public Integer sample() { + return min + rng.nextInt(range); + } + + @Override + public SharedStateObjectSampler<Integer> withUniformRandomProvider(UniformRandomProvider generator) { + return new RangeSampler(min, min + range, generator); + } + } + + /** + * Sample a primitive {@code integer} from a range. + */ + private static class IntRangeSampler implements SharedStateDiscreteSampler { + private final int min; + private final int range; + private final UniformRandomProvider rng; + + /** + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + * @param rng the source of randomness + */ + IntRangeSampler(int min, int max, UniformRandomProvider rng) { + this.min = min; + this.range = max - min; + this.rng = rng; + } + + @Override + public int sample() { + return min + rng.nextInt(range); + } + + @Override + public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider generator) { + return new IntRangeSampler(min, min + range, generator); + } + } + + /** + * Sample a primitive {@code double} from a range between a and b. + */ + private static class DoubleRangeSampler implements SharedStateContinuousSampler { + private final double a; + private final double b; + private final UniformRandomProvider rng; + + /** + * @param a bound a + * @param b bound b + * @param rng the source of randomness + */ + DoubleRangeSampler(double a, double b, UniformRandomProvider rng) { + this.a = a; + this.b = b; + this.rng = rng; + } + + /** + * Get the range from a to b. + * + * @return the range + */ + double range() { + return Math.abs(b - a); + } + + @Override + public double sample() { + // a + u * (b - a) == u * b + (1 - u) * a + final double u = rng.nextDouble(); + return u * b + (1 - u) * a; + } + + @Override + public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider generator) { + return new DoubleRangeSampler(a, b, generator); + } + } + + /** + * Sample a primitive {@code long} from a range. + */ + private static class LongRangeSampler implements SharedStateLongSampler { + private final long min; + private final long range; + private final UniformRandomProvider rng; + + /** + * @param min the minimum (inclusive) + * @param max the maximum (exclusive) + * @param rng the source of randomness + */ + LongRangeSampler(long min, long max, UniformRandomProvider rng) { + this.min = min; + this.range = max - min; + this.rng = rng; + } + + @Override + public long sample() { + return min + rng.nextLong(range); + } + + @Override + public SharedStateLongSampler withUniformRandomProvider(UniformRandomProvider generator) { + return new LongRangeSampler(min, min + range, generator); + } + } +} diff --git a/src/main/resources/pmd/pmd-ruleset.xml b/src/main/resources/pmd/pmd-ruleset.xml index a86c1ae..bf05492 100644 --- a/src/main/resources/pmd/pmd-ruleset.xml +++ b/src/main/resources/pmd/pmd-ruleset.xml @@ -75,7 +75,7 @@ <!-- Array is generated internally in this case. --> <property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[@SimpleName='PoissonSamplerCache' or @SimpleName='AliasMethodDiscreteSampler' - or @SimpleName='GuideTableDiscreteSampler']"/> + or @SimpleName='GuideTableDiscreteSampler' or @SimpleName='SharedStateDiscreteProbabilitySampler']"/> </properties> </rule> <rule ref="category/java/bestpractices.xml/SystemPrintln"> @@ -117,6 +117,12 @@ value="//ClassOrInterfaceDeclaration[@SimpleName='ProbabilityDensityApproximationCommand']"/> </properties> </rule> + <rule ref="category/java/codestyle.xml/LinguisticNaming"> + <properties> + <!-- Allow Builder set methods to return the Builder (not void) --> + <property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[matches(@SimpleName, '^.*Builder$')]"/> + </properties> + </rule> <rule ref="category/java/design.xml/NPathComplexity"> <properties> @@ -152,7 +158,8 @@ <rule ref="category/java/design.xml/ExcessiveClassLength"> <properties> <!-- The length is due to multiple implementations as inner classes --> - <property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[@SimpleName='MarsagliaTsangWangDiscreteSampler']"/> + <property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[@SimpleName='MarsagliaTsangWangDiscreteSampler' + or @SimpleName='CompositeSamplers']"/> </properties> </rule> <rule ref="category/java/design.xml/LogicInversion"> @@ -171,12 +178,24 @@ or @SimpleName='UniformSamplingVisualCheckCommand']"/> </properties> </rule> + <rule ref="category/java/design.xml/AvoidThrowingNullPointerException"> + <properties> + <!-- Local implementation of Objects.requireNonNull --> + <property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[@SimpleName='CompositeSamplers']"/> + </properties> + </rule> <rule ref="category/java/errorprone.xml/AvoidLiteralsInIfCondition"> <properties> <property name="ignoreMagicNumbers" value="-1,0,1" /> </properties> </rule> + <rule ref="category/java/errorprone.xml/AvoidFieldNameMatchingMethodName"> + <properties> + <!-- Field INSTANCE matches instance() which returne a generic typed version of the instance. --> + <property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[matches(@SimpleName, '^.*ObjectSamplerFactory$')]"/> + </properties> + </rule> <rule ref="category/java/multithreading.xml/UseConcurrentHashMap"> <properties>