RNG-30: Sampling from generic collection. Some of this code was copied and adapted from the development version of "Commons Math" (class "o.a.c.m.random.RandomUtils").
Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/b6cce0d9 Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/b6cce0d9 Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/b6cce0d9 Branch: refs/heads/master Commit: b6cce0d92fd4dee4c17e62a390d48173015b9872 Parents: 0fce78b Author: Gilles <er...@apache.org> Authored: Tue Nov 15 17:25:00 2016 +0100 Committer: Gilles <er...@apache.org> Committed: Tue Nov 15 17:25:00 2016 +0100 ---------------------------------------------------------------------- .../commons/rng/sampling/CollectionSampler.java | 134 ++++++++++++++ .../rng/sampling/CollectionSamplerTest.java | 173 +++++++++++++++++++ 2 files changed, 307 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-rng/blob/b6cce0d9/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java new file mode 100644 index 0000000..222c8a5 --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.rng.sampling; + +import java.util.Collection; +import java.util.List; +import java.util.ArrayList; + +import org.apache.commons.rng.UniformRandomProvider; + +/** + * Sampling from a {@link Collection}. + * + * @param <T> Type of items in the collection. + * + * This class also contains utilities for shuffling a generic {@link List}. + * + * @since 1.0 + */ +public class CollectionSampler<T> { + /** Collection to be sampled from. */ + private final ArrayList<T> items; + /** Permutation. */ + private final PermutationSampler permutation; + /** Size of returned list. */ + private final int size; + + /** + * Creates a sampler. + * + * The {@link #sample()} method will generate a collection of + * size {@code k} whose entries are selected randomly, without + * repetition, from the items in the given {@code collection}. + * + * @param rng Generator of uniformly distributed random numbers. + * @param collection Collection to be sampled. + * A (shallow) copy will be stored in the created instance. + * @param k Size of the permutation. + * @throws IllegalArgumentException if {@code k <= 0} or + * {@code k > collection.size()}. + */ + public CollectionSampler(UniformRandomProvider rng, + Collection<T> collection, + int k) { + permutation = new PermutationSampler(rng, collection.size(), k); + items = new ArrayList<T>(collection); + size = k; + } + + /** + * Creates a list of objects selected randomly, without repetition, + * from the collection provided at + * {@link #CollectionSampler(UniformRandomProvider,Collection,int) + * construction}. + * + * <p> + * Sampling is without replacement; but if the source collection + * contains identical objects, the sample may include repeats. + * </p> + * <p> + * There is no guarantee that the concrete type of the returned + * collection is the same as the source collection. + * </p> + * + * @return a random sample. + */ + public Collection<T> sample() { + final ArrayList<T> result = new ArrayList<T>(size); + final int[] index = permutation.sample(); + + for (int i = 0; i < size; i++) { + result.add(items.get(index[i])); + } + + return result; + } + /** + * Shuffles the entries of the given array. + * + * @see #shuffle(List,int,boolean,UniformRandomProvider) + * + * @param <S> Type of the list items. + * @param list List whose entries will be shuffled (in-place). + * @param rng Random number generator. + */ + public static <S> void shuffle(List<S> list, + UniformRandomProvider rng) { + shuffle(list, 0, false, rng); + } + + /** + * Shuffles the entries of the given array, using the + * <a href="http://en.wikipedia.org/wiki/FisherâYates_shuffle#The_modern_algorithm"> + * Fisher-Yates</a> algorithm. + * The {@code start} and {@code pos} parameters select which part + * of the array is randomized and which is left untouched. + * + * @param <S> Type of the list items. + * @param list List whose entries will be shuffled (in-place). + * @param start Index at which shuffling begins. + * @param towardHead Shuffling is performed for index positions between + * {@code start} and either the end (if {@code false}) or the beginning + * (if {@code true}) of the array. + * @param rng Random number generator. + */ + public static <S> void shuffle(List<S> list, + int start, + boolean towardHead, + UniformRandomProvider rng) { + final int len = list.size(); + final int[] indices = PermutationSampler.natural(len); + PermutationSampler.shuffle(indices, start, towardHead, rng); + + final ArrayList<S> items = new ArrayList<S>(list); + for (int i = 0; i < len; i++) { + list.set(i, items.get(indices[i])); + } + } +} http://git-wip-us.apache.org/repos/asf/commons-rng/blob/b6cce0d9/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CollectionSamplerTest.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CollectionSamplerTest.java new file mode 100644 index 0000000..3bd897c --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CollectionSamplerTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling; + +import java.util.Set; +import java.util.HashSet; +import java.util.List; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Arrays; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.commons.math3.stat.inference.ChiSquareTest; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; + +/** + * Tests for {@link CollectionSampler}. + */ +public class CollectionSamplerTest { + private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 6543432321L); + private final ChiSquareTest chiSquareTest = new ChiSquareTest(); + + @Test + public void testSample() { + final String[][] c = { { "0", "1" }, { "0", "2" }, { "0", "3" }, { "0", "4" }, + { "1", "2" }, { "1", "3" }, { "1", "4" }, + { "2", "3" }, { "2", "4" }, + { "3", "4" } }; + final long[] observed = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + final double[] expected = { 100, 100, 100, 100, 100, 100, 100, 100, 100, 100 }; + + final HashSet<String> cPop = new HashSet<String>(); // {0, 1, 2, 3, 4}. + for (int i = 0; i < 5; i++) { + cPop.add(Integer.toString(i)); + } + + final List<Set<String>> sets = new ArrayList<Set<String>>(); // 2-sets from 5. + for (int i = 0; i < 10; i++) { + final HashSet<String> hs = new HashSet<String>(); + hs.add(c[i][0]); + hs.add(c[i][1]); + sets.add(hs); + } + + final CollectionSampler<String> sampler = new CollectionSampler<String>(rng, cPop, 2); + for (int i = 0; i < 1000; i++) { + observed[findSample(sets, sampler.sample())]++; + } + + // Pass if we cannot reject null hypothesis that distributions are the same + Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001)); + } + + @Test + public void testSampleWhole() { + // Sample of size = size of collection must return the same collection. + final HashSet<String> hs = new HashSet<String>(); + hs.add("one"); + + final CollectionSampler<String> sampler = new CollectionSampler<String>(rng, hs, 1); + final Collection<String> one = sampler.sample(); + Assert.assertEquals(1, one.size()); + Assert.assertTrue(one.contains("one")); + } + + @Test(expected=IllegalArgumentException.class) + public void testSamplePrecondition1() { + // Must fail for sample size > collection size. + final HashSet<String> hs = new HashSet<String>(); + hs.add("one"); + new CollectionSampler<String>(rng, hs, 2).sample(); + } + + @Test(expected=IllegalArgumentException.class) + public void testSamplePrecondition2() { + // Must fail for empty collection. + final HashSet<String> hs = new HashSet<String>(); + new CollectionSampler<String>(rng, hs, 0).sample(); + } + + @Test + public void testShuffleTail() { + final List<Integer> orig = new ArrayList<Integer>(); + for (int i = 0; i < 10; i++) { + orig.add((i + 1) * rng.nextInt()); + } + final List<Integer> list = new ArrayList<Integer>(orig); + + final int start = 4; + CollectionSampler.shuffle(list, start, false, rng); + + // Ensure that all entries below index "start" did not move. + for (int i = 0; i < start; i++) { + Assert.assertEquals(orig.get(i), list.get(i)); + } + + // Ensure that at least one entry has moved. + boolean ok = false; + for (int i = start; i < orig.size() - 1; i++) { + if (!orig.get(i).equals(list.get(i))) { + ok = true; + break; + } + } + Assert.assertTrue(ok); + } + + @Test + public void testShuffleHead() { + final List<Integer> orig = new ArrayList<Integer>(); + for (int i = 0; i < 10; i++) { + orig.add((i + 1) * rng.nextInt()); + } + final List<Integer> list = new ArrayList<Integer>(orig); + + final int start = 4; + CollectionSampler.shuffle(list, start, true, rng); + + // Ensure that all entries above index "start" did not move. + for (int i = start + 1; i < orig.size(); i++) { + Assert.assertEquals(orig.get(i), list.get(i)); + } + + // Ensure that at least one entry has moved. + boolean ok = false; + for (int i = 0; i <= start; i++) { + if (!orig.get(i).equals(list.get(i))) { + ok = true; + break; + } + } + Assert.assertTrue(ok); + } + + //// Support methods. + + private <T extends Set<String>> int findSample(List<T> u, + Collection<String> sampList) { + final String[] samp = sampList.toArray(new String[sampList.size()]); + for (int i = 0; i < u.size(); i++) { + final T set = u.get(i); + final HashSet<String> sampSet = new HashSet<String>(); + for (int j = 0; j < samp.length; j++) { + sampSet.add(samp[j]); + } + if (set.equals(sampSet)) { + return i; + } + } + + Assert.fail("Sample not found: { " + + samp[0] + ", " + samp[1] + " }"); + return -1; + } +}