Rewrite bit_reader and bit_writer to take advantage of current zig semantics and enhance readability (#21689)
Co-authored-by: Tanner Schultz <tgschultz@tgschultz-dl.tail7ba92.ts.net>
This commit is contained in:
@@ -405,7 +405,7 @@ pub const DecodeState = struct {
|
||||
};
|
||||
fn readLiteralsBits(
|
||||
self: *DecodeState,
|
||||
bit_count_to_read: usize,
|
||||
bit_count_to_read: u16,
|
||||
) LiteralBitsError!u16 {
|
||||
return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
|
||||
if (self.literal_streams == .four and self.literal_stream_index < 3) {
|
||||
|
||||
@@ -63,7 +63,7 @@ fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *
|
||||
|
||||
fn assignWeights(
|
||||
huff_bits: *readers.ReverseBitReader,
|
||||
accuracy_log: usize,
|
||||
accuracy_log: u16,
|
||||
entries: *[1 << 6]Table.Fse,
|
||||
weights: *[256]u4,
|
||||
) !usize {
|
||||
@@ -73,7 +73,7 @@ fn assignWeights(
|
||||
|
||||
while (i < 254) {
|
||||
const even_data = entries[even_state];
|
||||
var read_bits: usize = 0;
|
||||
var read_bits: u16 = 0;
|
||||
const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
|
||||
weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
|
||||
i += 1;
|
||||
|
||||
@@ -42,11 +42,11 @@ pub const ReverseBitReader = struct {
|
||||
if (i == 8) return error.BitStreamHasNoStartBit;
|
||||
}
|
||||
|
||||
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
|
||||
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: u16) error{EndOfStream}!U {
|
||||
return self.bit_reader.readBitsNoEof(U, num_bits);
|
||||
}
|
||||
|
||||
pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U {
|
||||
pub fn readBits(self: *@This(), comptime U: type, num_bits: u16, out_bits: *u16) error{}!U {
|
||||
return try self.bit_reader.readBits(U, num_bits, out_bits);
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ pub const ReverseBitReader = struct {
|
||||
}
|
||||
|
||||
pub fn isEmpty(self: ReverseBitReader) bool {
|
||||
return self.byte_reader.remaining_bytes == 0 and self.bit_reader.bit_count == 0;
|
||||
return self.byte_reader.remaining_bytes == 0 and self.bit_reader.count == 0;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -63,11 +63,11 @@ pub fn BitReader(comptime Reader: type) type {
|
||||
return struct {
|
||||
underlying: std.io.BitReader(.little, Reader),
|
||||
|
||||
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
|
||||
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: u16) !U {
|
||||
return self.underlying.readBitsNoEof(U, num_bits);
|
||||
}
|
||||
|
||||
pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
|
||||
pub fn readBits(self: *@This(), comptime U: type, num_bits: u16, out_bits: *u16) !U {
|
||||
return self.underlying.readBits(U, num_bits, out_bits);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,176 +1,179 @@
|
||||
const std = @import("../std.zig");
|
||||
const io = std.io;
|
||||
const assert = std.debug.assert;
|
||||
const testing = std.testing;
|
||||
const meta = std.meta;
|
||||
const math = std.math;
|
||||
|
||||
/// Creates a stream which allows for reading bit fields from another stream
|
||||
pub fn BitReader(comptime endian: std.builtin.Endian, comptime ReaderType: type) type {
|
||||
//General note on endianess:
|
||||
//Big endian is packed starting in the most significant part of the byte and subsequent
|
||||
// bytes contain less significant bits. Thus we always take bits from the high
|
||||
// end and place them below existing bits in our output.
|
||||
//Little endian is packed starting in the least significant part of the byte and
|
||||
// subsequent bytes contain more significant bits. Thus we always take bits from
|
||||
// the low end and place them above existing bits in our output.
|
||||
//Regardless of endianess, within any given byte the bits are always in most
|
||||
// to least significant order.
|
||||
//Also regardless of endianess, the buffer always aligns bits to the low end
|
||||
// of the byte.
|
||||
|
||||
/// Creates a bit reader which allows for reading bits from an underlying standard reader
|
||||
pub fn BitReader(comptime endian: std.builtin.Endian, comptime Reader: type) type {
|
||||
return struct {
|
||||
forward_reader: ReaderType,
|
||||
bit_buffer: u7,
|
||||
bit_count: u3,
|
||||
reader: Reader,
|
||||
bits: u8 = 0,
|
||||
count: u4 = 0,
|
||||
|
||||
pub const Error = ReaderType.Error;
|
||||
pub const Reader = io.Reader(*Self, Error, read);
|
||||
const low_bit_mask = [9]u8{
|
||||
0b00000000,
|
||||
0b00000001,
|
||||
0b00000011,
|
||||
0b00000111,
|
||||
0b00001111,
|
||||
0b00011111,
|
||||
0b00111111,
|
||||
0b01111111,
|
||||
0b11111111,
|
||||
};
|
||||
|
||||
const Self = @This();
|
||||
const u8_bit_count = @bitSizeOf(u8);
|
||||
const u7_bit_count = @bitSizeOf(u7);
|
||||
const u4_bit_count = @bitSizeOf(u4);
|
||||
|
||||
pub fn init(forward_reader: ReaderType) Self {
|
||||
return Self{
|
||||
.forward_reader = forward_reader,
|
||||
.bit_buffer = 0,
|
||||
.bit_count = 0,
|
||||
fn Bits(comptime T: type) type {
|
||||
return struct {
|
||||
T,
|
||||
u16,
|
||||
};
|
||||
}
|
||||
|
||||
/// Reads `bits` bits from the stream and returns a specified unsigned int type
|
||||
fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) {
|
||||
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
|
||||
return .{
|
||||
@bitCast(@as(UT, @intCast(out))),
|
||||
num,
|
||||
};
|
||||
}
|
||||
|
||||
/// Reads `bits` bits from the reader and returns a specified type
|
||||
/// containing them in the least significant end, returning an error if the
|
||||
/// specified number of bits could not be read.
|
||||
pub fn readBitsNoEof(self: *Self, comptime U: type, bits: usize) !U {
|
||||
var n: usize = undefined;
|
||||
const result = try self.readBits(U, bits, &n);
|
||||
if (n < bits) return error.EndOfStream;
|
||||
return result;
|
||||
pub fn readBitsNoEof(self: *@This(), comptime T: type, num: u16) !T {
|
||||
const b, const c = try self.readBitsTuple(T, num);
|
||||
if (c < num) return error.EndOfStream;
|
||||
return b;
|
||||
}
|
||||
|
||||
/// Reads `bits` bits from the stream and returns a specified unsigned int type
|
||||
/// Reads `bits` bits from the reader and returns a specified type
|
||||
/// containing them in the least significant end. The number of bits successfully
|
||||
/// read is placed in `out_bits`, as reaching the end of the stream is not an error.
|
||||
pub fn readBits(self: *Self, comptime U: type, bits: usize, out_bits: *usize) Error!U {
|
||||
//by extending the buffer to a minimum of u8 we can cover a number of edge cases
|
||||
// related to shifting and casting.
|
||||
const u_bit_count = @bitSizeOf(U);
|
||||
const buf_bit_count = bc: {
|
||||
assert(u_bit_count >= bits);
|
||||
break :bc if (u_bit_count <= u8_bit_count) u8_bit_count else u_bit_count;
|
||||
};
|
||||
const Buf = std.meta.Int(.unsigned, buf_bit_count);
|
||||
const BufShift = math.Log2Int(Buf);
|
||||
pub fn readBits(self: *@This(), comptime T: type, num: u16, out_bits: *u16) !T {
|
||||
const b, const c = try self.readBitsTuple(T, num);
|
||||
out_bits.* = c;
|
||||
return b;
|
||||
}
|
||||
|
||||
out_bits.* = @as(usize, 0);
|
||||
if (U == u0 or bits == 0) return 0;
|
||||
var out_buffer = @as(Buf, 0);
|
||||
/// Reads `bits` bits from the reader and returns a tuple of the specified type
|
||||
/// containing them in the least significant end, and the number of bits successfully
|
||||
/// read. Reaching the end of the stream is not an error.
|
||||
pub fn readBitsTuple(self: *@This(), comptime T: type, num: u16) !Bits(T) {
|
||||
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
|
||||
const U = if (@bitSizeOf(T) < 8) u8 else UT; //it is a pain to work with <u8
|
||||
|
||||
if (self.bit_count > 0) {
|
||||
const n = if (self.bit_count >= bits) @as(u3, @intCast(bits)) else self.bit_count;
|
||||
const shift = u7_bit_count - n;
|
||||
switch (endian) {
|
||||
.big => {
|
||||
out_buffer = @as(Buf, self.bit_buffer >> shift);
|
||||
if (n >= u7_bit_count)
|
||||
self.bit_buffer = 0
|
||||
else
|
||||
self.bit_buffer <<= n;
|
||||
},
|
||||
.little => {
|
||||
const value = (self.bit_buffer << shift) >> shift;
|
||||
out_buffer = @as(Buf, value);
|
||||
if (n >= u7_bit_count)
|
||||
self.bit_buffer = 0
|
||||
else
|
||||
self.bit_buffer >>= n;
|
||||
},
|
||||
}
|
||||
self.bit_count -= n;
|
||||
out_bits.* = n;
|
||||
}
|
||||
//at this point we know bit_buffer is empty
|
||||
//dump any bits in our buffer first
|
||||
if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num);
|
||||
|
||||
//copy bytes until we have enough bits, then leave the rest in bit_buffer
|
||||
while (out_bits.* < bits) {
|
||||
const n = bits - out_bits.*;
|
||||
const next_byte = self.forward_reader.readByte() catch |err| switch (err) {
|
||||
error.EndOfStream => return @as(U, @intCast(out_buffer)),
|
||||
var out_count: u16 = self.count;
|
||||
var out: U = self.removeBits(self.count);
|
||||
|
||||
//grab all the full bytes we need and put their
|
||||
//bits where they belong
|
||||
const full_bytes_left = (num - out_count) / 8;
|
||||
|
||||
for (0..full_bytes_left) |_| {
|
||||
const byte = self.reader.readByte() catch |err| switch (err) {
|
||||
error.EndOfStream => return initBits(T, out, out_count),
|
||||
else => |e| return e,
|
||||
};
|
||||
|
||||
switch (endian) {
|
||||
.big => {
|
||||
if (n >= u8_bit_count) {
|
||||
out_buffer <<= @as(u3, @intCast(u8_bit_count - 1));
|
||||
out_buffer <<= 1;
|
||||
out_buffer |= @as(Buf, next_byte);
|
||||
out_bits.* += u8_bit_count;
|
||||
continue;
|
||||
}
|
||||
|
||||
const shift = @as(u3, @intCast(u8_bit_count - n));
|
||||
out_buffer <<= @as(BufShift, @intCast(n));
|
||||
out_buffer |= @as(Buf, next_byte >> shift);
|
||||
out_bits.* += n;
|
||||
self.bit_buffer = @as(u7, @truncate(next_byte << @as(u3, @intCast(n - 1))));
|
||||
self.bit_count = shift;
|
||||
if (U == u8) out = 0 else out <<= 8; //shifting u8 by 8 is illegal in Zig
|
||||
out |= byte;
|
||||
},
|
||||
.little => {
|
||||
if (n >= u8_bit_count) {
|
||||
out_buffer |= @as(Buf, next_byte) << @as(BufShift, @intCast(out_bits.*));
|
||||
out_bits.* += u8_bit_count;
|
||||
continue;
|
||||
}
|
||||
|
||||
const shift = @as(u3, @intCast(u8_bit_count - n));
|
||||
const value = (next_byte << shift) >> shift;
|
||||
out_buffer |= @as(Buf, value) << @as(BufShift, @intCast(out_bits.*));
|
||||
out_bits.* += n;
|
||||
self.bit_buffer = @as(u7, @truncate(next_byte >> @as(u3, @intCast(n))));
|
||||
self.bit_count = shift;
|
||||
const pos = @as(U, byte) << @intCast(out_count);
|
||||
out |= pos;
|
||||
},
|
||||
}
|
||||
out_count += 8;
|
||||
}
|
||||
|
||||
return @as(U, @intCast(out_buffer));
|
||||
}
|
||||
const bits_left = num - out_count;
|
||||
const keep = 8 - bits_left;
|
||||
|
||||
pub fn alignToByte(self: *Self) void {
|
||||
self.bit_buffer = 0;
|
||||
self.bit_count = 0;
|
||||
}
|
||||
if (bits_left == 0) return initBits(T, out, out_count);
|
||||
|
||||
pub fn read(self: *Self, buffer: []u8) Error!usize {
|
||||
var out_bits: usize = undefined;
|
||||
var out_bits_total = @as(usize, 0);
|
||||
//@NOTE: I'm not sure this is a good idea, maybe alignToByte should be forced
|
||||
if (self.bit_count > 0) {
|
||||
for (buffer) |*b| {
|
||||
b.* = try self.readBits(u8, u8_bit_count, &out_bits);
|
||||
out_bits_total += out_bits;
|
||||
}
|
||||
const incomplete_byte = @intFromBool(out_bits_total % u8_bit_count > 0);
|
||||
return (out_bits_total / u8_bit_count) + incomplete_byte;
|
||||
const final_byte = self.reader.readByte() catch |err| switch (err) {
|
||||
error.EndOfStream => return initBits(T, out, out_count),
|
||||
else => |e| return e,
|
||||
};
|
||||
|
||||
switch (endian) {
|
||||
.big => {
|
||||
out <<= @intCast(bits_left);
|
||||
out |= final_byte >> @intCast(keep);
|
||||
self.bits = final_byte & low_bit_mask[keep];
|
||||
},
|
||||
.little => {
|
||||
const pos = @as(U, final_byte & low_bit_mask[bits_left]) << @intCast(out_count);
|
||||
out |= pos;
|
||||
self.bits = final_byte >> @intCast(bits_left);
|
||||
},
|
||||
}
|
||||
|
||||
return self.forward_reader.read(buffer);
|
||||
self.count = @intCast(keep);
|
||||
return initBits(T, out, num);
|
||||
}
|
||||
|
||||
pub fn reader(self: *Self) Reader {
|
||||
return .{ .context = self };
|
||||
//convenience function for removing bits from
|
||||
//the appropriate part of the buffer based on
|
||||
//endianess.
|
||||
fn removeBits(self: *@This(), num: u4) u8 {
|
||||
if (num == 8) {
|
||||
self.count = 0;
|
||||
return self.bits;
|
||||
}
|
||||
|
||||
const keep = self.count - num;
|
||||
const bits = switch (endian) {
|
||||
.big => self.bits >> @intCast(keep),
|
||||
.little => self.bits & low_bit_mask[num],
|
||||
};
|
||||
switch (endian) {
|
||||
.big => self.bits &= low_bit_mask[keep],
|
||||
.little => self.bits >>= @intCast(num),
|
||||
}
|
||||
|
||||
self.count = keep;
|
||||
return bits;
|
||||
}
|
||||
|
||||
pub fn alignToByte(self: *@This()) void {
|
||||
self.bits = 0;
|
||||
self.count = 0;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn bitReader(
|
||||
comptime endian: std.builtin.Endian,
|
||||
underlying_stream: anytype,
|
||||
) BitReader(endian, @TypeOf(underlying_stream)) {
|
||||
return BitReader(endian, @TypeOf(underlying_stream)).init(underlying_stream);
|
||||
pub fn bitReader(comptime endian: std.builtin.Endian, reader: anytype) BitReader(endian, @TypeOf(reader)) {
|
||||
return .{ .reader = reader };
|
||||
}
|
||||
|
||||
///////////////////////////////
|
||||
|
||||
test "api coverage" {
|
||||
const mem_be = [_]u8{ 0b11001101, 0b00001011 };
|
||||
const mem_le = [_]u8{ 0b00011101, 0b10010101 };
|
||||
|
||||
var mem_in_be = io.fixedBufferStream(&mem_be);
|
||||
var mem_in_be = std.io.fixedBufferStream(&mem_be);
|
||||
var bit_stream_be = bitReader(.big, mem_in_be.reader());
|
||||
|
||||
var out_bits: usize = undefined;
|
||||
var out_bits: u16 = undefined;
|
||||
|
||||
const expect = testing.expect;
|
||||
const expectError = testing.expectError;
|
||||
const expect = std.testing.expect;
|
||||
const expectError = std.testing.expectError;
|
||||
|
||||
try expect(1 == try bit_stream_be.readBits(u2, 1, &out_bits));
|
||||
try expect(out_bits == 1);
|
||||
@@ -186,12 +189,12 @@ test "api coverage" {
|
||||
try expect(out_bits == 1);
|
||||
|
||||
mem_in_be.pos = 0;
|
||||
bit_stream_be.bit_count = 0;
|
||||
bit_stream_be.count = 0;
|
||||
try expect(0b110011010000101 == try bit_stream_be.readBits(u15, 15, &out_bits));
|
||||
try expect(out_bits == 15);
|
||||
|
||||
mem_in_be.pos = 0;
|
||||
bit_stream_be.bit_count = 0;
|
||||
bit_stream_be.count = 0;
|
||||
try expect(0b1100110100001011 == try bit_stream_be.readBits(u16, 16, &out_bits));
|
||||
try expect(out_bits == 16);
|
||||
|
||||
@@ -201,7 +204,7 @@ test "api coverage" {
|
||||
try expect(out_bits == 0);
|
||||
try expectError(error.EndOfStream, bit_stream_be.readBitsNoEof(u1, 1));
|
||||
|
||||
var mem_in_le = io.fixedBufferStream(&mem_le);
|
||||
var mem_in_le = std.io.fixedBufferStream(&mem_le);
|
||||
var bit_stream_le = bitReader(.little, mem_in_le.reader());
|
||||
|
||||
try expect(1 == try bit_stream_le.readBits(u2, 1, &out_bits));
|
||||
@@ -218,12 +221,12 @@ test "api coverage" {
|
||||
try expect(out_bits == 1);
|
||||
|
||||
mem_in_le.pos = 0;
|
||||
bit_stream_le.bit_count = 0;
|
||||
bit_stream_le.count = 0;
|
||||
try expect(0b001010100011101 == try bit_stream_le.readBits(u15, 15, &out_bits));
|
||||
try expect(out_bits == 15);
|
||||
|
||||
mem_in_le.pos = 0;
|
||||
bit_stream_le.bit_count = 0;
|
||||
bit_stream_le.count = 0;
|
||||
try expect(0b1001010100011101 == try bit_stream_le.readBits(u16, 16, &out_bits));
|
||||
try expect(out_bits == 16);
|
||||
|
||||
|
||||
@@ -1,153 +1,138 @@
|
||||
const std = @import("../std.zig");
|
||||
const io = std.io;
|
||||
const testing = std.testing;
|
||||
const assert = std.debug.assert;
|
||||
const math = std.math;
|
||||
|
||||
/// Creates a stream which allows for writing bit fields to another stream
|
||||
pub fn BitWriter(comptime endian: std.builtin.Endian, comptime WriterType: type) type {
|
||||
//General note on endianess:
|
||||
//Big endian is packed starting in the most significant part of the byte and subsequent
|
||||
// bytes contain less significant bits. Thus we write out bits from the high end
|
||||
// of our input first.
|
||||
//Little endian is packed starting in the least significant part of the byte and
|
||||
// subsequent bytes contain more significant bits. Thus we write out bits from
|
||||
// the low end of our input first.
|
||||
//Regardless of endianess, within any given byte the bits are always in most
|
||||
// to least significant order.
|
||||
//Also regardless of endianess, the buffer always aligns bits to the low end
|
||||
// of the byte.
|
||||
|
||||
/// Creates a bit writer which allows for writing bits to an underlying standard writer
|
||||
pub fn BitWriter(comptime endian: std.builtin.Endian, comptime Writer: type) type {
|
||||
return struct {
|
||||
forward_writer: WriterType,
|
||||
bit_buffer: u8,
|
||||
bit_count: u4,
|
||||
writer: Writer,
|
||||
bits: u8 = 0,
|
||||
count: u4 = 0,
|
||||
|
||||
pub const Error = WriterType.Error;
|
||||
pub const Writer = io.Writer(*Self, Error, write);
|
||||
const low_bit_mask = [9]u8{
|
||||
0b00000000,
|
||||
0b00000001,
|
||||
0b00000011,
|
||||
0b00000111,
|
||||
0b00001111,
|
||||
0b00011111,
|
||||
0b00111111,
|
||||
0b01111111,
|
||||
0b11111111,
|
||||
};
|
||||
|
||||
const Self = @This();
|
||||
const u8_bit_count = @bitSizeOf(u8);
|
||||
const u4_bit_count = @bitSizeOf(u4);
|
||||
|
||||
pub fn init(forward_writer: WriterType) Self {
|
||||
return Self{
|
||||
.forward_writer = forward_writer,
|
||||
.bit_buffer = 0,
|
||||
.bit_count = 0,
|
||||
};
|
||||
}
|
||||
|
||||
/// Write the specified number of bits to the stream from the least significant bits of
|
||||
/// the specified unsigned int value. Bits will only be written to the stream when there
|
||||
/// Write the specified number of bits to the writer from the least significant bits of
|
||||
/// the specified value. Bits will only be written to the writer when there
|
||||
/// are enough to fill a byte.
|
||||
pub fn writeBits(self: *Self, value: anytype, bits: usize) Error!void {
|
||||
if (bits == 0) return;
|
||||
pub fn writeBits(self: *@This(), value: anytype, num: u16) !void {
|
||||
const T = @TypeOf(value);
|
||||
const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
|
||||
const U = if (@bitSizeOf(T) < 8) u8 else UT; //<u8 is a pain to work with
|
||||
|
||||
const U = @TypeOf(value);
|
||||
comptime assert(@typeInfo(U).int.signedness == .unsigned);
|
||||
var in: U = @as(UT, @bitCast(value));
|
||||
var in_count: u16 = num;
|
||||
|
||||
//by extending the buffer to a minimum of u8 we can cover a number of edge cases
|
||||
// related to shifting and casting.
|
||||
const u_bit_count = @bitSizeOf(U);
|
||||
const buf_bit_count = bc: {
|
||||
assert(u_bit_count >= bits);
|
||||
break :bc if (u_bit_count <= u8_bit_count) u8_bit_count else u_bit_count;
|
||||
};
|
||||
const Buf = std.meta.Int(.unsigned, buf_bit_count);
|
||||
const BufShift = math.Log2Int(Buf);
|
||||
if (self.count > 0) {
|
||||
//if we can't fill the buffer, add what we have
|
||||
const bits_free = 8 - self.count;
|
||||
if (num < bits_free) {
|
||||
self.addBits(@truncate(in), @intCast(num));
|
||||
return;
|
||||
}
|
||||
|
||||
const buf_value = @as(Buf, @intCast(value));
|
||||
//finish filling the buffer and flush it
|
||||
if (num == bits_free) {
|
||||
self.addBits(@truncate(in), @intCast(num));
|
||||
return self.flushBits();
|
||||
}
|
||||
|
||||
const high_byte_shift = @as(BufShift, @intCast(buf_bit_count - u8_bit_count));
|
||||
var in_buffer = switch (endian) {
|
||||
.big => buf_value << @as(BufShift, @intCast(buf_bit_count - bits)),
|
||||
.little => buf_value,
|
||||
};
|
||||
var in_bits = bits;
|
||||
|
||||
if (self.bit_count > 0) {
|
||||
const bits_remaining = u8_bit_count - self.bit_count;
|
||||
const n = @as(u3, @intCast(if (bits_remaining > bits) bits else bits_remaining));
|
||||
switch (endian) {
|
||||
.big => {
|
||||
const shift = @as(BufShift, @intCast(high_byte_shift + self.bit_count));
|
||||
const v = @as(u8, @intCast(in_buffer >> shift));
|
||||
self.bit_buffer |= v;
|
||||
in_buffer <<= n;
|
||||
const bits = in >> @intCast(in_count - bits_free);
|
||||
self.addBits(@truncate(bits), bits_free);
|
||||
},
|
||||
.little => {
|
||||
const v = @as(u8, @truncate(in_buffer)) << @as(u3, @intCast(self.bit_count));
|
||||
self.bit_buffer |= v;
|
||||
in_buffer >>= n;
|
||||
self.addBits(@truncate(in), bits_free);
|
||||
in >>= @intCast(bits_free);
|
||||
},
|
||||
}
|
||||
self.bit_count += n;
|
||||
in_bits -= n;
|
||||
|
||||
//if we didn't fill the buffer, it's because bits < bits_remaining;
|
||||
if (self.bit_count != u8_bit_count) return;
|
||||
try self.forward_writer.writeByte(self.bit_buffer);
|
||||
self.bit_buffer = 0;
|
||||
self.bit_count = 0;
|
||||
in_count -= bits_free;
|
||||
try self.flushBits();
|
||||
}
|
||||
//at this point we know bit_buffer is empty
|
||||
|
||||
//copy bytes until we can't fill one anymore, then leave the rest in bit_buffer
|
||||
while (in_bits >= u8_bit_count) {
|
||||
//write full bytes while we can
|
||||
const full_bytes_left = in_count / 8;
|
||||
for (0..full_bytes_left) |_| {
|
||||
switch (endian) {
|
||||
.big => {
|
||||
const v = @as(u8, @intCast(in_buffer >> high_byte_shift));
|
||||
try self.forward_writer.writeByte(v);
|
||||
in_buffer <<= @as(u3, @intCast(u8_bit_count - 1));
|
||||
in_buffer <<= 1;
|
||||
const bits = in >> @intCast(in_count - 8);
|
||||
try self.writer.writeByte(@truncate(bits));
|
||||
},
|
||||
.little => {
|
||||
const v = @as(u8, @truncate(in_buffer));
|
||||
try self.forward_writer.writeByte(v);
|
||||
in_buffer >>= @as(u3, @intCast(u8_bit_count - 1));
|
||||
in_buffer >>= 1;
|
||||
try self.writer.writeByte(@truncate(in));
|
||||
if (U == u8) in = 0 else in >>= 8;
|
||||
},
|
||||
}
|
||||
in_bits -= u8_bit_count;
|
||||
in_count -= 8;
|
||||
}
|
||||
|
||||
if (in_bits > 0) {
|
||||
self.bit_count = @as(u4, @intCast(in_bits));
|
||||
self.bit_buffer = switch (endian) {
|
||||
.big => @as(u8, @truncate(in_buffer >> high_byte_shift)),
|
||||
.little => @as(u8, @truncate(in_buffer)),
|
||||
};
|
||||
//save the remaining bits in the buffer
|
||||
self.addBits(@truncate(in), @intCast(in_count));
|
||||
}
|
||||
|
||||
//convenience funciton for adding bits to the buffer
|
||||
//in the appropriate position based on endianess
|
||||
fn addBits(self: *@This(), bits: u8, num: u4) void {
|
||||
if (num == 8) self.bits = bits else switch (endian) {
|
||||
.big => {
|
||||
self.bits <<= @intCast(num);
|
||||
self.bits |= bits & low_bit_mask[num];
|
||||
},
|
||||
.little => {
|
||||
const pos = bits << @intCast(self.count);
|
||||
self.bits |= pos;
|
||||
},
|
||||
}
|
||||
self.count += num;
|
||||
}
|
||||
|
||||
/// Flush any remaining bits to the stream.
|
||||
pub fn flushBits(self: *Self) Error!void {
|
||||
if (self.bit_count == 0) return;
|
||||
try self.forward_writer.writeByte(self.bit_buffer);
|
||||
self.bit_buffer = 0;
|
||||
self.bit_count = 0;
|
||||
}
|
||||
|
||||
pub fn write(self: *Self, buffer: []const u8) Error!usize {
|
||||
// TODO: I'm not sure this is a good idea, maybe flushBits should be forced
|
||||
if (self.bit_count > 0) {
|
||||
for (buffer) |b|
|
||||
try self.writeBits(b, u8_bit_count);
|
||||
return buffer.len;
|
||||
}
|
||||
|
||||
return self.forward_writer.write(buffer);
|
||||
}
|
||||
|
||||
pub fn writer(self: *Self) Writer {
|
||||
return .{ .context = self };
|
||||
/// Flush any remaining bits to the writer, filling
|
||||
/// unused bits with 0s.
|
||||
pub fn flushBits(self: *@This()) !void {
|
||||
if (self.count == 0) return;
|
||||
if (endian == .big) self.bits <<= @intCast(8 - self.count);
|
||||
try self.writer.writeByte(self.bits);
|
||||
self.bits = 0;
|
||||
self.count = 0;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn bitWriter(
|
||||
comptime endian: std.builtin.Endian,
|
||||
underlying_stream: anytype,
|
||||
) BitWriter(endian, @TypeOf(underlying_stream)) {
|
||||
return BitWriter(endian, @TypeOf(underlying_stream)).init(underlying_stream);
|
||||
pub fn bitWriter(comptime endian: std.builtin.Endian, writer: anytype) BitWriter(endian, @TypeOf(writer)) {
|
||||
return .{ .writer = writer };
|
||||
}
|
||||
|
||||
///////////////////////////////
|
||||
|
||||
test "api coverage" {
|
||||
var mem_be = [_]u8{0} ** 2;
|
||||
var mem_le = [_]u8{0} ** 2;
|
||||
|
||||
var mem_out_be = io.fixedBufferStream(&mem_be);
|
||||
var mem_out_be = std.io.fixedBufferStream(&mem_be);
|
||||
var bit_stream_be = bitWriter(.big, mem_out_be.writer());
|
||||
|
||||
const testing = std.testing;
|
||||
|
||||
try bit_stream_be.writeBits(@as(u2, 1), 1);
|
||||
try bit_stream_be.writeBits(@as(u5, 2), 2);
|
||||
try bit_stream_be.writeBits(@as(u128, 3), 3);
|
||||
@@ -169,7 +154,7 @@ test "api coverage" {
|
||||
|
||||
try bit_stream_be.writeBits(@as(u0, 0), 0);
|
||||
|
||||
var mem_out_le = io.fixedBufferStream(&mem_le);
|
||||
var mem_out_le = std.io.fixedBufferStream(&mem_le);
|
||||
var bit_stream_le = bitWriter(.little, mem_out_le.writer());
|
||||
|
||||
try bit_stream_le.writeBits(@as(u2, 1), 1);
|
||||
|
||||
@@ -82,7 +82,7 @@ test "BitStreams with File Stream" {
|
||||
|
||||
var bit_stream = io.bitReader(native_endian, file.reader());
|
||||
|
||||
var out_bits: usize = undefined;
|
||||
var out_bits: u16 = undefined;
|
||||
|
||||
try expect(1 == try bit_stream.readBits(u2, 1, &out_bits));
|
||||
try expect(out_bits == 1);
|
||||
|
||||
Reference in New Issue
Block a user