Merge pull request #24740 from ziglang/http-plus-fixes

fetch, tls, and http fixes
This commit is contained in:
Andrew Kelley
2025-08-08 12:33:53 -07:00
committed by GitHub
7 changed files with 172 additions and 155 deletions

View File

@@ -25,9 +25,7 @@ pub const VTable = struct {
///
/// Returns the number of bytes written, which will be at minimum `0` and
/// at most `limit`. The number returned, including zero, does not indicate
/// end of stream. `limit` is guaranteed to be at least as large as the
/// buffer capacity of `w`, a value whose minimum size is determined by the
/// stream implementation.
/// end of stream.
///
/// The reader's internal logical seek position moves forward in accordance
/// with the number of bytes returned from this function.

View File

@@ -61,9 +61,6 @@ pub const ReadError = error{
TlsUnexpectedMessage,
TlsIllegalParameter,
TlsSequenceOverflow,
/// The buffer provided to the read function was not at least
/// `min_buffer_len`.
OutputBufferUndersize,
};
pub const SslKeyLog = struct {
@@ -372,7 +369,8 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
};
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch
return error.TlsBadRecordMac;
cleartext_fragment_end += std.mem.trimEnd(u8, cleartext, "\x00").len;
// TODO use scalar, non-slice version
cleartext_fragment_end += mem.trimEnd(u8, cleartext, "\x00").len;
},
}
read_seq += 1;
@@ -395,9 +393,9 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_fragment_buf[0..message_len];
const ad = std.mem.toBytes(big(read_seq)) ++
const ad = mem.toBytes(big(read_seq)) ++
record_header[0 .. 1 + 2] ++
std.mem.toBytes(big(message_len));
mem.toBytes(big(message_len));
const record_iv = record_decoder.array(P.record_iv_length).*;
const masked_read_seq = read_seq &
comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
@@ -738,7 +736,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
&.{ "server finished", &p.transcript_hash.finalResult() },
P.verify_data_length,
),
.app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block),
.app_cipher = mem.bytesToValue(P.Tls_1_2, &key_block),
} };
const pv = &p.version.tls_1_2;
const nonce: [P.AEAD.nonce_length]u8 = nonce: {
@@ -756,7 +754,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
client_verify_cleartext.len ..][0..client_verify_cleartext.len],
client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length],
&client_verify_cleartext,
std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
nonce,
pv.app_cipher.client_write_key,
);
@@ -873,7 +871,10 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
.input = input,
.reader = .{
.buffer = options.read_buffer,
.vtable = &.{ .stream = stream },
.vtable = &.{
.stream = stream,
.readVec = readVec,
},
.seek = 0,
.end = 0,
},
@@ -1017,7 +1018,7 @@ fn prepareCiphertextRecord(
const nonce = nonce: {
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ std.mem.toBytes(big(c.write_seq));
const operand: V = pad ++ mem.toBytes(big(c.write_seq));
break :nonce @as(V, pv.client_iv) ^ operand;
};
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
@@ -1048,7 +1049,7 @@ fn prepareCiphertextRecord(
record_header.* = .{@intFromEnum(inner_content_type)} ++
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
int(u16, P.record_iv_length + message_len + P.mac_length);
const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
const ad = mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length];
ciphertext_end += P.record_iv_length;
const nonce: [P.AEAD.nonce_length]u8 = nonce: {
@@ -1076,7 +1077,22 @@ pub fn eof(c: Client) bool {
}
fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
// This function writes exclusively to the buffer.
_ = w;
_ = limit;
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
return readIndirect(c);
}
fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
// This function writes exclusively to the buffer.
_ = data;
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
return readIndirect(c);
}
fn readIndirect(c: *Client) Reader.Error!usize {
const r = &c.reader;
if (c.eof()) return error.EndOfStream;
const input = c.input;
// If at least one full encrypted record is not buffered, read once.
@@ -1108,8 +1124,13 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
if (record_end > input.buffered().len) return 0;
}
var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
if (r.seek == r.end) {
r.seek = 0;
r.end = 0;
}
const cleartext_buffer = r.buffer[r.end..];
const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
inline else => |*p| switch (c.tls_version) {
.tls_1_3 => {
const pv = &p.tls_1_3;
@@ -1121,23 +1142,24 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
const nonce = nonce: {
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
const operand: V = pad ++ mem.toBytes(big(c.read_seq));
break :nonce @as(V, pv.server_iv) ^ operand;
};
const cleartext = cleartext_stack_buffer[0..ciphertext.len];
const cleartext = cleartext_buffer[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
return failRead(c, error.TlsBadRecordMac);
// TODO use scalar, non-slice version
const msg = mem.trimRight(u8, cleartext, "\x00");
break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
break :cleartext .{ msg.len - 1, @enumFromInt(msg[msg.len - 1]) };
},
.tls_1_2 => {
const pv = &p.tls_1_2;
const P = @TypeOf(p.*);
const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked
const ad = std.mem.toBytes(big(c.read_seq)) ++
const ad = mem.toBytes(big(c.read_seq)) ++
ad_header[0 .. 1 + 2] ++
std.mem.toBytes(big(message_len));
mem.toBytes(big(message_len));
const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked
const masked_read_seq = c.read_seq &
comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
@@ -1149,14 +1171,15 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
};
const ciphertext = input.take(message_len) catch unreachable; // already peeked
const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
const cleartext = cleartext_stack_buffer[0..ciphertext.len];
const cleartext = cleartext_buffer[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
return failRead(c, error.TlsBadRecordMac);
break :cleartext .{ cleartext, ct };
break :cleartext .{ cleartext.len, ct };
},
else => unreachable,
},
};
const cleartext = cleartext_buffer[0..cleartext_len];
c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
switch (inner_ct) {
.alert => {
@@ -1245,9 +1268,8 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
return 0;
},
.application_data => {
if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
try w.writeAll(cleartext);
return cleartext.len;
r.end += cleartext.len;
return 0;
},
else => return failRead(c, error.TlsUnexpectedMessage),
}

