std.compress.zstd.Decompress fixes

* std.Io.Reader: appendRemaining no longer supports alignment and has
  different rules about how exceeding limit. Fixed bug where it would
  return success instead of error.StreamTooLong like it was supposed to.

* std.Io.Reader: simplify appendRemaining and appendRemainingUnlimited
  to be implemented based on std.Io.Writer.Allocating

* std.Io.Writer: introduce unreachableRebase

* std.Io.Writer: remove minimum_unused_capacity from Allocating. maybe
  that flexibility could have been handy, but let's see if anyone
  actually needs it. The field is redundant with the superlinear growth
  of ArrayList capacity.

* std.Io.Writer: growingRebase also ensures total capacity on the
  preserve parameter, making it no longer necessary to do
  ensureTotalCapacity at the usage site of decompression streams.

* std.compress.flate.Decompress: fix rebase not taking into account seek

* std.compress.zstd.Decompress: split into "direct" and "indirect" usage
  patterns depending on whether a buffer is provided to init, matching
  how flate works. Remove some overzealous asserts that prevented buffer
  expansion from within rebase implementation.

* std.zig: fix readSourceFileToAlloc returning an overaligned slice
  which was difficult to free correctly.

fixes #24608
This commit is contained in:
Andrew Kelley
2025-08-14 20:34:44 -07:00
parent 6d7c6a0f4e
commit 30b41dc510
10 changed files with 168 additions and 154 deletions

View File

@@ -8,7 +8,7 @@ const Writer = std.io.Writer;
const assert = std.debug.assert; const assert = std.debug.assert;
const testing = std.testing; const testing = std.testing;
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
const ArrayList = std.ArrayListUnmanaged; const ArrayList = std.ArrayList;
const Limit = std.io.Limit; const Limit = std.io.Limit;
pub const Limited = @import("Reader/Limited.zig"); pub const Limited = @import("Reader/Limited.zig");
@@ -290,103 +290,63 @@ pub const LimitedAllocError = Allocator.Error || ShortError || error{StreamTooLo
pub fn allocRemaining(r: *Reader, gpa: Allocator, limit: Limit) LimitedAllocError![]u8 { pub fn allocRemaining(r: *Reader, gpa: Allocator, limit: Limit) LimitedAllocError![]u8 {
var buffer: ArrayList(u8) = .empty; var buffer: ArrayList(u8) = .empty;
defer buffer.deinit(gpa); defer buffer.deinit(gpa);
try appendRemaining(r, gpa, null, &buffer, limit); try appendRemaining(r, gpa, &buffer, limit);
return buffer.toOwnedSlice(gpa); return buffer.toOwnedSlice(gpa);
} }
/// Transfers all bytes from the current position to the end of the stream, up /// Transfers all bytes from the current position to the end of the stream, up
/// to `limit`, appending them to `list`. /// to `limit`, appending them to `list`.
/// ///
/// If `limit` would be exceeded, `error.StreamTooLong` is returned instead. In /// If `limit` is reached or exceeded, `error.StreamTooLong` is returned
/// such case, the next byte that would be read will be the first one to exceed /// instead. In such case, the next byte that would be read will be the first
/// `limit`, and all preceeding bytes have been appended to `list`. /// one to exceed `limit`, and all preceeding bytes have been appended to
/// /// `list`.
/// If `limit` is not `Limit.unlimited`, asserts `buffer` has nonzero capacity.
/// ///
/// See also: /// See also:
/// * `allocRemaining` /// * `allocRemaining`
pub fn appendRemaining( pub fn appendRemaining(
r: *Reader, r: *Reader,
gpa: Allocator, gpa: Allocator,
comptime alignment: ?std.mem.Alignment, list: *ArrayList(u8),
list: *std.ArrayListAlignedUnmanaged(u8, alignment),
limit: Limit, limit: Limit,
) LimitedAllocError!void { ) LimitedAllocError!void {
if (limit == .unlimited) return appendRemainingUnlimited(r, gpa, alignment, list, 1); var a: std.Io.Writer.Allocating = .initOwnedSlice(gpa, list.items);
assert(r.buffer.len != 0); // Needed to detect limit exceeded without losing data. a.writer.end = list.items.len;
const buffer_contents = r.buffer[r.seek..r.end]; list.* = .empty;
const copy_len = limit.minInt(buffer_contents.len); defer {
try list.appendSlice(gpa, r.buffer[0..copy_len]); list.* = .{
r.seek += copy_len; .items = a.writer.buffer[0..a.writer.end],
if (buffer_contents.len - copy_len != 0) return error.StreamTooLong; .capacity = a.writer.buffer.len,
r.seek = 0; };
r.end = 0; }
var remaining = @intFromEnum(limit) - copy_len; var remaining = limit;
// From here, we leave `buffer` empty, appending directly to `list`. while (remaining.nonzero()) {
var writer: Writer = .{ const n = stream(r, &a.writer, remaining) catch |err| switch (err) {
.buffer = undefined,
.end = undefined,
.vtable = &.{ .drain = Writer.fixedDrain },
};
while (true) {
try list.ensureUnusedCapacity(gpa, 2);
const cap = list.unusedCapacitySlice();
const dest = cap[0..@min(cap.len, remaining + 1)];
writer.buffer = list.allocatedSlice();
writer.end = list.items.len;
const n = r.vtable.stream(r, &writer, .limited(dest.len)) catch |err| switch (err) {
error.WriteFailed => unreachable, // Prevented by the limit.
error.EndOfStream => return, error.EndOfStream => return,
error.WriteFailed => return error.OutOfMemory,
error.ReadFailed => return error.ReadFailed, error.ReadFailed => return error.ReadFailed,
}; };
list.items.len += n; remaining = remaining.subtract(n).?;
if (n > remaining) {
// Move the byte to `Reader.buffer` so it is not lost.
assert(n - remaining == 1);
assert(r.end == 0);
r.buffer[0] = list.items[list.items.len - 1];
list.items.len -= 1;
r.end = 1;
return;
}
remaining -= n;
} }
return error.StreamTooLong;
} }
pub const UnlimitedAllocError = Allocator.Error || ShortError; pub const UnlimitedAllocError = Allocator.Error || ShortError;
pub fn appendRemainingUnlimited( pub fn appendRemainingUnlimited(r: *Reader, gpa: Allocator, list: *ArrayList(u8)) UnlimitedAllocError!void {
r: *Reader, var a: std.Io.Writer.Allocating = .initOwnedSlice(gpa, list.items);
gpa: Allocator, a.writer.end = list.items.len;
comptime alignment: ?std.mem.Alignment, list.* = .empty;
list: *std.ArrayListAlignedUnmanaged(u8, alignment), defer {
bump: usize, list.* = .{
) UnlimitedAllocError!void { .items = a.writer.buffer[0..a.writer.end],
const buffer_contents = r.buffer[r.seek..r.end]; .capacity = a.writer.buffer.len,
try list.ensureUnusedCapacity(gpa, buffer_contents.len + bump);
list.appendSliceAssumeCapacity(buffer_contents);
// If statement protects `ending`.
if (r.end != 0) {
r.seek = 0;
r.end = 0;
}
// From here, we leave `buffer` empty, appending directly to `list`.
var writer: Writer = .{
.buffer = undefined,
.end = undefined,
.vtable = &.{ .drain = Writer.fixedDrain },
};
while (true) {
try list.ensureUnusedCapacity(gpa, bump);
writer.buffer = list.allocatedSlice();
writer.end = list.items.len;
const n = r.vtable.stream(r, &writer, .limited(list.unusedCapacitySlice().len)) catch |err| switch (err) {
error.WriteFailed => unreachable, // Prevented by the limit.
error.EndOfStream => return,
error.ReadFailed => return error.ReadFailed,
}; };
list.items.len += n;
} }
_ = streamRemaining(r, &a.writer) catch |err| switch (err) {
error.WriteFailed => return error.OutOfMemory,
error.ReadFailed => return error.ReadFailed,
};
} }
/// Writes bytes from the internally tracked stream position to `data`. /// Writes bytes from the internally tracked stream position to `data`.
@@ -1295,7 +1255,10 @@ fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Resu
/// Ensures `capacity` more data can be buffered without rebasing. /// Ensures `capacity` more data can be buffered without rebasing.
pub fn rebase(r: *Reader, capacity: usize) RebaseError!void { pub fn rebase(r: *Reader, capacity: usize) RebaseError!void {
if (r.end + capacity <= r.buffer.len) return; if (r.end + capacity <= r.buffer.len) {
@branchHint(.likely);
return;
}
return r.vtable.rebase(r, capacity); return r.vtable.rebase(r, capacity);
} }

