std.compress.zstandard: verify content size and fix crash

This commit is contained in:
dweiller
2023-02-05 22:27:00 +11:00
parent a9c8376305
commit ece52e0771
2 changed files with 70 additions and 23 deletions

View File

@@ -334,6 +334,8 @@ pub const DecodeState = struct {
/// mean the literal stream or the sequence is malformed).
/// - `error.InvalidBitStream` if the FSE sequence bitstream is malformed
/// - `error.EndOfStream` if `bit_reader` does not contain enough bits
/// - `error.DestTooSmall` if `dest` is not large enough to holde the
/// decompressed sequence
pub fn decodeSequenceSlice(
self: *DecodeState,
dest: []u8,
@@ -341,10 +343,11 @@ pub const DecodeState = struct {
bit_reader: *readers.ReverseBitReader,
sequence_size_limit: usize,
last_sequence: bool,
) DecodeSequenceError!usize {
) (error{DestTooSmall} || DecodeSequenceError)!usize {
const sequence = try self.nextSequence(bit_reader);
const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
if (sequence_length > sequence_size_limit) return error.MalformedSequence;
if (sequence_length > dest[write_pos..].len) return error.DestTooSmall;
try self.executeSequenceSlice(dest, write_pos, sequence);
if (!last_sequence) {
@@ -583,6 +586,8 @@ pub const DecodeState = struct {
/// - `error.MalformedRleBlock` if the block is an RLE block and `src.len < 1`
/// - `error.MalformedCompressedBlock` if there are errors decoding a
/// compressed block
/// - `error.DestTooSmall` is `dest` is not large enough to hold the
/// decompressed block
pub fn decodeBlock(
dest: []u8,
src: []const u8,
@@ -590,13 +595,14 @@ pub fn decodeBlock(
decode_state: *DecodeState,
consumed_count: *usize,
written_count: usize,
) Error!usize {
) (error{DestTooSmall} || Error)!usize {
const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB
const block_size = block_header.block_size;
if (block_size_max < block_size) return error.BlockSizeOverMaximum;
switch (block_header.block_type) {
.raw => {
if (src.len < block_size) return error.MalformedBlockSize;
if (dest[written_count..].len < block_size) return error.DestTooSmall;
const data = src[0..block_size];
std.mem.copy(u8, dest[written_count..], data);
consumed_count.* += block_size;
@@ -604,6 +610,7 @@ pub fn decodeBlock(
},
.rle => {
if (src.len < 1) return error.MalformedRleBlock;
if (dest[written_count..].len < block_size) return error.DestTooSmall;
var write_pos: usize = written_count;
while (write_pos < block_size + written_count) : (write_pos += 1) {
dest[write_pos] = src[0];
@@ -644,7 +651,10 @@ pub fn decodeBlock(
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
) catch return error.MalformedCompressedBlock;
) catch |err| switch (err) {
error.DestTooSmall => return error.DestTooSmall,
else => return error.MalformedCompressedBlock,
};
bytes_written += decompressed_size;
sequence_size_limit -= decompressed_size;
}
@@ -655,6 +665,7 @@ pub fn decodeBlock(
if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count;
if (len > dest[written_count + bytes_written ..].len) return error.DestTooSmall;
decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch
return error.MalformedCompressedBlock;
bytes_written += len;

View File

@@ -96,6 +96,7 @@ pub fn decodeAlloc(
/// - `error.UnknownContentSizeUnsupported` if the frame does not declare the
/// uncompressed content size
/// - `error.ContentTooLarge` if `dest` is smaller than the uncompressed data
/// size declared by the frame header
/// - `error.BadMagic` if the first 4 bytes of `src` is not a valid magic
/// number for a Zstandard or Skippable frame
/// - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary
@@ -180,6 +181,7 @@ pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 {
const FrameError = error{
DictionaryIdFlagUnsupported,
ChecksumFailure,
BadContentSize,
EndOfStream,
} || InvalidBit || block.Error;
@@ -191,7 +193,7 @@ const FrameError = error{
/// - `error.UnknownContentSizeUnsupported` if the frame does not declare the
/// uncompressed content size
/// - `error.ContentTooLarge` if `dest` is smaller than the uncompressed data
/// number for a Zstandard or Skippable frame
/// size declared by the frame header
/// - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary
/// - `error.ChecksumFailure` if `verify_checksum` is true and the frame
/// contains a checksum that does not match the checksum of the decompressed
@@ -200,39 +202,51 @@ const FrameError = error{
/// - `error.UnusedBitSet` if the unused bit of the frame header is set
/// - `error.EndOfStream` if `src` does not contain a complete frame
/// - an error in `block.Error` if there are errors decoding a block
/// - `error.BadContentSize` if the content size declared by the frame does
/// not equal the actual size of decompressed data
pub fn decodeZstandardFrame(
dest: []u8,
src: []const u8,
verify_checksum: bool,
) (error{ UnknownContentSizeUnsupported, ContentTooLarge } || FrameError)!ReadWriteCount {
) (error{
UnknownContentSizeUnsupported,
ContentTooLarge,
ContentSizeTooLarge,
WindowSizeUnknown,
} || FrameError)!ReadWriteCount {
assert(readInt(u32, src[0..4]) == frame.Zstandard.magic_number);
var consumed_count: usize = 4;
var fbs = std.io.fixedBufferStream(src[consumed_count..]);
var source = fbs.reader();
const frame_header = try decodeZstandardHeader(source);
consumed_count += fbs.pos;
var frame_context = context: {
var fbs = std.io.fixedBufferStream(src[consumed_count..]);
var source = fbs.reader();
const frame_header = try decodeZstandardHeader(source);
consumed_count += fbs.pos;
break :context FrameContext.init(frame_header, std.math.maxInt(usize), verify_checksum) catch |err| switch (err) {
error.WindowTooLarge => unreachable,
inline else => |e| return e,
};
};
if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported;
const content_size = frame_header.content_size orelse return error.UnknownContentSizeUnsupported;
const content_size = frame_context.content_size orelse return error.UnknownContentSizeUnsupported;
if (dest.len < content_size) return error.ContentTooLarge;
const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null;
const written_count = try decodeFrameBlocks(
dest,
const written_count = decodeFrameBlocks(
dest[0..content_size],
src[consumed_count..],
&consumed_count,
if (hasher_opt) |*hasher| hasher else null,
);
if (frame_context.hasher_opt) |*hasher| hasher else null,
) catch |err| switch (err) {
error.DestTooSmall => return error.BadContentSize,
inline else => |e| return e,
};
if (frame_header.descriptor.content_checksum_flag) {
if (written_count != content_size) return error.BadContentSize;
if (frame_context.has_checksum) {
if (src.len < consumed_count + 4) return error.EndOfStream;
const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
consumed_count += 4;
if (hasher_opt) |*hasher| {
if (frame_context.hasher_opt) |*hasher| {
if (checksum != computeChecksum(hasher)) return error.ChecksumFailure;
}
}
@@ -244,8 +258,14 @@ pub const FrameContext = struct {
window_size: usize,
has_checksum: bool,
block_size_max: usize,
content_size: ?usize,
const Error = error{ DictionaryIdFlagUnsupported, WindowSizeUnknown, WindowTooLarge };
const Error = error{
DictionaryIdFlagUnsupported,
WindowSizeUnknown,
WindowTooLarge,
ContentSizeTooLarge,
};
/// Validates `frame_header` and returns the associated `FrameContext`.
///
/// Errors returned:
@@ -266,11 +286,18 @@ pub const FrameContext = struct {
@intCast(usize, window_size_raw);
const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
const content_size = if (frame_header.content_size) |size|
std.math.cast(usize, size) orelse return error.ContentSizeTooLarge
else
null;
return .{
.hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null,
.window_size = window_size,
.has_checksum = frame_header.descriptor.content_checksum_flag,
.block_size_max = @min(1 << 17, window_size),
.content_size = content_size,
};
}
};
@@ -294,6 +321,8 @@ pub const FrameContext = struct {
/// - `error.EndOfStream` if `src` does not contain a complete frame
/// - `error.OutOfMemory` if `allocator` cannot allocate enough memory
/// - an error in `block.Error` if there are errors decoding a block
/// - `error.BadContentSize` if the content size declared by the frame does
/// not equal the size of decompressed data
pub fn decodeZstandardFrameAlloc(
allocator: Allocator,
src: []const u8,
@@ -321,6 +350,7 @@ pub fn decodeZstandardFrameArrayList(
window_size_max: usize,
) (error{OutOfMemory} || FrameContext.Error || FrameError)!usize {
assert(readInt(u32, src[0..4]) == frame.Zstandard.magic_number);
const initial_len = dest.items.len;
var consumed_count: usize = 4;
var frame_context = context: {
@@ -364,6 +394,12 @@ pub fn decodeZstandardFrameArrayList(
hasher.update(written_slice.second);
}
}
const added_len = dest.items.len - initial_len;
if (frame_context.content_size) |size| {
if (added_len != size) {
return error.BadContentSize;
}
}
if (block_header.last_block) break;
}
@@ -384,7 +420,7 @@ fn decodeFrameBlocks(
src: []const u8,
consumed_count: *usize,
hash: ?*std.hash.XxHash64,
) (error{EndOfStream} || block.Error)!usize {
) (error{ EndOfStream, DestTooSmall } || block.Error)!usize {
// These tables take 7680 bytes
var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined;
var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined;