RNG-30: Generate random permutations. Code copied and adapted from the development version of "Commons Math" (classes "o.a.c.m.util.MathArrays" and "o.a.c.m.random.RandomUtils").
Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/0fce78bc Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/0fce78bc Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/0fce78bc Branch: refs/heads/master Commit: 0fce78bc73fe626ccb3ed2e74fac0c57026181e4 Parents: d3a7be3 Author: Gilles <er...@apache.org> Authored: Tue Nov 15 17:22:40 2016 +0100 Committer: Gilles <er...@apache.org> Committed: Tue Nov 15 17:22:40 2016 +0100 ---------------------------------------------------------------------- .../rng/sampling/PermutationSampler.java | 152 +++++++++++++++ .../rng/sampling/PermutationSamplerTest.java | 194 +++++++++++++++++++ 2 files changed, 346 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-rng/blob/0fce78bc/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java new file mode 100644 index 0000000..d53cfca --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.rng.sampling; + +import java.util.Arrays; + +import org.apache.commons.rng.UniformRandomProvider; + +/** + * Class for representing permutations of a sequence of integers. + * + * This class also contains utilities for shuffling an {@code int[]} array. + */ +public class PermutationSampler { + /** Domain of the permutation. */ + private final int[] domain; + /** Size of the permutation. */ + private final int size; + /** RNG. */ + private final UniformRandomProvider rng; + + /** + * Creates a generator of permutations. + * + * The {@link #sample()} method will generate an integer array of + * length {@code k} whose entries are selected randomly, without + * repetition, from the integers 0, 1, ..., {@code n}-1 (inclusive). + * The returned array represents a permutation of {@code n} taken + * {@code k}. + * + * @param rng Generator of uniformly distributed random numbers. + * @param n Domain of the permutation. + * @param k Size of the permutation. + * @throws IllegalArgumentException if {@code n < 0} or {@code k <= 0} + * or {@code k > n}. + */ + public PermutationSampler(UniformRandomProvider rng, + int n, + int k) { + if (n < 0) { + throw new IllegalArgumentException(n + " < " + 0); + } + if (k <= 0) { + throw new IllegalArgumentException(k + " <= " + 0); + } + if (k > n) { + throw new IllegalArgumentException(k + " > " + n); + } + + domain = natural(n); + size = k; + this.rng = rng; + } + + /** + * @return a random permutation. + * + * @see #Permutation(UniformRandomProvider,int,int) + */ + public int[] sample() { + shuffle(domain, rng); + return Arrays.copyOf(domain, size); + } + + /** + * Shuffles the entries of the given array. + * + * @see #shuffle(int[],int,boolean,UniformRandomProvider) + * + * @param list Array whose entries will be shuffled (in-place). + * @param rng Random number generator. + */ + public static void shuffle(int[] list, + UniformRandomProvider rng) { + shuffle(list, 0, false, rng); + } + + /** + * Shuffles the entries of the given array, using the + * <a href="http://en.wikipedia.org/wiki/FisherâYates_shuffle#The_modern_algorithm"> + * Fisher-Yates</a> algorithm. + * The {@code start} and {@code pos} parameters select which part + * of the array is randomized and which is left untouched. + * + * @param list Array whose entries will be shuffled (in-place). + * @param start Index at which shuffling begins. + * @param towardHead Shuffling is performed for index positions between + * {@code start} and either the end (if {@code false}) or the beginning + * (if {@code true}) of the array. + * @param rng Random number generator. + */ + public static void shuffle(int[] list, + int start, + boolean towardHead, + UniformRandomProvider rng) { + if (towardHead) { + for (int i = 0; i <= start; i++) { + final int target; + if (i == start) { + target = start; + } else { + target = rng.nextInt(start - i + 1) + i; + } + final int temp = list[target]; + list[target] = list[i]; + list[i] = temp; + } + } else { + for (int i = list.length - 1; i >= start; i--) { + final int target; + if (i == start) { + target = start; + } else { + target = rng.nextInt(i - start + 1) + start; + } + final int temp = list[target]; + list[target] = list[i]; + list[i] = temp; + } + } + } + + /** + * Creates an array representing the natural number {@code n}. + * + * @param n Natural number. + * @return an array whose entries are the numbers 0, 1, ..., {@code n}-1. + * If {@code n == 0}, the returned array is empty. + */ + public static int[] natural(int n) { + final int[] a = new int[n]; + for (int i = 0; i < n; i++) { + a[i] = i; + } + return a; + } +} http://git-wip-us.apache.org/repos/asf/commons-rng/blob/0fce78bc/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java new file mode 100644 index 0000000..bd32c05 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling; + +import java.util.Set; +import java.util.HashSet; +import java.util.List; +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.commons.math3.stat.inference.ChiSquareTest; +import org.apache.commons.math3.util.MathArrays; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; + +/** + * Tests for {@link PermutationSampler}. + */ +public class PermutationSamplerTest { + private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 1232343456L); + private final ChiSquareTest chiSquareTest = new ChiSquareTest(); + + @Test + public void testSample() { + final int[][] p = { { 0, 1, 2 }, { 0, 2, 1 }, + { 1, 0, 2 }, { 1, 2, 0 }, + { 2, 0, 1 }, { 2, 1, 0 } }; + final int len = p.length; + final long[] observed = new long[len]; + final int numSamples = 6000; + final double numExpected = numSamples / (double) len; + final double[] expected = new double[len]; + Arrays.fill(expected, numExpected); + + final PermutationSampler sampler = new PermutationSampler(rng, 3, 3); + for (int i = 0; i < numSamples; i++) { + observed[findPerm(p, sampler.sample())]++; + } + + // Pass if we cannot reject null hypothesis that distributions are the same + Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001)); + } + + @Test + public void testSampleBoundaryCase() { + // Check size = 1 boundary case. + final PermutationSampler sampler = new PermutationSampler(rng, 1, 1); + final int[] perm = sampler.sample(); + Assert.assertEquals(1, perm.length); + Assert.assertEquals(0, perm[0]); + } + + @Test(expected=IllegalArgumentException.class) + public void testSamplePrecondition1() { + // Must fail for k > n. + new PermutationSampler(rng, 2, 3); + } + + @Test(expected=IllegalArgumentException.class) + public void testSamplePrecondition2() { + // Must fail for n = 0. + new PermutationSampler(rng, 0, 0); + } + + @Test(expected=IllegalArgumentException.class) + public void testSamplePrecondition3() { + // Must fail for k < n < 0. + new PermutationSampler(rng, -1, 0); + } + + @Test(expected=IllegalArgumentException.class) + public void testSamplePrecondition4() { + // Must fail for k < n < 0. + new PermutationSampler(rng, 1, -1); + } + + @Test + public void testNatural() { + final int n = 4; + final int[] expected = {0, 1, 2, 3}; + + final int[] natural = PermutationSampler.natural(n); + for (int i = 0; i < n; i++) { + Assert.assertEquals(expected[i], natural[i]); + } + } + + @Test + public void testNaturalZero() { + final int[] natural = PermutationSampler.natural(0); + Assert.assertEquals(0, natural.length); + } + + @Test + public void testShuffleNoDuplicates() { + final int n = 100; + final int[] orig = PermutationSampler.natural(n); + PermutationSampler.shuffle(orig, rng); + + // Test that all (unique) entries exist in the shuffled array. + final int[] count = new int[n]; + for (int i = 0; i < n; i++) { + count[orig[i]] += 1; + } + + for (int i = 0; i < n; i++) { + Assert.assertEquals(1, count[i]); + } + } + + @Test + public void testShuffleTail() { + final int[] orig = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + final int[] list = orig.clone(); + final int start = 4; + PermutationSampler.shuffle(list, start, false, rng); + + // Ensure that all entries below index "start" did not move. + for (int i = 0; i < start; i++) { + Assert.assertEquals(orig[i], list[i]); + } + + // Ensure that at least one entry has moved. + boolean ok = false; + for (int i = start; i < orig.length - 1; i++) { + if (orig[i] != list[i]) { + ok = true; + break; + } + } + Assert.assertTrue(ok); + } + + @Test + public void testShuffleHead() { + final int[] orig = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + final int[] list = orig.clone(); + final int start = 4; + PermutationSampler.shuffle(list, start, true, rng); + + // Ensure that all entries above index "start" did not move. + for (int i = start + 1; i < orig.length; i++) { + Assert.assertEquals(orig[i], list[i]); + } + + // Ensure that at least one entry has moved. + boolean ok = false; + for (int i = 0; i <= start; i++) { + if (orig[i] != list[i]) { + ok = true; + break; + } + } + Assert.assertTrue(ok); + } + + //// Support methods. + + private int findPerm(int[][] p, + int[] samp) { + for (int i = 0; i < p.length; i++) { + boolean good = true; + for (int j = 0; j < samp.length; j++) { + if (samp[j] != p[i][j]) { + good = false; + } + } + if (good) { + return i; + } + } + + Assert.fail("Permutation not found"); + return -1; + } +}