Merge pull request #15927 from truemedian/http-bugs
std.http: fix infinite read loop, deduplicate connection code, add TlsAlert errors
This commit is contained in:
@@ -138,6 +138,35 @@ pub const AlertLevel = enum(u8) {
|
||||
};
|
||||
|
||||
pub const AlertDescription = enum(u8) {
|
||||
pub const Error = error{
|
||||
TlsAlertUnexpectedMessage,
|
||||
TlsAlertBadRecordMac,
|
||||
TlsAlertRecordOverflow,
|
||||
TlsAlertHandshakeFailure,
|
||||
TlsAlertBadCertificate,
|
||||
TlsAlertUnsupportedCertificate,
|
||||
TlsAlertCertificateRevoked,
|
||||
TlsAlertCertificateExpired,
|
||||
TlsAlertCertificateUnknown,
|
||||
TlsAlertIllegalParameter,
|
||||
TlsAlertUnknownCa,
|
||||
TlsAlertAccessDenied,
|
||||
TlsAlertDecodeError,
|
||||
TlsAlertDecryptError,
|
||||
TlsAlertProtocolVersion,
|
||||
TlsAlertInsufficientSecurity,
|
||||
TlsAlertInternalError,
|
||||
TlsAlertInappropriateFallback,
|
||||
TlsAlertMissingExtension,
|
||||
TlsAlertUnsupportedExtension,
|
||||
TlsAlertUnrecognizedName,
|
||||
TlsAlertBadCertificateStatusResponse,
|
||||
TlsAlertUnknownPskIdentity,
|
||||
TlsAlertCertificateRequired,
|
||||
TlsAlertNoApplicationProtocol,
|
||||
TlsAlertUnknown,
|
||||
};
|
||||
|
||||
close_notify = 0,
|
||||
unexpected_message = 10,
|
||||
bad_record_mac = 20,
|
||||
@@ -166,6 +195,39 @@ pub const AlertDescription = enum(u8) {
|
||||
certificate_required = 116,
|
||||
no_application_protocol = 120,
|
||||
_,
|
||||
|
||||
pub fn toError(alert: AlertDescription) Error!void {
|
||||
return switch (alert) {
|
||||
.close_notify => {}, // not an error
|
||||
.unexpected_message => error.TlsAlertUnexpectedMessage,
|
||||
.bad_record_mac => error.TlsAlertBadRecordMac,
|
||||
.record_overflow => error.TlsAlertRecordOverflow,
|
||||
.handshake_failure => error.TlsAlertHandshakeFailure,
|
||||
.bad_certificate => error.TlsAlertBadCertificate,
|
||||
.unsupported_certificate => error.TlsAlertUnsupportedCertificate,
|
||||
.certificate_revoked => error.TlsAlertCertificateRevoked,
|
||||
.certificate_expired => error.TlsAlertCertificateExpired,
|
||||
.certificate_unknown => error.TlsAlertCertificateUnknown,
|
||||
.illegal_parameter => error.TlsAlertIllegalParameter,
|
||||
.unknown_ca => error.TlsAlertUnknownCa,
|
||||
.access_denied => error.TlsAlertAccessDenied,
|
||||
.decode_error => error.TlsAlertDecodeError,
|
||||
.decrypt_error => error.TlsAlertDecryptError,
|
||||
.protocol_version => error.TlsAlertProtocolVersion,
|
||||
.insufficient_security => error.TlsAlertInsufficientSecurity,
|
||||
.internal_error => error.TlsAlertInternalError,
|
||||
.inappropriate_fallback => error.TlsAlertInappropriateFallback,
|
||||
.user_canceled => {}, // not an error
|
||||
.missing_extension => error.TlsAlertMissingExtension,
|
||||
.unsupported_extension => error.TlsAlertUnsupportedExtension,
|
||||
.unrecognized_name => error.TlsAlertUnrecognizedName,
|
||||
.bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse,
|
||||
.unknown_psk_identity => error.TlsAlertUnknownPskIdentity,
|
||||
.certificate_required => error.TlsAlertCertificateRequired,
|
||||
.no_application_protocol => error.TlsAlertNoApplicationProtocol,
|
||||
_ => error.TlsAlertUnknown,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const SignatureScheme = enum(u16) {
|
||||
|
||||
@@ -89,12 +89,11 @@ pub const StreamInterface = struct {
|
||||
};
|
||||
|
||||
pub fn InitError(comptime Stream: type) type {
|
||||
return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{
|
||||
return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{
|
||||
InsufficientEntropy,
|
||||
DiskQuota,
|
||||
LockViolation,
|
||||
NotOpenForWriting,
|
||||
TlsAlert,
|
||||
TlsUnexpectedMessage,
|
||||
TlsIllegalParameter,
|
||||
TlsDecryptFailure,
|
||||
@@ -251,8 +250,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
|
||||
const level = ptd.decode(tls.AlertLevel);
|
||||
const desc = ptd.decode(tls.AlertDescription);
|
||||
_ = level;
|
||||
_ = desc;
|
||||
return error.TlsAlert;
|
||||
|
||||
// if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake
|
||||
try desc.toError();
|
||||
// TODO: handle server-side closures
|
||||
return error.TlsUnexpectedMessage;
|
||||
},
|
||||
.handshake => {
|
||||
try ptd.ensure(4);
|
||||
@@ -1071,8 +1073,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
|
||||
const level = @intToEnum(tls.AlertLevel, frag[in]);
|
||||
const desc = @intToEnum(tls.AlertDescription, frag[in + 1]);
|
||||
_ = level;
|
||||
_ = desc;
|
||||
return error.TlsAlert;
|
||||
|
||||
try desc.toError();
|
||||
// TODO: handle server-side closures
|
||||
return error.TlsUnexpectedMessage;
|
||||
},
|
||||
.application_data => {
|
||||
const cleartext = switch (c.application_cipher) {
|
||||
@@ -1112,7 +1116,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
|
||||
return vp.total;
|
||||
}
|
||||
_ = level;
|
||||
return error.TlsAlert;
|
||||
|
||||
try desc.toError();
|
||||
// TODO: handle server-side closures
|
||||
return error.TlsUnexpectedMessage;
|
||||
},
|
||||
.handshake => {
|
||||
var ct_i: usize = 0;
|
||||
|
||||
@@ -36,21 +36,7 @@ pub const ConnectionPool = struct {
|
||||
is_tls: bool,
|
||||
};
|
||||
|
||||
pub const StoredConnection = struct {
|
||||
buffered: BufferedConnection,
|
||||
host: []u8,
|
||||
port: u16,
|
||||
|
||||
proxied: bool = false,
|
||||
closing: bool = false,
|
||||
|
||||
pub fn deinit(self: *StoredConnection, client: *Client) void {
|
||||
self.buffered.close(client);
|
||||
client.allocator.free(self.host);
|
||||
}
|
||||
};
|
||||
|
||||
const Queue = std.TailQueue(StoredConnection);
|
||||
const Queue = std.TailQueue(Connection);
|
||||
pub const Node = Queue.Node;
|
||||
|
||||
mutex: std.Thread.Mutex = .{},
|
||||
@@ -69,7 +55,7 @@ pub const ConnectionPool = struct {
|
||||
|
||||
var next = pool.free.last;
|
||||
while (next) |node| : (next = node.prev) {
|
||||
if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue;
|
||||
if ((node.data.protocol == .tls) != criteria.is_tls) continue;
|
||||
if (node.data.port != criteria.port) continue;
|
||||
if (!mem.eql(u8, node.data.host, criteria.host)) continue;
|
||||
|
||||
@@ -160,45 +146,105 @@ pub const ConnectionPool = struct {
|
||||
|
||||
/// An interface to either a plain or TLS connection.
|
||||
pub const Connection = struct {
|
||||
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
|
||||
pub const Protocol = enum { plain, tls };
|
||||
|
||||
stream: net.Stream,
|
||||
/// undefined unless protocol is tls.
|
||||
tls_client: *std.crypto.tls.Client,
|
||||
|
||||
protocol: Protocol,
|
||||
host: []u8,
|
||||
port: u16,
|
||||
|
||||
pub const Protocol = enum { plain, tls };
|
||||
proxied: bool = false,
|
||||
closing: bool = false,
|
||||
|
||||
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
|
||||
return switch (conn.protocol) {
|
||||
.plain => conn.stream.read(buffer),
|
||||
.tls => conn.tls_client.read(conn.stream, buffer),
|
||||
} catch |err| switch (err) {
|
||||
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
|
||||
error.TlsAlert => return error.TlsAlert,
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
};
|
||||
}
|
||||
read_start: u16 = 0,
|
||||
read_end: u16 = 0,
|
||||
read_buf: [buffer_size]u8 = undefined,
|
||||
|
||||
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
|
||||
pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
|
||||
return switch (conn.protocol) {
|
||||
.plain => conn.stream.readAtLeast(buffer, len),
|
||||
.tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
|
||||
} catch |err| switch (err) {
|
||||
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
|
||||
error.TlsAlert => return error.TlsAlert,
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
} catch |err| {
|
||||
// TODO: https://github.com/ziglang/zig/issues/2473
|
||||
if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
|
||||
|
||||
switch (err) {
|
||||
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn fill(conn: *Connection) ReadError!void {
|
||||
if (conn.read_end != conn.read_start) return;
|
||||
|
||||
const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
|
||||
if (nread == 0) return error.EndOfStream;
|
||||
conn.read_start = 0;
|
||||
conn.read_end = @intCast(u16, nread);
|
||||
}
|
||||
|
||||
pub fn peek(conn: *Connection) []const u8 {
|
||||
return conn.read_buf[conn.read_start..conn.read_end];
|
||||
}
|
||||
|
||||
pub fn drop(conn: *Connection, num: u16) void {
|
||||
conn.read_start += num;
|
||||
}
|
||||
|
||||
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
|
||||
assert(len <= buffer.len);
|
||||
|
||||
var out_index: u16 = 0;
|
||||
while (out_index < len) {
|
||||
const available_read = conn.read_end - conn.read_start;
|
||||
const available_buffer = buffer.len - out_index;
|
||||
|
||||
if (available_read > available_buffer) { // partially read buffered data
|
||||
@memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
|
||||
out_index += @intCast(u16, available_buffer);
|
||||
conn.read_start += @intCast(u16, available_buffer);
|
||||
|
||||
break;
|
||||
} else if (available_read > 0) { // fully read buffered data
|
||||
@memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
|
||||
out_index += available_read;
|
||||
conn.read_start += available_read;
|
||||
|
||||
if (out_index >= len) break;
|
||||
}
|
||||
|
||||
const leftover_buffer = available_buffer - available_read;
|
||||
const leftover_len = len - out_index;
|
||||
|
||||
if (leftover_buffer > conn.read_buf.len) {
|
||||
// skip the buffer if the output is large enough
|
||||
return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
|
||||
}
|
||||
|
||||
try conn.fill();
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
|
||||
return conn.readAtLeast(buffer, 1);
|
||||
}
|
||||
|
||||
pub const ReadError = error{
|
||||
TlsFailure,
|
||||
TlsAlert,
|
||||
ConnectionTimedOut,
|
||||
ConnectionResetByPeer,
|
||||
UnexpectedReadFailure,
|
||||
EndOfStream,
|
||||
};
|
||||
|
||||
pub const Reader = std.io.Reader(*Connection, ReadError, read);
|
||||
@@ -247,111 +293,10 @@ pub const Connection = struct {
|
||||
|
||||
conn.stream.close();
|
||||
}
|
||||
};
|
||||
|
||||
/// A buffered (and peekable) Connection.
|
||||
pub const BufferedConnection = struct {
|
||||
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
|
||||
|
||||
conn: Connection,
|
||||
read_buf: [buffer_size]u8 = undefined,
|
||||
read_start: u16 = 0,
|
||||
read_end: u16 = 0,
|
||||
|
||||
write_buf: [buffer_size]u8 = undefined,
|
||||
write_end: u16 = 0,
|
||||
|
||||
pub fn fill(bconn: *BufferedConnection) ReadError!void {
|
||||
if (bconn.read_end != bconn.read_start) return;
|
||||
|
||||
const nread = try bconn.conn.read(bconn.read_buf[0..]);
|
||||
if (nread == 0) return error.EndOfStream;
|
||||
bconn.read_start = 0;
|
||||
bconn.read_end = @intCast(u16, nread);
|
||||
}
|
||||
|
||||
pub fn peek(bconn: *BufferedConnection) []const u8 {
|
||||
return bconn.read_buf[bconn.read_start..bconn.read_end];
|
||||
}
|
||||
|
||||
pub fn clear(bconn: *BufferedConnection, num: u16) void {
|
||||
bconn.read_start += num;
|
||||
}
|
||||
|
||||
pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
|
||||
var out_index: u16 = 0;
|
||||
while (out_index < len) {
|
||||
const available = bconn.read_end - bconn.read_start;
|
||||
const left = buffer.len - out_index;
|
||||
|
||||
if (available > 0) {
|
||||
const can_read = @intCast(u16, @min(available, left));
|
||||
|
||||
@memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]);
|
||||
out_index += can_read;
|
||||
bconn.read_start += can_read;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (left > bconn.read_buf.len) {
|
||||
// skip the buffer if the output is large enough
|
||||
return bconn.conn.read(buffer[out_index..]);
|
||||
}
|
||||
|
||||
try bconn.fill();
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
|
||||
return bconn.readAtLeast(buffer, 1);
|
||||
}
|
||||
|
||||
pub const ReadError = Connection.ReadError || error{EndOfStream};
|
||||
pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
|
||||
|
||||
pub fn reader(bconn: *BufferedConnection) Reader {
|
||||
return Reader{ .context = bconn };
|
||||
}
|
||||
|
||||
pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
|
||||
if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
|
||||
@memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
|
||||
bconn.write_end += @intCast(u16, buffer.len);
|
||||
} else {
|
||||
try bconn.flush();
|
||||
try bconn.conn.writeAll(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
|
||||
if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
|
||||
@memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
|
||||
bconn.write_end += @intCast(u16, buffer.len);
|
||||
|
||||
return buffer.len;
|
||||
} else {
|
||||
try bconn.flush();
|
||||
return try bconn.conn.write(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(bconn: *BufferedConnection) WriteError!void {
|
||||
defer bconn.write_end = 0;
|
||||
return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
|
||||
}
|
||||
|
||||
pub const WriteError = Connection.WriteError;
|
||||
pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
|
||||
|
||||
pub fn writer(bconn: *BufferedConnection) Writer {
|
||||
return Writer{ .context = bconn };
|
||||
}
|
||||
|
||||
pub fn close(bconn: *BufferedConnection, client: *const Client) void {
|
||||
bconn.conn.close(client);
|
||||
pub fn deinit(conn: *Connection, client: *const Client) void {
|
||||
conn.close(client);
|
||||
client.allocator.free(conn.host);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -585,11 +530,12 @@ pub const Request = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
|
||||
pub const StartError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
|
||||
|
||||
/// Send the request to the server.
|
||||
pub fn start(req: *Request) StartError!void {
|
||||
const w = req.connection.data.buffered.writer();
|
||||
var buffered = std.io.bufferedWriter(req.connection.data.writer());
|
||||
const w = buffered.writer();
|
||||
|
||||
try w.writeAll(@tagName(req.method));
|
||||
try w.writeByte(' ');
|
||||
@@ -663,10 +609,10 @@ pub const Request = struct {
|
||||
|
||||
try w.writeAll("\r\n");
|
||||
|
||||
try req.connection.data.buffered.flush();
|
||||
try buffered.flush();
|
||||
}
|
||||
|
||||
pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
|
||||
pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
|
||||
|
||||
pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
|
||||
|
||||
@@ -679,7 +625,7 @@ pub const Request = struct {
|
||||
|
||||
var index: usize = 0;
|
||||
while (index == 0) {
|
||||
const amt = try req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip);
|
||||
const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip);
|
||||
if (amt == 0 and req.response.parser.done) break;
|
||||
index += amt;
|
||||
}
|
||||
@@ -697,10 +643,10 @@ pub const Request = struct {
|
||||
pub fn wait(req: *Request) WaitError!void {
|
||||
while (true) { // handle redirects
|
||||
while (true) { // read headers
|
||||
try req.connection.data.buffered.fill();
|
||||
try req.connection.data.fill();
|
||||
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
|
||||
req.connection.data.buffered.clear(@intCast(u16, nchecked));
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
|
||||
req.connection.data.drop(@intCast(u16, nchecked));
|
||||
|
||||
if (req.response.parser.state.isContent()) break;
|
||||
}
|
||||
@@ -816,10 +762,10 @@ pub const Request = struct {
|
||||
const has_trail = !req.response.parser.state.isContent();
|
||||
|
||||
while (!req.response.parser.state.isContent()) { // read trailing headers
|
||||
try req.connection.data.buffered.fill();
|
||||
try req.connection.data.fill();
|
||||
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
|
||||
req.connection.data.buffered.clear(@intCast(u16, nchecked));
|
||||
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
|
||||
req.connection.data.drop(@intCast(u16, nchecked));
|
||||
}
|
||||
|
||||
if (has_trail) {
|
||||
@@ -845,7 +791,7 @@ pub const Request = struct {
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
|
||||
pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
|
||||
|
||||
pub const Writer = std.io.Writer(*Request, WriteError, write);
|
||||
|
||||
@@ -857,16 +803,16 @@ pub const Request = struct {
|
||||
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
|
||||
switch (req.transfer_encoding) {
|
||||
.chunked => {
|
||||
try req.connection.data.buffered.writer().print("{x}\r\n", .{bytes.len});
|
||||
try req.connection.data.buffered.writeAll(bytes);
|
||||
try req.connection.data.buffered.writeAll("\r\n");
|
||||
try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
|
||||
try req.connection.data.writeAll(bytes);
|
||||
try req.connection.data.writeAll("\r\n");
|
||||
|
||||
return bytes.len;
|
||||
},
|
||||
.content_length => |*len| {
|
||||
if (len.* < bytes.len) return error.MessageTooLong;
|
||||
|
||||
const amt = try req.connection.data.buffered.write(bytes);
|
||||
const amt = try req.connection.data.write(bytes);
|
||||
len.* -= amt;
|
||||
return amt;
|
||||
},
|
||||
@@ -886,12 +832,10 @@ pub const Request = struct {
|
||||
/// Finish the body of a request. This notifies the server that you have no more data to send.
|
||||
pub fn finish(req: *Request) FinishError!void {
|
||||
switch (req.transfer_encoding) {
|
||||
.chunked => try req.connection.data.buffered.writeAll("0\r\n\r\n"),
|
||||
.chunked => try req.connection.data.writeAll("0\r\n\r\n"),
|
||||
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
|
||||
.none => {},
|
||||
}
|
||||
|
||||
try req.connection.data.buffered.flush();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -948,11 +892,10 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol:
|
||||
errdefer stream.close();
|
||||
|
||||
conn.data = .{
|
||||
.buffered = .{ .conn = .{
|
||||
.stream = stream,
|
||||
.tls_client = undefined,
|
||||
.protocol = protocol,
|
||||
} },
|
||||
.stream = stream,
|
||||
.tls_client = undefined,
|
||||
.protocol = protocol,
|
||||
|
||||
.host = try client.allocator.dupe(u8, host),
|
||||
.port = port,
|
||||
};
|
||||
@@ -961,13 +904,13 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol:
|
||||
switch (protocol) {
|
||||
.plain => {},
|
||||
.tls => {
|
||||
conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
|
||||
errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client);
|
||||
conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
|
||||
errdefer client.allocator.destroy(conn.data.tls_client);
|
||||
|
||||
conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
|
||||
conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
|
||||
// This is appropriate for HTTPS because the HTTP headers contain
|
||||
// the content length which is used to detect truncation attacks.
|
||||
conn.data.buffered.conn.tls_client.allow_truncation_attacks = true;
|
||||
conn.data.tls_client.allow_truncation_attacks = true;
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1003,7 +946,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
|
||||
}
|
||||
}
|
||||
|
||||
pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
|
||||
pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || Connection.WriteError || error{
|
||||
UnsupportedUrlScheme,
|
||||
UriMissingHost,
|
||||
|
||||
|
||||
@@ -16,39 +16,92 @@ socket: net.StreamServer,
|
||||
|
||||
/// An interface to either a plain or TLS connection.
|
||||
pub const Connection = struct {
|
||||
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
|
||||
pub const Protocol = enum { plain };
|
||||
|
||||
stream: net.Stream,
|
||||
protocol: Protocol,
|
||||
|
||||
closing: bool = true,
|
||||
|
||||
pub const Protocol = enum { plain };
|
||||
read_buf: [buffer_size]u8 = undefined,
|
||||
read_start: u16 = 0,
|
||||
read_end: u16 = 0,
|
||||
|
||||
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
|
||||
pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
|
||||
return switch (conn.protocol) {
|
||||
.plain => conn.stream.read(buffer),
|
||||
// .tls => return conn.tls_client.read(conn.stream, buffer),
|
||||
} catch |err| switch (err) {
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
.plain => conn.stream.readAtLeast(buffer, len),
|
||||
// .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
|
||||
} catch |err| {
|
||||
switch (err) {
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn fill(conn: *Connection) ReadError!void {
|
||||
if (conn.read_end != conn.read_start) return;
|
||||
|
||||
const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
|
||||
if (nread == 0) return error.EndOfStream;
|
||||
conn.read_start = 0;
|
||||
conn.read_end = @intCast(u16, nread);
|
||||
}
|
||||
|
||||
pub fn peek(conn: *Connection) []const u8 {
|
||||
return conn.read_buf[conn.read_start..conn.read_end];
|
||||
}
|
||||
|
||||
pub fn drop(conn: *Connection, num: u16) void {
|
||||
conn.read_start += num;
|
||||
}
|
||||
|
||||
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
|
||||
return switch (conn.protocol) {
|
||||
.plain => conn.stream.readAtLeast(buffer, len),
|
||||
// .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
|
||||
} catch |err| switch (err) {
|
||||
error.ConnectionTimedOut => return error.ConnectionTimedOut,
|
||||
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
|
||||
else => return error.UnexpectedReadFailure,
|
||||
};
|
||||
assert(len <= buffer.len);
|
||||
|
||||
var out_index: u16 = 0;
|
||||
while (out_index < len) {
|
||||
const available_read = conn.read_end - conn.read_start;
|
||||
const available_buffer = buffer.len - out_index;
|
||||
|
||||
if (available_read > available_buffer) { // partially read buffered data
|
||||
@memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
|
||||
out_index += @intCast(u16, available_buffer);
|
||||
conn.read_start += @intCast(u16, available_buffer);
|
||||
|
||||
break;
|
||||
} else if (available_read > 0) { // fully read buffered data
|
||||
@memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
|
||||
out_index += available_read;
|
||||
conn.read_start += available_read;
|
||||
|
||||
if (out_index >= len) break;
|
||||
}
|
||||
|
||||
const leftover_buffer = available_buffer - available_read;
|
||||
const leftover_len = len - out_index;
|
||||
|
||||
if (leftover_buffer > conn.read_buf.len) {
|
||||
// skip the buffer if the output is large enough
|
||||
return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
|
||||
}
|
||||
|
||||
try conn.fill();
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
|
||||
return conn.readAtLeast(buffer, 1);
|
||||
}
|
||||
|
||||
pub const ReadError = error{
|
||||
ConnectionTimedOut,
|
||||
ConnectionResetByPeer,
|
||||
UnexpectedReadFailure,
|
||||
EndOfStream,
|
||||
};
|
||||
|
||||
pub const Reader = std.io.Reader(*Connection, ReadError, read);
|
||||
@@ -93,112 +146,6 @@ pub const Connection = struct {
|
||||
}
|
||||
};
|
||||
|
||||
/// A buffered (and peekable) Connection.
|
||||
pub const BufferedConnection = struct {
|
||||
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
|
||||
|
||||
conn: Connection,
|
||||
read_buf: [buffer_size]u8 = undefined,
|
||||
read_start: u16 = 0,
|
||||
read_end: u16 = 0,
|
||||
|
||||
write_buf: [buffer_size]u8 = undefined,
|
||||
write_end: u16 = 0,
|
||||
|
||||
pub fn fill(bconn: *BufferedConnection) ReadError!void {
|
||||
if (bconn.read_end != bconn.read_start) return;
|
||||
|
||||
const nread = try bconn.conn.read(bconn.read_buf[0..]);
|
||||
if (nread == 0) return error.EndOfStream;
|
||||
bconn.read_start = 0;
|
||||
bconn.read_end = @intCast(u16, nread);
|
||||
}
|
||||
|
||||
pub fn peek(bconn: *BufferedConnection) []const u8 {
|
||||
return bconn.read_buf[bconn.read_start..bconn.read_end];
|
||||
}
|
||||
|
||||
pub fn clear(bconn: *BufferedConnection, num: u16) void {
|
||||
bconn.read_start += num;
|
||||
}
|
||||
|
||||
pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
|
||||
var out_index: u16 = 0;
|
||||
while (out_index < len) {
|
||||
const available = bconn.read_end - bconn.read_start;
|
||||
const left = buffer.len - out_index;
|
||||
|
||||
if (available > 0) {
|
||||
const can_read = @intCast(u16, @min(available, left));
|
||||
|
||||
@memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]);
|
||||
out_index += can_read;
|
||||
bconn.read_start += can_read;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (left > bconn.read_buf.len) {
|
||||
// skip the buffer if the output is large enough
|
||||
return bconn.conn.read(buffer[out_index..]);
|
||||
}
|
||||
|
||||
try bconn.fill();
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
|
||||
return bconn.readAtLeast(buffer, 1);
|
||||
}
|
||||
|
||||
pub const ReadError = Connection.ReadError || error{EndOfStream};
|
||||
pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
|
||||
|
||||
pub fn reader(bconn: *BufferedConnection) Reader {
|
||||
return Reader{ .context = bconn };
|
||||
}
|
||||
|
||||
pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
|
||||
if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
|
||||
@memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
|
||||
bconn.write_end += @intCast(u16, buffer.len);
|
||||
} else {
|
||||
try bconn.flush();
|
||||
try bconn.conn.writeAll(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
|
||||
if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
|
||||
@memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
|
||||
bconn.write_end += @intCast(u16, buffer.len);
|
||||
|
||||
return buffer.len;
|
||||
} else {
|
||||
try bconn.flush();
|
||||
return try bconn.conn.write(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(bconn: *BufferedConnection) WriteError!void {
|
||||
defer bconn.write_end = 0;
|
||||
return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
|
||||
}
|
||||
|
||||
pub const WriteError = Connection.WriteError;
|
||||
pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
|
||||
|
||||
pub fn writer(bconn: *BufferedConnection) Writer {
|
||||
return Writer{ .context = bconn };
|
||||
}
|
||||
|
||||
pub fn close(bconn: *BufferedConnection) void {
|
||||
bconn.conn.close();
|
||||
}
|
||||
};
|
||||
|
||||
/// The mode of transport for responses.
|
||||
pub const ResponseTransfer = union(enum) {
|
||||
content_length: u64,
|
||||
@@ -351,7 +298,7 @@ pub const Response = struct {
|
||||
|
||||
allocator: Allocator,
|
||||
address: net.Address,
|
||||
connection: BufferedConnection,
|
||||
connection: Connection,
|
||||
|
||||
headers: http.Headers,
|
||||
request: Request,
|
||||
@@ -388,7 +335,7 @@ pub const Response = struct {
|
||||
|
||||
if (!res.request.parser.done) {
|
||||
// If the response wasn't fully read, then we need to close the connection.
|
||||
res.connection.conn.closing = true;
|
||||
res.connection.closing = true;
|
||||
return .closing;
|
||||
}
|
||||
|
||||
@@ -402,9 +349,9 @@ pub const Response = struct {
|
||||
const req_connection = res.request.headers.getFirstValue("connection");
|
||||
const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
|
||||
if (req_keepalive and (res_keepalive or res_connection == null)) {
|
||||
res.connection.conn.closing = false;
|
||||
res.connection.closing = false;
|
||||
} else {
|
||||
res.connection.conn.closing = true;
|
||||
res.connection.closing = true;
|
||||
}
|
||||
|
||||
switch (res.request.compression) {
|
||||
@@ -434,14 +381,14 @@ pub const Response = struct {
|
||||
.parser = res.request.parser,
|
||||
};
|
||||
|
||||
if (res.connection.conn.closing) {
|
||||
if (res.connection.closing) {
|
||||
return .closing;
|
||||
} else {
|
||||
return .reset;
|
||||
}
|
||||
}
|
||||
|
||||
pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength };
|
||||
pub const DoError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength };
|
||||
|
||||
/// Send the response headers.
|
||||
pub fn do(res: *Response) !void {
|
||||
@@ -450,7 +397,8 @@ pub const Response = struct {
|
||||
.first, .start, .responded, .finished => unreachable,
|
||||
}
|
||||
|
||||
const w = res.connection.writer();
|
||||
var buffered = std.io.bufferedWriter(res.connection.writer());
|
||||
const w = buffered.writer();
|
||||
|
||||
try w.writeAll(@tagName(res.version));
|
||||
try w.writeByte(' ');
|
||||
@@ -508,10 +456,10 @@ pub const Response = struct {
|
||||
|
||||
try w.writeAll("\r\n");
|
||||
|
||||
try res.connection.flush();
|
||||
try buffered.flush();
|
||||
}
|
||||
|
||||
pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
|
||||
pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
|
||||
|
||||
pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead);
|
||||
|
||||
@@ -532,7 +480,7 @@ pub const Response = struct {
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported };
|
||||
pub const WaitError = Connection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported };
|
||||
|
||||
/// Wait for the client to send a complete request head.
|
||||
pub fn wait(res: *Response) WaitError!void {
|
||||
@@ -545,7 +493,7 @@ pub const Response = struct {
|
||||
try res.connection.fill();
|
||||
|
||||
const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek());
|
||||
res.connection.clear(@intCast(u16, nchecked));
|
||||
res.connection.drop(@intCast(u16, nchecked));
|
||||
|
||||
if (res.request.parser.state.isContent()) break;
|
||||
}
|
||||
@@ -612,7 +560,7 @@ pub const Response = struct {
|
||||
try res.connection.fill();
|
||||
|
||||
const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek());
|
||||
res.connection.clear(@intCast(u16, nchecked));
|
||||
res.connection.drop(@intCast(u16, nchecked));
|
||||
}
|
||||
|
||||
if (has_trail) {
|
||||
@@ -637,7 +585,7 @@ pub const Response = struct {
|
||||
return index;
|
||||
}
|
||||
|
||||
pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
|
||||
pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
|
||||
|
||||
pub const Writer = std.io.Writer(*Response, WriteError, write);
|
||||
|
||||
@@ -692,8 +640,6 @@ pub const Response = struct {
|
||||
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
|
||||
.none => {},
|
||||
}
|
||||
|
||||
try res.connection.flush();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -742,10 +688,10 @@ pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response {
|
||||
return Response{
|
||||
.allocator = options.allocator,
|
||||
.address = in.address,
|
||||
.connection = .{ .conn = .{
|
||||
.connection = .{
|
||||
.stream = in.stream,
|
||||
.protocol = .plain,
|
||||
} },
|
||||
},
|
||||
.headers = .{ .allocator = options.allocator },
|
||||
.request = .{
|
||||
.version = undefined,
|
||||
|
||||
@@ -513,8 +513,8 @@ pub const HeadersParser = struct {
|
||||
///
|
||||
/// If `skip` is true, the buffer will be unused and the body will be skipped.
|
||||
///
|
||||
/// See `std.http.Client.BufferedConnection for an example of `bconn`.
|
||||
pub fn read(r: *HeadersParser, bconn: anytype, buffer: []u8, skip: bool) !usize {
|
||||
/// See `std.http.Client.BufferedConnection for an example of `conn`.
|
||||
pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize {
|
||||
assert(r.state.isContent());
|
||||
if (r.done) return 0;
|
||||
|
||||
@@ -526,10 +526,10 @@ pub const HeadersParser = struct {
|
||||
const data_avail = r.next_chunk_length;
|
||||
|
||||
if (skip) {
|
||||
try bconn.fill();
|
||||
try conn.fill();
|
||||
|
||||
const nread = @min(bconn.peek().len, data_avail);
|
||||
bconn.clear(@intCast(u16, nread));
|
||||
const nread = @min(conn.peek().len, data_avail);
|
||||
conn.drop(@intCast(u16, nread));
|
||||
r.next_chunk_length -= nread;
|
||||
|
||||
if (r.next_chunk_length == 0) r.done = true;
|
||||
@@ -539,7 +539,7 @@ pub const HeadersParser = struct {
|
||||
const out_avail = buffer.len;
|
||||
|
||||
const can_read = @intCast(usize, @min(data_avail, out_avail));
|
||||
const nread = try bconn.read(buffer[0..can_read]);
|
||||
const nread = try conn.read(buffer[0..can_read]);
|
||||
r.next_chunk_length -= nread;
|
||||
|
||||
if (r.next_chunk_length == 0) r.done = true;
|
||||
@@ -548,15 +548,15 @@ pub const HeadersParser = struct {
|
||||
}
|
||||
},
|
||||
.chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
|
||||
try bconn.fill();
|
||||
try conn.fill();
|
||||
|
||||
const i = r.findChunkedLen(bconn.peek());
|
||||
bconn.clear(@intCast(u16, i));
|
||||
const i = r.findChunkedLen(conn.peek());
|
||||
conn.drop(@intCast(u16, i));
|
||||
|
||||
switch (r.state) {
|
||||
.invalid => return error.HttpChunkInvalid,
|
||||
.chunk_data => if (r.next_chunk_length == 0) {
|
||||
if (std.mem.eql(u8, bconn.peek(), "\r\n")) {
|
||||
if (std.mem.eql(u8, conn.peek(), "\r\n")) {
|
||||
r.state = .finished;
|
||||
} else {
|
||||
// The trailer section is formatted identically to the header section.
|
||||
@@ -576,14 +576,14 @@ pub const HeadersParser = struct {
|
||||
const out_avail = buffer.len - out_index;
|
||||
|
||||
if (skip) {
|
||||
try bconn.fill();
|
||||
try conn.fill();
|
||||
|
||||
const nread = @min(bconn.peek().len, data_avail);
|
||||
bconn.clear(@intCast(u16, nread));
|
||||
const nread = @min(conn.peek().len, data_avail);
|
||||
conn.drop(@intCast(u16, nread));
|
||||
r.next_chunk_length -= nread;
|
||||
} else {
|
||||
const can_read = @intCast(usize, @min(data_avail, out_avail));
|
||||
const nread = try bconn.read(buffer[out_index..][0..can_read]);
|
||||
const nread = try conn.read(buffer[out_index..][0..can_read]);
|
||||
r.next_chunk_length -= nread;
|
||||
out_index += nread;
|
||||
}
|
||||
@@ -628,74 +628,74 @@ const MockBufferedConnection = struct {
|
||||
start: u16 = 0,
|
||||
end: u16 = 0,
|
||||
|
||||
pub fn fill(bconn: *MockBufferedConnection) ReadError!void {
|
||||
if (bconn.end != bconn.start) return;
|
||||
pub fn fill(conn: *MockBufferedConnection) ReadError!void {
|
||||
if (conn.end != conn.start) return;
|
||||
|
||||
const nread = try bconn.conn.read(bconn.buf[0..]);
|
||||
const nread = try conn.conn.read(conn.buf[0..]);
|
||||
if (nread == 0) return error.EndOfStream;
|
||||
bconn.start = 0;
|
||||
bconn.end = @truncate(u16, nread);
|
||||
conn.start = 0;
|
||||
conn.end = @truncate(u16, nread);
|
||||
}
|
||||
|
||||
pub fn peek(bconn: *MockBufferedConnection) []const u8 {
|
||||
return bconn.buf[bconn.start..bconn.end];
|
||||
pub fn peek(conn: *MockBufferedConnection) []const u8 {
|
||||
return conn.buf[conn.start..conn.end];
|
||||
}
|
||||
|
||||
pub fn clear(bconn: *MockBufferedConnection, num: u16) void {
|
||||
bconn.start += num;
|
||||
pub fn drop(conn: *MockBufferedConnection, num: u16) void {
|
||||
conn.start += num;
|
||||
}
|
||||
|
||||
pub fn readAtLeast(bconn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
|
||||
pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
|
||||
var out_index: u16 = 0;
|
||||
while (out_index < len) {
|
||||
const available = bconn.end - bconn.start;
|
||||
const available = conn.end - conn.start;
|
||||
const left = buffer.len - out_index;
|
||||
|
||||
if (available > 0) {
|
||||
const can_read = @truncate(u16, @min(available, left));
|
||||
|
||||
@memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]);
|
||||
@memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]);
|
||||
out_index += can_read;
|
||||
bconn.start += can_read;
|
||||
conn.start += can_read;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (left > bconn.buf.len) {
|
||||
if (left > conn.buf.len) {
|
||||
// skip the buffer if the output is large enough
|
||||
return bconn.conn.read(buffer[out_index..]);
|
||||
return conn.conn.read(buffer[out_index..]);
|
||||
}
|
||||
|
||||
try bconn.fill();
|
||||
try conn.fill();
|
||||
}
|
||||
|
||||
return out_index;
|
||||
}
|
||||
|
||||
pub fn read(bconn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
|
||||
return bconn.readAtLeast(buffer, 1);
|
||||
pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
|
||||
return conn.readAtLeast(buffer, 1);
|
||||
}
|
||||
|
||||
pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream};
|
||||
pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read);
|
||||
|
||||
pub fn reader(bconn: *MockBufferedConnection) Reader {
|
||||
return Reader{ .context = bconn };
|
||||
pub fn reader(conn: *MockBufferedConnection) Reader {
|
||||
return Reader{ .context = conn };
|
||||
}
|
||||
|
||||
pub fn writeAll(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
|
||||
return bconn.conn.writeAll(buffer);
|
||||
pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
|
||||
return conn.conn.writeAll(buffer);
|
||||
}
|
||||
|
||||
pub fn write(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
|
||||
return bconn.conn.write(buffer);
|
||||
pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
|
||||
return conn.conn.write(buffer);
|
||||
}
|
||||
|
||||
pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError;
|
||||
pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write);
|
||||
|
||||
pub fn writer(bconn: *MockBufferedConnection) Writer {
|
||||
return Writer{ .context = bconn };
|
||||
pub fn writer(conn: *MockBufferedConnection) Writer {
|
||||
return Writer{ .context = conn };
|
||||
}
|
||||
};
|
||||
|
||||
@@ -753,15 +753,15 @@ test "HeadersParser.read length" {
|
||||
const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello";
|
||||
var fbs = std.io.fixedBufferStream(data);
|
||||
|
||||
var bconn = MockBufferedConnection{
|
||||
var conn = MockBufferedConnection{
|
||||
.conn = fbs,
|
||||
};
|
||||
|
||||
while (true) { // read headers
|
||||
try bconn.fill();
|
||||
try conn.fill();
|
||||
|
||||
const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
|
||||
bconn.clear(@intCast(u16, nchecked));
|
||||
const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
|
||||
conn.drop(@intCast(u16, nchecked));
|
||||
|
||||
if (r.state.isContent()) break;
|
||||
}
|
||||
@@ -769,7 +769,7 @@ test "HeadersParser.read length" {
|
||||
var buf: [8]u8 = undefined;
|
||||
|
||||
r.next_chunk_length = 5;
|
||||
const len = try r.read(&bconn, &buf, false);
|
||||
const len = try r.read(&conn, &buf, false);
|
||||
try std.testing.expectEqual(@as(usize, 5), len);
|
||||
try std.testing.expectEqualStrings("Hello", buf[0..len]);
|
||||
|
||||
@@ -784,22 +784,22 @@ test "HeadersParser.read chunked" {
|
||||
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n";
|
||||
var fbs = std.io.fixedBufferStream(data);
|
||||
|
||||
var bconn = MockBufferedConnection{
|
||||
var conn = MockBufferedConnection{
|
||||
.conn = fbs,
|
||||
};
|
||||
|
||||
while (true) { // read headers
|
||||
try bconn.fill();
|
||||
try conn.fill();
|
||||
|
||||
const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
|
||||
bconn.clear(@intCast(u16, nchecked));
|
||||
const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
|
||||
conn.drop(@intCast(u16, nchecked));
|
||||
|
||||
if (r.state.isContent()) break;
|
||||
}
|
||||
var buf: [8]u8 = undefined;
|
||||
|
||||
r.state = .chunk_head_size;
|
||||
const len = try r.read(&bconn, &buf, false);
|
||||
const len = try r.read(&conn, &buf, false);
|
||||
try std.testing.expectEqual(@as(usize, 5), len);
|
||||
try std.testing.expectEqualStrings("Hello", buf[0..len]);
|
||||
|
||||
@@ -814,30 +814,30 @@ test "HeadersParser.read chunked trailer" {
|
||||
const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n";
|
||||
var fbs = std.io.fixedBufferStream(data);
|
||||
|
||||
var bconn = MockBufferedConnection{
|
||||
var conn = MockBufferedConnection{
|
||||
.conn = fbs,
|
||||
};
|
||||
|
||||
while (true) { // read headers
|
||||
try bconn.fill();
|
||||
try conn.fill();
|
||||
|
||||
const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
|
||||
bconn.clear(@intCast(u16, nchecked));
|
||||
const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
|
||||
conn.drop(@intCast(u16, nchecked));
|
||||
|
||||
if (r.state.isContent()) break;
|
||||
}
|
||||
var buf: [8]u8 = undefined;
|
||||
|
||||
r.state = .chunk_head_size;
|
||||
const len = try r.read(&bconn, &buf, false);
|
||||
const len = try r.read(&conn, &buf, false);
|
||||
try std.testing.expectEqual(@as(usize, 5), len);
|
||||
try std.testing.expectEqualStrings("Hello", buf[0..len]);
|
||||
|
||||
while (true) { // read headers
|
||||
try bconn.fill();
|
||||
try conn.fill();
|
||||
|
||||
const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
|
||||
bconn.clear(@intCast(u16, nchecked));
|
||||
const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
|
||||
conn.drop(@intCast(u16, nchecked));
|
||||
|
||||
if (r.state.isContent()) break;
|
||||
}
|
||||
|
||||
@@ -86,7 +86,6 @@ fn handleRequest(res: *Server.Response) !void {
|
||||
try res.writeAll("World!\n");
|
||||
// try res.finish();
|
||||
try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n");
|
||||
try res.connection.flush();
|
||||
} else if (mem.eql(u8, res.request.target, "/redirect/1")) {
|
||||
res.transfer_encoding = .chunked;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user