diff --git a/lib/std/http.zig b/lib/std/http.zig index 7c2a2da605..d4cc259f19 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -253,6 +253,16 @@ pub const TransferEncoding = enum { gzip, }; +pub const Connection = enum { + keep_alive, + close, +}; + +pub const CustomHeader = struct { + name: []const u8, + value: []const u8, +}; + const std = @import("std.zig"); test { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index d4d8f85ad1..cac6571798 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -21,27 +21,51 @@ ca_bundle: std.crypto.Certificate.Bundle = .{}, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, -connection_pool: std.TailQueue(Connection) = .{}, +connection_mutex: std.Thread.Mutex = .{}, +connection_pool: ConnectionPool = .{}, +connection_used: ConnectionPool = .{}, const ConnectionPool = std.TailQueue(Connection); const ConnectionNode = ConnectionPool.Node; -pub fn release(client: *Client, node: *ConnectionNode) void { - if (node.data.unusable) return node.data.close(client); +/// Acquires an existing connection from the connection pool. This function is threadsafe. +pub fn acquire(client: *Client, node: *ConnectionNode) void { + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + client.connection_pool.remove(node); + client.connection_used.append(node); +} + +/// Tries to release a connection back to the connection pool. This function is threadsafe. +/// If the connection is marked as closing, it will be closed instead. +pub fn release(client: *Client, node: *ConnectionNode) void { + if (node.data.closing) { + node.data.close(client); + + return client.allocator.destroy(node); + } + + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + + client.connection_used.remove(node); client.connection_pool.append(node); } +const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw); +const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw); + pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. + tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. protocol: Protocol, host: []u8, port: u16, // This connection has been part of a non keepalive request and cannot be added to the pool. - unusable: bool = false, + closing: bool = false, pub const Protocol = enum { plain, tls }; @@ -59,6 +83,24 @@ pub const Connection = struct { } } + pub const ReadError = std.net.Stream.ReadError || error{ + TlsConnectionTruncated, + TlsRecordOverflow, + TlsDecodeError, + TlsAlert, + TlsBadRecordMac, + Overflow, + TlsBadLength, + TlsIllegalParameter, + TlsUnexpectedMessage, + }; + + pub const Reader = std.io.Reader(*Connection, ReadError, read); + + pub fn reader(conn: *Connection) Reader { + return Reader{ .context = conn }; + } + pub fn writeAll(conn: *Connection, buffer: []const u8) !void { switch (conn.protocol) { .plain => return conn.stream.writeAll(buffer), @@ -73,10 +115,18 @@ pub const Connection = struct { } } + pub const WriteError = std.net.Stream.WriteError || error{}; + pub const Writer = std.io.Writer(*Connection, WriteError, write); + + pub fn writer(conn: *Connection) Writer { + return Writer{ .context = conn }; + } + pub fn close(conn: *Connection, client: *const Client) void { if (conn.protocol == .tls) { // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + client.allocator.destroy(conn.tls_client); } conn.stream.close(); @@ -85,10 +135,10 @@ pub const Connection = struct { } }; -/// TODO: emit error.UnexpectedEndOfStream or something like that when the read -/// data does not match the content length. This is necessary since HTTPS disables -/// close_notify protection on underlying TLS streams. pub const Request = struct { + const read_buffer_size = 8192; + const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size); + client: *Client, connection: *ConnectionNode, redirects_left: u32, @@ -97,6 +147,11 @@ pub const Request = struct { /// redirects. headers: Headers, + /// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning. + read_buffer: [read_buffer_size]u8 = undefined, + read_buffer_start: ReadBufferIndex = 0, + read_buffer_len: ReadBufferIndex = 0, + pub const Response = struct { headers: Response.Headers, state: State, @@ -106,15 +161,24 @@ pub const Request = struct { header_bytes: std.ArrayListUnmanaged(u8), max_header_bytes: usize, next_chunk_length: u64, - done: bool, + done: bool = false, + + compression: union(enum) { + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + none: void, + } = .none, pub const Headers = struct { status: http.Status, version: http.Version, location: ?[]const u8 = null, content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - connection_close: bool = true, + transfer_encoding: ?http.TransferEncoding = null, // This should only ever be chunked, compression is handled separately. + transfer_compression: ?http.TransferEncoding = null, + connection: http.Connection = .close, + + number_of_headers: usize = 0, pub fn parse(bytes: []const u8) !Response.Headers { var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); @@ -137,6 +201,8 @@ pub const Request = struct { }; while (it.next()) |line| { + headers.number_of_headers += 1; + if (line.len == 0) return error.HttpHeadersInvalid; switch (line[0]) { ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, @@ -152,14 +218,65 @@ pub const Request = struct { if (headers.content_length != null) return error.HttpHeadersInvalid; headers.content_length = try std.fmt.parseInt(u64, header_value, 10); } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; - headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse + if (headers.transfer_encoding != null or headers.transfer_compression != null) return error.HttpHeadersInvalid; + + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = std.mem.splitBackwards(u8, header_value, ","); + + if (iter.next()) |first| { + const kind = std.meta.stringToEnum( + http.TransferEncoding, + std.mem.trim(u8, first, " "), + ) orelse + return error.HttpTransferEncodingUnsupported; + + switch (kind) { + .chunked => headers.transfer_encoding = .chunked, + .compress => headers.transfer_compression = .compress, + .deflate => headers.transfer_compression = .deflate, + .gzip => headers.transfer_compression = .gzip, + } + } + + if (iter.next()) |second| { + if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + + const kind = std.meta.stringToEnum( + http.TransferEncoding, + std.mem.trim(u8, second, " "), + ) orelse + return error.HttpTransferEncodingUnsupported; + + switch (kind) { + .chunked => return error.HttpHeadersInvalid, // chunked must come last + .compress => return error.HttpTransferEncodingUnsupported, // compress not supported + .deflate => headers.transfer_compression = .deflate, + .gzip => headers.transfer_compression = .gzip, + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (headers.transfer_compression != null) return error.HttpHeadersInvalid; + + const kind = std.meta.stringToEnum( + http.TransferEncoding, + std.mem.trim(u8, header_value, " "), + ) orelse return error.HttpTransferEncodingUnsupported; + + switch (kind) { + .chunked => return error.HttpHeadersInvalid, // not transfer encoding + .compress => return error.HttpTransferEncodingUnsupported, // compress not supported + .deflate => headers.transfer_compression = .deflate, + .gzip => headers.transfer_compression = .gzip, + } } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { - headers.connection_close = false; + headers.connection = .keep_alive; } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { - headers.connection_close = true; + headers.connection = .close; } else { return error.HttpConnectionHeaderUnsupported; } @@ -238,7 +355,6 @@ pub const Request = struct { .max_header_bytes = max, .header_bytes_owned = true, .next_chunk_length = undefined, - .done = false, }; } @@ -250,7 +366,6 @@ pub const Request = struct { .max_header_bytes = buf.len, .header_bytes_owned = false, .next_chunk_length = undefined, - .done = false, }; } @@ -537,10 +652,19 @@ pub const Request = struct { } }; + pub const RequestTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, + }; + pub const Headers = struct { version: http.Version = .@"HTTP/1.1", method: http.Method = .GET, - connection_close: bool = false, + connection: http.Connection = .keep_alive, + transfer_encoding: RequestTransfer = .none, + + custom: []const http.CustomHeader = &[_]http.CustomHeader{}, }; pub const Options = struct { @@ -561,167 +685,131 @@ pub const Request = struct { }; }; - /// May be skipped if header strategy is buffer. + /// Frees all resources associated with the request. pub fn deinit(req: *Request) void { + switch (req.response.compression) { + .none => {}, + .deflate => |*deflate| deflate.deinit(), + .gzip => |*gzip| gzip.deinit(), + } + if (req.response.header_bytes_owned) { req.response.header_bytes.deinit(req.client.allocator); } + + if (!req.response.done) { + // If the response wasn't fully read, then we need to close the connection. + req.connection.data.closing = true; + req.client.release(req.connection); + } + req.* = undefined; } - pub const Reader = std.io.Reader(*Request, ReadError, read); - - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - pub fn readAll(req: *Request, buffer: []u8) !usize { - return readAtLeast(req, buffer, buffer.len); - } - - pub const ReadError = net.Stream.ReadError || error{ - // From HTTP protocol - HttpHeadersInvalid, - HttpHeadersExceededSizeLimit, - HttpRedirectMissingLocation, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - HttpContentLengthUnknown, + const ReadRawError = Connection.ReadError || std.Uri.ParseError || RequestError || error{ + UnexpectedEndOfStream, TooManyHttpRedirects, - ShortHttpStatusLine, - BadHttpVersion, - HttpHeaderContinuationsUnsupported, - UnsupportedUrlScheme, - UriMissingHost, - UnknownHostName, - - // Network problems - NetworkUnreachable, - HostLacksNetworkAddresses, - TemporaryNameServerFailure, - NameServerFailure, - ProtocolFamilyNotAvailable, - ProtocolNotSupported, - - // System resource problems - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - OutOfMemory, - - // TLS problems - InsufficientEntropy, - TlsConnectionTruncated, - TlsRecordOverflow, - TlsDecodeError, - TlsAlert, - TlsBadRecordMac, - TlsBadLength, - TlsIllegalParameter, - TlsUnexpectedMessage, - TlsDecryptFailure, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - TlsDecryptError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - CertificateAuthorityBundleTooBig, - - // TODO: convert to higher level errors - InvalidFormat, - InvalidPort, - UnexpectedCharacter, - Overflow, - InvalidCharacter, - AddressFamilyNotSupported, - AddressInUse, - AddressNotAvailable, - ConnectionPending, - ConnectionRefused, - FileNotFound, - PermissionDenied, - ServiceUnavailable, - SocketTypeNotSupported, - FileTooBig, - LockViolation, - NoSpaceLeft, - NotOpenForWriting, - InvalidEncoding, - IdentityElement, - NonCanonical, - SignatureVerificationFailed, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - DiskQuota, - InvalidEnd, - Incomplete, - InvalidIpv4Mapping, - InvalidIPAddressFormat, - BadPathName, - DeviceBusy, - FileBusy, - FileLocksNotSupported, - InvalidHandle, - InvalidUtf8, - NameTooLong, - NoDevice, - PathAlreadyExists, - PipeBusy, - SharingViolation, - SymLinkLoop, - FileSystem, - InterfaceNotFound, - AlreadyBound, - FileDescriptorNotASocket, - NetworkSubsystemFailed, - NotDir, - ReadOnlyFileSystem, - Unseekable, - MissingEndCertificateMarker, - InvalidPadding, - EndOfStream, - InvalidArgument, + HttpRedirectMissingLocation, + HttpHeadersInvalid, }; - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - return readAtLeast(req, buffer, 1); - } + const ReaderRaw = std.io.Reader(*Request, ReadRawError, readRaw); + + /// Read from the underlying stream, without decompressing or parsing the headers. Must be called + /// after waitForCompleteHead() has returned successfully. + pub fn readRaw(req: *Request, buffer: []u8) ReadRawError!usize { + assert(req.response.state.isContent()); - pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { - assert(len <= buffer.len); var index: usize = 0; - while (index < len) { - const amt = try readAdvanced(req, buffer[index..]); + while (index == 0) { + const amt = try req.readRawAdvanced(buffer[index..]); const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect; if (amt == 0 and zero_means_end) break; index += amt; } + return index; } + fn checkForCompleteHead(req: *Request, buffer: []u8) !usize { + switch (req.response.state) { + .invalid => unreachable, + .start, .seen_r, .seen_rn, .seen_rnr => {}, + else => return 0, // No more headers to read. + } + + const i = req.response.findHeadersEnd(buffer[0..]); + if (req.response.state == .invalid) return error.HttpHeadersInvalid; + + const headers_data = buffer[0..i]; + if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { + return error.HttpHeadersExceededSizeLimit; + } + try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); + + if (req.response.state == .finished) { + req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); + + if (req.response.headers.connection == .keep_alive) { + req.connection.data.closing = false; + } else { + req.connection.data.closing = true; + } + + if (req.response.headers.transfer_encoding) |transfer_encoding| { + switch (transfer_encoding) { + .chunked => { + req.response.next_chunk_length = 0; + req.response.state = .chunk_size; + }, + .compress => unreachable, + .deflate => unreachable, + .gzip => unreachable, + } + } else if (req.response.headers.content_length) |content_length| { + req.response.next_chunk_length = content_length; + } else { + req.response.done = true; + } + + return i; + } + + return 0; + } + + pub const WaitForCompleteHeadError = ReadRawError || error { + UnexpectedEndOfStream, + + HttpHeadersExceededSizeLimit, + ShortHttpStatusLine, + BadHttpVersion, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + }; + + /// Reads a complete response head. Any leftover data is stored in the request. This function is idempotent. + pub fn waitForCompleteHead(req: *Request) WaitForCompleteHeadError!void { + if (req.response.state.isContent()) return; + + while (true) { + const nread = try req.connection.data.read(req.read_buffer[0..]); + const amt = try checkForCompleteHead(req, req.read_buffer[0..nread]); + + if (amt != 0) { + req.read_buffer_start = @intCast(ReadBufferIndex, amt); + req.read_buffer_len = @intCast(ReadBufferIndex, nread); + return; + } else if (nread == 0) { + return error.UnexpectedEndOfStream; + } + } + } + /// This one can return 0 without meaning EOF. - /// TODO change to readvAdvanced - pub fn readAdvanced(req: *Request, buffer: []u8) !usize { + fn readRawAdvanced(req: *Request, buffer: []u8) !usize { if (req.response.done) { if (req.response.headers.status.class() == .redirect) { if (req.redirects_left == 0) return error.TooManyHttpRedirects; @@ -744,82 +832,56 @@ pub const Request = struct { } } - var in = buffer[0..try req.connection.data.read(buffer)]; + // var in: []const u8 = undefined; + if (req.read_buffer_start == req.read_buffer_len) { + const nread = try req.connection.data.read(req.read_buffer[0..]); + if (nread == 0) return error.UnexpectedEndOfStream; + + req.read_buffer_start = 0; + req.read_buffer_len = @intCast(ReadBufferIndex, nread); + } + var out_index: usize = 0; while (true) { switch (req.response.state) { - .invalid => unreachable, - .start, .seen_r, .seen_rn, .seen_rnr => { - const i = req.response.findHeadersEnd(in); - if (req.response.state == .invalid) return error.HttpHeadersInvalid; - - const headers_data = in[0..i]; - if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { - return error.HttpHeadersExceededSizeLimit; - } - try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); - - if (req.response.state == .finished) { - req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - - if (req.response.headers.connection_close == true) { - req.connection.data.unusable = true; - } else { - req.connection.data.unusable = false; - } - - if (req.response.headers.transfer_encoding) |transfer_encoding| { - switch (transfer_encoding) { - .chunked => { - req.response.next_chunk_length = 0; - req.response.state = .chunk_size; - }, - .compress => return error.HttpTransferEncodingUnsupported, - .deflate => return error.HttpTransferEncodingUnsupported, - .gzip => return error.HttpTransferEncodingUnsupported, - } - } else if (req.response.headers.content_length) |content_length| { - req.response.next_chunk_length = content_length; - } else { - return error.HttpContentLengthUnknown; - } - - in = in[i..]; - continue; - } - - assert(out_index == 0); - return 0; - }, + .invalid, .start, .seen_r, .seen_rn, .seen_rnr => unreachable, .finished => { - const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); - req.response.next_chunk_length -= sub_amt; + // TODO https://github.com/ziglang/zig/issues/14039 + const buf_avail = req.read_buffer_len - req.read_buffer_start; + const data_avail = req.response.next_chunk_length; + const out_avail = buffer.len; + + if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) { + const can_read = @intCast(usize, @min(buf_avail, data_avail)); + req.response.next_chunk_length -= can_read; + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + req.response.done = true; + continue; + } + + return 0; // skip over as much data as possible + } + + const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); + req.response.next_chunk_length -= can_read; + + mem.copy(u8, buffer[0..], req.read_buffer[req.read_buffer_start..][0..can_read]); + req.read_buffer_start += @intCast(ReadBufferIndex, can_read); if (req.response.next_chunk_length == 0) { req.client.release(req.connection); req.connection = undefined; - req.response.done = true; - assert(in.len == sub_amt); // TODO: figure out how to not read more than necessary. - - if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0; - - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - return out_index + sub_amt; } - if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0; - - if (in.ptr == buffer.ptr) { - return sub_amt; - } else { - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - return out_index + sub_amt; - } + return can_read; }, - .chunk_size_prefix_r => switch (in.len) { + .chunk_size_prefix_r => switch (req.read_buffer_len - req.read_buffer_start) { 0 => return out_index, - 1 => switch (in[0]) { + 1 => switch (req.read_buffer[req.read_buffer_start]) { '\r' => { req.response.state = .chunk_size_prefix_n; return out_index; @@ -829,9 +891,9 @@ pub const Request = struct { return error.HttpHeadersInvalid; }, }, - else => switch (int16(in[0..2])) { + else => switch (int16(req.read_buffer[req.read_buffer_start..][0..2])) { int16("\r\n") => { - in = in[2..]; + req.read_buffer_start += 2; req.response.state = .chunk_size; continue; }, @@ -841,11 +903,11 @@ pub const Request = struct { }, }, }, - .chunk_size_prefix_n => switch (in.len) { + .chunk_size_prefix_n => switch (req.read_buffer_len - req.read_buffer_start) { 0 => return out_index, - else => switch (in[0]) { + else => switch (req.read_buffer[req.read_buffer_start]) { '\n' => { - in = in[1..]; + req.read_buffer_start += 1; req.response.state = .chunk_size; continue; }, @@ -856,7 +918,7 @@ pub const Request = struct { }, }, .chunk_size, .chunk_r => { - const i = req.response.findChunkedLen(in); + const i = req.response.findChunkedLen(req.read_buffer[req.read_buffer_start..req.read_buffer_len]); switch (req.response.state) { .invalid => return error.HttpHeadersInvalid, .chunk_data => { @@ -867,7 +929,8 @@ pub const Request = struct { return out_index; } - in = in[i..]; + + req.read_buffer_start += @intCast(ReadBufferIndex, i); continue; }, .chunk_size => return out_index, @@ -876,34 +939,129 @@ pub const Request = struct { }, .chunk_data => { // TODO https://github.com/ziglang/zig/issues/14039 - const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); - req.response.next_chunk_length -= sub_amt; + const buf_avail = req.read_buffer_len - req.read_buffer_start; + const data_avail = req.response.next_chunk_length; + const out_avail = buffer.len - out_index; + + if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) { + const can_read = @intCast(usize, @min(buf_avail, data_avail)); + req.response.next_chunk_length -= can_read; + + if (req.response.next_chunk_length == 0) { + req.client.release(req.connection); + req.connection = undefined; + req.response.done = true; + continue; + } + + return 0; // skip over as much data as possible + } + + const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); + req.response.next_chunk_length -= can_read; + + mem.copy(u8, buffer[out_index..], req.read_buffer[req.read_buffer_start..][0..can_read]); + req.read_buffer_start += @intCast(ReadBufferIndex, can_read); + out_index += can_read; if (req.response.next_chunk_length == 0) { req.response.state = .chunk_size_prefix_r; - in = in[sub_amt..]; - if (req.response.headers.status.class() == .redirect) continue; - - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; continue; } - if (req.response.headers.status.class() == .redirect) return 0; - - if (in.ptr == buffer.ptr) { - return sub_amt; - } else { - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; - return out_index; - } + return out_index; }, } } } + pub const ReadError = DeflateDecompressor.Error || GzipDecompressor.Error || WaitForCompleteHeadError || error{ + BadHeader, + InvalidCompression, + StreamTooLong, + InvalidWindowSize, + }; + + pub const Reader = std.io.Reader(*Request, ReadError, read); + + pub fn reader(req: *Request) Reader { + return .{ .context = req }; + } + + pub fn read(req: *Request, buffer: []u8) ReadError!usize { + if (!req.response.state.isContent()) try req.waitForCompleteHead(); + + if (req.response.compression == .none and req.response.state.isContent()) { + if (req.response.headers.transfer_compression) |compression| { + switch (compression) { + .compress => unreachable, + .deflate => req.response.compression = .{ + .deflate = try std.compress.zlib.zlibStream(req.client.allocator, ReaderRaw{ .context = req }), + }, + .gzip => req.response.compression = .{ + .gzip = try std.compress.gzip.decompress(req.client.allocator, ReaderRaw{ .context = req }), + }, + .chunked => unreachable, + } + } + } + + return switch (req.response.compression) { + .deflate => |*deflate| try deflate.read(buffer), + .gzip => |*gzip| try gzip.read(buffer), + else => try req.readRaw(buffer), + }; + } + + pub fn readAll(req: *Request, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = try read(req, buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + pub const WriteError = Connection.WriteError || error{MessageTooLong}; + + pub const Writer = std.io.Writer(*Request, WriteError, write); + + pub fn writer(req: *Request) Writer { + return .{ .context = req }; + } + + /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + pub fn write(req: *Request, bytes: []const u8) !usize { + switch (req.headers.transfer_encoding) { + .chunked => { + try req.connection.data.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.data.writeAll(bytes); + try req.connection.data.writeAll("\r\n"); + + return bytes.len; + }, + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = try req.connection.data.write(bytes); + len.* -= amt; + return amt; + }, + .none => return error.NotWriteable, + } + } + + /// Finish the body of a request. This notifies the server that you have no more data to send. + pub fn finish(req: *Request) !void { + switch (req.headers.transfer_encoding) { + .chunked => try req.connection.data.writeAll("0\r\n"), + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } + } + inline fn int16(array: *const [2]u8) u16 { return @bitCast(u16, array.*); } @@ -917,6 +1075,10 @@ pub const Request = struct { } test { + const builtin = @import("builtin"); + + if (builtin.os.tag == .wasi) return error.SkipZigTest; + _ = Response; } }; @@ -931,23 +1093,39 @@ pub fn deinit(client: *Client) void { client.allocator.destroy(node); } + next = client.connection_used.first; + while (next) |node| { + next = node.next; + + node.data.close(client); + + client.allocator.destroy(node); + } + client.ca_bundle.deinit(client.allocator); client.* = undefined; } -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !*ConnectionNode { - var potential = client.connection_pool.last; - while (potential) |node| { - const same_host = mem.eql(u8, node.data.host, host); - const same_port = node.data.port == port; - const same_protocol = node.data.protocol == protocol; +pub const ConnectError = std.mem.Allocator.Error || std.net.TcpConnectToHostError || std.crypto.tls.Client.InitError(std.net.Stream); - if (same_host and same_port and same_protocol) { - client.connection_pool.remove(node); - return node; +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode { + { // Search through the connection pool for a potential connection. + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + + var potential = client.connection_pool.last; + while (potential) |node| { + const same_host = mem.eql(u8, node.data.host, host); + const same_port = node.data.port == port; + const same_protocol = node.data.protocol == protocol; + + if (same_host and same_port and same_protocol) { + client.acquire(node); + return node; + } + + potential = node.prev; } - - potential = node.prev; } const conn = try client.allocator.create(ConnectionNode); @@ -964,17 +1142,35 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio switch (protocol) { .plain => {}, .tls => { - conn.data.tls_client = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); + conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. conn.data.tls_client.allow_truncation_attacks = true; }, } + { + client.connection_mutex.lock(); + defer client.connection_mutex.unlock(); + + client.connection_used.append(conn); + } + return conn; } -pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) !Request { +pub const RequestError = ConnectError || Connection.WriteError || error{ + UnsupportedUrlScheme, + UriMissingHost, + + CertificateAuthorityBundleTooBig, + InvalidPadding, + MissingEndCertificateMarker, + Unseekable, +}; + +pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request { const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http")) .plain else if (mem.eql(u8, uri.scheme, "https")) @@ -990,8 +1186,13 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req const host = uri.host orelse return error.UriMissingHost; if (client.next_https_rescan_certs and protocol == .tls) { - try client.ca_bundle.rescan(client.allocator); - client.next_https_rescan_certs = false; + client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex. + defer client.connection_mutex.unlock(); + + if (client.next_https_rescan_certs) { + try client.ca_bundle.rescan(client.allocator); + client.next_https_rescan_certs = false; + } } var req: Request = .{ @@ -1006,23 +1207,39 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req }; { - var h = try std.BoundedArray(u8, 1000).init(0); - try h.appendSlice(@tagName(headers.method)); - try h.appendSlice(" "); - try h.appendSlice(uri.path); - try h.appendSlice(" "); - try h.appendSlice(@tagName(headers.version)); - try h.appendSlice("\r\nHost: "); - try h.appendSlice(host); - if (headers.connection_close) { - try h.appendSlice("\r\nConnection: close"); - } else { - try h.appendSlice("\r\nConnection: keep-alive"); - } - try h.appendSlice("\r\n\r\n"); + var buffered = std.io.bufferedWriter(req.connection.data.writer()); + const writer = buffered.writer(); - const header_bytes = h.slice(); - try req.connection.data.writeAll(header_bytes); + try writer.writeAll(@tagName(headers.method)); + try writer.writeByte(' '); + try writer.writeAll(uri.path); + try writer.writeByte(' '); + try writer.writeAll(@tagName(headers.version)); + try writer.writeAll("\r\nHost: "); + try writer.writeAll(host); + if (headers.connection == .close) { + try writer.writeAll("\r\nConnection: close"); + } else { + try writer.writeAll("\r\nConnection: keep-alive"); + } + try writer.writeAll("\r\nAccept-Encoding: gzip, deflate"); + + switch (headers.transfer_encoding) { + .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"), + .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}), + .none => {}, + } + + for (headers.custom) |header| { + try writer.writeAll("\r\n"); + try writer.writeAll(header.name); + try writer.writeAll(": "); + try writer.writeAll(header.value); + } + + try writer.writeAll("\r\n\r\n"); + + try buffered.flush(); } return req; @@ -1036,5 +1253,7 @@ test { return error.SkipZigTest; } + if (builtin.os.tag == .wasi) return error.SkipZigTest; + _ = Request; }