From ece52e0771560726a72a11cfafb59bb1dc9ad221 Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Sun, 5 Feb 2023 22:27:00 +1100 Subject: [PATCH] std.compress.zstandard: verify content size and fix crash --- lib/std/compress/zstandard/decode/block.zig | 17 ++++- lib/std/compress/zstandard/decompress.zig | 76 +++++++++++++++------ 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/lib/std/compress/zstandard/decode/block.zig b/lib/std/compress/zstandard/decode/block.zig index 8f97bea399..4182996d43 100644 --- a/lib/std/compress/zstandard/decode/block.zig +++ b/lib/std/compress/zstandard/decode/block.zig @@ -334,6 +334,8 @@ pub const DecodeState = struct { /// mean the literal stream or the sequence is malformed). /// - `error.InvalidBitStream` if the FSE sequence bitstream is malformed /// - `error.EndOfStream` if `bit_reader` does not contain enough bits + /// - `error.DestTooSmall` if `dest` is not large enough to holde the + /// decompressed sequence pub fn decodeSequenceSlice( self: *DecodeState, dest: []u8, @@ -341,10 +343,11 @@ pub const DecodeState = struct { bit_reader: *readers.ReverseBitReader, sequence_size_limit: usize, last_sequence: bool, - ) DecodeSequenceError!usize { + ) (error{DestTooSmall} || DecodeSequenceError)!usize { const sequence = try self.nextSequence(bit_reader); const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; if (sequence_length > sequence_size_limit) return error.MalformedSequence; + if (sequence_length > dest[write_pos..].len) return error.DestTooSmall; try self.executeSequenceSlice(dest, write_pos, sequence); if (!last_sequence) { @@ -583,6 +586,8 @@ pub const DecodeState = struct { /// - `error.MalformedRleBlock` if the block is an RLE block and `src.len < 1` /// - `error.MalformedCompressedBlock` if there are errors decoding a /// compressed block +/// - `error.DestTooSmall` is `dest` is not large enough to hold the +/// decompressed block pub fn decodeBlock( dest: []u8, src: []const u8, @@ -590,13 +595,14 @@ pub fn decodeBlock( decode_state: *DecodeState, consumed_count: *usize, written_count: usize, -) Error!usize { +) (error{DestTooSmall} || Error)!usize { const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB const block_size = block_header.block_size; if (block_size_max < block_size) return error.BlockSizeOverMaximum; switch (block_header.block_type) { .raw => { if (src.len < block_size) return error.MalformedBlockSize; + if (dest[written_count..].len < block_size) return error.DestTooSmall; const data = src[0..block_size]; std.mem.copy(u8, dest[written_count..], data); consumed_count.* += block_size; @@ -604,6 +610,7 @@ pub fn decodeBlock( }, .rle => { if (src.len < 1) return error.MalformedRleBlock; + if (dest[written_count..].len < block_size) return error.DestTooSmall; var write_pos: usize = written_count; while (write_pos < block_size + written_count) : (write_pos += 1) { dest[write_pos] = src[0]; @@ -644,7 +651,10 @@ pub fn decodeBlock( &bit_stream, sequence_size_limit, i == sequences_header.sequence_count - 1, - ) catch return error.MalformedCompressedBlock; + ) catch |err| switch (err) { + error.DestTooSmall => return error.DestTooSmall, + else => return error.MalformedCompressedBlock, + }; bytes_written += decompressed_size; sequence_size_limit -= decompressed_size; } @@ -655,6 +665,7 @@ pub fn decodeBlock( if (decode_state.literal_written_count < literals.header.regenerated_size) { const len = literals.header.regenerated_size - decode_state.literal_written_count; + if (len > dest[written_count + bytes_written ..].len) return error.DestTooSmall; decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch return error.MalformedCompressedBlock; bytes_written += len; diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 5b164ded35..73e9196657 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -96,6 +96,7 @@ pub fn decodeAlloc( /// - `error.UnknownContentSizeUnsupported` if the frame does not declare the /// uncompressed content size /// - `error.ContentTooLarge` if `dest` is smaller than the uncompressed data +/// size declared by the frame header /// - `error.BadMagic` if the first 4 bytes of `src` is not a valid magic /// number for a Zstandard or Skippable frame /// - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary @@ -180,6 +181,7 @@ pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 { const FrameError = error{ DictionaryIdFlagUnsupported, ChecksumFailure, + BadContentSize, EndOfStream, } || InvalidBit || block.Error; @@ -191,7 +193,7 @@ const FrameError = error{ /// - `error.UnknownContentSizeUnsupported` if the frame does not declare the /// uncompressed content size /// - `error.ContentTooLarge` if `dest` is smaller than the uncompressed data -/// number for a Zstandard or Skippable frame +/// size declared by the frame header /// - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary /// - `error.ChecksumFailure` if `verify_checksum` is true and the frame /// contains a checksum that does not match the checksum of the decompressed @@ -200,39 +202,51 @@ const FrameError = error{ /// - `error.UnusedBitSet` if the unused bit of the frame header is set /// - `error.EndOfStream` if `src` does not contain a complete frame /// - an error in `block.Error` if there are errors decoding a block +/// - `error.BadContentSize` if the content size declared by the frame does +/// not equal the actual size of decompressed data pub fn decodeZstandardFrame( dest: []u8, src: []const u8, verify_checksum: bool, -) (error{ UnknownContentSizeUnsupported, ContentTooLarge } || FrameError)!ReadWriteCount { +) (error{ + UnknownContentSizeUnsupported, + ContentTooLarge, + ContentSizeTooLarge, + WindowSizeUnknown, +} || FrameError)!ReadWriteCount { assert(readInt(u32, src[0..4]) == frame.Zstandard.magic_number); var consumed_count: usize = 4; - var fbs = std.io.fixedBufferStream(src[consumed_count..]); - var source = fbs.reader(); - const frame_header = try decodeZstandardHeader(source); - consumed_count += fbs.pos; + var frame_context = context: { + var fbs = std.io.fixedBufferStream(src[consumed_count..]); + var source = fbs.reader(); + const frame_header = try decodeZstandardHeader(source); + consumed_count += fbs.pos; + break :context FrameContext.init(frame_header, std.math.maxInt(usize), verify_checksum) catch |err| switch (err) { + error.WindowTooLarge => unreachable, + inline else => |e| return e, + }; + }; - if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported; - - const content_size = frame_header.content_size orelse return error.UnknownContentSizeUnsupported; + const content_size = frame_context.content_size orelse return error.UnknownContentSizeUnsupported; if (dest.len < content_size) return error.ContentTooLarge; - const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum; - var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null; - - const written_count = try decodeFrameBlocks( - dest, + const written_count = decodeFrameBlocks( + dest[0..content_size], src[consumed_count..], &consumed_count, - if (hasher_opt) |*hasher| hasher else null, - ); + if (frame_context.hasher_opt) |*hasher| hasher else null, + ) catch |err| switch (err) { + error.DestTooSmall => return error.BadContentSize, + inline else => |e| return e, + }; - if (frame_header.descriptor.content_checksum_flag) { + if (written_count != content_size) return error.BadContentSize; + if (frame_context.has_checksum) { if (src.len < consumed_count + 4) return error.EndOfStream; const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]); consumed_count += 4; - if (hasher_opt) |*hasher| { + if (frame_context.hasher_opt) |*hasher| { if (checksum != computeChecksum(hasher)) return error.ChecksumFailure; } } @@ -244,8 +258,14 @@ pub const FrameContext = struct { window_size: usize, has_checksum: bool, block_size_max: usize, + content_size: ?usize, - const Error = error{ DictionaryIdFlagUnsupported, WindowSizeUnknown, WindowTooLarge }; + const Error = error{ + DictionaryIdFlagUnsupported, + WindowSizeUnknown, + WindowTooLarge, + ContentSizeTooLarge, + }; /// Validates `frame_header` and returns the associated `FrameContext`. /// /// Errors returned: @@ -266,11 +286,18 @@ pub const FrameContext = struct { @intCast(usize, window_size_raw); const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum; + + const content_size = if (frame_header.content_size) |size| + std.math.cast(usize, size) orelse return error.ContentSizeTooLarge + else + null; + return .{ .hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null, .window_size = window_size, .has_checksum = frame_header.descriptor.content_checksum_flag, .block_size_max = @min(1 << 17, window_size), + .content_size = content_size, }; } }; @@ -294,6 +321,8 @@ pub const FrameContext = struct { /// - `error.EndOfStream` if `src` does not contain a complete frame /// - `error.OutOfMemory` if `allocator` cannot allocate enough memory /// - an error in `block.Error` if there are errors decoding a block +/// - `error.BadContentSize` if the content size declared by the frame does +/// not equal the size of decompressed data pub fn decodeZstandardFrameAlloc( allocator: Allocator, src: []const u8, @@ -321,6 +350,7 @@ pub fn decodeZstandardFrameArrayList( window_size_max: usize, ) (error{OutOfMemory} || FrameContext.Error || FrameError)!usize { assert(readInt(u32, src[0..4]) == frame.Zstandard.magic_number); + const initial_len = dest.items.len; var consumed_count: usize = 4; var frame_context = context: { @@ -364,6 +394,12 @@ pub fn decodeZstandardFrameArrayList( hasher.update(written_slice.second); } } + const added_len = dest.items.len - initial_len; + if (frame_context.content_size) |size| { + if (added_len != size) { + return error.BadContentSize; + } + } if (block_header.last_block) break; } @@ -384,7 +420,7 @@ fn decodeFrameBlocks( src: []const u8, consumed_count: *usize, hash: ?*std.hash.XxHash64, -) (error{EndOfStream} || block.Error)!usize { +) (error{ EndOfStream, DestTooSmall } || block.Error)!usize { // These tables take 7680 bytes var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined; var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined;