diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index ac926c4b41..0498c9c297 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -138,6 +138,35 @@ pub const AlertLevel = enum(u8) { }; pub const AlertDescription = enum(u8) { + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; + close_notify = 0, unexpected_message = 10, bad_record_mac = 20, @@ -166,6 +195,39 @@ pub const AlertDescription = enum(u8) { certificate_required = 116, no_application_protocol = 120, _, + + pub fn toError(alert: AlertDescription) Error!void { + return switch (alert) { + .close_notify => {}, // not an error + .unexpected_message => error.TlsAlertUnexpectedMessage, + .bad_record_mac => error.TlsAlertBadRecordMac, + .record_overflow => error.TlsAlertRecordOverflow, + .handshake_failure => error.TlsAlertHandshakeFailure, + .bad_certificate => error.TlsAlertBadCertificate, + .unsupported_certificate => error.TlsAlertUnsupportedCertificate, + .certificate_revoked => error.TlsAlertCertificateRevoked, + .certificate_expired => error.TlsAlertCertificateExpired, + .certificate_unknown => error.TlsAlertCertificateUnknown, + .illegal_parameter => error.TlsAlertIllegalParameter, + .unknown_ca => error.TlsAlertUnknownCa, + .access_denied => error.TlsAlertAccessDenied, + .decode_error => error.TlsAlertDecodeError, + .decrypt_error => error.TlsAlertDecryptError, + .protocol_version => error.TlsAlertProtocolVersion, + .insufficient_security => error.TlsAlertInsufficientSecurity, + .internal_error => error.TlsAlertInternalError, + .inappropriate_fallback => error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => error.TlsAlertMissingExtension, + .unsupported_extension => error.TlsAlertUnsupportedExtension, + .unrecognized_name => error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, + .certificate_required => error.TlsAlertCertificateRequired, + .no_application_protocol => error.TlsAlertNoApplicationProtocol, + _ => error.TlsAlertUnknown, + }; + } }; pub const SignatureScheme = enum(u16) { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 0d404d29ac..2745bd4e6f 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -89,12 +89,11 @@ pub const StreamInterface = struct { }; pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{ + return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ InsufficientEntropy, DiskQuota, LockViolation, NotOpenForWriting, - TlsAlert, TlsUnexpectedMessage, TlsIllegalParameter, TlsDecryptFailure, @@ -251,8 +250,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const level = ptd.decode(tls.AlertLevel); const desc = ptd.decode(tls.AlertDescription); _ = level; - _ = desc; - return error.TlsAlert; + + // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake + try desc.toError(); + // TODO: handle server-side closures + return error.TlsUnexpectedMessage; }, .handshake => { try ptd.ensure(4); @@ -1071,8 +1073,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) const level = @intToEnum(tls.AlertLevel, frag[in]); const desc = @intToEnum(tls.AlertDescription, frag[in + 1]); _ = level; - _ = desc; - return error.TlsAlert; + + try desc.toError(); + // TODO: handle server-side closures + return error.TlsUnexpectedMessage; }, .application_data => { const cleartext = switch (c.application_cipher) { @@ -1112,7 +1116,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) return vp.total; } _ = level; - return error.TlsAlert; + + try desc.toError(); + // TODO: handle server-side closures + return error.TlsUnexpectedMessage; }, .handshake => { var ct_i: usize = 0; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 023bdd28bc..91b688a25c 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -36,21 +36,7 @@ pub const ConnectionPool = struct { is_tls: bool, }; - pub const StoredConnection = struct { - buffered: BufferedConnection, - host: []u8, - port: u16, - - proxied: bool = false, - closing: bool = false, - - pub fn deinit(self: *StoredConnection, client: *Client) void { - self.buffered.close(client); - client.allocator.free(self.host); - } - }; - - const Queue = std.TailQueue(StoredConnection); + const Queue = std.TailQueue(Connection); pub const Node = Queue.Node; mutex: std.Thread.Mutex = .{}, @@ -69,7 +55,7 @@ pub const ConnectionPool = struct { var next = pool.free.last; while (next) |node| : (next = node.prev) { - if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue; + if ((node.data.protocol == .tls) != criteria.is_tls) continue; if (node.data.port != criteria.port) continue; if (!mem.eql(u8, node.data.host, criteria.host)) continue; @@ -160,45 +146,105 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + pub const Protocol = enum { plain, tls }; + stream: net.Stream, /// undefined unless protocol is tls. tls_client: *std.crypto.tls.Client, + protocol: Protocol, + host: []u8, + port: u16, - pub const Protocol = enum { plain, tls }; + proxied: bool = false, + closing: bool = false, - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.read(buffer), - .tls => conn.tls_client.read(conn.stream, buffer), - } catch |err| switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, - error.TlsAlert => return error.TlsAlert, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; - } + read_start: u16 = 0, + read_end: u16 = 0, + read_buf: [buffer_size]u8 = undefined, - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { + pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { return switch (conn.protocol) { .plain => conn.stream.readAtLeast(buffer, len), .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), - } catch |err| switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, - error.TlsAlert => return error.TlsAlert, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, + } catch |err| { + // TODO: https://github.com/ziglang/zig/issues/2473 + if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; + + switch (err) { + error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + } }; } + pub fn fill(conn: *Connection) ReadError!void { + if (conn.read_end != conn.read_start) return; + + const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); + if (nread == 0) return error.EndOfStream; + conn.read_start = 0; + conn.read_end = @intCast(u16, nread); + } + + pub fn peek(conn: *Connection) []const u8 { + return conn.read_buf[conn.read_start..conn.read_end]; + } + + pub fn drop(conn: *Connection, num: u16) void { + conn.read_start += num; + } + + pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + + var out_index: u16 = 0; + while (out_index < len) { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len - out_index; + + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + out_index += @intCast(u16, available_buffer); + conn.read_start += @intCast(u16, available_buffer); + + break; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + out_index += available_read; + conn.read_start += available_read; + + if (out_index >= len) break; + } + + const leftover_buffer = available_buffer - available_read; + const leftover_len = len - out_index; + + if (leftover_buffer > conn.read_buf.len) { + // skip the buffer if the output is large enough + return conn.rawReadAtLeast(buffer[out_index..], leftover_len); + } + + try conn.fill(); + } + + return out_index; + } + + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + return conn.readAtLeast(buffer, 1); + } + pub const ReadError = error{ TlsFailure, TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure, + EndOfStream, }; pub const Reader = std.io.Reader(*Connection, ReadError, read); @@ -247,111 +293,10 @@ pub const Connection = struct { conn.stream.close(); } -}; -/// A buffered (and peekable) Connection. -pub const BufferedConnection = struct { - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - - conn: Connection, - read_buf: [buffer_size]u8 = undefined, - read_start: u16 = 0, - read_end: u16 = 0, - - write_buf: [buffer_size]u8 = undefined, - write_end: u16 = 0, - - pub fn fill(bconn: *BufferedConnection) ReadError!void { - if (bconn.read_end != bconn.read_start) return; - - const nread = try bconn.conn.read(bconn.read_buf[0..]); - if (nread == 0) return error.EndOfStream; - bconn.read_start = 0; - bconn.read_end = @intCast(u16, nread); - } - - pub fn peek(bconn: *BufferedConnection) []const u8 { - return bconn.read_buf[bconn.read_start..bconn.read_end]; - } - - pub fn clear(bconn: *BufferedConnection, num: u16) void { - bconn.read_start += num; - } - - pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { - var out_index: u16 = 0; - while (out_index < len) { - const available = bconn.read_end - bconn.read_start; - const left = buffer.len - out_index; - - if (available > 0) { - const can_read = @intCast(u16, @min(available, left)); - - @memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]); - out_index += can_read; - bconn.read_start += can_read; - - continue; - } - - if (left > bconn.read_buf.len) { - // skip the buffer if the output is large enough - return bconn.conn.read(buffer[out_index..]); - } - - try bconn.fill(); - } - - return out_index; - } - - pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize { - return bconn.readAtLeast(buffer, 1); - } - - pub const ReadError = Connection.ReadError || error{EndOfStream}; - pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read); - - pub fn reader(bconn: *BufferedConnection) Reader { - return Reader{ .context = bconn }; - } - - pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - if (bconn.write_buf.len - bconn.write_end >= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); - bconn.write_end += @intCast(u16, buffer.len); - } else { - try bconn.flush(); - try bconn.conn.writeAll(buffer); - } - } - - pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - if (bconn.write_buf.len - bconn.write_end >= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); - bconn.write_end += @intCast(u16, buffer.len); - - return buffer.len; - } else { - try bconn.flush(); - return try bconn.conn.write(buffer); - } - } - - pub fn flush(bconn: *BufferedConnection) WriteError!void { - defer bconn.write_end = 0; - return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - } - - pub const WriteError = Connection.WriteError; - pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write); - - pub fn writer(bconn: *BufferedConnection) Writer { - return Writer{ .context = bconn }; - } - - pub fn close(bconn: *BufferedConnection, client: *const Client) void { - bconn.conn.close(client); + pub fn deinit(conn: *Connection, client: *const Client) void { + conn.close(client); + client.allocator.free(conn.host); } }; @@ -585,11 +530,12 @@ pub const Request = struct { }; } - pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + pub const StartError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; /// Send the request to the server. pub fn start(req: *Request) StartError!void { - const w = req.connection.data.buffered.writer(); + var buffered = std.io.bufferedWriter(req.connection.data.writer()); + const w = buffered.writer(); try w.writeAll(@tagName(req.method)); try w.writeByte(' '); @@ -663,10 +609,10 @@ pub const Request = struct { try w.writeAll("\r\n"); - try req.connection.data.buffered.flush(); + try buffered.flush(); } - pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; + pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); @@ -679,7 +625,7 @@ pub const Request = struct { var index: usize = 0; while (index == 0) { - const amt = try req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip); + const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip); if (amt == 0 and req.response.parser.done) break; index += amt; } @@ -697,10 +643,10 @@ pub const Request = struct { pub fn wait(req: *Request) WaitError!void { while (true) { // handle redirects while (true) { // read headers - try req.connection.data.buffered.fill(); + try req.connection.data.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek()); - req.connection.data.buffered.clear(@intCast(u16, nchecked)); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek()); + req.connection.data.drop(@intCast(u16, nchecked)); if (req.response.parser.state.isContent()) break; } @@ -816,10 +762,10 @@ pub const Request = struct { const has_trail = !req.response.parser.state.isContent(); while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.data.buffered.fill(); + try req.connection.data.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek()); - req.connection.data.buffered.clear(@intCast(u16, nchecked)); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek()); + req.connection.data.drop(@intCast(u16, nchecked)); } if (has_trail) { @@ -845,7 +791,7 @@ pub const Request = struct { return index; } - pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong }; + pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; pub const Writer = std.io.Writer(*Request, WriteError, write); @@ -857,16 +803,16 @@ pub const Request = struct { pub fn write(req: *Request, bytes: []const u8) WriteError!usize { switch (req.transfer_encoding) { .chunked => { - try req.connection.data.buffered.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.data.buffered.writeAll(bytes); - try req.connection.data.buffered.writeAll("\r\n"); + try req.connection.data.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.data.writeAll(bytes); + try req.connection.data.writeAll("\r\n"); return bytes.len; }, .content_length => |*len| { if (len.* < bytes.len) return error.MessageTooLong; - const amt = try req.connection.data.buffered.write(bytes); + const amt = try req.connection.data.write(bytes); len.* -= amt; return amt; }, @@ -886,12 +832,10 @@ pub const Request = struct { /// Finish the body of a request. This notifies the server that you have no more data to send. pub fn finish(req: *Request) FinishError!void { switch (req.transfer_encoding) { - .chunked => try req.connection.data.buffered.writeAll("0\r\n\r\n"), + .chunked => try req.connection.data.writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } - - try req.connection.data.buffered.flush(); } }; @@ -948,11 +892,10 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: errdefer stream.close(); conn.data = .{ - .buffered = .{ .conn = .{ - .stream = stream, - .tls_client = undefined, - .protocol = protocol, - } }, + .stream = stream, + .tls_client = undefined, + .protocol = protocol, + .host = try client.allocator.dupe(u8, host), .port = port, }; @@ -961,13 +904,13 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: switch (protocol) { .plain => {}, .tls => { - conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client); + conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + errdefer client.allocator.destroy(conn.data.tls_client); - conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; + conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. - conn.data.buffered.conn.tls_client.allow_truncation_attacks = true; + conn.data.tls_client.allow_truncation_attacks = true; }, } @@ -1003,7 +946,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio } } -pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || BufferedConnection.WriteError || error{ +pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || Connection.WriteError || error{ UnsupportedUrlScheme, UriMissingHost, diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 6b5db6725f..67641eab00 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -16,39 +16,92 @@ socket: net.StreamServer, /// An interface to either a plain or TLS connection. pub const Connection = struct { + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + pub const Protocol = enum { plain }; + stream: net.Stream, protocol: Protocol, closing: bool = true, - pub const Protocol = enum { plain }; + read_buf: [buffer_size]u8 = undefined, + read_start: u16 = 0, + read_end: u16 = 0, - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { return switch (conn.protocol) { - .plain => conn.stream.read(buffer), - // .tls => return conn.tls_client.read(conn.stream, buffer), - } catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, + .plain => conn.stream.readAtLeast(buffer, len), + // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), + } catch |err| { + switch (err) { + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + } }; } + pub fn fill(conn: *Connection) ReadError!void { + if (conn.read_end != conn.read_start) return; + + const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); + if (nread == 0) return error.EndOfStream; + conn.read_start = 0; + conn.read_end = @intCast(u16, nread); + } + + pub fn peek(conn: *Connection) []const u8 { + return conn.read_buf[conn.read_start..conn.read_end]; + } + + pub fn drop(conn: *Connection, num: u16) void { + conn.read_start += num; + } + pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.readAtLeast(buffer, len), - // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), - } catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; + assert(len <= buffer.len); + + var out_index: u16 = 0; + while (out_index < len) { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len - out_index; + + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + out_index += @intCast(u16, available_buffer); + conn.read_start += @intCast(u16, available_buffer); + + break; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + out_index += available_read; + conn.read_start += available_read; + + if (out_index >= len) break; + } + + const leftover_buffer = available_buffer - available_read; + const leftover_len = len - out_index; + + if (leftover_buffer > conn.read_buf.len) { + // skip the buffer if the output is large enough + return conn.rawReadAtLeast(buffer[out_index..], leftover_len); + } + + try conn.fill(); + } + + return out_index; + } + + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + return conn.readAtLeast(buffer, 1); } pub const ReadError = error{ ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure, + EndOfStream, }; pub const Reader = std.io.Reader(*Connection, ReadError, read); @@ -93,112 +146,6 @@ pub const Connection = struct { } }; -/// A buffered (and peekable) Connection. -pub const BufferedConnection = struct { - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - - conn: Connection, - read_buf: [buffer_size]u8 = undefined, - read_start: u16 = 0, - read_end: u16 = 0, - - write_buf: [buffer_size]u8 = undefined, - write_end: u16 = 0, - - pub fn fill(bconn: *BufferedConnection) ReadError!void { - if (bconn.read_end != bconn.read_start) return; - - const nread = try bconn.conn.read(bconn.read_buf[0..]); - if (nread == 0) return error.EndOfStream; - bconn.read_start = 0; - bconn.read_end = @intCast(u16, nread); - } - - pub fn peek(bconn: *BufferedConnection) []const u8 { - return bconn.read_buf[bconn.read_start..bconn.read_end]; - } - - pub fn clear(bconn: *BufferedConnection, num: u16) void { - bconn.read_start += num; - } - - pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { - var out_index: u16 = 0; - while (out_index < len) { - const available = bconn.read_end - bconn.read_start; - const left = buffer.len - out_index; - - if (available > 0) { - const can_read = @intCast(u16, @min(available, left)); - - @memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]); - out_index += can_read; - bconn.read_start += can_read; - - continue; - } - - if (left > bconn.read_buf.len) { - // skip the buffer if the output is large enough - return bconn.conn.read(buffer[out_index..]); - } - - try bconn.fill(); - } - - return out_index; - } - - pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize { - return bconn.readAtLeast(buffer, 1); - } - - pub const ReadError = Connection.ReadError || error{EndOfStream}; - pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read); - - pub fn reader(bconn: *BufferedConnection) Reader { - return Reader{ .context = bconn }; - } - - pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - if (bconn.write_buf.len - bconn.write_end >= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); - bconn.write_end += @intCast(u16, buffer.len); - } else { - try bconn.flush(); - try bconn.conn.writeAll(buffer); - } - } - - pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - if (bconn.write_buf.len - bconn.write_end >= buffer.len) { - @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); - bconn.write_end += @intCast(u16, buffer.len); - - return buffer.len; - } else { - try bconn.flush(); - return try bconn.conn.write(buffer); - } - } - - pub fn flush(bconn: *BufferedConnection) WriteError!void { - defer bconn.write_end = 0; - return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); - } - - pub const WriteError = Connection.WriteError; - pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write); - - pub fn writer(bconn: *BufferedConnection) Writer { - return Writer{ .context = bconn }; - } - - pub fn close(bconn: *BufferedConnection) void { - bconn.conn.close(); - } -}; - /// The mode of transport for responses. pub const ResponseTransfer = union(enum) { content_length: u64, @@ -351,7 +298,7 @@ pub const Response = struct { allocator: Allocator, address: net.Address, - connection: BufferedConnection, + connection: Connection, headers: http.Headers, request: Request, @@ -388,7 +335,7 @@ pub const Response = struct { if (!res.request.parser.done) { // If the response wasn't fully read, then we need to close the connection. - res.connection.conn.closing = true; + res.connection.closing = true; return .closing; } @@ -402,9 +349,9 @@ pub const Response = struct { const req_connection = res.request.headers.getFirstValue("connection"); const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); if (req_keepalive and (res_keepalive or res_connection == null)) { - res.connection.conn.closing = false; + res.connection.closing = false; } else { - res.connection.conn.closing = true; + res.connection.closing = true; } switch (res.request.compression) { @@ -434,14 +381,14 @@ pub const Response = struct { .parser = res.request.parser, }; - if (res.connection.conn.closing) { + if (res.connection.closing) { return .closing; } else { return .reset; } } - pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; + pub const DoError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; /// Send the response headers. pub fn do(res: *Response) !void { @@ -450,7 +397,8 @@ pub const Response = struct { .first, .start, .responded, .finished => unreachable, } - const w = res.connection.writer(); + var buffered = std.io.bufferedWriter(res.connection.writer()); + const w = buffered.writer(); try w.writeAll(@tagName(res.version)); try w.writeByte(' '); @@ -508,10 +456,10 @@ pub const Response = struct { try w.writeAll("\r\n"); - try res.connection.flush(); + try buffered.flush(); } - pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; + pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); @@ -532,7 +480,7 @@ pub const Response = struct { return index; } - pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported }; + pub const WaitError = Connection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported }; /// Wait for the client to send a complete request head. pub fn wait(res: *Response) WaitError!void { @@ -545,7 +493,7 @@ pub const Response = struct { try res.connection.fill(); const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); - res.connection.clear(@intCast(u16, nchecked)); + res.connection.drop(@intCast(u16, nchecked)); if (res.request.parser.state.isContent()) break; } @@ -612,7 +560,7 @@ pub const Response = struct { try res.connection.fill(); const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); - res.connection.clear(@intCast(u16, nchecked)); + res.connection.drop(@intCast(u16, nchecked)); } if (has_trail) { @@ -637,7 +585,7 @@ pub const Response = struct { return index; } - pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong }; + pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; pub const Writer = std.io.Writer(*Response, WriteError, write); @@ -692,8 +640,6 @@ pub const Response = struct { .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } - - try res.connection.flush(); } }; @@ -742,10 +688,10 @@ pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response { return Response{ .allocator = options.allocator, .address = in.address, - .connection = .{ .conn = .{ + .connection = .{ .stream = in.stream, .protocol = .plain, - } }, + }, .headers = .{ .allocator = options.allocator }, .request = .{ .version = undefined, diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index c6bdd76272..b001b3cddf 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -513,8 +513,8 @@ pub const HeadersParser = struct { /// /// If `skip` is true, the buffer will be unused and the body will be skipped. /// - /// See `std.http.Client.BufferedConnection for an example of `bconn`. - pub fn read(r: *HeadersParser, bconn: anytype, buffer: []u8, skip: bool) !usize { + /// See `std.http.Client.BufferedConnection for an example of `conn`. + pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize { assert(r.state.isContent()); if (r.done) return 0; @@ -526,10 +526,10 @@ pub const HeadersParser = struct { const data_avail = r.next_chunk_length; if (skip) { - try bconn.fill(); + try conn.fill(); - const nread = @min(bconn.peek().len, data_avail); - bconn.clear(@intCast(u16, nread)); + const nread = @min(conn.peek().len, data_avail); + conn.drop(@intCast(u16, nread)); r.next_chunk_length -= nread; if (r.next_chunk_length == 0) r.done = true; @@ -539,7 +539,7 @@ pub const HeadersParser = struct { const out_avail = buffer.len; const can_read = @intCast(usize, @min(data_avail, out_avail)); - const nread = try bconn.read(buffer[0..can_read]); + const nread = try conn.read(buffer[0..can_read]); r.next_chunk_length -= nread; if (r.next_chunk_length == 0) r.done = true; @@ -548,15 +548,15 @@ pub const HeadersParser = struct { } }, .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - try bconn.fill(); + try conn.fill(); - const i = r.findChunkedLen(bconn.peek()); - bconn.clear(@intCast(u16, i)); + const i = r.findChunkedLen(conn.peek()); + conn.drop(@intCast(u16, i)); switch (r.state) { .invalid => return error.HttpChunkInvalid, .chunk_data => if (r.next_chunk_length == 0) { - if (std.mem.eql(u8, bconn.peek(), "\r\n")) { + if (std.mem.eql(u8, conn.peek(), "\r\n")) { r.state = .finished; } else { // The trailer section is formatted identically to the header section. @@ -576,14 +576,14 @@ pub const HeadersParser = struct { const out_avail = buffer.len - out_index; if (skip) { - try bconn.fill(); + try conn.fill(); - const nread = @min(bconn.peek().len, data_avail); - bconn.clear(@intCast(u16, nread)); + const nread = @min(conn.peek().len, data_avail); + conn.drop(@intCast(u16, nread)); r.next_chunk_length -= nread; } else { const can_read = @intCast(usize, @min(data_avail, out_avail)); - const nread = try bconn.read(buffer[out_index..][0..can_read]); + const nread = try conn.read(buffer[out_index..][0..can_read]); r.next_chunk_length -= nread; out_index += nread; } @@ -628,74 +628,74 @@ const MockBufferedConnection = struct { start: u16 = 0, end: u16 = 0, - pub fn fill(bconn: *MockBufferedConnection) ReadError!void { - if (bconn.end != bconn.start) return; + pub fn fill(conn: *MockBufferedConnection) ReadError!void { + if (conn.end != conn.start) return; - const nread = try bconn.conn.read(bconn.buf[0..]); + const nread = try conn.conn.read(conn.buf[0..]); if (nread == 0) return error.EndOfStream; - bconn.start = 0; - bconn.end = @truncate(u16, nread); + conn.start = 0; + conn.end = @truncate(u16, nread); } - pub fn peek(bconn: *MockBufferedConnection) []const u8 { - return bconn.buf[bconn.start..bconn.end]; + pub fn peek(conn: *MockBufferedConnection) []const u8 { + return conn.buf[conn.start..conn.end]; } - pub fn clear(bconn: *MockBufferedConnection, num: u16) void { - bconn.start += num; + pub fn drop(conn: *MockBufferedConnection, num: u16) void { + conn.start += num; } - pub fn readAtLeast(bconn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { + pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { var out_index: u16 = 0; while (out_index < len) { - const available = bconn.end - bconn.start; + const available = conn.end - conn.start; const left = buffer.len - out_index; if (available > 0) { const can_read = @truncate(u16, @min(available, left)); - @memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]); + @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]); out_index += can_read; - bconn.start += can_read; + conn.start += can_read; continue; } - if (left > bconn.buf.len) { + if (left > conn.buf.len) { // skip the buffer if the output is large enough - return bconn.conn.read(buffer[out_index..]); + return conn.conn.read(buffer[out_index..]); } - try bconn.fill(); + try conn.fill(); } return out_index; } - pub fn read(bconn: *MockBufferedConnection, buffer: []u8) ReadError!usize { - return bconn.readAtLeast(buffer, 1); + pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize { + return conn.readAtLeast(buffer, 1); } pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read); - pub fn reader(bconn: *MockBufferedConnection) Reader { - return Reader{ .context = bconn }; + pub fn reader(conn: *MockBufferedConnection) Reader { + return Reader{ .context = conn }; } - pub fn writeAll(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!void { - return bconn.conn.writeAll(buffer); + pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { + return conn.conn.writeAll(buffer); } - pub fn write(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { - return bconn.conn.write(buffer); + pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { + return conn.conn.write(buffer); } pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write); - pub fn writer(bconn: *MockBufferedConnection) Writer { - return Writer{ .context = bconn }; + pub fn writer(conn: *MockBufferedConnection) Writer { + return Writer{ .context = conn }; } }; @@ -753,15 +753,15 @@ test "HeadersParser.read length" { const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; var fbs = std.io.fixedBufferStream(data); - var bconn = MockBufferedConnection{ + var conn = MockBufferedConnection{ .conn = fbs, }; while (true) { // read headers - try bconn.fill(); + try conn.fill(); - const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek()); - bconn.clear(@intCast(u16, nchecked)); + const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); + conn.drop(@intCast(u16, nchecked)); if (r.state.isContent()) break; } @@ -769,7 +769,7 @@ test "HeadersParser.read length" { var buf: [8]u8 = undefined; r.next_chunk_length = 5; - const len = try r.read(&bconn, &buf, false); + const len = try r.read(&conn, &buf, false); try std.testing.expectEqual(@as(usize, 5), len); try std.testing.expectEqualStrings("Hello", buf[0..len]); @@ -784,22 +784,22 @@ test "HeadersParser.read chunked" { const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; var fbs = std.io.fixedBufferStream(data); - var bconn = MockBufferedConnection{ + var conn = MockBufferedConnection{ .conn = fbs, }; while (true) { // read headers - try bconn.fill(); + try conn.fill(); - const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek()); - bconn.clear(@intCast(u16, nchecked)); + const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); + conn.drop(@intCast(u16, nchecked)); if (r.state.isContent()) break; } var buf: [8]u8 = undefined; r.state = .chunk_head_size; - const len = try r.read(&bconn, &buf, false); + const len = try r.read(&conn, &buf, false); try std.testing.expectEqual(@as(usize, 5), len); try std.testing.expectEqualStrings("Hello", buf[0..len]); @@ -814,30 +814,30 @@ test "HeadersParser.read chunked trailer" { const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; var fbs = std.io.fixedBufferStream(data); - var bconn = MockBufferedConnection{ + var conn = MockBufferedConnection{ .conn = fbs, }; while (true) { // read headers - try bconn.fill(); + try conn.fill(); - const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek()); - bconn.clear(@intCast(u16, nchecked)); + const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); + conn.drop(@intCast(u16, nchecked)); if (r.state.isContent()) break; } var buf: [8]u8 = undefined; r.state = .chunk_head_size; - const len = try r.read(&bconn, &buf, false); + const len = try r.read(&conn, &buf, false); try std.testing.expectEqual(@as(usize, 5), len); try std.testing.expectEqualStrings("Hello", buf[0..len]); while (true) { // read headers - try bconn.fill(); + try conn.fill(); - const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek()); - bconn.clear(@intCast(u16, nchecked)); + const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek()); + conn.drop(@intCast(u16, nchecked)); if (r.state.isContent()) break; } diff --git a/test/standalone/http.zig b/test/standalone/http.zig index 13dc278b6d..ffb7a59276 100644 --- a/test/standalone/http.zig +++ b/test/standalone/http.zig @@ -86,7 +86,6 @@ fn handleRequest(res: *Server.Response) !void { try res.writeAll("World!\n"); // try res.finish(); try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); - try res.connection.flush(); } else if (mem.eql(u8, res.request.target, "/redirect/1")) { res.transfer_encoding = .chunked;