commit a07490a4bc501334727e70b9dca387d03c24fe2e (tree)
parent 4d747d452fba0987332244c777dfd3a29d5a20bd
Author: Andrew Kelley <superjoe30@gmail.com>
Date: Sun, 25 Nov 2018 11:44:08 -0500
Merge pull request #1783 from ziglang/rand-range
Use better rand range implementations
Diffstat:
| M | std/rand/index.zig | | | 195 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------- |
1 file changed, 155 insertions(+), 40 deletions(-)
diff --git a/std/rand/index.zig b/std/rand/index.zig
@@ -57,6 +57,18 @@ pub const Random = struct {
return @bitCast(T, unsigned_result);
}
+ /// Constant-time implementation off ::uintLessThan.
+ /// The results of this function may be biased.
+ pub fn uintLessThanBiased(r: *Random, comptime T: type, less_than: T) T {
+ comptime assert(T.is_signed == false);
+ comptime assert(T.bit_count <= 64); // TODO: workaround: LLVM ERROR: Unsupported library call operation!
+ assert(0 < less_than);
+ if (T.bit_count <= 32) {
+ return @intCast(T, limitRangeBiased(u32, r.int(u32), less_than));
+ } else {
+ return @intCast(T, limitRangeBiased(u64, r.int(u64), less_than));
+ }
+ }
/// Returns an evenly distributed random unsigned integer `0 <= i < less_than`.
/// This function assumes that the underlying ::fillFn produces evenly distributed values.
/// Within this assumption, the runtime of this function is exponentially distributed.
@@ -64,29 +76,53 @@ pub const Random = struct {
/// the runtime of this function would technically be unbounded.
/// However, if ::fillFn is backed by any evenly distributed pseudo random number generator,
/// this function is guaranteed to return.
- /// If you need deterministic runtime bounds, consider instead using `r.int(T) % less_than`,
- /// which will usually be biased toward smaller values.
+ /// If you need deterministic runtime bounds, use `::uintLessThanBiased`.
pub fn uintLessThan(r: *Random, comptime T: type, less_than: T) T {
- assert(T.is_signed == false);
+ comptime assert(T.is_signed == false);
+ comptime assert(T.bit_count <= 64); // TODO: workaround: LLVM ERROR: Unsupported library call operation!
assert(0 < less_than);
-
- const last_group_size_minus_one: T = maxInt(T) % less_than;
- if (last_group_size_minus_one == less_than - 1) {
- // less_than is a power of two.
- assert(math.floorPowerOfTwo(T, less_than) == less_than);
- // There is no retry zone. The optimal retry_zone_start would be maxInt(T) + 1.
- return r.int(T) % less_than;
- }
- const retry_zone_start = maxInt(T) - last_group_size_minus_one;
-
- while (true) {
- const rand_val = r.int(T);
- if (rand_val < retry_zone_start) {
- return rand_val % less_than;
+ // Small is typically u32
+ const Small = @IntType(false, @divTrunc(T.bit_count + 31, 32) * 32);
+ // Large is typically u64
+ const Large = @IntType(false, Small.bit_count * 2);
+
+ // adapted from:
+ // http://www.pcg-random.org/posts/bounded-rands.html
+ // "Lemire's (with an extra tweak from me)"
+ var x: Small = r.int(Small);
+ var m: Large = Large(x) * Large(less_than);
+ var l: Small = @truncate(Small, m);
+ if (l < less_than) {
+ // TODO: workaround for https://github.com/ziglang/zig/issues/1770
+ // should be:
+ // var t: Small = -%less_than;
+ var t: Small = @bitCast(Small, -%@bitCast(@IntType(true, Small.bit_count), Small(less_than)));
+
+ if (t >= less_than) {
+ t -= less_than;
+ if (t >= less_than) {
+ t %= less_than;
+ }
+ }
+ while (l < t) {
+ x = r.int(Small);
+ m = Large(x) * Large(less_than);
+ l = @truncate(Small, m);
}
}
+ return @intCast(T, m >> Small.bit_count);
}
+ /// Constant-time implementation off ::uintAtMost.
+ /// The results of this function may be biased.
+ pub fn uintAtMostBiased(r: *Random, comptime T: type, at_most: T) T {
+ assert(T.is_signed == false);
+ if (at_most == maxInt(T)) {
+ // have the full range
+ return r.int(T);
+ }
+ return r.uintLessThanBiased(T, at_most + 1);
+ }
/// Returns an evenly distributed random unsigned integer `0 <= i <= at_most`.
/// See ::uintLessThan, which this function uses in most cases,
/// for commentary on the runtime of this function.
@@ -99,6 +135,22 @@ pub const Random = struct {
return r.uintLessThan(T, at_most + 1);
}
+ /// Constant-time implementation off ::intRangeLessThan.
+ /// The results of this function may be biased.
+ pub fn intRangeLessThanBiased(r: *Random, comptime T: type, at_least: T, less_than: T) T {
+ assert(at_least < less_than);
+ if (T.is_signed) {
+ // Two's complement makes this math pretty easy.
+ const UnsignedT = @IntType(false, T.bit_count);
+ const lo = @bitCast(UnsignedT, at_least);
+ const hi = @bitCast(UnsignedT, less_than);
+ const result = lo +% r.uintLessThanBiased(UnsignedT, hi -% lo);
+ return @bitCast(T, result);
+ } else {
+ // The signed implementation would work fine, but we can use stricter arithmetic operators here.
+ return at_least + r.uintLessThanBiased(T, less_than - at_least);
+ }
+ }
/// Returns an evenly distributed random integer `at_least <= i < less_than`.
/// See ::uintLessThan, which this function uses in most cases,
/// for commentary on the runtime of this function.
@@ -117,6 +169,22 @@ pub const Random = struct {
}
}
+ /// Constant-time implementation off ::intRangeAtMostBiased.
+ /// The results of this function may be biased.
+ pub fn intRangeAtMostBiased(r: *Random, comptime T: type, at_least: T, at_most: T) T {
+ assert(at_least <= at_most);
+ if (T.is_signed) {
+ // Two's complement makes this math pretty easy.
+ const UnsignedT = @IntType(false, T.bit_count);
+ const lo = @bitCast(UnsignedT, at_least);
+ const hi = @bitCast(UnsignedT, at_most);
+ const result = lo +% r.uintAtMostBiased(UnsignedT, hi -% lo);
+ return @bitCast(T, result);
+ } else {
+ // The signed implementation would work fine, but we can use stricter arithmetic operators here.
+ return at_least + r.uintAtMostBiased(T, at_most - at_least);
+ }
+ }
/// Returns an evenly distributed random integer `at_least <= i <= at_most`.
/// See ::uintLessThan, which this function uses in most cases,
/// for commentary on the runtime of this function.
@@ -135,15 +203,11 @@ pub const Random = struct {
}
}
- /// Return a random integer/boolean type.
/// TODO: deprecated. use ::boolean or ::int instead.
pub fn scalar(r: *Random, comptime T: type) T {
- if (T == bool) return r.boolean();
- return r.int(T);
+ return if (T == bool) r.boolean() else r.int(T);
}
- /// Return a random integer with even distribution between `start`
- /// inclusive and `end` exclusive. `start` must be less than `end`.
/// TODO: deprecated. renamed to ::intRangeLessThan
pub fn range(r: *Random, comptime T: type, start: T, end: T) T {
return r.intRangeLessThan(T, start, end);
@@ -206,6 +270,20 @@ pub const Random = struct {
}
};
+/// Convert a random integer 0 <= random_int <= maxValue(T),
+/// into an integer 0 <= result < less_than.
+/// This function introduces a minor bias.
+pub fn limitRangeBiased(comptime T: type, random_int: T, less_than: T) T {
+ comptime assert(T.is_signed == false);
+ const T2 = @IntType(false, T.bit_count * 2);
+
+ // adapted from:
+ // http://www.pcg-random.org/posts/bounded-rands.html
+ // "Integer Multiplication (Biased)"
+ var m: T2 = T2(random_int) * T2(less_than);
+ return @intCast(T, m >> T.bit_count);
+}
+
const SequentialPrng = struct {
const Self = @This();
random: Random,
@@ -294,10 +372,19 @@ fn testRandomIntLessThan() void {
var r = SequentialPrng.init();
r.next_value = 0xff;
assert(r.random.uintLessThan(u8, 4) == 3);
- r.next_value = 0xff;
- assert(r.random.uintLessThan(u8, 3) == 0);
+ assert(r.next_value == 0);
+ assert(r.random.uintLessThan(u8, 4) == 0);
assert(r.next_value == 1);
+ r.next_value = 0;
+ assert(r.random.uintLessThan(u64, 32) == 0);
+
+ // trigger the bias rejection code path
+ r.next_value = 0;
+ assert(r.random.uintLessThan(u8, 3) == 0);
+ // verify we incremented twice
+ assert(r.next_value == 2);
+
r.next_value = 0xff;
assert(r.random.intRangeLessThan(u8, 0, 0x80) == 0x7f);
r.next_value = 0xff;
@@ -311,16 +398,9 @@ fn testRandomIntLessThan() void {
assert(r.random.intRangeLessThan(i8, -0x80, 0) == -1);
r.next_value = 0xff;
- assert(r.random.intRangeLessThan(i64, -0x8000000000000000, 0) == -1);
- r.next_value = 0xff;
assert(r.random.intRangeLessThan(i3, -4, 0) == -1);
r.next_value = 0xff;
assert(r.random.intRangeLessThan(i3, -2, 2) == 1);
-
- // test retrying and eventually getting a good value
- // start just out of bounds
- r.next_value = 0x81;
- assert(r.random.uintLessThan(u8, 0x81) == 0);
}
test "Random intAtMost" {
@@ -332,9 +412,14 @@ fn testRandomIntAtMost() void {
var r = SequentialPrng.init();
r.next_value = 0xff;
assert(r.random.uintAtMost(u8, 3) == 3);
- r.next_value = 0xff;
+ assert(r.next_value == 0);
+ assert(r.random.uintAtMost(u8, 3) == 0);
+
+ // trigger the bias rejection code path
+ r.next_value = 0;
assert(r.random.uintAtMost(u8, 2) == 0);
- assert(r.next_value == 1);
+ // verify we incremented twice
+ assert(r.next_value == 2);
r.next_value = 0xff;
assert(r.random.intRangeAtMost(u8, 0, 0x7f) == 0x7f);
@@ -349,16 +434,42 @@ fn testRandomIntAtMost() void {
assert(r.random.intRangeAtMost(i8, -0x80, -1) == -1);
r.next_value = 0xff;
- assert(r.random.intRangeAtMost(i64, -0x8000000000000000, -1) == -1);
- r.next_value = 0xff;
assert(r.random.intRangeAtMost(i3, -4, -1) == -1);
r.next_value = 0xff;
assert(r.random.intRangeAtMost(i3, -2, 1) == 1);
- // test retrying and eventually getting a good value
- // start just out of bounds
- r.next_value = 0x81;
- assert(r.random.uintAtMost(u8, 0x80) == 0);
+ assert(r.random.uintAtMost(u0, 0) == 0);
+}
+
+test "Random Biased" {
+ var r = DefaultPrng.init(0);
+ // Not thoroughly checking the logic here.
+ // Just want to execute all the paths with different types.
+
+ assert(r.random.uintLessThanBiased(u1, 1) == 0);
+ assert(r.random.uintLessThanBiased(u32, 10) < 10);
+ assert(r.random.uintLessThanBiased(u64, 20) < 20);
+
+ assert(r.random.uintAtMostBiased(u0, 0) == 0);
+ assert(r.random.uintAtMostBiased(u1, 0) <= 0);
+ assert(r.random.uintAtMostBiased(u32, 10) <= 10);
+ assert(r.random.uintAtMostBiased(u64, 20) <= 20);
+
+ assert(r.random.intRangeLessThanBiased(u1, 0, 1) == 0);
+ assert(r.random.intRangeLessThanBiased(i1, -1, 0) == -1);
+ assert(r.random.intRangeLessThanBiased(u32, 10, 20) >= 10);
+ assert(r.random.intRangeLessThanBiased(i32, 10, 20) >= 10);
+ assert(r.random.intRangeLessThanBiased(u64, 20, 40) >= 20);
+ assert(r.random.intRangeLessThanBiased(i64, 20, 40) >= 20);
+
+ // uncomment for broken module error:
+ //assert(r.random.intRangeAtMostBiased(u0, 0, 0) == 0);
+ assert(r.random.intRangeAtMostBiased(u1, 0, 1) >= 0);
+ assert(r.random.intRangeAtMostBiased(i1, -1, 0) >= -1);
+ assert(r.random.intRangeAtMostBiased(u32, 10, 20) >= 10);
+ assert(r.random.intRangeAtMostBiased(i32, 10, 20) >= 10);
+ assert(r.random.intRangeAtMostBiased(u64, 20, 40) >= 20);
+ assert(r.random.intRangeAtMostBiased(i64, 20, 40) >= 20);
}
// Generator to extend 64-bit seed values into longer sequences.
@@ -870,12 +981,16 @@ test "Random range" {
}
fn testRange(r: *Random, start: i8, end: i8) void {
+ testRangeBias(r, start, end, true);
+ testRangeBias(r, start, end, false);
+}
+fn testRangeBias(r: *Random, start: i8, end: i8, biased: bool) void {
const count = @intCast(usize, i32(end) - i32(start));
var values_buffer = []bool{false} ** 0x100;
const values = values_buffer[0..count];
var i: usize = 0;
while (i < count) {
- const value: i32 = r.intRangeLessThan(i8, start, end);
+ const value: i32 = if (biased) r.intRangeLessThanBiased(i8, start, end) else r.intRangeLessThan(i8, start, end);
const index = @intCast(usize, value - start);
if (!values[index]) {
i += 1;