From 2a415a033cce07b579a490eca2556e7d700c04b1 Mon Sep 17 00:00:00 2001 From: Arnavion Date: Wed, 22 Dec 2021 16:48:46 -0800 Subject: [PATCH] std.bit_set: add setRangeValue(Range, bool) For large ranges, this is faster than having the caller call setValue() for each index in the range. Masks wholly covered by the range can be set to the new mask value in one go, and the two masks at either end that are partially covered can each set the covered range of bits in one go. --- lib/std/bit_set.zig | 174 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/lib/std/bit_set.zig b/lib/std/bit_set.zig index 5101f934bc..d839512c07 100644 --- a/lib/std/bit_set.zig +++ b/lib/std/bit_set.zig @@ -110,6 +110,31 @@ pub fn IntegerBitSet(comptime size: u16) type { self.mask |= maskBit(index); } + /// Changes the value of all bits in the specified range to + /// match the passed boolean. + pub fn setRangeValue(self: *Self, range: Range, value: bool) void { + assert(range.end <= bit_length); + assert(range.start <= range.end); + if (range.start == range.end) return; + if (MaskInt == u0) return; + + const start_bit = @intCast(ShiftInt, range.start); + + var mask = std.math.boolMask(MaskInt, true) << start_bit; + if (range.end != bit_length) { + const end_bit = @intCast(ShiftInt, range.end); + mask &= std.math.boolMask(MaskInt, true) >> @truncate(ShiftInt, @as(usize, @bitSizeOf(MaskInt)) - @as(usize, end_bit)); + } + self.mask &= ~mask; + + mask = std.math.boolMask(MaskInt, value) << start_bit; + if (range.end != bit_length) { + const end_bit = @intCast(ShiftInt, range.end); + mask &= std.math.boolMask(MaskInt, value) >> @truncate(ShiftInt, @as(usize, @bitSizeOf(MaskInt)) - @as(usize, end_bit)); + } + self.mask |= mask; + } + /// Removes a specific bit from the bit set pub fn unset(self: *Self, index: usize) void { assert(index < bit_length); @@ -345,6 +370,51 @@ pub fn ArrayBitSet(comptime MaskIntType: type, comptime size: usize) type { self.masks[maskIndex(index)] |= maskBit(index); } + /// Changes the value of all bits in the specified range to + /// match the passed boolean. + pub fn setRangeValue(self: *Self, range: Range, value: bool) void { + assert(range.end <= bit_length); + assert(range.start <= range.end); + if (range.start == range.end) return; + if (num_masks == 0) return; + + const start_mask_index = maskIndex(range.start); + const start_bit = @truncate(ShiftInt, range.start); + + const end_mask_index = maskIndex(range.end); + const end_bit = @truncate(ShiftInt, range.end); + + if (start_mask_index == end_mask_index) { + var mask1 = std.math.boolMask(MaskInt, true) << start_bit; + var mask2 = std.math.boolMask(MaskInt, true) >> (mask_len - 1) - (end_bit - 1); + self.masks[start_mask_index] &= ~(mask1 & mask2); + + mask1 = std.math.boolMask(MaskInt, value) << start_bit; + mask2 = std.math.boolMask(MaskInt, value) >> (mask_len - 1) - (end_bit - 1); + self.masks[start_mask_index] |= mask1 & mask2; + } else { + var bulk_mask_index: usize = undefined; + if (start_bit > 0) { + self.masks[start_mask_index] = + (self.masks[start_mask_index] & ~(std.math.boolMask(MaskInt, true) << start_bit)) | + (std.math.boolMask(MaskInt, value) << start_bit); + bulk_mask_index = start_mask_index + 1; + } else { + bulk_mask_index = start_mask_index; + } + + while (bulk_mask_index < end_mask_index) : (bulk_mask_index += 1) { + self.masks[bulk_mask_index] = std.math.boolMask(MaskInt, value); + } + + if (end_bit > 0) { + self.masks[end_mask_index] = + (self.masks[end_mask_index] & (std.math.boolMask(MaskInt, true) << end_bit)) | + (std.math.boolMask(MaskInt, value) >> ((@bitSizeOf(MaskInt) - 1) - (end_bit - 1))); + } + } + } + /// Removes a specific bit from the bit set pub fn unset(self: *Self, index: usize) void { assert(index < bit_length); @@ -608,6 +678,50 @@ pub const DynamicBitSetUnmanaged = struct { self.masks[maskIndex(index)] |= maskBit(index); } + /// Changes the value of all bits in the specified range to + /// match the passed boolean. + pub fn setRangeValue(self: *Self, range: Range, value: bool) void { + assert(range.end <= self.bit_length); + assert(range.start <= range.end); + if (range.start == range.end) return; + + const start_mask_index = maskIndex(range.start); + const start_bit = @truncate(ShiftInt, range.start); + + const end_mask_index = maskIndex(range.end); + const end_bit = @truncate(ShiftInt, range.end); + + if (start_mask_index == end_mask_index) { + var mask1 = std.math.boolMask(MaskInt, true) << start_bit; + var mask2 = std.math.boolMask(MaskInt, true) >> (@bitSizeOf(MaskInt) - 1) - (end_bit - 1); + self.masks[start_mask_index] &= ~(mask1 & mask2); + + mask1 = std.math.boolMask(MaskInt, value) << start_bit; + mask2 = std.math.boolMask(MaskInt, value) >> (@bitSizeOf(MaskInt) - 1) - (end_bit - 1); + self.masks[start_mask_index] |= mask1 & mask2; + } else { + var bulk_mask_index: usize = undefined; + if (start_bit > 0) { + self.masks[start_mask_index] = + (self.masks[start_mask_index] & ~(std.math.boolMask(MaskInt, true) << start_bit)) | + (std.math.boolMask(MaskInt, value) << start_bit); + bulk_mask_index = start_mask_index + 1; + } else { + bulk_mask_index = start_mask_index; + } + + while (bulk_mask_index < end_mask_index) : (bulk_mask_index += 1) { + self.masks[bulk_mask_index] = std.math.boolMask(MaskInt, value); + } + + if (end_bit > 0) { + self.masks[end_mask_index] = + (self.masks[end_mask_index] & (std.math.boolMask(MaskInt, true) << end_bit)) | + (std.math.boolMask(MaskInt, value) >> ((@bitSizeOf(MaskInt) - 1) - (end_bit - 1))); + } + } + } + /// Removes a specific bit from the bit set pub fn unset(self: *Self, index: usize) void { assert(index < self.bit_length); @@ -811,6 +925,12 @@ pub const DynamicBitSet = struct { self.unmanaged.set(index); } + /// Changes the value of all bits in the specified range to + /// match the passed boolean. + pub fn setRangeValue(self: *Self, range: Range, value: bool) void { + self.unmanaged.setRangeValue(range, value); + } + /// Removes a specific bit from the bit set pub fn unset(self: *Self, index: usize) void { self.unmanaged.unset(index); @@ -990,6 +1110,14 @@ fn BitSetIterator(comptime MaskInt: type, comptime options: IteratorOptions) typ }; } +/// A range of indices within a bitset. +pub const Range = struct { + /// The index of the first bit of interest. + start: usize, + /// The index immediately after the last bit of interest. + end: usize, +}; + // ---------------- Tests ----------------- const testing = std.testing; @@ -1144,6 +1272,52 @@ fn testBitSet(a: anytype, b: anytype, len: usize) !void { try testing.expectEqual(@as(?usize, null), a.findFirstSet()); try testing.expectEqual(@as(?usize, null), a.toggleFirstSet()); try testing.expectEqual(@as(usize, 0), a.count()); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + try testing.expectEqual(@as(usize, 0), a.count()); + + a.setRangeValue(.{ .start = 0, .end = len }, true); + try testing.expectEqual(len, a.count()); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 0, .end = 0 }, true); + try testing.expectEqual(@as(usize, 0), a.count()); + + a.setRangeValue(.{ .start = len, .end = len }, true); + try testing.expectEqual(@as(usize, 0), a.count()); + + if (len >= 1) { + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 0, .end = 1 }, true); + try testing.expectEqual(@as(usize, 1), a.count()); + try testing.expect(a.isSet(0)); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 0, .end = len - 1 }, true); + try testing.expectEqual(len - 1, a.count()); + try testing.expect(!a.isSet(len - 1)); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 1, .end = len }, true); + try testing.expectEqual(@as(usize, len - 1), a.count()); + try testing.expect(!a.isSet(0)); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = len - 1, .end = len }, true); + try testing.expectEqual(@as(usize, 1), a.count()); + try testing.expect(a.isSet(len - 1)); + + if (len >= 4) { + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 1, .end = len - 2 }, true); + try testing.expectEqual(@as(usize, len - 3), a.count()); + try testing.expect(!a.isSet(0)); + try testing.expect(a.isSet(1)); + try testing.expect(a.isSet(len - 3)); + try testing.expect(!a.isSet(len - 2)); + try testing.expect(!a.isSet(len - 1)); + } + } } fn testStaticBitSet(comptime Set: type) !void {