From a9c8376305d4c99591432e0ed267fad665bb4f5f Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Sat, 4 Feb 2023 13:49:53 +1100 Subject: [PATCH] std.compress.zstandard: make ZstandardStream decode multiple frames --- lib/std/compress/zstandard.zig | 144 +++++++++++++++++++++------------ 1 file changed, 94 insertions(+), 50 deletions(-) diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index 5b6f928db5..00f475df00 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -13,6 +13,7 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool allocator: Allocator, in_reader: ReaderType, + state: enum { NewFrame, InFrame }, decode_state: decompress.block.DecodeState, frame_context: decompress.FrameContext, buffer: RingBuffer, @@ -24,16 +25,43 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool sequence_buffer: []u8, checksum: if (verify_checksum) ?u32 else void, - pub const Error = ReaderType.Error || error{ MalformedBlock, MalformedFrame }; + pub const Error = ReaderType.Error || error{ ChecksumFailure, MalformedBlock, MalformedFrame, OutOfMemory }; pub const Reader = std.io.Reader(*Self, Error, read); pub fn init(allocator: Allocator, source: ReaderType) !Self { - switch (try decompress.decodeFrameType(source)) { - .skippable => return error.SkippableFrame, + return Self{ + .allocator = allocator, + .in_reader = source, + .state = .NewFrame, + .decode_state = undefined, + .frame_context = undefined, + .buffer = undefined, + .last_block = undefined, + .literal_fse_buffer = undefined, + .match_fse_buffer = undefined, + .offset_fse_buffer = undefined, + .literals_buffer = undefined, + .sequence_buffer = undefined, + .checksum = undefined, + }; + } + + fn frameInit(self: *Self) !void { + var bytes: [4]u8 = undefined; + const bytes_read = try self.in_reader.readAll(&bytes); + if (bytes_read == 0) return error.NoBytes; + if (bytes_read < 4) return error.EndOfStream; + const frame_type = try decompress.frameType(std.mem.readIntLittle(u32, &bytes)); + switch (frame_type) { + .skippable => { + const size = try self.in_reader.readIntLittle(u32); + try self.in_reader.skipBytes(size, .{}); + self.state = .NewFrame; + }, .zstandard => { const frame_context = context: { - const frame_header = try decompress.decodeZstandardHeader(source); + const frame_header = try decompress.decodeZstandardHeader(self.in_reader); break :context try decompress.FrameContext.init( frame_header, window_size_max, @@ -41,56 +69,58 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool ); }; - const literal_fse_buffer = try allocator.alloc( + const literal_fse_buffer = try self.allocator.alloc( types.compressed_block.Table.Fse, types.compressed_block.table_size_max.literal, ); - errdefer allocator.free(literal_fse_buffer); + errdefer self.allocator.free(literal_fse_buffer); - const match_fse_buffer = try allocator.alloc( + const match_fse_buffer = try self.allocator.alloc( types.compressed_block.Table.Fse, types.compressed_block.table_size_max.match, ); - errdefer allocator.free(match_fse_buffer); + errdefer self.allocator.free(match_fse_buffer); - const offset_fse_buffer = try allocator.alloc( + const offset_fse_buffer = try self.allocator.alloc( types.compressed_block.Table.Fse, types.compressed_block.table_size_max.offset, ); - errdefer allocator.free(offset_fse_buffer); + errdefer self.allocator.free(offset_fse_buffer); const decode_state = decompress.block.DecodeState.init( literal_fse_buffer, match_fse_buffer, offset_fse_buffer, ); - const buffer = try RingBuffer.init(allocator, frame_context.window_size); + const buffer = try RingBuffer.init(self.allocator, frame_context.window_size); - const literals_data = try allocator.alloc(u8, window_size_max); - errdefer allocator.free(literals_data); + const literals_data = try self.allocator.alloc(u8, window_size_max); + errdefer self.allocator.free(literals_data); - const sequence_data = try allocator.alloc(u8, window_size_max); - errdefer allocator.free(sequence_data); + const sequence_data = try self.allocator.alloc(u8, window_size_max); + errdefer self.allocator.free(sequence_data); - return Self{ - .allocator = allocator, - .in_reader = source, - .decode_state = decode_state, - .frame_context = frame_context, - .buffer = buffer, - .checksum = if (verify_checksum) null else {}, - .last_block = false, - .literal_fse_buffer = literal_fse_buffer, - .match_fse_buffer = match_fse_buffer, - .offset_fse_buffer = offset_fse_buffer, - .literals_buffer = literals_data, - .sequence_buffer = sequence_data, - }; + self.literal_fse_buffer = literal_fse_buffer; + self.match_fse_buffer = match_fse_buffer; + self.offset_fse_buffer = offset_fse_buffer; + self.literals_buffer = literals_data; + self.sequence_buffer = sequence_data; + + self.buffer = buffer; + + self.decode_state = decode_state; + self.frame_context = frame_context; + + self.checksum = if (verify_checksum) null else {}; + self.last_block = false; + + self.state = .InFrame; }, } } pub fn deinit(self: *Self) void { + if (self.state == .NewFrame) return; self.allocator.free(self.decode_state.literal_fse_buffer); self.allocator.free(self.decode_state.match_fse_buffer); self.allocator.free(self.decode_state.offset_fse_buffer); @@ -105,6 +135,19 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool pub fn read(self: *Self, buffer: []u8) Error!usize { if (buffer.len == 0) return 0; + while (self.state == .NewFrame) { + self.frameInit() catch |err| switch (err) { + error.NoBytes => return 0, + error.OutOfMemory => return error.OutOfMemory, + else => return error.MalformedFrame, + }; + } + + return self.readInner(buffer); + } + + fn readInner(self: *Self, buffer: []u8) Error!usize { + std.debug.assert(self.state == .InFrame); if (self.buffer.isEmpty() and !self.last_block) { const header_bytes = self.in_reader.readBytesNoEof(3) catch return error.MalformedFrame; @@ -127,9 +170,15 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool hasher.update(written_slice.first); hasher.update(written_slice.second); } - if (block_header.last_block and self.frame_context.has_checksum) { - const checksum = self.in_reader.readIntLittle(u32) catch return error.MalformedFrame; - if (verify_checksum) self.checksum = checksum; + if (block_header.last_block) { + if (self.frame_context.has_checksum) { + const checksum = self.in_reader.readIntLittle(u32) catch return error.MalformedFrame; + if (comptime verify_checksum) { + if (self.frame_context.hasher_opt) |*hasher| { + if (checksum != decompress.computeChecksum(hasher)) return error.ChecksumFailure; + } + } + } } } @@ -138,18 +187,16 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool while (written_count < decoded_data_len and written_count < buffer.len) : (written_count += 1) { buffer[written_count] = self.buffer.read().?; } - return written_count; - } - - pub fn verifyChecksum(self: *Self) !bool { - if (verify_checksum) { - if (self.checksum) |checksum| { - if (self.frame_context.hasher_opt) |*hasher| { - return checksum == decompress.computeChecksum(hasher); - } - } + if (self.buffer.len() == 0) { + self.state = .NewFrame; + self.allocator.free(self.literal_fse_buffer); + self.allocator.free(self.match_fse_buffer); + self.allocator.free(self.offset_fse_buffer); + self.allocator.free(self.literals_buffer); + self.allocator.free(self.sequence_buffer); + self.buffer.deinit(self.allocator); } - return true; + return written_count; } }; } @@ -163,7 +210,6 @@ fn testDecompress(data: []const u8) ![]u8 { var stream = try zstandardStream(std.testing.allocator, in_stream.reader()); defer stream.deinit(); const result = stream.reader().readAllAlloc(std.testing.allocator, std.math.maxInt(usize)); - try std.testing.expect(try stream.verifyChecksum()); return result; } @@ -181,14 +227,12 @@ test "decompression" { var buffer = try std.testing.allocator.alloc(u8, uncompressed.len); defer std.testing.allocator.free(buffer); - const res3 = try decompress.decodeFrame(buffer, compressed3, true); - try std.testing.expectEqual(compressed3.len, res3.read_count); - try std.testing.expectEqual(uncompressed.len, res3.write_count); + const res3 = try decompress.decode(buffer, compressed3, true); + try std.testing.expectEqual(uncompressed.len, res3); try std.testing.expectEqualSlices(u8, uncompressed, buffer); - const res19 = try decompress.decodeFrame(buffer, compressed19, true); - try std.testing.expectEqual(compressed19.len, res19.read_count); - try std.testing.expectEqual(uncompressed.len, res19.write_count); + const res19 = try decompress.decode(buffer, compressed19, true); + try std.testing.expectEqual(uncompressed.len, res19); try std.testing.expectEqualSlices(u8, uncompressed, buffer); try testReader(compressed3, uncompressed);