From abde76a808df816ea12a8a2dbf8e6b53ff9b110f Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 21 Feb 2024 23:47:35 -0700 Subject: [PATCH] std.http.Server: handle expect: 100-continue requests The API automatically handles these requests as expected. After receiveHead(), the server has a chance to notice the expectation and do something about it. If it does not, then the Server implementation will handle it by sending the continuation header when the read stream is created. Both respond() and respondStreaming() send the continuation header as part of discarding the request body, only if the read stream has not already been created. --- lib/std/http/Server.zig | 91 ++++++++++++++++++++++++++++++---------- lib/std/http/test.zig | 23 ++++++---- test/standalone/http.zig | 12 +----- 3 files changed, 85 insertions(+), 41 deletions(-) diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 4ce77a90c4..bc284f5517 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -313,11 +313,20 @@ pub const Request = struct { var first_buffer: [500]u8 = undefined; var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); + if (request.head.expect != null) { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + try request.server.connection.stream.writeAll(h.items); + return; + } h.fixedWriter().print("{s} {d} {s}\r\n", .{ @tagName(options.version), @intFromEnum(options.status), phrase, }) catch unreachable; - if (keep_alive) - h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { .none => {}, @@ -452,25 +461,35 @@ pub const Request = struct { var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(o.version), @intFromEnum(o.status), phrase, - }) catch unreachable; - if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + const elide_body = if (request.head.expect != null) eb: { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + break :eb true; + } else eb: { + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }) catch unreachable; + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); - if (options.content_length) |len| { - h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; - } else { - h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); - } + if (options.content_length) |len| { + h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; + } else { + h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); + } + + for (o.extra_headers) |header| { + h.appendSliceAssumeCapacity(header.name); + h.appendSliceAssumeCapacity(": "); + h.appendSliceAssumeCapacity(header.value); + h.appendSliceAssumeCapacity("\r\n"); + } - for (o.extra_headers) |header| { - h.appendSliceAssumeCapacity(header.name); - h.appendSliceAssumeCapacity(": "); - h.appendSliceAssumeCapacity(header.value); h.appendSliceAssumeCapacity("\r\n"); - } - - h.appendSliceAssumeCapacity("\r\n"); + break :eb request.head.method == .HEAD; + }; return .{ .stream = request.server.connection.stream, @@ -478,16 +497,20 @@ pub const Request = struct { .send_buffer_start = 0, .send_buffer_end = h.items.len, .content_length = options.content_length, - .elide_body = request.head.method == .HEAD, + .elide_body = elide_body, .chunk_len = 0, }; } - pub const ReadError = net.Stream.ReadError || error{ HttpChunkInvalid, HttpHeadersOversize }; + pub const ReadError = net.Stream.ReadError || error{ + HttpChunkInvalid, + HttpHeadersOversize, + }; fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; + const remaining_content_length = &request.reader_state.remaining_content_length; if (remaining_content_length.* == 0) { s.state = .ready; @@ -515,6 +538,7 @@ pub const Request = struct { fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; + const cp = &request.reader_state.chunk_parser; const head_end = request.head_end; @@ -599,11 +623,33 @@ pub const Request = struct { return out_end; } - pub fn reader(request: *Request) std.io.AnyReader { + pub const ReaderError = Response.WriteError || error{ + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; + + /// In the case that the request contains "expect: 100-continue", this + /// function writes the continuation header, which means it can fail with a + /// write error. After sending the continuation header, it sets the + /// request's expect field to `null`. + /// + /// Asserts that this function is only called once. + pub fn reader(request: *Request) ReaderError!std.io.AnyReader { const s = request.server; assert(s.state == .received_head); s.state = .receiving_body; s.next_request_start = request.head_end; + + if (request.head.expect) |expect| { + if (mem.eql(u8, expect, "100-continue")) { + try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; + } else { + return error.HttpExpectationFailed; + } + } + switch (request.head.transfer_encoding) { .chunked => { request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; @@ -639,7 +685,8 @@ pub const Request = struct { const s = request.server; if (keep_alive and request.head.keep_alive) switch (s.state) { .received_head => { - _ = request.reader().discard() catch return false; + const r = request.reader() catch return false; + _ = r.discard() catch return false; assert(s.state == .ready); return true; }, diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index abb98f28e1..108d0aba56 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -136,7 +136,7 @@ test "HTTP server handles a chunked transfer coding request" { try expect(request.head.transfer_encoding == .chunked); var buf: [128]u8 = undefined; - const n = try request.reader().readAll(&buf); + const n = try (try request.reader()).readAll(&buf); try expect(std.mem.eql(u8, buf[0..n], "ABCD")); try request.respond("message from server!\n", .{ @@ -187,13 +187,13 @@ test "echo content server" { const server_thread = try std.Thread.spawn(.{}, (struct { fn handleRequest(request: *std.http.Server.Request) !void { - std.debug.print("server received {s} {s} {s}\n", .{ - @tagName(request.head.method), - @tagName(request.head.version), - request.head.target, - }); + //std.debug.print("server received {s} {s} {s}\n", .{ + // @tagName(request.head.method), + // @tagName(request.head.version), + // request.head.target, + //}); - const body = try request.reader().readAllAlloc(std.testing.allocator, 8192); + const body = try (try request.reader()).readAllAlloc(std.testing.allocator, 8192); defer std.testing.allocator.free(body); try testing.expect(std.mem.startsWith(u8, request.head.target, "/echo-content")); @@ -217,7 +217,7 @@ test "echo content server" { try w.writeAll("Hello, "); try w.writeAll("World!\n"); try response.end(); - std.debug.print(" server finished responding\n", .{}); + //std.debug.print(" server finished responding\n", .{}); } fn run(net_server: *std.net.Server) anyerror!void { @@ -237,6 +237,13 @@ test "echo content server" { if (std.mem.eql(u8, request.head.target, "/end")) { return request.respond("", .{ .keep_alive = false }); } + if (request.head.expect) |expect| { + if (std.mem.eql(u8, expect, "garbage")) { + try testing.expectError(error.HttpExpectationFailed, request.reader()); + try request.respond("", .{ .keep_alive = false }); + continue; + } + } handleRequest(&request) catch |err| { // This message helps the person troubleshooting determine whether // output comes from the server thread or the client thread. diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 5b44a14032..ff6467fc6c 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -26,17 +26,7 @@ fn handleRequest(request: *http.Server.Request, listen_port: u16) !void { request.head.target, }); - if (request.head.expect) |expect| { - if (mem.eql(u8, expect, "100-continue")) { - @panic("test failure, didn't handle expect 100-continue"); - } else { - return request.respond("", .{ - .status = .expectation_failed, - }); - } - } - - const body = try request.reader().readAllAlloc(salloc, 8192); + const body = try (try request.reader()).readAllAlloc(salloc, 8192); defer salloc.free(body); var send_buffer: [100]u8 = undefined;