View File

@@ -292,6 +292,14 @@ pub const ContentEncoding = enum {
});
return map.get(s);
}
pub fn minBufferCapacity(ce: ContentEncoding) usize {
return switch (ce) {
.zstd => std.compress.zstd.default_window_len,
.gzip, .deflate => std.compress.flate.max_window_len,
.compress, .identity => 0,
};
}
};
pub const Connection = enum {
@@ -412,7 +420,7 @@ pub const Reader = struct {
/// * `interfaceDecompressing`
pub fn bodyReader(
reader: *Reader,
buffer: []u8,
transfer_buffer: []u8,
transfer_encoding: TransferEncoding,
content_length: ?u64,
) *std.Io.Reader {
@@ -421,7 +429,7 @@ pub const Reader = struct {
.chunked => {
reader.state = .{ .body_remaining_chunk_len = .head };
reader.interface = .{
.buffer = buffer,
.buffer = transfer_buffer,
.seek = 0,
.end = 0,
.vtable = &.{
@@ -435,7 +443,7 @@ pub const Reader = struct {
if (content_length) |len| {
reader.state = .{ .body_remaining_content_length = len };
reader.interface = .{
.buffer = buffer,
.buffer = transfer_buffer,
.seek = 0,
.end = 0,
.vtable = &.{
@@ -460,11 +468,12 @@ pub const Reader = struct {
/// * `interface`
pub fn bodyReaderDecompressing(
reader: *Reader,
transfer_buffer: []u8,
transfer_encoding: TransferEncoding,
content_length: ?u64,
content_encoding: ContentEncoding,
decompressor: *Decompressor,
decompression_buffer: []u8,
decompress: *Decompress,
decompress_buffer: []u8,
) *std.Io.Reader {
if (transfer_encoding == .none and content_length == null) {
assert(reader.state == .received_head);
@@ -474,22 +483,22 @@ pub const Reader = struct {
return reader.in;
},
.deflate => {
decompressor.* = .{ .flate = .init(reader.in, .zlib, decompression_buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(reader.in, .zlib, decompress_buffer) };
return &decompress.flate.reader;
},
.gzip => {
decompressor.* = .{ .flate = .init(reader.in, .gzip, decompression_buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(reader.in, .gzip, decompress_buffer) };
return &decompress.flate.reader;
},
.zstd => {
decompressor.* = .{ .zstd = .init(reader.in, decompression_buffer, .{ .verify_checksum = false }) };
return &decompressor.zstd.reader;
decompress.* = .{ .zstd = .init(reader.in, decompress_buffer, .{ .verify_checksum = false }) };
return &decompress.zstd.reader;
},
.compress => unreachable,
}
}
const transfer_reader = bodyReader(reader, &.{}, transfer_encoding, content_length);
return decompressor.init(transfer_reader, decompression_buffer, content_encoding);
const transfer_reader = bodyReader(reader, transfer_buffer, transfer_encoding, content_length);
return decompress.init(transfer_reader, decompress_buffer, content_encoding);
}
fn contentLengthStream(
@@ -691,33 +700,33 @@ pub const Reader = struct {
}
};
pub const Decompressor = union(enum) {
pub const Decompress = union(enum) {
flate: std.compress.flate.Decompress,
zstd: std.compress.zstd.Decompress,
none: *std.Io.Reader,
pub fn init(
decompressor: *Decompressor,
decompress: *Decompress,
transfer_reader: *std.Io.Reader,
buffer: []u8,
content_encoding: ContentEncoding,
) *std.Io.Reader {
switch (content_encoding) {
.identity => {
decompressor.* = .{ .none = transfer_reader };
decompress.* = .{ .none = transfer_reader };
return transfer_reader;
},
.deflate => {
decompressor.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
return &decompress.flate.reader;
},
.gzip => {
decompressor.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
return &decompressor.flate.reader;
decompress.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
return &decompress.flate.reader;
},
.zstd => {
decompressor.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
return &decompressor.zstd.reader;
decompress.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
return &decompress.zstd.reader;
},
.compress => unreachable,
}
@@ -794,7 +803,7 @@ pub const BodyWriter = struct {
}
/// When using content-length, asserts that the amount of data sent matches
/// the value sent in the header, then flushes.
/// the value sent in the header, then flushes `http_protocol_output`.
///
/// When using transfer-encoding: chunked, writes the end-of-stream message
/// with empty trailers, then flushes the stream to the system. Asserts any
@@ -818,10 +827,13 @@ pub const BodyWriter = struct {
///
/// Respects the value of `isEliding` to omit all data after the headers.
///
/// Does not flush `http_protocol_output`, but does flush `writer`.
///
/// See also:
/// * `end`
/// * `endChunked`
pub fn endUnflushed(w: *BodyWriter) Error!void {
try w.writer.flush();
switch (w.state) {
.end => unreachable,
.content_length => |len| {

View File

@@ -13,8 +13,8 @@ const net = std.net;
const Uri = std.Uri;
const Allocator = mem.Allocator;
const assert = std.debug.assert;
const Writer = std.io.Writer;
const Reader = std.io.Reader;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
const Client = @This();
@@ -704,12 +704,12 @@ pub const Response = struct {
///
/// See also:
/// * `readerDecompressing`
pub fn reader(response: *Response, buffer: []u8) *Reader {
pub fn reader(response: *Response, transfer_buffer: []u8) *Reader {
response.head.invalidateStrings();
const req = response.request;
if (!req.method.responseHasBody()) return .ending;
const head = &response.head;
return req.reader.bodyReader(buffer, head.transfer_encoding, head.content_length);
return req.reader.bodyReader(transfer_buffer, head.transfer_encoding, head.content_length);
}
/// If compressed body has been negotiated this will return decompressed bytes.
@@ -723,17 +723,19 @@ pub const Response = struct {
/// * `reader`
pub fn readerDecompressing(
response: *Response,
decompressor: *http.Decompressor,
decompression_buffer: []u8,
transfer_buffer: []u8,
decompress: *http.Decompress,
decompress_buffer: []u8,
) *Reader {
response.head.invalidateStrings();
const head = &response.head;
return response.request.reader.bodyReaderDecompressing(
transfer_buffer,
head.transfer_encoding,
head.content_length,
head.content_encoding,
decompressor,
decompression_buffer,
decompress,
decompress_buffer,
);
}
@@ -1322,7 +1324,7 @@ pub const basic_authorization = struct {
const user: Uri.Component = uri.user orelse .empty;
const password: Uri.Component = uri.password orelse .empty;
var dw: std.io.Writer.Discarding = .init(&.{});
var dw: Writer.Discarding = .init(&.{});
user.formatUser(&dw.writer) catch unreachable; // discarding
const user_len = dw.count + dw.writer.end;
@@ -1696,8 +1698,8 @@ pub const FetchOptions = struct {
/// `null` means it will be heap-allocated.
decompress_buffer: ?[]u8 = null,
redirect_behavior: ?Request.RedirectBehavior = null,
/// If the server sends a body, it will be stored here.
response_storage: ?ResponseStorage = null,
/// If the server sends a body, it will be written here.
response_writer: ?*Writer = null,
location: Location,
method: ?http.Method = null,
@@ -1725,7 +1727,7 @@ pub const FetchOptions = struct {
list: *std.ArrayListUnmanaged(u8),
/// If null then only the existing capacity will be used.
allocator: ?Allocator = null,
append_limit: std.io.Limit = .unlimited,
append_limit: std.Io.Limit = .unlimited,
};
};
@@ -1778,7 +1780,7 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
var response = try req.receiveHead(redirect_buffer);
const storage = options.response_storage orelse {
const response_writer = options.response_writer orelse {
const reader = response.reader(&.{});
_ = reader.discardRemaining() catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
@@ -1794,21 +1796,14 @@ pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
};
defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer);
var decompressor: http.Decompressor = undefined;
const reader = response.readerDecompressing(&decompressor, decompress_buffer);
const list = storage.list;
var transfer_buffer: [64]u8 = undefined;
var decompress: http.Decompress = undefined;
const reader = response.readerDecompressing(&transfer_buffer, &decompress, decompress_buffer);
if (storage.allocator) |allocator| {
reader.appendRemaining(allocator, null, list, storage.append_limit) catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
else => |e| return e,
};
} else {
const buf = storage.append_limit.slice(list.unusedCapacitySlice());
list.items.len += reader.readSliceShort(buf) catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
};
}
_ = reader.streamRemaining(response_writer) catch |err| switch (err) {
error.ReadFailed => return response.bodyErr().?,
else => |e| return e,
};
return .{ .status = response.head.status };
}

View File

@@ -1006,8 +1006,9 @@ fn echoTests(client: *http.Client, port: u16) !void {
const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#fetch", .{port});
defer gpa.free(location);
var body: std.ArrayListUnmanaged(u8) = .empty;
defer body.deinit(gpa);
var body: std.Io.Writer.Allocating = .init(gpa);
defer body.deinit();
try body.ensureUnusedCapacity(64);
const res = try client.fetch(.{
.location = .{ .url = location },
@@ -1016,10 +1017,10 @@ fn echoTests(client: *http.Client, port: u16) !void {
.extra_headers = &.{
.{ .name = "content-type", .value = "text/plain" },
},
.response_storage = .{ .allocator = gpa, .list = &body },
.response_writer = &body.writer,
});
try expectEqual(.ok, res.status);
try expectEqualStrings("Hello, World!\n", body.items);
try expectEqualStrings("Hello, World!\n", body.getWritten());
}
{ // expect: 100-continue