diff --git a/lib/std/bit_set.zig b/lib/std/bit_set.zig index 5101f934bc..d839512c07 100644 --- a/lib/std/bit_set.zig +++ b/lib/std/bit_set.zig @@ -110,6 +110,31 @@ pub fn IntegerBitSet(comptime size: u16) type { self.mask |= maskBit(index); } + /// Changes the value of all bits in the specified range to + /// match the passed boolean. + pub fn setRangeValue(self: *Self, range: Range, value: bool) void { + assert(range.end <= bit_length); + assert(range.start <= range.end); + if (range.start == range.end) return; + if (MaskInt == u0) return; + + const start_bit = @intCast(ShiftInt, range.start); + + var mask = std.math.boolMask(MaskInt, true) << start_bit; + if (range.end != bit_length) { + const end_bit = @intCast(ShiftInt, range.end); + mask &= std.math.boolMask(MaskInt, true) >> @truncate(ShiftInt, @as(usize, @bitSizeOf(MaskInt)) - @as(usize, end_bit)); + } + self.mask &= ~mask; + + mask = std.math.boolMask(MaskInt, value) << start_bit; + if (range.end != bit_length) { + const end_bit = @intCast(ShiftInt, range.end); + mask &= std.math.boolMask(MaskInt, value) >> @truncate(ShiftInt, @as(usize, @bitSizeOf(MaskInt)) - @as(usize, end_bit)); + } + self.mask |= mask; + } + /// Removes a specific bit from the bit set pub fn unset(self: *Self, index: usize) void { assert(index < bit_length); @@ -345,6 +370,51 @@ pub fn ArrayBitSet(comptime MaskIntType: type, comptime size: usize) type { self.masks[maskIndex(index)] |= maskBit(index); } + /// Changes the value of all bits in the specified range to + /// match the passed boolean. + pub fn setRangeValue(self: *Self, range: Range, value: bool) void { + assert(range.end <= bit_length); + assert(range.start <= range.end); + if (range.start == range.end) return; + if (num_masks == 0) return; + + const start_mask_index = maskIndex(range.start); + const start_bit = @truncate(ShiftInt, range.start); + + const end_mask_index = maskIndex(range.end); + const end_bit = @truncate(ShiftInt, range.end); + + if (start_mask_index == end_mask_index) { + var mask1 = std.math.boolMask(MaskInt, true) << start_bit; + var mask2 = std.math.boolMask(MaskInt, true) >> (mask_len - 1) - (end_bit - 1); + self.masks[start_mask_index] &= ~(mask1 & mask2); + + mask1 = std.math.boolMask(MaskInt, value) << start_bit; + mask2 = std.math.boolMask(MaskInt, value) >> (mask_len - 1) - (end_bit - 1); + self.masks[start_mask_index] |= mask1 & mask2; + } else { + var bulk_mask_index: usize = undefined; + if (start_bit > 0) { + self.masks[start_mask_index] = + (self.masks[start_mask_index] & ~(std.math.boolMask(MaskInt, true) << start_bit)) | + (std.math.boolMask(MaskInt, value) << start_bit); + bulk_mask_index = start_mask_index + 1; + } else { + bulk_mask_index = start_mask_index; + } + + while (bulk_mask_index < end_mask_index) : (bulk_mask_index += 1) { + self.masks[bulk_mask_index] = std.math.boolMask(MaskInt, value); + } + + if (end_bit > 0) { + self.masks[end_mask_index] = + (self.masks[end_mask_index] & (std.math.boolMask(MaskInt, true) << end_bit)) | + (std.math.boolMask(MaskInt, value) >> ((@bitSizeOf(MaskInt) - 1) - (end_bit - 1))); + } + } + } + /// Removes a specific bit from the bit set pub fn unset(self: *Self, index: usize) void { assert(index < bit_length); @@ -608,6 +678,50 @@ pub const DynamicBitSetUnmanaged = struct { self.masks[maskIndex(index)] |= maskBit(index); } + /// Changes the value of all bits in the specified range to + /// match the passed boolean. + pub fn setRangeValue(self: *Self, range: Range, value: bool) void { + assert(range.end <= self.bit_length); + assert(range.start <= range.end); + if (range.start == range.end) return; + + const start_mask_index = maskIndex(range.start); + const start_bit = @truncate(ShiftInt, range.start); + + const end_mask_index = maskIndex(range.end); + const end_bit = @truncate(ShiftInt, range.end); + + if (start_mask_index == end_mask_index) { + var mask1 = std.math.boolMask(MaskInt, true) << start_bit; + var mask2 = std.math.boolMask(MaskInt, true) >> (@bitSizeOf(MaskInt) - 1) - (end_bit - 1); + self.masks[start_mask_index] &= ~(mask1 & mask2); + + mask1 = std.math.boolMask(MaskInt, value) << start_bit; + mask2 = std.math.boolMask(MaskInt, value) >> (@bitSizeOf(MaskInt) - 1) - (end_bit - 1); + self.masks[start_mask_index] |= mask1 & mask2; + } else { + var bulk_mask_index: usize = undefined; + if (start_bit > 0) { + self.masks[start_mask_index] = + (self.masks[start_mask_index] & ~(std.math.boolMask(MaskInt, true) << start_bit)) | + (std.math.boolMask(MaskInt, value) << start_bit); + bulk_mask_index = start_mask_index + 1; + } else { + bulk_mask_index = start_mask_index; + } + + while (bulk_mask_index < end_mask_index) : (bulk_mask_index += 1) { + self.masks[bulk_mask_index] = std.math.boolMask(MaskInt, value); + } + + if (end_bit > 0) { + self.masks[end_mask_index] = + (self.masks[end_mask_index] & (std.math.boolMask(MaskInt, true) << end_bit)) | + (std.math.boolMask(MaskInt, value) >> ((@bitSizeOf(MaskInt) - 1) - (end_bit - 1))); + } + } + } + /// Removes a specific bit from the bit set pub fn unset(self: *Self, index: usize) void { assert(index < self.bit_length); @@ -811,6 +925,12 @@ pub const DynamicBitSet = struct { self.unmanaged.set(index); } + /// Changes the value of all bits in the specified range to + /// match the passed boolean. + pub fn setRangeValue(self: *Self, range: Range, value: bool) void { + self.unmanaged.setRangeValue(range, value); + } + /// Removes a specific bit from the bit set pub fn unset(self: *Self, index: usize) void { self.unmanaged.unset(index); @@ -990,6 +1110,14 @@ fn BitSetIterator(comptime MaskInt: type, comptime options: IteratorOptions) typ }; } +/// A range of indices within a bitset. +pub const Range = struct { + /// The index of the first bit of interest. + start: usize, + /// The index immediately after the last bit of interest. + end: usize, +}; + // ---------------- Tests ----------------- const testing = std.testing; @@ -1144,6 +1272,52 @@ fn testBitSet(a: anytype, b: anytype, len: usize) !void { try testing.expectEqual(@as(?usize, null), a.findFirstSet()); try testing.expectEqual(@as(?usize, null), a.toggleFirstSet()); try testing.expectEqual(@as(usize, 0), a.count()); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + try testing.expectEqual(@as(usize, 0), a.count()); + + a.setRangeValue(.{ .start = 0, .end = len }, true); + try testing.expectEqual(len, a.count()); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 0, .end = 0 }, true); + try testing.expectEqual(@as(usize, 0), a.count()); + + a.setRangeValue(.{ .start = len, .end = len }, true); + try testing.expectEqual(@as(usize, 0), a.count()); + + if (len >= 1) { + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 0, .end = 1 }, true); + try testing.expectEqual(@as(usize, 1), a.count()); + try testing.expect(a.isSet(0)); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 0, .end = len - 1 }, true); + try testing.expectEqual(len - 1, a.count()); + try testing.expect(!a.isSet(len - 1)); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 1, .end = len }, true); + try testing.expectEqual(@as(usize, len - 1), a.count()); + try testing.expect(!a.isSet(0)); + + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = len - 1, .end = len }, true); + try testing.expectEqual(@as(usize, 1), a.count()); + try testing.expect(a.isSet(len - 1)); + + if (len >= 4) { + a.setRangeValue(.{ .start = 0, .end = len }, false); + a.setRangeValue(.{ .start = 1, .end = len - 2 }, true); + try testing.expectEqual(@as(usize, len - 3), a.count()); + try testing.expect(!a.isSet(0)); + try testing.expect(a.isSet(1)); + try testing.expect(a.isSet(len - 3)); + try testing.expect(!a.isSet(len - 2)); + try testing.expect(!a.isSet(len - 1)); + } + } } fn testStaticBitSet(comptime Set: type) !void {