zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

commit bb304796f466b4fd15a4adac7ffbb81ace95d2a4 (tree)
parent 333055ced72ae73c1ec74a66aa9436291d8b7d1d
Author: Kendall Condon <goon.pri.low@gmail.com>
Date:   Wed, 18 Feb 2026 18:17:15 -0500

optimize flate decompression

Matches now use memcpy and memset when possible.

Block loops have been rewritten to be more optimizer friendly.

Reworks Symbol and HuffmanDecoder
* Symbol now only includes the value and number of code bits.
  decodeSymbol returns only the value.
* HuffmanDecoder now takes the regular bits instead of the reversed.
* Code table construction now uses buckets instead of sorting.
* For linked codes, the value field of Symbol is now used as the next
  index. The actual value is the element index.
* InvalidCode is now detected only once with a special linked index.

Performance is 39.7% faster than before and 1.1% faster than gzip using
a sample created from compressing a tar of the src directory.

Diffstat:
Mlib/std/compress/flate/Decompress.zig | 359+++++++++++++++++++++++++++++++++++++------------------------------------------
1 file changed, 170 insertions(+), 189 deletions(-)

diff --git a/lib/std/compress/flate/Decompress.zig b/lib/std/compress/flate/Decompress.zig @@ -229,11 +229,11 @@ fn dynamicCodeLength(self: *Decompress, code: u16, lens: []u4, pos: usize) !usiz } } -fn decodeSymbol(self: *Decompress, decoder: anytype) !Symbol { +fn decodeSymbol(self: *Decompress, decoder: anytype) !u16 { // Maximum code len is 15 bits. - const sym = try decoder.find(@bitReverse(try self.peekIntBitsShort(u15))); + const sym = try decoder.find(try self.peekIntBitsShort(u15)); try self.tossBitsShort(sym.code_bits); - return sym; + return sym.value; } fn streamDirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize { @@ -348,10 +348,10 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader var dec_lens: [286 + 30]u4 = @splat(0); var pos: usize = 0; while (pos < hlit + hdist) { - const peeked = @bitReverse(try d.peekIntBitsShort(u7)); + const peeked = try d.peekIntBitsShort(u7); const sym = try cl_dec.find(peeked); try d.tossBitsShort(sym.code_bits); - pos += try d.dynamicCodeLength(sym.symbol, &dec_lens, pos); + pos += try d.dynamicCodeLength(sym.value, &dec_lens, pos); } if (pos > hlit + hdist) { return error.InvalidDynamicBlockHeader; @@ -383,35 +383,34 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader w.advance(n); return @intFromEnum(limit) - remaining + n; }, - .fixed_block => { - while (remaining > 0) { - const code = try d.readFixedCode(); - switch (code) { - 0...255 => { - if (remaining != 0) { - @branchHint(.likely); - try w.writeBytePreserve(flate.history_len, @intCast(code)); - remaining -= 1; - } else { - d.state = .{ .fixed_block_literal = @intCast(code) }; - return @intFromEnum(limit) - remaining; - } - }, - 256 => { - d.state = if (d.final_block) .protocol_footer else .block_header; - return @intFromEnum(limit) - remaining; - }, - 257...285 => { - // Handles fixed block non literal (length) code. - // Length code is followed by 5 bits of distance code. - const length = try d.decodeLength(@intCast(code - 257)); - continue :sw .{ .fixed_block_match = length }; - }, - else => return error.InvalidCode, + .fixed_block => while (true) { + // Consume bytes + const sym = try d.readFixedCode(); + + if (sym >= 256) { + @branchHint(.unlikely); + + if (sym == 256) { + @branchHint(.unlikely); + // End + d.state = if (d.final_block) .protocol_footer else .block_header; + continue :sw d.state; } + + // Match + const length = try d.decodeLength(@intCast(sym - 257)); + continue :sw .{ .fixed_block_match = length }; + } + + const byte: u8 = @intCast(sym); + if (remaining != 0) { + @branchHint(.likely); + remaining -= 1; + try w.writeBytePreserve(flate.history_len, byte); + } else { + d.state = .{ .fixed_block_literal = byte }; + return @intFromEnum(limit) - remaining; } - d.state = .fixed_block; - return @intFromEnum(limit) - remaining; }, .fixed_block_literal => |symbol| { assert(remaining != 0); @@ -431,32 +430,35 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader return @intFromEnum(limit) - remaining; } }, - .dynamic_block => { - // In larger archives most blocks are usually dynamic, so - // decompression performance depends on this logic. - var sym = try d.decodeSymbol(&d.lit_dec); - sym: switch (sym.kind) { - .literal => { - if (remaining != 0) { - @branchHint(.likely); - remaining -= 1; - try w.writeBytePreserve(flate.history_len, sym.symbol); - sym = try d.decodeSymbol(&d.lit_dec); - continue :sym sym.kind; - } else { - d.state = .{ .dynamic_block_literal = sym.symbol }; - return @intFromEnum(limit) - remaining; - } - }, - .match => { - // Decode match backreference <length, distance> - const length = try d.decodeLength(@intCast(sym.symbol)); - continue :sw .{ .dynamic_block_match = length }; - }, - .end_of_block => { + // In larger archives most blocks are usually dynamic, so + // decompression performance depends on this logic. + .dynamic_block => while (true) { + // Consume bytes + const sym = try d.decodeSymbol(&d.lit_dec); + + if (sym >= 256) { + @branchHint(.unlikely); + + if (sym == 256) { + @branchHint(.unlikely); + // End d.state = if (d.final_block) .protocol_footer else .block_header; continue :sw d.state; - }, + } + + // Match + const length = try d.decodeLength(@intCast(sym - 257)); + continue :sw .{ .dynamic_block_match = length }; + } + + const byte: u8 = @intCast(sym); + if (remaining != 0) { + @branchHint(.likely); + remaining -= 1; + try w.writeBytePreserve(flate.history_len, byte); + } else { + d.state = .{ .dynamic_block_literal = byte }; + return @intFromEnum(limit) - remaining; } }, .dynamic_block_literal => |symbol| { @@ -470,7 +472,7 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader @branchHint(.likely); remaining -= length; const dsm = try d.decodeSymbol(&d.dst_dec); - const distance = try d.decodeDistance(@intCast(dsm.symbol)); + const distance = try d.decodeDistance(@intCast(dsm)); try writeMatch(w, length, distance); continue :sw .dynamic_block; } else { @@ -501,17 +503,25 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader /// back from current write position, and `length` of bytes. fn writeMatch(w: *Writer, length: u16, distance: u16) !void { if (w.end < distance) return error.InvalidMatch; - if (length < token.min_length) return error.InvalidMatch; - if (length > token.max_length) return error.InvalidMatch; - if (distance < token.min_distance) return error.InvalidMatch; - if (distance > token.max_distance) return error.InvalidMatch; + assert(length >= token.min_length); + assert(length <= token.max_length); + assert(distance >= token.min_distance); + assert(distance <= token.max_distance); // This is not a @memmove; it intentionally repeats patterns caused by // iterating one byte at a time. const dest = try w.writableSlicePreserve(flate.history_len, length); const end = dest.ptr - w.buffer.ptr; const src = w.buffer[end - distance ..][0..length]; - for (dest, src) |*d, s| d.* = s; + if (distance >= length) { + @memcpy(dest, src); + } else if (distance == 1) { + // Repeating copy of single byte + @memset(dest, src[0]); + } else { + // Repeating copy of multiple bytes + for (dest, src) |*d, s| d.* = s; + } } fn peekBits(d: *Decompress, n: u4) !u16 { @@ -603,31 +613,9 @@ fn readFixedCode(d: *Decompress) !u16 { }; } -pub const Symbol = packed struct { - pub const Kind = enum(u2) { - literal, - end_of_block, - match, - }; - - symbol: u8 = 0, // symbol from alphabet +pub const Symbol = packed struct(u16) { + value: u12 = 0, code_bits: u4 = 0, // number of bits in code 0-15 - kind: Kind = .literal, - - code: u16 = 0, // huffman code of the symbol - next: u16 = 0, // pointer to the next symbol in linked list - // it is safe to use 0 as null pointer, when sorted 0 has shortest code and fits into lookup - - // Sorting less than function. - pub fn asc(_: void, a: Symbol, b: Symbol) bool { - if (a.code_bits == b.code_bits) { - if (a.kind == b.kind) { - return a.symbol < b.symbol; - } - return @intFromEnum(a.kind) < @intFromEnum(b.kind); - } - return a.code_bits < b.code_bits; - } }; pub const LiteralDecoder = HuffmanDecoder(286, 15, 9); @@ -646,69 +634,85 @@ pub const CodegenDecoder = HuffmanDecoder(19, 7, 7); /// Small lookup table is optimization for faster search. /// It is variation of the algorithm explained in [zlib](https://github.com/madler/zlib/blob/643e17b7498d12ab8d15565662880579692f769d/doc/algorithm.txt#L92) /// with difference that we here use statically allocated arrays. -/// fn HuffmanDecoder( comptime alphabet_size: u16, comptime max_code_bits: u4, comptime lookup_bits: u4, ) type { const lookup_shift = max_code_bits - lookup_bits; + const lookup_mask = (1 << lookup_bits) - 1; return struct { - // all symbols in alaphabet, sorted by code_len, symbol - symbols: [alphabet_size]Symbol = undefined, // lookup table code -> symbol + // for values with code_bits == 0, symbol is the index of the first node in linked + // if the index of the first node is 0xfff, it is an invalid code lookup: [1 << lookup_bits]Symbol = undefined, + linked: if (lookup_bits == max_code_bits) void else [alphabet_size]struct { + // sym.value is the next index in linked where the current index ends the chain + // the actual symbol is this nodes's index + sym: Symbol, + code: u16, + } = undefined, const Self = @This(); + fn reverseIdx(idx: usize) u16 { + return @bitReverse(@as(@Int(.unsigned, lookup_bits), @intCast(idx))); + } + /// Generates symbols and lookup tables from list of code lens for each symbol. pub fn generate(self: *Self, lens: []const u4) !void { try checkCompleteness(lens); - // init alphabet with code_bits - for (self.symbols, 0..) |_, i| { - const cb: u4 = if (i < lens.len) lens[i] else 0; - self.symbols[i] = if (i < 256) - .{ .kind = .literal, .symbol = @intCast(i), .code_bits = cb } - else if (i == 256) - .{ .kind = .end_of_block, .symbol = 0xff, .code_bits = cb } - else - .{ .kind = .match, .symbol = @intCast(i - 257), .code_bits = cb }; - } - std.sort.heap(Symbol, &self.symbols, {}, Symbol.asc); - - // reset lookup table - for (0..self.lookup.len) |i| { - self.lookup[i] = .{}; + var buckets: [1 + @as(usize, max_code_bits)][alphabet_size]Symbol = undefined; + var bucket_len: [buckets.len]u16 = @splat(0); + for (0.., lens) |symbol, bits| { + buckets[bits][bucket_len[bits]] = .{ + .value = @intCast(symbol), + .code_bits = bits, + }; + bucket_len[bits] += 1; } - // assign code to symbols - // reference: https://youtu.be/9_YEGLe33NA?list=PLU4IQLU9e_OrY8oASHx0u3IXAL9TOdidm&t=2639 var code: u16 = 0; var idx: u16 = 0; - for (&self.symbols, 0..) |*sym, pos| { - if (sym.code_bits == 0) continue; // skip unused - sym.code = code; - - const next_code = code + (@as(u16, 1) << (max_code_bits - sym.code_bits)); - const next_idx = next_code >> lookup_shift; - - if (next_idx > self.lookup.len or idx >= self.lookup.len) break; - if (sym.code_bits <= lookup_bits) { - // fill small lookup table - for (idx..next_idx) |j| - self.lookup[j] = sym.*; - } else { - // insert into linked table starting at root - const root = &self.lookup[idx]; - const root_next = root.next; - root.next = @intCast(pos); - sym.next = root_next; + for (1..lookup_bits + 1) |bits| { + const inc = @as(u16, 1) << @intCast(max_code_bits - bits); + for (buckets[bits][0..bucket_len[bits]]) |lookup_sym| { + const next_code = code + inc; + const next_idx = next_code >> lookup_shift; + for (idx..next_idx) |i| { + self.lookup[reverseIdx(i)] = lookup_sym; + } + code = next_code; + idx = next_idx; + } + } + for (lookup_bits + 1..buckets.len) |bits| { + const inc = @as(u16, 1) << @intCast(max_code_bits - bits); + for (buckets[bits][0..bucket_len[bits]]) |linked_sym| { + const next_code = code + inc; + const next_idx = next_code >> lookup_shift; + + const ri = reverseIdx(idx); + const next: Symbol = .{ + .value = self.lookup[ri].value, + .code_bits = linked_sym.code_bits, + }; + self.linked[linked_sym.value] = .{ + .sym = next, + .code = @bitReverse(@as(@Int(.unsigned, max_code_bits), @intCast(code))), + }; + self.lookup[ri] = .{ .value = linked_sym.value, .code_bits = 0 }; + + code = next_code; + idx = next_idx; } + } - idx = next_idx; - code = next_code; + // Invalid codes + for (idx..self.lookup.len) |i| { + self.lookup[reverseIdx(i)] = .{ .value = 0xfff, .code_bits = 0 }; } } @@ -748,23 +752,25 @@ fn HuffmanDecoder( /// Finds symbol for lookup table code. pub fn find(self: *Self, code: u16) !Symbol { // try to find in lookup table - const idx = code >> lookup_shift; + const idx = code & lookup_mask; const sym = self.lookup[idx]; if (sym.code_bits != 0) return sym; // if not use linked list of symbols with same prefix - return self.findLinked(code, sym.next); + return self.findLinked(code, sym.value); } fn findLinked(self: *Self, code: u16, start: u16) !Symbol { + if (start == 0xfff) return error.InvalidCode; + if (lookup_bits == max_code_bits) unreachable; var pos = start; - while (pos > 0) { - const sym = self.symbols[pos]; - const shift = max_code_bits - sym.code_bits; + while (true) { + const node = self.linked[pos]; + const shift = -%node.sym.code_bits; // compare code_bits number of upper bits - if ((code ^ sym.code) >> shift == 0) return sym; - pos = sym.next; + if ((code ^ node.code) << shift == 0) + return .{ .value = @intCast(pos), .code_bits = node.sym.code_bits }; + pos = node.sym.value; } - return error.InvalidCode; } }; } @@ -775,74 +781,49 @@ test "init/find" { var h: CodegenDecoder = .{}; try h.generate(&code_lens); - const expected = [_]struct { - sym: Symbol, - code: u16, - }{ - .{ - .code = 0b00_00000, - .sym = .{ .symbol = 3, .code_bits = 2 }, - }, - .{ - .code = 0b01_00000, - .sym = .{ .symbol = 18, .code_bits = 2 }, - }, - .{ - .code = 0b100_0000, - .sym = .{ .symbol = 1, .code_bits = 3 }, - }, - .{ - .code = 0b101_0000, - .sym = .{ .symbol = 4, .code_bits = 3 }, - }, - .{ - .code = 0b110_0000, - .sym = .{ .symbol = 17, .code_bits = 3 }, - }, - .{ - .code = 0b1110_000, - .sym = .{ .symbol = 0, .code_bits = 4 }, - }, - .{ - .code = 0b1111_000, - .sym = .{ .symbol = 16, .code_bits = 4 }, - }, - }; - - // unused symbols - for (0..12) |i| { - try testing.expectEqual(0, h.symbols[i].code_bits); - } - // used, from index 12 - for (expected, 12..) |e, i| { - try testing.expectEqual(e.sym.symbol, h.symbols[i].symbol); - try testing.expectEqual(e.sym.code_bits, h.symbols[i].code_bits); - const sym_from_code = try h.find(e.code); - try testing.expectEqual(e.sym.symbol, sym_from_code.symbol); - } - // All possible codes for each symbol. // Lookup table has 126 elements, to cover all possible 7 bit codes. for (0b0000_000..0b0100_000) |c| // 0..32 (32) - try testing.expectEqual(3, (try h.find(@intCast(c))).symbol); + try testing.expectEqual( + Symbol{ .value = 3, .code_bits = 2 }, + try h.find(@bitReverse(@as(u7, @intCast(c)))), + ); for (0b0100_000..0b1000_000) |c| // 32..64 (32) - try testing.expectEqual(18, (try h.find(@intCast(c))).symbol); + try testing.expectEqual( + Symbol{ .value = 18, .code_bits = 2 }, + try h.find(@bitReverse(@as(u7, @intCast(c)))), + ); for (0b1000_000..0b1010_000) |c| // 64..80 (16) - try testing.expectEqual(1, (try h.find(@intCast(c))).symbol); + try testing.expectEqual( + Symbol{ .value = 1, .code_bits = 3 }, + try h.find(@bitReverse(@as(u7, @intCast(c)))), + ); for (0b1010_000..0b1100_000) |c| // 80..96 (16) - try testing.expectEqual(4, (try h.find(@intCast(c))).symbol); + try testing.expectEqual( + Symbol{ .value = 4, .code_bits = 3 }, + try h.find(@bitReverse(@as(u7, @intCast(c)))), + ); for (0b1100_000..0b1110_000) |c| // 96..112 (16) - try testing.expectEqual(17, (try h.find(@intCast(c))).symbol); + try testing.expectEqual( + Symbol{ .value = 17, .code_bits = 3 }, + try h.find(@bitReverse(@as(u7, @intCast(c)))), + ); for (0b1110_000..0b1111_000) |c| // 112..120 (8) - try testing.expectEqual(0, (try h.find(@intCast(c))).symbol); + try testing.expectEqual( + Symbol{ .value = 0, .code_bits = 4 }, + try h.find(@bitReverse(@as(u7, @intCast(c)))), + ); for (0b1111_000..0b1_0000_000) |c| // 120...128 (8) - try testing.expectEqual(16, (try h.find(@intCast(c))).symbol); + try testing.expectEqual( + Symbol{ .value = 16, .code_bits = 4 }, + try h.find(@bitReverse(@as(u7, @intCast(c)))), + ); } test "encode/decode literals" { @@ -867,8 +848,8 @@ test "encode/decode literals" { if (bits == 0) continue; for (0..1 << (max_bits - bits)) |extra| { const full = (@as(u16, code) << (max_bits - bits)) | @as(u16, @intCast(extra)); - const symbol = try decoder.find(full); - try testing.expectEqual(i, symbol.symbol); + const symbol = try decoder.find(@bitReverse(@as(u5, @intCast(full)))); + try testing.expectEqual(i, symbol.value); try testing.expectEqual(bits, symbol.code_bits); } }