Rewrite bit_reader and bit_writer to take advantage of current zig semantics and enhance readability (#21689)

Co-authored-by: Tanner Schultz <tgschultz@tgschultz-dl.tail7ba92.ts.net>
This commit is contained in:
tgschultz
2024-10-13 20:44:42 -05:00
committed by GitHub
parent e2e79960d2
commit ba569bb8e9
6 changed files with 238 additions and 250 deletions

View File

@@ -405,7 +405,7 @@ pub const DecodeState = struct {
};
fn readLiteralsBits(
self: *DecodeState,
bit_count_to_read: usize,
bit_count_to_read: u16,
) LiteralBitsError!u16 {
return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
if (self.literal_streams == .four and self.literal_stream_index < 3) {

View File

@@ -63,7 +63,7 @@ fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *
fn assignWeights(
huff_bits: *readers.ReverseBitReader,
accuracy_log: usize,
accuracy_log: u16,
entries: *[1 << 6]Table.Fse,
weights: *[256]u4,
) !usize {
@@ -73,7 +73,7 @@ fn assignWeights(
while (i < 254) {
const even_data = entries[even_state];
var read_bits: usize = 0;
var read_bits: u16 = 0;
const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
i += 1;

View File

@@ -42,11 +42,11 @@ pub const ReverseBitReader = struct {
if (i == 8) return error.BitStreamHasNoStartBit;
}
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: u16) error{EndOfStream}!U {
return self.bit_reader.readBitsNoEof(U, num_bits);
}
pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U {
pub fn readBits(self: *@This(), comptime U: type, num_bits: u16, out_bits: *u16) error{}!U {
return try self.bit_reader.readBits(U, num_bits, out_bits);
}
@@ -55,7 +55,7 @@ pub const ReverseBitReader = struct {
}
pub fn isEmpty(self: ReverseBitReader) bool {
return self.byte_reader.remaining_bytes == 0 and self.bit_reader.bit_count == 0;
return self.byte_reader.remaining_bytes == 0 and self.bit_reader.count == 0;
}
};
@@ -63,11 +63,11 @@ pub fn BitReader(comptime Reader: type) type {
return struct {
underlying: std.io.BitReader(.little, Reader),
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: u16) !U {
return self.underlying.readBitsNoEof(U, num_bits);
}
pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
pub fn readBits(self: *@This(), comptime U: type, num_bits: u16, out_bits: *u16) !U {
return self.underlying.readBits(U, num_bits, out_bits);
}

View File

@@ -1,176 +1,179 @@
const std = @import("../std.zig");
const io = std.io;
const assert = std.debug.assert;
const testing = std.testing;
const meta = std.meta;
const math = std.math;
/// Creates a stream which allows for reading bit fields from another stream
pub fn BitReader(comptime endian: std.builtin.Endian, comptime ReaderType: type) type {
//General note on endianess:
//Big endian is packed starting in the most significant part of the byte and subsequent
// bytes contain less significant bits. Thus we always take bits from the high
// end and place them below existing bits in our output.
//Little endian is packed starting in the least significant part of the byte and
// subsequent bytes contain more significant bits. Thus we always take bits from
// the low end and place them above existing bits in our output.
//Regardless of endianess, within any given byte the bits are always in most
// to least significant order.
//Also regardless of endianess, the buffer always aligns bits to the low end
// of the byte.
/// Creates a bit reader which allows for reading bits from an underlying standard reader
pub fn BitReader(comptime endian: std.builtin.Endian, comptime Reader: type) type {
return struct {
forward_reader: ReaderType,
bit_buffer: u7,
bit_count: u3,
reader: Reader,
bits: u8 = 0,
count: u4 = 0,
pub const Error = ReaderType.Error;
pub const Reader = io.Reader(*Self, Error, read);
const low_bit_mask = [9]u8{
0b00000000,
0b00000001,
0b00000011,
0b00000111,
0b00001111,
0b00011111,
0b00111111,
0b01111111,
0b11111111,
};
const Self = @This();
const u8_bit_count = @bitSizeOf(u8);
const u7_bit_count = @bitSizeOf(u7);
const u4_bit_count = @bitSizeOf(u4);
pub fn init(forward_reader: ReaderType) Self {
return Self{
.forward_reader = forward_reader,
.bit_buffer = 0,
.bit_count = 0,
fn Bits(comptime T: type) type {
return struct {
T,
u16,
};
}
/// Reads `bits` bits from the stream and returns a specified unsigned int type
fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) {
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
return .{
@bitCast(@as(UT, @intCast(out))),
num,
};
}
/// Reads `bits` bits from the reader and returns a specified type
/// containing them in the least significant end, returning an error if the
/// specified number of bits could not be read.
pub fn readBitsNoEof(self: *Self, comptime U: type, bits: usize) !U {
var n: usize = undefined;
const result = try self.readBits(U, bits, &n);
if (n < bits) return error.EndOfStream;
return result;
pub fn readBitsNoEof(self: *@This(), comptime T: type, num: u16) !T {
const b, const c = try self.readBitsTuple(T, num);
if (c < num) return error.EndOfStream;
return b;
}
/// Reads `bits` bits from the stream and returns a specified unsigned int type
/// Reads `bits` bits from the reader and returns a specified type
/// containing them in the least significant end. The number of bits successfully
/// read is placed in `out_bits`, as reaching the end of the stream is not an error.
pub fn readBits(self: *Self, comptime U: type, bits: usize, out_bits: *usize) Error!U {
//by extending the buffer to a minimum of u8 we can cover a number of edge cases
// related to shifting and casting.
const u_bit_count = @bitSizeOf(U);
const buf_bit_count = bc: {
assert(u_bit_count >= bits);
break :bc if (u_bit_count <= u8_bit_count) u8_bit_count else u_bit_count;
};
const Buf = std.meta.Int(.unsigned, buf_bit_count);
const BufShift = math.Log2Int(Buf);
pub fn readBits(self: *@This(), comptime T: type, num: u16, out_bits: *u16) !T {
const b, const c = try self.readBitsTuple(T, num);
out_bits.* = c;
return b;
}
out_bits.* = @as(usize, 0);
if (U == u0 or bits == 0) return 0;
var out_buffer = @as(Buf, 0);
/// Reads `bits` bits from the reader and returns a tuple of the specified type
/// containing them in the least significant end, and the number of bits successfully
/// read. Reaching the end of the stream is not an error.
pub fn readBitsTuple(self: *@This(), comptime T: type, num: u16) !Bits(T) {
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
const U = if (@bitSizeOf(T) < 8) u8 else UT; //it is a pain to work with <u8
if (self.bit_count > 0) {
const n = if (self.bit_count >= bits) @as(u3, @intCast(bits)) else self.bit_count;
const shift = u7_bit_count - n;
switch (endian) {
.big => {
out_buffer = @as(Buf, self.bit_buffer >> shift);
if (n >= u7_bit_count)
self.bit_buffer = 0
else
self.bit_buffer <<= n;
},
.little => {
const value = (self.bit_buffer << shift) >> shift;
out_buffer = @as(Buf, value);
if (n >= u7_bit_count)
self.bit_buffer = 0
else
self.bit_buffer >>= n;
},
}
self.bit_count -= n;
out_bits.* = n;
}
//at this point we know bit_buffer is empty
//dump any bits in our buffer first
if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num);
//copy bytes until we have enough bits, then leave the rest in bit_buffer
while (out_bits.* < bits) {
const n = bits - out_bits.*;
const next_byte = self.forward_reader.readByte() catch |err| switch (err) {
error.EndOfStream => return @as(U, @intCast(out_buffer)),
var out_count: u16 = self.count;
var out: U = self.removeBits(self.count);
//grab all the full bytes we need and put their
//bits where they belong
const full_bytes_left = (num - out_count) / 8;
for (0..full_bytes_left) |_| {
const byte = self.reader.readByte() catch |err| switch (err) {
error.EndOfStream => return initBits(T, out, out_count),
else => |e| return e,
};
switch (endian) {
.big => {
if (n >= u8_bit_count) {
out_buffer <<= @as(u3, @intCast(u8_bit_count - 1));
out_buffer <<= 1;
out_buffer |= @as(Buf, next_byte);
out_bits.* += u8_bit_count;
continue;
}
const shift = @as(u3, @intCast(u8_bit_count - n));
out_buffer <<= @as(BufShift, @intCast(n));
out_buffer |= @as(Buf, next_byte >> shift);
out_bits.* += n;
self.bit_buffer = @as(u7, @truncate(next_byte << @as(u3, @intCast(n - 1))));
self.bit_count = shift;
if (U == u8) out = 0 else out <<= 8; //shifting u8 by 8 is illegal in Zig
out |= byte;
},
.little => {
if (n >= u8_bit_count) {
out_buffer |= @as(Buf, next_byte) << @as(BufShift, @intCast(out_bits.*));
out_bits.* += u8_bit_count;
continue;
}
const shift = @as(u3, @intCast(u8_bit_count - n));
const value = (next_byte << shift) >> shift;
out_buffer |= @as(Buf, value) << @as(BufShift, @intCast(out_bits.*));
out_bits.* += n;
self.bit_buffer = @as(u7, @truncate(next_byte >> @as(u3, @intCast(n))));
self.bit_count = shift;
const pos = @as(U, byte) << @intCast(out_count);
out |= pos;
},
}
out_count += 8;
}
return @as(U, @intCast(out_buffer));
}
const bits_left = num - out_count;
const keep = 8 - bits_left;
pub fn alignToByte(self: *Self) void {
self.bit_buffer = 0;
self.bit_count = 0;
}
if (bits_left == 0) return initBits(T, out, out_count);
pub fn read(self: *Self, buffer: []u8) Error!usize {
var out_bits: usize = undefined;
var out_bits_total = @as(usize, 0);
//@NOTE: I'm not sure this is a good idea, maybe alignToByte should be forced
if (self.bit_count > 0) {
for (buffer) |*b| {
b.* = try self.readBits(u8, u8_bit_count, &out_bits);
out_bits_total += out_bits;
}
const incomplete_byte = @intFromBool(out_bits_total % u8_bit_count > 0);
return (out_bits_total / u8_bit_count) + incomplete_byte;
const final_byte = self.reader.readByte() catch |err| switch (err) {
error.EndOfStream => return initBits(T, out, out_count),
else => |e| return e,
};
switch (endian) {
.big => {
out <<= @intCast(bits_left);
out |= final_byte >> @intCast(keep);
self.bits = final_byte & low_bit_mask[keep];
},
.little => {
const pos = @as(U, final_byte & low_bit_mask[bits_left]) << @intCast(out_count);
out |= pos;
self.bits = final_byte >> @intCast(bits_left);
},
}
return self.forward_reader.read(buffer);
self.count = @intCast(keep);
return initBits(T, out, num);
}
pub fn reader(self: *Self) Reader {
return .{ .context = self };
//convenience function for removing bits from
//the appropriate part of the buffer based on
//endianess.
fn removeBits(self: *@This(), num: u4) u8 {
if (num == 8) {
self.count = 0;
return self.bits;
}
const keep = self.count - num;
const bits = switch (endian) {
.big => self.bits >> @intCast(keep),
.little => self.bits & low_bit_mask[num],
};
switch (endian) {
.big => self.bits &= low_bit_mask[keep],
.little => self.bits >>= @intCast(num),
}
self.count = keep;
return bits;
}
pub fn alignToByte(self: *@This()) void {
self.bits = 0;
self.count = 0;
}
};
}
pub fn bitReader(
comptime endian: std.builtin.Endian,
underlying_stream: anytype,
) BitReader(endian, @TypeOf(underlying_stream)) {
return BitReader(endian, @TypeOf(underlying_stream)).init(underlying_stream);
pub fn bitReader(comptime endian: std.builtin.Endian, reader: anytype) BitReader(endian, @TypeOf(reader)) {
return .{ .reader = reader };
}
///////////////////////////////
test "api coverage" {
const mem_be = [_]u8{ 0b11001101, 0b00001011 };
const mem_le = [_]u8{ 0b00011101, 0b10010101 };
var mem_in_be = io.fixedBufferStream(&mem_be);
var mem_in_be = std.io.fixedBufferStream(&mem_be);
var bit_stream_be = bitReader(.big, mem_in_be.reader());
var out_bits: usize = undefined;
var out_bits: u16 = undefined;
const expect = testing.expect;
const expectError = testing.expectError;
const expect = std.testing.expect;
const expectError = std.testing.expectError;
try expect(1 == try bit_stream_be.readBits(u2, 1, &out_bits));
try expect(out_bits == 1);
@@ -186,12 +189,12 @@ test "api coverage" {
try expect(out_bits == 1);
mem_in_be.pos = 0;
bit_stream_be.bit_count = 0;
bit_stream_be.count = 0;
try expect(0b110011010000101 == try bit_stream_be.readBits(u15, 15, &out_bits));
try expect(out_bits == 15);
mem_in_be.pos = 0;
bit_stream_be.bit_count = 0;
bit_stream_be.count = 0;
try expect(0b1100110100001011 == try bit_stream_be.readBits(u16, 16, &out_bits));
try expect(out_bits == 16);
@@ -201,7 +204,7 @@ test "api coverage" {
try expect(out_bits == 0);
try expectError(error.EndOfStream, bit_stream_be.readBitsNoEof(u1, 1));
var mem_in_le = io.fixedBufferStream(&mem_le);
var mem_in_le = std.io.fixedBufferStream(&mem_le);
var bit_stream_le = bitReader(.little, mem_in_le.reader());
try expect(1 == try bit_stream_le.readBits(u2, 1, &out_bits));
@@ -218,12 +221,12 @@ test "api coverage" {
try expect(out_bits == 1);
mem_in_le.pos = 0;
bit_stream_le.bit_count = 0;
bit_stream_le.count = 0;
try expect(0b001010100011101 == try bit_stream_le.readBits(u15, 15, &out_bits));
try expect(out_bits == 15);
mem_in_le.pos = 0;
bit_stream_le.bit_count = 0;
bit_stream_le.count = 0;
try expect(0b1001010100011101 == try bit_stream_le.readBits(u16, 16, &out_bits));
try expect(out_bits == 16);

View File

@@ -1,153 +1,138 @@
const std = @import("../std.zig");
const io = std.io;
const testing = std.testing;
const assert = std.debug.assert;
const math = std.math;
/// Creates a stream which allows for writing bit fields to another stream
pub fn BitWriter(comptime endian: std.builtin.Endian, comptime WriterType: type) type {
//General note on endianess:
//Big endian is packed starting in the most significant part of the byte and subsequent
// bytes contain less significant bits. Thus we write out bits from the high end
// of our input first.
//Little endian is packed starting in the least significant part of the byte and
// subsequent bytes contain more significant bits. Thus we write out bits from
// the low end of our input first.
//Regardless of endianess, within any given byte the bits are always in most
// to least significant order.
//Also regardless of endianess, the buffer always aligns bits to the low end
// of the byte.
/// Creates a bit writer which allows for writing bits to an underlying standard writer
pub fn BitWriter(comptime endian: std.builtin.Endian, comptime Writer: type) type {
return struct {
forward_writer: WriterType,
bit_buffer: u8,
bit_count: u4,
writer: Writer,
bits: u8 = 0,
count: u4 = 0,
pub const Error = WriterType.Error;
pub const Writer = io.Writer(*Self, Error, write);
const low_bit_mask = [9]u8{
0b00000000,
0b00000001,
0b00000011,
0b00000111,
0b00001111,
0b00011111,
0b00111111,
0b01111111,
0b11111111,
};
const Self = @This();
const u8_bit_count = @bitSizeOf(u8);
const u4_bit_count = @bitSizeOf(u4);
pub fn init(forward_writer: WriterType) Self {
return Self{
.forward_writer = forward_writer,
.bit_buffer = 0,
.bit_count = 0,
};
}
/// Write the specified number of bits to the stream from the least significant bits of
/// the specified unsigned int value. Bits will only be written to the stream when there
/// Write the specified number of bits to the writer from the least significant bits of
/// the specified value. Bits will only be written to the writer when there
/// are enough to fill a byte.
pub fn writeBits(self: *Self, value: anytype, bits: usize) Error!void {
if (bits == 0) return;
pub fn writeBits(self: *@This(), value: anytype, num: u16) !void {
const T = @TypeOf(value);
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
const U = if (@bitSizeOf(T) < 8) u8 else UT; //<u8 is a pain to work with
const U = @TypeOf(value);
comptime assert(@typeInfo(U).int.signedness == .unsigned);
var in: U = @as(UT, @bitCast(value));
var in_count: u16 = num;
//by extending the buffer to a minimum of u8 we can cover a number of edge cases
// related to shifting and casting.
const u_bit_count = @bitSizeOf(U);
const buf_bit_count = bc: {
assert(u_bit_count >= bits);
break :bc if (u_bit_count <= u8_bit_count) u8_bit_count else u_bit_count;
};
const Buf = std.meta.Int(.unsigned, buf_bit_count);
const BufShift = math.Log2Int(Buf);
if (self.count > 0) {
//if we can't fill the buffer, add what we have
const bits_free = 8 - self.count;
if (num < bits_free) {
self.addBits(@truncate(in), @intCast(num));
return;
}
const buf_value = @as(Buf, @intCast(value));
//finish filling the buffer and flush it
if (num == bits_free) {
self.addBits(@truncate(in), @intCast(num));
return self.flushBits();
}
const high_byte_shift = @as(BufShift, @intCast(buf_bit_count - u8_bit_count));
var in_buffer = switch (endian) {
.big => buf_value << @as(BufShift, @intCast(buf_bit_count - bits)),
.little => buf_value,
};
var in_bits = bits;
if (self.bit_count > 0) {
const bits_remaining = u8_bit_count - self.bit_count;
const n = @as(u3, @intCast(if (bits_remaining > bits) bits else bits_remaining));
switch (endian) {
.big => {
const shift = @as(BufShift, @intCast(high_byte_shift + self.bit_count));
const v = @as(u8, @intCast(in_buffer >> shift));
self.bit_buffer |= v;
in_buffer <<= n;
const bits = in >> @intCast(in_count - bits_free);
self.addBits(@truncate(bits), bits_free);
},
.little => {
const v = @as(u8, @truncate(in_buffer)) << @as(u3, @intCast(self.bit_count));
self.bit_buffer |= v;
in_buffer >>= n;
self.addBits(@truncate(in), bits_free);
in >>= @intCast(bits_free);
},
}
self.bit_count += n;
in_bits -= n;
//if we didn't fill the buffer, it's because bits < bits_remaining;
if (self.bit_count != u8_bit_count) return;
try self.forward_writer.writeByte(self.bit_buffer);
self.bit_buffer = 0;
self.bit_count = 0;
in_count -= bits_free;
try self.flushBits();
}
//at this point we know bit_buffer is empty
//copy bytes until we can't fill one anymore, then leave the rest in bit_buffer
while (in_bits >= u8_bit_count) {
//write full bytes while we can
const full_bytes_left = in_count / 8;
for (0..full_bytes_left) |_| {
switch (endian) {
.big => {
const v = @as(u8, @intCast(in_buffer >> high_byte_shift));
try self.forward_writer.writeByte(v);
in_buffer <<= @as(u3, @intCast(u8_bit_count - 1));
in_buffer <<= 1;
const bits = in >> @intCast(in_count - 8);
try self.writer.writeByte(@truncate(bits));
},
.little => {
const v = @as(u8, @truncate(in_buffer));
try self.forward_writer.writeByte(v);
in_buffer >>= @as(u3, @intCast(u8_bit_count - 1));
in_buffer >>= 1;
try self.writer.writeByte(@truncate(in));
if (U == u8) in = 0 else in >>= 8;
},
}
in_bits -= u8_bit_count;
in_count -= 8;
}
if (in_bits > 0) {
self.bit_count = @as(u4, @intCast(in_bits));
self.bit_buffer = switch (endian) {
.big => @as(u8, @truncate(in_buffer >> high_byte_shift)),
.little => @as(u8, @truncate(in_buffer)),
};
//save the remaining bits in the buffer
self.addBits(@truncate(in), @intCast(in_count));
}
//convenience funciton for adding bits to the buffer
//in the appropriate position based on endianess
fn addBits(self: *@This(), bits: u8, num: u4) void {
if (num == 8) self.bits = bits else switch (endian) {
.big => {
self.bits <<= @intCast(num);
self.bits |= bits & low_bit_mask[num];
},
.little => {
const pos = bits << @intCast(self.count);
self.bits |= pos;
},
}
self.count += num;
}
/// Flush any remaining bits to the stream.
pub fn flushBits(self: *Self) Error!void {
if (self.bit_count == 0) return;
try self.forward_writer.writeByte(self.bit_buffer);
self.bit_buffer = 0;
self.bit_count = 0;
}
pub fn write(self: *Self, buffer: []const u8) Error!usize {
// TODO: I'm not sure this is a good idea, maybe flushBits should be forced
if (self.bit_count > 0) {
for (buffer) |b|
try self.writeBits(b, u8_bit_count);
return buffer.len;
}
return self.forward_writer.write(buffer);
}
pub fn writer(self: *Self) Writer {
return .{ .context = self };
/// Flush any remaining bits to the writer, filling
/// unused bits with 0s.
pub fn flushBits(self: *@This()) !void {
if (self.count == 0) return;
if (endian == .big) self.bits <<= @intCast(8 - self.count);
try self.writer.writeByte(self.bits);
self.bits = 0;
self.count = 0;
}
};
}
pub fn bitWriter(
comptime endian: std.builtin.Endian,
underlying_stream: anytype,
) BitWriter(endian, @TypeOf(underlying_stream)) {
return BitWriter(endian, @TypeOf(underlying_stream)).init(underlying_stream);
pub fn bitWriter(comptime endian: std.builtin.Endian, writer: anytype) BitWriter(endian, @TypeOf(writer)) {
return .{ .writer = writer };
}
///////////////////////////////
test "api coverage" {
var mem_be = [_]u8{0} ** 2;
var mem_le = [_]u8{0} ** 2;
var mem_out_be = io.fixedBufferStream(&mem_be);
var mem_out_be = std.io.fixedBufferStream(&mem_be);
var bit_stream_be = bitWriter(.big, mem_out_be.writer());
const testing = std.testing;
try bit_stream_be.writeBits(@as(u2, 1), 1);
try bit_stream_be.writeBits(@as(u5, 2), 2);
try bit_stream_be.writeBits(@as(u128, 3), 3);
@@ -169,7 +154,7 @@ test "api coverage" {
try bit_stream_be.writeBits(@as(u0, 0), 0);
var mem_out_le = io.fixedBufferStream(&mem_le);
var mem_out_le = std.io.fixedBufferStream(&mem_le);
var bit_stream_le = bitWriter(.little, mem_out_le.writer());
try bit_stream_le.writeBits(@as(u2, 1), 1);

View File

@@ -82,7 +82,7 @@ test "BitStreams with File Stream" {
var bit_stream = io.bitReader(native_endian, file.reader());
var out_bits: usize = undefined;
var out_bits: u16 = undefined;
try expect(1 == try bit_stream.readBits(u2, 1, &out_bits));
try expect(out_bits == 1);