commit bb304796f466b4fd15a4adac7ffbb81ace95d2a4 (tree)
parent 333055ced72ae73c1ec74a66aa9436291d8b7d1d
Author: Kendall Condon <goon.pri.low@gmail.com>
Date: Wed, 18 Feb 2026 18:17:15 -0500
optimize flate decompression
Matches now use memcpy and memset when possible.
Block loops have been rewritten to be more optimizer friendly.
Reworks Symbol and HuffmanDecoder
* Symbol now only includes the value and number of code bits.
decodeSymbol returns only the value.
* HuffmanDecoder now takes the regular bits instead of the reversed.
* Code table construction now uses buckets instead of sorting.
* For linked codes, the value field of Symbol is now used as the next
index. The actual value is the element index.
* InvalidCode is now detected only once with a special linked index.
Performance is 39.7% faster than before and 1.1% faster than gzip using
a sample created from compressing a tar of the src directory.
Diffstat:
1 file changed, 170 insertions(+), 189 deletions(-)
diff --git a/lib/std/compress/flate/Decompress.zig b/lib/std/compress/flate/Decompress.zig
@@ -229,11 +229,11 @@ fn dynamicCodeLength(self: *Decompress, code: u16, lens: []u4, pos: usize) !usiz
}
}
-fn decodeSymbol(self: *Decompress, decoder: anytype) !Symbol {
+fn decodeSymbol(self: *Decompress, decoder: anytype) !u16 {
// Maximum code len is 15 bits.
- const sym = try decoder.find(@bitReverse(try self.peekIntBitsShort(u15)));
+ const sym = try decoder.find(try self.peekIntBitsShort(u15));
try self.tossBitsShort(sym.code_bits);
- return sym;
+ return sym.value;
}
fn streamDirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
@@ -348,10 +348,10 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
var dec_lens: [286 + 30]u4 = @splat(0);
var pos: usize = 0;
while (pos < hlit + hdist) {
- const peeked = @bitReverse(try d.peekIntBitsShort(u7));
+ const peeked = try d.peekIntBitsShort(u7);
const sym = try cl_dec.find(peeked);
try d.tossBitsShort(sym.code_bits);
- pos += try d.dynamicCodeLength(sym.symbol, &dec_lens, pos);
+ pos += try d.dynamicCodeLength(sym.value, &dec_lens, pos);
}
if (pos > hlit + hdist) {
return error.InvalidDynamicBlockHeader;
@@ -383,35 +383,34 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
w.advance(n);
return @intFromEnum(limit) - remaining + n;
},
- .fixed_block => {
- while (remaining > 0) {
- const code = try d.readFixedCode();
- switch (code) {
- 0...255 => {
- if (remaining != 0) {
- @branchHint(.likely);
- try w.writeBytePreserve(flate.history_len, @intCast(code));
- remaining -= 1;
- } else {
- d.state = .{ .fixed_block_literal = @intCast(code) };
- return @intFromEnum(limit) - remaining;
- }
- },
- 256 => {
- d.state = if (d.final_block) .protocol_footer else .block_header;
- return @intFromEnum(limit) - remaining;
- },
- 257...285 => {
- // Handles fixed block non literal (length) code.
- // Length code is followed by 5 bits of distance code.
- const length = try d.decodeLength(@intCast(code - 257));
- continue :sw .{ .fixed_block_match = length };
- },
- else => return error.InvalidCode,
+ .fixed_block => while (true) {
+ // Consume bytes
+ const sym = try d.readFixedCode();
+
+ if (sym >= 256) {
+ @branchHint(.unlikely);
+
+ if (sym == 256) {
+ @branchHint(.unlikely);
+ // End
+ d.state = if (d.final_block) .protocol_footer else .block_header;
+ continue :sw d.state;
}
+
+ // Match
+ const length = try d.decodeLength(@intCast(sym - 257));
+ continue :sw .{ .fixed_block_match = length };
+ }
+
+ const byte: u8 = @intCast(sym);
+ if (remaining != 0) {
+ @branchHint(.likely);
+ remaining -= 1;
+ try w.writeBytePreserve(flate.history_len, byte);
+ } else {
+ d.state = .{ .fixed_block_literal = byte };
+ return @intFromEnum(limit) - remaining;
}
- d.state = .fixed_block;
- return @intFromEnum(limit) - remaining;
},
.fixed_block_literal => |symbol| {
assert(remaining != 0);
@@ -431,32 +430,35 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
return @intFromEnum(limit) - remaining;
}
},
- .dynamic_block => {
- // In larger archives most blocks are usually dynamic, so
- // decompression performance depends on this logic.
- var sym = try d.decodeSymbol(&d.lit_dec);
- sym: switch (sym.kind) {
- .literal => {
- if (remaining != 0) {
- @branchHint(.likely);
- remaining -= 1;
- try w.writeBytePreserve(flate.history_len, sym.symbol);
- sym = try d.decodeSymbol(&d.lit_dec);
- continue :sym sym.kind;
- } else {
- d.state = .{ .dynamic_block_literal = sym.symbol };
- return @intFromEnum(limit) - remaining;
- }
- },
- .match => {
- // Decode match backreference <length, distance>
- const length = try d.decodeLength(@intCast(sym.symbol));
- continue :sw .{ .dynamic_block_match = length };
- },
- .end_of_block => {
+ // In larger archives most blocks are usually dynamic, so
+ // decompression performance depends on this logic.
+ .dynamic_block => while (true) {
+ // Consume bytes
+ const sym = try d.decodeSymbol(&d.lit_dec);
+
+ if (sym >= 256) {
+ @branchHint(.unlikely);
+
+ if (sym == 256) {
+ @branchHint(.unlikely);
+ // End
d.state = if (d.final_block) .protocol_footer else .block_header;
continue :sw d.state;
- },
+ }
+
+ // Match
+ const length = try d.decodeLength(@intCast(sym - 257));
+ continue :sw .{ .dynamic_block_match = length };
+ }
+
+ const byte: u8 = @intCast(sym);
+ if (remaining != 0) {
+ @branchHint(.likely);
+ remaining -= 1;
+ try w.writeBytePreserve(flate.history_len, byte);
+ } else {
+ d.state = .{ .dynamic_block_literal = byte };
+ return @intFromEnum(limit) - remaining;
}
},
.dynamic_block_literal => |symbol| {
@@ -470,7 +472,7 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
@branchHint(.likely);
remaining -= length;
const dsm = try d.decodeSymbol(&d.dst_dec);
- const distance = try d.decodeDistance(@intCast(dsm.symbol));
+ const distance = try d.decodeDistance(@intCast(dsm));
try writeMatch(w, length, distance);
continue :sw .dynamic_block;
} else {
@@ -501,17 +503,25 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
/// back from current write position, and `length` of bytes.
fn writeMatch(w: *Writer, length: u16, distance: u16) !void {
if (w.end < distance) return error.InvalidMatch;
- if (length < token.min_length) return error.InvalidMatch;
- if (length > token.max_length) return error.InvalidMatch;
- if (distance < token.min_distance) return error.InvalidMatch;
- if (distance > token.max_distance) return error.InvalidMatch;
+ assert(length >= token.min_length);
+ assert(length <= token.max_length);
+ assert(distance >= token.min_distance);
+ assert(distance <= token.max_distance);
// This is not a @memmove; it intentionally repeats patterns caused by
// iterating one byte at a time.
const dest = try w.writableSlicePreserve(flate.history_len, length);
const end = dest.ptr - w.buffer.ptr;
const src = w.buffer[end - distance ..][0..length];
- for (dest, src) |*d, s| d.* = s;
+ if (distance >= length) {
+ @memcpy(dest, src);
+ } else if (distance == 1) {
+ // Repeating copy of single byte
+ @memset(dest, src[0]);
+ } else {
+ // Repeating copy of multiple bytes
+ for (dest, src) |*d, s| d.* = s;
+ }
}
fn peekBits(d: *Decompress, n: u4) !u16 {
@@ -603,31 +613,9 @@ fn readFixedCode(d: *Decompress) !u16 {
};
}
-pub const Symbol = packed struct {
- pub const Kind = enum(u2) {
- literal,
- end_of_block,
- match,
- };
-
- symbol: u8 = 0, // symbol from alphabet
+pub const Symbol = packed struct(u16) {
+ value: u12 = 0,
code_bits: u4 = 0, // number of bits in code 0-15
- kind: Kind = .literal,
-
- code: u16 = 0, // huffman code of the symbol
- next: u16 = 0, // pointer to the next symbol in linked list
- // it is safe to use 0 as null pointer, when sorted 0 has shortest code and fits into lookup
-
- // Sorting less than function.
- pub fn asc(_: void, a: Symbol, b: Symbol) bool {
- if (a.code_bits == b.code_bits) {
- if (a.kind == b.kind) {
- return a.symbol < b.symbol;
- }
- return @intFromEnum(a.kind) < @intFromEnum(b.kind);
- }
- return a.code_bits < b.code_bits;
- }
};
pub const LiteralDecoder = HuffmanDecoder(286, 15, 9);
@@ -646,69 +634,85 @@ pub const CodegenDecoder = HuffmanDecoder(19, 7, 7);
/// Small lookup table is optimization for faster search.
/// It is variation of the algorithm explained in [zlib](https://github.com/madler/zlib/blob/643e17b7498d12ab8d15565662880579692f769d/doc/algorithm.txt#L92)
/// with difference that we here use statically allocated arrays.
-///
fn HuffmanDecoder(
comptime alphabet_size: u16,
comptime max_code_bits: u4,
comptime lookup_bits: u4,
) type {
const lookup_shift = max_code_bits - lookup_bits;
+ const lookup_mask = (1 << lookup_bits) - 1;
return struct {
- // all symbols in alaphabet, sorted by code_len, symbol
- symbols: [alphabet_size]Symbol = undefined,
// lookup table code -> symbol
+ // for values with code_bits == 0, symbol is the index of the first node in linked
+ // if the index of the first node is 0xfff, it is an invalid code
lookup: [1 << lookup_bits]Symbol = undefined,
+ linked: if (lookup_bits == max_code_bits) void else [alphabet_size]struct {
+ // sym.value is the next index in linked where the current index ends the chain
+ // the actual symbol is this nodes's index
+ sym: Symbol,
+ code: u16,
+ } = undefined,
const Self = @This();
+ fn reverseIdx(idx: usize) u16 {
+ return @bitReverse(@as(@Int(.unsigned, lookup_bits), @intCast(idx)));
+ }
+
/// Generates symbols and lookup tables from list of code lens for each symbol.
pub fn generate(self: *Self, lens: []const u4) !void {
try checkCompleteness(lens);
- // init alphabet with code_bits
- for (self.symbols, 0..) |_, i| {
- const cb: u4 = if (i < lens.len) lens[i] else 0;
- self.symbols[i] = if (i < 256)
- .{ .kind = .literal, .symbol = @intCast(i), .code_bits = cb }
- else if (i == 256)
- .{ .kind = .end_of_block, .symbol = 0xff, .code_bits = cb }
- else
- .{ .kind = .match, .symbol = @intCast(i - 257), .code_bits = cb };
- }
- std.sort.heap(Symbol, &self.symbols, {}, Symbol.asc);
-
- // reset lookup table
- for (0..self.lookup.len) |i| {
- self.lookup[i] = .{};
+ var buckets: [1 + @as(usize, max_code_bits)][alphabet_size]Symbol = undefined;
+ var bucket_len: [buckets.len]u16 = @splat(0);
+ for (0.., lens) |symbol, bits| {
+ buckets[bits][bucket_len[bits]] = .{
+ .value = @intCast(symbol),
+ .code_bits = bits,
+ };
+ bucket_len[bits] += 1;
}
- // assign code to symbols
- // reference: https://youtu.be/9_YEGLe33NA?list=PLU4IQLU9e_OrY8oASHx0u3IXAL9TOdidm&t=2639
var code: u16 = 0;
var idx: u16 = 0;
- for (&self.symbols, 0..) |*sym, pos| {
- if (sym.code_bits == 0) continue; // skip unused
- sym.code = code;
-
- const next_code = code + (@as(u16, 1) << (max_code_bits - sym.code_bits));
- const next_idx = next_code >> lookup_shift;
-
- if (next_idx > self.lookup.len or idx >= self.lookup.len) break;
- if (sym.code_bits <= lookup_bits) {
- // fill small lookup table
- for (idx..next_idx) |j|
- self.lookup[j] = sym.*;
- } else {
- // insert into linked table starting at root
- const root = &self.lookup[idx];
- const root_next = root.next;
- root.next = @intCast(pos);
- sym.next = root_next;
+ for (1..lookup_bits + 1) |bits| {
+ const inc = @as(u16, 1) << @intCast(max_code_bits - bits);
+ for (buckets[bits][0..bucket_len[bits]]) |lookup_sym| {
+ const next_code = code + inc;
+ const next_idx = next_code >> lookup_shift;
+ for (idx..next_idx) |i| {
+ self.lookup[reverseIdx(i)] = lookup_sym;
+ }
+ code = next_code;
+ idx = next_idx;
+ }
+ }
+ for (lookup_bits + 1..buckets.len) |bits| {
+ const inc = @as(u16, 1) << @intCast(max_code_bits - bits);
+ for (buckets[bits][0..bucket_len[bits]]) |linked_sym| {
+ const next_code = code + inc;
+ const next_idx = next_code >> lookup_shift;
+
+ const ri = reverseIdx(idx);
+ const next: Symbol = .{
+ .value = self.lookup[ri].value,
+ .code_bits = linked_sym.code_bits,
+ };
+ self.linked[linked_sym.value] = .{
+ .sym = next,
+ .code = @bitReverse(@as(@Int(.unsigned, max_code_bits), @intCast(code))),
+ };
+ self.lookup[ri] = .{ .value = linked_sym.value, .code_bits = 0 };
+
+ code = next_code;
+ idx = next_idx;
}
+ }
- idx = next_idx;
- code = next_code;
+ // Invalid codes
+ for (idx..self.lookup.len) |i| {
+ self.lookup[reverseIdx(i)] = .{ .value = 0xfff, .code_bits = 0 };
}
}
@@ -748,23 +752,25 @@ fn HuffmanDecoder(
/// Finds symbol for lookup table code.
pub fn find(self: *Self, code: u16) !Symbol {
// try to find in lookup table
- const idx = code >> lookup_shift;
+ const idx = code & lookup_mask;
const sym = self.lookup[idx];
if (sym.code_bits != 0) return sym;
// if not use linked list of symbols with same prefix
- return self.findLinked(code, sym.next);
+ return self.findLinked(code, sym.value);
}
fn findLinked(self: *Self, code: u16, start: u16) !Symbol {
+ if (start == 0xfff) return error.InvalidCode;
+ if (lookup_bits == max_code_bits) unreachable;
var pos = start;
- while (pos > 0) {
- const sym = self.symbols[pos];
- const shift = max_code_bits - sym.code_bits;
+ while (true) {
+ const node = self.linked[pos];
+ const shift = -%node.sym.code_bits;
// compare code_bits number of upper bits
- if ((code ^ sym.code) >> shift == 0) return sym;
- pos = sym.next;
+ if ((code ^ node.code) << shift == 0)
+ return .{ .value = @intCast(pos), .code_bits = node.sym.code_bits };
+ pos = node.sym.value;
}
- return error.InvalidCode;
}
};
}
@@ -775,74 +781,49 @@ test "init/find" {
var h: CodegenDecoder = .{};
try h.generate(&code_lens);
- const expected = [_]struct {
- sym: Symbol,
- code: u16,
- }{
- .{
- .code = 0b00_00000,
- .sym = .{ .symbol = 3, .code_bits = 2 },
- },
- .{
- .code = 0b01_00000,
- .sym = .{ .symbol = 18, .code_bits = 2 },
- },
- .{
- .code = 0b100_0000,
- .sym = .{ .symbol = 1, .code_bits = 3 },
- },
- .{
- .code = 0b101_0000,
- .sym = .{ .symbol = 4, .code_bits = 3 },
- },
- .{
- .code = 0b110_0000,
- .sym = .{ .symbol = 17, .code_bits = 3 },
- },
- .{
- .code = 0b1110_000,
- .sym = .{ .symbol = 0, .code_bits = 4 },
- },
- .{
- .code = 0b1111_000,
- .sym = .{ .symbol = 16, .code_bits = 4 },
- },
- };
-
- // unused symbols
- for (0..12) |i| {
- try testing.expectEqual(0, h.symbols[i].code_bits);
- }
- // used, from index 12
- for (expected, 12..) |e, i| {
- try testing.expectEqual(e.sym.symbol, h.symbols[i].symbol);
- try testing.expectEqual(e.sym.code_bits, h.symbols[i].code_bits);
- const sym_from_code = try h.find(e.code);
- try testing.expectEqual(e.sym.symbol, sym_from_code.symbol);
- }
-
// All possible codes for each symbol.
// Lookup table has 126 elements, to cover all possible 7 bit codes.
for (0b0000_000..0b0100_000) |c| // 0..32 (32)
- try testing.expectEqual(3, (try h.find(@intCast(c))).symbol);
+ try testing.expectEqual(
+ Symbol{ .value = 3, .code_bits = 2 },
+ try h.find(@bitReverse(@as(u7, @intCast(c)))),
+ );
for (0b0100_000..0b1000_000) |c| // 32..64 (32)
- try testing.expectEqual(18, (try h.find(@intCast(c))).symbol);
+ try testing.expectEqual(
+ Symbol{ .value = 18, .code_bits = 2 },
+ try h.find(@bitReverse(@as(u7, @intCast(c)))),
+ );
for (0b1000_000..0b1010_000) |c| // 64..80 (16)
- try testing.expectEqual(1, (try h.find(@intCast(c))).symbol);
+ try testing.expectEqual(
+ Symbol{ .value = 1, .code_bits = 3 },
+ try h.find(@bitReverse(@as(u7, @intCast(c)))),
+ );
for (0b1010_000..0b1100_000) |c| // 80..96 (16)
- try testing.expectEqual(4, (try h.find(@intCast(c))).symbol);
+ try testing.expectEqual(
+ Symbol{ .value = 4, .code_bits = 3 },
+ try h.find(@bitReverse(@as(u7, @intCast(c)))),
+ );
for (0b1100_000..0b1110_000) |c| // 96..112 (16)
- try testing.expectEqual(17, (try h.find(@intCast(c))).symbol);
+ try testing.expectEqual(
+ Symbol{ .value = 17, .code_bits = 3 },
+ try h.find(@bitReverse(@as(u7, @intCast(c)))),
+ );
for (0b1110_000..0b1111_000) |c| // 112..120 (8)
- try testing.expectEqual(0, (try h.find(@intCast(c))).symbol);
+ try testing.expectEqual(
+ Symbol{ .value = 0, .code_bits = 4 },
+ try h.find(@bitReverse(@as(u7, @intCast(c)))),
+ );
for (0b1111_000..0b1_0000_000) |c| // 120...128 (8)
- try testing.expectEqual(16, (try h.find(@intCast(c))).symbol);
+ try testing.expectEqual(
+ Symbol{ .value = 16, .code_bits = 4 },
+ try h.find(@bitReverse(@as(u7, @intCast(c)))),
+ );
}
test "encode/decode literals" {
@@ -867,8 +848,8 @@ test "encode/decode literals" {
if (bits == 0) continue;
for (0..1 << (max_bits - bits)) |extra| {
const full = (@as(u16, code) << (max_bits - bits)) | @as(u16, @intCast(extra));
- const symbol = try decoder.find(full);
- try testing.expectEqual(i, symbol.symbol);
+ const symbol = try decoder.find(@bitReverse(@as(u5, @intCast(full))));
+ try testing.expectEqual(i, symbol.value);
try testing.expectEqual(bits, symbol.code_bits);
}
}