From 6b85373875e2a7a8ef9d74e20e8e827dc622a29c Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Tue, 24 Jan 2023 13:07:58 +1100 Subject: [PATCH] std.compress.zstandard: validate sequence lengths --- lib/std/compress/zstandard/decompress.zig | 38 ++++++++++++++++------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index c80492c3f1..315bc0196c 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -271,10 +271,9 @@ pub const DecodeState = struct { literals: LiteralsSection, sequence: Sequence, ) !void { - try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length); + if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence; - // TODO: should we validate offset against max_window_size? - assert(sequence.offset <= write_pos + sequence.literal_length); + try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length); const copy_start = write_pos + sequence.literal_length - sequence.offset; const copy_end = copy_start + sequence.match_length; // NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr @@ -288,8 +287,9 @@ pub const DecodeState = struct { literals: LiteralsSection, sequence: Sequence, ) !void { + if (sequence.offset > dest.data.len) return error.MalformedSequence; + try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length); - // TODO: check that ring buffer window is full enough for match copies const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length); // TODO: would std.mem.copy and figuring out dest slice be better/faster? for (copy_slice.first) |b| dest.writeAssumeCapacity(b); @@ -302,9 +302,13 @@ pub const DecodeState = struct { write_pos: usize, literals: LiteralsSection, bit_reader: anytype, + sequence_size_limit: usize, last_sequence: bool, ) !usize { const sequence = try self.nextSequence(bit_reader); + const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; + if (sequence_length > sequence_size_limit) return error.MalformedSequence; + try self.executeSequenceSlice(dest, write_pos, literals, sequence); log.debug("sequence decompressed into '{x}'", .{ std.fmt.fmtSliceHexUpper(dest[write_pos .. write_pos + sequence.literal_length + sequence.match_length]), @@ -314,7 +318,7 @@ pub const DecodeState = struct { try self.updateState(.match, bit_reader); try self.updateState(.offset, bit_reader); } - return sequence.match_length + sequence.literal_length; + return sequence_length; } pub fn decodeSequenceRingBuffer( @@ -322,12 +326,15 @@ pub const DecodeState = struct { dest: *RingBuffer, literals: LiteralsSection, bit_reader: anytype, + sequence_size_limit: usize, last_sequence: bool, ) !usize { const sequence = try self.nextSequence(bit_reader); + const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length; + if (sequence_length > sequence_size_limit) return error.MalformedSequence; + try self.executeSequenceRingBuffer(dest, literals, sequence); if (std.options.log_level == .debug) { - const sequence_length = sequence.literal_length + sequence.match_length; const written_slice = dest.sliceLast(sequence_length); log.debug("sequence decompressed into '{x}{x}'", .{ std.fmt.fmtSliceHexUpper(written_slice.first), @@ -339,7 +346,7 @@ pub const DecodeState = struct { try self.updateState(.match, bit_reader); try self.updateState(.offset, bit_reader); } - return sequence.match_length + sequence.literal_length; + return sequence_length; } fn nextLiteralMultiStream(self: *DecodeState, literals: LiteralsSection) !void { @@ -717,9 +724,9 @@ pub fn decodeBlock( consumed_count: *usize, written_count: usize, ) !usize { - const block_maximum_size = 1 << 17; // 128KiB + const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB const block_size = block_header.block_size; - if (block_maximum_size < block_size) return error.BlockSizeOverMaximum; + if (block_size_max < block_size) return error.BlockSizeOverMaximum; // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks) switch (block_header.block_type) { .raw => return decodeRawBlock(dest[written_count..], src, block_size, consumed_count), @@ -739,17 +746,21 @@ pub fn decodeBlock( try decode_state.readInitialFseState(&bit_stream); + var sequence_size_limit = block_size_max; var i: usize = 0; while (i < sequences_header.sequence_count) : (i += 1) { log.debug("decoding sequence {d}", .{i}); + const write_pos = written_count + bytes_written; const decompressed_size = try decode_state.decodeSequenceSlice( dest, - written_count + bytes_written, + write_pos, literals, &bit_stream, + sequence_size_limit, i == sequences_header.sequence_count - 1, ); bytes_written += decompressed_size; + sequence_size_limit -= decompressed_size; } bytes_read += bit_stream_bytes.len; @@ -781,10 +792,10 @@ pub fn decodeBlockRingBuffer( block_header: frame.ZStandard.Block.Header, decode_state: *DecodeState, consumed_count: *usize, - block_size_maximum: usize, + block_size_max: usize, ) !usize { const block_size = block_header.block_size; - if (block_size_maximum < block_size) return error.BlockSizeOverMaximum; + if (block_size_max < block_size) return error.BlockSizeOverMaximum; // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks) switch (block_header.block_type) { .raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count), @@ -804,6 +815,7 @@ pub fn decodeBlockRingBuffer( try decode_state.readInitialFseState(&bit_stream); + var sequence_size_limit = block_size_max; var i: usize = 0; while (i < sequences_header.sequence_count) : (i += 1) { log.debug("decoding sequence {d}", .{i}); @@ -811,9 +823,11 @@ pub fn decodeBlockRingBuffer( dest, literals, &bit_stream, + sequence_size_limit, i == sequences_header.sequence_count - 1, ); bytes_written += decompressed_size; + sequence_size_limit -= decompressed_size; } bytes_read += bit_stream_bytes.len;