commit 83e578a181e33eedd57666376dab371b7ae58d5b (tree)
parent 97aa5f7b8a059f91b78ce7cd70cba0f3aa2c5118
Author: Andrew Kelley <andrew@ziglang.org>
Date: Thu, 7 Mar 2024 18:46:47 -0800
Merge pull request #19163 from ianic/zlib_no_lookahead
compress.zlib: don't overshoot underlying reader
Diffstat:
6 files changed, 201 insertions(+), 60 deletions(-)
diff --git a/lib/std/compress/flate.zig b/lib/std/compress/flate.zig
@@ -13,7 +13,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void {
/// Decompressor type
pub fn Decompressor(comptime ReaderType: type) type {
- return inflate.Inflate(.raw, ReaderType);
+ return inflate.Decompressor(.raw, ReaderType);
}
/// Create Decompressor which will read compressed data from reader.
diff --git a/lib/std/compress/flate/bit_reader.zig b/lib/std/compress/flate/bit_reader.zig
@@ -2,8 +2,16 @@ const std = @import("std");
const assert = std.debug.assert;
const testing = std.testing;
-pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
- return BitReader(@TypeOf(reader)).init(reader);
+pub fn bitReader(comptime T: type, reader: anytype) BitReader(T, @TypeOf(reader)) {
+ return BitReader(T, @TypeOf(reader)).init(reader);
+}
+
+pub fn BitReader64(comptime ReaderType: type) type {
+ return BitReader(u64, ReaderType);
+}
+
+pub fn BitReader32(comptime ReaderType: type) type {
+ return BitReader(u32, ReaderType);
}
/// Bit reader used during inflate (decompression). Has internal buffer of 64
@@ -15,12 +23,16 @@ pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
/// fill buffer from forward_reader by calling fill in advance and readF with
/// buffered flag set.
///
-pub fn BitReader(comptime ReaderType: type) type {
+pub fn BitReader(comptime T: type, comptime ReaderType: type) type {
+ assert(T == u32 or T == u64);
+ const t_bytes: usize = @sizeOf(T);
+ const Tshift = if (T == u64) u6 else u5;
+
return struct {
// Underlying reader used for filling internal bits buffer
forward_reader: ReaderType = undefined,
// Internal buffer of 64 bits
- bits: u64 = 0,
+ bits: T = 0,
// Number of bits in the buffer
nbits: u32 = 0,
@@ -44,21 +56,21 @@ pub fn BitReader(comptime ReaderType: type) type {
/// that number of bits available. If end of forward stream is reached
/// it may be some extra zero bits in buffer.
pub inline fn fill(self: *Self, nice: u6) !void {
- if (self.nbits >= nice) {
+ if (self.nbits >= nice and nice != 0) {
return; // We have enought bits
}
// Read more bits from forward reader
// Number of empty bytes in bits, round nbits to whole bytes.
const empty_bytes =
- @as(u8, if (self.nbits & 0x7 == 0) 8 else 7) - // 8 for 8, 16, 24..., 7 otherwise
+ @as(u8, if (self.nbits & 0x7 == 0) t_bytes else t_bytes - 1) - // 8 for 8, 16, 24..., 7 otherwise
(self.nbits >> 3); // 0 for 0-7, 1 for 8-16, ... same as / 8
- var buf: [8]u8 = [_]u8{0} ** 8;
+ var buf: [t_bytes]u8 = [_]u8{0} ** t_bytes;
const bytes_read = self.forward_reader.readAll(buf[0..empty_bytes]) catch 0;
if (bytes_read > 0) {
- const u: u64 = std.mem.readInt(u64, buf[0..8], .little);
- self.bits |= u << @as(u6, @intCast(self.nbits));
+ const u: T = std.mem.readInt(T, buf[0..t_bytes], .little);
+ self.bits |= u << @as(Tshift, @intCast(self.nbits));
self.nbits += 8 * @as(u8, @intCast(bytes_read));
return;
}
@@ -99,7 +111,17 @@ pub fn BitReader(comptime ReaderType: type) type {
/// Read with flags provided.
pub fn readF(self: *Self, comptime U: type, comptime how: u3) !U {
- const n: u6 = @bitSizeOf(U);
+ if (U == T) {
+ assert(how == 0);
+ assert(self.alignBits() == 0);
+ try self.fill(@bitSizeOf(T));
+ if (self.nbits != @bitSizeOf(T)) return error.EndOfStream;
+ const v = self.bits;
+ self.nbits = 0;
+ self.bits = 0;
+ return v;
+ }
+ const n: Tshift = @bitSizeOf(U);
switch (how) {
0 => { // `normal` read
try self.fill(n); // ensure that there are n bits in the buffer
@@ -157,7 +179,7 @@ pub fn BitReader(comptime ReaderType: type) type {
}
/// Advance buffer for n bits.
- pub fn shift(self: *Self, n: u6) !void {
+ pub fn shift(self: *Self, n: Tshift) !void {
if (n > self.nbits) return error.EndOfStream;
self.bits >>= n;
self.nbits -= n;
@@ -218,10 +240,10 @@ pub fn BitReader(comptime ReaderType: type) type {
};
}
-test "BitReader" {
+test "readF" {
var fbs = std.io.fixedBufferStream(&[_]u8{ 0xf3, 0x48, 0xcd, 0xc9, 0x00, 0x00 });
- var br = bitReader(fbs.reader());
- const F = BitReader(@TypeOf(fbs.reader())).flag;
+ var br = bitReader(u64, fbs.reader());
+ const F = BitReader64(@TypeOf(fbs.reader())).flag;
try testing.expectEqual(@as(u8, 48), br.nbits);
try testing.expectEqual(@as(u64, 0xc9cd48f3), br.bits);
@@ -254,36 +276,38 @@ test "BitReader" {
}
test "read block type 1 data" {
- const data = [_]u8{
- 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
- 0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
- 0x0c, 0x01, 0x02, 0x03, //
- 0xaa, 0xbb, 0xcc, 0xdd,
- };
- var fbs = std.io.fixedBufferStream(&data);
- var br = bitReader(fbs.reader());
- const F = BitReader(@TypeOf(fbs.reader())).flag;
+ inline for ([_]type{ u64, u32 }) |T| {
+ const data = [_]u8{
+ 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
+ 0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
+ 0x0c, 0x01, 0x02, 0x03, //
+ 0xaa, 0xbb, 0xcc, 0xdd,
+ };
+ var fbs = std.io.fixedBufferStream(&data);
+ var br = bitReader(T, fbs.reader());
+ const F = BitReader(T, @TypeOf(fbs.reader())).flag;
- try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal
- try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type
+ try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal
+ try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type
- for ("Hello world\n") |c| {
- try testing.expectEqual(@as(u8, c), try br.readF(u8, F.reverse) - 0x30);
+ for ("Hello world\n") |c| {
+ try testing.expectEqual(@as(u8, c), try br.readF(u8, F.reverse) - 0x30);
+ }
+ try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block
+ br.alignToByte();
+ try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0));
+ try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0));
+ try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0));
}
- try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block
- br.alignToByte();
- try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0));
- try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0));
- try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0));
}
-test "init" {
+test "shift/fill" {
const data = [_]u8{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
};
var fbs = std.io.fixedBufferStream(&data);
- var br = bitReader(fbs.reader());
+ var br = bitReader(u64, fbs.reader());
try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits);
try br.shift(8);
@@ -303,31 +327,96 @@ test "init" {
}
test "readAll" {
- const data = [_]u8{
- 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
- 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
- };
- var fbs = std.io.fixedBufferStream(&data);
- var br = bitReader(fbs.reader());
+ inline for ([_]type{ u64, u32 }) |T| {
+ const data = [_]u8{
+ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
+ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
+ };
+ var fbs = std.io.fixedBufferStream(&data);
+ var br = bitReader(T, fbs.reader());
- try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits);
+ switch (T) {
+ u64 => try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits),
+ u32 => try testing.expectEqual(@as(u32, 0x04_03_02_01), br.bits),
+ else => unreachable,
+ }
- var out: [16]u8 = undefined;
- try br.readAll(out[0..]);
- try testing.expect(br.nbits == 0);
- try testing.expect(br.bits == 0);
+ var out: [16]u8 = undefined;
+ try br.readAll(out[0..]);
+ try testing.expect(br.nbits == 0);
+ try testing.expect(br.bits == 0);
- try testing.expectEqualSlices(u8, data[0..16], &out);
+ try testing.expectEqualSlices(u8, data[0..16], &out);
+ }
}
test "readFixedCode" {
- const fixed_codes = @import("huffman_encoder.zig").fixed_codes;
+ inline for ([_]type{ u64, u32 }) |T| {
+ const fixed_codes = @import("huffman_encoder.zig").fixed_codes;
- var fbs = std.io.fixedBufferStream(&fixed_codes);
- var rdr = bitReader(fbs.reader());
+ var fbs = std.io.fixedBufferStream(&fixed_codes);
+ var rdr = bitReader(T, fbs.reader());
- for (0..286) |c| {
- try testing.expectEqual(c, try rdr.readFixedCode());
+ for (0..286) |c| {
+ try testing.expectEqual(c, try rdr.readFixedCode());
+ }
+ try testing.expect(rdr.nbits == 0);
}
- try testing.expect(rdr.nbits == 0);
+}
+
+test "u32 leaves no bits on u32 reads" {
+ const data = [_]u8{
+ 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
+ };
+ var fbs = std.io.fixedBufferStream(&data);
+ var br = bitReader(u32, fbs.reader());
+
+ _ = try br.read(u3);
+ try testing.expectEqual(29, br.nbits);
+ br.alignToByte();
+ try testing.expectEqual(24, br.nbits);
+ try testing.expectEqual(0x04_03_02_01, try br.read(u32));
+ try testing.expectEqual(0, br.nbits);
+ try testing.expectEqual(0x08_07_06_05, try br.read(u32));
+ try testing.expectEqual(0, br.nbits);
+
+ _ = try br.read(u9);
+ try testing.expectEqual(23, br.nbits);
+ br.alignToByte();
+ try testing.expectEqual(16, br.nbits);
+ try testing.expectEqual(0x0e_0d_0c_0b, try br.read(u32));
+ try testing.expectEqual(0, br.nbits);
+}
+
+test "u64 need fill after alignToByte" {
+ const data = [_]u8{
+ 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
+ 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
+ };
+
+ // without fill
+ var fbs = std.io.fixedBufferStream(&data);
+ var br = bitReader(u64, fbs.reader());
+ _ = try br.read(u23);
+ try testing.expectEqual(41, br.nbits);
+ br.alignToByte();
+ try testing.expectEqual(40, br.nbits);
+ try testing.expectEqual(0x06_05_04_03, try br.read(u32));
+ try testing.expectEqual(8, br.nbits);
+ try testing.expectEqual(0x0a_09_08_07, try br.read(u32));
+ try testing.expectEqual(32, br.nbits);
+
+ // fill after align ensures all bits filled
+ fbs.reset();
+ br = bitReader(u64, fbs.reader());
+ _ = try br.read(u23);
+ try testing.expectEqual(41, br.nbits);
+ br.alignToByte();
+ try br.fill(0);
+ try testing.expectEqual(64, br.nbits);
+ try testing.expectEqual(0x06_05_04_03, try br.read(u32));
+ try testing.expectEqual(32, br.nbits);
+ try testing.expectEqual(0x0a_09_08_07, try br.read(u32));
+ try testing.expectEqual(0, br.nbits);
}
diff --git a/lib/std/compress/flate/container.zig b/lib/std/compress/flate/container.zig
@@ -154,6 +154,7 @@ pub const Container = enum {
pub fn parseFooter(comptime wrap: Container, hasher: *Hasher(wrap), reader: anytype) !void {
switch (wrap) {
.gzip => {
+ try reader.fill(0);
if (try reader.read(u32) != hasher.chksum()) return error.WrongGzipChecksum;
if (try reader.read(u32) != hasher.bytesRead()) return error.WrongGzipSize;
},
diff --git a/lib/std/compress/flate/inflate.zig b/lib/std/compress/flate/inflate.zig
@@ -17,8 +17,16 @@ pub fn decompress(comptime container: Container, reader: anytype, writer: anytyp
}
/// Inflate decompressor for the reader type.
-pub fn decompressor(comptime container: Container, reader: anytype) Inflate(container, @TypeOf(reader)) {
- return Inflate(container, @TypeOf(reader)).init(reader);
+pub fn decompressor(comptime container: Container, reader: anytype) Decompressor(container, @TypeOf(reader)) {
+ return Decompressor(container, @TypeOf(reader)).init(reader);
+}
+
+pub fn Decompressor(comptime container: Container, comptime ReaderType: type) type {
+ // zlib has 4 bytes footer, lookahead of 4 bytes ensures that we will not overshoot.
+ // gzip has 8 bytes footer so we will not overshoot even with 8 bytes of lookahead.
+ // For raw deflate there is always possibility of overshot so we use 8 bytes lookahead.
+ const lookahead: type = if (container == .zlib) u32 else u64;
+ return Inflate(container, lookahead, ReaderType);
}
/// Inflate decompresses deflate bit stream. Reads compressed data from reader
@@ -40,9 +48,12 @@ pub fn decompressor(comptime container: Container, reader: anytype) Inflate(cont
/// * 64K for history (CircularBuffer)
/// * ~10K huffman decoders (Literal and DistanceDecoder)
///
-pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
+pub fn Inflate(comptime container: Container, comptime LookaheadType: type, comptime ReaderType: type) type {
+ assert(LookaheadType == u32 or LookaheadType == u64);
+ const BitReaderType = BitReader(LookaheadType, ReaderType);
+
return struct {
- const BitReaderType = BitReader(ReaderType);
+ //const BitReaderType = BitReader(ReaderType);
const F = BitReaderType.flag;
bits: BitReaderType = .{},
@@ -219,9 +230,14 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
switch (sym.kind) {
.literal => self.hist.write(sym.symbol),
.match => { // Decode match backreference <length, distance>
- try self.bits.fill(5 + 15 + 13); // so we can use buffered reads
+ // fill so we can use buffered reads
+ if (LookaheadType == u32)
+ try self.bits.fill(5 + 15)
+ else
+ try self.bits.fill(5 + 15 + 13);
const length = try self.decodeLength(sym.symbol);
const dsm = try self.decodeSymbol(&self.dst_dec);
+ if (LookaheadType == u32) try self.bits.fill(13);
const distance = try self.decodeDistance(dsm.symbol);
try self.hist.writeMatch(length, distance);
},
diff --git a/lib/std/compress/gzip.zig b/lib/std/compress/gzip.zig
@@ -8,7 +8,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void {
/// Decompressor type
pub fn Decompressor(comptime ReaderType: type) type {
- return inflate.Inflate(.gzip, ReaderType);
+ return inflate.Decompressor(.gzip, ReaderType);
}
/// Create Decompressor which will read compressed data from reader.
diff --git a/lib/std/compress/zlib.zig b/lib/std/compress/zlib.zig
@@ -8,7 +8,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void {
/// Decompressor type
pub fn Decompressor(comptime ReaderType: type) type {
- return inflate.Inflate(.zlib, ReaderType);
+ return inflate.Decompressor(.zlib, ReaderType);
}
/// Create Decompressor which will read compressed data from reader.
@@ -64,3 +64,38 @@ pub const store = struct {
return deflate.store.compressor(.zlib, writer);
}
};
+
+test "should not overshoot" {
+ const std = @import("std");
+
+ // Compressed zlib data with extra 4 bytes at the end.
+ const data = [_]u8{
+ 0x78, 0x9c, 0x73, 0xce, 0x2f, 0xa8, 0x2c, 0xca, 0x4c, 0xcf, 0x28, 0x51, 0x08, 0xcf, 0xcc, 0xc9,
+ 0x49, 0xcd, 0x55, 0x28, 0x4b, 0xcc, 0x53, 0x08, 0x4e, 0xce, 0x48, 0xcc, 0xcc, 0xd6, 0x51, 0x08,
+ 0xce, 0xcc, 0x4b, 0x4f, 0x2c, 0xc8, 0x2f, 0x4a, 0x55, 0x30, 0xb4, 0xb4, 0x34, 0xd5, 0xb5, 0x34,
+ 0x03, 0x00, 0x8b, 0x61, 0x0f, 0xa4, 0x52, 0x5a, 0x94, 0x12,
+ };
+
+ var stream = std.io.fixedBufferStream(data[0..]);
+ const reader = stream.reader();
+
+ var dcp = decompressor(reader);
+ var out: [128]u8 = undefined;
+
+ // Decompress
+ var n = try dcp.reader().readAll(out[0..]);
+
+ // Expected decompressed data
+ try std.testing.expectEqual(46, n);
+ try std.testing.expectEqualStrings("Copyright Willem van Schaik, Singapore 1995-96", out[0..n]);
+
+ // Decompressor don't overshoot underlying reader.
+ // It is leaving it at the end of compressed data chunk.
+ try std.testing.expectEqual(data.len - 4, stream.getPos());
+ try std.testing.expectEqual(0, dcp.unreadBytes());
+
+ // 4 bytes after compressed chunk are available in reader.
+ n = try reader.readAll(out[0..]);
+ try std.testing.expectEqual(n, 4);
+ try std.testing.expectEqualSlices(u8, data[data.len - 4 .. data.len], out[0..n]);
+}