This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 9f492d05b47c51427228ffce1e48ce5c2099accb
Author: aherbert <aherb...@apache.org>
AuthorDate: Thu Sep 16 13:45:02 2021 +0100

    Allow ziggurat sampling from only the overhangs in the performance test
    
    Add additional ternary variant for testing.
---
 .../distribution/ZigguratSamplerPerformance.java   | 205 +++++++++++++++++++--
 .../sampling/distribution/ZigguratSamplerTest.java |   1 +
 2 files changed, 193 insertions(+), 13 deletions(-)

diff --git 
a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
 
b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
index 3a52182..0d17f52 100644
--- 
a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
+++ 
b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
@@ -115,6 +115,8 @@ public class ZigguratSamplerPerformance {
     static final String MOD_EXPONENTIAL_E_MAX_2 = "ModExponentialEmax2";
     /** The name for the {@link ModifiedZigguratExponentialSamplerTernary}. */
     static final String MOD_EXPONENTIAL_TERNARY = "ModExponentialTernary";
+    /** The name for the {@link 
ModifiedZigguratExponentialSamplerTernarySubtract}. */
+    static final String MOD_EXPONENTIAL_TERNARY_SUBTRACT = 
"ModExponentialTernarySubtract";
     /** The name for the {@link ModifiedZigguratExponentialSampler512} using a 
table size of 512. */
     static final String MOD_EXPONENTIAL_512 = "ModExponential512";
 
@@ -580,8 +582,57 @@ public class ZigguratSamplerPerformance {
                 MOD_EXPONENTIAL_LOOP, MOD_EXPONENTIAL_LOOP2,
                 MOD_EXPONENTIAL_RECURSION, MOD_EXPONENTIAL_INT_MAP,
                 MOD_EXPONENTIAL_E_MAX_TABLE, MOD_EXPONENTIAL_E_MAX_2,
-                MOD_EXPONENTIAL_TERNARY, MOD_EXPONENTIAL_512})
-        protected String type;
+                MOD_EXPONENTIAL_TERNARY, MOD_EXPONENTIAL_TERNARY_SUBTRACT, 
MOD_EXPONENTIAL_512})
+        private String type;
+
+
+        /** Flag to indicate that the sample targets the overhangs.
+         * This is applicable to the McFarland Ziggurat sampler and
+         * requires manipulation of the final bits of the RNG. */
+        @Param({"true", "false"})
+        private boolean overhang;
+
+        /**
+         * Creates the sampler.
+         *
+         * <p>This may return a specialisation for the McFarland sampler that 
only samples
+         * from overhangs.
+         *
+         * @param rng RNG
+         * @return the sampler
+         */
+        protected ContinuousSampler createSampler(UniformRandomProvider rng) {
+            if (!overhang) {
+                return createSampler(type, rng);
+            }
+            // For the Marsaglia Ziggurat sampler overhangs are only tested 
once then
+            // the method recurses the entire sample method. Overhang sampling 
cannot be forced
+            // for this sampler.
+            if (GAUSSIAN_128.equals(type) ||
+                GAUSSIAN_256.equals(type) ||
+                EXPONENTIAL.equals(type)) {
+                return createSampler(type, rng);
+            }
+            // Assume the sampler is a McFarland Ziggurat sampler.
+            // Manipulate the final bits of the long from the RNG to force 
sampling
+            // from the overhang. Assume most of the samplers use an 8-bit 
look-up table.
+            int numberOfBits = 8;
+            if (type.contains("512")) {
+                // 9-bit look-up table
+                numberOfBits = 9;
+            }
+            // Use an RNG that can set the lower bits of the long.
+            final ModifiedRNG modRNG = new ModifiedRNG(rng, numberOfBits);
+            final ContinuousSampler sampler = createSampler(type, modRNG);
+            // Create a sampler where each call should force overhangs/tail 
sampling
+            return new ContinuousSampler() {
+                @Override
+                public double sample() {
+                    modRNG.modifyNextLong();
+                    return sampler.sample();
+                }
+            };
+        }
 
         /**
          * Creates the sampler.
@@ -641,12 +692,85 @@ public class ZigguratSamplerPerformance {
                 return new ModifiedZigguratExponentialSamplerEMax2(rng);
             } else if (MOD_EXPONENTIAL_TERNARY.equals(type)) {
                 return new ModifiedZigguratExponentialSamplerTernary(rng);
+            } else if (MOD_EXPONENTIAL_TERNARY_SUBTRACT.equals(type)) {
+                return new 
ModifiedZigguratExponentialSamplerTernarySubtract(rng);
             } else if (MOD_EXPONENTIAL_512.equals(type)) {
                 return new ModifiedZigguratExponentialSampler512(rng);
             } else {
                 throw new IllegalStateException("Unknown type: " + type);
             }
         }
+
+        /**
+         * A class that can modify the lower bits to be all set for the next 
invocation of
+         * {@link UniformRandomProvider#nextLong()}.
+         */
+        private static class ModifiedRNG implements UniformRandomProvider {
+            /** Underlying source of randomness. */
+            private final UniformRandomProvider rng;
+            /** The bits to set in the output long using a bitwise or ('|'). */
+            private final long bits;
+            /** The next bits to set in the output long using a bitwise or 
('|'). */
+            private long nextBits;
+
+            /**
+             * @param rng Underlying source of randomness
+             * @param numberOfBits Number of least significant bits to set for 
a call to nextLong()
+             */
+            ModifiedRNG(UniformRandomProvider rng, int numberOfBits) {
+                this.rng = rng;
+                bits = (1L << numberOfBits) - 1;
+            }
+
+            /**
+             * Set the state to modify the lower bits on the next call to 
nextLong().
+             */
+            void modifyNextLong() {
+                nextBits = bits;
+            }
+
+            @Override
+            public long nextLong() {
+                final long x = rng.nextLong() | nextBits;
+                nextBits = 0;
+                return x;
+            }
+
+            // The following methods should not be used.
+
+            @Override
+            public void nextBytes(byte[] bytes) {
+                throw new IllegalStateException();
+            }
+            @Override
+            public void nextBytes(byte[] bytes, int start, int len) {
+                throw new IllegalStateException();
+            }
+            @Override
+            public int nextInt() {
+                throw new IllegalStateException();
+            }
+            @Override
+            public int nextInt(int n) {
+                throw new IllegalStateException();
+            }
+            @Override
+            public long nextLong(long n) {
+                throw new IllegalStateException();
+            }
+            @Override
+            public boolean nextBoolean() {
+                throw new IllegalStateException();
+            }
+            @Override
+            public float nextFloat() {
+                throw new IllegalStateException();
+            }
+            @Override
+            public double nextDouble() {
+                throw new IllegalStateException();
+            }
+        }
     }
 
     /**
@@ -683,7 +807,7 @@ public class ZigguratSamplerPerformance {
         public void setup() {
             final RandomSource randomSource = 
RandomSource.valueOf(randomSourceName);
             final UniformRandomProvider rng = randomSource.create();
-            sampler = createSampler(type, rng);
+            sampler = createSampler(rng);
         }
     }
 
@@ -736,7 +860,7 @@ public class ZigguratSamplerPerformance {
         public void setup() {
             final RandomSource randomSource = 
RandomSource.valueOf(randomSourceName);
             final UniformRandomProvider rng = randomSource.create();
-            final ContinuousSampler s = createSampler(type, rng);
+            final ContinuousSampler s = createSampler(rng);
             sampler = createSequentialSampler(size, s);
         }
 
@@ -1194,9 +1318,9 @@ public class ZigguratSamplerPerformance {
         // Ziggurat volumes:
         // Inside the layers              = 98.8281%  (253/256)
         // Fraction outside the layers:
-        // concave overhangs              = 76.1941%
+        // convex overhangs               = 76.1941%
         // inflection overhang            =  0.1358%
-        // convex overhangs               = 21.3072%
+        // concave overhangs              = 21.3072%
         // tail                           =  2.3629%
 
         /** The number of layers in the ziggurat. Maximum i value for early 
exit. */
@@ -2267,9 +2391,9 @@ public class ZigguratSamplerPerformance {
         // Ziggurat volumes:
         // Inside the layers              = 98.8281%  (253/256)
         // Fraction outside the layers:
-        // concave overhangs              = 76.1941%
+        // convex overhangs               = 76.1941%
         // inflection overhang            =  0.1358%
-        // convex overhangs               = 21.3072%
+        // concave overhangs              = 21.3072%
         // tail                           =  2.3629%
 
         // Separation of convex overhangs:
@@ -2479,9 +2603,11 @@ public class ZigguratSamplerPerformance {
                     // Concave overhang
                     for (;;) {
                         // If u2 < u1 then reflect in the hypotenuse by 
swapping u1 and u2.
+                        // Create a second uniform deviate (as u1 is recycled).
                         final long ua = u1;
                         final long ub = randomInt63();
-                        // Sort u1 < u2 to sample the lower-left triangle
+                        // Sort u1 < u2 to sample the lower-left triangle.
+                        // Use conditional ternary to avoid a 50/50 branch 
statement to swap the pair.
                         u1 = ua < ub ? ua : ub;
                         final long u2 = ua < ub ? ub : ua;
                         x = sampleX(X, j, u1);
@@ -2518,9 +2644,9 @@ public class ZigguratSamplerPerformance {
         // Ziggurat volumes:
         // Inside the layers              = 99.4141%  (509/512)
         // Fraction outside the layers:
-        // concave overhangs              = 75.5775%
+        // convex overhangs               = 75.5775%
         // inflection overhang            =  0.0675%
-        // convex overhangs               = 22.2196%
+        // concave overhangs              = 22.2196%
         // tail                           =  2.1354%
 
         /** The number of layers in the ziggurat. Maximum i value for early 
exit. */
@@ -4121,7 +4247,7 @@ public class ZigguratSamplerPerformance {
      * <p>Uses the algorithm from McFarland, C.D. (2016).
      *
      * <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
-     * a ternary operator to sort the two random long values.
+     * two ternary operators to sort the two random long values.
      */
     static class ModifiedZigguratExponentialSamplerTernary
         extends ModifiedZigguratExponentialSampler {
@@ -4149,7 +4275,8 @@ public class ZigguratSamplerPerformance {
             // If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
             final long ua = randomInt63();
             final long ub = randomInt63();
-            // Sort u1 < u2 to sample the lower-left triangle
+            // Sort u1 < u2 to sample the lower-left triangle.
+            // Use conditional ternary to avoid a 50/50 branch statement to 
swap the pair.
             final long u1 = ua < ub ? ua : ub;
             final long u2 = ua < ub ? ub : ua;
             final double x = sampleX(X, j, u1);
@@ -4168,6 +4295,58 @@ public class ZigguratSamplerPerformance {
      * <p>Uses the algorithm from McFarland, C.D. (2016).
      *
      * <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
+     * a ternary operator to sort the two random long values and a subtraction
+     * to get the difference.
+     */
+    static class ModifiedZigguratExponentialSamplerTernarySubtract
+        extends ModifiedZigguratExponentialSampler {
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         */
+        
ModifiedZigguratExponentialSamplerTernarySubtract(UniformRandomProvider rng) {
+            super(rng);
+        }
+
+        @Override
+        protected double sampleOverhang(int j) {
+            // Sample from the triangle:
+            //    X[j],Y[j]
+            //        |\-->u1
+            //        | \  |
+            //        |  \ |
+            //        |   \|    Overhang j (with hypotenuse not pdf(x))
+            //        |    \
+            //        |    |\
+            //        |    | \
+            //        |    u2 \
+            //        +-------- X[j-1],Y[j-1]
+            // If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
+            final long ua = randomInt63();
+            final long ub = randomInt63();
+            // Sort u1 < u2 to sample the lower-left triangle.
+            // Use conditional ternary to avoid a 50/50 branch statement to 
swap the pair.
+            final long u1 = ua < ub ? ua : ub;
+            final double x = sampleX(X, j, u1);
+            // u2 = ua + ub - u1
+            // uDistance = ua + ub - u1 - u1
+            final long uDistance = ua + ub - (u1 << 1);
+            if (uDistance >= E_MAX) {
+                // Early Exit: x < y - epsilon
+                return x;
+            }
+
+            // u2 = u1 + uDistance
+            return sampleY(Y, j, u1 + uDistance) <= Math.exp(-x) ? x : 
sampleOverhang(j);
+        }
+    }
+
+    /**
+     * Modified Ziggurat method for sampling from an exponential distribution.
+     *
+     * <p>Uses the algorithm from McFarland, C.D. (2016).
+     *
+     * <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
      * a table size of 512.
      */
     static class ModifiedZigguratExponentialSampler512 implements 
ContinuousSampler {
diff --git 
a/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
 
b/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
index 25d0670..177ca93 100644
--- 
a/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
+++ 
b/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
@@ -117,6 +117,7 @@ class ZigguratSamplerTest {
                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_TABLE),
                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_2),
                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY),
+                
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY_SUBTRACT),
                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_512));
     }
 

Reply via email to