diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 4ce77a90c4..bc284f5517 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -313,11 +313,20 @@ pub const Request = struct { var first_buffer: [500]u8 = undefined; var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); + if (request.head.expect != null) { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + try request.server.connection.stream.writeAll(h.items); + return; + } h.fixedWriter().print("{s} {d} {s}\r\n", .{ @tagName(options.version), @intFromEnum(options.status), phrase, }) catch unreachable; - if (keep_alive) - h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { .none => {}, @@ -452,25 +461,35 @@ pub const Request = struct { var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(o.version), @intFromEnum(o.status), phrase, - }) catch unreachable; - if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + const elide_body = if (request.head.expect != null) eb: { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + break :eb true; + } else eb: { + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }) catch unreachable; + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); - if (options.content_length) |len| { - h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; - } else { - h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); - } + if (options.content_length) |len| { + h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; + } else { + h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); + } + + for (o.extra_headers) |header| { + h.appendSliceAssumeCapacity(header.name); + h.appendSliceAssumeCapacity(": "); + h.appendSliceAssumeCapacity(header.value); + h.appendSliceAssumeCapacity("\r\n"); + } - for (o.extra_headers) |header| { - h.appendSliceAssumeCapacity(header.name); - h.appendSliceAssumeCapacity(": "); - h.appendSliceAssumeCapacity(header.value); h.appendSliceAssumeCapacity("\r\n"); - } - - h.appendSliceAssumeCapacity("\r\n"); + break :eb request.head.method == .HEAD; + }; return .{ .stream = request.server.connection.stream, @@ -478,16 +497,20 @@ pub const Request = struct { .send_buffer_start = 0, .send_buffer_end = h.items.len, .content_length = options.content_length, - .elide_body = request.head.method == .HEAD, + .elide_body = elide_body, .chunk_len = 0, }; } - pub const ReadError = net.Stream.ReadError || error{ HttpChunkInvalid, HttpHeadersOversize }; + pub const ReadError = net.Stream.ReadError || error{ + HttpChunkInvalid, + HttpHeadersOversize, + }; fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; + const remaining_content_length = &request.reader_state.remaining_content_length; if (remaining_content_length.* == 0) { s.state = .ready; @@ -515,6 +538,7 @@ pub const Request = struct { fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; + const cp = &request.reader_state.chunk_parser; const head_end = request.head_end; @@ -599,11 +623,33 @@ pub const Request = struct { return out_end; } - pub fn reader(request: *Request) std.io.AnyReader { + pub const ReaderError = Response.WriteError || error{ + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; + + /// In the case that the request contains "expect: 100-continue", this + /// function writes the continuation header, which means it can fail with a + /// write error. After sending the continuation header, it sets the + /// request's expect field to `null`. + /// + /// Asserts that this function is only called once. + pub fn reader(request: *Request) ReaderError!std.io.AnyReader { const s = request.server; assert(s.state == .received_head); s.state = .receiving_body; s.next_request_start = request.head_end; + + if (request.head.expect) |expect| { + if (mem.eql(u8, expect, "100-continue")) { + try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; + } else { + return error.HttpExpectationFailed; + } + } + switch (request.head.transfer_encoding) { .chunked => { request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; @@ -639,7 +685,8 @@ pub const Request = struct { const s = request.server; if (keep_alive and request.head.keep_alive) switch (s.state) { .received_head => { - _ = request.reader().discard() catch return false; + const r = request.reader() catch return false; + _ = r.discard() catch return false; assert(s.state == .ready); return true; }, diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index abb98f28e1..108d0aba56 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -136,7 +136,7 @@ test "HTTP server handles a chunked transfer coding request" { try expect(request.head.transfer_encoding == .chunked); var buf: [128]u8 = undefined; - const n = try request.reader().readAll(&buf); + const n = try (try request.reader()).readAll(&buf); try expect(std.mem.eql(u8, buf[0..n], "ABCD")); try request.respond("message from server!\n", .{ @@ -187,13 +187,13 @@ test "echo content server" { const server_thread = try std.Thread.spawn(.{}, (struct { fn handleRequest(request: *std.http.Server.Request) !void { - std.debug.print("server received {s} {s} {s}\n", .{ - @tagName(request.head.method), - @tagName(request.head.version), - request.head.target, - }); + //std.debug.print("server received {s} {s} {s}\n", .{ + // @tagName(request.head.method), + // @tagName(request.head.version), + // request.head.target, + //}); - const body = try request.reader().readAllAlloc(std.testing.allocator, 8192); + const body = try (try request.reader()).readAllAlloc(std.testing.allocator, 8192); defer std.testing.allocator.free(body); try testing.expect(std.mem.startsWith(u8, request.head.target, "/echo-content")); @@ -217,7 +217,7 @@ test "echo content server" { try w.writeAll("Hello, "); try w.writeAll("World!\n"); try response.end(); - std.debug.print(" server finished responding\n", .{}); + //std.debug.print(" server finished responding\n", .{}); } fn run(net_server: *std.net.Server) anyerror!void { @@ -237,6 +237,13 @@ test "echo content server" { if (std.mem.eql(u8, request.head.target, "/end")) { return request.respond("", .{ .keep_alive = false }); } + if (request.head.expect) |expect| { + if (std.mem.eql(u8, expect, "garbage")) { + try testing.expectError(error.HttpExpectationFailed, request.reader()); + try request.respond("", .{ .keep_alive = false }); + continue; + } + } handleRequest(&request) catch |err| { // This message helps the person troubleshooting determine whether // output comes from the server thread or the client thread. diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 5b44a14032..ff6467fc6c 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -26,17 +26,7 @@ fn handleRequest(request: *http.Server.Request, listen_port: u16) !void { request.head.target, }); - if (request.head.expect) |expect| { - if (mem.eql(u8, expect, "100-continue")) { - @panic("test failure, didn't handle expect 100-continue"); - } else { - return request.respond("", .{ - .status = .expectation_failed, - }); - } - } - - const body = try request.reader().readAllAlloc(salloc, 8192); + const body = try (try request.reader()).readAllAlloc(salloc, 8192); defer salloc.free(body); var send_buffer: [100]u8 = undefined;