View File

@@ -329,7 +329,7 @@ pub fn rebase(w: *Writer, preserve: usize, unused_capacity_len: usize) Error!voi
@branchHint(.likely); @branchHint(.likely);
return; return;
} }
try w.vtable.rebase(w, preserve, unused_capacity_len); return w.vtable.rebase(w, preserve, unused_capacity_len);
} }
pub fn defaultRebase(w: *Writer, preserve: usize, minimum_len: usize) Error!void { pub fn defaultRebase(w: *Writer, preserve: usize, minimum_len: usize) Error!void {
@@ -2349,6 +2349,13 @@ pub fn unreachableDrain(w: *Writer, data: []const []const u8, splat: usize) Erro
unreachable; unreachable;
} }
pub fn unreachableRebase(w: *Writer, preserve: usize, capacity: usize) Error!void {
_ = w;
_ = preserve;
_ = capacity;
unreachable;
}
/// Provides a `Writer` implementation based on calling `Hasher.update`, sending /// Provides a `Writer` implementation based on calling `Hasher.update`, sending
/// all data also to an underlying `Writer`. /// all data also to an underlying `Writer`.
/// ///
@@ -2489,10 +2496,6 @@ pub fn Hashing(comptime Hasher: type) type {
pub const Allocating = struct { pub const Allocating = struct {
allocator: Allocator, allocator: Allocator,
writer: Writer, writer: Writer,
/// Every call to `drain` ensures at least this amount of unused capacity
/// before it returns. This prevents an infinite loop in interface logic
/// that calls `drain`.
minimum_unused_capacity: usize = 1,
pub fn init(allocator: Allocator) Allocating { pub fn init(allocator: Allocator) Allocating {
return .{ return .{
@@ -2604,13 +2607,12 @@ pub const Allocating = struct {
const gpa = a.allocator; const gpa = a.allocator;
const pattern = data[data.len - 1]; const pattern = data[data.len - 1];
const splat_len = pattern.len * splat; const splat_len = pattern.len * splat;
const bump = a.minimum_unused_capacity;
var list = a.toArrayList(); var list = a.toArrayList();
defer setArrayList(a, list); defer setArrayList(a, list);
const start_len = list.items.len; const start_len = list.items.len;
assert(data.len != 0); assert(data.len != 0);
for (data) |bytes| { for (data) |bytes| {
list.ensureUnusedCapacity(gpa, bytes.len + splat_len + bump) catch return error.WriteFailed; list.ensureUnusedCapacity(gpa, bytes.len + splat_len + 1) catch return error.WriteFailed;
list.appendSliceAssumeCapacity(bytes); list.appendSliceAssumeCapacity(bytes);
} }
if (splat == 0) { if (splat == 0) {
@@ -2641,11 +2643,12 @@ pub const Allocating = struct {
} }
fn growingRebase(w: *Writer, preserve: usize, minimum_len: usize) Error!void { fn growingRebase(w: *Writer, preserve: usize, minimum_len: usize) Error!void {
_ = preserve; // This implementation always preserves the entire buffer.
const a: *Allocating = @fieldParentPtr("writer", w); const a: *Allocating = @fieldParentPtr("writer", w);
const gpa = a.allocator; const gpa = a.allocator;
var list = a.toArrayList(); var list = a.toArrayList();
defer setArrayList(a, list); defer setArrayList(a, list);
const total = std.math.add(usize, preserve, minimum_len) catch return error.WriteFailed;
list.ensureTotalCapacity(gpa, total) catch return error.WriteFailed;
list.ensureUnusedCapacity(gpa, minimum_len) catch return error.WriteFailed; list.ensureUnusedCapacity(gpa, minimum_len) catch return error.WriteFailed;
} }

View File

@@ -1033,7 +1033,7 @@ pub fn Aligned(comptime T: type, comptime alignment: ?mem.Alignment) type {
pub fn print(self: *Self, gpa: Allocator, comptime fmt: []const u8, args: anytype) error{OutOfMemory}!void { pub fn print(self: *Self, gpa: Allocator, comptime fmt: []const u8, args: anytype) error{OutOfMemory}!void {
comptime assert(T == u8); comptime assert(T == u8);
try self.ensureUnusedCapacity(gpa, fmt.len); try self.ensureUnusedCapacity(gpa, fmt.len);
var aw: std.io.Writer.Allocating = .fromArrayList(gpa, self); var aw: std.Io.Writer.Allocating = .fromArrayList(gpa, self);
defer self.* = aw.toArrayList(); defer self.* = aw.toArrayList();
return aw.writer.print(fmt, args) catch |err| switch (err) { return aw.writer.print(fmt, args) catch |err| switch (err) {
error.WriteFailed => return error.OutOfMemory, error.WriteFailed => return error.OutOfMemory,

View File

@@ -62,7 +62,7 @@ pub const Error = Container.Error || error{
const direct_vtable: Reader.VTable = .{ const direct_vtable: Reader.VTable = .{
.stream = streamDirect, .stream = streamDirect,
.rebase = rebaseFallible, .rebase = rebaseFallible,
.discard = discard, .discard = discardDirect,
.readVec = readVec, .readVec = readVec,
}; };
@@ -105,17 +105,16 @@ fn rebaseFallible(r: *Reader, capacity: usize) Reader.RebaseError!void {
fn rebase(r: *Reader, capacity: usize) void { fn rebase(r: *Reader, capacity: usize) void {
assert(capacity <= r.buffer.len - flate.history_len); assert(capacity <= r.buffer.len - flate.history_len);
assert(r.end + capacity > r.buffer.len); assert(r.end + capacity > r.buffer.len);
const discard_n = r.end - flate.history_len; const discard_n = @min(r.seek, r.end - flate.history_len);
const keep = r.buffer[discard_n..r.end]; const keep = r.buffer[discard_n..r.end];
@memmove(r.buffer[0..keep.len], keep); @memmove(r.buffer[0..keep.len], keep);
assert(keep.len != 0);
r.end = keep.len; r.end = keep.len;
r.seek -= discard_n; r.seek -= discard_n;
} }
/// This could be improved so that when an amount is discarded that includes an /// This could be improved so that when an amount is discarded that includes an
/// entire frame, skip decoding that frame. /// entire frame, skip decoding that frame.
fn discard(r: *Reader, limit: std.Io.Limit) Reader.Error!usize { fn discardDirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
if (r.end + flate.history_len > r.buffer.len) rebase(r, flate.history_len); if (r.end + flate.history_len > r.buffer.len) rebase(r, flate.history_len);
var writer: Writer = .{ var writer: Writer = .{
.vtable = &.{ .vtable = &.{
@@ -167,11 +166,14 @@ fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
fn streamIndirectInner(d: *Decompress) Reader.Error!usize { fn streamIndirectInner(d: *Decompress) Reader.Error!usize {
const r = &d.reader; const r = &d.reader;
if (r.end + flate.history_len > r.buffer.len) rebase(r, flate.history_len); if (r.buffer.len - r.end < flate.history_len) rebase(r, flate.history_len);
var writer: Writer = .{ var writer: Writer = .{
.buffer = r.buffer, .buffer = r.buffer,
.end = r.end, .end = r.end,
.vtable = &.{ .drain = Writer.unreachableDrain }, .vtable = &.{
.drain = Writer.unreachableDrain,
.rebase = Writer.unreachableRebase,
},
}; };
defer r.end = writer.end; defer r.end = writer.end;
_ = streamFallible(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) { _ = streamFallible(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
@@ -1251,8 +1253,6 @@ test "zlib should not overshoot" {
fn testFailure(container: Container, in: []const u8, expected_err: anyerror) !void { fn testFailure(container: Container, in: []const u8, expected_err: anyerror) !void {
var reader: Reader = .fixed(in); var reader: Reader = .fixed(in);
var aw: Writer.Allocating = .init(testing.allocator); var aw: Writer.Allocating = .init(testing.allocator);
aw.minimum_unused_capacity = flate.history_len;
try aw.ensureUnusedCapacity(flate.max_window_len);
defer aw.deinit(); defer aw.deinit();
var decompress: Decompress = .init(&reader, container, &.{}); var decompress: Decompress = .init(&reader, container, &.{});
@@ -1263,8 +1263,6 @@ fn testFailure(container: Container, in: []const u8, expected_err: anyerror) !vo
fn testDecompress(container: Container, compressed: []const u8, expected_plain: []const u8) !void { fn testDecompress(container: Container, compressed: []const u8, expected_plain: []const u8) !void {
var in: std.Io.Reader = .fixed(compressed); var in: std.Io.Reader = .fixed(compressed);
var aw: std.Io.Writer.Allocating = .init(testing.allocator); var aw: std.Io.Writer.Allocating = .init(testing.allocator);
aw.minimum_unused_capacity = flate.history_len;
try aw.ensureUnusedCapacity(flate.max_window_len);
defer aw.deinit(); defer aw.deinit();
var decompress: Decompress = .init(&in, container, &.{}); var decompress: Decompress = .init(&in, container, &.{});

View File

@@ -78,15 +78,14 @@ pub const table_size_max = struct {
}; };
fn testDecompress(gpa: std.mem.Allocator, compressed: []const u8) ![]u8 { fn testDecompress(gpa: std.mem.Allocator, compressed: []const u8) ![]u8 {
var out: std.ArrayListUnmanaged(u8) = .empty; var out: std.Io.Writer.Allocating = .init(gpa);
defer out.deinit(gpa); defer out.deinit();
try out.ensureUnusedCapacity(gpa, default_window_len);
var in: std.io.Reader = .fixed(compressed); var in: std.Io.Reader = .fixed(compressed);
var zstd_stream: Decompress = .init(&in, &.{}, .{}); var zstd_stream: Decompress = .init(&in, &.{}, .{});
try zstd_stream.reader.appendRemaining(gpa, null, &out, .unlimited); _ = try zstd_stream.reader.streamRemaining(&out.writer);
return out.toOwnedSlice(gpa); return out.toOwnedSlice();
} }
fn testExpectDecompress(uncompressed: []const u8, compressed: []const u8) !void { fn testExpectDecompress(uncompressed: []const u8, compressed: []const u8) !void {
@@ -99,15 +98,14 @@ fn testExpectDecompress(uncompressed: []const u8, compressed: []const u8) !void
fn testExpectDecompressError(err: anyerror, compressed: []const u8) !void { fn testExpectDecompressError(err: anyerror, compressed: []const u8) !void {
const gpa = std.testing.allocator; const gpa = std.testing.allocator;
var out: std.ArrayListUnmanaged(u8) = .empty; var out: std.Io.Writer.Allocating = .init(gpa);
defer out.deinit(gpa); defer out.deinit();
try out.ensureUnusedCapacity(gpa, default_window_len);
var in: std.io.Reader = .fixed(compressed); var in: std.Io.Reader = .fixed(compressed);
var zstd_stream: Decompress = .init(&in, &.{}, .{}); var zstd_stream: Decompress = .init(&in, &.{}, .{});
try std.testing.expectError( try std.testing.expectError(
error.ReadFailed, error.ReadFailed,
zstd_stream.reader.appendRemaining(gpa, null, &out, .unlimited), zstd_stream.reader.streamRemaining(&out.writer),
); );
try std.testing.expectError(err, zstd_stream.err orelse {}); try std.testing.expectError(err, zstd_stream.err orelse {});
} }

View File

@@ -73,6 +73,20 @@ pub const Error = error{
WindowSizeUnknown, WindowSizeUnknown,
}; };
const direct_vtable: Reader.VTable = .{
.stream = streamDirect,
.rebase = rebaseFallible,
.discard = discardDirect,
.readVec = readVec,
};
const indirect_vtable: Reader.VTable = .{
.stream = streamIndirect,
.rebase = rebaseFallible,
.discard = discardIndirect,
.readVec = readVec,
};
/// When connecting `reader` to a `Writer`, `buffer` should be empty, and /// When connecting `reader` to a `Writer`, `buffer` should be empty, and
/// `Writer.buffer` capacity has requirements based on `Options.window_len`. /// `Writer.buffer` capacity has requirements based on `Options.window_len`.
/// ///
@@ -84,12 +98,7 @@ pub fn init(input: *Reader, buffer: []u8, options: Options) Decompress {
.verify_checksum = options.verify_checksum, .verify_checksum = options.verify_checksum,
.window_len = options.window_len, .window_len = options.window_len,
.reader = .{ .reader = .{
.vtable = &.{ .vtable = if (buffer.len == 0) &direct_vtable else &indirect_vtable,
.stream = stream,
.rebase = rebase,
.discard = discard,
.readVec = readVec,
},
.buffer = buffer, .buffer = buffer,
.seek = 0, .seek = 0,
.end = 0, .end = 0,
@@ -97,11 +106,27 @@ pub fn init(input: *Reader, buffer: []u8, options: Options) Decompress {
}; };
} }
fn rebase(r: *Reader, capacity: usize) Reader.RebaseError!void { fn streamDirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
return stream(d, w, limit);
}
fn streamIndirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
_ = limit;
_ = w;
return streamIndirectInner(d);
}
fn rebaseFallible(r: *Reader, capacity: usize) Reader.RebaseError!void {
rebase(r, capacity);
}
fn rebase(r: *Reader, capacity: usize) void {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
assert(capacity <= r.buffer.len - d.window_len); assert(capacity <= r.buffer.len - d.window_len);
assert(r.end + capacity > r.buffer.len); assert(r.end + capacity > r.buffer.len);
const discard_n = r.end - d.window_len; const discard_n = @min(r.seek, r.end - d.window_len);
const keep = r.buffer[discard_n..r.end]; const keep = r.buffer[discard_n..r.end];
@memmove(r.buffer[0..keep.len], keep); @memmove(r.buffer[0..keep.len], keep);
r.end = keep.len; r.end = keep.len;
@@ -110,9 +135,9 @@ fn rebase(r: *Reader, capacity: usize) Reader.RebaseError!void {
/// This could be improved so that when an amount is discarded that includes an /// This could be improved so that when an amount is discarded that includes an
/// entire frame, skip decoding that frame. /// entire frame, skip decoding that frame.
fn discard(r: *Reader, limit: std.Io.Limit) Reader.Error!usize { fn discardDirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
r.rebase(d.window_len) catch unreachable; rebase(r, d.window_len);
var writer: Writer = .{ var writer: Writer = .{
.vtable = &.{ .vtable = &.{
.drain = std.Io.Writer.Discarding.drain, .drain = std.Io.Writer.Discarding.drain,
@@ -134,25 +159,53 @@ fn discard(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
return n; return n;
} }
fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize { fn discardIndirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
_ = data;
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r)); const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
assert(r.seek == r.end); rebase(r, d.window_len);
r.rebase(d.window_len) catch unreachable;
var writer: Writer = .{ var writer: Writer = .{
.buffer = r.buffer, .buffer = r.buffer,
.end = r.end, .end = r.end,
.vtable = &.{ .drain = Writer.fixedDrain }, .vtable = &.{ .drain = Writer.unreachableDrain },
}; };
r.end += r.vtable.stream(r, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) { {
defer r.end = writer.end;
_ = stream(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
error.WriteFailed => unreachable,
else => |e| return e,
};
}
const n = limit.minInt(r.end - r.seek);
r.seek += n;
return n;
}
fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
_ = data;
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
return streamIndirectInner(d);
}
fn streamIndirectInner(d: *Decompress) Reader.Error!usize {
const r = &d.reader;
if (r.buffer.len - r.end < zstd.block_size_max) rebase(r, zstd.block_size_max);
assert(r.buffer.len - r.end >= zstd.block_size_max);
var writer: Writer = .{
.buffer = r.buffer,
.end = r.end,
.vtable = &.{
.drain = Writer.unreachableDrain,
.rebase = Writer.unreachableRebase,
},
};
defer r.end = writer.end;
_ = stream(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
error.WriteFailed => unreachable, error.WriteFailed => unreachable,
else => |e| return e, else => |e| return e,
}; };
return 0; return 0;
} }
fn stream(r: *Reader, w: *Writer, limit: Limit) Reader.StreamError!usize { fn stream(d: *Decompress, w: *Writer, limit: Limit) Reader.StreamError!usize {
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
const in = d.input; const in = d.input;
state: switch (d.state) { state: switch (d.state) {
@@ -170,7 +223,7 @@ fn stream(r: *Reader, w: *Writer, limit: Limit) Reader.StreamError!usize {
else => |e| return e, else => |e| return e,
}; };
const magic = try in.takeEnumNonexhaustive(Frame.Magic, .little); const magic = try in.takeEnumNonexhaustive(Frame.Magic, .little);
initFrame(d, w.buffer.len, magic) catch |err| { initFrame(d, magic) catch |err| {
d.err = err; d.err = err;
return error.ReadFailed; return error.ReadFailed;
}; };
@@ -198,13 +251,13 @@ fn stream(r: *Reader, w: *Writer, limit: Limit) Reader.StreamError!usize {
} }
} }
fn initFrame(d: *Decompress, window_size_max: usize, magic: Frame.Magic) !void { fn initFrame(d: *Decompress, magic: Frame.Magic) !void {
const in = d.input; const in = d.input;
switch (magic.kind() orelse return error.BadMagic) { switch (magic.kind() orelse return error.BadMagic) {
.zstandard => { .zstandard => {
const header = try Frame.Zstandard.Header.decode(in); const header = try Frame.Zstandard.Header.decode(in);
d.state = .{ .in_frame = .{ d.state = .{ .in_frame = .{
.frame = try Frame.init(header, window_size_max, d.verify_checksum), .frame = try Frame.init(header, d.window_len, d.verify_checksum),
.checksum = null, .checksum = null,
.decompressed_size = 0, .decompressed_size = 0,
.decode = .init, .decode = .init,
@@ -258,7 +311,6 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame)
try decode.readInitialFseState(&bit_stream); try decode.readInitialFseState(&bit_stream);
// Ensures the following calls to `decodeSequence` will not flush. // Ensures the following calls to `decodeSequence` will not flush.
if (window_len + frame_block_size_max > w.buffer.len) return error.OutputBufferUndersize;
const dest = (try w.writableSliceGreedyPreserve(window_len, frame_block_size_max))[0..frame_block_size_max]; const dest = (try w.writableSliceGreedyPreserve(window_len, frame_block_size_max))[0..frame_block_size_max];
const write_pos = dest.ptr - w.buffer.ptr; const write_pos = dest.ptr - w.buffer.ptr;
for (0..sequences_header.sequence_count - 1) |_| { for (0..sequences_header.sequence_count - 1) |_| {
@@ -775,7 +827,6 @@ pub const Frame = struct {
try w.splatByteAll(d.literal_streams.one[0], len); try w.splatByteAll(d.literal_streams.one[0], len);
}, },
.compressed, .treeless => { .compressed, .treeless => {
if (len > w.buffer.len) return error.OutputBufferUndersize;
const buf = try w.writableSlice(len); const buf = try w.writableSlice(len);
const huffman_tree = d.huffman_tree.?; const huffman_tree = d.huffman_tree.?;
const max_bit_count = huffman_tree.max_bit_count; const max_bit_count = huffman_tree.max_bit_count;

View File

@@ -2247,7 +2247,7 @@ pub const ElfModule = struct {
var decompress: std.compress.flate.Decompress = .init(&section_reader, .zlib, &.{}); var decompress: std.compress.flate.Decompress = .init(&section_reader, .zlib, &.{});
var decompressed_section: ArrayList(u8) = .empty; var decompressed_section: ArrayList(u8) = .empty;
defer decompressed_section.deinit(gpa); defer decompressed_section.deinit(gpa);
decompress.reader.appendRemainingUnlimited(gpa, null, &decompressed_section, std.compress.flate.history_len) catch { decompress.reader.appendRemainingUnlimited(gpa, &decompressed_section) catch {
invalidDebugInfoDetected(); invalidDebugInfoDetected();
continue; continue;
}; };

View File

@@ -149,9 +149,8 @@ test "HTTP server handles a chunked transfer coding request" {
"content-type: text/plain\r\n" ++ "content-type: text/plain\r\n" ++
"\r\n" ++ "\r\n" ++
"message from server!\n"; "message from server!\n";
var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&.{});
var stream_reader = stream.reader(&tiny_buffer); const response = try stream_reader.interface().allocRemaining(gpa, .limited(expected_response.len + 1));
const response = try stream_reader.interface().allocRemaining(gpa, .limited(expected_response.len));
defer gpa.free(response); defer gpa.free(response);
try expectEqualStrings(expected_response, response); try expectEqualStrings(expected_response, response);
} }
@@ -293,8 +292,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" {
var stream_writer = stream.writer(&.{}); var stream_writer = stream.writer(&.{});
try stream_writer.interface.writeAll(request_bytes); try stream_writer.interface.writeAll(request_bytes);
var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&.{});
var stream_reader = stream.reader(&tiny_buffer);
const response = try stream_reader.interface().allocRemaining(gpa, .unlimited); const response = try stream_reader.interface().allocRemaining(gpa, .unlimited);
defer gpa.free(response); defer gpa.free(response);
@@ -364,8 +362,7 @@ test "receiving arbitrary http headers from the client" {
var stream_writer = stream.writer(&.{}); var stream_writer = stream.writer(&.{});
try stream_writer.interface.writeAll(request_bytes); try stream_writer.interface.writeAll(request_bytes);
var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&.{});
var stream_reader = stream.reader(&tiny_buffer);
const response = try stream_reader.interface().allocRemaining(gpa, .unlimited); const response = try stream_reader.interface().allocRemaining(gpa, .unlimited);
defer gpa.free(response); defer gpa.free(response);

View File

@@ -4,6 +4,7 @@ const assert = std.debug.assert;
const testing = std.testing; const testing = std.testing;
const mem = std.mem; const mem = std.mem;
const native_endian = builtin.cpu.arch.endian(); const native_endian = builtin.cpu.arch.endian();
const Allocator = std.mem.Allocator;
/// Use this to replace an unknown, unrecognized, or unrepresentable character. /// Use this to replace an unknown, unrecognized, or unrepresentable character.
/// ///
@@ -921,7 +922,7 @@ fn utf16LeToUtf8ArrayListImpl(
comptime surrogates: Surrogates, comptime surrogates: Surrogates,
) (switch (surrogates) { ) (switch (surrogates) {
.cannot_encode_surrogate_half => Utf16LeToUtf8AllocError, .cannot_encode_surrogate_half => Utf16LeToUtf8AllocError,
.can_encode_surrogate_half => mem.Allocator.Error, .can_encode_surrogate_half => Allocator.Error,
})!void { })!void {
assert(result.unusedCapacitySlice().len >= utf16le.len); assert(result.unusedCapacitySlice().len >= utf16le.len);
@@ -965,15 +966,15 @@ fn utf16LeToUtf8ArrayListImpl(
} }
} }
pub const Utf16LeToUtf8AllocError = mem.Allocator.Error || Utf16LeToUtf8Error; pub const Utf16LeToUtf8AllocError = Allocator.Error || Utf16LeToUtf8Error;
pub fn utf16LeToUtf8ArrayList(result: *std.array_list.Managed(u8), utf16le: []const u16) Utf16LeToUtf8AllocError!void { pub fn utf16LeToUtf8ArrayList(result: *std.array_list.Managed(u8), utf16le: []const u16) Utf16LeToUtf8AllocError!void {
try result.ensureUnusedCapacity(utf16le.len); try result.ensureUnusedCapacity(utf16le.len);
return utf16LeToUtf8ArrayListImpl(result, utf16le, .cannot_encode_surrogate_half); return utf16LeToUtf8ArrayListImpl(result, utf16le, .cannot_encode_surrogate_half);
} }
/// Caller must free returned memory. /// Caller owns returned memory.
pub fn utf16LeToUtf8Alloc(allocator: mem.Allocator, utf16le: []const u16) Utf16LeToUtf8AllocError![]u8 { pub fn utf16LeToUtf8Alloc(allocator: Allocator, utf16le: []const u16) Utf16LeToUtf8AllocError![]u8 {
// optimistically guess that it will all be ascii. // optimistically guess that it will all be ascii.
var result = try std.array_list.Managed(u8).initCapacity(allocator, utf16le.len); var result = try std.array_list.Managed(u8).initCapacity(allocator, utf16le.len);
errdefer result.deinit(); errdefer result.deinit();
@@ -982,8 +983,8 @@ pub fn utf16LeToUtf8Alloc(allocator: mem.Allocator, utf16le: []const u16) Utf16L
return result.toOwnedSlice(); return result.toOwnedSlice();
} }
/// Caller must free returned memory. /// Caller owns returned memory.
pub fn utf16LeToUtf8AllocZ(allocator: mem.Allocator, utf16le: []const u16) Utf16LeToUtf8AllocError![:0]u8 { pub fn utf16LeToUtf8AllocZ(allocator: Allocator, utf16le: []const u16) Utf16LeToUtf8AllocError![:0]u8 {
// optimistically guess that it will all be ascii (and allocate space for the null terminator) // optimistically guess that it will all be ascii (and allocate space for the null terminator)
var result = try std.array_list.Managed(u8).initCapacity(allocator, utf16le.len + 1); var result = try std.array_list.Managed(u8).initCapacity(allocator, utf16le.len + 1);
errdefer result.deinit(); errdefer result.deinit();
@@ -1160,7 +1161,7 @@ pub fn utf8ToUtf16LeArrayList(result: *std.array_list.Managed(u16), utf8: []cons
return utf8ToUtf16LeArrayListImpl(result, utf8, .cannot_encode_surrogate_half); return utf8ToUtf16LeArrayListImpl(result, utf8, .cannot_encode_surrogate_half);
} }
pub fn utf8ToUtf16LeAlloc(allocator: mem.Allocator, utf8: []const u8) error{ InvalidUtf8, OutOfMemory }![]u16 { pub fn utf8ToUtf16LeAlloc(allocator: Allocator, utf8: []const u8) error{ InvalidUtf8, OutOfMemory }![]u16 {
// optimistically guess that it will not require surrogate pairs // optimistically guess that it will not require surrogate pairs
var result = try std.array_list.Managed(u16).initCapacity(allocator, utf8.len); var result = try std.array_list.Managed(u16).initCapacity(allocator, utf8.len);
errdefer result.deinit(); errdefer result.deinit();
@@ -1169,7 +1170,7 @@ pub fn utf8ToUtf16LeAlloc(allocator: mem.Allocator, utf8: []const u8) error{ Inv
return result.toOwnedSlice(); return result.toOwnedSlice();
} }
pub fn utf8ToUtf16LeAllocZ(allocator: mem.Allocator, utf8: []const u8) error{ InvalidUtf8, OutOfMemory }![:0]u16 { pub fn utf8ToUtf16LeAllocZ(allocator: Allocator, utf8: []const u8) error{ InvalidUtf8, OutOfMemory }![:0]u16 {
// optimistically guess that it will not require surrogate pairs // optimistically guess that it will not require surrogate pairs
var result = try std.array_list.Managed(u16).initCapacity(allocator, utf8.len + 1); var result = try std.array_list.Managed(u16).initCapacity(allocator, utf8.len + 1);
errdefer result.deinit(); errdefer result.deinit();
@@ -1750,13 +1751,13 @@ pub const Wtf8Iterator = struct {
} }
}; };
pub fn wtf16LeToWtf8ArrayList(result: *std.array_list.Managed(u8), utf16le: []const u16) mem.Allocator.Error!void { pub fn wtf16LeToWtf8ArrayList(result: *std.array_list.Managed(u8), utf16le: []const u16) Allocator.Error!void {
try result.ensureUnusedCapacity(utf16le.len); try result.ensureUnusedCapacity(utf16le.len);
return utf16LeToUtf8ArrayListImpl(result, utf16le, .can_encode_surrogate_half); return utf16LeToUtf8ArrayListImpl(result, utf16le, .can_encode_surrogate_half);
} }
/// Caller must free returned memory. /// Caller must free returned memory.
pub fn wtf16LeToWtf8Alloc(allocator: mem.Allocator, wtf16le: []const u16) mem.Allocator.Error![]u8 { pub fn wtf16LeToWtf8Alloc(allocator: Allocator, wtf16le: []const u16) Allocator.Error![]u8 {
// optimistically guess that it will all be ascii. // optimistically guess that it will all be ascii.
var result = try std.array_list.Managed(u8).initCapacity(allocator, wtf16le.len); var result = try std.array_list.Managed(u8).initCapacity(allocator, wtf16le.len);
errdefer result.deinit(); errdefer result.deinit();
@@ -1766,7 +1767,7 @@ pub fn wtf16LeToWtf8Alloc(allocator: mem.Allocator, wtf16le: []const u16) mem.Al
} }
/// Caller must free returned memory. /// Caller must free returned memory.
pub fn wtf16LeToWtf8AllocZ(allocator: mem.Allocator, wtf16le: []const u16) mem.Allocator.Error![:0]u8 { pub fn wtf16LeToWtf8AllocZ(allocator: Allocator, wtf16le: []const u16) Allocator.Error![:0]u8 {
// optimistically guess that it will all be ascii (and allocate space for the null terminator) // optimistically guess that it will all be ascii (and allocate space for the null terminator)
var result = try std.array_list.Managed(u8).initCapacity(allocator, wtf16le.len + 1); var result = try std.array_list.Managed(u8).initCapacity(allocator, wtf16le.len + 1);
errdefer result.deinit(); errdefer result.deinit();
@@ -1784,7 +1785,7 @@ pub fn wtf8ToWtf16LeArrayList(result: *std.array_list.Managed(u16), wtf8: []cons
return utf8ToUtf16LeArrayListImpl(result, wtf8, .can_encode_surrogate_half); return utf8ToUtf16LeArrayListImpl(result, wtf8, .can_encode_surrogate_half);
} }
pub fn wtf8ToWtf16LeAlloc(allocator: mem.Allocator, wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }![]u16 { pub fn wtf8ToWtf16LeAlloc(allocator: Allocator, wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }![]u16 {
// optimistically guess that it will not require surrogate pairs // optimistically guess that it will not require surrogate pairs
var result = try std.array_list.Managed(u16).initCapacity(allocator, wtf8.len); var result = try std.array_list.Managed(u16).initCapacity(allocator, wtf8.len);
errdefer result.deinit(); errdefer result.deinit();
@@ -1793,7 +1794,7 @@ pub fn wtf8ToWtf16LeAlloc(allocator: mem.Allocator, wtf8: []const u8) error{ Inv
return result.toOwnedSlice(); return result.toOwnedSlice();
} }
pub fn wtf8ToWtf16LeAllocZ(allocator: mem.Allocator, wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }![:0]u16 { pub fn wtf8ToWtf16LeAllocZ(allocator: Allocator, wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }![:0]u16 {
// optimistically guess that it will not require surrogate pairs // optimistically guess that it will not require surrogate pairs
var result = try std.array_list.Managed(u16).initCapacity(allocator, wtf8.len + 1); var result = try std.array_list.Managed(u16).initCapacity(allocator, wtf8.len + 1);
errdefer result.deinit(); errdefer result.deinit();
@@ -1870,7 +1871,7 @@ pub fn wtf8ToUtf8Lossy(utf8: []u8, wtf8: []const u8) error{InvalidWtf8}!void {
} }
} }
pub fn wtf8ToUtf8LossyAlloc(allocator: mem.Allocator, wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }![]u8 { pub fn wtf8ToUtf8LossyAlloc(allocator: Allocator, wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }![]u8 {
const utf8 = try allocator.alloc(u8, wtf8.len); const utf8 = try allocator.alloc(u8, wtf8.len);
errdefer allocator.free(utf8); errdefer allocator.free(utf8);
@@ -1879,7 +1880,7 @@ pub fn wtf8ToUtf8LossyAlloc(allocator: mem.Allocator, wtf8: []const u8) error{ I
return utf8; return utf8;
} }
pub fn wtf8ToUtf8LossyAllocZ(allocator: mem.Allocator, wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }![:0]u8 { pub fn wtf8ToUtf8LossyAllocZ(allocator: Allocator, wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }![:0]u8 {
const utf8 = try allocator.allocSentinel(u8, wtf8.len, 0); const utf8 = try allocator.allocSentinel(u8, wtf8.len, 0);
errdefer allocator.free(utf8); errdefer allocator.free(utf8);

View File

@@ -554,8 +554,11 @@ test isUnderscore {
try std.testing.expect(!isUnderscore("\\x5f")); try std.testing.expect(!isUnderscore("\\x5f"));
} }
/// If the source can be UTF-16LE encoded, this function asserts that `gpa`
/// will align a byte-sized allocation to at least 2. Allocators that don't do
/// this are rare.
pub fn readSourceFileToEndAlloc(gpa: Allocator, file_reader: *std.fs.File.Reader) ![:0]u8 { pub fn readSourceFileToEndAlloc(gpa: Allocator, file_reader: *std.fs.File.Reader) ![:0]u8 {
var buffer: std.ArrayListAlignedUnmanaged(u8, .@"2") = .empty; var buffer: std.ArrayList(u8) = .empty;
defer buffer.deinit(gpa); defer buffer.deinit(gpa);
if (file_reader.getSize()) |size| { if (file_reader.getSize()) |size| {
@@ -564,7 +567,7 @@ pub fn readSourceFileToEndAlloc(gpa: Allocator, file_reader: *std.fs.File.Reader
try buffer.ensureTotalCapacityPrecise(gpa, casted_size + 1); try buffer.ensureTotalCapacityPrecise(gpa, casted_size + 1);
} else |_| {} } else |_| {}
try file_reader.interface.appendRemaining(gpa, .@"2", &buffer, .limited(max_src_size)); try file_reader.interface.appendRemaining(gpa, &buffer, .limited(max_src_size));
// Detect unsupported file types with their Byte Order Mark // Detect unsupported file types with their Byte Order Mark
const unsupported_boms = [_][]const u8{ const unsupported_boms = [_][]const u8{
@@ -581,7 +584,7 @@ pub fn readSourceFileToEndAlloc(gpa: Allocator, file_reader: *std.fs.File.Reader
// If the file starts with a UTF-16 little endian BOM, translate it to UTF-8 // If the file starts with a UTF-16 little endian BOM, translate it to UTF-8
if (std.mem.startsWith(u8, buffer.items, "\xff\xfe")) { if (std.mem.startsWith(u8, buffer.items, "\xff\xfe")) {
if (buffer.items.len % 2 != 0) return error.InvalidEncoding; if (buffer.items.len % 2 != 0) return error.InvalidEncoding;
return std.unicode.utf16LeToUtf8AllocZ(gpa, @ptrCast(buffer.items)) catch |err| switch (err) { return std.unicode.utf16LeToUtf8AllocZ(gpa, @ptrCast(@alignCast(buffer.items))) catch |err| switch (err) {
error.DanglingSurrogateHalf => error.UnsupportedEncoding, error.DanglingSurrogateHalf => error.UnsupportedEncoding,
error.ExpectedSecondSurrogateHalf => error.UnsupportedEncoding, error.ExpectedSecondSurrogateHalf => error.UnsupportedEncoding,
error.UnexpectedSecondSurrogateHalf => error.UnsupportedEncoding, error.UnexpectedSecondSurrogateHalf => error.UnsupportedEncoding,