commit d4ebfa87633aa632a290dfffa38712468e5512db (tree)
parent 66204b780699ca6f043d312b4c637c25d6d9f3eb
Author: Andrew Kelley <andrew@ziglang.org>
Date: Mon, 4 Oct 2021 15:31:08 -0400
Merge pull request #9880 from squeek502/deflate-construct-errors
deflate: Better Huffman.construct errors and error handling
Diffstat:
1 file changed, 48 insertions(+), 7 deletions(-)
diff --git a/lib/std/compress/deflate.zig b/lib/std/compress/deflate.zig
@@ -45,7 +45,9 @@ const Huffman = struct {
min_code_len: u16,
- fn construct(self: *Huffman, code_length: []const u16) !void {
+ const ConstructError = error{ Oversubscribed, IncompleteSet };
+
+ fn construct(self: *Huffman, code_length: []const u16) ConstructError!void {
for (self.count) |*val| {
val.* = 0;
}
@@ -70,7 +72,7 @@ const Huffman = struct {
// Make sure the number of codes with this length isn't too high.
left -= @as(isize, @bitCast(i16, val));
if (left < 0)
- return error.InvalidTree;
+ return error.Oversubscribed;
}
// Compute the offset of the first symbol represented by a code of a
@@ -125,6 +127,9 @@ const Huffman = struct {
self.last_code = codes[PREFIX_LUT_BITS + 1];
self.last_index = offset[PREFIX_LUT_BITS + 1] - self.count[PREFIX_LUT_BITS + 1];
+
+ if (left > 0)
+ return error.IncompleteSet;
}
};
@@ -324,7 +329,13 @@ pub fn InflateStream(comptime ReaderType: type) type {
try lencode.construct(len_lengths[0..]);
const dist_lengths = [_]u16{5} ** MAXDCODES;
- try distcode.construct(dist_lengths[0..]);
+ distcode.construct(dist_lengths[0..]) catch |err| switch (err) {
+ // This error is expected because we only compute distance codes
+ // 0-29, which is fine since "distance codes 30-31 will never actually
+ // occur in the compressed data" (from section 3.2.6 of RFC1951).
+ error.IncompleteSet => {},
+ else => return err,
+ };
}
self.hlen = &lencode;
@@ -359,7 +370,7 @@ pub fn InflateStream(comptime ReaderType: type) type {
lengths[val] = @intCast(u16, try self.readBits(3));
}
- try lencode.construct(lengths[0..]);
+ lencode.construct(lengths[0..]) catch return error.InvalidTree;
}
// Read the length/literal and distance code length tables.
@@ -408,8 +419,24 @@ pub fn InflateStream(comptime ReaderType: type) type {
if (lengths[256] == 0)
return error.MissingEOBCode;
- try self.huffman_tables[0].construct(lengths[0..nlen]);
- try self.huffman_tables[1].construct(lengths[nlen .. nlen + ndist]);
+ self.huffman_tables[0].construct(lengths[0..nlen]) catch |err| switch (err) {
+ error.Oversubscribed => return error.InvalidTree,
+ error.IncompleteSet => {
+ // incomplete code ok only for single length 1 code
+ if (nlen != self.huffman_tables[0].count[0] + self.huffman_tables[0].count[1]) {
+ return error.InvalidTree;
+ }
+ },
+ };
+ self.huffman_tables[1].construct(lengths[nlen .. nlen + ndist]) catch |err| switch (err) {
+ error.Oversubscribed => return error.InvalidTree,
+ error.IncompleteSet => {
+ // incomplete code ok only for single length 1 code
+ if (ndist != self.huffman_tables[1].count[0] + self.huffman_tables[1].count[1]) {
+ return error.InvalidTree;
+ }
+ },
+ };
self.hlen = &self.huffman_tables[0];
self.hdist = &self.huffman_tables[1];
@@ -684,8 +711,22 @@ test "distance past beginning of output stream" {
test "inflateStream fuzzing" {
// see https://github.com/ziglang/zig/issues/9842
- try std.testing.expectError(error.EndOfStream, testInflate("\x950000"));
+ try std.testing.expectError(error.EndOfStream, testInflate("\x95\x90=o\xc20\x10\x86\xf30"));
try std.testing.expectError(error.OutOfCodes, testInflate("\x950\x00\x0000000"));
+
+ // Huffman.construct errors
+ // lencode
+ try std.testing.expectError(error.InvalidTree, testInflate("\x950000"));
+ try std.testing.expectError(error.InvalidTree, testInflate("\x05000"));
+ // hlen
+ try std.testing.expectError(error.InvalidTree, testInflate("\x05\xea\x01\t\x00\x00\x00\x01\x00\\\xbf.\t\x00"));
+ // hdist
+ try std.testing.expectError(error.InvalidTree, testInflate("\x05\xe0\x01A\x00\x00\x00\x00\x10\\\xbf."));
+
+ // Huffman.construct -> error.IncompleteSet returns that shouldn't give error.InvalidTree
+ // (like the "empty distance alphabet" test but for ndist instead of nlen)
+ try std.testing.expectError(error.EndOfStream, testInflate("\x05\xe0\x01\t\x00\x00\x00\x00\x10\\\xbf\xce"));
+ try testInflate("\x15\xe0\x01\t\x00\x00\x00\x00\x10\\\xbf.0");
}
fn testInflate(data: []const u8) !void {