zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

commit aa090a49d94155c4804644377db110f3b13f0500 (tree)
parent 5d40338f21b468c82d4bc2a1ac0a35c643126e74
Author: Nameless <truemedian@gmail.com>
Date:   Tue, 22 Aug 2023 10:05:03 -0500

std.http: handle expect:100-continue and continue responses

Diffstat:
Mlib/std/http/Client.zig | 45++++++++++++++++++++++++++++++++++++++++-----
Mlib/std/http/Server.zig | 74+++++++++++++++++++++++++++++++++++++++++---------------------------------
Mlib/std/http/protocol.zig | 9++++++---
Mtest/standalone/http.zig | 70+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
4 files changed, 156 insertions(+), 42 deletions(-)

diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig @@ -478,6 +478,7 @@ pub const Request = struct { .zstd => |*zstd| zstd.deinit(), } + req.headers.deinit(); req.response.headers.deinit(); if (req.response.parser.header_bytes_owned) { @@ -667,17 +668,19 @@ pub const Request = struct { try req.response.parse(req.response.parser.header_bytes.items, false); - if (req.response.status == .switching_protocols) { - req.connection.?.data.closing = false; - req.response.parser.done = true; + if (req.response.status == .@"continue") { + req.response.parser.done = true; // we're done parsing the continue response, reset to prepare for the real response + req.response.parser.reset(); + break; } - if (req.method == .CONNECT and req.response.status == .ok) { + // we're switching protocols, so this connection is no longer doing http + if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) { req.connection.?.data.closing = false; req.response.parser.done = true; } - // we default to using keep-alive if not provided + // we default to using keep-alive if not provided in the client if the server asks for it const req_connection = req.headers.getFirstValue("connection"); const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); @@ -955,6 +958,38 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: return conn; } +pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{NameTooLong} || std.os.ConnectError; + +pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node { + if (client.connection_pool.findConnection(.{ + .host = path, + .port = 0, + .is_tls = false, + })) |node| + return node; + + const conn = try client.allocator.create(ConnectionPool.Node); + errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; + + const stream = try std.net.connectUnixSocket(path); + errdefer stream.close(); + + conn.data = .{ + .stream = stream, + .tls_client = undefined, + .protocol = .plain, + + .host = try client.allocator.dupe(u8, path), + .port = 0, + }; + errdefer client.allocator.free(conn.data.host); + + client.connection_pool.addUsed(conn); + + return conn; +} + // Prevents a dependency loop in request() const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused }; pub const ConnectError = ConnectErrorPartial || RequestError; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig @@ -411,48 +411,52 @@ pub const Response = struct { } try w.writeAll("\r\n"); - if (!res.headers.contains("server")) { - try w.writeAll("Server: zig (std.http)\r\n"); - } + if (res.status == .@"continue") { + res.state = .waited; // we still need to send another request after this + } else { + if (!res.headers.contains("server")) { + try w.writeAll("Server: zig (std.http)\r\n"); + } - if (!res.headers.contains("connection")) { - const req_connection = res.request.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); + if (!res.headers.contains("connection")) { + const req_connection = res.request.headers.getFirstValue("connection"); + const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - if (req_keepalive) { - try w.writeAll("Connection: keep-alive\r\n"); - } else { - try w.writeAll("Connection: close\r\n"); + if (req_keepalive) { + try w.writeAll("Connection: keep-alive\r\n"); + } else { + try w.writeAll("Connection: close\r\n"); + } } - } - const has_transfer_encoding = res.headers.contains("transfer-encoding"); - const has_content_length = res.headers.contains("content-length"); + const has_transfer_encoding = res.headers.contains("transfer-encoding"); + const has_content_length = res.headers.contains("content-length"); - if (!has_transfer_encoding and !has_content_length) { - switch (res.transfer_encoding) { - .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"), - .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}), - .none => {}, - } - } else { - if (has_content_length) { - const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; - - res.transfer_encoding = .{ .content_length = content_length }; - } else if (has_transfer_encoding) { - const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?; - if (std.mem.eql(u8, transfer_encoding, "chunked")) { - res.transfer_encoding = .chunked; - } else { - return error.UnsupportedTransferEncoding; + if (!has_transfer_encoding and !has_content_length) { + switch (res.transfer_encoding) { + .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"), + .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}), + .none => {}, } } else { - res.transfer_encoding = .none; + if (has_content_length) { + const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; + + res.transfer_encoding = .{ .content_length = content_length }; + } else if (has_transfer_encoding) { + const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?; + if (std.mem.eql(u8, transfer_encoding, "chunked")) { + res.transfer_encoding = .chunked; + } else { + return error.UnsupportedTransferEncoding; + } + } else { + res.transfer_encoding = .none; + } } - } - try w.print("{}", .{res.headers}); + try w.print("{}", .{res.headers}); + } try w.writeAll("\r\n"); @@ -516,6 +520,10 @@ pub const Response = struct { res.request.parser.done = true; } + if (res.request.method == .HEAD) { + res.request.parser.done = true; + } + if (!res.request.parser.done) { if (res.request.transfer_compression) |tc| switch (tc) { .compress => return error.CompressionNotSupported, diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig @@ -534,9 +534,9 @@ pub const HeadersParser = struct { if (r.next_chunk_length == 0) r.done = true; - return 0; - } else { - const out_avail = buffer.len; + return out_index; + } else if (out_index < buffer.len) { + const out_avail = buffer.len - out_index; const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); const nread = try conn.read(buffer[0..can_read]); @@ -545,6 +545,8 @@ pub const HeadersParser = struct { if (r.next_chunk_length == 0) r.done = true; return nread; + } else { + return out_index; } }, .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { @@ -558,6 +560,7 @@ pub const HeadersParser = struct { .chunk_data => if (r.next_chunk_length == 0) { if (std.mem.eql(u8, conn.peek(), "\r\n")) { r.state = .finished; + r.done = true; } else { // The trailer section is formatted identically to the header section. r.state = .seen_rn; diff --git a/test/standalone/http.zig b/test/standalone/http.zig @@ -22,6 +22,18 @@ fn handleRequest(res: *Server.Response) !void { log.info("{s} {s} {s}", .{ @tagName(res.request.method), @tagName(res.request.version), res.request.target }); + if (res.request.headers.contains("expect")) { + if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) { + res.status = .@"continue"; + try res.do(); + res.status = .ok; + } else { + res.status = .expectation_failed; + try res.do(); + return; + } + } + const body = try res.reader().readAllAlloc(salloc, 8192); defer salloc.free(body); @@ -62,7 +74,7 @@ fn handleRequest(res: *Server.Response) !void { } try res.finish(); - } else if (mem.eql(u8, res.request.target, "/echo-content")) { + } else if (mem.startsWith(u8, res.request.target, "/echo-content")) { try testing.expectEqualStrings("Hello, World!\n", body); try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?); @@ -592,6 +604,62 @@ pub fn main() !void { try testing.expectEqualStrings("Hello, World!\n", res.body.?); } + { // expect: 100-continue + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + try h.append("expect", "100-continue"); + try h.append("content-type", "text/plain"); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-100", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.POST, uri, h, .{}); + defer req.deinit(); + + req.transfer_encoding = .chunked; + + try req.start(); + try req.wait(); + try testing.expectEqual(http.Status.@"continue", req.response.status); + + try req.writeAll("Hello, "); + try req.writeAll("World!\n"); + try req.finish(); + + try req.wait(); + try testing.expectEqual(http.Status.ok, req.response.status); + + const body = try req.reader().readAllAlloc(calloc, 8192); + defer calloc.free(body); + + try testing.expectEqualStrings("Hello, World!\n", body); + } + + { // expect: garbage + var h = http.Headers{ .allocator = calloc }; + defer h.deinit(); + + try h.append("content-type", "text/plain"); + try h.append("expect", "garbage"); + + const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-garbage", .{port}); + defer calloc.free(location); + const uri = try std.Uri.parse(location); + + log.info("{s}", .{location}); + var req = try client.request(.POST, uri, h, .{}); + defer req.deinit(); + + req.transfer_encoding = .chunked; + + try req.start(); + try req.wait(); + try testing.expectEqual(http.Status.expectation_failed, req.response.status); + } + { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port}); defer calloc.free(location);