commit 5998a8cebe3973d70c258b2a1440c5c3252d3539 (tree)
parent 2cf15bee0325321e9da496580b55310d4ba1053f
Author: Andrew Kelley <andrew@ziglang.org>
Date: Thu, 7 Aug 2025 19:54:25 -0700
Merge pull request #24698 from ziglang/http
std: rework HTTP and TLS for new I/O API
Diffstat:
31 files changed, 3743 insertions(+), 5295 deletions(-)
diff --git a/lib/compiler/resinator/cli.zig b/lib/compiler/resinator/cli.zig
@@ -1141,6 +1141,8 @@ pub fn parse(allocator: Allocator, args: []const []const u8, diagnostics: *Diagn
}
output_format = .res;
}
+ } else {
+ output_format_source = .output_format_arg;
}
options.output_source = .{ .filename = try filepathWithExtension(allocator, options.input_source.filename, output_format.?.extension()) };
} else {
@@ -1529,21 +1531,21 @@ fn testParseOutput(args: []const []const u8, expected_output: []const u8) !?Opti
var diagnostics = Diagnostics.init(std.testing.allocator);
defer diagnostics.deinit();
- var output = std.ArrayList(u8).init(std.testing.allocator);
+ var output: std.io.Writer.Allocating = .init(std.testing.allocator);
defer output.deinit();
var options = parse(std.testing.allocator, args, &diagnostics) catch |err| switch (err) {
error.ParseError => {
- try diagnostics.renderToWriter(args, output.writer(), .no_color);
- try std.testing.expectEqualStrings(expected_output, output.items);
+ try diagnostics.renderToWriter(args, &output.writer, .no_color);
+ try std.testing.expectEqualStrings(expected_output, output.getWritten());
return null;
},
else => |e| return e,
};
errdefer options.deinit();
- try diagnostics.renderToWriter(args, output.writer(), .no_color);
- try std.testing.expectEqualStrings(expected_output, output.items);
+ try diagnostics.renderToWriter(args, &output.writer, .no_color);
+ try std.testing.expectEqualStrings(expected_output, output.getWritten());
return options;
}
diff --git a/lib/compiler/resinator/compile.zig b/lib/compiler/resinator/compile.zig
@@ -550,7 +550,7 @@ pub const Compiler = struct {
// so get it here to simplify future usage.
const filename_token = node.filename.getFirstToken();
- const file = self.searchForFile(filename_utf8) catch |err| switch (err) {
+ const file_handle = self.searchForFile(filename_utf8) catch |err| switch (err) {
error.OutOfMemory => |e| return e,
else => |e| {
const filename_string_index = try self.diagnostics.putString(filename_utf8);
@@ -564,13 +564,15 @@ pub const Compiler = struct {
});
},
};
- defer file.close();
+ defer file_handle.close();
+ var file_buffer: [2048]u8 = undefined;
+ var file_reader = file_handle.reader(&file_buffer);
if (maybe_predefined_type) |predefined_type| {
switch (predefined_type) {
.GROUP_ICON, .GROUP_CURSOR => {
// Check for animated icon first
- if (ani.isAnimatedIcon(file.deprecatedReader())) {
+ if (ani.isAnimatedIcon(file_reader.interface.adaptToOldInterface())) {
// Animated icons are just put into the resource unmodified,
// and the resource type changes to ANIICON/ANICURSOR
@@ -582,18 +584,18 @@ pub const Compiler = struct {
header.type_value.ordinal = @intFromEnum(new_predefined_type);
header.memory_flags = MemoryFlags.defaults(new_predefined_type);
header.applyMemoryFlags(node.common_resource_attributes, self.source);
- header.data_size = @intCast(try file.getEndPos());
+ header.data_size = @intCast(try file_reader.getSize());
try header.write(writer, self.errContext(node.id));
- try file.seekTo(0);
- try writeResourceData(writer, file.deprecatedReader(), header.data_size);
+ try file_reader.seekTo(0);
+ try writeResourceData(writer, &file_reader.interface, header.data_size);
return;
}
// isAnimatedIcon moved the file cursor so reset to the start
- try file.seekTo(0);
+ try file_reader.seekTo(0);
- const icon_dir = ico.read(self.allocator, file.deprecatedReader(), try file.getEndPos()) catch |err| switch (err) {
+ const icon_dir = ico.read(self.allocator, file_reader.interface.adaptToOldInterface(), try file_reader.getSize()) catch |err| switch (err) {
error.OutOfMemory => |e| return e,
else => |e| {
return self.iconReadError(
@@ -671,15 +673,15 @@ pub const Compiler = struct {
try writer.writeInt(u16, entry.type_specific_data.cursor.hotspot_y, .little);
}
- try file.seekTo(entry.data_offset_from_start_of_file);
- var header_bytes = file.deprecatedReader().readBytesNoEof(16) catch {
+ try file_reader.seekTo(entry.data_offset_from_start_of_file);
+ var header_bytes = (file_reader.interface.takeArray(16) catch {
return self.iconReadError(
error.UnexpectedEOF,
filename_utf8,
filename_token,
predefined_type,
);
- };
+ }).*;
const image_format = ico.ImageFormat.detect(&header_bytes);
if (!image_format.validate(&header_bytes)) {
@@ -802,8 +804,8 @@ pub const Compiler = struct {
},
}
- try file.seekTo(entry.data_offset_from_start_of_file);
- try writeResourceDataNoPadding(writer, file.deprecatedReader(), entry.data_size_in_bytes);
+ try file_reader.seekTo(entry.data_offset_from_start_of_file);
+ try writeResourceDataNoPadding(writer, &file_reader.interface, entry.data_size_in_bytes);
try writeDataPadding(writer, full_data_size);
if (self.state.icon_id == std.math.maxInt(u16)) {
@@ -857,9 +859,9 @@ pub const Compiler = struct {
},
.BITMAP => {
header.applyMemoryFlags(node.common_resource_attributes, self.source);
- const file_size = try file.getEndPos();
+ const file_size = try file_reader.getSize();
- const bitmap_info = bmp.read(file.deprecatedReader(), file_size) catch |err| {
+ const bitmap_info = bmp.read(file_reader.interface.adaptToOldInterface(), file_size) catch |err| {
const filename_string_index = try self.diagnostics.putString(filename_utf8);
return self.addErrorDetailsAndFail(.{
.err = .bmp_read_error,
@@ -921,18 +923,17 @@ pub const Compiler = struct {
header.data_size = bmp_bytes_to_write;
try header.write(writer, self.errContext(node.id));
- try file.seekTo(bmp.file_header_len);
- const file_reader = file.deprecatedReader();
- try writeResourceDataNoPadding(writer, file_reader, bitmap_info.dib_header_size);
+ try file_reader.seekTo(bmp.file_header_len);
+ try writeResourceDataNoPadding(writer, &file_reader.interface, bitmap_info.dib_header_size);
if (bitmap_info.getBitmasksByteLen() > 0) {
- try writeResourceDataNoPadding(writer, file_reader, bitmap_info.getBitmasksByteLen());
+ try writeResourceDataNoPadding(writer, &file_reader.interface, bitmap_info.getBitmasksByteLen());
}
if (bitmap_info.getExpectedPaletteByteLen() > 0) {
- try writeResourceDataNoPadding(writer, file_reader, @intCast(bitmap_info.getActualPaletteByteLen()));
+ try writeResourceDataNoPadding(writer, &file_reader.interface, @intCast(bitmap_info.getActualPaletteByteLen()));
}
- try file.seekTo(bitmap_info.pixel_data_offset);
+ try file_reader.seekTo(bitmap_info.pixel_data_offset);
const pixel_bytes: u32 = @intCast(file_size - bitmap_info.pixel_data_offset);
- try writeResourceDataNoPadding(writer, file_reader, pixel_bytes);
+ try writeResourceDataNoPadding(writer, &file_reader.interface, pixel_bytes);
try writeDataPadding(writer, bmp_bytes_to_write);
return;
},
@@ -956,7 +957,7 @@ pub const Compiler = struct {
return;
}
header.applyMemoryFlags(node.common_resource_attributes, self.source);
- const file_size = try file.getEndPos();
+ const file_size = try file_reader.getSize();
if (file_size > std.math.maxInt(u32)) {
return self.addErrorDetailsAndFail(.{
.err = .resource_data_size_exceeds_max,
@@ -968,8 +969,9 @@ pub const Compiler = struct {
header.data_size = @intCast(file_size);
try header.write(writer, self.errContext(node.id));
- var header_slurping_reader = headerSlurpingReader(148, file.deprecatedReader());
- try writeResourceData(writer, header_slurping_reader.reader(), header.data_size);
+ var header_slurping_reader = headerSlurpingReader(148, file_reader.interface.adaptToOldInterface());
+ var adapter = header_slurping_reader.reader().adaptToNewApi(&.{});
+ try writeResourceData(writer, &adapter.new_interface, header.data_size);
try self.state.font_dir.add(self.arena, FontDir.Font{
.id = header.name_value.ordinal,
@@ -992,7 +994,7 @@ pub const Compiler = struct {
}
// Fallback to just writing out the entire contents of the file
- const data_size = try file.getEndPos();
+ const data_size = try file_reader.getSize();
if (data_size > std.math.maxInt(u32)) {
return self.addErrorDetailsAndFail(.{
.err = .resource_data_size_exceeds_max,
@@ -1002,7 +1004,7 @@ pub const Compiler = struct {
// We now know that the data size will fit in a u32
header.data_size = @intCast(data_size);
try header.write(writer, self.errContext(node.id));
- try writeResourceData(writer, file.deprecatedReader(), header.data_size);
+ try writeResourceData(writer, &file_reader.interface, header.data_size);
}
fn iconReadError(
@@ -1250,8 +1252,8 @@ pub const Compiler = struct {
const data_len: u32 = @intCast(data_buffer.items.len);
try self.writeResourceHeader(writer, node.id, node.type, data_len, node.common_resource_attributes, self.state.language);
- var data_fbs = std.io.fixedBufferStream(data_buffer.items);
- try writeResourceData(writer, data_fbs.reader(), data_len);
+ var data_fbs: std.Io.Reader = .fixed(data_buffer.items);
+ try writeResourceData(writer, &data_fbs, data_len);
}
pub fn writeResourceHeader(self: *Compiler, writer: anytype, id_token: Token, type_token: Token, data_size: u32, common_resource_attributes: []Token, language: res.Language) !void {
@@ -1266,15 +1268,15 @@ pub const Compiler = struct {
try header.write(writer, self.errContext(id_token));
}
- pub fn writeResourceDataNoPadding(writer: anytype, data_reader: anytype, data_size: u32) !void {
- var limited_reader = std.io.limitedReader(data_reader, data_size);
-
- const FifoBuffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 });
- var fifo = FifoBuffer.init();
- try fifo.pump(limited_reader.reader(), writer);
+ pub fn writeResourceDataNoPadding(writer: anytype, data_reader: *std.Io.Reader, data_size: u32) !void {
+ var adapted = writer.adaptToNewApi();
+ var buffer: [128]u8 = undefined;
+ adapted.new_interface.buffer = &buffer;
+ try data_reader.streamExact(&adapted.new_interface, data_size);
+ try adapted.new_interface.flush();
}
- pub fn writeResourceData(writer: anytype, data_reader: anytype, data_size: u32) !void {
+ pub fn writeResourceData(writer: anytype, data_reader: *std.Io.Reader, data_size: u32) !void {
try writeResourceDataNoPadding(writer, data_reader, data_size);
try writeDataPadding(writer, data_size);
}
@@ -1339,8 +1341,8 @@ pub const Compiler = struct {
try header.write(writer, self.errContext(node.id));
- var data_fbs = std.io.fixedBufferStream(data_buffer.items);
- try writeResourceData(writer, data_fbs.reader(), data_size);
+ var data_fbs: std.Io.Reader = .fixed(data_buffer.items);
+ try writeResourceData(writer, &data_fbs, data_size);
}
/// Expects `data_writer` to be a LimitedWriter limited to u32, meaning all writes to
@@ -1732,8 +1734,8 @@ pub const Compiler = struct {
try header.write(writer, self.errContext(node.id));
- var data_fbs = std.io.fixedBufferStream(data_buffer.items);
- try writeResourceData(writer, data_fbs.reader(), data_size);
+ var data_fbs: std.Io.Reader = .fixed(data_buffer.items);
+ try writeResourceData(writer, &data_fbs, data_size);
}
fn writeDialogHeaderAndStrings(
@@ -2046,8 +2048,8 @@ pub const Compiler = struct {
try header.write(writer, self.errContext(node.id));
- var data_fbs = std.io.fixedBufferStream(data_buffer.items);
- try writeResourceData(writer, data_fbs.reader(), data_size);
+ var data_fbs: std.Io.Reader = .fixed(data_buffer.items);
+ try writeResourceData(writer, &data_fbs, data_size);
}
/// Weight and italic carry over from previous FONT statements within a single resource,
@@ -2121,8 +2123,8 @@ pub const Compiler = struct {
try header.write(writer, self.errContext(node.id));
- var data_fbs = std.io.fixedBufferStream(data_buffer.items);
- try writeResourceData(writer, data_fbs.reader(), data_size);
+ var data_fbs: std.Io.Reader = .fixed(data_buffer.items);
+ try writeResourceData(writer, &data_fbs, data_size);
}
/// Expects `data_writer` to be a LimitedWriter limited to u32, meaning all writes to
@@ -2386,8 +2388,8 @@ pub const Compiler = struct {
try header.write(writer, self.errContext(node.id));
- var data_fbs = std.io.fixedBufferStream(data_buffer.items);
- try writeResourceData(writer, data_fbs.reader(), data_size);
+ var data_fbs: std.Io.Reader = .fixed(data_buffer.items);
+ try writeResourceData(writer, &data_fbs, data_size);
}
/// Expects writer to be a LimitedWriter limited to u16, meaning all writes to
@@ -3321,8 +3323,8 @@ pub const StringTable = struct {
// we fully control and know are numbers, so they have a fixed size.
try header.writeAssertNoOverflow(writer);
- var data_fbs = std.io.fixedBufferStream(data_buffer.items);
- try Compiler.writeResourceData(writer, data_fbs.reader(), data_size);
+ var data_fbs: std.Io.Reader = .fixed(data_buffer.items);
+ try Compiler.writeResourceData(writer, &data_fbs, data_size);
}
};
diff --git a/lib/compiler/resinator/cvtres.zig b/lib/compiler/resinator/cvtres.zig
@@ -65,7 +65,7 @@ pub const ParseResOptions = struct {
};
/// The returned ParsedResources should be freed by calling its `deinit` function.
-pub fn parseRes(allocator: Allocator, reader: anytype, options: ParseResOptions) !ParsedResources {
+pub fn parseRes(allocator: Allocator, reader: *std.Io.Reader, options: ParseResOptions) !ParsedResources {
var resources = ParsedResources.init(allocator);
errdefer resources.deinit();
@@ -74,7 +74,7 @@ pub fn parseRes(allocator: Allocator, reader: anytype, options: ParseResOptions)
return resources;
}
-pub fn parseResInto(resources: *ParsedResources, reader: anytype, options: ParseResOptions) !void {
+pub fn parseResInto(resources: *ParsedResources, reader: *std.Io.Reader, options: ParseResOptions) !void {
const allocator = resources.allocator;
var bytes_remaining: u64 = options.max_size;
{
@@ -103,43 +103,38 @@ pub const ResourceAndSize = struct {
total_size: u64,
};
-pub fn parseResource(allocator: Allocator, reader: anytype, max_size: u64) !ResourceAndSize {
- var header_counting_reader = std.io.countingReader(reader);
- const header_reader = header_counting_reader.reader();
- const data_size = try header_reader.readInt(u32, .little);
- const header_size = try header_reader.readInt(u32, .little);
+pub fn parseResource(allocator: Allocator, reader: *std.Io.Reader, max_size: u64) !ResourceAndSize {
+ const data_size = try reader.takeInt(u32, .little);
+ const header_size = try reader.takeInt(u32, .little);
const total_size: u64 = @as(u64, header_size) + data_size;
if (total_size > max_size) return error.ImpossibleSize;
- var header_bytes_available = header_size -| 8;
- var type_reader = std.io.limitedReader(header_reader, header_bytes_available);
- const type_value = try parseNameOrOrdinal(allocator, type_reader.reader());
+ const remaining_header_bytes = try reader.take(header_size -| 8);
+ var remaining_header_reader: std.Io.Reader = .fixed(remaining_header_bytes);
+ const type_value = try parseNameOrOrdinal(allocator, &remaining_header_reader);
errdefer type_value.deinit(allocator);
- header_bytes_available -|= @intCast(type_value.byteLen());
- var name_reader = std.io.limitedReader(header_reader, header_bytes_available);
- const name_value = try parseNameOrOrdinal(allocator, name_reader.reader());
+ const name_value = try parseNameOrOrdinal(allocator, &remaining_header_reader);
errdefer name_value.deinit(allocator);
- const padding_after_name = numPaddingBytesNeeded(@intCast(header_counting_reader.bytes_read));
- try header_reader.skipBytes(padding_after_name, .{ .buf_size = 3 });
+ const padding_after_name = numPaddingBytesNeeded(@intCast(remaining_header_reader.seek));
+ try remaining_header_reader.discardAll(padding_after_name);
- std.debug.assert(header_counting_reader.bytes_read % 4 == 0);
- const data_version = try header_reader.readInt(u32, .little);
- const memory_flags: MemoryFlags = @bitCast(try header_reader.readInt(u16, .little));
- const language: Language = @bitCast(try header_reader.readInt(u16, .little));
- const version = try header_reader.readInt(u32, .little);
- const characteristics = try header_reader.readInt(u32, .little);
+ std.debug.assert(remaining_header_reader.seek % 4 == 0);
+ const data_version = try remaining_header_reader.takeInt(u32, .little);
+ const memory_flags: MemoryFlags = @bitCast(try remaining_header_reader.takeInt(u16, .little));
+ const language: Language = @bitCast(try remaining_header_reader.takeInt(u16, .little));
+ const version = try remaining_header_reader.takeInt(u32, .little);
+ const characteristics = try remaining_header_reader.takeInt(u32, .little);
- const header_bytes_read = header_counting_reader.bytes_read;
- if (header_size != header_bytes_read) return error.HeaderSizeMismatch;
+ if (remaining_header_reader.seek != remaining_header_reader.end) return error.HeaderSizeMismatch;
const data = try allocator.alloc(u8, data_size);
errdefer allocator.free(data);
- try reader.readNoEof(data);
+ try reader.readSliceAll(data);
const padding_after_data = numPaddingBytesNeeded(@intCast(data_size));
- try reader.skipBytes(padding_after_data, .{ .buf_size = 3 });
+ try reader.discardAll(padding_after_data);
return .{
.resource = .{
@@ -156,10 +151,10 @@ pub fn parseResource(allocator: Allocator, reader: anytype, max_size: u64) !Reso
};
}
-pub fn parseNameOrOrdinal(allocator: Allocator, reader: anytype) !NameOrOrdinal {
- const first_code_unit = try reader.readInt(u16, .little);
+pub fn parseNameOrOrdinal(allocator: Allocator, reader: *std.Io.Reader) !NameOrOrdinal {
+ const first_code_unit = try reader.takeInt(u16, .little);
if (first_code_unit == 0xFFFF) {
- const ordinal_value = try reader.readInt(u16, .little);
+ const ordinal_value = try reader.takeInt(u16, .little);
return .{ .ordinal = ordinal_value };
}
var name_buf = try std.ArrayListUnmanaged(u16).initCapacity(allocator, 16);
@@ -167,7 +162,7 @@ pub fn parseNameOrOrdinal(allocator: Allocator, reader: anytype) !NameOrOrdinal
var code_unit = first_code_unit;
while (code_unit != 0) {
try name_buf.append(allocator, std.mem.nativeToLittle(u16, code_unit));
- code_unit = try reader.readInt(u16, .little);
+ code_unit = try reader.takeInt(u16, .little);
}
return .{ .name = try name_buf.toOwnedSliceSentinel(allocator, 0) };
}
diff --git a/lib/compiler/resinator/errors.zig b/lib/compiler/resinator/errors.zig
@@ -1078,11 +1078,9 @@ const CorrespondingLines = struct {
at_eof: bool = false,
span: SourceMappings.CorrespondingSpan,
file: std.fs.File,
- buffered_reader: BufferedReaderType,
+ buffered_reader: std.fs.File.Reader,
code_page: SupportedCodePage,
- const BufferedReaderType = std.io.BufferedReader(512, std.fs.File.DeprecatedReader);
-
pub fn init(cwd: std.fs.Dir, err_details: ErrorDetails, line_for_comparison: []const u8, corresponding_span: SourceMappings.CorrespondingSpan, corresponding_file: []const u8) !CorrespondingLines {
// We don't do line comparison for this error, so don't print the note if the line
// number is different
@@ -1101,9 +1099,7 @@ const CorrespondingLines = struct {
.buffered_reader = undefined,
.code_page = err_details.code_page,
};
- corresponding_lines.buffered_reader = BufferedReaderType{
- .unbuffered_reader = corresponding_lines.file.deprecatedReader(),
- };
+ corresponding_lines.buffered_reader = corresponding_lines.file.reader(&.{});
errdefer corresponding_lines.deinit();
var fbs = std.io.fixedBufferStream(&corresponding_lines.line_buf);
@@ -1111,7 +1107,7 @@ const CorrespondingLines = struct {
try corresponding_lines.writeLineFromStreamVerbatim(
writer,
- corresponding_lines.buffered_reader.reader(),
+ corresponding_lines.buffered_reader.interface.adaptToOldInterface(),
corresponding_span.start_line,
);
@@ -1154,7 +1150,7 @@ const CorrespondingLines = struct {
try self.writeLineFromStreamVerbatim(
writer,
- self.buffered_reader.reader(),
+ self.buffered_reader.interface.adaptToOldInterface(),
self.line_num,
);
diff --git a/lib/compiler/resinator/ico.zig b/lib/compiler/resinator/ico.zig
@@ -14,8 +14,9 @@ pub fn read(allocator: std.mem.Allocator, reader: anytype, max_size: u64) ReadEr
// Some Reader implementations have an empty ReadError error set which would
// cause 'unreachable else' if we tried to use an else in the switch, so we
// need to detect this case and not try to translate to ReadError
+ const anyerror_reader_errorset = @TypeOf(reader).Error == anyerror;
const empty_reader_errorset = @typeInfo(@TypeOf(reader).Error).error_set == null or @typeInfo(@TypeOf(reader).Error).error_set.?.len == 0;
- if (empty_reader_errorset) {
+ if (empty_reader_errorset and !anyerror_reader_errorset) {
return readAnyError(allocator, reader, max_size) catch |err| switch (err) {
error.EndOfStream => error.UnexpectedEOF,
else => |e| return e,
diff --git a/lib/compiler/resinator/main.zig b/lib/compiler/resinator/main.zig
@@ -325,8 +325,8 @@ pub fn main() !void {
std.debug.assert(options.output_format == .coff);
// TODO: Maybe use a buffered file reader instead of reading file into memory -> fbs
- var fbs = std.io.fixedBufferStream(res_data.bytes);
- break :resources cvtres.parseRes(allocator, fbs.reader(), .{ .max_size = res_data.bytes.len }) catch |err| {
+ var res_reader: std.Io.Reader = .fixed(res_data.bytes);
+ break :resources cvtres.parseRes(allocator, &res_reader, .{ .max_size = res_data.bytes.len }) catch |err| {
// TODO: Better errors
try error_handler.emitMessage(allocator, .err, "unable to parse res from '{s}': {s}", .{ res_stream.name, @errorName(err) });
std.process.exit(1);
diff --git a/lib/docs/wasm/markdown.zig b/lib/docs/wasm/markdown.zig
@@ -145,13 +145,12 @@ fn mainImpl() !void {
var parser = try Parser.init(gpa);
defer parser.deinit();
- var stdin_buf = std.io.bufferedReader(std.fs.File.stdin().deprecatedReader());
- var line_buf = std.ArrayList(u8).init(gpa);
- defer line_buf.deinit();
- while (stdin_buf.reader().streamUntilDelimiter(line_buf.writer(), '\n', null)) {
- if (line_buf.getLastOrNull() == '\r') _ = line_buf.pop();
- try parser.feedLine(line_buf.items);
- line_buf.clearRetainingCapacity();
+ var stdin_buffer: [1024]u8 = undefined;
+ var stdin_reader = std.fs.File.stdin().reader(&stdin_buffer);
+
+ while (stdin_reader.takeDelimiterExclusive('\n')) |line| {
+ const trimmed = std.mem.trimRight(u8, line, '\r');
+ try parser.feedLine(trimmed);
} else |err| switch (err) {
error.EndOfStream => {},
else => |e| return e,
diff --git a/lib/std/Build/Fuzz.zig b/lib/std/Build/Fuzz.zig
@@ -234,7 +234,7 @@ pub const Previous = struct {
};
pub fn sendUpdate(
fuzz: *Fuzz,
- socket: *std.http.WebSocket,
+ socket: *std.http.Server.WebSocket,
prev: *Previous,
) !void {
fuzz.coverage_mutex.lock();
@@ -263,36 +263,36 @@ pub fn sendUpdate(
.string_bytes_len = @intCast(coverage_map.coverage.string_bytes.items.len),
.start_timestamp = coverage_map.start_timestamp,
};
- const iovecs: [5]std.posix.iovec_const = .{
- makeIov(@ptrCast(&header)),
- makeIov(@ptrCast(coverage_map.coverage.directories.keys())),
- makeIov(@ptrCast(coverage_map.coverage.files.keys())),
- makeIov(@ptrCast(coverage_map.source_locations)),
- makeIov(coverage_map.coverage.string_bytes.items),
+ var iovecs: [5][]const u8 = .{
+ @ptrCast(&header),
+ @ptrCast(coverage_map.coverage.directories.keys()),
+ @ptrCast(coverage_map.coverage.files.keys()),
+ @ptrCast(coverage_map.source_locations),
+ coverage_map.coverage.string_bytes.items,
};
- try socket.writeMessagev(&iovecs, .binary);
+ try socket.writeMessageVec(&iovecs, .binary);
}
const header: abi.CoverageUpdateHeader = .{
.n_runs = n_runs,
.unique_runs = unique_runs,
};
- const iovecs: [2]std.posix.iovec_const = .{
- makeIov(@ptrCast(&header)),
- makeIov(@ptrCast(seen_pcs)),
+ var iovecs: [2][]const u8 = .{
+ @ptrCast(&header),
+ @ptrCast(seen_pcs),
};
- try socket.writeMessagev(&iovecs, .binary);
+ try socket.writeMessageVec(&iovecs, .binary);
prev.unique_runs = unique_runs;
}
if (prev.entry_points != coverage_map.entry_points.items.len) {
const header: abi.EntryPointHeader = .init(@intCast(coverage_map.entry_points.items.len));
- const iovecs: [2]std.posix.iovec_const = .{
- makeIov(@ptrCast(&header)),
- makeIov(@ptrCast(coverage_map.entry_points.items)),
+ var iovecs: [2][]const u8 = .{
+ @ptrCast(&header),
+ @ptrCast(coverage_map.entry_points.items),
};
- try socket.writeMessagev(&iovecs, .binary);
+ try socket.writeMessageVec(&iovecs, .binary);
prev.entry_points = coverage_map.entry_points.items.len;
}
@@ -448,10 +448,3 @@ fn addEntryPoint(fuzz: *Fuzz, coverage_id: u64, addr: u64) error{ AlreadyReporte
}
try coverage_map.entry_points.append(fuzz.ws.gpa, @intCast(index));
}
-
-fn makeIov(s: []const u8) std.posix.iovec_const {
- return .{
- .base = s.ptr,
- .len = s.len,
- };
-}
diff --git a/lib/std/Build/WebServer.zig b/lib/std/Build/WebServer.zig
@@ -251,48 +251,44 @@ pub fn now(s: *const WebServer) i64 {
fn accept(ws: *WebServer, connection: std.net.Server.Connection) void {
defer connection.stream.close();
- var read_buf: [0x4000]u8 = undefined;
- var server: std.http.Server = .init(connection, &read_buf);
+ var send_buffer: [4096]u8 = undefined;
+ var recv_buffer: [4096]u8 = undefined;
+ var connection_reader = connection.stream.reader(&recv_buffer);
+ var connection_writer = connection.stream.writer(&send_buffer);
+ var server: http.Server = .init(connection_reader.interface(), &connection_writer.interface);
while (true) {
var request = server.receiveHead() catch |err| switch (err) {
error.HttpConnectionClosing => return,
- else => {
- log.err("failed to receive http request: {s}", .{@errorName(err)});
- return;
- },
+ else => return log.err("failed to receive http request: {t}", .{err}),
};
- var ws_send_buf: [0x4000]u8 = undefined;
- var ws_recv_buf: [0x4000]u8 align(4) = undefined;
- if (std.http.WebSocket.init(&request, &ws_send_buf, &ws_recv_buf) catch |err| {
- log.err("failed to initialize websocket connection: {s}", .{@errorName(err)});
- return;
- }) |ws_init| {
- var web_socket = ws_init;
- ws.serveWebSocket(&web_socket) catch |err| {
- log.err("failed to serve websocket: {s}", .{@errorName(err)});
- return;
- };
- comptime unreachable;
- } else {
- ws.serveRequest(&request) catch |err| switch (err) {
- error.AlreadyReported => return,
- else => {
- log.err("failed to serve '{s}': {s}", .{ request.head.target, @errorName(err) });
+ switch (request.upgradeRequested()) {
+ .websocket => |opt_key| {
+ const key = opt_key orelse return log.err("missing websocket key", .{});
+ var web_socket = request.respondWebSocket(.{ .key = key }) catch {
+ return log.err("failed to respond web socket: {t}", .{connection_writer.err.?});
+ };
+ ws.serveWebSocket(&web_socket) catch |err| {
+ log.err("failed to serve websocket: {t}", .{err});
return;
- },
- };
+ };
+ comptime unreachable;
+ },
+ .other => |name| return log.err("unknown upgrade request: {s}", .{name}),
+ .none => {
+ ws.serveRequest(&request) catch |err| switch (err) {
+ error.AlreadyReported => return,
+ else => {
+ log.err("failed to serve '{s}': {t}", .{ request.head.target, err });
+ return;
+ },
+ };
+ },
}
}
}
-fn makeIov(s: []const u8) std.posix.iovec_const {
- return .{
- .base = s.ptr,
- .len = s.len,
- };
-}
-fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn {
+fn serveWebSocket(ws: *WebServer, sock: *http.Server.WebSocket) !noreturn {
var prev_build_status = ws.build_status.load(.monotonic);
const prev_step_status_bits = try ws.gpa.alloc(u8, ws.step_status_bits.len);
@@ -312,11 +308,8 @@ fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn {
.timestamp = ws.now(),
.steps_len = @intCast(ws.all_steps.len),
};
- try sock.writeMessagev(&.{
- makeIov(@ptrCast(&hello_header)),
- makeIov(ws.step_names_trailing),
- makeIov(prev_step_status_bits),
- }, .binary);
+ var bufs: [3][]const u8 = .{ @ptrCast(&hello_header), ws.step_names_trailing, prev_step_status_bits };
+ try sock.writeMessageVec(&bufs, .binary);
}
var prev_fuzz: Fuzz.Previous = .init;
@@ -380,7 +373,7 @@ fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn {
std.Thread.Futex.timedWait(&ws.update_id, start_update_id, std.time.ns_per_ms * default_update_interval_ms) catch {};
}
}
-fn recvWebSocketMessages(ws: *WebServer, sock: *std.http.WebSocket) void {
+fn recvWebSocketMessages(ws: *WebServer, sock: *http.Server.WebSocket) void {
while (true) {
const msg = sock.readSmallMessage() catch return;
if (msg.opcode != .binary) continue;
@@ -402,7 +395,7 @@ fn recvWebSocketMessages(ws: *WebServer, sock: *std.http.WebSocket) void {
}
}
-fn serveRequest(ws: *WebServer, req: *std.http.Server.Request) !void {
+fn serveRequest(ws: *WebServer, req: *http.Server.Request) !void {
// Strip an optional leading '/debug' component from the request.
const target: []const u8, const debug: bool = target: {
if (mem.eql(u8, req.head.target, "/debug")) break :target .{ "/", true };
@@ -431,7 +424,7 @@ fn serveRequest(ws: *WebServer, req: *std.http.Server.Request) !void {
fn serveLibFile(
ws: *WebServer,
- request: *std.http.Server.Request,
+ request: *http.Server.Request,
sub_path: []const u8,
content_type: []const u8,
) !void {
@@ -442,7 +435,7 @@ fn serveLibFile(
}
fn serveClientWasm(
ws: *WebServer,
- req: *std.http.Server.Request,
+ req: *http.Server.Request,
optimize_mode: std.builtin.OptimizeMode,
) !void {
var arena_state: std.heap.ArenaAllocator = .init(ws.gpa);
@@ -456,12 +449,12 @@ fn serveClientWasm(
pub fn serveFile(
ws: *WebServer,
- request: *std.http.Server.Request,
+ request: *http.Server.Request,
path: Cache.Path,
content_type: []const u8,
) !void {
const gpa = ws.gpa;
- // The desired API is actually sendfile, which will require enhancing std.http.Server.
+ // The desired API is actually sendfile, which will require enhancing http.Server.
// We load the file with every request so that the user can make changes to the file
// and refresh the HTML page without restarting this server.
const file_contents = path.root_dir.handle.readFileAlloc(gpa, path.sub_path, 10 * 1024 * 1024) catch |err| {
@@ -478,14 +471,13 @@ pub fn serveFile(
}
pub fn serveTarFile(
ws: *WebServer,
- request: *std.http.Server.Request,
+ request: *http.Server.Request,
paths: []const Cache.Path,
) !void {
const gpa = ws.gpa;
- var send_buf: [0x4000]u8 = undefined;
- var response = request.respondStreaming(.{
- .send_buffer = &send_buf,
+ var send_buffer: [0x4000]u8 = undefined;
+ var response = try request.respondStreaming(&send_buffer, .{
.respond_options = .{
.extra_headers = &.{
.{ .name = "Content-Type", .value = "application/x-tar" },
@@ -497,10 +489,7 @@ pub fn serveTarFile(
var cached_cwd_path: ?[]const u8 = null;
defer if (cached_cwd_path) |p| gpa.free(p);
- var response_buf: [1024]u8 = undefined;
- var adapter = response.writer().adaptToNewApi();
- adapter.new_interface.buffer = &response_buf;
- var archiver: std.tar.Writer = .{ .underlying_writer = &adapter.new_interface };
+ var archiver: std.tar.Writer = .{ .underlying_writer = &response.writer };
for (paths) |path| {
var file = path.root_dir.handle.openFile(path.sub_path, .{}) catch |err| {
@@ -526,7 +515,6 @@ pub fn serveTarFile(
}
// intentionally not calling `archiver.finishPedantically`
- try adapter.new_interface.flush();
try response.end();
}
@@ -804,7 +792,7 @@ pub fn wait(ws: *WebServer) RunnerRequest {
}
}
-const cache_control_header: std.http.Header = .{
+const cache_control_header: http.Header = .{
.name = "Cache-Control",
.value = "max-age=0, must-revalidate",
};
@@ -819,5 +807,6 @@ const Build = std.Build;
const Cache = Build.Cache;
const Fuzz = Build.Fuzz;
const abi = Build.abi;
+const http = std.http;
const WebServer = @This();
diff --git a/lib/std/Io.zig b/lib/std/Io.zig
@@ -428,19 +428,9 @@ pub const BufferedWriter = @import("Io/buffered_writer.zig").BufferedWriter;
/// Deprecated in favor of `Writer`.
pub const bufferedWriter = @import("Io/buffered_writer.zig").bufferedWriter;
/// Deprecated in favor of `Reader`.
-pub const BufferedReader = @import("Io/buffered_reader.zig").BufferedReader;
-/// Deprecated in favor of `Reader`.
-pub const bufferedReader = @import("Io/buffered_reader.zig").bufferedReader;
-/// Deprecated in favor of `Reader`.
-pub const bufferedReaderSize = @import("Io/buffered_reader.zig").bufferedReaderSize;
-/// Deprecated in favor of `Reader`.
pub const FixedBufferStream = @import("Io/fixed_buffer_stream.zig").FixedBufferStream;
/// Deprecated in favor of `Reader`.
pub const fixedBufferStream = @import("Io/fixed_buffer_stream.zig").fixedBufferStream;
-/// Deprecated in favor of `Reader.Limited`.
-pub const LimitedReader = @import("Io/limited_reader.zig").LimitedReader;
-/// Deprecated in favor of `Reader.Limited`.
-pub const limitedReader = @import("Io/limited_reader.zig").limitedReader;
/// Deprecated with no replacement; inefficient pattern
pub const CountingWriter = @import("Io/counting_writer.zig").CountingWriter;
/// Deprecated with no replacement; inefficient pattern
@@ -926,7 +916,6 @@ pub fn PollFiles(comptime StreamEnum: type) type {
test {
_ = Reader;
_ = Writer;
- _ = BufferedReader;
_ = BufferedWriter;
_ = CountingWriter;
_ = CountingReader;
diff --git a/lib/std/Io/Reader.zig b/lib/std/Io/Reader.zig
@@ -367,8 +367,11 @@ pub fn appendRemainingUnlimited(
const buffer_contents = r.buffer[r.seek..r.end];
try list.ensureUnusedCapacity(gpa, buffer_contents.len + bump);
list.appendSliceAssumeCapacity(buffer_contents);
- r.seek = 0;
- r.end = 0;
+ // If statement protects `ending`.
+ if (r.end != 0) {
+ r.seek = 0;
+ r.end = 0;
+ }
// From here, we leave `buffer` empty, appending directly to `list`.
var writer: Writer = .{
.buffer = undefined,
@@ -1306,31 +1309,6 @@ pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void {
r.end = data.len;
}
-/// Advances the stream and decreases the size of the storage buffer by `n`,
-/// returning the range of bytes no longer accessible by `r`.
-///
-/// This action can be undone by `restitute`.
-///
-/// Asserts there are at least `n` buffered bytes already.
-///
-/// Asserts that `r.seek` is zero, i.e. the buffer is in a rebased state.
-pub fn steal(r: *Reader, n: usize) []u8 {
- assert(r.seek == 0);
- assert(n <= r.end);
- const stolen = r.buffer[0..n];
- r.buffer = r.buffer[n..];
- r.end -= n;
- return stolen;
-}
-
-/// Expands the storage buffer, undoing the effects of `steal`
-/// Assumes that `n` does not exceed the total number of stolen bytes.
-pub fn restitute(r: *Reader, n: usize) void {
- r.buffer = (r.buffer.ptr - n)[0 .. r.buffer.len + n];
- r.end += n;
- r.seek += n;
-}
-
test fixed {
var r: Reader = .fixed("a\x02");
try testing.expect((try r.takeByte()) == 'a');
diff --git a/lib/std/Io/Writer.zig b/lib/std/Io/Writer.zig
@@ -191,29 +191,87 @@ pub fn writeSplatHeader(
data: []const []const u8,
splat: usize,
) Error!usize {
- const new_end = w.end + header.len;
- if (new_end <= w.buffer.len) {
- @memcpy(w.buffer[w.end..][0..header.len], header);
- w.end = new_end;
- return header.len + try writeSplat(w, data, splat);
- }
- var vecs: [8][]const u8 = undefined; // Arbitrarily chosen size.
- var i: usize = 1;
- vecs[0] = header;
- for (data[0 .. data.len - 1]) |buf| {
- if (buf.len == 0) continue;
- vecs[i] = buf;
- i += 1;
- if (vecs.len - i == 0) break;
+ return writeSplatHeaderLimit(w, header, data, splat, .unlimited);
+}
+
+/// Equivalent to `writeSplatHeader` but writes at most `limit` bytes.
+pub fn writeSplatHeaderLimit(
+ w: *Writer,
+ header: []const u8,
+ data: []const []const u8,
+ splat: usize,
+ limit: Limit,
+) Error!usize {
+ var remaining = @intFromEnum(limit);
+ {
+ const copy_len = @min(header.len, w.buffer.len - w.end, remaining);
+ if (header.len - copy_len != 0) return writeSplatHeaderLimitFinish(w, header, data, splat, remaining);
+ @memcpy(w.buffer[w.end..][0..copy_len], header[0..copy_len]);
+ w.end += copy_len;
+ remaining -= copy_len;
+ }
+ for (data[0 .. data.len - 1], 0..) |buf, i| {
+ const copy_len = @min(buf.len, w.buffer.len - w.end, remaining);
+ if (buf.len - copy_len != 0) return @intFromEnum(limit) - remaining +
+ try writeSplatHeaderLimitFinish(w, &.{}, data[i..], splat, remaining);
+ @memcpy(w.buffer[w.end..][0..copy_len], buf[0..copy_len]);
+ w.end += copy_len;
+ remaining -= copy_len;
}
const pattern = data[data.len - 1];
- const new_splat = s: {
- if (pattern.len == 0 or vecs.len - i == 0) break :s 1;
+ const splat_n = pattern.len * splat;
+ if (splat_n > @min(w.buffer.len - w.end, remaining)) {
+ const buffered_n = @intFromEnum(limit) - remaining;
+ const written = try writeSplatHeaderLimitFinish(w, &.{}, data[data.len - 1 ..][0..1], splat, remaining);
+ return buffered_n + written;
+ }
+
+ for (0..splat) |_| {
+ @memcpy(w.buffer[w.end..][0..pattern.len], pattern);
+ w.end += pattern.len;
+ }
+
+ remaining -= splat_n;
+ return @intFromEnum(limit) - remaining;
+}
+
+fn writeSplatHeaderLimitFinish(
+ w: *Writer,
+ header: []const u8,
+ data: []const []const u8,
+ splat: usize,
+ limit: usize,
+) Error!usize {
+ var remaining = limit;
+ var vecs: [8][]const u8 = undefined;
+ var i: usize = 0;
+ v: {
+ if (header.len != 0) {
+ const copy_len = @min(header.len, remaining);
+ vecs[i] = header[0..copy_len];
+ i += 1;
+ remaining -= copy_len;
+ if (remaining == 0) break :v;
+ }
+ for (data[0 .. data.len - 1]) |buf| if (buf.len != 0) {
+ const copy_len = @min(header.len, remaining);
+ vecs[i] = buf;
+ i += 1;
+ remaining -= copy_len;
+ if (remaining == 0) break :v;
+ if (vecs.len - i == 0) break :v;
+ };
+ const pattern = data[data.len - 1];
+ if (splat == 1) {
+ vecs[i] = pattern[0..@min(remaining, pattern.len)];
+ i += 1;
+ break :v;
+ }
vecs[i] = pattern;
i += 1;
- break :s splat;
- };
- return w.vtable.drain(w, vecs[0..i], new_splat);
+ return w.vtable.drain(w, (&vecs)[0..i], @min(remaining / pattern.len, splat));
+ }
+ return w.vtable.drain(w, (&vecs)[0..i], 1);
}
test "writeSplatHeader splatting avoids buffer aliasing temptation" {
diff --git a/lib/std/Io/buffered_reader.zig b/lib/std/Io/buffered_reader.zig
@@ -1,201 +0,0 @@
-const std = @import("../std.zig");
-const io = std.io;
-const mem = std.mem;
-const assert = std.debug.assert;
-const testing = std.testing;
-
-pub fn BufferedReader(comptime buffer_size: usize, comptime ReaderType: type) type {
- return struct {
- unbuffered_reader: ReaderType,
- buf: [buffer_size]u8 = undefined,
- start: usize = 0,
- end: usize = 0,
-
- pub const Error = ReaderType.Error;
- pub const Reader = io.GenericReader(*Self, Error, read);
-
- const Self = @This();
-
- pub fn read(self: *Self, dest: []u8) Error!usize {
- // First try reading from the already buffered data onto the destination.
- const current = self.buf[self.start..self.end];
- if (current.len != 0) {
- const to_transfer = @min(current.len, dest.len);
- @memcpy(dest[0..to_transfer], current[0..to_transfer]);
- self.start += to_transfer;
- return to_transfer;
- }
-
- // If dest is large, read from the unbuffered reader directly into the destination.
- if (dest.len >= buffer_size) {
- return self.unbuffered_reader.read(dest);
- }
-
- // If dest is small, read from the unbuffered reader into our own internal buffer,
- // and then transfer to destination.
- self.end = try self.unbuffered_reader.read(&self.buf);
- const to_transfer = @min(self.end, dest.len);
- @memcpy(dest[0..to_transfer], self.buf[0..to_transfer]);
- self.start = to_transfer;
- return to_transfer;
- }
-
- pub fn reader(self: *Self) Reader {
- return .{ .context = self };
- }
- };
-}
-
-pub fn bufferedReader(reader: anytype) BufferedReader(4096, @TypeOf(reader)) {
- return .{ .unbuffered_reader = reader };
-}
-
-pub fn bufferedReaderSize(comptime size: usize, reader: anytype) BufferedReader(size, @TypeOf(reader)) {
- return .{ .unbuffered_reader = reader };
-}
-
-test "OneByte" {
- const OneByteReadReader = struct {
- str: []const u8,
- curr: usize,
-
- const Error = error{NoError};
- const Self = @This();
- const Reader = io.GenericReader(*Self, Error, read);
-
- fn init(str: []const u8) Self {
- return Self{
- .str = str,
- .curr = 0,
- };
- }
-
- fn read(self: *Self, dest: []u8) Error!usize {
- if (self.str.len <= self.curr or dest.len == 0)
- return 0;
-
- dest[0] = self.str[self.curr];
- self.curr += 1;
- return 1;
- }
-
- fn reader(self: *Self) Reader {
- return .{ .context = self };
- }
- };
-
- const str = "This is a test";
- var one_byte_stream = OneByteReadReader.init(str);
- var buf_reader = bufferedReader(one_byte_stream.reader());
- const stream = buf_reader.reader();
-
- const res = try stream.readAllAlloc(testing.allocator, str.len + 1);
- defer testing.allocator.free(res);
- try testing.expectEqualSlices(u8, str, res);
-}
-
-fn smallBufferedReader(underlying_stream: anytype) BufferedReader(8, @TypeOf(underlying_stream)) {
- return .{ .unbuffered_reader = underlying_stream };
-}
-test "Block" {
- const BlockReader = struct {
- block: []const u8,
- reads_allowed: usize,
- curr_read: usize,
-
- const Error = error{NoError};
- const Self = @This();
- const Reader = io.GenericReader(*Self, Error, read);
-
- fn init(block: []const u8, reads_allowed: usize) Self {
- return Self{
- .block = block,
- .reads_allowed = reads_allowed,
- .curr_read = 0,
- };
- }
-
- fn read(self: *Self, dest: []u8) Error!usize {
- if (self.curr_read >= self.reads_allowed) return 0;
- @memcpy(dest[0..self.block.len], self.block);
-
- self.curr_read += 1;
- return self.block.len;
- }
-
- fn reader(self: *Self) Reader {
- return .{ .context = self };
- }
- };
-
- const block = "0123";
-
- // len out == block
- {
- var test_buf_reader: BufferedReader(4, BlockReader) = .{
- .unbuffered_reader = BlockReader.init(block, 2),
- };
- const reader = test_buf_reader.reader();
- var out_buf: [4]u8 = undefined;
- _ = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, &out_buf, block);
- _ = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, &out_buf, block);
- try testing.expectEqual(try reader.readAll(&out_buf), 0);
- }
-
- // len out < block
- {
- var test_buf_reader: BufferedReader(4, BlockReader) = .{
- .unbuffered_reader = BlockReader.init(block, 2),
- };
- const reader = test_buf_reader.reader();
- var out_buf: [3]u8 = undefined;
- _ = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, &out_buf, "012");
- _ = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, &out_buf, "301");
- const n = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, out_buf[0..n], "23");
- try testing.expectEqual(try reader.readAll(&out_buf), 0);
- }
-
- // len out > block
- {
- var test_buf_reader: BufferedReader(4, BlockReader) = .{
- .unbuffered_reader = BlockReader.init(block, 2),
- };
- const reader = test_buf_reader.reader();
- var out_buf: [5]u8 = undefined;
- _ = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, &out_buf, "01230");
- const n = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, out_buf[0..n], "123");
- try testing.expectEqual(try reader.readAll(&out_buf), 0);
- }
-
- // len out == 0
- {
- var test_buf_reader: BufferedReader(4, BlockReader) = .{
- .unbuffered_reader = BlockReader.init(block, 2),
- };
- const reader = test_buf_reader.reader();
- var out_buf: [0]u8 = undefined;
- _ = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, &out_buf, "");
- }
-
- // len bufreader buf > block
- {
- var test_buf_reader: BufferedReader(5, BlockReader) = .{
- .unbuffered_reader = BlockReader.init(block, 2),
- };
- const reader = test_buf_reader.reader();
- var out_buf: [4]u8 = undefined;
- _ = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, &out_buf, block);
- _ = try reader.readAll(&out_buf);
- try testing.expectEqualSlices(u8, &out_buf, block);
- try testing.expectEqual(try reader.readAll(&out_buf), 0);
- }
-}
diff --git a/lib/std/Io/limited_reader.zig b/lib/std/Io/limited_reader.zig
@@ -1,45 +0,0 @@
-const std = @import("../std.zig");
-const io = std.io;
-const assert = std.debug.assert;
-const testing = std.testing;
-
-pub fn LimitedReader(comptime ReaderType: type) type {
- return struct {
- inner_reader: ReaderType,
- bytes_left: u64,
-
- pub const Error = ReaderType.Error;
- pub const Reader = io.GenericReader(*Self, Error, read);
-
- const Self = @This();
-
- pub fn read(self: *Self, dest: []u8) Error!usize {
- const max_read = @min(self.bytes_left, dest.len);
- const n = try self.inner_reader.read(dest[0..max_read]);
- self.bytes_left -= n;
- return n;
- }
-
- pub fn reader(self: *Self) Reader {
- return .{ .context = self };
- }
- };
-}
-
-/// Returns an initialised `LimitedReader`.
-/// `bytes_left` is a `u64` to be able to take 64 bit file offsets
-pub fn limitedReader(inner_reader: anytype, bytes_left: u64) LimitedReader(@TypeOf(inner_reader)) {
- return .{ .inner_reader = inner_reader, .bytes_left = bytes_left };
-}
-
-test "basic usage" {
- const data = "hello world";
- var fbs = std.io.fixedBufferStream(data);
- var early_stream = limitedReader(fbs.reader(), 3);
-
- var buf: [5]u8 = undefined;
- try testing.expectEqual(@as(usize, 3), try early_stream.reader().read(&buf));
- try testing.expectEqualSlices(u8, data[0..3], buf[0..3]);
- try testing.expectEqual(@as(usize, 0), try early_stream.reader().read(&buf));
- try testing.expectError(error.EndOfStream, early_stream.reader().skipBytes(10, .{}));
-}
diff --git a/lib/std/Io/test.zig b/lib/std/Io/test.zig
@@ -45,9 +45,9 @@ test "write a file, read it, then delete it" {
const expected_file_size: u64 = "begin".len + data.len + "end".len;
try expectEqual(expected_file_size, file_size);
- var buf_stream = io.bufferedReader(file.deprecatedReader());
- const st = buf_stream.reader();
- const contents = try st.readAllAlloc(std.testing.allocator, 2 * 1024);
+ var file_buffer: [1024]u8 = undefined;
+ var file_reader = file.reader(&file_buffer);
+ const contents = try file_reader.interface.allocRemaining(std.testing.allocator, .limited(2 * 1024));
defer std.testing.allocator.free(contents);
try expect(mem.eql(u8, contents[0.."begin".len], "begin"));
diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig
@@ -4,6 +4,8 @@
const std = @import("std.zig");
const testing = std.testing;
const Uri = @This();
+const Allocator = std.mem.Allocator;
+const Writer = std.Io.Writer;
scheme: []const u8,
user: ?Component = null,
@@ -14,6 +16,32 @@ path: Component = Component.empty,
query: ?Component = null,
fragment: ?Component = null,
+pub const host_name_max = 255;
+
+/// Returned value may point into `buffer` or be the original string.
+///
+/// Suggested buffer length: `host_name_max`.
+///
+/// See also:
+/// * `getHostAlloc`
+pub fn getHost(uri: Uri, buffer: []u8) error{ UriMissingHost, UriHostTooLong }![]const u8 {
+ const component = uri.host orelse return error.UriMissingHost;
+ return component.toRaw(buffer) catch |err| switch (err) {
+ error.NoSpaceLeft => return error.UriHostTooLong,
+ };
+}
+
+/// Returned value may point into `buffer` or be the original string.
+///
+/// See also:
+/// * `getHost`
+pub fn getHostAlloc(uri: Uri, arena: Allocator) error{ UriMissingHost, UriHostTooLong, OutOfMemory }![]const u8 {
+ const component = uri.host orelse return error.UriMissingHost;
+ const result = try component.toRawMaybeAlloc(arena);
+ if (result.len > host_name_max) return error.UriHostTooLong;
+ return result;
+}
+
pub const Component = union(enum) {
/// Invalid characters in this component must be percent encoded
/// before being printed as part of a URI.
@@ -30,11 +58,19 @@ pub const Component = union(enum) {
};
}
+ /// Returned value may point into `buffer` or be the original string.
+ pub fn toRaw(component: Component, buffer: []u8) error{NoSpaceLeft}![]const u8 {
+ return switch (component) {
+ .raw => |raw| raw,
+ .percent_encoded => |percent_encoded| if (std.mem.indexOfScalar(u8, percent_encoded, '%')) |_|
+ try std.fmt.bufPrint(buffer, "{f}", .{std.fmt.alt(component, .formatRaw)})
+ else
+ percent_encoded,
+ };
+ }
+
/// Allocates the result with `arena` only if needed, so the result should not be freed.
- pub fn toRawMaybeAlloc(
- component: Component,
- arena: std.mem.Allocator,
- ) std.mem.Allocator.Error![]const u8 {
+ pub fn toRawMaybeAlloc(component: Component, arena: Allocator) Allocator.Error![]const u8 {
return switch (component) {
.raw => |raw| raw,
.percent_encoded => |percent_encoded| if (std.mem.indexOfScalar(u8, percent_encoded, '%')) |_|
@@ -44,7 +80,7 @@ pub const Component = union(enum) {
};
}
- pub fn formatRaw(component: Component, w: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn formatRaw(component: Component, w: *Writer) Writer.Error!void {
switch (component) {
.raw => |raw| try w.writeAll(raw),
.percent_encoded => |percent_encoded| {
@@ -67,56 +103,56 @@ pub const Component = union(enum) {
}
}
- pub fn formatEscaped(component: Component, w: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn formatEscaped(component: Component, w: *Writer) Writer.Error!void {
switch (component) {
.raw => |raw| try percentEncode(w, raw, isUnreserved),
.percent_encoded => |percent_encoded| try w.writeAll(percent_encoded),
}
}
- pub fn formatUser(component: Component, w: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn formatUser(component: Component, w: *Writer) Writer.Error!void {
switch (component) {
.raw => |raw| try percentEncode(w, raw, isUserChar),
.percent_encoded => |percent_encoded| try w.writeAll(percent_encoded),
}
}
- pub fn formatPassword(component: Component, w: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn formatPassword(component: Component, w: *Writer) Writer.Error!void {
switch (component) {
.raw => |raw| try percentEncode(w, raw, isPasswordChar),
.percent_encoded => |percent_encoded| try w.writeAll(percent_encoded),
}
}
- pub fn formatHost(component: Component, w: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn formatHost(component: Component, w: *Writer) Writer.Error!void {
switch (component) {
.raw => |raw| try percentEncode(w, raw, isHostChar),
.percent_encoded => |percent_encoded| try w.writeAll(percent_encoded),
}
}
- pub fn formatPath(component: Component, w: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn formatPath(component: Component, w: *Writer) Writer.Error!void {
switch (component) {
.raw => |raw| try percentEncode(w, raw, isPathChar),
.percent_encoded => |percent_encoded| try w.writeAll(percent_encoded),
}
}
- pub fn formatQuery(component: Component, w: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn formatQuery(component: Component, w: *Writer) Writer.Error!void {
switch (component) {
.raw => |raw| try percentEncode(w, raw, isQueryChar),
.percent_encoded => |percent_encoded| try w.writeAll(percent_encoded),
}
}
- pub fn formatFragment(component: Component, w: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn formatFragment(component: Component, w: *Writer) Writer.Error!void {
switch (component) {
.raw => |raw| try percentEncode(w, raw, isFragmentChar),
.percent_encoded => |percent_encoded| try w.writeAll(percent_encoded),
}
}
- pub fn percentEncode(w: *std.io.Writer, raw: []const u8, comptime isValidChar: fn (u8) bool) std.io.Writer.Error!void {
+ pub fn percentEncode(w: *Writer, raw: []const u8, comptime isValidChar: fn (u8) bool) Writer.Error!void {
var start: usize = 0;
for (raw, 0..) |char, index| {
if (isValidChar(char)) continue;
@@ -165,17 +201,15 @@ pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort };
/// The return value will contain strings pointing into the original `text`.
/// Each component that is provided, will be non-`null`.
pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri {
- var reader = SliceReader{ .slice = text };
-
var uri: Uri = .{ .scheme = scheme, .path = undefined };
+ var i: usize = 0;
- if (reader.peekPrefix("//")) a: { // authority part
- std.debug.assert(reader.get().? == '/');
- std.debug.assert(reader.get().? == '/');
-
- const authority = reader.readUntil(isAuthoritySeparator);
+ if (std.mem.startsWith(u8, text, "//")) a: {
+ i = std.mem.indexOfAnyPos(u8, text, 2, &authority_sep) orelse text.len;
+ const authority = text[2..i];
if (authority.len == 0) {
- if (reader.peekPrefix("/")) break :a else return error.InvalidFormat;
+ if (!std.mem.startsWith(u8, text[2..], "/")) return error.InvalidFormat;
+ break :a;
}
var start_of_host: usize = 0;
@@ -225,26 +259,28 @@ pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri {
uri.host = .{ .percent_encoded = authority[start_of_host..end_of_host] };
}
- uri.path = .{ .percent_encoded = reader.readUntil(isPathSeparator) };
+ const path_start = i;
+ i = std.mem.indexOfAnyPos(u8, text, path_start, &path_sep) orelse text.len;
+ uri.path = .{ .percent_encoded = text[path_start..i] };
- if ((reader.peek() orelse 0) == '?') { // query part
- std.debug.assert(reader.get().? == '?');
- uri.query = .{ .percent_encoded = reader.readUntil(isQuerySeparator) };
+ if (std.mem.startsWith(u8, text[i..], "?")) {
+ const query_start = i + 1;
+ i = std.mem.indexOfScalarPos(u8, text, query_start, '#') orelse text.len;
+ uri.query = .{ .percent_encoded = text[query_start..i] };
}
- if ((reader.peek() orelse 0) == '#') { // fragment part
- std.debug.assert(reader.get().? == '#');
- uri.fragment = .{ .percent_encoded = reader.readUntilEof() };
+ if (std.mem.startsWith(u8, text[i..], "#")) {
+ uri.fragment = .{ .percent_encoded = text[i + 1 ..] };
}
return uri;
}
-pub fn format(uri: *const Uri, writer: *std.io.Writer) std.io.Writer.Error!void {
+pub fn format(uri: *const Uri, writer: *Writer) Writer.Error!void {
return writeToStream(uri, writer, .all);
}
-pub fn writeToStream(uri: *const Uri, writer: *std.io.Writer, flags: Format.Flags) std.io.Writer.Error!void {
+pub fn writeToStream(uri: *const Uri, writer: *Writer, flags: Format.Flags) Writer.Error!void {
if (flags.scheme) {
try writer.print("{s}:", .{uri.scheme});
if (flags.authority and uri.host != null) {
@@ -318,7 +354,7 @@ pub const Format = struct {
};
};
- pub fn default(f: Format, writer: *std.io.Writer) std.io.Writer.Error!void {
+ pub fn default(f: Format, writer: *Writer) Writer.Error!void {
return writeToStream(f.uri, writer, f.flags);
}
};
@@ -327,41 +363,34 @@ pub fn fmt(uri: *const Uri, flags: Format.Flags) std.fmt.Formatter(Format, Forma
return .{ .data = .{ .uri = uri, .flags = flags } };
}
-/// Parses the URI or returns an error.
-/// The return value will contain strings pointing into the
-/// original `text`. Each component that is provided, will be non-`null`.
+/// The return value will contain strings pointing into the original `text`.
+/// Each component that is provided will be non-`null`.
pub fn parse(text: []const u8) ParseError!Uri {
- var reader: SliceReader = .{ .slice = text };
- const scheme = reader.readWhile(isSchemeChar);
-
- // after the scheme, a ':' must appear
- if (reader.get()) |c| {
- if (c != ':')
- return error.UnexpectedCharacter;
- } else {
- return error.InvalidFormat;
- }
-
- return parseAfterScheme(scheme, reader.readUntilEof());
+ const end = for (text, 0..) |byte, i| {
+ if (!isSchemeChar(byte)) break i;
+ } else text.len;
+ // After the scheme, a ':' must appear.
+ if (end >= text.len) return error.InvalidFormat;
+ if (text[end] != ':') return error.UnexpectedCharacter;
+ return parseAfterScheme(text[0..end], text[end + 1 ..]);
}
pub const ResolveInPlaceError = ParseError || error{NoSpaceLeft};
-/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5.
-/// Copies `new` to the beginning of `aux_buf.*`, allowing the slices to overlap,
-/// then parses `new` as a URI, and then resolves the path in place.
+/// Resolves a URI against a base URI, conforming to
+/// [RFC 3986, Section 5](https://www.rfc-editor.org/rfc/rfc3986#section-5)
+///
+/// Assumes new location is already copied to the beginning of `aux_buf.*`.
+/// Parses that new location as a URI, and then resolves the path in place.
+///
/// If a merge needs to take place, the newly constructed path will be stored
-/// in `aux_buf.*` just after the copied `new`, and `aux_buf.*` will be modified
-/// to only contain the remaining unused space.
-pub fn resolve_inplace(base: Uri, new: []const u8, aux_buf: *[]u8) ResolveInPlaceError!Uri {
- std.mem.copyForwards(u8, aux_buf.*, new);
- // At this point, new is an invalid pointer.
- const new_mut = aux_buf.*[0..new.len];
- aux_buf.* = aux_buf.*[new.len..];
-
- const new_parsed = parse(new_mut) catch |err|
- (parseAfterScheme("", new_mut) catch return err);
- // As you can see above, `new_mut` is not a const pointer.
+/// in `aux_buf.*` just after the copied location, and `aux_buf.*` will be
+/// modified to only contain the remaining unused space.
+pub fn resolveInPlace(base: Uri, new_len: usize, aux_buf: *[]u8) ResolveInPlaceError!Uri {
+ const new = aux_buf.*[0..new_len];
+ const new_parsed = parse(new) catch |err| (parseAfterScheme("", new) catch return err);
+ aux_buf.* = aux_buf.*[new_len..];
+ // As you can see above, `new` is not a const pointer.
const new_path: []u8 = @constCast(new_parsed.path.percent_encoded);
if (new_parsed.scheme.len > 0) return .{
@@ -461,7 +490,7 @@ test remove_dot_segments {
/// 5.2.3. Merge Paths
fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Component {
- var aux: std.io.Writer = .fixed(aux_buf.*);
+ var aux: Writer = .fixed(aux_buf.*);
if (!base.isEmpty()) {
base.formatPath(&aux) catch return error.NoSpaceLeft;
aux.end = std.mem.lastIndexOfScalar(u8, aux.buffered(), '/') orelse return remove_dot_segments(new);
@@ -472,59 +501,6 @@ fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Co
return merged_path;
}
-const SliceReader = struct {
- const Self = @This();
-
- slice: []const u8,
- offset: usize = 0,
-
- fn get(self: *Self) ?u8 {
- if (self.offset >= self.slice.len)
- return null;
- const c = self.slice[self.offset];
- self.offset += 1;
- return c;
- }
-
- fn peek(self: Self) ?u8 {
- if (self.offset >= self.slice.len)
- return null;
- return self.slice[self.offset];
- }
-
- fn readWhile(self: *Self, comptime predicate: fn (u8) bool) []const u8 {
- const start = self.offset;
- var end = start;
- while (end < self.slice.len and predicate(self.slice[end])) {
- end += 1;
- }
- self.offset = end;
- return self.slice[start..end];
- }
-
- fn readUntil(self: *Self, comptime predicate: fn (u8) bool) []const u8 {
- const start = self.offset;
- var end = start;
- while (end < self.slice.len and !predicate(self.slice[end])) {
- end += 1;
- }
- self.offset = end;
- return self.slice[start..end];
- }
-
- fn readUntilEof(self: *Self) []const u8 {
- const start = self.offset;
- self.offset = self.slice.len;
- return self.slice[start..];
- }
-
- fn peekPrefix(self: Self, prefix: []const u8) bool {
- if (self.offset + prefix.len > self.slice.len)
- return false;
- return std.mem.eql(u8, self.slice[self.offset..][0..prefix.len], prefix);
- }
-};
-
/// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
fn isSchemeChar(c: u8) bool {
return switch (c) {
@@ -533,19 +509,6 @@ fn isSchemeChar(c: u8) bool {
};
}
-/// reserved = gen-delims / sub-delims
-fn isReserved(c: u8) bool {
- return isGenLimit(c) or isSubLimit(c);
-}
-
-/// gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@"
-fn isGenLimit(c: u8) bool {
- return switch (c) {
- ':', ',', '?', '#', '[', ']', '@' => true,
- else => false,
- };
-}
-
/// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
/// / "*" / "+" / "," / ";" / "="
fn isSubLimit(c: u8) bool {
@@ -585,26 +548,8 @@ fn isQueryChar(c: u8) bool {
const isFragmentChar = isQueryChar;
-fn isAuthoritySeparator(c: u8) bool {
- return switch (c) {
- '/', '?', '#' => true,
- else => false,
- };
-}
-
-fn isPathSeparator(c: u8) bool {
- return switch (c) {
- '?', '#' => true,
- else => false,
- };
-}
-
-fn isQuerySeparator(c: u8) bool {
- return switch (c) {
- '#' => true,
- else => false,
- };
-}
+const authority_sep: [3]u8 = .{ '/', '?', '#' };
+const path_sep: [2]u8 = .{ '?', '#' };
test "basic" {
const parsed = try parse("https://ziglang.org/download");
diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig
@@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{
};
pub const close_notify_alert = [_]u8{
- @intFromEnum(AlertLevel.warning),
- @intFromEnum(AlertDescription.close_notify),
+ @intFromEnum(Alert.Level.warning),
+ @intFromEnum(Alert.Description.close_notify),
};
pub const ProtocolVersion = enum(u16) {
@@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) {
_,
};
-pub const AlertLevel = enum(u8) {
- warning = 1,
- fatal = 2,
- _,
-};
+pub const Alert = struct {
+ level: Level,
+ description: Description,
-pub const AlertDescription = enum(u8) {
- pub const Error = error{
- TlsAlertUnexpectedMessage,
- TlsAlertBadRecordMac,
- TlsAlertRecordOverflow,
- TlsAlertHandshakeFailure,
- TlsAlertBadCertificate,
- TlsAlertUnsupportedCertificate,
- TlsAlertCertificateRevoked,
- TlsAlertCertificateExpired,
- TlsAlertCertificateUnknown,
- TlsAlertIllegalParameter,
- TlsAlertUnknownCa,
- TlsAlertAccessDenied,
- TlsAlertDecodeError,
- TlsAlertDecryptError,
- TlsAlertProtocolVersion,
- TlsAlertInsufficientSecurity,
- TlsAlertInternalError,
- TlsAlertInappropriateFallback,
- TlsAlertMissingExtension,
- TlsAlertUnsupportedExtension,
- TlsAlertUnrecognizedName,
- TlsAlertBadCertificateStatusResponse,
- TlsAlertUnknownPskIdentity,
- TlsAlertCertificateRequired,
- TlsAlertNoApplicationProtocol,
- TlsAlertUnknown,
+ pub const Level = enum(u8) {
+ warning = 1,
+ fatal = 2,
+ _,
};
- close_notify = 0,
- unexpected_message = 10,
- bad_record_mac = 20,
- record_overflow = 22,
- handshake_failure = 40,
- bad_certificate = 42,
- unsupported_certificate = 43,
- certificate_revoked = 44,
- certificate_expired = 45,
- certificate_unknown = 46,
- illegal_parameter = 47,
- unknown_ca = 48,
- access_denied = 49,
- decode_error = 50,
- decrypt_error = 51,
- protocol_version = 70,
- insufficient_security = 71,
- internal_error = 80,
- inappropriate_fallback = 86,
- user_canceled = 90,
- missing_extension = 109,
- unsupported_extension = 110,
- unrecognized_name = 112,
- bad_certificate_status_response = 113,
- unknown_psk_identity = 115,
- certificate_required = 116,
- no_application_protocol = 120,
- _,
+ pub const Description = enum(u8) {
+ pub const Error = error{
+ TlsAlertUnexpectedMessage,
+ TlsAlertBadRecordMac,
+ TlsAlertRecordOverflow,
+ TlsAlertHandshakeFailure,
+ TlsAlertBadCertificate,
+ TlsAlertUnsupportedCertificate,
+ TlsAlertCertificateRevoked,
+ TlsAlertCertificateExpired,
+ TlsAlertCertificateUnknown,
+ TlsAlertIllegalParameter,
+ TlsAlertUnknownCa,
+ TlsAlertAccessDenied,
+ TlsAlertDecodeError,
+ TlsAlertDecryptError,
+ TlsAlertProtocolVersion,
+ TlsAlertInsufficientSecurity,
+ TlsAlertInternalError,
+ TlsAlertInappropriateFallback,
+ TlsAlertMissingExtension,
+ TlsAlertUnsupportedExtension,
+ TlsAlertUnrecognizedName,
+ TlsAlertBadCertificateStatusResponse,
+ TlsAlertUnknownPskIdentity,
+ TlsAlertCertificateRequired,
+ TlsAlertNoApplicationProtocol,
+ TlsAlertUnknown,
+ };
- pub fn toError(alert: AlertDescription) Error!void {
- switch (alert) {
- .close_notify => {}, // not an error
- .unexpected_message => return error.TlsAlertUnexpectedMessage,
- .bad_record_mac => return error.TlsAlertBadRecordMac,
- .record_overflow => return error.TlsAlertRecordOverflow,
- .handshake_failure => return error.TlsAlertHandshakeFailure,
- .bad_certificate => return error.TlsAlertBadCertificate,
- .unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
- .certificate_revoked => return error.TlsAlertCertificateRevoked,
- .certificate_expired => return error.TlsAlertCertificateExpired,
- .certificate_unknown => return error.TlsAlertCertificateUnknown,
- .illegal_parameter => return error.TlsAlertIllegalParameter,
- .unknown_ca => return error.TlsAlertUnknownCa,
- .access_denied => return error.TlsAlertAccessDenied,
- .decode_error => return error.TlsAlertDecodeError,
- .decrypt_error => return error.TlsAlertDecryptError,
- .protocol_version => return error.TlsAlertProtocolVersion,
- .insufficient_security => return error.TlsAlertInsufficientSecurity,
- .internal_error => return error.TlsAlertInternalError,
- .inappropriate_fallback => return error.TlsAlertInappropriateFallback,
- .user_canceled => {}, // not an error
- .missing_extension => return error.TlsAlertMissingExtension,
- .unsupported_extension => return error.TlsAlertUnsupportedExtension,
- .unrecognized_name => return error.TlsAlertUnrecognizedName,
- .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
- .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
- .certificate_required => return error.TlsAlertCertificateRequired,
- .no_application_protocol => return error.TlsAlertNoApplicationProtocol,
- _ => return error.TlsAlertUnknown,
+ close_notify = 0,
+ unexpected_message = 10,
+ bad_record_mac = 20,
+ record_overflow = 22,
+ handshake_failure = 40,
+ bad_certificate = 42,
+ unsupported_certificate = 43,
+ certificate_revoked = 44,
+ certificate_expired = 45,
+ certificate_unknown = 46,
+ illegal_parameter = 47,
+ unknown_ca = 48,
+ access_denied = 49,
+ decode_error = 50,
+ decrypt_error = 51,
+ protocol_version = 70,
+ insufficient_security = 71,
+ internal_error = 80,
+ inappropriate_fallback = 86,
+ user_canceled = 90,
+ missing_extension = 109,
+ unsupported_extension = 110,
+ unrecognized_name = 112,
+ bad_certificate_status_response = 113,
+ unknown_psk_identity = 115,
+ certificate_required = 116,
+ no_application_protocol = 120,
+ _,
+
+ pub fn toError(description: Description) Error!void {
+ switch (description) {
+ .close_notify => {}, // not an error
+ .unexpected_message => return error.TlsAlertUnexpectedMessage,
+ .bad_record_mac => return error.TlsAlertBadRecordMac,
+ .record_overflow => return error.TlsAlertRecordOverflow,
+ .handshake_failure => return error.TlsAlertHandshakeFailure,
+ .bad_certificate => return error.TlsAlertBadCertificate,
+ .unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
+ .certificate_revoked => return error.TlsAlertCertificateRevoked,
+ .certificate_expired => return error.TlsAlertCertificateExpired,
+ .certificate_unknown => return error.TlsAlertCertificateUnknown,
+ .illegal_parameter => return error.TlsAlertIllegalParameter,
+ .unknown_ca => return error.TlsAlertUnknownCa,
+ .access_denied => return error.TlsAlertAccessDenied,
+ .decode_error => return error.TlsAlertDecodeError,
+ .decrypt_error => return error.TlsAlertDecryptError,
+ .protocol_version => return error.TlsAlertProtocolVersion,
+ .insufficient_security => return error.TlsAlertInsufficientSecurity,
+ .internal_error => return error.TlsAlertInternalError,
+ .inappropriate_fallback => return error.TlsAlertInappropriateFallback,
+ .user_canceled => {}, // not an error
+ .missing_extension => return error.TlsAlertMissingExtension,
+ .unsupported_extension => return error.TlsAlertUnsupportedExtension,
+ .unrecognized_name => return error.TlsAlertUnrecognizedName,
+ .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
+ .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
+ .certificate_required => return error.TlsAlertCertificateRequired,
+ .no_application_protocol => return error.TlsAlertNoApplicationProtocol,
+ _ => return error.TlsAlertUnknown,
+ }
}
- }
+ };
};
pub const SignatureScheme = enum(u16) {
@@ -650,7 +655,7 @@ pub const Decoder = struct {
}
/// Use this function to increase `their_end`.
- pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
+ pub fn readAtLeast(d: *Decoder, stream: *std.io.Reader, their_amt: usize) !void {
assert(!d.disable_reads);
const existing_amt = d.cap - d.idx;
d.their_end = d.idx + their_amt;
@@ -658,14 +663,16 @@ pub const Decoder = struct {
const request_amt = their_amt - existing_amt;
const dest = d.buf[d.cap..];
if (request_amt > dest.len) return error.TlsRecordOverflow;
- const actual_amt = try stream.readAtLeast(dest, request_amt);
- if (actual_amt < request_amt) return error.TlsConnectionTruncated;
- d.cap += actual_amt;
+ stream.readSlice(dest[0..request_amt]) catch |err| switch (err) {
+ error.EndOfStream => return error.TlsConnectionTruncated,
+ error.ReadFailed => return error.ReadFailed,
+ };
+ d.cap += request_amt;
}
/// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
/// Use when `our_amt` is calculated by us, not by them.
- pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
+ pub fn readAtLeastOurAmt(d: *Decoder, stream: *std.io.Reader, our_amt: usize) !void {
assert(!d.disable_reads);
try readAtLeast(d, stream, our_amt);
d.our_end = d.idx + our_amt;
diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig
@@ -1,11 +1,15 @@
+const builtin = @import("builtin");
+const native_endian = builtin.cpu.arch.endian();
+
const std = @import("../../std.zig");
const tls = std.crypto.tls;
const Client = @This();
-const net = std.net;
const mem = std.mem;
const crypto = std.crypto;
const assert = std.debug.assert;
const Certificate = std.crypto.Certificate;
+const Reader = std.Io.Reader;
+const Writer = std.Io.Writer;
const max_ciphertext_len = tls.max_ciphertext_len;
const hmacExpandLabel = tls.hmacExpandLabel;
@@ -13,44 +17,60 @@ const hkdfExpandLabel = tls.hkdfExpandLabel;
const int = tls.int;
const array = tls.array;
+/// The encrypted stream from the server to the client. Bytes are pulled from
+/// here via `reader`.
+///
+/// The buffer is asserted to have capacity at least `min_buffer_len`.
+input: *Reader,
+/// Decrypted stream from the server to the client.
+reader: Reader,
+
+/// The encrypted stream from the client to the server. Bytes are pushed here
+/// via `writer`.
+///
+/// The buffer is asserted to have capacity at least `min_buffer_len`.
+output: *Writer,
+/// The plaintext stream from the client to the server.
+writer: Writer,
+
+/// Populated when `error.TlsAlert` is returned.
+alert: ?tls.Alert = null,
+read_err: ?ReadError = null,
tls_version: tls.ProtocolVersion,
read_seq: u64,
write_seq: u64,
-/// The starting index of cleartext bytes inside `partially_read_buffer`.
-partial_cleartext_idx: u15,
-/// The ending index of cleartext bytes inside `partially_read_buffer` as well
-/// as the starting index of ciphertext bytes.
-partial_ciphertext_idx: u15,
-/// The ending index of ciphertext bytes inside `partially_read_buffer`.
-partial_ciphertext_end: u15,
/// When this is true, the stream may still not be at the end because there
-/// may be data in `partially_read_buffer`.
+/// may be data in the input buffer.
received_close_notify: bool,
-/// By default, reaching the end-of-stream when reading from the server will
-/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
-/// message has been received. By setting this flag to `true`, instead, the
-/// end-of-stream will be forwarded to the application layer above TLS.
-/// This makes the application vulnerable to truncation attacks unless the
-/// application layer itself verifies that the amount of data received equals
-/// the amount of data expected, such as HTTP with the Content-Length header.
allow_truncation_attacks: bool,
application_cipher: tls.ApplicationCipher,
-/// The size is enough to contain exactly one TLSCiphertext record.
-/// This buffer is segmented into four parts:
-/// 0. unused
-/// 1. cleartext
-/// 2. ciphertext
-/// 3. unused
-/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
-/// `partial_ciphertext_end` describe the span of the segments.
-partially_read_buffer: [tls.max_ciphertext_record_len]u8,
-/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other
-/// programs with access to that file to decrypt all traffic over this connection.
-ssl_key_log: ?struct {
+
+/// If non-null, ssl secrets are logged to a stream. Creating such a log file
+/// allows other programs with access to that file to decrypt all traffic over
+/// this connection.
+ssl_key_log: ?*SslKeyLog,
+
+pub const ReadError = error{
+ /// The alert description will be stored in `alert`.
+ TlsAlert,
+ TlsBadLength,
+ TlsBadRecordMac,
+ TlsConnectionTruncated,
+ TlsDecodeError,
+ TlsRecordOverflow,
+ TlsUnexpectedMessage,
+ TlsIllegalParameter,
+ TlsSequenceOverflow,
+ /// The buffer provided to the read function was not at least
+ /// `min_buffer_len`.
+ OutputBufferUndersize,
+};
+
+pub const SslKeyLog = struct {
client_key_seq: u64,
server_key_seq: u64,
client_random: [32]u8,
- file: std.fs.File,
+ writer: *Writer,
fn clientCounter(key_log: *@This()) u64 {
defer key_log.client_key_seq += 1;
@@ -61,51 +81,12 @@ ssl_key_log: ?struct {
defer key_log.server_key_seq += 1;
return key_log.server_key_seq;
}
-},
-
-/// This is an example of the type that is needed by the read and write
-/// functions. It can have any fields but it must at least have these
-/// functions.
-///
-/// Note that `std.net.Stream` conforms to this interface.
-///
-/// This declaration serves as documentation only.
-pub const StreamInterface = struct {
- /// Can be any error set.
- pub const ReadError = error{};
-
- /// Returns the number of bytes read. The number read may be less than the
- /// buffer space provided. End-of-stream is indicated by a return value of 0.
- ///
- /// The `iovecs` parameter is mutable because so that function may to
- /// mutate the fields in order to handle partial reads from the underlying
- /// stream layer.
- pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize {
- _ = .{ this, iovecs };
- @panic("unimplemented");
- }
-
- /// Can be any error set.
- pub const WriteError = error{};
-
- /// Returns the number of bytes read, which may be less than the buffer
- /// space provided. A short read does not indicate end-of-stream.
- pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize {
- _ = .{ this, iovecs };
- @panic("unimplemented");
- }
-
- /// Returns the number of bytes read, which may be less than the buffer
- /// space provided, indicating end-of-stream.
- /// The `iovecs` parameter is mutable in case this function needs to mutate
- /// the fields in order to handle partial writes from the underlying layer.
- pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!usize {
- // This can be implemented in terms of writev, or specialized if desired.
- _ = .{ this, iovecs };
- @panic("unimplemented");
- }
};
+/// The `Reader` supplied to `init` requires a buffer capacity
+/// at least this amount.
+pub const min_buffer_len = tls.max_ciphertext_record_len;
+
pub const Options = struct {
/// How to perform host verification of server certificates.
host: union(enum) {
@@ -127,64 +108,85 @@ pub const Options = struct {
/// Verify that the server certificate is authorized by a given ca bundle.
bundle: Certificate.Bundle,
},
- /// If non-null, ssl secrets are logged to this file. Creating such a log file allows
+ /// If non-null, ssl secrets are logged to this stream. Creating such a log file allows
/// other programs with access to that file to decrypt all traffic over this connection.
- ssl_key_log_file: ?std.fs.File = null,
+ ///
+ /// Only the `writer` field is observed during the handshake (`init`).
+ /// After that, the other fields are populated.
+ ssl_key_log: ?*SslKeyLog = null,
+ /// By default, reaching the end-of-stream when reading from the server will
+ /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
+ /// message has been received. By setting this flag to `true`, instead, the
+ /// end-of-stream will be forwarded to the application layer above TLS.
+ ///
+ /// This makes the application vulnerable to truncation attacks unless the
+ /// application layer itself verifies that the amount of data received equals
+ /// the amount of data expected, such as HTTP with the Content-Length header.
+ allow_truncation_attacks: bool = false,
+ write_buffer: []u8,
+ read_buffer: []u8,
+ /// Populated when `error.TlsAlert` is returned from `init`.
+ alert: ?*tls.Alert = null,
};
-pub fn InitError(comptime Stream: type) type {
- return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{
- InsufficientEntropy,
- DiskQuota,
- LockViolation,
- NotOpenForWriting,
- TlsUnexpectedMessage,
- TlsIllegalParameter,
- TlsDecryptFailure,
- TlsRecordOverflow,
- TlsBadRecordMac,
- CertificateFieldHasInvalidLength,
- CertificateHostMismatch,
- CertificatePublicKeyInvalid,
- CertificateExpired,
- CertificateFieldHasWrongDataType,
- CertificateIssuerMismatch,
- CertificateNotYetValid,
- CertificateSignatureAlgorithmMismatch,
- CertificateSignatureAlgorithmUnsupported,
- CertificateSignatureInvalid,
- CertificateSignatureInvalidLength,
- CertificateSignatureNamedCurveUnsupported,
- CertificateSignatureUnsupportedBitCount,
- TlsCertificateNotVerified,
- TlsBadSignatureScheme,
- TlsBadRsaSignatureBitCount,
- InvalidEncoding,
- IdentityElement,
- SignatureVerificationFailed,
- TlsDecryptError,
- TlsConnectionTruncated,
- TlsDecodeError,
- UnsupportedCertificateVersion,
- CertificateTimeInvalid,
- CertificateHasUnrecognizedObjectId,
- CertificateHasInvalidBitString,
- MessageTooLong,
- NegativeIntoUnsigned,
- TargetTooSmall,
- BufferTooSmall,
- InvalidSignature,
- NotSquare,
- NonCanonical,
- WeakPublicKey,
- };
-}
+const InitError = error{
+ WriteFailed,
+ ReadFailed,
+ InsufficientEntropy,
+ DiskQuota,
+ LockViolation,
+ NotOpenForWriting,
+ /// The alert description will be stored in `alert`.
+ TlsAlert,
+ TlsUnexpectedMessage,
+ TlsIllegalParameter,
+ TlsDecryptFailure,
+ TlsRecordOverflow,
+ TlsBadRecordMac,
+ CertificateFieldHasInvalidLength,
+ CertificateHostMismatch,
+ CertificatePublicKeyInvalid,
+ CertificateExpired,
+ CertificateFieldHasWrongDataType,
+ CertificateIssuerMismatch,
+ CertificateNotYetValid,
+ CertificateSignatureAlgorithmMismatch,
+ CertificateSignatureAlgorithmUnsupported,
+ CertificateSignatureInvalid,
+ CertificateSignatureInvalidLength,
+ CertificateSignatureNamedCurveUnsupported,
+ CertificateSignatureUnsupportedBitCount,
+ TlsCertificateNotVerified,
+ TlsBadSignatureScheme,
+ TlsBadRsaSignatureBitCount,
+ InvalidEncoding,
+ IdentityElement,
+ SignatureVerificationFailed,
+ TlsDecryptError,
+ TlsConnectionTruncated,
+ TlsDecodeError,
+ UnsupportedCertificateVersion,
+ CertificateTimeInvalid,
+ CertificateHasUnrecognizedObjectId,
+ CertificateHasInvalidBitString,
+ MessageTooLong,
+ NegativeIntoUnsigned,
+ TargetTooSmall,
+ BufferTooSmall,
+ InvalidSignature,
+ NotSquare,
+ NonCanonical,
+ WeakPublicKey,
+};
-/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which
-/// must conform to `StreamInterface`.
+/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session.
///
/// `host` is only borrowed during this function call.
-pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client {
+///
+/// `input` is asserted to have buffer capacity at least `min_buffer_len`.
+pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
+ assert(input.buffer.len >= min_buffer_len);
+ assert(output.buffer.len >= min_buffer_len);
const host = switch (options.host) {
.no_verification => "",
.explicit => |host| host,
@@ -276,11 +278,9 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
};
{
- var iovecs = [_]std.posix.iovec_const{
- .{ .base = cleartext_header.ptr, .len = cleartext_header.len },
- .{ .base = host.ptr, .len = host.len },
- };
- try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]);
+ var iovecs: [2][]const u8 = .{ cleartext_header, host };
+ try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]);
+ try output.flush();
}
var tls_version: tls.ProtocolVersion = undefined;
@@ -329,20 +329,28 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
var cleartext_fragment_start: usize = 0;
var cleartext_fragment_end: usize = 0;
var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined;
- var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
- var d: tls.Decoder = .{ .buf = &handshake_buffer };
fragment: while (true) {
- try d.readAtLeastOurAmt(stream, tls.record_header_len);
- const record_header = d.buf[d.idx..][0..tls.record_header_len];
- const record_ct = d.decode(tls.ContentType);
- d.skip(2); // legacy_version
- const record_len = d.decode(u16);
- try d.readAtLeast(stream, record_len);
- var record_decoder = try d.sub(record_len);
+ // Ensure the input buffer pointer is stable in this scope.
+ input.rebase(tls.max_ciphertext_record_len) catch |err| switch (err) {
+ error.EndOfStream => {}, // We have assurance the remainder of stream can be buffered.
+ };
+ const record_header = input.peek(tls.record_header_len) catch |err| switch (err) {
+ error.EndOfStream => return error.TlsConnectionTruncated,
+ error.ReadFailed => return error.ReadFailed,
+ };
+ const record_ct = input.takeEnumNonexhaustive(tls.ContentType, .big) catch unreachable; // already peeked
+ input.toss(2); // legacy_version
+ const record_len = input.takeInt(u16, .big) catch unreachable; // already peeked
+ if (record_len > tls.max_ciphertext_len) return error.TlsRecordOverflow;
+ const record_buffer = input.take(record_len) catch |err| switch (err) {
+ error.EndOfStream => return error.TlsConnectionTruncated,
+ error.ReadFailed => return error.ReadFailed,
+ };
+ var record_decoder: tls.Decoder = .fromTheirSlice(record_buffer);
var ctd, const ct = content: switch (cipher_state) {
.cleartext => .{ record_decoder, record_ct },
.handshake => {
- std.debug.assert(tls_version == .tls_1_3);
+ assert(tls_version == .tls_1_3);
if (record_ct != .application_data) return error.TlsUnexpectedMessage;
try record_decoder.ensure(record_len);
const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
@@ -374,7 +382,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct };
},
.application => {
- std.debug.assert(tls_version == .tls_1_2);
+ assert(tls_version == .tls_1_2);
if (record_ct != .handshake) return error.TlsUnexpectedMessage;
try record_decoder.ensure(record_len);
const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
@@ -412,14 +420,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
switch (ct) {
.alert => {
ctd.ensure(2) catch continue :fragment;
- const level = ctd.decode(tls.AlertLevel);
- const desc = ctd.decode(tls.AlertDescription);
- _ = level;
-
- // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake
- try desc.toError();
- // TODO: handle server-side closures
- return error.TlsUnexpectedMessage;
+ if (options.alert) |a| a.* = .{
+ .level = ctd.decode(tls.Alert.Level),
+ .description = ctd.decode(tls.Alert.Description),
+ };
+ return error.TlsAlert;
},
.change_cipher_spec => {
ctd.ensure(1) catch continue :fragment;
@@ -533,7 +538,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
- if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+ if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
.client_random = &client_hello_rand,
}, .{
.SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret,
@@ -707,7 +712,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
&client_hello_rand,
&server_hello_rand,
}, 48);
- if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+ if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
.client_random = &client_hello_rand,
}, .{
.CLIENT_RANDOM = &master_secret,
@@ -755,11 +760,13 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
nonce,
pv.app_cipher.client_write_key,
);
- const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg;
- var all_msgs_vec = [_]std.posix.iovec_const{
- .{ .base = &all_msgs, .len = all_msgs.len },
+ var all_msgs_vec: [3][]const u8 = .{
+ &client_key_exchange_msg,
+ &client_change_cipher_spec_msg,
+ &client_verify_msg,
};
- try stream.writevAll(&all_msgs_vec);
+ try output.writeVecAll(&all_msgs_vec);
+ try output.flush();
},
}
write_seq += 1;
@@ -820,15 +827,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
const nonce = pv.client_handshake_iv;
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key);
- const all_msgs = client_change_cipher_spec_msg ++ finished_msg;
- var all_msgs_vec = [_]std.posix.iovec_const{
- .{ .base = &all_msgs, .len = all_msgs.len },
+ var all_msgs_vec: [2][]const u8 = .{
+ &client_change_cipher_spec_msg,
+ &finished_msg,
};
- try stream.writevAll(&all_msgs_vec);
+ try output.writeVecAll(&all_msgs_vec);
+ try output.flush();
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
- if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+ if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
.counter = key_seq,
.client_random = &client_hello_rand,
}, .{
@@ -855,8 +863,28 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
else => unreachable,
},
};
- const leftover = d.rest();
- var client: Client = .{
+ if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{
+ .client_key_seq = key_seq,
+ .server_key_seq = key_seq,
+ .client_random = client_hello_rand,
+ .writer = ssl_key_log.writer,
+ };
+ return .{
+ .input = input,
+ .reader = .{
+ .buffer = options.read_buffer,
+ .vtable = &.{ .stream = stream },
+ .seek = 0,
+ .end = 0,
+ },
+ .output = output,
+ .writer = .{
+ .buffer = options.write_buffer,
+ .vtable = &.{
+ .drain = drain,
+ .flush = flush,
+ },
+ },
.tls_version = tls_version,
.read_seq = switch (tls_version) {
.tls_1_3 => 0,
@@ -868,22 +896,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
.tls_1_2 => write_seq,
else => unreachable,
},
- .partial_cleartext_idx = 0,
- .partial_ciphertext_idx = 0,
- .partial_ciphertext_end = @intCast(leftover.len),
.received_close_notify = false,
- .allow_truncation_attacks = false,
+ .allow_truncation_attacks = options.allow_truncation_attacks,
.application_cipher = app_cipher,
- .partially_read_buffer = undefined,
- .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{
- .client_key_seq = key_seq,
- .server_key_seq = key_seq,
- .client_random = client_hello_rand,
- .file = key_log_file,
- } else null,
+ .ssl_key_log = options.ssl_key_log,
};
- @memcpy(client.partially_read_buffer[0..leftover.len], leftover);
- return client;
},
else => return error.TlsUnexpectedMessage,
}
@@ -897,94 +914,73 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
}
}
-/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
-/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
-pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize {
- return writeEnd(c, stream, bytes, false);
-}
-
-/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
-pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try c.write(stream, bytes[index..]);
+fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize {
+ const c: *Client = @alignCast(@fieldParentPtr("writer", w));
+ const output = c.output;
+ const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
+ var ciphertext_end: usize = 0;
+ var total_clear: usize = 0;
+ done: {
+ {
+ const buf = w.buffered();
+ const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
+ total_clear += prepared.cleartext_len;
+ ciphertext_end += prepared.ciphertext_end;
+ if (prepared.cleartext_len < buf.len) break :done;
+ }
+ for (data[0 .. data.len - 1]) |buf| {
+ if (buf.len < min_buffer_len) break :done;
+ const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
+ total_clear += prepared.cleartext_len;
+ ciphertext_end += prepared.ciphertext_end;
+ if (prepared.cleartext_len < buf.len) break :done;
+ }
+ const buf = data[data.len - 1];
+ for (0..splat) |_| {
+ if (buf.len < min_buffer_len) break :done;
+ const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
+ total_clear += prepared.cleartext_len;
+ ciphertext_end += prepared.ciphertext_end;
+ if (prepared.cleartext_len < buf.len) break :done;
+ }
}
+ output.advance(ciphertext_end);
+ return w.consume(total_clear);
}
-/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
-/// If `end` is true, then this function additionally sends a `close_notify` alert,
-/// which is necessary for the server to distinguish between a properly finished
-/// TLS session, or a truncation attack.
-pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try c.writeEnd(stream, bytes[index..], end);
- }
+fn flush(w: *Writer) Writer.Error!void {
+ const c: *Client = @alignCast(@fieldParentPtr("writer", w));
+ const output = c.output;
+ const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
+ const prepared = prepareCiphertextRecord(c, ciphertext_buf, w.buffered(), .application_data);
+ output.advance(prepared.ciphertext_end);
+ w.end = 0;
}
-/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
-/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
-/// If `end` is true, then this function additionally sends a `close_notify` alert,
-/// which is necessary for the server to distinguish between a properly finished
-/// TLS session, or a truncation attack.
-pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
- var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
- var iovecs_buf: [6]std.posix.iovec_const = undefined;
- var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
- if (end) {
- prepared.iovec_end += prepareCiphertextRecord(
- c,
- iovecs_buf[prepared.iovec_end..],
- ciphertext_buf[prepared.ciphertext_end..],
- &tls.close_notify_alert,
- .alert,
- ).iovec_end;
- }
-
- const iovec_end = prepared.iovec_end;
- const overhead_len = prepared.overhead_len;
-
- // Ideally we would call writev exactly once here, however, we must ensure
- // that we don't return with a record partially written.
- var i: usize = 0;
- var total_amt: usize = 0;
- while (true) {
- var amt = try stream.writev(iovecs_buf[i..iovec_end]);
- while (amt >= iovecs_buf[i].len) {
- const encrypted_amt = iovecs_buf[i].len;
- total_amt += encrypted_amt - overhead_len;
- amt -= encrypted_amt;
- i += 1;
- // Rely on the property that iovecs delineate records, meaning that
- // if amt equals zero here, we have fortunately found ourselves
- // with a short read that aligns at the record boundary.
- if (i >= iovec_end) return total_amt;
- // We also cannot return on a vector boundary if the final close_notify is
- // not sent; otherwise the caller would not know to retry the call.
- if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
- }
- iovecs_buf[i].base += amt;
- iovecs_buf[i].len -= amt;
- }
+/// Sends a `close_notify` alert, which is necessary for the server to
+/// distinguish between a properly finished TLS session, or a truncation
+/// attack.
+pub fn end(c: *Client) Writer.Error!void {
+ try flush(&c.writer);
+ const output = c.output;
+ const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
+ const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
+ output.advance(prepared.ciphertext_end);
}
fn prepareCiphertextRecord(
c: *Client,
- iovecs: []std.posix.iovec_const,
ciphertext_buf: []u8,
bytes: []const u8,
inner_content_type: tls.ContentType,
) struct {
- iovec_end: usize,
ciphertext_end: usize,
- /// How many bytes are taken up by overhead per record.
- overhead_len: usize,
+ cleartext_len: usize,
} {
// Due to the trailing inner content type byte in the ciphertext, we need
// an additional buffer for storing the cleartext into before encrypting.
var cleartext_buf: [max_ciphertext_len]u8 = undefined;
var ciphertext_end: usize = 0;
- var iovec_end: usize = 0;
var bytes_i: usize = 0;
switch (c.application_cipher) {
inline else => |*p| switch (c.tls_version) {
@@ -992,18 +988,15 @@ fn prepareCiphertextRecord(
const pv = &p.tls_1_3;
const P = @TypeOf(p.*);
const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
- const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const encrypted_content_len: u16 = @min(
bytes.len - bytes_i,
tls.max_ciphertext_inner_record_len,
- ciphertext_buf.len -|
- (close_notify_alert_reserved + overhead_len + ciphertext_end),
+ ciphertext_buf.len -| (overhead_len + ciphertext_end),
);
if (encrypted_content_len == 0) return .{
- .iovec_end = iovec_end,
.ciphertext_end = ciphertext_end,
- .overhead_len = overhead_len,
+ .cleartext_len = bytes_i,
};
@memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]);
@@ -1012,7 +1005,6 @@ fn prepareCiphertextRecord(
const ciphertext_len = encrypted_content_len + 1;
const cleartext = cleartext_buf[0..ciphertext_len];
- const record_start = ciphertext_end;
const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
@@ -1030,38 +1022,27 @@ fn prepareCiphertextRecord(
};
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
c.write_seq += 1; // TODO send key_update on overflow
-
- const record = ciphertext_buf[record_start..ciphertext_end];
- iovecs[iovec_end] = .{
- .base = record.ptr,
- .len = record.len,
- };
- iovec_end += 1;
}
},
.tls_1_2 => {
const pv = &p.tls_1_2;
const P = @TypeOf(p.*);
const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length;
- const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const message_len: u16 = @min(
bytes.len - bytes_i,
tls.max_ciphertext_inner_record_len,
- ciphertext_buf.len -|
- (close_notify_alert_reserved + overhead_len + ciphertext_end),
+ ciphertext_buf.len -| (overhead_len + ciphertext_end),
);
if (message_len == 0) return .{
- .iovec_end = iovec_end,
.ciphertext_end = ciphertext_end,
- .overhead_len = overhead_len,
+ .cleartext_len = bytes_i,
};
@memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]);
bytes_i += message_len;
const cleartext = cleartext_buf[0..message_len];
- const record_start = ciphertext_end;
const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
ciphertext_end += tls.record_header_len;
record_header.* = .{@intFromEnum(inner_content_type)} ++
@@ -1083,13 +1064,6 @@ fn prepareCiphertextRecord(
ciphertext_end += P.mac_length;
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key);
c.write_seq += 1; // TODO send key_update on overflow
-
- const record = ciphertext_buf[record_start..ciphertext_end];
- iovecs[iovec_end] = .{
- .base = record.ptr,
- .len = record.len,
- };
- iovec_end += 1;
}
},
else => unreachable,
@@ -1098,421 +1072,194 @@ fn prepareCiphertextRecord(
}
pub fn eof(c: Client) bool {
- return c.received_close_notify and
- c.partial_cleartext_idx >= c.partial_ciphertext_idx and
- c.partial_ciphertext_idx >= c.partial_ciphertext_end;
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns the number of bytes read, calling the underlying read function the
-/// minimal number of times until the buffer has at least `len` bytes filled.
-/// If the number read is less than `len` it means the stream reached the end.
-/// Reaching the end of the stream is not an error condition.
-pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize {
- var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }};
- return readvAtLeast(c, stream, &iovecs, len);
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
- return readAtLeast(c, stream, buffer, 1);
+ return c.received_close_notify;
}
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns the number of bytes read. If the number read is smaller than
-/// `buffer.len`, it means the stream reached the end. Reaching the end of the
-/// stream is not an error condition.
-pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
- return readAtLeast(c, stream, buffer, buffer.len);
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns the number of bytes read. If the number read is less than the space
-/// provided it means the stream reached the end. Reaching the end of the
-/// stream is not an error condition.
-/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
-/// order to handle partial reads from the underlying stream layer.
-pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize {
- return readvAtLeast(c, stream, iovecs, 1);
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns the number of bytes read, calling the underlying read function the
-/// minimal number of times until the iovecs have at least `len` bytes filled.
-/// If the number read is less than `len` it means the stream reached the end.
-/// Reaching the end of the stream is not an error condition.
-/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
-/// order to handle partial reads from the underlying stream layer.
-pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize {
- if (c.eof()) return 0;
-
- var off_i: usize = 0;
- var vec_i: usize = 0;
- while (true) {
- var amt = try c.readvAdvanced(stream, iovecs[vec_i..]);
- off_i += amt;
- if (c.eof() or off_i >= len) return off_i;
- while (amt >= iovecs[vec_i].len) {
- amt -= iovecs[vec_i].len;
- vec_i += 1;
- }
- iovecs[vec_i].base += amt;
- iovecs[vec_i].len -= amt;
- }
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns number of bytes that have been read, populated inside `iovecs`. A
-/// return value of zero bytes does not mean end of stream. Instead, check the `eof()`
-/// for the end of stream. The `eof()` may be true after any call to
-/// `read`, including when greater than zero bytes are returned, and this
-/// function asserts that `eof()` is `false`.
-/// See `readv` for a higher level function that has the same, familiar API as
-/// other read functions, such as `std.fs.File.read`.
-pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize {
- var vp: VecPut = .{ .iovecs = iovecs };
-
- // Give away the buffered cleartext we have, if any.
- const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx];
- if (partial_cleartext.len > 0) {
- const amt: u15 = @intCast(vp.put(partial_cleartext));
- c.partial_cleartext_idx += amt;
-
- if (c.partial_cleartext_idx == c.partial_ciphertext_idx and
- c.partial_ciphertext_end == c.partial_ciphertext_idx)
- {
- // The buffer is now empty.
- c.partial_cleartext_idx = 0;
- c.partial_ciphertext_idx = 0;
- c.partial_ciphertext_end = 0;
- }
-
- if (c.received_close_notify) {
- c.partial_ciphertext_end = 0;
- assert(vp.total == amt);
- return amt;
- } else if (amt > 0) {
- // We don't need more data, so don't call read.
- assert(vp.total == amt);
- return amt;
- }
+fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
+ const c: *Client = @alignCast(@fieldParentPtr("reader", r));
+ if (c.eof()) return error.EndOfStream;
+ const input = c.input;
+ // If at least one full encrypted record is not buffered, read once.
+ const record_header = input.peek(tls.record_header_len) catch |err| switch (err) {
+ error.EndOfStream => {
+ // This is either a truncation attack, a bug in the server, or an
+ // intentional omission of the close_notify message due to truncation
+ // detection handled above the TLS layer.
+ if (c.allow_truncation_attacks) {
+ c.received_close_notify = true;
+ return error.EndOfStream;
+ } else {
+ return failRead(c, error.TlsConnectionTruncated);
+ }
+ },
+ error.ReadFailed => return error.ReadFailed,
+ };
+ const ct: tls.ContentType = @enumFromInt(record_header[0]);
+ const legacy_version = mem.readInt(u16, record_header[1..][0..2], .big);
+ _ = legacy_version;
+ const record_len = mem.readInt(u16, record_header[3..][0..2], .big);
+ if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow);
+ const record_end = 5 + record_len;
+ if (record_end > input.buffered().len) {
+ input.fillMore() catch |err| switch (err) {
+ error.EndOfStream => return failRead(c, error.TlsConnectionTruncated),
+ error.ReadFailed => return error.ReadFailed,
+ };
+ if (record_end > input.buffered().len) return 0;
}
- assert(!c.received_close_notify);
-
- // Ideally, this buffer would never be used. It is needed when `iovecs` are
- // too small to fit the cleartext, which may be as large as `max_ciphertext_len`.
var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
- // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`.
- var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined;
- // How many bytes left in the user's buffer.
- const free_size = vp.freeSize();
- // The amount of the user's buffer that we need to repurpose for storing
- // ciphertext. The end of the buffer will be used for such purposes.
- const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len;
- // The amount of the user's buffer that will be used to give cleartext. The
- // beginning of the buffer will be used for such purposes.
- const cleartext_buf_len = free_size - ciphertext_buf_len;
-
- // Recoup `partially_read_buffer` space. This is necessary because it is assumed
- // below that `frag0` is big enough to hold at least one record.
- limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx);
- c.partial_ciphertext_end -= c.partial_ciphertext_idx;
- c.partial_ciphertext_idx = 0;
- c.partial_cleartext_idx = 0;
- const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..];
-
- var ask_iovecs_buf: [2]std.posix.iovec = .{
- .{
- .base = first_iov.ptr,
- .len = first_iov.len,
- },
- .{
- .base = &in_stack_buffer,
- .len = in_stack_buffer.len,
+ const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
+ inline else => |*p| switch (c.tls_version) {
+ .tls_1_3 => {
+ const pv = &p.tls_1_3;
+ const P = @TypeOf(p.*);
+ const ad = input.take(tls.record_header_len) catch unreachable; // already peeked
+ const ciphertext_len = record_len - P.AEAD.tag_length;
+ const ciphertext = input.take(ciphertext_len) catch unreachable; // already peeked
+ const auth_tag = (input.takeArray(P.AEAD.tag_length) catch unreachable).*; // already peeked
+ const nonce = nonce: {
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
+ break :nonce @as(V, pv.server_iv) ^ operand;
+ };
+ const cleartext = cleartext_stack_buffer[0..ciphertext.len];
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
+ return failRead(c, error.TlsBadRecordMac);
+ const msg = mem.trimRight(u8, cleartext, "\x00");
+ break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
+ },
+ .tls_1_2 => {
+ const pv = &p.tls_1_2;
+ const P = @TypeOf(p.*);
+ const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
+ const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked
+ const ad = std.mem.toBytes(big(c.read_seq)) ++
+ ad_header[0 .. 1 + 2] ++
+ std.mem.toBytes(big(message_len));
+ const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked
+ const masked_read_seq = c.read_seq &
+ comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
+ const nonce: [P.AEAD.nonce_length]u8 = nonce: {
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq)));
+ break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand;
+ };
+ const ciphertext = input.take(message_len) catch unreachable; // already peeked
+ const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
+ const cleartext = cleartext_stack_buffer[0..ciphertext.len];
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
+ return failRead(c, error.TlsBadRecordMac);
+ break :cleartext .{ cleartext, ct };
+ },
+ else => unreachable,
},
};
-
- // Cleartext capacity of output buffer, in records. Minimum one full record.
- const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1);
- const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
- const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end;
- const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
- const actual_read_len = try stream.readv(ask_iovecs);
- if (actual_read_len == 0) {
- // This is either a truncation attack, a bug in the server, or an
- // intentional omission of the close_notify message due to truncation
- // detection handled above the TLS layer.
- if (c.allow_truncation_attacks) {
- c.received_close_notify = true;
- } else {
- return error.TlsConnectionTruncated;
- }
- }
-
- // There might be more bytes inside `in_stack_buffer` that need to be processed,
- // but at least frag0 will have one complete ciphertext record.
- const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len);
- const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end];
- var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len];
- // We need to decipher frag0 and frag1 but there may be a ciphertext record
- // straddling the boundary. We can handle this with two memcpy() calls to
- // assemble the straddling record in between handling the two sides.
- var frag = frag0;
- var in: usize = 0;
- while (true) {
- if (in == frag.len) {
- // Perfect split.
- if (frag.ptr == frag1.ptr) {
- c.partial_ciphertext_end = c.partial_ciphertext_idx;
- return vp.total;
- }
- frag = frag1;
- in = 0;
- continue;
- }
-
- if (in + tls.record_header_len > frag.len) {
- if (frag.ptr == frag1.ptr)
- return finishRead(c, frag, in, vp.total);
-
- const first = frag[in..];
-
- if (frag1.len < tls.record_header_len)
- return finishRead2(c, first, frag1, vp.total);
-
- // A record straddles the two fragments. Copy into the now-empty first fragment.
- const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3);
- const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4);
- const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
- if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
-
- const full_record_len = record_len + tls.record_header_len;
- const second_len = full_record_len - first.len;
- if (frag1.len < second_len)
- return finishRead2(c, first, frag1, vp.total);
-
- limitedOverlapCopy(frag, in);
- @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]);
- frag = frag[0..full_record_len];
- frag1 = frag1[second_len..];
- in = 0;
- continue;
- }
- const ct: tls.ContentType = @enumFromInt(frag[in]);
- in += 1;
- const legacy_version = mem.readInt(u16, frag[in..][0..2], .big);
- in += 2;
- _ = legacy_version;
- const record_len = mem.readInt(u16, frag[in..][0..2], .big);
- if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
- in += 2;
- const end = in + record_len;
- if (end > frag.len) {
- // We need the record header on the next iteration of the loop.
- in -= tls.record_header_len;
-
- if (frag.ptr == frag1.ptr)
- return finishRead(c, frag, in, vp.total);
-
- // A record straddles the two fragments. Copy into the now-empty first fragment.
- const first = frag[in..];
- const full_record_len = record_len + tls.record_header_len;
- const second_len = full_record_len - first.len;
- if (frag1.len < second_len)
- return finishRead2(c, first, frag1, vp.total);
-
- limitedOverlapCopy(frag, in);
- @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]);
- frag = frag[0..full_record_len];
- frag1 = frag1[second_len..];
- in = 0;
- continue;
- }
- const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
- inline else => |*p| switch (c.tls_version) {
- .tls_1_3 => {
- const pv = &p.tls_1_3;
- const P = @TypeOf(p.*);
- const ad = frag[in - tls.record_header_len ..][0..tls.record_header_len];
- const ciphertext_len = record_len - P.AEAD.tag_length;
- const ciphertext = frag[in..][0..ciphertext_len];
- in += ciphertext_len;
- const auth_tag = frag[in..][0..P.AEAD.tag_length].*;
- const nonce = nonce: {
- const V = @Vector(P.AEAD.nonce_length, u8);
- const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
- const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
- break :nonce @as(V, pv.server_iv) ^ operand;
- };
- const out_buf = vp.peek();
- const cleartext_buf = if (ciphertext.len <= out_buf.len)
- out_buf
- else
- &cleartext_stack_buffer;
- const cleartext = cleartext_buf[0..ciphertext.len];
- P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
- return error.TlsBadRecordMac;
- const msg = mem.trimEnd(u8, cleartext, "\x00");
- break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
+ c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
+ switch (inner_ct) {
+ .alert => {
+ if (cleartext.len != 2) return failRead(c, error.TlsDecodeError);
+ const alert: tls.Alert = .{
+ .level = @enumFromInt(cleartext[0]),
+ .description = @enumFromInt(cleartext[1]),
+ };
+ switch (alert.description) {
+ .close_notify => {
+ c.received_close_notify = true;
+ return 0;
},
- .tls_1_2 => {
- const pv = &p.tls_1_2;
- const P = @TypeOf(p.*);
- const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
- const ad = std.mem.toBytes(big(c.read_seq)) ++
- frag[in - tls.record_header_len ..][0 .. 1 + 2] ++
- std.mem.toBytes(big(message_len));
- const record_iv = frag[in..][0..P.record_iv_length].*;
- in += P.record_iv_length;
- const masked_read_seq = c.read_seq &
- comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
- const nonce: [P.AEAD.nonce_length]u8 = nonce: {
- const V = @Vector(P.AEAD.nonce_length, u8);
- const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
- const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq)));
- break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand;
- };
- const ciphertext = frag[in..][0..message_len];
- in += message_len;
- const auth_tag = frag[in..][0..P.mac_length].*;
- in += P.mac_length;
- const out_buf = vp.peek();
- const cleartext_buf = if (message_len <= out_buf.len)
- out_buf
- else
- &cleartext_stack_buffer;
- const cleartext = cleartext_buf[0..ciphertext.len];
- P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
- return error.TlsBadRecordMac;
- break :cleartext .{ cleartext, ct };
+ .user_canceled => {
+ // TODO: handle server-side closures
+ return failRead(c, error.TlsUnexpectedMessage);
},
- else => unreachable,
- },
- };
- c.read_seq = try std.math.add(u64, c.read_seq, 1);
- switch (inner_ct) {
- .alert => {
- if (cleartext.len != 2) return error.TlsDecodeError;
- const level: tls.AlertLevel = @enumFromInt(cleartext[0]);
- const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
- if (desc == .close_notify) {
- c.received_close_notify = true;
- c.partial_ciphertext_end = c.partial_ciphertext_idx;
- return vp.total;
- }
- _ = level;
-
- try desc.toError();
- // TODO: handle server-side closures
- return error.TlsUnexpectedMessage;
- },
- .handshake => {
- var ct_i: usize = 0;
- while (true) {
- const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]);
- ct_i += 1;
- const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big);
- ct_i += 3;
- const next_handshake_i = ct_i + handshake_len;
- if (next_handshake_i > cleartext.len)
- return error.TlsBadLength;
- const handshake = cleartext[ct_i..next_handshake_i];
- switch (handshake_type) {
- .new_session_ticket => {
- // This client implementation ignores new session tickets.
- },
- .key_update => {
- switch (c.application_cipher) {
- inline else => |*p| {
- const pv = &p.tls_1_3;
- const P = @TypeOf(p.*);
- const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length);
- if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{
- .counter = key_log.serverCounter(),
- .client_random = &key_log.client_random,
- }, .{
- .SERVER_TRAFFIC_SECRET = &server_secret,
- });
- pv.server_secret = server_secret;
- pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
- pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
- },
- }
- c.read_seq = 0;
-
- switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) {
- .update_requested => {
- switch (c.application_cipher) {
- inline else => |*p| {
- const pv = &p.tls_1_3;
- const P = @TypeOf(p.*);
- const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length);
- if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{
- .counter = key_log.clientCounter(),
- .client_random = &key_log.client_random,
- }, .{
- .CLIENT_TRAFFIC_SECRET = &client_secret,
- });
- pv.client_secret = client_secret;
- pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
- pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
- },
- }
- c.write_seq = 0;
- },
- .update_not_requested => {},
- _ => return error.TlsIllegalParameter,
- }
- },
- else => {
- return error.TlsUnexpectedMessage;
- },
- }
- ct_i = next_handshake_i;
- if (ct_i >= cleartext.len) break;
- }
- },
- .application_data => {
- // Determine whether the output buffer or a stack
- // buffer was used for storing the cleartext.
- if (cleartext.ptr == &cleartext_stack_buffer) {
- // Stack buffer was used, so we must copy to the output buffer.
- if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
- // We have already run out of room in iovecs. Continue
- // appending to `partially_read_buffer`.
- @memcpy(
- c.partially_read_buffer[c.partial_ciphertext_idx..][0..cleartext.len],
- cleartext,
- );
- c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + cleartext.len);
- } else {
- const amt = vp.put(cleartext);
- if (amt < cleartext.len) {
- const rest = cleartext[amt..];
- c.partial_cleartext_idx = 0;
- c.partial_ciphertext_idx = @intCast(rest.len);
- @memcpy(c.partially_read_buffer[0..rest.len], rest);
+ else => {
+ c.alert = alert;
+ return failRead(c, error.TlsAlert);
+ },
+ }
+ },
+ .handshake => {
+ var ct_i: usize = 0;
+ while (true) {
+ const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]);
+ ct_i += 1;
+ const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big);
+ ct_i += 3;
+ const next_handshake_i = ct_i + handshake_len;
+ if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength);
+ const handshake = cleartext[ct_i..next_handshake_i];
+ switch (handshake_type) {
+ .new_session_ticket => {
+ // This client implementation ignores new session tickets.
+ },
+ .key_update => {
+ switch (c.application_cipher) {
+ inline else => |*p| {
+ const pv = &p.tls_1_3;
+ const P = @TypeOf(p.*);
+ const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length);
+ if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
+ .counter = key_log.serverCounter(),
+ .client_random = &key_log.client_random,
+ }, .{
+ .SERVER_TRAFFIC_SECRET = &server_secret,
+ });
+ pv.server_secret = server_secret;
+ pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
+ pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
+ },
+ }
+ c.read_seq = 0;
+
+ switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) {
+ .update_requested => {
+ switch (c.application_cipher) {
+ inline else => |*p| {
+ const pv = &p.tls_1_3;
+ const P = @TypeOf(p.*);
+ const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length);
+ if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
+ .counter = key_log.clientCounter(),
+ .client_random = &key_log.client_random,
+ }, .{
+ .CLIENT_TRAFFIC_SECRET = &client_secret,
+ });
+ pv.client_secret = client_secret;
+ pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
+ pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
+ },
+ }
+ c.write_seq = 0;
+ },
+ .update_not_requested => {},
+ _ => return failRead(c, error.TlsIllegalParameter),
}
- }
- } else {
- // Output buffer was used directly which means no
- // memory copying needs to occur, and we can move
- // on to the next ciphertext record.
- vp.next(cleartext.len);
+ },
+ else => return failRead(c, error.TlsUnexpectedMessage),
}
- },
- else => return error.TlsUnexpectedMessage,
- }
- in = end;
+ ct_i = next_handshake_i;
+ if (ct_i >= cleartext.len) break;
+ }
+ return 0;
+ },
+ .application_data => {
+ if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
+ try w.writeAll(cleartext);
+ return cleartext.len;
+ },
+ else => return failRead(c, error.TlsUnexpectedMessage),
}
}
-fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void {
- const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false;
- defer if (locked) key_log_file.unlock();
- key_log_file.seekFromEnd(0) catch {};
- inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.deprecatedWriter().print("{s}" ++
+fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
+ c.read_err = err;
+ return error.ReadFailed;
+}
+
+fn logSecrets(w: *Writer, context: anytype, secrets: anytype) void {
+ inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++
(if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++
(if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{
context.client_random,
@@ -1520,62 +1267,6 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi
}) catch {};
}
-fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
- const saved_buf = frag[in..];
- if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
- // There is cleartext at the beginning already which we need to preserve.
- c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len);
- @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf);
- } else {
- c.partial_cleartext_idx = 0;
- c.partial_ciphertext_idx = 0;
- c.partial_ciphertext_end = @intCast(saved_buf.len);
- @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf);
- }
- return out;
-}
-
-/// Note that `first` usually overlaps with `c.partially_read_buffer`.
-fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
- if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
- // There is cleartext at the beginning already which we need to preserve.
- c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len);
- // TODO: eliminate this call to copyForwards
- std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first);
- @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1);
- } else {
- c.partial_cleartext_idx = 0;
- c.partial_ciphertext_idx = 0;
- c.partial_ciphertext_end = @intCast(first.len + frag1.len);
- // TODO: eliminate this call to copyForwards
- std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first);
- @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1);
- }
- return out;
-}
-
-fn limitedOverlapCopy(frag: []u8, in: usize) void {
- const first = frag[in..];
- if (first.len <= in) {
- // A single, non-overlapping memcpy suffices.
- @memcpy(frag[0..first.len], first);
- } else {
- // One memcpy call would overlap, so just do this instead.
- std.mem.copyForwards(u8, frag, first);
- }
-}
-
-fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 {
- if (index < s1.len) {
- return s1[index];
- } else {
- return s2[index - s1.len];
- }
-}
-
-const builtin = @import("builtin");
-const native_endian = builtin.cpu.arch.endian();
-
fn big(x: anytype) @TypeOf(x) {
return switch (native_endian) {
.big => x,
@@ -1836,81 +1527,6 @@ const CertificatePublicKey = struct {
}
};
-/// Abstraction for sending multiple byte buffers to a slice of iovecs.
-const VecPut = struct {
- iovecs: []const std.posix.iovec,
- idx: usize = 0,
- off: usize = 0,
- total: usize = 0,
-
- /// Returns the amount actually put which is always equal to bytes.len
- /// unless the vectors ran out of space.
- fn put(vp: *VecPut, bytes: []const u8) usize {
- if (vp.idx >= vp.iovecs.len) return 0;
- var bytes_i: usize = 0;
- while (true) {
- const v = vp.iovecs[vp.idx];
- const dest = v.base[vp.off..v.len];
- const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
- @memcpy(dest[0..src.len], src);
- bytes_i += src.len;
- vp.off += src.len;
- if (vp.off >= v.len) {
- vp.off = 0;
- vp.idx += 1;
- if (vp.idx >= vp.iovecs.len) {
- vp.total += bytes_i;
- return bytes_i;
- }
- }
- if (bytes_i >= bytes.len) {
- vp.total += bytes_i;
- return bytes_i;
- }
- }
- }
-
- /// Returns the next buffer that consecutive bytes can go into.
- fn peek(vp: VecPut) []u8 {
- if (vp.idx >= vp.iovecs.len) return &.{};
- const v = vp.iovecs[vp.idx];
- return v.base[vp.off..v.len];
- }
-
- // After writing to the result of peek(), one can call next() to
- // advance the cursor.
- fn next(vp: *VecPut, len: usize) void {
- vp.total += len;
- vp.off += len;
- if (vp.off >= vp.iovecs[vp.idx].len) {
- vp.off = 0;
- vp.idx += 1;
- }
- }
-
- fn freeSize(vp: VecPut) usize {
- if (vp.idx >= vp.iovecs.len) return 0;
- var total: usize = 0;
- total += vp.iovecs[vp.idx].len - vp.off;
- if (vp.idx + 1 >= vp.iovecs.len) return total;
- for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len;
- return total;
- }
-};
-
-/// Limit iovecs to a specific byte size.
-fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec {
- var bytes_left: usize = len;
- for (iovecs, 0..) |*iovec, vec_i| {
- if (bytes_left <= iovec.len) {
- iovec.len = bytes_left;
- return iovecs[0 .. vec_i + 1];
- }
- bytes_left -= iovec.len;
- }
- return iovecs;
-}
-
/// The priority order here is chosen based on what crypto algorithms Zig has
/// available in the standard library as well as what is faster. Following are
/// a few data points on the relative performance of these algorithms.
@@ -1954,7 +1570,3 @@ else
.AES_256_GCM_SHA384,
.ECDHE_RSA_WITH_AES_256_GCM_SHA384,
});
-
-test {
- _ = StreamInterface;
-}
diff --git a/lib/std/fifo.zig b/lib/std/fifo.zig
@@ -1,548 +0,0 @@
-// FIFO of fixed size items
-// Usually used for e.g. byte buffers
-
-const std = @import("std");
-const math = std.math;
-const mem = std.mem;
-const Allocator = mem.Allocator;
-const assert = std.debug.assert;
-const testing = std.testing;
-
-pub const LinearFifoBufferType = union(enum) {
- /// The buffer is internal to the fifo; it is of the specified size.
- Static: usize,
-
- /// The buffer is passed as a slice to the initialiser.
- Slice,
-
- /// The buffer is managed dynamically using a `mem.Allocator`.
- Dynamic,
-};
-
-pub fn LinearFifo(
- comptime T: type,
- comptime buffer_type: LinearFifoBufferType,
-) type {
- const autoalign = false;
-
- const powers_of_two = switch (buffer_type) {
- .Static => std.math.isPowerOfTwo(buffer_type.Static),
- .Slice => false, // Any size slice could be passed in
- .Dynamic => true, // This could be configurable in future
- };
-
- return struct {
- allocator: if (buffer_type == .Dynamic) Allocator else void,
- buf: if (buffer_type == .Static) [buffer_type.Static]T else []T,
- head: usize,
- count: usize,
-
- const Self = @This();
- pub const Reader = std.io.GenericReader(*Self, error{}, readFn);
- pub const Writer = std.io.GenericWriter(*Self, error{OutOfMemory}, appendWrite);
-
- // Type of Self argument for slice operations.
- // If buffer is inline (Static) then we need to ensure we haven't
- // returned a slice into a copy on the stack
- const SliceSelfArg = if (buffer_type == .Static) *Self else Self;
-
- pub const init = switch (buffer_type) {
- .Static => initStatic,
- .Slice => initSlice,
- .Dynamic => initDynamic,
- };
-
- fn initStatic() Self {
- comptime assert(buffer_type == .Static);
- return .{
- .allocator = {},
- .buf = undefined,
- .head = 0,
- .count = 0,
- };
- }
-
- fn initSlice(buf: []T) Self {
- comptime assert(buffer_type == .Slice);
- return .{
- .allocator = {},
- .buf = buf,
- .head = 0,
- .count = 0,
- };
- }
-
- fn initDynamic(allocator: Allocator) Self {
- comptime assert(buffer_type == .Dynamic);
- return .{
- .allocator = allocator,
- .buf = &.{},
- .head = 0,
- .count = 0,
- };
- }
-
- pub fn deinit(self: Self) void {
- if (buffer_type == .Dynamic) self.allocator.free(self.buf);
- }
-
- pub fn realign(self: *Self) void {
- if (self.buf.len - self.head >= self.count) {
- mem.copyForwards(T, self.buf[0..self.count], self.buf[self.head..][0..self.count]);
- self.head = 0;
- } else {
- var tmp: [4096 / 2 / @sizeOf(T)]T = undefined;
-
- while (self.head != 0) {
- const n = @min(self.head, tmp.len);
- const m = self.buf.len - n;
- @memcpy(tmp[0..n], self.buf[0..n]);
- mem.copyForwards(T, self.buf[0..m], self.buf[n..][0..m]);
- @memcpy(self.buf[m..][0..n], tmp[0..n]);
- self.head -= n;
- }
- }
- { // set unused area to undefined
- const unused = mem.sliceAsBytes(self.buf[self.count..]);
- @memset(unused, undefined);
- }
- }
-
- /// Reduce allocated capacity to `size`.
- pub fn shrink(self: *Self, size: usize) void {
- assert(size >= self.count);
- if (buffer_type == .Dynamic) {
- self.realign();
- self.buf = self.allocator.realloc(self.buf, size) catch |e| switch (e) {
- error.OutOfMemory => return, // no problem, capacity is still correct then.
- };
- }
- }
-
- /// Ensure that the buffer can fit at least `size` items
- pub fn ensureTotalCapacity(self: *Self, size: usize) !void {
- if (self.buf.len >= size) return;
- if (buffer_type == .Dynamic) {
- self.realign();
- const new_size = if (powers_of_two) math.ceilPowerOfTwo(usize, size) catch return error.OutOfMemory else size;
- self.buf = try self.allocator.realloc(self.buf, new_size);
- } else {
- return error.OutOfMemory;
- }
- }
-
- /// Makes sure at least `size` items are unused
- pub fn ensureUnusedCapacity(self: *Self, size: usize) error{OutOfMemory}!void {
- if (self.writableLength() >= size) return;
-
- return try self.ensureTotalCapacity(math.add(usize, self.count, size) catch return error.OutOfMemory);
- }
-
- /// Returns number of items currently in fifo
- pub fn readableLength(self: Self) usize {
- return self.count;
- }
-
- /// Returns a writable slice from the 'read' end of the fifo
- fn readableSliceMut(self: SliceSelfArg, offset: usize) []T {
- if (offset > self.count) return &[_]T{};
-
- var start = self.head + offset;
- if (start >= self.buf.len) {
- start -= self.buf.len;
- return self.buf[start .. start + (self.count - offset)];
- } else {
- const end = @min(self.head + self.count, self.buf.len);
- return self.buf[start..end];
- }
- }
-
- /// Returns a readable slice from `offset`
- pub fn readableSlice(self: SliceSelfArg, offset: usize) []const T {
- return self.readableSliceMut(offset);
- }
-
- pub fn readableSliceOfLen(self: *Self, len: usize) []const T {
- assert(len <= self.count);
- const buf = self.readableSlice(0);
- if (buf.len >= len) {
- return buf[0..len];
- } else {
- self.realign();
- return self.readableSlice(0)[0..len];
- }
- }
-
- /// Discard first `count` items in the fifo
- pub fn discard(self: *Self, count: usize) void {
- assert(count <= self.count);
- { // set old range to undefined. Note: may be wrapped around
- const slice = self.readableSliceMut(0);
- if (slice.len >= count) {
- const unused = mem.sliceAsBytes(slice[0..count]);
- @memset(unused, undefined);
- } else {
- const unused = mem.sliceAsBytes(slice[0..]);
- @memset(unused, undefined);
- const unused2 = mem.sliceAsBytes(self.readableSliceMut(slice.len)[0 .. count - slice.len]);
- @memset(unused2, undefined);
- }
- }
- if (autoalign and self.count == count) {
- self.head = 0;
- self.count = 0;
- } else {
- var head = self.head + count;
- if (powers_of_two) {
- // Note it is safe to do a wrapping subtract as
- // bitwise & with all 1s is a noop
- head &= self.buf.len -% 1;
- } else {
- head %= self.buf.len;
- }
- self.head = head;
- self.count -= count;
- }
- }
-
- /// Read the next item from the fifo
- pub fn readItem(self: *Self) ?T {
- if (self.count == 0) return null;
-
- const c = self.buf[self.head];
- self.discard(1);
- return c;
- }
-
- /// Read data from the fifo into `dst`, returns number of items copied.
- pub fn read(self: *Self, dst: []T) usize {
- var dst_left = dst;
-
- while (dst_left.len > 0) {
- const slice = self.readableSlice(0);
- if (slice.len == 0) break;
- const n = @min(slice.len, dst_left.len);
- @memcpy(dst_left[0..n], slice[0..n]);
- self.discard(n);
- dst_left = dst_left[n..];
- }
-
- return dst.len - dst_left.len;
- }
-
- /// Same as `read` except it returns an error union
- /// The purpose of this function existing is to match `std.io.GenericReader` API.
- fn readFn(self: *Self, dest: []u8) error{}!usize {
- return self.read(dest);
- }
-
- pub fn reader(self: *Self) Reader {
- return .{ .context = self };
- }
-
- /// Returns number of items available in fifo
- pub fn writableLength(self: Self) usize {
- return self.buf.len - self.count;
- }
-
- /// Returns the first section of writable buffer.
- /// Note that this may be of length 0
- pub fn writableSlice(self: SliceSelfArg, offset: usize) []T {
- if (offset > self.buf.len) return &[_]T{};
-
- const tail = self.head + offset + self.count;
- if (tail < self.buf.len) {
- return self.buf[tail..];
- } else {
- return self.buf[tail - self.buf.len ..][0 .. self.writableLength() - offset];
- }
- }
-
- /// Returns a writable buffer of at least `size` items, allocating memory as needed.
- /// Use `fifo.update` once you've written data to it.
- pub fn writableWithSize(self: *Self, size: usize) ![]T {
- try self.ensureUnusedCapacity(size);
-
- // try to avoid realigning buffer
- var slice = self.writableSlice(0);
- if (slice.len < size) {
- self.realign();
- slice = self.writableSlice(0);
- }
- return slice;
- }
-
- /// Update the tail location of the buffer (usually follows use of writable/writableWithSize)
- pub fn update(self: *Self, count: usize) void {
- assert(self.count + count <= self.buf.len);
- self.count += count;
- }
-
- /// Appends the data in `src` to the fifo.
- /// You must have ensured there is enough space.
- pub fn writeAssumeCapacity(self: *Self, src: []const T) void {
- assert(self.writableLength() >= src.len);
-
- var src_left = src;
- while (src_left.len > 0) {
- const writable_slice = self.writableSlice(0);
- assert(writable_slice.len != 0);
- const n = @min(writable_slice.len, src_left.len);
- @memcpy(writable_slice[0..n], src_left[0..n]);
- self.update(n);
- src_left = src_left[n..];
- }
- }
-
- /// Write a single item to the fifo
- pub fn writeItem(self: *Self, item: T) !void {
- try self.ensureUnusedCapacity(1);
- return self.writeItemAssumeCapacity(item);
- }
-
- pub fn writeItemAssumeCapacity(self: *Self, item: T) void {
- var tail = self.head + self.count;
- if (powers_of_two) {
- tail &= self.buf.len - 1;
- } else {
- tail %= self.buf.len;
- }
- self.buf[tail] = item;
- self.update(1);
- }
-
- /// Appends the data in `src` to the fifo.
- /// Allocates more memory as necessary
- pub fn write(self: *Self, src: []const T) !void {
- try self.ensureUnusedCapacity(src.len);
-
- return self.writeAssumeCapacity(src);
- }
-
- /// Same as `write` except it returns the number of bytes written, which is always the same
- /// as `bytes.len`. The purpose of this function existing is to match `std.io.GenericWriter` API.
- fn appendWrite(self: *Self, bytes: []const u8) error{OutOfMemory}!usize {
- try self.write(bytes);
- return bytes.len;
- }
-
- pub fn writer(self: *Self) Writer {
- return .{ .context = self };
- }
-
- /// Make `count` items available before the current read location
- fn rewind(self: *Self, count: usize) void {
- assert(self.writableLength() >= count);
-
- var head = self.head + (self.buf.len - count);
- if (powers_of_two) {
- head &= self.buf.len - 1;
- } else {
- head %= self.buf.len;
- }
- self.head = head;
- self.count += count;
- }
-
- /// Place data back into the read stream
- pub fn unget(self: *Self, src: []const T) !void {
- try self.ensureUnusedCapacity(src.len);
-
- self.rewind(src.len);
-
- const slice = self.readableSliceMut(0);
- if (src.len < slice.len) {
- @memcpy(slice[0..src.len], src);
- } else {
- @memcpy(slice, src[0..slice.len]);
- const slice2 = self.readableSliceMut(slice.len);
- @memcpy(slice2[0 .. src.len - slice.len], src[slice.len..]);
- }
- }
-
- /// Returns the item at `offset`.
- /// Asserts offset is within bounds.
- pub fn peekItem(self: Self, offset: usize) T {
- assert(offset < self.count);
-
- var index = self.head + offset;
- if (powers_of_two) {
- index &= self.buf.len - 1;
- } else {
- index %= self.buf.len;
- }
- return self.buf[index];
- }
-
- /// Pump data from a reader into a writer.
- /// Stops when reader returns 0 bytes (EOF).
- /// Buffer size must be set before calling; a buffer length of 0 is invalid.
- pub fn pump(self: *Self, src_reader: anytype, dest_writer: anytype) !void {
- assert(self.buf.len > 0);
- while (true) {
- if (self.writableLength() > 0) {
- const n = try src_reader.read(self.writableSlice(0));
- if (n == 0) break; // EOF
- self.update(n);
- }
- self.discard(try dest_writer.write(self.readableSlice(0)));
- }
- // flush remaining data
- while (self.readableLength() > 0) {
- self.discard(try dest_writer.write(self.readableSlice(0)));
- }
- }
-
- pub fn toOwnedSlice(self: *Self) Allocator.Error![]T {
- if (self.head != 0) self.realign();
- assert(self.head == 0);
- assert(self.count <= self.buf.len);
- const allocator = self.allocator;
- if (allocator.resize(self.buf, self.count)) {
- const result = self.buf[0..self.count];
- self.* = Self.init(allocator);
- return result;
- }
- const new_memory = try allocator.dupe(T, self.buf[0..self.count]);
- allocator.free(self.buf);
- self.* = Self.init(allocator);
- return new_memory;
- }
- };
-}
-
-test "LinearFifo(u8, .Dynamic) discard(0) from empty buffer should not error on overflow" {
- var fifo = LinearFifo(u8, .Dynamic).init(testing.allocator);
- defer fifo.deinit();
-
- // If overflow is not explicitly allowed this will crash in debug / safe mode
- fifo.discard(0);
-}
-
-test "LinearFifo(u8, .Dynamic)" {
- var fifo = LinearFifo(u8, .Dynamic).init(testing.allocator);
- defer fifo.deinit();
-
- try fifo.write("HELLO");
- try testing.expectEqual(@as(usize, 5), fifo.readableLength());
- try testing.expectEqualSlices(u8, "HELLO", fifo.readableSlice(0));
-
- {
- var i: usize = 0;
- while (i < 5) : (i += 1) {
- try fifo.write(&[_]u8{fifo.peekItem(i)});
- }
- try testing.expectEqual(@as(usize, 10), fifo.readableLength());
- try testing.expectEqualSlices(u8, "HELLOHELLO", fifo.readableSlice(0));
- }
-
- {
- try testing.expectEqual(@as(u8, 'H'), fifo.readItem().?);
- try testing.expectEqual(@as(u8, 'E'), fifo.readItem().?);
- try testing.expectEqual(@as(u8, 'L'), fifo.readItem().?);
- try testing.expectEqual(@as(u8, 'L'), fifo.readItem().?);
- try testing.expectEqual(@as(u8, 'O'), fifo.readItem().?);
- }
- try testing.expectEqual(@as(usize, 5), fifo.readableLength());
-
- { // Writes that wrap around
- try testing.expectEqual(@as(usize, 11), fifo.writableLength());
- try testing.expectEqual(@as(usize, 6), fifo.writableSlice(0).len);
- fifo.writeAssumeCapacity("6<chars<11");
- try testing.expectEqualSlices(u8, "HELLO6<char", fifo.readableSlice(0));
- try testing.expectEqualSlices(u8, "s<11", fifo.readableSlice(11));
- try testing.expectEqualSlices(u8, "11", fifo.readableSlice(13));
- try testing.expectEqualSlices(u8, "", fifo.readableSlice(15));
- fifo.discard(11);
- try testing.expectEqualSlices(u8, "s<11", fifo.readableSlice(0));
- fifo.discard(4);
- try testing.expectEqual(@as(usize, 0), fifo.readableLength());
- }
-
- {
- const buf = try fifo.writableWithSize(12);
- try testing.expectEqual(@as(usize, 12), buf.len);
- var i: u8 = 0;
- while (i < 10) : (i += 1) {
- buf[i] = i + 'a';
- }
- fifo.update(10);
- try testing.expectEqualSlices(u8, "abcdefghij", fifo.readableSlice(0));
- }
-
- {
- try fifo.unget("prependedstring");
- var result: [30]u8 = undefined;
- try testing.expectEqualSlices(u8, "prependedstringabcdefghij", result[0..fifo.read(&result)]);
- try fifo.unget("b");
- try fifo.unget("a");
- try testing.expectEqualSlices(u8, "ab", result[0..fifo.read(&result)]);
- }
-
- fifo.shrink(0);
-
- {
- try fifo.writer().print("{s}, {s}!", .{ "Hello", "World" });
- var result: [30]u8 = undefined;
- try testing.expectEqualSlices(u8, "Hello, World!", result[0..fifo.read(&result)]);
- try testing.expectEqual(@as(usize, 0), fifo.readableLength());
- }
-
- {
- try fifo.writer().writeAll("This is a test");
- var result: [30]u8 = undefined;
- try testing.expectEqualSlices(u8, "This", (try fifo.reader().readUntilDelimiterOrEof(&result, ' ')).?);
- try testing.expectEqualSlices(u8, "is", (try fifo.reader().readUntilDelimiterOrEof(&result, ' ')).?);
- try testing.expectEqualSlices(u8, "a", (try fifo.reader().readUntilDelimiterOrEof(&result, ' ')).?);
- try testing.expectEqualSlices(u8, "test", (try fifo.reader().readUntilDelimiterOrEof(&result, ' ')).?);
- }
-
- {
- try fifo.ensureTotalCapacity(1);
- var in_fbs = std.io.fixedBufferStream("pump test");
- var out_buf: [50]u8 = undefined;
- var out_fbs = std.io.fixedBufferStream(&out_buf);
- try fifo.pump(in_fbs.reader(), out_fbs.writer());
- try testing.expectEqualSlices(u8, in_fbs.buffer, out_fbs.getWritten());
- }
-}
-
-test LinearFifo {
- inline for ([_]type{ u1, u8, u16, u64 }) |T| {
- inline for ([_]LinearFifoBufferType{ LinearFifoBufferType{ .Static = 32 }, .Slice, .Dynamic }) |bt| {
- const FifoType = LinearFifo(T, bt);
- var buf: if (bt == .Slice) [32]T else void = undefined;
- var fifo = switch (bt) {
- .Static => FifoType.init(),
- .Slice => FifoType.init(buf[0..]),
- .Dynamic => FifoType.init(testing.allocator),
- };
- defer fifo.deinit();
-
- try fifo.write(&[_]T{ 0, 1, 1, 0, 1 });
- try testing.expectEqual(@as(usize, 5), fifo.readableLength());
-
- {
- try testing.expectEqual(@as(T, 0), fifo.readItem().?);
- try testing.expectEqual(@as(T, 1), fifo.readItem().?);
- try testing.expectEqual(@as(T, 1), fifo.readItem().?);
- try testing.expectEqual(@as(T, 0), fifo.readItem().?);
- try testing.expectEqual(@as(T, 1), fifo.readItem().?);
- try testing.expectEqual(@as(usize, 0), fifo.readableLength());
- }
-
- {
- try fifo.writeItem(1);
- try fifo.writeItem(1);
- try fifo.writeItem(1);
- try testing.expectEqual(@as(usize, 3), fifo.readableLength());
- }
-
- {
- var readBuf: [3]T = undefined;
- const n = fifo.read(&readBuf);
- try testing.expectEqual(@as(usize, 3), n); // NOTE: It should be the number of items.
- }
- }
- }
-}
diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig
@@ -1351,8 +1351,7 @@ pub const Reader = struct {
}
r.pos += n;
if (n > data_size) {
- io_reader.seek = 0;
- io_reader.end = n - data_size;
+ io_reader.end += n - data_size;
return data_size;
}
return n;
@@ -1386,8 +1385,7 @@ pub const Reader = struct {
}
r.pos += n;
if (n > data_size) {
- io_reader.seek = 0;
- io_reader.end = n - data_size;
+ io_reader.end += n - data_size;
return data_size;
}
return n;
diff --git a/lib/std/http.zig b/lib/std/http.zig
@@ -1,14 +1,14 @@
const builtin = @import("builtin");
const std = @import("std.zig");
const assert = std.debug.assert;
+const Writer = std.Io.Writer;
+const File = std.fs.File;
pub const Client = @import("http/Client.zig");
pub const Server = @import("http/Server.zig");
-pub const protocol = @import("http/protocol.zig");
pub const HeadParser = @import("http/HeadParser.zig");
pub const ChunkParser = @import("http/ChunkParser.zig");
pub const HeaderIterator = @import("http/HeaderIterator.zig");
-pub const WebSocket = @import("http/WebSocket.zig");
pub const Version = enum {
@"HTTP/1.0",
@@ -20,51 +20,32 @@ pub const Version = enum {
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition
///
/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH
-pub const Method = enum(u64) {
- GET = parse("GET"),
- HEAD = parse("HEAD"),
- POST = parse("POST"),
- PUT = parse("PUT"),
- DELETE = parse("DELETE"),
- CONNECT = parse("CONNECT"),
- OPTIONS = parse("OPTIONS"),
- TRACE = parse("TRACE"),
- PATCH = parse("PATCH"),
-
- _,
-
- /// Converts `s` into a type that may be used as a `Method` field.
- /// Asserts that `s` is 24 or fewer bytes.
- pub fn parse(s: []const u8) u64 {
- var x: u64 = 0;
- const len = @min(s.len, @sizeOf(@TypeOf(x)));
- @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]);
- return x;
- }
-
- pub fn format(self: Method, w: *std.io.Writer) std.io.Writer.Error!void {
- const bytes: []const u8 = @ptrCast(&@intFromEnum(self));
- const str = std.mem.sliceTo(bytes, 0);
- try w.writeAll(str);
- }
+pub const Method = enum {
+ GET,
+ HEAD,
+ POST,
+ PUT,
+ DELETE,
+ CONNECT,
+ OPTIONS,
+ TRACE,
+ PATCH,
/// Returns true if a request of this method is allowed to have a body
/// Actual behavior from servers may vary and should still be checked
- pub fn requestHasBody(self: Method) bool {
- return switch (self) {
+ pub fn requestHasBody(m: Method) bool {
+ return switch (m) {
.POST, .PUT, .PATCH => true,
.GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false,
- else => true,
};
}
/// Returns true if a response to this method is allowed to have a body
/// Actual behavior from clients may vary and should still be checked
- pub fn responseHasBody(self: Method) bool {
- return switch (self) {
+ pub fn responseHasBody(m: Method) bool {
+ return switch (m) {
.GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true,
.HEAD, .PUT, .TRACE => false,
- else => true,
};
}
@@ -73,11 +54,10 @@ pub const Method = enum(u64) {
/// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
///
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1
- pub fn safe(self: Method) bool {
- return switch (self) {
+ pub fn safe(m: Method) bool {
+ return switch (m) {
.GET, .HEAD, .OPTIONS, .TRACE => true,
.POST, .PUT, .DELETE, .CONNECT, .PATCH => false,
- else => false,
};
}
@@ -88,11 +68,10 @@ pub const Method = enum(u64) {
/// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent
///
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2
- pub fn idempotent(self: Method) bool {
- return switch (self) {
+ pub fn idempotent(m: Method) bool {
+ return switch (m) {
.GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true,
.CONNECT, .POST, .PATCH => false,
- else => false,
};
}
@@ -102,11 +81,10 @@ pub const Method = enum(u64) {
/// https://developer.mozilla.org/en-US/docs/Glossary/cacheable
///
/// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3
- pub fn cacheable(self: Method) bool {
- return switch (self) {
+ pub fn cacheable(m: Method) bool {
+ return switch (m) {
.GET, .HEAD => true,
.POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false,
- else => false,
};
}
};
@@ -296,13 +274,24 @@ pub const TransferEncoding = enum {
};
pub const ContentEncoding = enum {
- identity,
- compress,
- @"x-compress",
- deflate,
- gzip,
- @"x-gzip",
zstd,
+ gzip,
+ deflate,
+ compress,
+ identity,
+
+ pub fn fromString(s: []const u8) ?ContentEncoding {
+ const map = std.StaticStringMap(ContentEncoding).initComptime(.{
+ .{ "zstd", .zstd },
+ .{ "gzip", .gzip },
+ .{ "x-gzip", .gzip },
+ .{ "deflate", .deflate },
+ .{ "compress", .compress },
+ .{ "x-compress", .compress },
+ .{ "identity", .identity },
+ });
+ return map.get(s);
+ }
};
pub const Connection = enum {
@@ -315,15 +304,790 @@ pub const Header = struct {
value: []const u8,
};
+pub const Reader = struct {
+ in: *std.Io.Reader,
+ /// This is preallocated memory that might be used by `bodyReader`. That
+ /// function might return a pointer to this field, or a different
+ /// `*std.Io.Reader`. Advisable to not access this field directly.
+ interface: std.Io.Reader,
+ /// Keeps track of whether the stream is ready to accept a new request,
+ /// making invalid API usage cause assertion failures rather than HTTP
+ /// protocol violations.
+ state: State,
+ /// HTTP trailer bytes. These are at the end of a transfer-encoding:
+ /// chunked message. This data is available only after calling one of the
+ /// "end" functions and points to data inside the buffer of `in`, and is
+ /// therefore invalidated on the next call to `receiveHead`, or any other
+ /// read from `in`.
+ trailers: []const u8 = &.{},
+ body_err: ?BodyError = null,
+
+ pub const RemainingChunkLen = enum(u64) {
+ head = 0,
+ n = 1,
+ rn = 2,
+ _,
+
+ pub fn init(integer: u64) RemainingChunkLen {
+ return @enumFromInt(integer);
+ }
+
+ pub fn int(rcl: RemainingChunkLen) u64 {
+ return @intFromEnum(rcl);
+ }
+ };
+
+ pub const State = union(enum) {
+ /// The stream is available to be used for the first time, or reused.
+ ready,
+ received_head,
+ /// The stream goes until the connection is closed.
+ body_none,
+ body_remaining_content_length: u64,
+ body_remaining_chunk_len: RemainingChunkLen,
+ /// The stream would be eligible for another HTTP request, however the
+ /// client and server did not negotiate a persistent connection.
+ closing,
+ };
+
+ pub const BodyError = error{
+ HttpChunkInvalid,
+ HttpChunkTruncated,
+ HttpHeadersOversize,
+ };
+
+ pub const HeadError = error{
+ /// Too many bytes of HTTP headers.
+ ///
+ /// The HTTP specification suggests to respond with a 431 status code
+ /// before closing the connection.
+ HttpHeadersOversize,
+ /// Partial HTTP request was received but the connection was closed
+ /// before fully receiving the headers.
+ HttpRequestTruncated,
+ /// The client sent 0 bytes of headers before closing the stream. This
+ /// happens when a keep-alive connection is finally closed.
+ HttpConnectionClosing,
+ /// Transitive error occurred reading from `in`.
+ ReadFailed,
+ };
+
+ /// Buffers the entire head inside `in`.
+ ///
+ /// The resulting memory is invalidated by any subsequent consumption of
+ /// the input stream.
+ pub fn receiveHead(reader: *Reader) HeadError![]const u8 {
+ reader.trailers = &.{};
+ const in = reader.in;
+ var hp: HeadParser = .{};
+ var head_len: usize = 0;
+ while (true) {
+ if (in.buffer.len - head_len == 0) return error.HttpHeadersOversize;
+ const remaining = in.buffered()[head_len..];
+ if (remaining.len == 0) {
+ in.fillMore() catch |err| switch (err) {
+ error.EndOfStream => switch (head_len) {
+ 0 => return error.HttpConnectionClosing,
+ else => return error.HttpRequestTruncated,
+ },
+ error.ReadFailed => return error.ReadFailed,
+ };
+ continue;
+ }
+ head_len += hp.feed(remaining);
+ if (hp.state == .finished) {
+ reader.state = .received_head;
+ const head_buffer = in.buffered()[0..head_len];
+ in.toss(head_len);
+ return head_buffer;
+ }
+ }
+ }
+
+ /// If compressed body has been negotiated this will return compressed bytes.
+ ///
+ /// Asserts only called once and after `receiveHead`.
+ ///
+ /// See also:
+ /// * `interfaceDecompressing`
+ pub fn bodyReader(
+ reader: *Reader,
+ buffer: []u8,
+ transfer_encoding: TransferEncoding,
+ content_length: ?u64,
+ ) *std.Io.Reader {
+ assert(reader.state == .received_head);
+ switch (transfer_encoding) {
+ .chunked => {
+ reader.state = .{ .body_remaining_chunk_len = .head };
+ reader.interface = .{
+ .buffer = buffer,
+ .seek = 0,
+ .end = 0,
+ .vtable = &.{
+ .stream = chunkedStream,
+ .discard = chunkedDiscard,
+ },
+ };
+ return &reader.interface;
+ },
+ .none => {
+ if (content_length) |len| {
+ reader.state = .{ .body_remaining_content_length = len };
+ reader.interface = .{
+ .buffer = buffer,
+ .seek = 0,
+ .end = 0,
+ .vtable = &.{
+ .stream = contentLengthStream,
+ .discard = contentLengthDiscard,
+ },
+ };
+ return &reader.interface;
+ } else {
+ reader.state = .body_none;
+ return reader.in;
+ }
+ },
+ }
+ }
+
+ /// If compressed body has been negotiated this will return decompressed bytes.
+ ///
+ /// Asserts only called once and after `receiveHead`.
+ ///
+ /// See also:
+ /// * `interface`
+ pub fn bodyReaderDecompressing(
+ reader: *Reader,
+ transfer_encoding: TransferEncoding,
+ content_length: ?u64,
+ content_encoding: ContentEncoding,
+ decompressor: *Decompressor,
+ decompression_buffer: []u8,
+ ) *std.Io.Reader {
+ if (transfer_encoding == .none and content_length == null) {
+ assert(reader.state == .received_head);
+ reader.state = .body_none;
+ switch (content_encoding) {
+ .identity => {
+ return reader.in;
+ },
+ .deflate => {
+ decompressor.* = .{ .flate = .init(reader.in, .zlib, decompression_buffer) };
+ return &decompressor.flate.reader;
+ },
+ .gzip => {
+ decompressor.* = .{ .flate = .init(reader.in, .gzip, decompression_buffer) };
+ return &decompressor.flate.reader;
+ },
+ .zstd => {
+ decompressor.* = .{ .zstd = .init(reader.in, decompression_buffer, .{ .verify_checksum = false }) };
+ return &decompressor.zstd.reader;
+ },
+ .compress => unreachable,
+ }
+ }
+ const transfer_reader = bodyReader(reader, &.{}, transfer_encoding, content_length);
+ return decompressor.init(transfer_reader, decompression_buffer, content_encoding);
+ }
+
+ fn contentLengthStream(
+ io_r: *std.Io.Reader,
+ w: *Writer,
+ limit: std.Io.Limit,
+ ) std.Io.Reader.StreamError!usize {
+ const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r));
+ const remaining_content_length = &reader.state.body_remaining_content_length;
+ const remaining = remaining_content_length.*;
+ if (remaining == 0) {
+ reader.state = .ready;
+ return error.EndOfStream;
+ }
+ const n = try reader.in.stream(w, limit.min(.limited64(remaining)));
+ remaining_content_length.* = remaining - n;
+ return n;
+ }
+
+ fn contentLengthDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize {
+ const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r));
+ const remaining_content_length = &reader.state.body_remaining_content_length;
+ const remaining = remaining_content_length.*;
+ if (remaining == 0) {
+ reader.state = .ready;
+ return error.EndOfStream;
+ }
+ const n = try reader.in.discard(limit.min(.limited64(remaining)));
+ remaining_content_length.* = remaining - n;
+ return n;
+ }
+
+ fn chunkedStream(io_r: *std.Io.Reader, w: *Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize {
+ const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r));
+ const chunk_len_ptr = switch (reader.state) {
+ .ready => return error.EndOfStream,
+ .body_remaining_chunk_len => |*x| x,
+ else => unreachable,
+ };
+ return chunkedReadEndless(reader, w, limit, chunk_len_ptr) catch |err| switch (err) {
+ error.ReadFailed => return error.ReadFailed,
+ error.WriteFailed => return error.WriteFailed,
+ error.EndOfStream => {
+ reader.body_err = error.HttpChunkTruncated;
+ return error.ReadFailed;
+ },
+ else => |e| {
+ reader.body_err = e;
+ return error.ReadFailed;
+ },
+ };
+ }
+
+ fn chunkedReadEndless(
+ reader: *Reader,
+ w: *Writer,
+ limit: std.Io.Limit,
+ chunk_len_ptr: *RemainingChunkLen,
+ ) (BodyError || std.Io.Reader.StreamError)!usize {
+ const in = reader.in;
+ len: switch (chunk_len_ptr.*) {
+ .head => {
+ var cp: ChunkParser = .init;
+ while (true) {
+ const i = cp.feed(in.buffered());
+ switch (cp.state) {
+ .invalid => return error.HttpChunkInvalid,
+ .data => {
+ in.toss(i);
+ break;
+ },
+ else => {
+ in.toss(i);
+ try in.fillMore();
+ continue;
+ },
+ }
+ }
+ if (cp.chunk_len == 0) return parseTrailers(reader, 0);
+ const n = try in.stream(w, limit.min(.limited64(cp.chunk_len)));
+ chunk_len_ptr.* = .init(cp.chunk_len + 2 - n);
+ return n;
+ },
+ .n => {
+ if ((try in.peekByte()) != '\n') return error.HttpChunkInvalid;
+ in.toss(1);
+ continue :len .head;
+ },
+ .rn => {
+ const rn = try in.peekArray(2);
+ if (rn[0] != '\r' or rn[1] != '\n') return error.HttpChunkInvalid;
+ in.toss(2);
+ continue :len .head;
+ },
+ else => |remaining_chunk_len| {
+ const n = try in.stream(w, limit.min(.limited64(@intFromEnum(remaining_chunk_len) - 2)));
+ chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n);
+ return n;
+ },
+ }
+ }
+
+ fn chunkedDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize {
+ const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r));
+ const chunk_len_ptr = switch (reader.state) {
+ .ready => return error.EndOfStream,
+ .body_remaining_chunk_len => |*x| x,
+ else => unreachable,
+ };
+ return chunkedDiscardEndless(reader, limit, chunk_len_ptr) catch |err| switch (err) {
+ error.ReadFailed => return error.ReadFailed,
+ error.EndOfStream => {
+ reader.body_err = error.HttpChunkTruncated;
+ return error.ReadFailed;
+ },
+ else => |e| {
+ reader.body_err = e;
+ return error.ReadFailed;
+ },
+ };
+ }
+
+ fn chunkedDiscardEndless(
+ reader: *Reader,
+ limit: std.Io.Limit,
+ chunk_len_ptr: *RemainingChunkLen,
+ ) (BodyError || std.Io.Reader.Error)!usize {
+ const in = reader.in;
+ len: switch (chunk_len_ptr.*) {
+ .head => {
+ var cp: ChunkParser = .init;
+ while (true) {
+ const i = cp.feed(in.buffered());
+ switch (cp.state) {
+ .invalid => return error.HttpChunkInvalid,
+ .data => {
+ in.toss(i);
+ break;
+ },
+ else => {
+ in.toss(i);
+ try in.fillMore();
+ continue;
+ },
+ }
+ }
+ if (cp.chunk_len == 0) return parseTrailers(reader, 0);
+ const n = try in.discard(limit.min(.limited64(cp.chunk_len)));
+ chunk_len_ptr.* = .init(cp.chunk_len + 2 - n);
+ return n;
+ },
+ .n => {
+ if ((try in.peekByte()) != '\n') return error.HttpChunkInvalid;
+ in.toss(1);
+ continue :len .head;
+ },
+ .rn => {
+ const rn = try in.peekArray(2);
+ if (rn[0] != '\r' or rn[1] != '\n') return error.HttpChunkInvalid;
+ in.toss(2);
+ continue :len .head;
+ },
+ else => |remaining_chunk_len| {
+ const n = try in.discard(limit.min(.limited64(remaining_chunk_len.int() - 2)));
+ chunk_len_ptr.* = .init(remaining_chunk_len.int() - n);
+ return n;
+ },
+ }
+ }
+
+ /// Called when next bytes in the stream are trailers, or "\r\n" to indicate
+ /// end of chunked body.
+ fn parseTrailers(reader: *Reader, amt_read: usize) (BodyError || std.Io.Reader.Error)!usize {
+ const in = reader.in;
+ const rn = try in.peekArray(2);
+ if (rn[0] == '\r' and rn[1] == '\n') {
+ in.toss(2);
+ reader.state = .ready;
+ assert(reader.trailers.len == 0);
+ return amt_read;
+ }
+ var hp: HeadParser = .{ .state = .seen_rn };
+ var trailers_len: usize = 2;
+ while (true) {
+ if (in.buffer.len - trailers_len == 0) return error.HttpHeadersOversize;
+ const remaining = in.buffered()[trailers_len..];
+ if (remaining.len == 0) {
+ try in.fillMore();
+ continue;
+ }
+ trailers_len += hp.feed(remaining);
+ if (hp.state == .finished) {
+ reader.state = .ready;
+ reader.trailers = in.buffered()[0..trailers_len];
+ in.toss(trailers_len);
+ return amt_read;
+ }
+ }
+ }
+};
+
+pub const Decompressor = union(enum) {
+ flate: std.compress.flate.Decompress,
+ zstd: std.compress.zstd.Decompress,
+ none: *std.Io.Reader,
+
+ pub fn init(
+ decompressor: *Decompressor,
+ transfer_reader: *std.Io.Reader,
+ buffer: []u8,
+ content_encoding: ContentEncoding,
+ ) *std.Io.Reader {
+ switch (content_encoding) {
+ .identity => {
+ decompressor.* = .{ .none = transfer_reader };
+ return transfer_reader;
+ },
+ .deflate => {
+ decompressor.* = .{ .flate = .init(transfer_reader, .zlib, buffer) };
+ return &decompressor.flate.reader;
+ },
+ .gzip => {
+ decompressor.* = .{ .flate = .init(transfer_reader, .gzip, buffer) };
+ return &decompressor.flate.reader;
+ },
+ .zstd => {
+ decompressor.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) };
+ return &decompressor.zstd.reader;
+ },
+ .compress => unreachable,
+ }
+ }
+};
+
+/// Request or response body.
+pub const BodyWriter = struct {
+ /// Until the lifetime of `BodyWriter` ends, it is illegal to modify the
+ /// state of this other than via methods of `BodyWriter`.
+ http_protocol_output: *Writer,
+ state: State,
+ writer: Writer,
+
+ pub const Error = Writer.Error;
+
+ /// How many zeroes to reserve for hex-encoded chunk length.
+ const chunk_len_digits = 8;
+ const max_chunk_len: usize = std.math.pow(u64, 16, chunk_len_digits) - 1;
+ const chunk_header_template = ("0" ** chunk_len_digits) ++ "\r\n";
+
+ comptime {
+ assert(max_chunk_len == std.math.maxInt(u32));
+ }
+
+ pub const State = union(enum) {
+ /// End of connection signals the end of the stream.
+ none,
+ /// As a debugging utility, counts down to zero as bytes are written.
+ content_length: u64,
+ /// Each chunk is wrapped in a header and trailer.
+ chunked: Chunked,
+ /// Cleanly finished stream; connection can be reused.
+ end,
+
+ pub const Chunked = union(enum) {
+ /// Index to the start of the hex-encoded chunk length in the chunk
+ /// header within the buffer of `BodyWriter.http_protocol_output`.
+ /// Buffered chunk data starts here plus length of `chunk_header_template`.
+ offset: usize,
+ /// We are in the middle of a chunk and this is how many bytes are
+ /// left until the next header. This includes +2 for "\r"\n", and
+ /// is zero for the beginning of the stream.
+ chunk_len: usize,
+
+ pub const init: Chunked = .{ .chunk_len = 0 };
+ };
+ };
+
+ pub fn isEliding(w: *const BodyWriter) bool {
+ return w.writer.vtable.drain == elidingDrain;
+ }
+
+ /// Sends all buffered data across `BodyWriter.http_protocol_output`.
+ pub fn flush(w: *BodyWriter) Error!void {
+ const out = w.http_protocol_output;
+ switch (w.state) {
+ .end, .none, .content_length => return out.flush(),
+ .chunked => |*chunked| switch (chunked.*) {
+ .offset => |offset| {
+ const chunk_len = out.end - offset - chunk_header_template.len;
+ if (chunk_len > 0) {
+ writeHex(out.buffer[offset..][0..chunk_len_digits], chunk_len);
+ chunked.* = .{ .chunk_len = 2 };
+ } else {
+ out.end = offset;
+ chunked.* = .{ .chunk_len = 0 };
+ }
+ try out.flush();
+ },
+ .chunk_len => return out.flush(),
+ },
+ }
+ }
+
+ /// When using content-length, asserts that the amount of data sent matches
+ /// the value sent in the header, then flushes.
+ ///
+ /// When using transfer-encoding: chunked, writes the end-of-stream message
+ /// with empty trailers, then flushes the stream to the system. Asserts any
+ /// started chunk has been completely finished.
+ ///
+ /// Respects the value of `isEliding` to omit all data after the headers.
+ ///
+ /// See also:
+ /// * `endUnflushed`
+ /// * `endChunked`
+ pub fn end(w: *BodyWriter) Error!void {
+ try endUnflushed(w);
+ try w.http_protocol_output.flush();
+ }
+
+ /// When using content-length, asserts that the amount of data sent matches
+ /// the value sent in the header.
+ ///
+ /// Otherwise, transfer-encoding: chunked is being used, and it writes the
+ /// end-of-stream message with empty trailers.
+ ///
+ /// Respects the value of `isEliding` to omit all data after the headers.
+ ///
+ /// See also:
+ /// * `end`
+ /// * `endChunked`
+ pub fn endUnflushed(w: *BodyWriter) Error!void {
+ switch (w.state) {
+ .end => unreachable,
+ .content_length => |len| {
+ assert(len == 0); // Trips when end() called before all bytes written.
+ w.state = .end;
+ },
+ .none => {},
+ .chunked => return endChunkedUnflushed(w, .{}),
+ }
+ }
+
+ pub const EndChunkedOptions = struct {
+ trailers: []const Header = &.{},
+ };
+
+ /// Writes the end-of-stream message and any optional trailers, flushing
+ /// the underlying stream.
+ ///
+ /// Asserts that the BodyWriter is using transfer-encoding: chunked.
+ ///
+ /// Respects the value of `isEliding` to omit all data after the headers.
+ ///
+ /// See also:
+ /// * `endChunkedUnflushed`
+ /// * `end`
+ pub fn endChunked(w: *BodyWriter, options: EndChunkedOptions) Error!void {
+ try endChunkedUnflushed(w, options);
+ try w.http_protocol_output.flush();
+ }
+
+ /// Writes the end-of-stream message and any optional trailers.
+ ///
+ /// Does not flush.
+ ///
+ /// Asserts that the BodyWriter is using transfer-encoding: chunked.
+ ///
+ /// Respects the value of `isEliding` to omit all data after the headers.
+ ///
+ /// See also:
+ /// * `endChunked`
+ /// * `endUnflushed`
+ /// * `end`
+ pub fn endChunkedUnflushed(w: *BodyWriter, options: EndChunkedOptions) Error!void {
+ const chunked = &w.state.chunked;
+ if (w.isEliding()) {
+ w.state = .end;
+ return;
+ }
+ const bw = w.http_protocol_output;
+ switch (chunked.*) {
+ .offset => |offset| {
+ const chunk_len = bw.end - offset - chunk_header_template.len;
+ writeHex(bw.buffer[offset..][0..chunk_len_digits], chunk_len);
+ try bw.writeAll("\r\n");
+ },
+ .chunk_len => |chunk_len| switch (chunk_len) {
+ 0 => {},
+ 1 => try bw.writeByte('\n'),
+ 2 => try bw.writeAll("\r\n"),
+ else => unreachable, // An earlier write call indicated more data would follow.
+ },
+ }
+ try bw.writeAll("0\r\n");
+ for (options.trailers) |trailer| {
+ try bw.writeAll(trailer.name);
+ try bw.writeAll(": ");
+ try bw.writeAll(trailer.value);
+ try bw.writeAll("\r\n");
+ }
+ try bw.writeAll("\r\n");
+ w.state = .end;
+ }
+
+ pub fn contentLengthDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
+ const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w));
+ assert(!bw.isEliding());
+ const out = bw.http_protocol_output;
+ const n = try out.writeSplatHeader(w.buffered(), data, splat);
+ bw.state.content_length -= n;
+ return w.consume(n);
+ }
+
+ pub fn noneDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
+ const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w));
+ assert(!bw.isEliding());
+ const out = bw.http_protocol_output;
+ const n = try out.writeSplatHeader(w.buffered(), data, splat);
+ return w.consume(n);
+ }
+
+ pub fn elidingDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
+ const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w));
+ const slice = data[0 .. data.len - 1];
+ const pattern = data[slice.len];
+ var written: usize = pattern.len * splat;
+ for (slice) |bytes| written += bytes.len;
+ switch (bw.state) {
+ .content_length => |*len| len.* -= written + w.end,
+ else => {},
+ }
+ w.end = 0;
+ return written;
+ }
+
+ pub fn elidingSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize {
+ const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w));
+ if (File.Handle == void) return error.Unimplemented;
+ if (builtin.zig_backend == .stage2_aarch64) return error.Unimplemented;
+ switch (bw.state) {
+ .content_length => |*len| len.* -= w.end,
+ else => {},
+ }
+ w.end = 0;
+ if (limit == .nothing) return 0;
+ if (file_reader.getSize()) |size| {
+ const n = limit.minInt64(size - file_reader.pos);
+ if (n == 0) return error.EndOfStream;
+ file_reader.seekBy(@intCast(n)) catch return error.Unimplemented;
+ switch (bw.state) {
+ .content_length => |*len| len.* -= n,
+ else => {},
+ }
+ return n;
+ } else |_| {
+ // Error is observable on `file_reader` instance, and it is better to
+ // treat the file as a pipe.
+ return error.Unimplemented;
+ }
+ }
+
+ /// Returns `null` if size cannot be computed without making any syscalls.
+ pub fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize {
+ const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w));
+ assert(!bw.isEliding());
+ const out = bw.http_protocol_output;
+ const n = try out.sendFileHeader(w.buffered(), file_reader, limit);
+ return w.consume(n);
+ }
+
+ pub fn contentLengthSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize {
+ const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w));
+ assert(!bw.isEliding());
+ const out = bw.http_protocol_output;
+ const n = try out.sendFileHeader(w.buffered(), file_reader, limit);
+ bw.state.content_length -= n;
+ return w.consume(n);
+ }
+
+ pub fn chunkedSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize {
+ const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w));
+ assert(!bw.isEliding());
+ const data_len = Writer.countSendFileLowerBound(w.end, file_reader, limit) orelse {
+ // If the file size is unknown, we cannot lower to a `sendFile` since we would
+ // have to flush the chunk header before knowing the chunk length.
+ return error.Unimplemented;
+ };
+ const out = bw.http_protocol_output;
+ const chunked = &bw.state.chunked;
+ state: switch (chunked.*) {
+ .offset => |off| {
+ // TODO: is it better perf to read small files into the buffer?
+ const buffered_len = out.end - off - chunk_header_template.len;
+ const chunk_len = data_len + buffered_len;
+ writeHex(out.buffer[off..][0..chunk_len_digits], chunk_len);
+ const n = try out.sendFileHeader(w.buffered(), file_reader, limit);
+ chunked.* = .{ .chunk_len = data_len + 2 - n };
+ return w.consume(n);
+ },
+ .chunk_len => |chunk_len| l: switch (chunk_len) {
+ 0 => {
+ const off = out.end;
+ const header_buf = try out.writableArray(chunk_header_template.len);
+ @memcpy(header_buf, chunk_header_template);
+ chunked.* = .{ .offset = off };
+ continue :state .{ .offset = off };
+ },
+ 1 => {
+ try out.writeByte('\n');
+ chunked.chunk_len = 0;
+ continue :l 0;
+ },
+ 2 => {
+ try out.writeByte('\r');
+ chunked.chunk_len = 1;
+ continue :l 1;
+ },
+ else => {
+ const new_limit = limit.min(.limited(chunk_len - 2));
+ const n = try out.sendFileHeader(w.buffered(), file_reader, new_limit);
+ chunked.chunk_len = chunk_len - n;
+ return w.consume(n);
+ },
+ },
+ }
+ }
+
+ pub fn chunkedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
+ const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w));
+ assert(!bw.isEliding());
+ const out = bw.http_protocol_output;
+ const data_len = w.end + Writer.countSplat(data, splat);
+ const chunked = &bw.state.chunked;
+ state: switch (chunked.*) {
+ .offset => |offset| {
+ if (out.unusedCapacityLen() >= data_len) {
+ return w.consume(out.writeSplatHeader(w.buffered(), data, splat) catch unreachable);
+ }
+ const buffered_len = out.end - offset - chunk_header_template.len;
+ const chunk_len = data_len + buffered_len;
+ writeHex(out.buffer[offset..][0..chunk_len_digits], chunk_len);
+ const n = try out.writeSplatHeader(w.buffered(), data, splat);
+ chunked.* = .{ .chunk_len = data_len + 2 - n };
+ return w.consume(n);
+ },
+ .chunk_len => |chunk_len| l: switch (chunk_len) {
+ 0 => {
+ const offset = out.end;
+ const header_buf = try out.writableArray(chunk_header_template.len);
+ @memcpy(header_buf, chunk_header_template);
+ chunked.* = .{ .offset = offset };
+ continue :state .{ .offset = offset };
+ },
+ 1 => {
+ try out.writeByte('\n');
+ chunked.chunk_len = 0;
+ continue :l 0;
+ },
+ 2 => {
+ try out.writeByte('\r');
+ chunked.chunk_len = 1;
+ continue :l 1;
+ },
+ else => {
+ const n = try out.writeSplatHeaderLimit(w.buffered(), data, splat, .limited(chunk_len - 2));
+ chunked.chunk_len = chunk_len - n;
+ return w.consume(n);
+ },
+ },
+ }
+ }
+
+ /// Writes an integer as base 16 to `buf`, right-aligned, assuming the
+ /// buffer has already been filled with zeroes.
+ fn writeHex(buf: []u8, x: usize) void {
+ assert(std.mem.allEqual(u8, buf, '0'));
+ const base = 16;
+ var index: usize = buf.len;
+ var a = x;
+ while (a > 0) {
+ const digit = a % base;
+ index -= 1;
+ buf[index] = std.fmt.digitToChar(@intCast(digit), .lower);
+ a /= base;
+ }
+ }
+};
+
test {
+ _ = Server;
+ _ = Status;
+ _ = Method;
+ _ = ChunkParser;
+ _ = HeadParser;
+
if (builtin.os.tag != .wasi) {
_ = Client;
- _ = Method;
- _ = Server;
- _ = Status;
- _ = HeadParser;
- _ = ChunkParser;
- _ = WebSocket;
_ = @import("http/test.zig");
}
}
diff --git a/lib/std/http/ChunkParser.zig b/lib/std/http/ChunkParser.zig
@@ -1,5 +1,8 @@
//! Parser for transfer-encoding: chunked.
+const ChunkParser = @This();
+const std = @import("std");
+
state: State,
chunk_len: u64,
@@ -97,9 +100,6 @@ pub fn feed(p: *ChunkParser, bytes: []const u8) usize {
return bytes.len;
}
-const ChunkParser = @This();
-const std = @import("std");
-
test feed {
const testing = std.testing;
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig
@@ -13,9 +13,10 @@ const net = std.net;
const Uri = std.Uri;
const Allocator = mem.Allocator;
const assert = std.debug.assert;
+const Writer = std.io.Writer;
+const Reader = std.io.Reader;
const Client = @This();
-const proto = @import("protocol.zig");
pub const disable_tls = std.options.http_disable_tls;
@@ -24,6 +25,12 @@ allocator: Allocator,
ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{},
ca_bundle_mutex: std.Thread.Mutex = .{},
+/// Used both for the reader and writer buffers.
+tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
+/// If non-null, ssl secrets are logged to a stream. Creating such a stream
+/// allows other processes with access to that stream to decrypt all
+/// traffic over connections created with this `Client`.
+ssl_key_log: ?*std.crypto.tls.Client.SslKeyLog = null,
/// When this is `true`, the next time this client performs an HTTPS request,
/// it will first rescan the system for root certificates.
@@ -31,6 +38,13 @@ next_https_rescan_certs: bool = true,
/// The pool of connections that can be reused (and currently in use).
connection_pool: ConnectionPool = .{},
+/// Each `Connection` allocates this amount for the reader buffer.
+///
+/// If the entire HTTP header cannot fit in this amount of bytes,
+/// `error.HttpHeadersOversize` will be returned from `Request.wait`.
+read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
+/// Each `Connection` allocates this amount for the writer buffer.
+write_buffer_size: usize = 1024,
/// If populated, all http traffic travels through this third party.
/// This field cannot be modified while the client has active connections.
@@ -41,7 +55,7 @@ http_proxy: ?*Proxy = null,
/// Pointer to externally-owned memory.
https_proxy: ?*Proxy = null,
-/// A set of linked lists of connections that can be reused.
+/// A Least-Recently-Used cache of open connections to be reused.
pub const ConnectionPool = struct {
mutex: std.Thread.Mutex = .{},
/// Open connections that are currently in use.
@@ -55,23 +69,25 @@ pub const ConnectionPool = struct {
pub const Criteria = struct {
host: []const u8,
port: u16,
- protocol: Connection.Protocol,
+ protocol: Protocol,
};
- /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
+ /// Finds and acquires a connection from the connection pool matching the criteria.
/// If no connection is found, null is returned.
+ ///
+ /// Threadsafe.
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection {
pool.mutex.lock();
defer pool.mutex.unlock();
var next = pool.free.last;
while (next) |node| : (next = node.prev) {
- const connection: *Connection = @fieldParentPtr("pool_node", node);
+ const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node));
if (connection.protocol != criteria.protocol) continue;
if (connection.port != criteria.port) continue;
// Domain names are case-insensitive (RFC 5890, Section 2.3.2.4)
- if (!std.ascii.eqlIgnoreCase(connection.host, criteria.host)) continue;
+ if (!std.ascii.eqlIgnoreCase(connection.host(), criteria.host)) continue;
pool.acquireUnsafe(connection);
return connection;
@@ -96,28 +112,23 @@ pub const ConnectionPool = struct {
return pool.acquireUnsafe(connection);
}
- /// Tries to release a connection back to the connection pool. This function is threadsafe.
+ /// Tries to release a connection back to the connection pool.
/// If the connection is marked as closing, it will be closed instead.
///
- /// The allocator must be the owner of all nodes in this pool.
- /// The allocator must be the owner of all resources associated with the connection.
- pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void {
+ /// Threadsafe.
+ pub fn release(pool: *ConnectionPool, connection: *Connection) void {
pool.mutex.lock();
defer pool.mutex.unlock();
pool.used.remove(&connection.pool_node);
- if (connection.closing or pool.free_size == 0) {
- connection.close(allocator);
- return allocator.destroy(connection);
- }
+ if (connection.closing or pool.free_size == 0) return connection.destroy();
if (pool.free_len >= pool.free_size) {
- const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?);
+ const popped: *Connection = @alignCast(@fieldParentPtr("pool_node", pool.free.popFirst().?));
pool.free_len -= 1;
- popped.close(allocator);
- allocator.destroy(popped);
+ popped.destroy();
}
if (connection.proxied) {
@@ -138,9 +149,11 @@ pub const ConnectionPool = struct {
pool.used.append(&connection.pool_node);
}
- /// Resizes the connection pool. This function is threadsafe.
+ /// Resizes the connection pool.
///
/// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size.
+ ///
+ /// Threadsafe.
pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void {
pool.mutex.lock();
defer pool.mutex.unlock();
@@ -158,538 +171,612 @@ pub const ConnectionPool = struct {
pool.free_size = new_size;
}
- /// Frees the connection pool and closes all connections within. This function is threadsafe.
+ /// Frees the connection pool and closes all connections within.
///
/// All future operations on the connection pool will deadlock.
- pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void {
+ ///
+ /// Threadsafe.
+ pub fn deinit(pool: *ConnectionPool) void {
pool.mutex.lock();
var next = pool.free.first;
while (next) |node| {
- const connection: *Connection = @fieldParentPtr("pool_node", node);
+ const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node));
next = node.next;
- connection.close(allocator);
- allocator.destroy(connection);
+ connection.destroy();
}
next = pool.used.first;
while (next) |node| {
- const connection: *Connection = @fieldParentPtr("pool_node", node);
+ const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node));
next = node.next;
- connection.close(allocator);
- allocator.destroy(node);
+ connection.destroy();
}
pool.* = undefined;
}
};
-/// An interface to either a plain or TLS connection.
-pub const Connection = struct {
- stream: net.Stream,
- /// undefined unless protocol is tls.
- tls_client: if (!disable_tls) *std.crypto.tls.Client else void,
-
- /// Entry in `ConnectionPool.used` or `ConnectionPool.free`.
- pool_node: std.DoublyLinkedList.Node,
-
- /// The protocol that this connection is using.
- protocol: Protocol,
-
- /// The host that this connection is connected to.
- host: []u8,
+pub const Protocol = enum {
+ plain,
+ tls,
- /// The port that this connection is connected to.
- port: u16,
-
- /// Whether this connection is proxied and is not directly connected.
- proxied: bool = false,
-
- /// Whether this connection is closing when we're done with it.
- closing: bool = false,
-
- read_start: BufferSize = 0,
- read_end: BufferSize = 0,
- write_end: BufferSize = 0,
- read_buf: [buffer_size]u8 = undefined,
- write_buf: [buffer_size]u8 = undefined,
-
- pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
- const BufferSize = std.math.IntFittingRange(0, buffer_size);
-
- pub const Protocol = enum { plain, tls };
-
- pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
- return conn.tls_client.readv(conn.stream, buffers) catch |err| {
- // https://github.com/ziglang/zig/issues/2473
- if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
-
- switch (err) {
- error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
- error.ConnectionTimedOut => return error.ConnectionTimedOut,
- error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
- else => return error.UnexpectedReadFailure,
- }
+ fn port(protocol: Protocol) u16 {
+ return switch (protocol) {
+ .plain => 80,
+ .tls => 443,
};
}
- pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
- if (conn.protocol == .tls) {
- if (disable_tls) unreachable;
-
- return conn.readvDirectTls(buffers);
- }
-
- return conn.stream.readv(buffers) catch |err| switch (err) {
- error.ConnectionTimedOut => return error.ConnectionTimedOut,
- error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
- else => return error.UnexpectedReadFailure,
- };
- }
-
- /// Refills the read buffer with data from the connection.
- pub fn fill(conn: *Connection) ReadError!void {
- if (conn.read_end != conn.read_start) return;
-
- var iovecs = [1]std.posix.iovec{
- .{ .base = &conn.read_buf, .len = conn.read_buf.len },
- };
- const nread = try conn.readvDirect(&iovecs);
- if (nread == 0) return error.EndOfStream;
- conn.read_start = 0;
- conn.read_end = @intCast(nread);
+ pub fn fromScheme(scheme: []const u8) ?Protocol {
+ const protocol_map = std.StaticStringMap(Protocol).initComptime(.{
+ .{ "http", .plain },
+ .{ "ws", .plain },
+ .{ "https", .tls },
+ .{ "wss", .tls },
+ });
+ return protocol_map.get(scheme);
}
- /// Returns the current slice of buffered data.
- pub fn peek(conn: *Connection) []const u8 {
- return conn.read_buf[conn.read_start..conn.read_end];
+ pub fn fromUri(uri: Uri) ?Protocol {
+ return fromScheme(uri.scheme);
}
+};
- /// Discards the given number of bytes from the read buffer.
- pub fn drop(conn: *Connection, num: BufferSize) void {
- conn.read_start += num;
- }
+pub const Connection = struct {
+ client: *Client,
+ stream_writer: net.Stream.Writer,
+ stream_reader: net.Stream.Reader,
+ /// Entry in `ConnectionPool.used` or `ConnectionPool.free`.
+ pool_node: std.DoublyLinkedList.Node,
+ port: u16,
+ host_len: u8,
+ proxied: bool,
+ closing: bool,
+ protocol: Protocol,
- /// Reads data from the connection into the given buffer.
- pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
- const available_read = conn.read_end - conn.read_start;
- const available_buffer = buffer.len;
+ const Plain = struct {
+ connection: Connection,
+
+ fn create(
+ client: *Client,
+ remote_host: []const u8,
+ port: u16,
+ stream: net.Stream,
+ ) error{OutOfMemory}!*Plain {
+ const gpa = client.allocator;
+ const alloc_len = allocLen(client, remote_host.len);
+ const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len);
+ errdefer gpa.free(base);
+ const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len];
+ const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size];
+ const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size];
+ assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len);
+ @memcpy(host_buffer, remote_host);
+ const plain: *Plain = @ptrCast(base);
+ plain.* = .{
+ .connection = .{
+ .client = client,
+ .stream_writer = stream.writer(socket_write_buffer),
+ .stream_reader = stream.reader(socket_read_buffer),
+ .pool_node = .{},
+ .port = port,
+ .host_len = @intCast(remote_host.len),
+ .proxied = false,
+ .closing = false,
+ .protocol = .plain,
+ },
+ };
+ return plain;
+ }
- if (available_read > available_buffer) { // partially read buffered data
- @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
- conn.read_start += @intCast(available_buffer);
+ fn destroy(plain: *Plain) void {
+ const c = &plain.connection;
+ const gpa = c.client.allocator;
+ const base: [*]align(@alignOf(Plain)) u8 = @ptrCast(plain);
+ gpa.free(base[0..allocLen(c.client, c.host_len)]);
+ }
- return available_buffer;
- } else if (available_read > 0) { // fully read buffered data
- @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
- conn.read_start += available_read;
+ fn allocLen(client: *Client, host_len: usize) usize {
+ return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size;
+ }
- return available_read;
+ fn host(plain: *Plain) []u8 {
+ const base: [*]u8 = @ptrCast(plain);
+ return base[@sizeOf(Plain)..][0..plain.connection.host_len];
}
+ };
- var iovecs = [2]std.posix.iovec{
- .{ .base = buffer.ptr, .len = buffer.len },
- .{ .base = &conn.read_buf, .len = conn.read_buf.len },
- };
- const nread = try conn.readvDirect(&iovecs);
+ const Tls = struct {
+ client: std.crypto.tls.Client,
+ connection: Connection,
+
+ fn create(
+ client: *Client,
+ remote_host: []const u8,
+ port: u16,
+ stream: net.Stream,
+ ) error{ OutOfMemory, TlsInitializationFailed }!*Tls {
+ const gpa = client.allocator;
+ const alloc_len = allocLen(client, remote_host.len);
+ const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len);
+ errdefer gpa.free(base);
+ const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
+ const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
+ const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
+ const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
+ const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size];
+ assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len);
+ @memcpy(host_buffer, remote_host);
+ const tls: *Tls = @ptrCast(base);
+ tls.* = .{
+ .connection = .{
+ .client = client,
+ .stream_writer = stream.writer(tls_write_buffer),
+ .stream_reader = stream.reader(tls_read_buffer),
+ .pool_node = .{},
+ .port = port,
+ .host_len = @intCast(remote_host.len),
+ .proxied = false,
+ .closing = false,
+ .protocol = .tls,
+ },
+ // TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true
+ .client = std.crypto.tls.Client.init(
+ tls.connection.stream_reader.interface(),
+ &tls.connection.stream_writer.interface,
+ .{
+ .host = .{ .explicit = remote_host },
+ .ca = .{ .bundle = client.ca_bundle },
+ .ssl_key_log = client.ssl_key_log,
+ .read_buffer = read_buffer,
+ .write_buffer = write_buffer,
+ // This is appropriate for HTTPS because the HTTP headers contain
+ // the content length which is used to detect truncation attacks.
+ .allow_truncation_attacks = true,
+ },
+ ) catch return error.TlsInitializationFailed,
+ };
+ return tls;
+ }
- if (nread > buffer.len) {
- conn.read_start = 0;
- conn.read_end = @intCast(nread - buffer.len);
- return buffer.len;
+ fn destroy(tls: *Tls) void {
+ const c = &tls.connection;
+ const gpa = c.client.allocator;
+ const base: [*]align(@alignOf(Tls)) u8 = @ptrCast(tls);
+ gpa.free(base[0..allocLen(c.client, c.host_len)]);
}
- return nread;
- }
+ fn allocLen(client: *Client, host_len: usize) usize {
+ return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size +
+ client.write_buffer_size + client.read_buffer_size;
+ }
- pub const ReadError = error{
- TlsFailure,
- TlsAlert,
- ConnectionTimedOut,
- ConnectionResetByPeer,
- UnexpectedReadFailure,
- EndOfStream,
+ fn host(tls: *Tls) []u8 {
+ const base: [*]u8 = @ptrCast(tls);
+ return base[@sizeOf(Tls)..][0..tls.connection.host_len];
+ }
};
- pub const Reader = std.io.GenericReader(*Connection, ReadError, read);
-
- pub fn reader(conn: *Connection) Reader {
- return Reader{ .context = conn };
- }
+ pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError;
- pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void {
- return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) {
- error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
- else => return error.UnexpectedWriteFailure,
+ pub fn getReadError(c: *const Connection) ?ReadError {
+ return switch (c.protocol) {
+ .tls => {
+ if (disable_tls) unreachable;
+ const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c));
+ return tls.client.read_err orelse c.stream_reader.getError();
+ },
+ .plain => {
+ return c.stream_reader.getError();
+ },
};
}
- pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void {
- if (conn.protocol == .tls) {
- if (disable_tls) unreachable;
-
- return conn.writeAllDirectTls(buffer);
- }
+ fn getStream(c: *Connection) net.Stream {
+ return c.stream_reader.getStream();
+ }
- return conn.stream.writeAll(buffer) catch |err| switch (err) {
- error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
- else => return error.UnexpectedWriteFailure,
+ fn host(c: *Connection) []u8 {
+ return switch (c.protocol) {
+ .tls => {
+ if (disable_tls) unreachable;
+ const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
+ return tls.host();
+ },
+ .plain => {
+ const plain: *Plain = @alignCast(@fieldParentPtr("connection", c));
+ return plain.host();
+ },
};
}
- /// Writes the given buffer to the connection.
- pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
- if (conn.write_buf.len - conn.write_end < buffer.len) {
- try conn.flush();
-
- if (buffer.len > conn.write_buf.len) {
- try conn.writeAllDirect(buffer);
- return buffer.len;
- }
+ /// If this is called without calling `flush` or `end`, data will be
+ /// dropped unsent.
+ pub fn destroy(c: *Connection) void {
+ c.getStream().close();
+ switch (c.protocol) {
+ .tls => {
+ if (disable_tls) unreachable;
+ const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
+ tls.destroy();
+ },
+ .plain => {
+ const plain: *Plain = @alignCast(@fieldParentPtr("connection", c));
+ plain.destroy();
+ },
}
-
- @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer);
- conn.write_end += @intCast(buffer.len);
-
- return buffer.len;
}
- /// Returns a buffer to be filled with exactly len bytes to write to the connection.
- pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 {
- if (conn.write_buf.len - conn.write_end < len) try conn.flush();
- defer conn.write_end += len;
- return conn.write_buf[conn.write_end..][0..len];
+ /// HTTP protocol from client to server.
+ /// This either goes directly to `stream_writer`, or to a TLS client.
+ pub fn writer(c: *Connection) *Writer {
+ return switch (c.protocol) {
+ .tls => {
+ if (disable_tls) unreachable;
+ const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
+ return &tls.client.writer;
+ },
+ .plain => &c.stream_writer.interface,
+ };
}
- /// Flushes the write buffer to the connection.
- pub fn flush(conn: *Connection) WriteError!void {
- if (conn.write_end == 0) return;
-
- try conn.writeAllDirect(conn.write_buf[0..conn.write_end]);
- conn.write_end = 0;
+ /// HTTP protocol from server to client.
+ /// This either comes directly from `stream_reader`, or from a TLS client.
+ pub fn reader(c: *Connection) *Reader {
+ return switch (c.protocol) {
+ .tls => {
+ if (disable_tls) unreachable;
+ const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
+ return &tls.client.reader;
+ },
+ .plain => c.stream_reader.interface(),
+ };
}
- pub const WriteError = error{
- ConnectionResetByPeer,
- UnexpectedWriteFailure,
- };
-
- pub const Writer = std.io.GenericWriter(*Connection, WriteError, write);
-
- pub fn writer(conn: *Connection) Writer {
- return Writer{ .context = conn };
+ pub fn flush(c: *Connection) Writer.Error!void {
+ if (c.protocol == .tls) {
+ if (disable_tls) unreachable;
+ const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
+ try tls.client.writer.flush();
+ }
+ try c.stream_writer.interface.flush();
}
- /// Closes the connection.
- pub fn close(conn: *Connection, allocator: Allocator) void {
- if (conn.protocol == .tls) {
+ /// If the connection is a TLS connection, sends the close_notify alert.
+ ///
+ /// Flushes all buffers.
+ pub fn end(c: *Connection) Writer.Error!void {
+ if (c.protocol == .tls) {
if (disable_tls) unreachable;
-
- // try to cleanly close the TLS connection, for any server that cares.
- _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
- if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close();
- allocator.destroy(conn.tls_client);
+ const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
+ try tls.client.end();
}
-
- conn.stream.close();
- allocator.free(conn.host);
+ try c.stream_writer.interface.flush();
}
};
-/// The mode of transport for requests.
-pub const RequestTransfer = union(enum) {
- content_length: u64,
- chunked: void,
- none: void,
-};
-
-/// The decompressor for response messages.
-pub const Compression = union(enum) {
- //deflate: std.compress.flate.Decompress,
- //gzip: std.compress.flate.Decompress,
- // https://github.com/ziglang/zig/issues/18937
- //zstd: ZstdDecompressor,
- none: void,
-};
-
-/// A HTTP response originating from a server.
pub const Response = struct {
- version: http.Version,
- status: http.Status,
- reason: []const u8,
+ request: *Request,
+ /// Pointers in this struct are invalidated when the response body stream
+ /// is initialized.
+ head: Head,
+
+ pub const Head = struct {
+ bytes: []const u8,
+ version: http.Version,
+ status: http.Status,
+ reason: []const u8,
+ location: ?[]const u8 = null,
+ content_type: ?[]const u8 = null,
+ content_disposition: ?[]const u8 = null,
+
+ keep_alive: bool,
+
+ /// If present, the number of bytes in the response body.
+ content_length: ?u64 = null,
+
+ transfer_encoding: http.TransferEncoding = .none,
+ content_encoding: http.ContentEncoding = .identity,
+
+ pub const ParseError = error{
+ HttpConnectionHeaderUnsupported,
+ HttpContentEncodingUnsupported,
+ HttpHeaderContinuationsUnsupported,
+ HttpHeadersInvalid,
+ HttpTransferEncodingUnsupported,
+ InvalidContentLength,
+ };
- /// Points into the user-provided `server_header_buffer`.
- location: ?[]const u8 = null,
- /// Points into the user-provided `server_header_buffer`.
- content_type: ?[]const u8 = null,
- /// Points into the user-provided `server_header_buffer`.
- content_disposition: ?[]const u8 = null,
+ pub fn parse(bytes: []const u8) ParseError!Head {
+ var res: Head = .{
+ .bytes = bytes,
+ .status = undefined,
+ .reason = undefined,
+ .version = undefined,
+ .keep_alive = false,
+ };
+ var it = mem.splitSequence(u8, bytes, "\r\n");
- keep_alive: bool,
+ const first_line = it.first();
+ if (first_line.len < 12) return error.HttpHeadersInvalid;
- /// If present, the number of bytes in the response body.
- content_length: ?u64 = null,
+ const version: http.Version = switch (int64(first_line[0..8])) {
+ int64("HTTP/1.0") => .@"HTTP/1.0",
+ int64("HTTP/1.1") => .@"HTTP/1.1",
+ else => return error.HttpHeadersInvalid,
+ };
+ if (first_line[8] != ' ') return error.HttpHeadersInvalid;
+ const status: http.Status = @enumFromInt(parseInt3(first_line[9..12]));
+ const reason = mem.trimLeft(u8, first_line[12..], " ");
+
+ res.version = version;
+ res.status = status;
+ res.reason = reason;
+ res.keep_alive = switch (version) {
+ .@"HTTP/1.0" => false,
+ .@"HTTP/1.1" => true,
+ };
- /// If present, the transfer encoding of the response body, otherwise none.
- transfer_encoding: http.TransferEncoding = .none,
+ while (it.next()) |line| {
+ if (line.len == 0) return res;
+ switch (line[0]) {
+ ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
+ else => {},
+ }
- /// If present, the compression of the response body, otherwise identity (no compression).
- transfer_compression: http.ContentEncoding = .identity,
+ var line_it = mem.splitScalar(u8, line, ':');
+ const header_name = line_it.next().?;
+ const header_value = mem.trim(u8, line_it.rest(), " \t");
+ if (header_name.len == 0) return error.HttpHeadersInvalid;
+
+ if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
+ res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close");
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) {
+ res.content_type = header_value;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "location")) {
+ res.location = header_value;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) {
+ res.content_disposition = header_value;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
+ // Transfer-Encoding: second, first
+ // Transfer-Encoding: deflate, chunked
+ var iter = mem.splitBackwardsScalar(u8, header_value, ',');
+
+ const first = iter.first();
+ const trimmed_first = mem.trim(u8, first, " ");
+
+ var next: ?[]const u8 = first;
+ if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| {
+ if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding
+ res.transfer_encoding = transfer;
+
+ next = iter.next();
+ }
- parser: proto.HeadersParser,
- compression: Compression = .none,
+ if (next) |second| {
+ const trimmed_second = mem.trim(u8, second, " ");
- /// Whether the response body should be skipped. Any data read from the
- /// response body will be discarded.
- skip: bool = false,
+ if (http.ContentEncoding.fromString(trimmed_second)) |transfer| {
+ if (res.content_encoding != .identity) return error.HttpHeadersInvalid; // double compression is not supported
+ res.content_encoding = transfer;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
- pub const ParseError = error{
- HttpHeadersInvalid,
- HttpHeaderContinuationsUnsupported,
- HttpTransferEncodingUnsupported,
- HttpConnectionHeaderUnsupported,
- InvalidContentLength,
- CompressionUnsupported,
- };
+ if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
+ const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
- pub fn parse(res: *Response, bytes: []const u8) ParseError!void {
- var it = mem.splitSequence(u8, bytes, "\r\n");
+ if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid;
- const first_line = it.next().?;
- if (first_line.len < 12) {
- return error.HttpHeadersInvalid;
- }
+ res.content_length = content_length;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
+ if (res.content_encoding != .identity) return error.HttpHeadersInvalid;
- const version: http.Version = switch (int64(first_line[0..8])) {
- int64("HTTP/1.0") => .@"HTTP/1.0",
- int64("HTTP/1.1") => .@"HTTP/1.1",
- else => return error.HttpHeadersInvalid,
- };
- if (first_line[8] != ' ') return error.HttpHeadersInvalid;
- const status: http.Status = @enumFromInt(parseInt3(first_line[9..12]));
- const reason = mem.trimStart(u8, first_line[12..], " ");
-
- res.version = version;
- res.status = status;
- res.reason = reason;
- res.keep_alive = switch (version) {
- .@"HTTP/1.0" => false,
- .@"HTTP/1.1" => true,
- };
+ const trimmed = mem.trim(u8, header_value, " ");
- while (it.next()) |line| {
- if (line.len == 0) return;
- switch (line[0]) {
- ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
- else => {},
- }
-
- var line_it = mem.splitScalar(u8, line, ':');
- const header_name = line_it.next().?;
- const header_value = mem.trim(u8, line_it.rest(), " \t");
- if (header_name.len == 0) return error.HttpHeadersInvalid;
-
- if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
- res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close");
- } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) {
- res.content_type = header_value;
- } else if (std.ascii.eqlIgnoreCase(header_name, "location")) {
- res.location = header_value;
- } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) {
- res.content_disposition = header_value;
- } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
- // Transfer-Encoding: second, first
- // Transfer-Encoding: deflate, chunked
- var iter = mem.splitBackwardsScalar(u8, header_value, ',');
-
- const first = iter.first();
- const trimmed_first = mem.trim(u8, first, " ");
-
- var next: ?[]const u8 = first;
- if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| {
- if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding
- res.transfer_encoding = transfer;
-
- next = iter.next();
- }
-
- if (next) |second| {
- const trimmed_second = mem.trim(u8, second, " ");
-
- if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| {
- if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported
- res.transfer_compression = transfer;
+ if (http.ContentEncoding.fromString(trimmed)) |ce| {
+ res.content_encoding = ce;
} else {
- return error.HttpTransferEncodingUnsupported;
+ return error.HttpContentEncodingUnsupported;
}
}
-
- if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
- } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
- const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
-
- if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid;
-
- res.content_length = content_length;
- } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
- if (res.transfer_compression != .identity) return error.HttpHeadersInvalid;
-
- const trimmed = mem.trim(u8, header_value, " ");
-
- if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
- res.transfer_compression = ce;
- } else {
- return error.HttpTransferEncodingUnsupported;
- }
}
+ return error.HttpHeadersInvalid; // missing empty line
}
- return error.HttpHeadersInvalid; // missing empty line
- }
- test parse {
- const response_bytes = "HTTP/1.1 200 OK\r\n" ++
- "LOcation:url\r\n" ++
- "content-tYpe: text/plain\r\n" ++
- "content-disposition:attachment; filename=example.txt \r\n" ++
- "content-Length:10\r\n" ++
- "TRansfer-encoding:\tdeflate, chunked \r\n" ++
- "connectioN:\t keep-alive \r\n\r\n";
-
- var header_buffer: [1024]u8 = undefined;
- var res = Response{
- .status = undefined,
- .reason = undefined,
- .version = undefined,
- .keep_alive = false,
- .parser = .init(&header_buffer),
- };
+ test parse {
+ const response_bytes = "HTTP/1.1 200 OK\r\n" ++
+ "LOcation:url\r\n" ++
+ "content-tYpe: text/plain\r\n" ++
+ "content-disposition:attachment; filename=example.txt \r\n" ++
+ "content-Length:10\r\n" ++
+ "TRansfer-encoding:\tdeflate, chunked \r\n" ++
+ "connectioN:\t keep-alive \r\n\r\n";
+
+ const head = try Head.parse(response_bytes);
+
+ try testing.expectEqual(.@"HTTP/1.1", head.version);
+ try testing.expectEqualStrings("OK", head.reason);
+ try testing.expectEqual(.ok, head.status);
+
+ try testing.expectEqualStrings("url", head.location.?);
+ try testing.expectEqualStrings("text/plain", head.content_type.?);
+ try testing.expectEqualStrings("attachment; filename=example.txt", head.content_disposition.?);
+
+ try testing.expectEqual(true, head.keep_alive);
+ try testing.expectEqual(10, head.content_length.?);
+ try testing.expectEqual(.chunked, head.transfer_encoding);
+ try testing.expectEqual(.deflate, head.content_encoding);
+ }
- @memcpy(header_buffer[0..response_bytes.len], response_bytes);
- res.parser.header_bytes_len = response_bytes.len;
+ pub fn iterateHeaders(h: Head) http.HeaderIterator {
+ return .init(h.bytes);
+ }
- try res.parse(response_bytes);
+ test iterateHeaders {
+ const response_bytes = "HTTP/1.1 200 OK\r\n" ++
+ "LOcation:url\r\n" ++
+ "content-tYpe: text/plain\r\n" ++
+ "content-disposition:attachment; filename=example.txt \r\n" ++
+ "content-Length:10\r\n" ++
+ "TRansfer-encoding:\tdeflate, chunked \r\n" ++
+ "connectioN:\t keep-alive \r\n\r\n";
+
+ const head = try Head.parse(response_bytes);
+ var it = head.iterateHeaders();
+ {
+ const header = it.next().?;
+ try testing.expectEqualStrings("LOcation", header.name);
+ try testing.expectEqualStrings("url", header.value);
+ try testing.expect(!it.is_trailer);
+ }
+ {
+ const header = it.next().?;
+ try testing.expectEqualStrings("content-tYpe", header.name);
+ try testing.expectEqualStrings("text/plain", header.value);
+ try testing.expect(!it.is_trailer);
+ }
+ {
+ const header = it.next().?;
+ try testing.expectEqualStrings("content-disposition", header.name);
+ try testing.expectEqualStrings("attachment; filename=example.txt", header.value);
+ try testing.expect(!it.is_trailer);
+ }
+ {
+ const header = it.next().?;
+ try testing.expectEqualStrings("content-Length", header.name);
+ try testing.expectEqualStrings("10", header.value);
+ try testing.expect(!it.is_trailer);
+ }
+ {
+ const header = it.next().?;
+ try testing.expectEqualStrings("TRansfer-encoding", header.name);
+ try testing.expectEqualStrings("deflate, chunked", header.value);
+ try testing.expect(!it.is_trailer);
+ }
+ {
+ const header = it.next().?;
+ try testing.expectEqualStrings("connectioN", header.name);
+ try testing.expectEqualStrings("keep-alive", header.value);
+ try testing.expect(!it.is_trailer);
+ }
+ try testing.expectEqual(null, it.next());
+ }
- try testing.expectEqual(.@"HTTP/1.1", res.version);
- try testing.expectEqualStrings("OK", res.reason);
- try testing.expectEqual(.ok, res.status);
+ inline fn int64(array: *const [8]u8) u64 {
+ return @bitCast(array.*);
+ }
- try testing.expectEqualStrings("url", res.location.?);
- try testing.expectEqualStrings("text/plain", res.content_type.?);
- try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?);
+ fn parseInt3(text: *const [3]u8) u10 {
+ const nnn: @Vector(3, u8) = text.*;
+ const zero: @Vector(3, u8) = .{ '0', '0', '0' };
+ const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
+ return @reduce(.Add, (nnn -% zero) *% mmm);
+ }
- try testing.expectEqual(true, res.keep_alive);
- try testing.expectEqual(10, res.content_length.?);
- try testing.expectEqual(.chunked, res.transfer_encoding);
- try testing.expectEqual(.deflate, res.transfer_compression);
- }
+ test parseInt3 {
+ const expectEqual = testing.expectEqual;
+ try expectEqual(@as(u10, 0), parseInt3("000"));
+ try expectEqual(@as(u10, 418), parseInt3("418"));
+ try expectEqual(@as(u10, 999), parseInt3("999"));
+ }
- inline fn int64(array: *const [8]u8) u64 {
- return @bitCast(array.*);
- }
+ /// Help the programmer avoid bugs by calling this when the string
+ /// memory of `Head` becomes invalidated.
+ fn invalidateStrings(h: *Head) void {
+ h.bytes = undefined;
+ h.reason = undefined;
+ if (h.location) |*s| s.* = undefined;
+ if (h.content_type) |*s| s.* = undefined;
+ if (h.content_disposition) |*s| s.* = undefined;
+ }
+ };
- fn parseInt3(text: *const [3]u8) u10 {
- const nnn: @Vector(3, u8) = text.*;
- const zero: @Vector(3, u8) = .{ '0', '0', '0' };
- const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
- return @reduce(.Add, (nnn -% zero) *% mmm);
+ /// If compressed body has been negotiated this will return compressed bytes.
+ ///
+ /// If the returned `Reader` returns `error.ReadFailed` the error is
+ /// available via `bodyErr`.
+ ///
+ /// Asserts that this function is only called once.
+ ///
+ /// See also:
+ /// * `readerDecompressing`
+ pub fn reader(response: *Response, buffer: []u8) *Reader {
+ response.head.invalidateStrings();
+ const req = response.request;
+ if (!req.method.responseHasBody()) return .ending;
+ const head = &response.head;
+ return req.reader.bodyReader(buffer, head.transfer_encoding, head.content_length);
}
- test parseInt3 {
- const expectEqual = testing.expectEqual;
- try expectEqual(@as(u10, 0), parseInt3("000"));
- try expectEqual(@as(u10, 418), parseInt3("418"));
- try expectEqual(@as(u10, 999), parseInt3("999"));
+ /// If compressed body has been negotiated this will return decompressed bytes.
+ ///
+ /// If the returned `Reader` returns `error.ReadFailed` the error is
+ /// available via `bodyErr`.
+ ///
+ /// Asserts that this function is only called once.
+ ///
+ /// See also:
+ /// * `reader`
+ pub fn readerDecompressing(
+ response: *Response,
+ decompressor: *http.Decompressor,
+ decompression_buffer: []u8,
+ ) *Reader {
+ response.head.invalidateStrings();
+ const head = &response.head;
+ return response.request.reader.bodyReaderDecompressing(
+ head.transfer_encoding,
+ head.content_length,
+ head.content_encoding,
+ decompressor,
+ decompression_buffer,
+ );
}
- pub fn iterateHeaders(r: Response) http.HeaderIterator {
- return .init(r.parser.get());
+ /// After receiving `error.ReadFailed` from the `Reader` returned by
+ /// `reader` or `readerDecompressing`, this function accesses the
+ /// more specific error code.
+ pub fn bodyErr(response: *const Response) ?http.Reader.BodyError {
+ return response.request.reader.body_err;
}
- test iterateHeaders {
- const response_bytes = "HTTP/1.1 200 OK\r\n" ++
- "LOcation:url\r\n" ++
- "content-tYpe: text/plain\r\n" ++
- "content-disposition:attachment; filename=example.txt \r\n" ++
- "content-Length:10\r\n" ++
- "TRansfer-encoding:\tdeflate, chunked \r\n" ++
- "connectioN:\t keep-alive \r\n\r\n";
-
- var header_buffer: [1024]u8 = undefined;
- var res = Response{
- .status = undefined,
- .reason = undefined,
- .version = undefined,
- .keep_alive = false,
- .parser = .init(&header_buffer),
+ pub fn iterateTrailers(response: *const Response) http.HeaderIterator {
+ const r = &response.request.reader;
+ assert(r.state == .ready);
+ return .{
+ .bytes = r.trailers,
+ .index = 0,
+ .is_trailer = true,
};
-
- @memcpy(header_buffer[0..response_bytes.len], response_bytes);
- res.parser.header_bytes_len = response_bytes.len;
-
- var it = res.iterateHeaders();
- {
- const header = it.next().?;
- try testing.expectEqualStrings("LOcation", header.name);
- try testing.expectEqualStrings("url", header.value);
- try testing.expect(!it.is_trailer);
- }
- {
- const header = it.next().?;
- try testing.expectEqualStrings("content-tYpe", header.name);
- try testing.expectEqualStrings("text/plain", header.value);
- try testing.expect(!it.is_trailer);
- }
- {
- const header = it.next().?;
- try testing.expectEqualStrings("content-disposition", header.name);
- try testing.expectEqualStrings("attachment; filename=example.txt", header.value);
- try testing.expect(!it.is_trailer);
- }
- {
- const header = it.next().?;
- try testing.expectEqualStrings("content-Length", header.name);
- try testing.expectEqualStrings("10", header.value);
- try testing.expect(!it.is_trailer);
- }
- {
- const header = it.next().?;
- try testing.expectEqualStrings("TRansfer-encoding", header.name);
- try testing.expectEqualStrings("deflate, chunked", header.value);
- try testing.expect(!it.is_trailer);
- }
- {
- const header = it.next().?;
- try testing.expectEqualStrings("connectioN", header.name);
- try testing.expectEqualStrings("keep-alive", header.value);
- try testing.expect(!it.is_trailer);
- }
- try testing.expectEqual(null, it.next());
}
};
-/// A HTTP request that has been sent.
-///
-/// Order of operations: open -> send[ -> write -> finish] -> wait -> read
pub const Request = struct {
+ /// This field is provided so that clients can observe redirected URIs.
+ ///
+ /// Its backing memory is externally provided by API users when creating a
+ /// request, and then again provided externally via `redirect_buffer` to
+ /// `receiveHead`.
uri: Uri,
client: *Client,
/// This is null when the connection is released.
connection: ?*Connection,
+ reader: http.Reader,
keep_alive: bool,
method: http.Method,
version: http.Version = .@"HTTP/1.1",
- transfer_encoding: RequestTransfer,
+ transfer_encoding: TransferEncoding,
redirect_behavior: RedirectBehavior,
+ accept_encoding: @TypeOf(default_accept_encoding) = default_accept_encoding,
/// Whether the request should handle a 100-continue response before sending the request body.
handle_continue: bool,
- /// The response associated with this request.
- ///
- /// This field is undefined until `wait` is called.
- response: Response,
-
/// Standard headers that have default, but overridable, behavior.
headers: Headers,
@@ -703,6 +790,20 @@ pub const Request = struct {
/// Externally-owned; must outlive the Request.
privileged_headers: []const http.Header,
+ pub const default_accept_encoding: [@typeInfo(http.ContentEncoding).@"enum".fields.len]bool = b: {
+ var result: [@typeInfo(http.ContentEncoding).@"enum".fields.len]bool = @splat(false);
+ result[@intFromEnum(http.ContentEncoding.gzip)] = true;
+ result[@intFromEnum(http.ContentEncoding.deflate)] = true;
+ result[@intFromEnum(http.ContentEncoding.identity)] = true;
+ break :b result;
+ };
+
+ pub const TransferEncoding = union(enum) {
+ content_length: u64,
+ chunked: void,
+ none: void,
+ };
+
pub const Headers = struct {
host: Value = .default,
authorization: Value = .default,
@@ -728,6 +829,11 @@ pub const Request = struct {
unhandled = std.math.maxInt(u16),
_,
+ pub fn init(n: u16) RedirectBehavior {
+ assert(n != std.math.maxInt(u16));
+ return @enumFromInt(n);
+ }
+
pub fn subtractOne(rb: *RedirectBehavior) void {
switch (rb.*) {
.not_allowed => unreachable,
@@ -742,98 +848,110 @@ pub const Request = struct {
}
};
- /// Frees all resources associated with the request.
- pub fn deinit(req: *Request) void {
- if (req.connection) |connection| {
- if (!req.response.parser.done) {
- // If the response wasn't fully read, then we need to close the connection.
- connection.closing = true;
- }
- req.client.connection_pool.release(req.client.allocator, connection);
+ /// Returns the request's `Connection` back to the pool of the `Client`.
+ pub fn deinit(r: *Request) void {
+ if (r.connection) |connection| {
+ connection.closing = connection.closing or switch (r.reader.state) {
+ .ready => false,
+ .received_head => r.method.requestHasBody(),
+ else => true,
+ };
+ r.client.connection_pool.release(connection);
}
- req.* = undefined;
+ r.* = undefined;
}
- // This function must deallocate all resources associated with the request,
- // or keep those which will be used.
- // This needs to be kept in sync with deinit and request.
- fn redirect(req: *Request, uri: Uri) !void {
- assert(req.response.parser.done);
-
- req.client.connection_pool.release(req.client.allocator, req.connection.?);
- req.connection = null;
-
- var server_header: std.heap.FixedBufferAllocator = .init(req.response.parser.header_bytes_buffer);
- defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..];
- const protocol, const valid_uri = try validateUri(uri, server_header.allocator());
-
- const new_host = valid_uri.host.?.raw;
- const prev_host = req.uri.host.?.raw;
- const keep_privileged_headers =
- std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and
- std.ascii.endsWithIgnoreCase(new_host, prev_host) and
- (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.');
- if (!keep_privileged_headers) {
- // When redirecting to a different domain, strip privileged headers.
- req.privileged_headers = &.{};
- }
-
- if (switch (req.response.status) {
- .see_other => true,
- .moved_permanently, .found => req.method == .POST,
- else => false,
- }) {
- // A redirect to a GET must change the method and remove the body.
- req.method = .GET;
- req.transfer_encoding = .none;
- req.headers.content_type = .omit;
- }
-
- if (req.transfer_encoding != .none) {
- // The request body has already been sent. The request is
- // still in a valid state, but the redirect must be handled
- // manually.
- return error.RedirectRequiresResend;
- }
+ /// Sends and flushes a complete request as only HTTP head, no body.
+ pub fn sendBodiless(r: *Request) Writer.Error!void {
+ try sendBodilessUnflushed(r);
+ try r.connection.?.flush();
+ }
- req.uri = valid_uri;
- req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol);
- req.redirect_behavior.subtractOne();
- req.response.parser.reset();
-
- req.response = .{
- .version = undefined,
- .status = undefined,
- .reason = undefined,
- .keep_alive = undefined,
- .parser = req.response.parser,
- };
+ /// Sends but does not flush a complete request as only HTTP head, no body.
+ pub fn sendBodilessUnflushed(r: *Request) Writer.Error!void {
+ assert(r.transfer_encoding == .none);
+ assert(!r.method.requestHasBody());
+ try sendHead(r);
}
- pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
+ /// Transfers the HTTP head over the connection and flushes.
+ ///
+ /// See also:
+ /// * `sendBodyUnflushed`
+ pub fn sendBody(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter {
+ const result = try sendBodyUnflushed(r, buffer);
+ try r.connection.?.flush();
+ return result;
+ }
- /// Send the HTTP request headers to the server.
- pub fn send(req: *Request) SendError!void {
- if (!req.method.requestHasBody() and req.transfer_encoding != .none)
- return error.UnsupportedTransferEncoding;
+ /// Transfers the HTTP head and body over the connection and flushes.
+ pub fn sendBodyComplete(r: *Request, body: []u8) Writer.Error!void {
+ r.transfer_encoding = .{ .content_length = body.len };
+ var bw = try sendBodyUnflushed(r, body);
+ bw.writer.end = body.len;
+ try bw.end();
+ try r.connection.?.flush();
+ }
- const connection = req.connection.?;
- var connection_writer_adapter = connection.writer().adaptToNewApi();
- const w = &connection_writer_adapter.new_interface;
- sendAdapted(req, connection, w) catch |err| switch (err) {
- error.WriteFailed => return connection_writer_adapter.err.?,
- else => |e| return e,
+ /// Transfers the HTTP head over the connection, which is not flushed until
+ /// `BodyWriter.flush` or `BodyWriter.end` is called.
+ ///
+ /// See also:
+ /// * `sendBody`
+ pub fn sendBodyUnflushed(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter {
+ assert(r.method.requestHasBody());
+ try sendHead(r);
+ const http_protocol_output = r.connection.?.writer();
+ return switch (r.transfer_encoding) {
+ .chunked => .{
+ .http_protocol_output = http_protocol_output,
+ .state = .{ .chunked = .init },
+ .writer = .{
+ .buffer = buffer,
+ .vtable = &.{
+ .drain = http.BodyWriter.chunkedDrain,
+ .sendFile = http.BodyWriter.chunkedSendFile,
+ },
+ },
+ },
+ .content_length => |len| .{
+ .http_protocol_output = http_protocol_output,
+ .state = .{ .content_length = len },
+ .writer = .{
+ .buffer = buffer,
+ .vtable = &.{
+ .drain = http.BodyWriter.contentLengthDrain,
+ .sendFile = http.BodyWriter.contentLengthSendFile,
+ },
+ },
+ },
+ .none => .{
+ .http_protocol_output = http_protocol_output,
+ .state = .none,
+ .writer = .{
+ .buffer = buffer,
+ .vtable = &.{
+ .drain = http.BodyWriter.noneDrain,
+ .sendFile = http.BodyWriter.noneSendFile,
+ },
+ },
+ },
};
}
- fn sendAdapted(req: *Request, connection: *Connection, w: *std.io.Writer) !void {
- try req.method.format(w);
+ /// Sends HTTP headers without flushing.
+ fn sendHead(r: *Request) Writer.Error!void {
+ const uri = r.uri;
+ const connection = r.connection.?;
+ const w = connection.writer();
+
+ try w.writeAll(@tagName(r.method));
try w.writeByte(' ');
- if (req.method == .CONNECT) {
- try req.uri.writeToStream(w, .{ .authority = true });
+ if (r.method == .CONNECT) {
+ try uri.writeToStream(w, .{ .authority = true });
} else {
- try req.uri.writeToStream(w, .{
+ try uri.writeToStream(w, .{
.scheme = connection.proxied,
.authentication = connection.proxied,
.authority = connection.proxied,
@@ -842,58 +960,64 @@ pub const Request = struct {
});
}
try w.writeByte(' ');
- try w.writeAll(@tagName(req.version));
+ try w.writeAll(@tagName(r.version));
try w.writeAll("\r\n");
- if (try emitOverridableHeader("host: ", req.headers.host, w)) {
+ if (try emitOverridableHeader("host: ", r.headers.host, w)) {
try w.writeAll("host: ");
- try req.uri.writeToStream(w, .{ .authority = true });
+ try uri.writeToStream(w, .{ .authority = true });
try w.writeAll("\r\n");
}
- if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) {
- if (req.uri.user != null or req.uri.password != null) {
+ if (try emitOverridableHeader("authorization: ", r.headers.authorization, w)) {
+ if (uri.user != null or uri.password != null) {
try w.writeAll("authorization: ");
- const authorization = try connection.allocWriteBuffer(
- @intCast(basic_authorization.valueLengthFromUri(req.uri)),
- );
- assert(basic_authorization.value(req.uri, authorization).len == authorization.len);
+ try basic_authorization.write(uri, w);
try w.writeAll("\r\n");
}
}
- if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) {
+ if (try emitOverridableHeader("user-agent: ", r.headers.user_agent, w)) {
try w.writeAll("user-agent: zig/");
try w.writeAll(builtin.zig_version_string);
try w.writeAll(" (std.http)\r\n");
}
- if (try emitOverridableHeader("connection: ", req.headers.connection, w)) {
- if (req.keep_alive) {
+ if (try emitOverridableHeader("connection: ", r.headers.connection, w)) {
+ if (r.keep_alive) {
try w.writeAll("connection: keep-alive\r\n");
} else {
try w.writeAll("connection: close\r\n");
}
}
- if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) {
- // https://github.com/ziglang/zig/issues/18937
- //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n");
- try w.writeAll("accept-encoding: gzip, deflate\r\n");
+ if (try emitOverridableHeader("accept-encoding: ", r.headers.accept_encoding, w)) {
+ try w.writeAll("accept-encoding: ");
+ for (r.accept_encoding, 0..) |enabled, i| {
+ if (!enabled) continue;
+ const tag: http.ContentEncoding = @enumFromInt(i);
+ if (tag == .identity) continue;
+ const tag_name = @tagName(tag);
+ try w.ensureUnusedCapacity(tag_name.len + 2);
+ try w.writeAll(tag_name);
+ try w.writeAll(", ");
+ }
+ w.undo(2);
+ try w.writeAll("\r\n");
}
- switch (req.transfer_encoding) {
+ switch (r.transfer_encoding) {
.chunked => try w.writeAll("transfer-encoding: chunked\r\n"),
.content_length => |len| try w.print("content-length: {d}\r\n", .{len}),
.none => {},
}
- if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) {
+ if (try emitOverridableHeader("content-type: ", r.headers.content_type, w)) {
// The default is to omit content-type if not provided because
// "application/octet-stream" is redundant.
}
- for (req.extra_headers) |header| {
+ for (r.extra_headers) |header| {
assert(header.name.len != 0);
try w.writeAll(header.name);
@@ -904,8 +1028,8 @@ pub const Request = struct {
if (connection.proxied) proxy: {
const proxy = switch (connection.protocol) {
- .plain => req.client.http_proxy,
- .tls => req.client.https_proxy,
+ .plain => r.client.http_proxy,
+ .tls => r.client.https_proxy,
} orelse break :proxy;
const authorization = proxy.authorization orelse break :proxy;
@@ -915,282 +1039,200 @@ pub const Request = struct {
}
try w.writeAll("\r\n");
-
- try connection.flush();
- }
-
- /// Returns true if the default behavior is required, otherwise handles
- /// writing (or not writing) the header.
- fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool {
- switch (v) {
- .default => return true,
- .omit => return false,
- .override => |x| {
- try w.writeAll(prefix);
- try w.writeAll(x);
- try w.writeAll("\r\n");
- return false;
- },
- }
}
- const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
-
- const TransferReader = std.io.GenericReader(*Request, TransferReadError, transferRead);
-
- fn transferReader(req: *Request) TransferReader {
- return .{ .context = req };
- }
-
- fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
- if (req.response.parser.done) return 0;
-
- var index: usize = 0;
- while (index == 0) {
- const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip);
- if (amt == 0 and req.response.parser.done) break;
- index += amt;
- }
-
- return index;
- }
+ pub const ReceiveHeadError = http.Reader.HeadError || ConnectError || error{
+ /// Server sent headers that did not conform to the HTTP protocol.
+ ///
+ /// To find out more detailed diagnostics, `http.Reader.head_buffer` can be
+ /// passed directly to `Request.Head.parse`.
+ HttpHeadersInvalid,
+ TooManyHttpRedirects,
+ /// This can be avoided by calling `receiveHead` before sending the
+ /// request body.
+ RedirectRequiresResend,
+ HttpRedirectLocationMissing,
+ HttpRedirectLocationOversize,
+ HttpRedirectLocationInvalid,
+ HttpContentEncodingUnsupported,
+ HttpChunkInvalid,
+ HttpChunkTruncated,
+ HttpHeadersOversize,
+ UnsupportedUriScheme,
- pub const WaitError = RequestError || SendError || TransferReadError ||
- proto.HeadersParser.CheckCompleteHeadError || Response.ParseError ||
- error{
- TooManyHttpRedirects,
- RedirectRequiresResend,
- HttpRedirectLocationMissing,
- HttpRedirectLocationInvalid,
- CompressionInitializationFailed,
- CompressionUnsupported,
- };
+ /// Sending the request failed. Error code can be found on the
+ /// `Connection` object.
+ WriteFailed,
+ };
- /// Waits for a response from the server and parses any headers that are sent.
- /// This function will block until the final response is received.
- ///
/// If handling redirects and the request has no payload, then this
- /// function will automatically follow redirects. If a request payload is
- /// present, then this function will error with
- /// error.RedirectRequiresResend.
+ /// function will automatically follow redirects.
+ ///
+ /// If a request payload is present, then this function will error with
+ /// `error.RedirectRequiresResend`.
+ ///
+ /// This function takes an auxiliary buffer to store the arbitrarily large
+ /// URI which may need to be merged with the previous URI, and that data
+ /// needs to survive across different connections, which is where the input
+ /// buffer lives.
///
- /// Must be called after `send` and, if any data was written to the request
- /// body, then also after `finish`.
- pub fn wait(req: *Request) WaitError!void {
+ /// `redirect_buffer` must outlive accesses to `Request.uri`. If this
+ /// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize`
+ /// is returned instead. This buffer may be empty if no redirects are to be
+ /// handled.
+ ///
+ /// If this fails with `error.ReadFailed` then the `Connection.getReadError`
+ /// method of `r.connection` can be used to get more detailed information.
+ pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response {
+ var aux_buf = redirect_buffer;
while (true) {
- // This while loop is for handling redirects, which means the request's
- // connection may be different than the previous iteration. However, it
- // is still guaranteed to be non-null with each iteration of this loop.
- const connection = req.connection.?;
-
- while (true) { // read headers
- try connection.fill();
-
- const nchecked = try req.response.parser.checkCompleteHead(connection.peek());
- connection.drop(@intCast(nchecked));
+ const head_buffer = try r.reader.receiveHead();
+ const response: Response = .{
+ .request = r,
+ .head = Response.Head.parse(head_buffer) catch return error.HttpHeadersInvalid,
+ };
+ const head = &response.head;
- if (req.response.parser.state.isContent()) break;
+ if (head.status == .@"continue") {
+ if (r.handle_continue) continue;
+ return response; // we're not handling the 100-continue
}
- try req.response.parse(req.response.parser.get());
-
- if (req.response.status == .@"continue") {
- // We're done parsing the continue response; reset to prepare
- // for the real response.
- req.response.parser.done = true;
- req.response.parser.reset();
-
- if (req.handle_continue)
- continue;
-
- return; // we're not handling the 100-continue
- }
+ // This while loop is for handling redirects, which means the request's
+ // connection may be different than the previous iteration. However, it
+ // is still guaranteed to be non-null with each iteration of this loop.
+ const connection = r.connection.?;
- // we're switching protocols, so this connection is no longer doing http
- if (req.method == .CONNECT and req.response.status.class() == .success) {
+ if (r.method == .CONNECT and head.status.class() == .success) {
+ // This connection is no longer doing HTTP.
connection.closing = false;
- req.response.parser.done = true;
- return; // the connection is not HTTP past this point
+ return response;
}
- connection.closing = !req.response.keep_alive or !req.keep_alive;
+ connection.closing = !head.keep_alive or !r.keep_alive;
// Any response to a HEAD request and any response with a 1xx
// (Informational), 204 (No Content), or 304 (Not Modified) status
// code is always terminated by the first empty line after the
// header fields, regardless of the header fields present in the
// message.
- if (req.method == .HEAD or req.response.status.class() == .informational or
- req.response.status == .no_content or req.response.status == .not_modified)
+ if (r.method == .HEAD or head.status.class() == .informational or
+ head.status == .no_content or head.status == .not_modified)
{
- req.response.parser.done = true;
- return; // The response is empty; no further setup or redirection is necessary.
- }
-
- switch (req.response.transfer_encoding) {
- .none => {
- if (req.response.content_length) |cl| {
- req.response.parser.next_chunk_length = cl;
-
- if (cl == 0) req.response.parser.done = true;
- } else {
- // read until the connection is closed
- req.response.parser.next_chunk_length = std.math.maxInt(u64);
- }
- },
- .chunked => {
- req.response.parser.next_chunk_length = 0;
- req.response.parser.state = .chunk_head_size;
- },
+ return response;
}
- if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) {
- // skip the body of the redirect response, this will at least
- // leave the connection in a known good state.
- req.response.skip = true;
- assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary
-
- if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects;
-
- const location = req.response.location orelse
- return error.HttpRedirectLocationMissing;
-
- // This mutates the beginning of header_bytes_buffer and uses that
- // for the backing memory of the returned Uri.
- try req.redirect(req.uri.resolve_inplace(
- location,
- &req.response.parser.header_bytes_buffer,
- ) catch |err| switch (err) {
- error.UnexpectedCharacter,
- error.InvalidFormat,
- error.InvalidPort,
- => return error.HttpRedirectLocationInvalid,
- error.NoSpaceLeft => return error.HttpHeadersOversize,
- });
- try req.send();
- } else {
- req.response.skip = false;
- if (!req.response.parser.done) {
- switch (req.response.transfer_compression) {
- .identity => req.response.compression = .none,
- .compress, .@"x-compress" => return error.CompressionUnsupported,
- // I'm about to upstream my http.Client rewrite
- .deflate => return error.CompressionUnsupported,
- // I'm about to upstream my http.Client rewrite
- .gzip, .@"x-gzip" => return error.CompressionUnsupported,
- // https://github.com/ziglang/zig/issues/18937
- //.zstd => req.response.compression = .{
- // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()),
- //},
- .zstd => return error.CompressionUnsupported,
- }
+ if (head.status.class() == .redirect and r.redirect_behavior != .unhandled) {
+ if (r.redirect_behavior == .not_allowed) {
+ // Connection can still be reused by skipping the body.
+ const reader = r.reader.bodyReader(&.{}, head.transfer_encoding, head.content_length);
+ _ = reader.discardRemaining() catch |err| switch (err) {
+ error.ReadFailed => connection.closing = true,
+ };
+ return error.TooManyHttpRedirects;
}
-
- break;
+ try r.redirect(head, &aux_buf);
+ try r.sendBodiless();
+ continue;
}
- }
- }
-
- pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError ||
- error{ DecompressionFailure, InvalidTrailers };
- pub const Reader = std.io.GenericReader(*Request, ReadError, read);
+ if (!r.accept_encoding[@intFromEnum(head.content_encoding)])
+ return error.HttpContentEncodingUnsupported;
- pub fn reader(req: *Request) Reader {
- return .{ .context = req };
- }
-
- /// Reads data from the response body. Must be called after `wait`.
- pub fn read(req: *Request, buffer: []u8) ReadError!usize {
- const out_index = switch (req.response.compression) {
- // I'm about to upstream my http client rewrite
- //.deflate => |*deflate| deflate.readSlice(buffer) catch return error.DecompressionFailure,
- //.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
- // https://github.com/ziglang/zig/issues/18937
- //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
- else => try req.transferRead(buffer),
- };
- if (out_index > 0) return out_index;
-
- while (!req.response.parser.state.isContent()) { // read trailing headers
- try req.connection.?.fill();
-
- const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek());
- req.connection.?.drop(@intCast(nchecked));
+ return response;
}
-
- return 0;
}
- /// Reads data from the response body. Must be called after `wait`.
- pub fn readAll(req: *Request, buffer: []u8) !usize {
- var index: usize = 0;
- while (index < buffer.len) {
- const amt = try read(req, buffer[index..]);
- if (amt == 0) break;
- index += amt;
+ /// This function takes an auxiliary buffer to store the arbitrarily large
+ /// URI which may need to be merged with the previous URI, and that data
+ /// needs to survive across different connections, which is where the input
+ /// buffer lives.
+ ///
+ /// `aux_buf` must outlive accesses to `Request.uri`.
+ fn redirect(r: *Request, head: *const Response.Head, aux_buf: *[]u8) !void {
+ const new_location = head.location orelse return error.HttpRedirectLocationMissing;
+ if (new_location.len > aux_buf.*.len) return error.HttpRedirectLocationOversize;
+ const location = aux_buf.*[0..new_location.len];
+ @memcpy(location, new_location);
+ {
+ // Skip the body of the redirect response to leave the connection in
+ // the correct state. This causes `new_location` to be invalidated.
+ const reader = r.reader.bodyReader(&.{}, head.transfer_encoding, head.content_length);
+ _ = reader.discardRemaining() catch |err| switch (err) {
+ error.ReadFailed => return r.reader.body_err.?,
+ };
}
- return index;
- }
-
- pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
-
- pub const Writer = std.io.GenericWriter(*Request, WriteError, write);
+ const new_uri = r.uri.resolveInPlace(location.len, aux_buf) catch |err| switch (err) {
+ error.UnexpectedCharacter => return error.HttpRedirectLocationInvalid,
+ error.InvalidFormat => return error.HttpRedirectLocationInvalid,
+ error.InvalidPort => return error.HttpRedirectLocationInvalid,
+ error.NoSpaceLeft => return error.HttpRedirectLocationOversize,
+ };
- pub fn writer(req: *Request) Writer {
- return .{ .context = req };
- }
+ const protocol = Protocol.fromUri(new_uri) orelse return error.UnsupportedUriScheme;
+ const old_connection = r.connection.?;
+ const old_host = old_connection.host();
+ var new_host_name_buffer: [Uri.host_name_max]u8 = undefined;
+ const new_host = try new_uri.getHost(&new_host_name_buffer);
+ const keep_privileged_headers =
+ std.ascii.eqlIgnoreCase(r.uri.scheme, new_uri.scheme) and
+ sameParentDomain(old_host, new_host);
- /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent.
- /// Must be called after `send` and before `finish`.
- pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
- switch (req.transfer_encoding) {
- .chunked => {
- if (bytes.len > 0) {
- try req.connection.?.writer().print("{x}\r\n", .{bytes.len});
- try req.connection.?.writer().writeAll(bytes);
- try req.connection.?.writer().writeAll("\r\n");
- }
+ r.client.connection_pool.release(old_connection);
+ r.connection = null;
- return bytes.len;
- },
- .content_length => |*len| {
- if (len.* < bytes.len) return error.MessageTooLong;
+ if (!keep_privileged_headers) {
+ // When redirecting to a different domain, strip privileged headers.
+ r.privileged_headers = &.{};
+ }
- const amt = try req.connection.?.write(bytes);
- len.* -= amt;
- return amt;
- },
- .none => return error.NotWriteable,
+ if (switch (head.status) {
+ .see_other => true,
+ .moved_permanently, .found => r.method == .POST,
+ else => false,
+ }) {
+ // A redirect to a GET must change the method and remove the body.
+ r.method = .GET;
+ r.transfer_encoding = .none;
+ r.headers.content_type = .omit;
}
- }
- /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent.
- /// Must be called after `send` and before `finish`.
- pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try write(req, bytes[index..]);
+ if (r.transfer_encoding != .none) {
+ // The request body has already been sent. The request is
+ // still in a valid state, but the redirect must be handled
+ // manually.
+ return error.RedirectRequiresResend;
}
- }
- pub const FinishError = WriteError || error{MessageNotCompleted};
+ const new_connection = try r.client.connect(new_host, uriPort(new_uri, protocol), protocol);
+ r.uri = new_uri;
+ r.connection = new_connection;
+ r.reader = .{
+ .in = new_connection.reader(),
+ .state = .ready,
+ // Populated when `http.Reader.bodyReader` is called.
+ .interface = undefined,
+ };
+ r.redirect_behavior.subtractOne();
+ }
- /// Finish the body of a request. This notifies the server that you have no more data to send.
- /// Must be called after `send`.
- pub fn finish(req: *Request) FinishError!void {
- switch (req.transfer_encoding) {
- .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"),
- .content_length => |len| if (len != 0) return error.MessageNotCompleted,
- .none => {},
+ /// Returns true if the default behavior is required, otherwise handles
+ /// writing (or not writing) the header.
+ fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, bw: *Writer) Writer.Error!bool {
+ switch (v) {
+ .default => return true,
+ .omit => return false,
+ .override => |x| {
+ var vecs: [3][]const u8 = .{ prefix, x, "\r\n" };
+ try bw.writeVecAll(&vecs);
+ return false;
+ },
}
-
- try req.connection.?.flush();
}
};
pub const Proxy = struct {
- protocol: Connection.Protocol,
+ protocol: Protocol,
host: []const u8,
authorization: ?[]const u8,
port: u16,
@@ -1204,10 +1246,8 @@ pub const Proxy = struct {
pub fn deinit(client: *Client) void {
assert(client.connection_pool.used.first == null); // There are still active requests.
- client.connection_pool.deinit(client.allocator);
-
- if (!disable_tls)
- client.ca_bundle.deinit(client.allocator);
+ client.connection_pool.deinit();
+ if (!disable_tls) client.ca_bundle.deinit(client.allocator);
client.* = undefined;
}
@@ -1249,24 +1289,21 @@ fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?
} else return null;
const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content);
- const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) {
- error.UnsupportedUriScheme => return null,
- error.UriMissingHost => return error.HttpProxyMissingHost,
- error.OutOfMemory => |e| return e,
- };
+ const protocol = Protocol.fromUri(uri) orelse return null;
+ const raw_host = try uri.getHostAlloc(arena);
- const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: {
- const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri));
- assert(basic_authorization.value(valid_uri, authorization).len == authorization.len);
+ const authorization: ?[]const u8 = if (uri.user != null or uri.password != null) a: {
+ const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(uri));
+ assert(basic_authorization.value(uri, authorization).len == authorization.len);
break :a authorization;
} else null;
const proxy = try arena.create(Proxy);
proxy.* = .{
.protocol = protocol,
- .host = valid_uri.host.?.raw,
+ .host = raw_host,
.authorization = authorization,
- .port = uriPort(valid_uri, protocol),
+ .port = uriPort(uri, protocol),
.supports_connect = true,
};
return proxy;
@@ -1277,10 +1314,8 @@ pub const basic_authorization = struct {
pub const max_password_len = 255;
pub const max_value_len = valueLength(max_user_len, max_password_len);
- const prefix = "Basic ";
-
pub fn valueLength(user_len: usize, password_len: usize) usize {
- return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len);
+ return "Basic ".len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len);
}
pub fn valueLengthFromUri(uri: Uri) usize {
@@ -1300,37 +1335,70 @@ pub const basic_authorization = struct {
}
pub fn value(uri: Uri, out: []u8) []u8 {
- const user: Uri.Component = uri.user orelse .empty;
- const password: Uri.Component = uri.password orelse .empty;
-
- var buf: [max_user_len + ":".len + max_password_len]u8 = undefined;
- var w: std.io.Writer = .fixed(&buf);
- user.formatUser(&w) catch unreachable; // fixed
- password.formatPassword(&w) catch unreachable; // fixed
+ var bw: Writer = .fixed(out);
+ write(uri, &bw) catch unreachable;
+ return bw.buffered();
+ }
- @memcpy(out[0..prefix.len], prefix);
- const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], w.buffered());
- return out[0 .. prefix.len + base64.len];
+ pub fn write(uri: Uri, out: *Writer) Writer.Error!void {
+ var buf: [max_user_len + 1 + max_password_len]u8 = undefined;
+ var w: Writer = .fixed(&buf);
+ const user: Uri.Component = uri.user orelse .empty;
+ const password: Uri.Component = uri.user orelse .empty;
+ user.formatUser(&w) catch unreachable;
+ w.writeByte(':') catch unreachable;
+ password.formatPassword(&w) catch unreachable;
+ try out.print("Basic {b64}", .{w.buffered()});
}
};
-pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
+pub const ConnectTcpError = Allocator.Error || error{
+ ConnectionRefused,
+ NetworkUnreachable,
+ ConnectionTimedOut,
+ ConnectionResetByPeer,
+ TemporaryNameServerFailure,
+ NameServerFailure,
+ UnknownHostName,
+ HostLacksNetworkAddresses,
+ UnexpectedConnectFailure,
+ TlsInitializationFailed,
+};
-/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
+/// Reuses a `Connection` if one matching `host` and `port` is already open.
///
-/// This function is threadsafe.
-pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection {
- if (client.connection_pool.findConnection(.{
- .host = host,
- .port = port,
- .protocol = protocol,
- })) |node| return node;
+/// Threadsafe.
+pub fn connectTcp(
+ client: *Client,
+ host: []const u8,
+ port: u16,
+ protocol: Protocol,
+) ConnectTcpError!*Connection {
+ return connectTcpOptions(client, .{ .host = host, .port = port, .protocol = protocol });
+}
+
+pub const ConnectTcpOptions = struct {
+ host: []const u8,
+ port: u16,
+ protocol: Protocol,
- if (disable_tls and protocol == .tls)
- return error.TlsInitializationFailed;
+ proxied_host: ?[]const u8 = null,
+ proxied_port: ?u16 = null,
+};
- const conn = try client.allocator.create(Connection);
- errdefer client.allocator.destroy(conn);
+pub fn connectTcpOptions(client: *Client, options: ConnectTcpOptions) ConnectTcpError!*Connection {
+ const host = options.host;
+ const port = options.port;
+ const protocol = options.protocol;
+
+ const proxied_host = options.proxied_host orelse host;
+ const proxied_port = options.proxied_port orelse port;
+
+ if (client.connection_pool.findConnection(.{
+ .host = proxied_host,
+ .port = proxied_port,
+ .protocol = protocol,
+ })) |conn| return conn;
const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
error.ConnectionRefused => return error.ConnectionRefused,
@@ -1345,53 +1413,19 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
};
errdefer stream.close();
- conn.* = .{
- .stream = stream,
- .tls_client = undefined,
-
- .protocol = protocol,
- .host = try client.allocator.dupe(u8, host),
- .port = port,
-
- .pool_node = .{},
- };
- errdefer client.allocator.free(conn.host);
-
- if (protocol == .tls) {
- if (disable_tls) unreachable;
-
- conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
- errdefer client.allocator.destroy(conn.tls_client);
-
- const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: {
- const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) {
- error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null,
- error.OutOfMemory => return error.OutOfMemory,
- };
- defer client.allocator.free(ssl_key_log_path);
- break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{
- .truncate = false,
- .mode = switch (builtin.os.tag) {
- .windows, .wasi => 0,
- else => 0o600,
- },
- }) catch null;
- } else null;
- errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close();
-
- conn.tls_client.* = std.crypto.tls.Client.init(stream, .{
- .host = .{ .explicit = host },
- .ca = .{ .bundle = client.ca_bundle },
- .ssl_key_log_file = ssl_key_log_file,
- }) catch return error.TlsInitializationFailed;
- // This is appropriate for HTTPS because the HTTP headers contain
- // the content length which is used to detect truncation attacks.
- conn.tls_client.allow_truncation_attacks = true;
+ switch (protocol) {
+ .tls => {
+ if (disable_tls) return error.TlsInitializationFailed;
+ const tc = try Connection.Tls.create(client, proxied_host, proxied_port, stream);
+ client.connection_pool.addUsed(&tc.connection);
+ return &tc.connection;
+ },
+ .plain => {
+ const pc = try Connection.Plain.create(client, proxied_host, proxied_port, stream);
+ client.connection_pool.addUsed(&pc.connection);
+ return &pc.connection;
+ },
}
-
- client.connection_pool.addUsed(conn);
-
- return conn;
}
pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError;
@@ -1429,69 +1463,67 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti
return &conn.data;
}
-/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP
+/// Connect to `proxied_host:proxied_port` using the specified proxy with HTTP
/// CONNECT. This will reuse a connection if one is already open.
///
/// This function is threadsafe.
-pub fn connectTunnel(
+pub fn connectProxied(
client: *Client,
proxy: *Proxy,
- tunnel_host: []const u8,
- tunnel_port: u16,
+ proxied_host: []const u8,
+ proxied_port: u16,
) !*Connection {
if (!proxy.supports_connect) return error.TunnelNotSupported;
if (client.connection_pool.findConnection(.{
- .host = tunnel_host,
- .port = tunnel_port,
+ .host = proxied_host,
+ .port = proxied_port,
.protocol = proxy.protocol,
- })) |node|
- return node;
+ })) |node| return node;
var maybe_valid = false;
(tunnel: {
- const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
+ const connection = try client.connectTcpOptions(.{
+ .host = proxy.host,
+ .port = proxy.port,
+ .protocol = proxy.protocol,
+ .proxied_host = proxied_host,
+ .proxied_port = proxied_port,
+ });
errdefer {
- conn.closing = true;
- client.connection_pool.release(client.allocator, conn);
+ connection.closing = true;
+ client.connection_pool.release(connection);
}
- var buffer: [8096]u8 = undefined;
- var req = client.open(.CONNECT, .{
+ var req = client.request(.CONNECT, .{
.scheme = "http",
- .host = .{ .raw = tunnel_host },
- .port = tunnel_port,
+ .host = .{ .raw = proxied_host },
+ .port = proxied_port,
}, .{
.redirect_behavior = .unhandled,
- .connection = conn,
- .server_header_buffer = &buffer,
+ .connection = connection,
}) catch |err| {
- std.log.debug("err {}", .{err});
break :tunnel err;
};
defer req.deinit();
- req.send() catch |err| break :tunnel err;
- req.wait() catch |err| break :tunnel err;
+ req.sendBodiless() catch |err| break :tunnel err;
+ const response = req.receiveHead(&.{}) catch |err| break :tunnel err;
- if (req.response.status.class() == .server_error) {
+ if (response.head.status.class() == .server_error) {
maybe_valid = true;
break :tunnel error.ServerError;
}
- if (req.response.status != .ok) break :tunnel error.ConnectionRefused;
+ if (response.head.status != .ok) break :tunnel error.ConnectionRefused;
- // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized.
+ // this connection is now a tunnel, so we can't use it for anything
+ // else, it will only be released when the client is de-initialized.
req.connection = null;
- client.allocator.free(conn.host);
- conn.host = try client.allocator.dupe(u8, tunnel_host);
- errdefer client.allocator.free(conn.host);
+ connection.closing = false;
- conn.port = tunnel_port;
- conn.closing = false;
-
- return conn;
+ return connection;
}) catch {
// something went wrong with the tunnel
proxy.supports_connect = maybe_valid;
@@ -1499,12 +1531,11 @@ pub fn connectTunnel(
};
}
-// Prevents a dependency loop in open()
-const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused };
-pub const ConnectError = ConnectErrorPartial || RequestError;
+pub const ConnectError = ConnectTcpError || RequestError;
/// Connect to `host:port` using the specified protocol. This will reuse a
/// connection if one is already open.
+///
/// If a proxy is configured for the client, then the proxy will be used to
/// connect to the host.
///
@@ -1513,7 +1544,7 @@ pub fn connect(
client: *Client,
host: []const u8,
port: u16,
- protocol: Connection.Protocol,
+ protocol: Protocol,
) ConnectError!*Connection {
const proxy = switch (protocol) {
.plain => client.http_proxy,
@@ -1528,32 +1559,24 @@ pub fn connect(
}
if (proxy.supports_connect) tunnel: {
- return connectTunnel(client, proxy, host, port) catch |err| switch (err) {
+ return connectProxied(client, proxy, host, port) catch |err| switch (err) {
error.TunnelNotSupported => break :tunnel,
else => |e| return e,
};
}
// fall back to using the proxy as a normal http proxy
- const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
- errdefer {
- conn.closing = true;
- client.connection_pool.release(conn);
- }
-
- conn.proxied = true;
- return conn;
+ const connection = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
+ connection.proxied = true;
+ return connection;
}
-pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError ||
- std.fmt.ParseIntError || Connection.WriteError ||
- error{
- UnsupportedUriScheme,
- UriMissingHost,
-
- CertificateBundleLoadFailure,
- UnsupportedTransferEncoding,
- };
+pub const RequestError = ConnectTcpError || error{
+ UnsupportedUriScheme,
+ UriMissingHost,
+ UriHostTooLong,
+ CertificateBundleLoadFailure,
+};
pub const RequestOptions = struct {
version: http.Version = .@"HTTP/1.1",
@@ -1578,11 +1601,6 @@ pub const RequestOptions = struct {
/// payload or the server has acknowledged the payload).
redirect_behavior: Request.RedirectBehavior = @enumFromInt(3),
- /// Externally-owned memory used to store the server's entire HTTP header.
- /// `error.HttpHeadersOversize` is returned from read() when a
- /// client sends too many bytes of HTTP headers.
- server_header_buffer: []u8,
-
/// Must be an already acquired connection.
connection: ?*Connection = null,
@@ -1598,38 +1616,17 @@ pub const RequestOptions = struct {
privileged_headers: []const http.Header = &.{},
};
-fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } {
- const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{
- .{ "http", .plain },
- .{ "ws", .plain },
- .{ "https", .tls },
- .{ "wss", .tls },
- });
- const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme;
- var valid_uri = uri;
- // The host is always going to be needed as a raw string for hostname resolution anyway.
- valid_uri.host = .{
- .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena),
- };
- return .{ protocol, valid_uri };
-}
-
-fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 {
- return uri.port orelse switch (protocol) {
- .plain => 80,
- .tls => 443,
- };
+fn uriPort(uri: Uri, protocol: Protocol) u16 {
+ return uri.port orelse protocol.port();
}
/// Open a connection to the host specified by `uri` and prepare to send a HTTP request.
///
-/// `uri` must remain alive during the entire request.
-///
/// The caller is responsible for calling `deinit()` on the `Request`.
/// This function is threadsafe.
///
/// Asserts that "\r\n" does not occur in any header name or value.
-pub fn open(
+pub fn request(
client: *Client,
method: http.Method,
uri: Uri,
@@ -1649,59 +1646,58 @@ pub fn open(
}
}
- var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer);
- const protocol, const valid_uri = try validateUri(uri, server_header.allocator());
+ const protocol = Protocol.fromUri(uri) orelse return error.UnsupportedUriScheme;
- if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) {
+ if (protocol == .tls) {
if (disable_tls) unreachable;
-
- client.ca_bundle_mutex.lock();
- defer client.ca_bundle_mutex.unlock();
-
- if (client.next_https_rescan_certs) {
- client.ca_bundle.rescan(client.allocator) catch
- return error.CertificateBundleLoadFailure;
- @atomicStore(bool, &client.next_https_rescan_certs, false, .release);
+ if (@atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) {
+ client.ca_bundle_mutex.lock();
+ defer client.ca_bundle_mutex.unlock();
+
+ if (client.next_https_rescan_certs) {
+ client.ca_bundle.rescan(client.allocator) catch
+ return error.CertificateBundleLoadFailure;
+ @atomicStore(bool, &client.next_https_rescan_certs, false, .release);
+ }
}
}
- const conn = options.connection orelse
- try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol);
+ const connection = options.connection orelse c: {
+ var host_name_buffer: [Uri.host_name_max]u8 = undefined;
+ const host_name = try uri.getHost(&host_name_buffer);
+ break :c try client.connect(host_name, uriPort(uri, protocol), protocol);
+ };
- var req: Request = .{
- .uri = valid_uri,
+ return .{
+ .uri = uri,
.client = client,
- .connection = conn,
+ .connection = connection,
+ .reader = .{
+ .in = connection.reader(),
+ .state = .ready,
+ // Populated when `http.Reader.bodyReader` is called.
+ .interface = undefined,
+ },
.keep_alive = options.keep_alive,
.method = method,
.version = options.version,
.transfer_encoding = .none,
.redirect_behavior = options.redirect_behavior,
.handle_continue = options.handle_continue,
- .response = .{
- .version = undefined,
- .status = undefined,
- .reason = undefined,
- .keep_alive = undefined,
- .parser = .init(server_header.buffer[server_header.end_index..]),
- },
.headers = options.headers,
.extra_headers = options.extra_headers,
.privileged_headers = options.privileged_headers,
};
- errdefer req.deinit();
-
- return req;
}
pub const FetchOptions = struct {
- server_header_buffer: ?[]u8 = null,
+ /// `null` means it will be heap-allocated.
+ redirect_buffer: ?[]u8 = null,
+ /// `null` means it will be heap-allocated.
+ decompress_buffer: ?[]u8 = null,
redirect_behavior: ?Request.RedirectBehavior = null,
-
- /// If the server sends a body, it will be appended to this ArrayList.
- /// `max_append_size` provides an upper limit for how much they can grow.
- response_storage: ResponseStorage = .ignore,
- max_append_size: ?usize = null,
+ /// If the server sends a body, it will be stored here.
+ response_storage: ?ResponseStorage = null,
location: Location,
method: ?http.Method = null,
@@ -1725,11 +1721,11 @@ pub const FetchOptions = struct {
uri: Uri,
};
- pub const ResponseStorage = union(enum) {
- ignore,
- /// Only the existing capacity will be used.
- static: *std.ArrayListUnmanaged(u8),
- dynamic: *std.ArrayList(u8),
+ pub const ResponseStorage = struct {
+ list: *std.ArrayListUnmanaged(u8),
+ /// If null then only the existing capacity will be used.
+ allocator: ?Allocator = null,
+ append_limit: std.io.Limit = .unlimited,
};
};
@@ -1737,23 +1733,29 @@ pub const FetchResult = struct {
status: http.Status,
};
+pub const FetchError = Uri.ParseError || RequestError || Request.ReceiveHeadError || error{
+ StreamTooLong,
+ /// TODO provide optional diagnostics when this occurs or break into more error codes
+ WriteFailed,
+ UnsupportedCompressionMethod,
+};
+
/// Perform a one-shot HTTP request with the provided options.
///
/// This function is threadsafe.
-pub fn fetch(client: *Client, options: FetchOptions) !FetchResult {
+pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult {
const uri = switch (options.location) {
.url => |u| try Uri.parse(u),
.uri => |u| u,
};
- var server_header_buffer: [16 * 1024]u8 = undefined;
-
const method: http.Method = options.method orelse
if (options.payload != null) .POST else .GET;
- var req = try open(client, method, uri, .{
- .server_header_buffer = options.server_header_buffer orelse &server_header_buffer,
- .redirect_behavior = options.redirect_behavior orelse
- if (options.payload == null) @enumFromInt(3) else .unhandled,
+ const redirect_behavior: Request.RedirectBehavior = options.redirect_behavior orelse
+ if (options.payload == null) @enumFromInt(3) else .unhandled;
+
+ var req = try request(client, method, uri, .{
+ .redirect_behavior = redirect_behavior,
.headers = options.headers,
.extra_headers = options.extra_headers,
.privileged_headers = options.privileged_headers,
@@ -1761,44 +1763,70 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult {
});
defer req.deinit();
- if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len };
+ if (options.payload) |payload| {
+ req.transfer_encoding = .{ .content_length = payload.len };
+ var body = try req.sendBody(&.{});
+ try body.writer.writeAll(payload);
+ try body.end();
+ } else {
+ try req.sendBodiless();
+ }
- try req.send();
+ const redirect_buffer: []u8 = if (redirect_behavior == .unhandled) &.{} else options.redirect_buffer orelse
+ try client.allocator.alloc(u8, 8 * 1024);
+ defer if (options.redirect_buffer == null) client.allocator.free(redirect_buffer);
- if (options.payload) |payload| try req.writeAll(payload);
+ var response = try req.receiveHead(redirect_buffer);
- try req.finish();
- try req.wait();
+ const storage = options.response_storage orelse {
+ const reader = response.reader(&.{});
+ _ = reader.discardRemaining() catch |err| switch (err) {
+ error.ReadFailed => return response.bodyErr().?,
+ };
+ return .{ .status = response.head.status };
+ };
- switch (options.response_storage) {
- .ignore => {
- // Take advantage of request internals to discard the response body
- // and make the connection available for another request.
- req.response.skip = true;
- assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping.
- },
- .dynamic => |list| {
- const max_append_size = options.max_append_size orelse 2 * 1024 * 1024;
- try req.reader().readAllArrayList(list, max_append_size);
- },
- .static => |list| {
- const buf = b: {
- const buf = list.unusedCapacitySlice();
- if (options.max_append_size) |len| {
- if (len < buf.len) break :b buf[0..len];
- }
- break :b buf;
- };
- list.items.len += try req.reader().readAll(buf);
- },
+ const decompress_buffer: []u8 = switch (response.head.content_encoding) {
+ .identity => &.{},
+ .zstd => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.zstd.default_window_len),
+ .deflate, .gzip => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.flate.max_window_len),
+ .compress => return error.UnsupportedCompressionMethod,
+ };
+ defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer);
+
+ var decompressor: http.Decompressor = undefined;
+ const reader = response.readerDecompressing(&decompressor, decompress_buffer);
+ const list = storage.list;
+
+ if (storage.allocator) |allocator| {
+ reader.appendRemaining(allocator, null, list, storage.append_limit) catch |err| switch (err) {
+ error.ReadFailed => return response.bodyErr().?,
+ else => |e| return e,
+ };
+ } else {
+ const buf = storage.append_limit.slice(list.unusedCapacitySlice());
+ list.items.len += reader.readSliceShort(buf) catch |err| switch (err) {
+ error.ReadFailed => return response.bodyErr().?,
+ };
}
- return .{
- .status = req.response.status,
- };
+ return .{ .status = response.head.status };
+}
+
+pub fn sameParentDomain(parent_host: []const u8, child_host: []const u8) bool {
+ if (!std.ascii.endsWithIgnoreCase(child_host, parent_host)) return false;
+ if (child_host.len == parent_host.len) return true;
+ if (parent_host.len > child_host.len) return false;
+ return child_host[child_host.len - parent_host.len - 1] == '.';
+}
+
+test sameParentDomain {
+ try testing.expect(!sameParentDomain("foo.com", "bar.com"));
+ try testing.expect(sameParentDomain("foo.com", "foo.com"));
+ try testing.expect(sameParentDomain("foo.com", "bar.foo.com"));
+ try testing.expect(!sameParentDomain("bar.foo.com", "foo.com"));
}
test {
_ = Response;
- _ = &initDefaultProxies;
}
diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig
@@ -1,139 +1,69 @@
-//! Blocking HTTP server implementation.
-//! Handles a single connection's lifecycle.
-
-connection: net.Server.Connection,
-/// Keeps track of whether the Server is ready to accept a new request on the
-/// same connection, and makes invalid API usage cause assertion failures
-/// rather than HTTP protocol violations.
-state: State,
-/// User-provided buffer that must outlive this Server.
-/// Used to store the client's entire HTTP header.
-read_buffer: []u8,
-/// Amount of available data inside read_buffer.
-read_buffer_len: usize,
-/// Index into `read_buffer` of the first byte of the next HTTP request.
-next_request_start: usize,
-
-pub const State = enum {
- /// The connection is available to be used for the first time, or reused.
- ready,
- /// An error occurred in `receiveHead`.
- receiving_head,
- /// A Request object has been obtained and from there a Response can be
- /// opened.
- received_head,
- /// The client is uploading something to this Server.
- receiving_body,
- /// The connection is eligible for another HTTP request, however the client
- /// and server did not negotiate a persistent connection.
- closing,
-};
+//! Handles a single connection lifecycle.
+
+const std = @import("../std.zig");
+const http = std.http;
+const mem = std.mem;
+const Uri = std.Uri;
+const assert = std.debug.assert;
+const testing = std.testing;
+const Writer = std.Io.Writer;
+const Reader = std.Io.Reader;
+
+const Server = @This();
+
+/// Data from the HTTP server to the HTTP client.
+out: *Writer,
+reader: http.Reader,
/// Initialize an HTTP server that can respond to multiple requests on the same
/// connection.
+///
+/// The buffer of `in` must be large enough to store the client's entire HTTP
+/// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`.
+///
/// The returned `Server` is ready for `receiveHead` to be called.
-pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server {
+pub fn init(in: *Reader, out: *Writer) Server {
return .{
- .connection = connection,
- .state = .ready,
- .read_buffer = read_buffer,
- .read_buffer_len = 0,
- .next_request_start = 0,
+ .reader = .{
+ .in = in,
+ .state = .ready,
+ // Populated when `http.Reader.bodyReader` is called.
+ .interface = undefined,
+ },
+ .out = out,
};
}
-pub const ReceiveHeadError = error{
- /// Client sent too many bytes of HTTP headers.
- /// The HTTP specification suggests to respond with a 431 status code
- /// before closing the connection.
- HttpHeadersOversize,
+pub const ReceiveHeadError = http.Reader.HeadError || error{
/// Client sent headers that did not conform to the HTTP protocol.
+ ///
+ /// To find out more detailed diagnostics, `Request.head_buffer` can be
+ /// passed directly to `Request.Head.parse`.
HttpHeadersInvalid,
- /// A low level I/O error occurred trying to read the headers.
- HttpHeadersUnreadable,
- /// Partial HTTP request was received but the connection was closed before
- /// fully receiving the headers.
- HttpRequestTruncated,
- /// The client sent 0 bytes of headers before closing the stream.
- /// In other words, a keep-alive connection was finally closed.
- HttpConnectionClosing,
};
-/// The header bytes reference the read buffer that Server was initialized with
-/// and remain alive until the next call to receiveHead.
pub fn receiveHead(s: *Server) ReceiveHeadError!Request {
- assert(s.state == .ready);
- s.state = .received_head;
- errdefer s.state = .receiving_head;
-
- // In case of a reused connection, move the next request's bytes to the
- // beginning of the buffer.
- if (s.next_request_start > 0) {
- if (s.read_buffer_len > s.next_request_start) {
- rebase(s, 0);
- } else {
- s.read_buffer_len = 0;
- }
- }
-
- var hp: http.HeadParser = .{};
-
- if (s.read_buffer_len > 0) {
- const bytes = s.read_buffer[0..s.read_buffer_len];
- const end = hp.feed(bytes);
- if (hp.state == .finished)
- return finishReceivingHead(s, end);
- }
-
- while (true) {
- const buf = s.read_buffer[s.read_buffer_len..];
- if (buf.len == 0)
- return error.HttpHeadersOversize;
- const read_n = s.connection.stream.read(buf) catch
- return error.HttpHeadersUnreadable;
- if (read_n == 0) {
- if (s.read_buffer_len > 0) {
- return error.HttpRequestTruncated;
- } else {
- return error.HttpConnectionClosing;
- }
- }
- s.read_buffer_len += read_n;
- const bytes = buf[0..read_n];
- const end = hp.feed(bytes);
- if (hp.state == .finished)
- return finishReceivingHead(s, s.read_buffer_len - bytes.len + end);
- }
-}
-
-fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request {
+ const head_buffer = try s.reader.receiveHead();
return .{
.server = s,
- .head_end = head_end,
- .head = Request.Head.parse(s.read_buffer[0..head_end]) catch
- return error.HttpHeadersInvalid,
- .reader_state = undefined,
+ .head_buffer = head_buffer,
+ // No need to track the returned error here since users can repeat the
+ // parse with the header buffer to get detailed diagnostics.
+ .head = Request.Head.parse(head_buffer) catch return error.HttpHeadersInvalid,
};
}
pub const Request = struct {
server: *Server,
- /// Index into Server's read_buffer.
- head_end: usize,
+ /// Pointers in this struct are invalidated when the request body stream is
+ /// initialized.
head: Head,
- reader_state: union {
- remaining_content_length: u64,
- chunk_parser: http.ChunkParser,
- },
-
- pub const Compression = union(enum) {
- pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader);
- pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader);
-
- deflate: std.compress.flate.Decompress,
- gzip: std.compress.flate.Decompress,
- zstd: std.compress.zstd.Decompress,
- none: void,
+ head_buffer: []const u8,
+ respond_err: ?RespondError = null,
+
+ pub const RespondError = error{
+ /// The request contained an `expect` header with an unrecognized value.
+ HttpExpectationFailed,
};
pub const Head = struct {
@@ -146,7 +76,6 @@ pub const Request = struct {
transfer_encoding: http.TransferEncoding,
transfer_compression: http.ContentEncoding,
keep_alive: bool,
- compression: Compression,
pub const ParseError = error{
UnknownHttpMethod,
@@ -168,10 +97,9 @@ pub const Request = struct {
const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse
return error.HttpHeadersInvalid;
- if (method_end > 24) return error.HttpHeadersInvalid;
- const method_str = first_line[0..method_end];
- const method: http.Method = @enumFromInt(http.Method.parse(method_str));
+ const method = std.meta.stringToEnum(http.Method, first_line[0..method_end]) orelse
+ return error.UnknownHttpMethod;
const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse
return error.HttpHeadersInvalid;
@@ -200,7 +128,6 @@ pub const Request = struct {
.@"HTTP/1.0" => false,
.@"HTTP/1.1" => true,
},
- .compression = .none,
};
while (it.next()) |line| {
@@ -230,7 +157,7 @@ pub const Request = struct {
const trimmed = mem.trim(u8, header_value, " ");
- if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ if (http.ContentEncoding.fromString(trimmed)) |ce| {
head.transfer_compression = ce;
} else {
return error.HttpTransferEncodingUnsupported;
@@ -255,7 +182,7 @@ pub const Request = struct {
if (next) |second| {
const trimmed_second = mem.trim(u8, second, " ");
- if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| {
+ if (http.ContentEncoding.fromString(trimmed_second)) |transfer| {
if (head.transfer_compression != .identity)
return error.HttpHeadersInvalid; // double compression is not supported
head.transfer_compression = transfer;
@@ -296,10 +223,19 @@ pub const Request = struct {
inline fn int64(array: *const [8]u8) u64 {
return @bitCast(array.*);
}
+
+ /// Help the programmer avoid bugs by calling this when the string
+ /// memory of `Head` becomes invalidated.
+ fn invalidateStrings(h: *Head) void {
+ h.target = undefined;
+ if (h.expect) |*s| s.* = undefined;
+ if (h.content_type) |*s| s.* = undefined;
+ }
};
- pub fn iterateHeaders(r: *Request) http.HeaderIterator {
- return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]);
+ pub fn iterateHeaders(r: *const Request) http.HeaderIterator {
+ assert(r.server.reader.state == .received_head);
+ return http.HeaderIterator.init(r.head_buffer);
}
test iterateHeaders {
@@ -310,22 +246,19 @@ pub const Request = struct {
"TRansfer-encoding:\tdeflate, chunked \r\n" ++
"connectioN:\t keep-alive \r\n\r\n";
- var read_buffer: [500]u8 = undefined;
- @memcpy(read_buffer[0..request_bytes.len], request_bytes);
-
var server: Server = .{
- .connection = undefined,
- .state = .ready,
- .read_buffer = &read_buffer,
- .read_buffer_len = request_bytes.len,
- .next_request_start = 0,
+ .reader = .{
+ .in = undefined,
+ .state = .received_head,
+ .interface = undefined,
+ },
+ .out = undefined,
};
var request: Request = .{
.server = &server,
- .head_end = request_bytes.len,
.head = undefined,
- .reader_state = undefined,
+ .head_buffer = @constCast(request_bytes),
};
var it = request.iterateHeaders();
@@ -384,16 +317,22 @@ pub const Request = struct {
/// no error is surfaced.
///
/// Asserts status is not `continue`.
- /// Asserts there are at most 25 extra_headers.
/// Asserts that "\r\n" does not occur in any header name or value.
pub fn respond(
request: *Request,
content: []const u8,
options: RespondOptions,
- ) Response.WriteError!void {
- const max_extra_headers = 25;
+ ) ExpectContinueError!void {
+ try respondUnflushed(request, content, options);
+ try request.server.out.flush();
+ }
+
+ pub fn respondUnflushed(
+ request: *Request,
+ content: []const u8,
+ options: RespondOptions,
+ ) ExpectContinueError!void {
assert(options.status != .@"continue");
- assert(options.extra_headers.len <= max_extra_headers);
if (std.debug.runtime_safety) {
for (options.extra_headers) |header| {
assert(header.name.len != 0);
@@ -402,6 +341,7 @@ pub const Request = struct {
assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null);
}
}
+ try writeExpectContinue(request);
const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none;
const server_keep_alive = !transfer_encoding_none and options.keep_alive;
@@ -409,130 +349,42 @@ pub const Request = struct {
const phrase = options.reason orelse options.status.phrase() orelse "";
- var first_buffer: [500]u8 = undefined;
- var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer);
- if (request.head.expect != null) {
- // reader() and hence discardBody() above sets expect to null if it
- // is handled. So the fact that it is not null here means unhandled.
- h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n");
- if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n");
- h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n");
- try request.server.connection.stream.writeAll(h.items);
- return;
- }
- h.fixedWriter().print("{s} {d} {s}\r\n", .{
+ const out = request.server.out;
+ try out.print("{s} {d} {s}\r\n", .{
@tagName(options.version), @intFromEnum(options.status), phrase,
- }) catch unreachable;
+ });
switch (options.version) {
- .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"),
- .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"),
+ .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"),
+ .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"),
}
if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
.none => {},
- .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"),
+ .chunked => try out.writeAll("transfer-encoding: chunked\r\n"),
} else {
- h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable;
+ try out.print("content-length: {d}\r\n", .{content.len});
}
- var chunk_header_buffer: [18]u8 = undefined;
- var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined;
- var iovecs_len: usize = 0;
-
- iovecs[iovecs_len] = .{
- .base = h.items.ptr,
- .len = h.items.len,
- };
- iovecs_len += 1;
-
for (options.extra_headers) |header| {
- iovecs[iovecs_len] = .{
- .base = header.name.ptr,
- .len = header.name.len,
- };
- iovecs_len += 1;
-
- iovecs[iovecs_len] = .{
- .base = ": ",
- .len = 2,
- };
- iovecs_len += 1;
-
- if (header.value.len != 0) {
- iovecs[iovecs_len] = .{
- .base = header.value.ptr,
- .len = header.value.len,
- };
- iovecs_len += 1;
- }
-
- iovecs[iovecs_len] = .{
- .base = "\r\n",
- .len = 2,
- };
- iovecs_len += 1;
+ var vecs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" };
+ try out.writeVecAll(&vecs);
}
- iovecs[iovecs_len] = .{
- .base = "\r\n",
- .len = 2,
- };
- iovecs_len += 1;
+ try out.writeAll("\r\n");
if (request.head.method != .HEAD) {
const is_chunked = (options.transfer_encoding orelse .none) == .chunked;
if (is_chunked) {
- if (content.len > 0) {
- const chunk_header = std.fmt.bufPrint(
- &chunk_header_buffer,
- "{x}\r\n",
- .{content.len},
- ) catch unreachable;
-
- iovecs[iovecs_len] = .{
- .base = chunk_header.ptr,
- .len = chunk_header.len,
- };
- iovecs_len += 1;
-
- iovecs[iovecs_len] = .{
- .base = content.ptr,
- .len = content.len,
- };
- iovecs_len += 1;
-
- iovecs[iovecs_len] = .{
- .base = "\r\n",
- .len = 2,
- };
- iovecs_len += 1;
- }
-
- iovecs[iovecs_len] = .{
- .base = "0\r\n\r\n",
- .len = 5,
- };
- iovecs_len += 1;
+ if (content.len > 0) try out.print("{x}\r\n{s}\r\n", .{ content.len, content });
+ try out.writeAll("0\r\n\r\n");
} else if (content.len > 0) {
- iovecs[iovecs_len] = .{
- .base = content.ptr,
- .len = content.len,
- };
- iovecs_len += 1;
+ try out.writeAll(content);
}
}
-
- try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]);
}
pub const RespondStreamingOptions = struct {
- /// An externally managed slice of memory used to batch bytes before
- /// sending. `respondStreaming` asserts this is large enough to store
- /// the full HTTP response head.
- ///
- /// Must outlive the returned Response.
- send_buffer: []u8,
/// If provided, the response will use the content-length header;
/// otherwise it will use transfer-encoding: chunked.
content_length: ?u64 = null,
@@ -540,254 +392,227 @@ pub const Request = struct {
respond_options: RespondOptions = .{},
};
- /// The header is buffered but not sent until Response.flush is called.
+ /// The header is not guaranteed to be sent until `BodyWriter.flush` or
+ /// `BodyWriter.end` is called.
///
/// If the request contains a body and the connection is to be reused,
/// discards the request body, leaving the Server in the `ready` state. If
/// this discarding fails, the connection is marked as not to be reused and
/// no error is surfaced.
///
- /// HEAD requests are handled transparently by setting a flag on the
- /// returned Response to omit the body. However it may be worth noticing
+ /// HEAD requests are handled transparently by setting the
+ /// `BodyWriter.elide` flag on the returned `BodyWriter`, causing
+ /// the response stream to omit the body. However, it may be worth noticing
/// that flag and skipping any expensive work that would otherwise need to
/// be done to satisfy the request.
///
- /// Asserts `send_buffer` is large enough to store the entire response header.
/// Asserts status is not `continue`.
- pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response {
+ pub fn respondStreaming(
+ request: *Request,
+ buffer: []u8,
+ options: RespondStreamingOptions,
+ ) ExpectContinueError!http.BodyWriter {
+ try writeExpectContinue(request);
const o = options.respond_options;
assert(o.status != .@"continue");
const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none;
const server_keep_alive = !transfer_encoding_none and o.keep_alive;
const keep_alive = request.discardBody(server_keep_alive);
const phrase = o.reason orelse o.status.phrase() orelse "";
+ const out = request.server.out;
- var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer);
-
- const elide_body = if (request.head.expect != null) eb: {
- // reader() and hence discardBody() above sets expect to null if it
- // is handled. So the fact that it is not null here means unhandled.
- h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n");
- if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n");
- h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n");
- break :eb true;
- } else eb: {
- h.fixedWriter().print("{s} {d} {s}\r\n", .{
- @tagName(o.version), @intFromEnum(o.status), phrase,
- }) catch unreachable;
-
- switch (o.version) {
- .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"),
- .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"),
- }
+ try out.print("{s} {d} {s}\r\n", .{
+ @tagName(o.version), @intFromEnum(o.status), phrase,
+ });
- if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
- .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"),
- .none => {},
- } else if (options.content_length) |len| {
- h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable;
- } else {
- h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n");
- }
+ switch (o.version) {
+ .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"),
+ .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"),
+ }
- for (o.extra_headers) |header| {
- assert(header.name.len != 0);
- h.appendSliceAssumeCapacity(header.name);
- h.appendSliceAssumeCapacity(": ");
- h.appendSliceAssumeCapacity(header.value);
- h.appendSliceAssumeCapacity("\r\n");
- }
+ if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
+ .chunked => try out.writeAll("transfer-encoding: chunked\r\n"),
+ .none => {},
+ } else if (options.content_length) |len| {
+ try out.print("content-length: {d}\r\n", .{len});
+ } else {
+ try out.writeAll("transfer-encoding: chunked\r\n");
+ }
- h.appendSliceAssumeCapacity("\r\n");
- break :eb request.head.method == .HEAD;
- };
+ for (o.extra_headers) |header| {
+ assert(header.name.len != 0);
+ var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" };
+ try out.writeVecAll(&bufs);
+ }
- return .{
- .stream = request.server.connection.stream,
- .send_buffer = options.send_buffer,
- .send_buffer_start = 0,
- .send_buffer_end = h.items.len,
- .transfer_encoding = if (o.transfer_encoding) |te| switch (te) {
- .chunked => .chunked,
- .none => .none,
- } else if (options.content_length) |len| .{
- .content_length = len,
- } else .chunked,
- .elide_body = elide_body,
- .chunk_len = 0,
+ try out.writeAll("\r\n");
+ const elide_body = request.head.method == .HEAD;
+ const state: http.BodyWriter.State = if (o.transfer_encoding) |te| switch (te) {
+ .chunked => .{ .chunked = .init },
+ .none => .none,
+ } else if (options.content_length) |len| .{
+ .content_length = len,
+ } else .{ .chunked = .init };
+
+ return if (elide_body) .{
+ .http_protocol_output = request.server.out,
+ .state = state,
+ .writer = .{
+ .buffer = buffer,
+ .vtable = &.{
+ .drain = http.BodyWriter.elidingDrain,
+ .sendFile = http.BodyWriter.elidingSendFile,
+ },
+ },
+ } else .{
+ .http_protocol_output = request.server.out,
+ .state = state,
+ .writer = .{
+ .buffer = buffer,
+ .vtable = switch (state) {
+ .none => &.{
+ .drain = http.BodyWriter.noneDrain,
+ .sendFile = http.BodyWriter.noneSendFile,
+ },
+ .content_length => &.{
+ .drain = http.BodyWriter.contentLengthDrain,
+ .sendFile = http.BodyWriter.contentLengthSendFile,
+ },
+ .chunked => &.{
+ .drain = http.BodyWriter.chunkedDrain,
+ .sendFile = http.BodyWriter.chunkedSendFile,
+ },
+ .end => unreachable,
+ },
+ },
};
}
- pub const ReadError = net.Stream.ReadError || error{
- HttpChunkInvalid,
- HttpHeadersOversize,
+ pub const UpgradeRequest = union(enum) {
+ websocket: ?[]const u8,
+ other: []const u8,
+ none,
};
- fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize {
- const request: *Request = @ptrCast(@alignCast(@constCast(context)));
- const s = request.server;
-
- const remaining_content_length = &request.reader_state.remaining_content_length;
- if (remaining_content_length.* == 0) {
- s.state = .ready;
- return 0;
+ /// Does not invalidate `request.head`.
+ pub fn upgradeRequested(request: *const Request) UpgradeRequest {
+ switch (request.head.version) {
+ .@"HTTP/1.0" => return .none,
+ .@"HTTP/1.1" => if (request.head.method != .GET) return .none,
}
- assert(s.state == .receiving_body);
- const available = try fill(s, request.head_end);
- const len = @min(remaining_content_length.*, available.len, buffer.len);
- @memcpy(buffer[0..len], available[0..len]);
- remaining_content_length.* -= len;
- s.next_request_start += len;
- if (remaining_content_length.* == 0)
- s.state = .ready;
- return len;
- }
-
- fn fill(s: *Server, head_end: usize) ReadError![]u8 {
- const available = s.read_buffer[s.next_request_start..s.read_buffer_len];
- if (available.len > 0) return available;
- s.next_request_start = head_end;
- s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]);
- return s.read_buffer[head_end..s.read_buffer_len];
- }
- fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize {
- const request: *Request = @ptrCast(@alignCast(@constCast(context)));
- const s = request.server;
-
- const cp = &request.reader_state.chunk_parser;
- const head_end = request.head_end;
-
- // Protect against returning 0 before the end of stream.
- var out_end: usize = 0;
- while (out_end == 0) {
- switch (cp.state) {
- .invalid => return 0,
- .data => {
- assert(s.state == .receiving_body);
- const available = try fill(s, head_end);
- const len = @min(cp.chunk_len, available.len, buffer.len);
- @memcpy(buffer[0..len], available[0..len]);
- cp.chunk_len -= len;
- if (cp.chunk_len == 0)
- cp.state = .data_suffix;
- out_end += len;
- s.next_request_start += len;
- continue;
- },
- else => {
- assert(s.state == .receiving_body);
- const available = try fill(s, head_end);
- const n = cp.feed(available);
- switch (cp.state) {
- .invalid => return error.HttpChunkInvalid,
- .data => {
- if (cp.chunk_len == 0) {
- // The next bytes in the stream are trailers,
- // or \r\n to indicate end of chunked body.
- //
- // This function must append the trailers at
- // head_end so that headers and trailers are
- // together.
- //
- // Since returning 0 would indicate end of
- // stream, this function must read all the
- // trailers before returning.
- if (s.next_request_start > head_end) rebase(s, head_end);
- var hp: http.HeadParser = .{};
- {
- const bytes = s.read_buffer[head_end..s.read_buffer_len];
- const end = hp.feed(bytes);
- if (hp.state == .finished) {
- cp.state = .invalid;
- s.state = .ready;
- s.next_request_start = s.read_buffer_len - bytes.len + end;
- return out_end;
- }
- }
- while (true) {
- const buf = s.read_buffer[s.read_buffer_len..];
- if (buf.len == 0)
- return error.HttpHeadersOversize;
- const read_n = try s.connection.stream.read(buf);
- s.read_buffer_len += read_n;
- const bytes = buf[0..read_n];
- const end = hp.feed(bytes);
- if (hp.state == .finished) {
- cp.state = .invalid;
- s.state = .ready;
- s.next_request_start = s.read_buffer_len - bytes.len + end;
- return out_end;
- }
- }
- }
- const data = available[n..];
- const len = @min(cp.chunk_len, data.len, buffer.len);
- @memcpy(buffer[0..len], data[0..len]);
- cp.chunk_len -= len;
- if (cp.chunk_len == 0)
- cp.state = .data_suffix;
- out_end += len;
- s.next_request_start += n + len;
- continue;
- },
- else => continue,
- }
- },
+ var sec_websocket_key: ?[]const u8 = null;
+ var upgrade_name: ?[]const u8 = null;
+ var it = request.iterateHeaders();
+ while (it.next()) |header| {
+ if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) {
+ sec_websocket_key = header.value;
+ } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) {
+ upgrade_name = header.value;
}
}
- return out_end;
+
+ const name = upgrade_name orelse return .none;
+ if (std.ascii.eqlIgnoreCase(name, "websocket")) return .{ .websocket = sec_websocket_key };
+ return .{ .other = name };
}
- pub const ReaderError = Response.WriteError || error{
- /// The client sent an expect HTTP header value other than
- /// "100-continue".
- HttpExpectationFailed,
+ pub const WebSocketOptions = struct {
+ /// The value from `UpgradeRequest.websocket` (sec-websocket-key header value).
+ key: []const u8,
+ reason: ?[]const u8 = null,
+ extra_headers: []const http.Header = &.{},
};
+ /// The header is not guaranteed to be sent until `WebSocket.flush` is
+ /// called on the returned struct.
+ pub fn respondWebSocket(request: *Request, options: WebSocketOptions) ExpectContinueError!WebSocket {
+ if (request.head.expect != null) return error.HttpExpectationFailed;
+
+ const out = request.server.out;
+ const version: http.Version = .@"HTTP/1.1";
+ const status: http.Status = .switching_protocols;
+ const phrase = options.reason orelse status.phrase() orelse "";
+
+ assert(request.head.version == version);
+ assert(request.head.method == .GET);
+
+ var sha1 = std.crypto.hash.Sha1.init(.{});
+ sha1.update(options.key);
+ sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
+ var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
+ sha1.final(&digest);
+ try out.print("{s} {d} {s}\r\n", .{ @tagName(version), @intFromEnum(status), phrase });
+ try out.writeAll("connection: upgrade\r\nupgrade: websocket\r\nsec-websocket-accept: ");
+ const base64_digest = try out.writableArray(28);
+ assert(std.base64.standard.Encoder.encode(base64_digest, &digest).len == base64_digest.len);
+ out.advance(base64_digest.len);
+ try out.writeAll("\r\n");
+
+ for (options.extra_headers) |header| {
+ assert(header.name.len != 0);
+ var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" };
+ try out.writeVecAll(&bufs);
+ }
+
+ try out.writeAll("\r\n");
+
+ return .{
+ .input = request.server.reader.in,
+ .output = request.server.out,
+ .key = options.key,
+ };
+ }
+
/// In the case that the request contains "expect: 100-continue", this
/// function writes the continuation header, which means it can fail with a
/// write error. After sending the continuation header, it sets the
/// request's expect field to `null`.
///
/// Asserts that this function is only called once.
- pub fn reader(request: *Request) ReaderError!std.io.AnyReader {
- const s = request.server;
- assert(s.state == .received_head);
- s.state = .receiving_body;
- s.next_request_start = request.head_end;
-
- if (request.head.expect) |expect| {
- if (mem.eql(u8, expect, "100-continue")) {
- try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n");
- request.head.expect = null;
- } else {
- return error.HttpExpectationFailed;
- }
- }
+ ///
+ /// See `readerExpectNone` for an infallible alternative that cannot write
+ /// to the server output stream.
+ pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*Reader {
+ const flush = request.head.expect != null;
+ try writeExpectContinue(request);
+ if (flush) try request.server.out.flush();
+ return readerExpectNone(request, buffer);
+ }
- switch (request.head.transfer_encoding) {
- .chunked => {
- request.reader_state = .{ .chunk_parser = http.ChunkParser.init };
- return .{
- .readFn = read_chunked,
- .context = request,
- };
- },
- .none => {
- request.reader_state = .{
- .remaining_content_length = request.head.content_length orelse 0,
- };
- return .{
- .readFn = read_cl,
- .context = request,
- };
- },
- }
+ /// Asserts the expect header is `null`. The caller must handle the
+ /// expectation manually and then set the value to `null` prior to calling
+ /// this function.
+ ///
+ /// Asserts that this function is only called once.
+ ///
+ /// Invalidates the string memory inside `Head`.
+ pub fn readerExpectNone(request: *Request, buffer: []u8) *Reader {
+ assert(request.server.reader.state == .received_head);
+ assert(request.head.expect == null);
+ request.head.invalidateStrings();
+ if (!request.head.method.requestHasBody()) return .ending;
+ return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length);
+ }
+
+ pub const ExpectContinueError = error{
+ /// Failed to write "HTTP/1.1 100 Continue\r\n\r\n" to the stream.
+ WriteFailed,
+ /// The client sent an expect HTTP header value other than
+ /// "100-continue".
+ HttpExpectationFailed,
+ };
+
+ pub fn writeExpectContinue(request: *Request) ExpectContinueError!void {
+ const expect = request.head.expect orelse return;
+ if (!mem.eql(u8, expect, "100-continue")) return error.HttpExpectationFailed;
+ try request.server.out.writeAll("HTTP/1.1 100 Continue\r\n\r\n");
+ request.head.expect = null;
}
/// Returns whether the connection should remain persistent.
- /// If it would fail, it instead sets the Server state to `receiving_body`
+ ///
+ /// If it would fail, it instead sets the Server state to receiving body
/// and returns false.
fn discardBody(request: *Request, keep_alive: bool) bool {
// Prepare to receive another request on the same connection.
@@ -798,350 +623,180 @@ pub const Request = struct {
// or the request body.
// If the connection won't be kept alive, then none of this matters
// because the connection will be severed after the response is sent.
- const s = request.server;
- if (keep_alive and request.head.keep_alive) switch (s.state) {
+ const r = &request.server.reader;
+ if (keep_alive and request.head.keep_alive) switch (r.state) {
.received_head => {
- const r = request.reader() catch return false;
- _ = r.discard() catch return false;
- assert(s.state == .ready);
+ if (request.head.method.requestHasBody()) {
+ assert(request.head.transfer_encoding != .none or request.head.content_length != null);
+ const reader_interface = request.readerExpectContinue(&.{}) catch return false;
+ _ = reader_interface.discardRemaining() catch return false;
+ assert(r.state == .ready);
+ } else {
+ r.state = .ready;
+ }
return true;
},
- .receiving_body, .ready => return true,
+ .body_remaining_content_length, .body_remaining_chunk_len, .body_none, .ready => return true,
else => unreachable,
};
// Avoid clobbering the state in case a reading stream already exists.
- switch (s.state) {
- .received_head => s.state = .closing,
+ switch (r.state) {
+ .received_head => r.state = .closing,
else => {},
}
return false;
}
};
-pub const Response = struct {
- stream: net.Stream,
- send_buffer: []u8,
- /// Index of the first byte in `send_buffer`.
- /// This is 0 unless a short write happens in `write`.
- send_buffer_start: usize,
- /// Index of the last byte + 1 in `send_buffer`.
- send_buffer_end: usize,
- /// `null` means transfer-encoding: chunked.
- /// As a debugging utility, counts down to zero as bytes are written.
- transfer_encoding: TransferEncoding,
- elide_body: bool,
- /// Indicates how much of the end of the `send_buffer` corresponds to a
- /// chunk. This amount of data will be wrapped by an HTTP chunk header.
- chunk_len: usize,
-
- pub const TransferEncoding = union(enum) {
- /// End of connection signals the end of the stream.
- none,
- /// As a debugging utility, counts down to zero as bytes are written.
- content_length: u64,
- /// Each chunk is wrapped in a header and trailer.
- chunked,
+/// See https://tools.ietf.org/html/rfc6455
+pub const WebSocket = struct {
+ key: []const u8,
+ input: *Reader,
+ output: *Writer,
+
+ pub const Header0 = packed struct(u8) {
+ opcode: Opcode,
+ rsv3: u1 = 0,
+ rsv2: u1 = 0,
+ rsv1: u1 = 0,
+ fin: bool,
};
- pub const WriteError = net.Stream.WriteError;
-
- /// When using content-length, asserts that the amount of data sent matches
- /// the value sent in the header, then calls `flush`.
- /// Otherwise, transfer-encoding: chunked is being used, and it writes the
- /// end-of-stream message, then flushes the stream to the system.
- /// Respects the value of `elide_body` to omit all data after the headers.
- pub fn end(r: *Response) WriteError!void {
- switch (r.transfer_encoding) {
- .content_length => |len| {
- assert(len == 0); // Trips when end() called before all bytes written.
- try flush_cl(r);
- },
- .none => {
- try flush_cl(r);
- },
- .chunked => {
- try flush_chunked(r, &.{});
- },
- }
- r.* = undefined;
- }
-
- pub const EndChunkedOptions = struct {
- trailers: []const http.Header = &.{},
+ pub const Header1 = packed struct(u8) {
+ payload_len: enum(u7) {
+ len16 = 126,
+ len64 = 127,
+ _,
+ },
+ mask: bool,
};
- /// Asserts that the Response is using transfer-encoding: chunked.
- /// Writes the end-of-stream message and any optional trailers, then
- /// flushes the stream to the system.
- /// Respects the value of `elide_body` to omit all data after the headers.
- /// Asserts there are at most 25 trailers.
- pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void {
- assert(r.transfer_encoding == .chunked);
- try flush_chunked(r, options.trailers);
- r.* = undefined;
- }
-
- /// If using content-length, asserts that writing these bytes to the client
- /// would not exceed the content-length value sent in the HTTP header.
- /// May return 0, which does not indicate end of stream. The caller decides
- /// when the end of stream occurs by calling `end`.
- pub fn write(r: *Response, bytes: []const u8) WriteError!usize {
- switch (r.transfer_encoding) {
- .content_length, .none => return write_cl(r, bytes),
- .chunked => return write_chunked(r, bytes),
- }
- }
-
- fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize {
- const r: *Response = @ptrCast(@alignCast(@constCast(context)));
+ pub const Opcode = enum(u4) {
+ continuation = 0,
+ text = 1,
+ binary = 2,
+ connection_close = 8,
+ ping = 9,
+ /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional
+ /// heartbeat. A response to an unsolicited Pong frame is not expected."
+ pong = 10,
+ _,
+ };
- var trash: u64 = std.math.maxInt(u64);
- const len = switch (r.transfer_encoding) {
- .content_length => |*len| len,
- else => &trash,
- };
+ pub const ReadSmallTextMessageError = error{
+ ConnectionClose,
+ UnexpectedOpCode,
+ MessageTooBig,
+ MissingMaskBit,
+ ReadFailed,
+ EndOfStream,
+ };
- if (r.elide_body) {
- len.* -= bytes.len;
- return bytes.len;
- }
+ pub const SmallMessage = struct {
+ /// Can be text, binary, or ping.
+ opcode: Opcode,
+ data: []u8,
+ };
- if (bytes.len + r.send_buffer_end > r.send_buffer.len) {
- const send_buffer_len = r.send_buffer_end - r.send_buffer_start;
- var iovecs: [2]std.posix.iovec_const = .{
- .{
- .base = r.send_buffer.ptr + r.send_buffer_start,
- .len = send_buffer_len,
- },
- .{
- .base = bytes.ptr,
- .len = bytes.len,
- },
- };
- const n = try r.stream.writev(&iovecs);
-
- if (n >= send_buffer_len) {
- // It was enough to reset the buffer.
- r.send_buffer_start = 0;
- r.send_buffer_end = 0;
- const bytes_n = n - send_buffer_len;
- len.* -= bytes_n;
- return bytes_n;
+ /// Reads the next message from the WebSocket stream, failing if the
+ /// message does not fit into the input buffer. The returned memory points
+ /// into the input buffer and is invalidated on the next read.
+ pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage {
+ const in = ws.input;
+ while (true) {
+ const header = try in.takeArray(2);
+ const h0: Header0 = @bitCast(header[0]);
+ const h1: Header1 = @bitCast(header[1]);
+
+ switch (h0.opcode) {
+ .text, .binary, .pong, .ping => {},
+ .connection_close => return error.ConnectionClose,
+ .continuation => return error.UnexpectedOpCode,
+ _ => return error.UnexpectedOpCode,
}
- // It didn't even make it through the existing buffer, let
- // alone the new bytes provided.
- r.send_buffer_start += n;
- return 0;
- }
-
- // All bytes can be stored in the remaining space of the buffer.
- @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes);
- r.send_buffer_end += bytes.len;
- len.* -= bytes.len;
- return bytes.len;
- }
+ if (!h0.fin) return error.MessageTooBig;
+ if (!h1.mask) return error.MissingMaskBit;
- fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize {
- const r: *Response = @ptrCast(@alignCast(@constCast(context)));
- assert(r.transfer_encoding == .chunked);
-
- if (r.elide_body)
- return bytes.len;
-
- if (bytes.len + r.send_buffer_end > r.send_buffer.len) {
- const send_buffer_len = r.send_buffer_end - r.send_buffer_start;
- const chunk_len = r.chunk_len + bytes.len;
- var header_buf: [18]u8 = undefined;
- const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable;
-
- var iovecs: [5]std.posix.iovec_const = .{
- .{
- .base = r.send_buffer.ptr + r.send_buffer_start,
- .len = send_buffer_len - r.chunk_len,
- },
- .{
- .base = chunk_header.ptr,
- .len = chunk_header.len,
- },
- .{
- .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len,
- .len = r.chunk_len,
- },
- .{
- .base = bytes.ptr,
- .len = bytes.len,
- },
- .{
- .base = "\r\n",
- .len = 2,
- },
+ const len: usize = switch (h1.payload_len) {
+ .len16 => try in.takeInt(u16, .big),
+ .len64 => std.math.cast(usize, try in.takeInt(u64, .big)) orelse return error.MessageTooBig,
+ else => @intFromEnum(h1.payload_len),
+ };
+ if (len > in.buffer.len) return error.MessageTooBig;
+ const mask: u32 = @bitCast((try in.takeArray(4)).*);
+ const payload = try in.take(len);
+
+ // Skip pongs.
+ if (h0.opcode == .pong) continue;
+
+ // The last item may contain a partial word of unused data.
+ const floored_len = (payload.len / 4) * 4;
+ const u32_payload: []align(1) u32 = @ptrCast(payload[0..floored_len]);
+ for (u32_payload) |*elem| elem.* ^= mask;
+ const mask_bytes: []const u8 = @ptrCast(&mask);
+ for (payload[floored_len..], mask_bytes[0 .. payload.len - floored_len]) |*leftover, m|
+ leftover.* ^= m;
+
+ return .{
+ .opcode = h0.opcode,
+ .data = payload,
};
- // TODO make this writev instead of writevAll, which involves
- // complicating the logic of this function.
- try r.stream.writevAll(&iovecs);
- r.send_buffer_start = 0;
- r.send_buffer_end = 0;
- r.chunk_len = 0;
- return bytes.len;
}
-
- // All bytes can be stored in the remaining space of the buffer.
- @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes);
- r.send_buffer_end += bytes.len;
- r.chunk_len += bytes.len;
- return bytes.len;
}
- /// If using content-length, asserts that writing these bytes to the client
- /// would not exceed the content-length value sent in the HTTP header.
- pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try write(r, bytes[index..]);
- }
+ pub fn writeMessage(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void {
+ var bufs: [1][]const u8 = .{data};
+ try writeMessageVecUnflushed(ws, &bufs, op);
+ try ws.output.flush();
}
- /// Sends all buffered data to the client.
- /// This is redundant after calling `end`.
- /// Respects the value of `elide_body` to omit all data after the headers.
- pub fn flush(r: *Response) WriteError!void {
- switch (r.transfer_encoding) {
- .none, .content_length => return flush_cl(r),
- .chunked => return flush_chunked(r, null),
- }
+ pub fn writeMessageUnflushed(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void {
+ var bufs: [1][]const u8 = .{data};
+ try writeMessageVecUnflushed(ws, &bufs, op);
}
- fn flush_cl(r: *Response) WriteError!void {
- try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]);
- r.send_buffer_start = 0;
- r.send_buffer_end = 0;
+ pub fn writeMessageVec(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void {
+ try writeMessageVecUnflushed(ws, data, op);
+ try ws.output.flush();
}
- fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void {
- const max_trailers = 25;
- if (end_trailers) |trailers| assert(trailers.len <= max_trailers);
- assert(r.transfer_encoding == .chunked);
-
- const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len];
-
- if (r.elide_body) {
- try r.stream.writeAll(http_headers);
- r.send_buffer_start = 0;
- r.send_buffer_end = 0;
- r.chunk_len = 0;
- return;
- }
-
- var header_buf: [18]u8 = undefined;
- const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable;
-
- var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined;
- var iovecs_len: usize = 0;
-
- iovecs[iovecs_len] = .{
- .base = http_headers.ptr,
- .len = http_headers.len,
+ pub fn writeMessageVecUnflushed(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void {
+ const total_len = l: {
+ var total_len: u64 = 0;
+ for (data) |iovec| total_len += iovec.len;
+ break :l total_len;
};
- iovecs_len += 1;
-
- if (r.chunk_len > 0) {
- iovecs[iovecs_len] = .{
- .base = chunk_header.ptr,
- .len = chunk_header.len,
- };
- iovecs_len += 1;
-
- iovecs[iovecs_len] = .{
- .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len,
- .len = r.chunk_len,
- };
- iovecs_len += 1;
-
- iovecs[iovecs_len] = .{
- .base = "\r\n",
- .len = 2,
- };
- iovecs_len += 1;
- }
-
- if (end_trailers) |trailers| {
- iovecs[iovecs_len] = .{
- .base = "0\r\n",
- .len = 3,
- };
- iovecs_len += 1;
-
- for (trailers) |trailer| {
- iovecs[iovecs_len] = .{
- .base = trailer.name.ptr,
- .len = trailer.name.len,
- };
- iovecs_len += 1;
-
- iovecs[iovecs_len] = .{
- .base = ": ",
- .len = 2,
- };
- iovecs_len += 1;
-
- if (trailer.value.len != 0) {
- iovecs[iovecs_len] = .{
- .base = trailer.value.ptr,
- .len = trailer.value.len,
- };
- iovecs_len += 1;
- }
-
- iovecs[iovecs_len] = .{
- .base = "\r\n",
- .len = 2,
- };
- iovecs_len += 1;
- }
-
- iovecs[iovecs_len] = .{
- .base = "\r\n",
- .len = 2,
- };
- iovecs_len += 1;
+ const out = ws.output;
+ try out.writeByte(@bitCast(@as(Header0, .{
+ .opcode = op,
+ .fin = true,
+ })));
+ switch (total_len) {
+ 0...125 => try out.writeByte(@bitCast(@as(Header1, .{
+ .payload_len = @enumFromInt(total_len),
+ .mask = false,
+ }))),
+ 126...0xffff => {
+ try out.writeByte(@bitCast(@as(Header1, .{
+ .payload_len = .len16,
+ .mask = false,
+ })));
+ try out.writeInt(u16, @intCast(total_len), .big);
+ },
+ else => {
+ try out.writeByte(@bitCast(@as(Header1, .{
+ .payload_len = .len64,
+ .mask = false,
+ })));
+ try out.writeInt(u64, total_len, .big);
+ },
}
-
- try r.stream.writevAll(iovecs[0..iovecs_len]);
- r.send_buffer_start = 0;
- r.send_buffer_end = 0;
- r.chunk_len = 0;
+ try out.writeVecAll(data);
}
- pub fn writer(r: *Response) std.io.AnyWriter {
- return .{
- .writeFn = switch (r.transfer_encoding) {
- .none, .content_length => write_cl,
- .chunked => write_chunked,
- },
- .context = r,
- };
+ pub fn flush(ws: *WebSocket) Writer.Error!void {
+ try ws.output.flush();
}
};
-
-fn rebase(s: *Server, index: usize) void {
- const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len];
- const dest = s.read_buffer[index..][0..leftover.len];
- if (leftover.len <= s.next_request_start - index) {
- @memcpy(dest, leftover);
- } else {
- mem.copyBackwards(u8, dest, leftover);
- }
- s.read_buffer_len = index + leftover.len;
-}
-
-const std = @import("../std.zig");
-const http = std.http;
-const mem = std.mem;
-const net = std.net;
-const Uri = std.Uri;
-const assert = std.debug.assert;
-const testing = std.testing;
-
-const Server = @This();
diff --git a/lib/std/http/WebSocket.zig b/lib/std/http/WebSocket.zig
@@ -1,246 +0,0 @@
-//! See https://tools.ietf.org/html/rfc6455
-
-const builtin = @import("builtin");
-const std = @import("std");
-const WebSocket = @This();
-const assert = std.debug.assert;
-const native_endian = builtin.cpu.arch.endian();
-
-key: []const u8,
-request: *std.http.Server.Request,
-recv_fifo: std.fifo.LinearFifo(u8, .Slice),
-reader: std.io.AnyReader,
-response: std.http.Server.Response,
-/// Number of bytes that have been peeked but not discarded yet.
-outstanding_len: usize,
-
-pub const InitError = error{WebSocketUpgradeMissingKey} ||
- std.http.Server.Request.ReaderError;
-
-pub fn init(
- request: *std.http.Server.Request,
- send_buffer: []u8,
- recv_buffer: []align(4) u8,
-) InitError!?WebSocket {
- switch (request.head.version) {
- .@"HTTP/1.0" => return null,
- .@"HTTP/1.1" => if (request.head.method != .GET) return null,
- }
-
- var sec_websocket_key: ?[]const u8 = null;
- var upgrade_websocket: bool = false;
- var it = request.iterateHeaders();
- while (it.next()) |header| {
- if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) {
- sec_websocket_key = header.value;
- } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) {
- if (!std.ascii.eqlIgnoreCase(header.value, "websocket"))
- return null;
- upgrade_websocket = true;
- }
- }
- if (!upgrade_websocket)
- return null;
-
- const key = sec_websocket_key orelse return error.WebSocketUpgradeMissingKey;
-
- var sha1 = std.crypto.hash.Sha1.init(.{});
- sha1.update(key);
- sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
- var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
- sha1.final(&digest);
- var base64_digest: [28]u8 = undefined;
- assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len);
-
- request.head.content_length = std.math.maxInt(u64);
-
- return .{
- .key = key,
- .recv_fifo = std.fifo.LinearFifo(u8, .Slice).init(recv_buffer),
- .reader = try request.reader(),
- .response = request.respondStreaming(.{
- .send_buffer = send_buffer,
- .respond_options = .{
- .status = .switching_protocols,
- .extra_headers = &.{
- .{ .name = "upgrade", .value = "websocket" },
- .{ .name = "connection", .value = "upgrade" },
- .{ .name = "sec-websocket-accept", .value = &base64_digest },
- },
- .transfer_encoding = .none,
- },
- }),
- .request = request,
- .outstanding_len = 0,
- };
-}
-
-pub const Header0 = packed struct(u8) {
- opcode: Opcode,
- rsv3: u1 = 0,
- rsv2: u1 = 0,
- rsv1: u1 = 0,
- fin: bool,
-};
-
-pub const Header1 = packed struct(u8) {
- payload_len: enum(u7) {
- len16 = 126,
- len64 = 127,
- _,
- },
- mask: bool,
-};
-
-pub const Opcode = enum(u4) {
- continuation = 0,
- text = 1,
- binary = 2,
- connection_close = 8,
- ping = 9,
- /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional
- /// heartbeat. A response to an unsolicited Pong frame is not expected."
- pong = 10,
- _,
-};
-
-pub const ReadSmallTextMessageError = error{
- ConnectionClose,
- UnexpectedOpCode,
- MessageTooBig,
- MissingMaskBit,
-} || RecvError;
-
-pub const SmallMessage = struct {
- /// Can be text, binary, or ping.
- opcode: Opcode,
- data: []u8,
-};
-
-/// Reads the next message from the WebSocket stream, failing if the message does not fit
-/// into `recv_buffer`.
-pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage {
- while (true) {
- const header_bytes = (try recv(ws, 2))[0..2];
- const h0: Header0 = @bitCast(header_bytes[0]);
- const h1: Header1 = @bitCast(header_bytes[1]);
-
- switch (h0.opcode) {
- .text, .binary, .pong, .ping => {},
- .connection_close => return error.ConnectionClose,
- .continuation => return error.UnexpectedOpCode,
- _ => return error.UnexpectedOpCode,
- }
-
- if (!h0.fin) return error.MessageTooBig;
- if (!h1.mask) return error.MissingMaskBit;
-
- const len: usize = switch (h1.payload_len) {
- .len16 => try recvReadInt(ws, u16),
- .len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig,
- else => @intFromEnum(h1.payload_len),
- };
- if (len > ws.recv_fifo.buf.len) return error.MessageTooBig;
-
- const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*);
- const payload = try recv(ws, len);
-
- // Skip pongs.
- if (h0.opcode == .pong) continue;
-
- // The last item may contain a partial word of unused data.
- const floored_len = (payload.len / 4) * 4;
- const u32_payload: []align(1) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload[0..floored_len]));
- for (u32_payload) |*elem| elem.* ^= mask;
- const mask_bytes = std.mem.asBytes(&mask)[0 .. payload.len - floored_len];
- for (payload[floored_len..], mask_bytes) |*leftover, m| leftover.* ^= m;
-
- return .{
- .opcode = h0.opcode,
- .data = payload,
- };
- }
-}
-
-const RecvError = std.http.Server.Request.ReadError || error{EndOfStream};
-
-fn recv(ws: *WebSocket, len: usize) RecvError![]u8 {
- ws.recv_fifo.discard(ws.outstanding_len);
- assert(len <= ws.recv_fifo.buf.len);
- if (len > ws.recv_fifo.count) {
- const small_buf = ws.recv_fifo.writableSlice(0);
- const needed = len - ws.recv_fifo.count;
- const buf = if (small_buf.len >= needed) small_buf else b: {
- ws.recv_fifo.realign();
- break :b ws.recv_fifo.writableSlice(0);
- };
- const n = try @as(RecvError!usize, @errorCast(ws.reader.readAtLeast(buf, needed)));
- if (n < needed) return error.EndOfStream;
- ws.recv_fifo.update(n);
- }
- ws.outstanding_len = len;
- // TODO: improve the std lib API so this cast isn't necessary.
- return @constCast(ws.recv_fifo.readableSliceOfLen(len));
-}
-
-fn recvReadInt(ws: *WebSocket, comptime I: type) !I {
- const unswapped: I = @bitCast((try recv(ws, @sizeOf(I)))[0..@sizeOf(I)].*);
- return switch (native_endian) {
- .little => @byteSwap(unswapped),
- .big => unswapped,
- };
-}
-
-pub const WriteError = std.http.Server.Response.WriteError;
-
-pub fn writeMessage(ws: *WebSocket, message: []const u8, opcode: Opcode) WriteError!void {
- const iovecs: [1]std.posix.iovec_const = .{
- .{ .base = message.ptr, .len = message.len },
- };
- return writeMessagev(ws, &iovecs, opcode);
-}
-
-pub fn writeMessagev(ws: *WebSocket, message: []const std.posix.iovec_const, opcode: Opcode) WriteError!void {
- const total_len = l: {
- var total_len: u64 = 0;
- for (message) |iovec| total_len += iovec.len;
- break :l total_len;
- };
-
- var header_buf: [2 + 8]u8 = undefined;
- header_buf[0] = @bitCast(@as(Header0, .{
- .opcode = opcode,
- .fin = true,
- }));
- const header = switch (total_len) {
- 0...125 => blk: {
- header_buf[1] = @bitCast(@as(Header1, .{
- .payload_len = @enumFromInt(total_len),
- .mask = false,
- }));
- break :blk header_buf[0..2];
- },
- 126...0xffff => blk: {
- header_buf[1] = @bitCast(@as(Header1, .{
- .payload_len = .len16,
- .mask = false,
- }));
- std.mem.writeInt(u16, header_buf[2..4], @intCast(total_len), .big);
- break :blk header_buf[0..4];
- },
- else => blk: {
- header_buf[1] = @bitCast(@as(Header1, .{
- .payload_len = .len64,
- .mask = false,
- }));
- std.mem.writeInt(u64, header_buf[2..10], total_len, .big);
- break :blk header_buf[0..10];
- },
- };
-
- const response = &ws.response;
- try response.writeAll(header);
- for (message) |iovec|
- try response.writeAll(iovec.base[0..iovec.len]);
- try response.flush();
-}
diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig
@@ -1,464 +0,0 @@
-const std = @import("../std.zig");
-const builtin = @import("builtin");
-const testing = std.testing;
-const mem = std.mem;
-
-const assert = std.debug.assert;
-
-pub const State = enum {
- invalid,
-
- // Begin header and trailer parsing states.
-
- start,
- seen_n,
- seen_r,
- seen_rn,
- seen_rnr,
- finished,
-
- // Begin transfer-encoding: chunked parsing states.
-
- chunk_head_size,
- chunk_head_ext,
- chunk_head_r,
- chunk_data,
- chunk_data_suffix,
- chunk_data_suffix_r,
-
- /// Returns true if the parser is in a content state (ie. not waiting for more headers).
- pub fn isContent(self: State) bool {
- return switch (self) {
- .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false,
- .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true,
- };
- }
-};
-
-pub const HeadersParser = struct {
- state: State = .start,
- /// A fixed buffer of len `max_header_bytes`.
- /// Pointers into this buffer are not stable until after a message is complete.
- header_bytes_buffer: []u8,
- header_bytes_len: u32,
- next_chunk_length: u64,
- /// `false`: headers. `true`: trailers.
- done: bool,
-
- /// Initializes the parser with a provided buffer `buf`.
- pub fn init(buf: []u8) HeadersParser {
- return .{
- .header_bytes_buffer = buf,
- .header_bytes_len = 0,
- .done = false,
- .next_chunk_length = 0,
- };
- }
-
- /// Reinitialize the parser.
- /// Asserts the parser is in the "done" state.
- pub fn reset(hp: *HeadersParser) void {
- assert(hp.done);
- hp.* = .{
- .state = .start,
- .header_bytes_buffer = hp.header_bytes_buffer,
- .header_bytes_len = 0,
- .done = false,
- .next_chunk_length = 0,
- };
- }
-
- pub fn get(hp: HeadersParser) []u8 {
- return hp.header_bytes_buffer[0..hp.header_bytes_len];
- }
-
- pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 {
- var hp: std.http.HeadParser = .{
- .state = switch (r.state) {
- .start => .start,
- .seen_n => .seen_n,
- .seen_r => .seen_r,
- .seen_rn => .seen_rn,
- .seen_rnr => .seen_rnr,
- .finished => .finished,
- else => unreachable,
- },
- };
- const result = hp.feed(bytes);
- r.state = switch (hp.state) {
- .start => .start,
- .seen_n => .seen_n,
- .seen_r => .seen_r,
- .seen_rn => .seen_rn,
- .seen_rnr => .seen_rnr,
- .finished => .finished,
- };
- return @intCast(result);
- }
-
- pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 {
- var cp: std.http.ChunkParser = .{
- .state = switch (r.state) {
- .chunk_head_size => .head_size,
- .chunk_head_ext => .head_ext,
- .chunk_head_r => .head_r,
- .chunk_data => .data,
- .chunk_data_suffix => .data_suffix,
- .chunk_data_suffix_r => .data_suffix_r,
- .invalid => .invalid,
- else => unreachable,
- },
- .chunk_len = r.next_chunk_length,
- };
- const result = cp.feed(bytes);
- r.state = switch (cp.state) {
- .head_size => .chunk_head_size,
- .head_ext => .chunk_head_ext,
- .head_r => .chunk_head_r,
- .data => .chunk_data,
- .data_suffix => .chunk_data_suffix,
- .data_suffix_r => .chunk_data_suffix_r,
- .invalid => .invalid,
- };
- r.next_chunk_length = cp.chunk_len;
- return @intCast(result);
- }
-
- /// Returns whether or not the parser has finished parsing a complete
- /// message. A message is only complete after the entire body has been read
- /// and any trailing headers have been parsed.
- pub fn isComplete(r: *HeadersParser) bool {
- return r.done and r.state == .finished;
- }
-
- pub const CheckCompleteHeadError = error{HttpHeadersOversize};
-
- /// Pushes `in` into the parser. Returns the number of bytes consumed by
- /// the header. Any header bytes are appended to `header_bytes_buffer`.
- pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 {
- if (hp.state.isContent()) return 0;
-
- const i = hp.findHeadersEnd(in);
- const data = in[0..i];
- if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len)
- return error.HttpHeadersOversize;
-
- @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data);
- hp.header_bytes_len += @intCast(data.len);
-
- return i;
- }
-
- pub const ReadError = error{
- HttpChunkInvalid,
- };
-
- /// Reads the body of the message into `buffer`. Returns the number of
- /// bytes placed in the buffer.
- ///
- /// If `skip` is true, the buffer will be unused and the body will be skipped.
- ///
- /// See `std.http.Client.Connection for an example of `conn`.
- pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize {
- assert(r.state.isContent());
- if (r.done) return 0;
-
- var out_index: usize = 0;
- while (true) {
- switch (r.state) {
- .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable,
- .finished => {
- const data_avail = r.next_chunk_length;
-
- if (skip) {
- conn.fill() catch |err| switch (err) {
- error.EndOfStream => {
- r.done = true;
- return 0;
- },
- else => |e| return e,
- };
-
- const nread = @min(conn.peek().len, data_avail);
- conn.drop(@intCast(nread));
- r.next_chunk_length -= nread;
-
- if (r.next_chunk_length == 0 or nread == 0) r.done = true;
-
- return out_index;
- } else if (out_index < buffer.len) {
- const out_avail = buffer.len - out_index;
-
- const can_read = @as(usize, @intCast(@min(data_avail, out_avail)));
- const nread = try conn.read(buffer[0..can_read]);
- r.next_chunk_length -= nread;
-
- if (r.next_chunk_length == 0 or nread == 0) r.done = true;
-
- return nread;
- } else {
- return out_index;
- }
- },
- .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
- conn.fill() catch |err| switch (err) {
- error.EndOfStream => {
- r.done = true;
- return 0;
- },
- else => |e| return e,
- };
-
- const i = r.findChunkedLen(conn.peek());
- conn.drop(@intCast(i));
-
- switch (r.state) {
- .invalid => return error.HttpChunkInvalid,
- .chunk_data => if (r.next_chunk_length == 0) {
- if (std.mem.eql(u8, conn.peek(), "\r\n")) {
- r.state = .finished;
- conn.drop(2);
- } else {
- // The trailer section is formatted identically
- // to the header section.
- r.state = .seen_rn;
- }
- r.done = true;
-
- return out_index;
- },
- else => return out_index,
- }
-
- continue;
- },
- .chunk_data => {
- const data_avail = r.next_chunk_length;
- const out_avail = buffer.len - out_index;
-
- if (skip) {
- conn.fill() catch |err| switch (err) {
- error.EndOfStream => {
- r.done = true;
- return 0;
- },
- else => |e| return e,
- };
-
- const nread = @min(conn.peek().len, data_avail);
- conn.drop(@intCast(nread));
- r.next_chunk_length -= nread;
- } else if (out_avail > 0) {
- const can_read: usize = @intCast(@min(data_avail, out_avail));
- const nread = try conn.read(buffer[out_index..][0..can_read]);
- r.next_chunk_length -= nread;
- out_index += nread;
- }
-
- if (r.next_chunk_length == 0) {
- r.state = .chunk_data_suffix;
- continue;
- }
-
- return out_index;
- },
- }
- }
- }
-};
-
-inline fn int16(array: *const [2]u8) u16 {
- return @as(u16, @bitCast(array.*));
-}
-
-inline fn int24(array: *const [3]u8) u24 {
- return @as(u24, @bitCast(array.*));
-}
-
-inline fn int32(array: *const [4]u8) u32 {
- return @as(u32, @bitCast(array.*));
-}
-
-inline fn intShift(comptime T: type, x: anytype) T {
- switch (@import("builtin").cpu.arch.endian()) {
- .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))),
- .big => return @as(T, @truncate(x)),
- }
-}
-
-/// A buffered (and peekable) Connection.
-const MockBufferedConnection = struct {
- pub const buffer_size = 0x2000;
-
- conn: std.io.FixedBufferStream([]const u8),
- buf: [buffer_size]u8 = undefined,
- start: u16 = 0,
- end: u16 = 0,
-
- pub fn fill(conn: *MockBufferedConnection) ReadError!void {
- if (conn.end != conn.start) return;
-
- const nread = try conn.conn.read(conn.buf[0..]);
- if (nread == 0) return error.EndOfStream;
- conn.start = 0;
- conn.end = @as(u16, @truncate(nread));
- }
-
- pub fn peek(conn: *MockBufferedConnection) []const u8 {
- return conn.buf[conn.start..conn.end];
- }
-
- pub fn drop(conn: *MockBufferedConnection, num: u16) void {
- conn.start += num;
- }
-
- pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
- var out_index: u16 = 0;
- while (out_index < len) {
- const available = conn.end - conn.start;
- const left = buffer.len - out_index;
-
- if (available > 0) {
- const can_read = @as(u16, @truncate(@min(available, left)));
-
- @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]);
- out_index += can_read;
- conn.start += can_read;
-
- continue;
- }
-
- if (left > conn.buf.len) {
- // skip the buffer if the output is large enough
- return conn.conn.read(buffer[out_index..]);
- }
-
- try conn.fill();
- }
-
- return out_index;
- }
-
- pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
- return conn.readAtLeast(buffer, 1);
- }
-
- pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream};
- pub const Reader = std.io.GenericReader(*MockBufferedConnection, ReadError, read);
-
- pub fn reader(conn: *MockBufferedConnection) Reader {
- return Reader{ .context = conn };
- }
-
- pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
- return conn.conn.writeAll(buffer);
- }
-
- pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
- return conn.conn.write(buffer);
- }
-
- pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError;
- pub const Writer = std.io.GenericWriter(*MockBufferedConnection, WriteError, write);
-
- pub fn writer(conn: *MockBufferedConnection) Writer {
- return Writer{ .context = conn };
- }
-};
-
-test "HeadersParser.read length" {
- // mock BufferedConnection for read
- var headers_buf: [256]u8 = undefined;
-
- var r = HeadersParser.init(&headers_buf);
- const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello";
-
- var conn: MockBufferedConnection = .{
- .conn = std.io.fixedBufferStream(data),
- };
-
- while (true) { // read headers
- try conn.fill();
-
- const nchecked = try r.checkCompleteHead(conn.peek());
- conn.drop(@intCast(nchecked));
-
- if (r.state.isContent()) break;
- }
-
- var buf: [8]u8 = undefined;
-
- r.next_chunk_length = 5;
- const len = try r.read(&conn, &buf, false);
- try std.testing.expectEqual(@as(usize, 5), len);
- try std.testing.expectEqualStrings("Hello", buf[0..len]);
-
- try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get());
-}
-
-test "HeadersParser.read chunked" {
- // mock BufferedConnection for read
-
- var headers_buf: [256]u8 = undefined;
- var r = HeadersParser.init(&headers_buf);
- const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n";
-
- var conn: MockBufferedConnection = .{
- .conn = std.io.fixedBufferStream(data),
- };
-
- while (true) { // read headers
- try conn.fill();
-
- const nchecked = try r.checkCompleteHead(conn.peek());
- conn.drop(@intCast(nchecked));
-
- if (r.state.isContent()) break;
- }
- var buf: [8]u8 = undefined;
-
- r.state = .chunk_head_size;
- const len = try r.read(&conn, &buf, false);
- try std.testing.expectEqual(@as(usize, 5), len);
- try std.testing.expectEqualStrings("Hello", buf[0..len]);
-
- try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get());
-}
-
-test "HeadersParser.read chunked trailer" {
- // mock BufferedConnection for read
-
- var headers_buf: [256]u8 = undefined;
- var r = HeadersParser.init(&headers_buf);
- const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n";
-
- var conn: MockBufferedConnection = .{
- .conn = std.io.fixedBufferStream(data),
- };
-
- while (true) { // read headers
- try conn.fill();
-
- const nchecked = try r.checkCompleteHead(conn.peek());
- conn.drop(@intCast(nchecked));
-
- if (r.state.isContent()) break;
- }
- var buf: [8]u8 = undefined;
-
- r.state = .chunk_head_size;
- const len = try r.read(&conn, &buf, false);
- try std.testing.expectEqual(@as(usize, 5), len);
- try std.testing.expectEqualStrings("Hello", buf[0..len]);
-
- while (true) { // read headers
- try conn.fill();
-
- const nchecked = try r.checkCompleteHead(conn.peek());
- conn.drop(@intCast(nchecked));
-
- if (r.state.isContent()) break;
- }
-
- try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get());
-}
diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig
@@ -10,32 +10,33 @@ const expectError = std.testing.expectError;
test "trailers" {
const test_server = try createTestServer(struct {
- fn run(net_server: *std.net.Server) anyerror!void {
- var header_buffer: [1024]u8 = undefined;
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [1024]u8 = undefined;
+ var send_buffer: [1024]u8 = undefined;
var remaining: usize = 1;
while (remaining != 0) : (remaining -= 1) {
- const conn = try net_server.accept();
- defer conn.stream.close();
+ const connection = try net_server.accept();
+ defer connection.stream.close();
- var server = http.Server.init(conn, &header_buffer);
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var server = http.Server.init(connection_br.interface(), &connection_bw.interface);
- try expectEqual(.ready, server.state);
+ try expectEqual(.ready, server.reader.state);
var request = try server.receiveHead();
try serve(&request);
- try expectEqual(.ready, server.state);
+ try expectEqual(.ready, server.reader.state);
}
}
fn serve(request: *http.Server.Request) !void {
try expectEqualStrings(request.head.target, "/trailer");
- var send_buffer: [1024]u8 = undefined;
- var response = request.respondStreaming(.{
- .send_buffer = &send_buffer,
- });
- try response.writeAll("Hello, ");
+ var response = try request.respondStreaming(&.{}, .{});
+ try response.writer.writeAll("Hello, ");
try response.flush();
- try response.writeAll("World!\n");
+ try response.writer.writeAll("World!\n");
try response.flush();
try response.endChunked(.{
.trailers = &.{
@@ -58,34 +59,32 @@ test "trailers" {
const uri = try std.Uri.parse(location);
{
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
-
- const body = try req.reader().readAllAlloc(gpa, 8192);
- defer gpa.free(body);
+ try req.sendBodiless();
+ var response = try req.receiveHead(&.{});
- try expectEqualStrings("Hello, World!\n", body);
-
- var it = req.response.iterateHeaders();
{
+ var it = response.head.iterateHeaders();
const header = it.next().?;
- try expect(!it.is_trailer);
try expectEqualStrings("transfer-encoding", header.name);
try expectEqualStrings("chunked", header.value);
+ try expectEqual(null, it.next());
}
+
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
+ defer gpa.free(body);
+
+ try expectEqualStrings("Hello, World!\n", body);
+
{
+ var it = response.iterateTrailers();
const header = it.next().?;
- try expect(it.is_trailer);
try expectEqualStrings("X-Checksum", header.name);
try expectEqualStrings("aaaa", header.value);
+ try expectEqual(null, it.next());
}
- try expectEqual(null, it.next());
}
// connection has been kept alive
@@ -94,19 +93,24 @@ test "trailers" {
test "HTTP server handles a chunked transfer coding request" {
const test_server = try createTestServer(struct {
- fn run(net_server: *std.net.Server) !void {
- var header_buffer: [8192]u8 = undefined;
- const conn = try net_server.accept();
- defer conn.stream.close();
-
- var server = http.Server.init(conn, &header_buffer);
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [8192]u8 = undefined;
+ var send_buffer: [500]u8 = undefined;
+ const connection = try net_server.accept();
+ defer connection.stream.close();
+
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var server = http.Server.init(connection_br.interface(), &connection_bw.interface);
var request = try server.receiveHead();
try expect(request.head.transfer_encoding == .chunked);
var buf: [128]u8 = undefined;
- const n = try (try request.reader()).readAll(&buf);
- try expect(mem.eql(u8, buf[0..n], "ABCD"));
+ var br = try request.readerExpectContinue(&.{});
+ const n = try br.readSliceShort(&buf);
+ try expectEqualStrings("ABCD", buf[0..n]);
try request.respond("message from server!\n", .{
.extra_headers = &.{
@@ -154,16 +158,20 @@ test "HTTP server handles a chunked transfer coding request" {
test "echo content server" {
const test_server = try createTestServer(struct {
- fn run(net_server: *std.net.Server) anyerror!void {
- var read_buffer: [1024]u8 = undefined;
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [1024]u8 = undefined;
+ var send_buffer: [100]u8 = undefined;
- accept: while (true) {
- const conn = try net_server.accept();
- defer conn.stream.close();
+ accept: while (!test_server.shutting_down) {
+ const connection = try net_server.accept();
+ defer connection.stream.close();
- var http_server = http.Server.init(conn, &read_buffer);
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var http_server = http.Server.init(connection_br.interface(), &connection_bw.interface);
- while (http_server.state == .ready) {
+ while (http_server.reader.state == .ready) {
var request = http_server.receiveHead() catch |err| switch (err) {
error.HttpConnectionClosing => continue :accept,
else => |e| return e,
@@ -173,8 +181,12 @@ test "echo content server" {
}
if (request.head.expect) |expect_header_value| {
if (mem.eql(u8, expect_header_value, "garbage")) {
- try expectError(error.HttpExpectationFailed, request.reader());
- try request.respond("", .{ .keep_alive = false });
+ try expectError(error.HttpExpectationFailed, request.readerExpectContinue(&.{}));
+ request.head.expect = null;
+ try request.respond("", .{
+ .keep_alive = false,
+ .status = .expectation_failed,
+ });
continue;
}
}
@@ -195,16 +207,16 @@ test "echo content server" {
// request.head.target,
//});
- const body = try (try request.reader()).readAllAlloc(std.testing.allocator, 8192);
+ try expect(mem.startsWith(u8, request.head.target, "/echo-content"));
+ try expectEqualStrings("text/plain", request.head.content_type.?);
+
+ // head strings expire here
+ const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .unlimited);
defer std.testing.allocator.free(body);
- try expect(mem.startsWith(u8, request.head.target, "/echo-content"));
try expectEqualStrings("Hello, World!\n", body);
- try expectEqualStrings("text/plain", request.head.content_type.?);
- var send_buffer: [100]u8 = undefined;
- var response = request.respondStreaming(.{
- .send_buffer = &send_buffer,
+ var response = try request.respondStreaming(&.{}, .{
.content_length = switch (request.head.transfer_encoding) {
.chunked => null,
.none => len: {
@@ -213,9 +225,8 @@ test "echo content server" {
},
},
});
-
try response.flush(); // Test an early flush to send the HTTP headers before the body.
- const w = response.writer();
+ const w = &response.writer;
try w.writeAll("Hello, ");
try w.writeAll("World!\n");
try response.end();
@@ -241,35 +252,35 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" {
// In this case, the response is expected to stream until the connection is
// closed, indicating the end of the body.
const test_server = try createTestServer(struct {
- fn run(net_server: *std.net.Server) anyerror!void {
- var header_buffer: [1000]u8 = undefined;
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [1000]u8 = undefined;
+ var send_buffer: [500]u8 = undefined;
var remaining: usize = 1;
while (remaining != 0) : (remaining -= 1) {
- const conn = try net_server.accept();
- defer conn.stream.close();
+ const connection = try net_server.accept();
+ defer connection.stream.close();
- var server = http.Server.init(conn, &header_buffer);
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var server = http.Server.init(connection_br.interface(), &connection_bw.interface);
- try expectEqual(.ready, server.state);
+ try expectEqual(.ready, server.reader.state);
var request = try server.receiveHead();
try expectEqualStrings(request.head.target, "/foo");
- var send_buffer: [500]u8 = undefined;
- var response = request.respondStreaming(.{
- .send_buffer = &send_buffer,
+ var buf: [30]u8 = undefined;
+ var response = try request.respondStreaming(&buf, .{
.respond_options = .{
.transfer_encoding = .none,
},
});
- var total: usize = 0;
+ const w = &response.writer;
for (0..500) |i| {
- var buf: [30]u8 = undefined;
- const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i});
- try response.writeAll(line);
- total += line.len;
+ try w.print("{d}, ah ha ha!\n", .{i});
}
- try expectEqual(7390, total);
+ try w.flush();
try response.end();
- try expectEqual(.closing, server.state);
+ try expectEqual(.closing, server.reader.state);
}
}
});
@@ -284,7 +295,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" {
var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded
var stream_reader = stream.reader(&tiny_buffer);
- const response = try stream_reader.interface().allocRemaining(gpa, .limited(8192));
+ const response = try stream_reader.interface().allocRemaining(gpa, .unlimited);
defer gpa.free(response);
var expected_response = std.ArrayList(u8).init(gpa);
@@ -308,15 +319,20 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" {
test "receiving arbitrary http headers from the client" {
const test_server = try createTestServer(struct {
- fn run(net_server: *std.net.Server) anyerror!void {
- var read_buffer: [666]u8 = undefined;
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [666]u8 = undefined;
+ var send_buffer: [777]u8 = undefined;
var remaining: usize = 1;
while (remaining != 0) : (remaining -= 1) {
- const conn = try net_server.accept();
- defer conn.stream.close();
+ const connection = try net_server.accept();
+ defer connection.stream.close();
- var server = http.Server.init(conn, &read_buffer);
- try expectEqual(.ready, server.state);
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var server = http.Server.init(connection_br.interface(), &connection_bw.interface);
+
+ try expectEqual(.ready, server.reader.state);
var request = try server.receiveHead();
try expectEqualStrings("/bar", request.head.target);
var it = request.iterateHeaders();
@@ -350,7 +366,7 @@ test "receiving arbitrary http headers from the client" {
var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded
var stream_reader = stream.reader(&tiny_buffer);
- const response = try stream_reader.interface().allocRemaining(gpa, .limited(8192));
+ const response = try stream_reader.interface().allocRemaining(gpa, .unlimited);
defer gpa.free(response);
var expected_response = std.ArrayList(u8).init(gpa);
@@ -368,19 +384,21 @@ test "general client/server API coverage" {
return error.SkipZigTest;
}
- const global = struct {
- var handle_new_requests = true;
- };
const test_server = try createTestServer(struct {
- fn run(net_server: *std.net.Server) anyerror!void {
- var client_header_buffer: [1024]u8 = undefined;
- outer: while (global.handle_new_requests) {
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [1024]u8 = undefined;
+ var send_buffer: [100]u8 = undefined;
+
+ outer: while (!test_server.shutting_down) {
var connection = try net_server.accept();
defer connection.stream.close();
- var http_server = http.Server.init(connection, &client_header_buffer);
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var http_server = http.Server.init(connection_br.interface(), &connection_bw.interface);
- while (http_server.state == .ready) {
+ while (http_server.reader.state == .ready) {
var request = http_server.receiveHead() catch |err| switch (err) {
error.HttpConnectionClosing => continue :outer,
else => |e| return e,
@@ -393,21 +411,19 @@ test "general client/server API coverage" {
fn handleRequest(request: *http.Server.Request, listen_port: u16) !void {
const log = std.log.scoped(.server);
+ const gpa = std.testing.allocator;
- log.info("{f} {s} {s}", .{
- request.head.method, @tagName(request.head.version), request.head.target,
- });
+ log.info("{t} {t} {s}", .{ request.head.method, request.head.version, request.head.target });
+ const target = try gpa.dupe(u8, request.head.target);
+ defer gpa.free(target);
- const gpa = std.testing.allocator;
- const body = try (try request.reader()).readAllAlloc(gpa, 8192);
+ const reader = (try request.readerExpectContinue(&.{}));
+ const body = try reader.allocRemaining(gpa, .unlimited);
defer gpa.free(body);
- var send_buffer: [100]u8 = undefined;
-
- if (mem.startsWith(u8, request.head.target, "/get")) {
- var response = request.respondStreaming(.{
- .send_buffer = &send_buffer,
- .content_length = if (mem.indexOf(u8, request.head.target, "?chunked") == null)
+ if (mem.startsWith(u8, target, "/get")) {
+ var response = try request.respondStreaming(&.{}, .{
+ .content_length = if (mem.indexOf(u8, target, "?chunked") == null)
14
else
null,
@@ -417,27 +433,27 @@ test "general client/server API coverage" {
},
},
});
- const w = response.writer();
+ const w = &response.writer;
try w.writeAll("Hello, ");
try w.writeAll("World!\n");
try response.end();
// Writing again would cause an assertion failure.
- } else if (mem.startsWith(u8, request.head.target, "/large")) {
- var response = request.respondStreaming(.{
- .send_buffer = &send_buffer,
+ } else if (mem.startsWith(u8, target, "/large")) {
+ var response = try request.respondStreaming(&.{}, .{
.content_length = 14 * 1024 + 14 * 10,
});
try response.flush(); // Test an early flush to send the HTTP headers before the body.
- const w = response.writer();
+ const w = &response.writer;
var i: u32 = 0;
while (i < 5) : (i += 1) {
try w.writeAll("Hello, World!\n");
}
- try w.writeAll("Hello, World!\n" ** 1024);
+ var vec: [1][]const u8 = .{"Hello, World!\n"};
+ try w.writeSplatAll(&vec, 1024);
i = 0;
while (i < 5) : (i += 1) {
@@ -445,9 +461,8 @@ test "general client/server API coverage" {
}
try response.end();
- } else if (mem.eql(u8, request.head.target, "/redirect/1")) {
- var response = request.respondStreaming(.{
- .send_buffer = &send_buffer,
+ } else if (mem.eql(u8, target, "/redirect/1")) {
+ var response = try request.respondStreaming(&.{}, .{
.respond_options = .{
.status = .found,
.extra_headers = &.{
@@ -456,18 +471,18 @@ test "general client/server API coverage" {
},
});
- const w = response.writer();
+ const w = &response.writer;
try w.writeAll("Hello, ");
try w.writeAll("Redirected!\n");
try response.end();
- } else if (mem.eql(u8, request.head.target, "/redirect/2")) {
+ } else if (mem.eql(u8, target, "/redirect/2")) {
try request.respond("Hello, Redirected!\n", .{
.status = .found,
.extra_headers = &.{
.{ .name = "location", .value = "/redirect/1" },
},
});
- } else if (mem.eql(u8, request.head.target, "/redirect/3")) {
+ } else if (mem.eql(u8, target, "/redirect/3")) {
const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/2", .{
listen_port,
});
@@ -479,23 +494,23 @@ test "general client/server API coverage" {
.{ .name = "location", .value = location },
},
});
- } else if (mem.eql(u8, request.head.target, "/redirect/4")) {
+ } else if (mem.eql(u8, target, "/redirect/4")) {
try request.respond("Hello, Redirected!\n", .{
.status = .found,
.extra_headers = &.{
.{ .name = "location", .value = "/redirect/3" },
},
});
- } else if (mem.eql(u8, request.head.target, "/redirect/5")) {
+ } else if (mem.eql(u8, target, "/redirect/5")) {
try request.respond("Hello, Redirected!\n", .{
.status = .found,
.extra_headers = &.{
.{ .name = "location", .value = "/%2525" },
},
});
- } else if (mem.eql(u8, request.head.target, "/%2525")) {
+ } else if (mem.eql(u8, target, "/%2525")) {
try request.respond("Encoded redirect successful!\n", .{});
- } else if (mem.eql(u8, request.head.target, "/redirect/invalid")) {
+ } else if (mem.eql(u8, target, "/redirect/invalid")) {
const invalid_port = try getUnusedTcpPort();
const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}", .{invalid_port});
defer gpa.free(location);
@@ -506,7 +521,7 @@ test "general client/server API coverage" {
.{ .name = "location", .value = location },
},
});
- } else if (mem.eql(u8, request.head.target, "/empty")) {
+ } else if (mem.eql(u8, target, "/empty")) {
try request.respond("", .{
.extra_headers = &.{
.{ .name = "empty", .value = "" },
@@ -524,17 +539,13 @@ test "general client/server API coverage" {
return s.listen_address.in.getPort();
}
});
- defer {
- global.handle_new_requests = false;
- test_server.destroy();
- }
+ defer test_server.destroy();
const log = std.log.scoped(.client);
const gpa = std.testing.allocator;
var client: http.Client = .{ .allocator = gpa };
- errdefer client.deinit();
- // defer client.deinit(); handled below
+ defer client.deinit();
const port = test_server.port();
@@ -544,20 +555,19 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
+
+ try expectEqualStrings("text/plain", response.head.content_type.?);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
- try expectEqualStrings("text/plain", req.response.content_type.?);
}
// connection has been kept alive
@@ -569,16 +579,14 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
- const body = try req.reader().readAllAlloc(gpa, 8192 * 1024);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len);
@@ -593,21 +601,20 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.HEAD, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.HEAD, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
+
+ try expectEqualStrings("text/plain", response.head.content_type.?);
+ try expectEqual(14, response.head.content_length.?);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("", body);
- try expectEqualStrings("text/plain", req.response.content_type.?);
- try expectEqual(14, req.response.content_length.?);
}
// connection has been kept alive
@@ -619,20 +626,19 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
+
+ try expectEqualStrings("text/plain", response.head.content_type.?);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
- try expectEqualStrings("text/plain", req.response.content_type.?);
}
// connection has been kept alive
@@ -644,21 +650,20 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.HEAD, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.HEAD, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ try expectEqualStrings("text/plain", response.head.content_type.?);
+ try expect(response.head.transfer_encoding == .chunked);
+
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("", body);
- try expectEqualStrings("text/plain", req.response.content_type.?);
- try expect(req.response.transfer_encoding == .chunked);
}
// connection has been kept alive
@@ -670,21 +675,21 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{
.keep_alive = false,
});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
+
+ try expectEqualStrings("text/plain", response.head.content_type.?);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
- try expectEqualStrings("text/plain", req.response.content_type.?);
}
// connection has been closed
@@ -696,32 +701,32 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{
.extra_headers = &.{
.{ .name = "empty", .value = "" },
},
});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
- try std.testing.expectEqual(.ok, req.response.status);
-
- const body = try req.reader().readAllAlloc(gpa, 8192);
- defer gpa.free(body);
+ try std.testing.expectEqual(.ok, response.head.status);
- try expectEqualStrings("", body);
-
- var it = req.response.iterateHeaders();
+ var it = response.head.iterateHeaders();
{
const header = it.next().?;
try expect(!it.is_trailer);
try expectEqualStrings("content-length", header.name);
try expectEqualStrings("0", header.value);
}
+
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
+ defer gpa.free(body);
+
+ try expectEqualStrings("", body);
+
{
const header = it.next().?;
try expect(!it.is_trailer);
@@ -740,16 +745,14 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
@@ -764,16 +767,14 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
@@ -788,16 +789,14 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
@@ -812,17 +811,17 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- req.wait() catch |err| switch (err) {
+ try req.sendBodiless();
+ if (req.receiveHead(&redirect_buffer)) |_| {
+ return error.TestFailed;
+ } else |err| switch (err) {
error.TooManyHttpRedirects => {},
else => return err,
- };
+ }
}
{ // redirect to encoded url
@@ -831,16 +830,14 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Encoded redirect successful!\n", body);
@@ -855,14 +852,12 @@ test "general client/server API coverage" {
const uri = try std.Uri.parse(location);
log.info("{s}", .{location});
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- const result = req.wait();
+ try req.sendBodiless();
+ const result = req.receiveHead(&redirect_buffer);
// a proxy without an upstream is likely to return a 5xx status.
if (client.http_proxy == null) {
@@ -872,77 +867,40 @@ test "general client/server API coverage" {
// connection has been kept alive
try expect(client.http_proxy != null or client.connection_pool.free_len == 1);
-
- { // issue 16282 *** This test leaves the client in an invalid state, it must be last ***
- const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port});
- defer gpa.free(location);
- const uri = try std.Uri.parse(location);
-
- const total_connections = client.connection_pool.free_size + 64;
- var requests = try gpa.alloc(http.Client.Request, total_connections);
- defer gpa.free(requests);
-
- var header_bufs = std.ArrayList([]u8).init(gpa);
- defer header_bufs.deinit();
- defer for (header_bufs.items) |item| gpa.free(item);
-
- for (0..total_connections) |i| {
- const headers_buf = try gpa.alloc(u8, 1024);
- try header_bufs.append(headers_buf);
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = headers_buf,
- });
- req.response.parser.done = true;
- req.connection.?.closing = false;
- requests[i] = req;
- }
-
- for (0..total_connections) |i| {
- requests[i].deinit();
- }
-
- // free connections should be full now
- try expect(client.connection_pool.free_len == client.connection_pool.free_size);
- }
-
- client.deinit();
-
- {
- global.handle_new_requests = false;
-
- const conn = try std.net.tcpConnectToAddress(test_server.net_server.listen_address);
- conn.close();
- }
}
test "Server streams both reading and writing" {
const test_server = try createTestServer(struct {
- fn run(net_server: *std.net.Server) anyerror!void {
- var header_buffer: [1024]u8 = undefined;
- const conn = try net_server.accept();
- defer conn.stream.close();
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [1024]u8 = undefined;
+ var send_buffer: [777]u8 = undefined;
- var server = http.Server.init(conn, &header_buffer);
- var request = try server.receiveHead();
- const reader = try request.reader();
+ const connection = try net_server.accept();
+ defer connection.stream.close();
- var send_buffer: [777]u8 = undefined;
- var response = request.respondStreaming(.{
- .send_buffer = &send_buffer,
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var server = http.Server.init(connection_br.interface(), &connection_bw.interface);
+ var request = try server.receiveHead();
+ var read_buffer: [100]u8 = undefined;
+ var br = try request.readerExpectContinue(&read_buffer);
+ var response = try request.respondStreaming(&.{}, .{
.respond_options = .{
.transfer_encoding = .none, // Causes keep_alive=false
},
});
- const writer = response.writer();
+ const w = &response.writer;
while (true) {
try response.flush();
- var buf: [100]u8 = undefined;
- const n = try reader.read(&buf);
- if (n == 0) break;
- const sub_buf = buf[0..n];
- for (sub_buf) |*b| b.* = std.ascii.toUpper(b.*);
- try writer.writeAll(sub_buf);
+ const buf = br.peekGreedy(1) catch |err| switch (err) {
+ error.EndOfStream => break,
+ error.ReadFailed => return error.ReadFailed,
+ };
+ br.toss(buf.len);
+ for (buf) |*b| b.* = std.ascii.toUpper(b.*);
+ try w.writeAll(buf);
}
try response.end();
}
@@ -952,27 +910,24 @@ test "Server streams both reading and writing" {
var client: http.Client = .{ .allocator = std.testing.allocator };
defer client.deinit();
- var server_header_buffer: [555]u8 = undefined;
- var req = try client.open(.POST, .{
+ var redirect_buffer: [555]u8 = undefined;
+ var req = try client.request(.POST, .{
.scheme = "http",
.host = .{ .raw = "127.0.0.1" },
.port = test_server.port(),
.path = .{ .percent_encoded = "/" },
- }, .{
- .server_header_buffer = &server_header_buffer,
- });
+ }, .{});
defer req.deinit();
req.transfer_encoding = .chunked;
- try req.send();
- try req.wait();
-
- try req.writeAll("one ");
- try req.writeAll("fish");
+ var body_writer = try req.sendBody(&.{});
+ var response = try req.receiveHead(&redirect_buffer);
- try req.finish();
+ try body_writer.writer.writeAll("one ");
+ try body_writer.writer.writeAll("fish");
+ try body_writer.end();
- const body = try req.reader().readAllAlloc(std.testing.allocator, 8192);
+ const body = try response.reader(&.{}).allocRemaining(std.testing.allocator, .unlimited);
defer std.testing.allocator.free(body);
try expectEqualStrings("ONE FISH", body);
@@ -987,9 +942,8 @@ fn echoTests(client: *http.Client, port: u16) !void {
defer gpa.free(location);
const uri = try std.Uri.parse(location);
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.POST, uri, .{
- .server_header_buffer = &server_header_buffer,
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.POST, uri, .{
.extra_headers = &.{
.{ .name = "content-type", .value = "text/plain" },
},
@@ -998,14 +952,14 @@ fn echoTests(client: *http.Client, port: u16) !void {
req.transfer_encoding = .{ .content_length = 14 };
- try req.send();
- try req.writeAll("Hello, ");
- try req.writeAll("World!\n");
- try req.finish();
+ var body_writer = try req.sendBody(&.{});
+ try body_writer.writer.writeAll("Hello, ");
+ try body_writer.writer.writeAll("World!\n");
+ try body_writer.end();
- try req.wait();
+ var response = try req.receiveHead(&redirect_buffer);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
@@ -1021,9 +975,8 @@ fn echoTests(client: *http.Client, port: u16) !void {
.{port},
));
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.POST, uri, .{
- .server_header_buffer = &server_header_buffer,
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.POST, uri, .{
.extra_headers = &.{
.{ .name = "content-type", .value = "text/plain" },
},
@@ -1032,14 +985,14 @@ fn echoTests(client: *http.Client, port: u16) !void {
req.transfer_encoding = .chunked;
- try req.send();
- try req.writeAll("Hello, ");
- try req.writeAll("World!\n");
- try req.finish();
+ var body_writer = try req.sendBody(&.{});
+ try body_writer.writer.writeAll("Hello, ");
+ try body_writer.writer.writeAll("World!\n");
+ try body_writer.end();
- try req.wait();
+ var response = try req.receiveHead(&redirect_buffer);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
@@ -1053,8 +1006,8 @@ fn echoTests(client: *http.Client, port: u16) !void {
const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#fetch", .{port});
defer gpa.free(location);
- var body = std.ArrayList(u8).init(gpa);
- defer body.deinit();
+ var body: std.ArrayListUnmanaged(u8) = .empty;
+ defer body.deinit(gpa);
const res = try client.fetch(.{
.location = .{ .url = location },
@@ -1063,7 +1016,7 @@ fn echoTests(client: *http.Client, port: u16) !void {
.extra_headers = &.{
.{ .name = "content-type", .value = "text/plain" },
},
- .response_storage = .{ .dynamic = &body },
+ .response_storage = .{ .allocator = gpa, .list = &body },
});
try expectEqual(.ok, res.status);
try expectEqualStrings("Hello, World!\n", body.items);
@@ -1074,9 +1027,8 @@ fn echoTests(client: *http.Client, port: u16) !void {
defer gpa.free(location);
const uri = try std.Uri.parse(location);
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.POST, uri, .{
- .server_header_buffer = &server_header_buffer,
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.POST, uri, .{
.extra_headers = &.{
.{ .name = "expect", .value = "100-continue" },
.{ .name = "content-type", .value = "text/plain" },
@@ -1086,15 +1038,15 @@ fn echoTests(client: *http.Client, port: u16) !void {
req.transfer_encoding = .chunked;
- try req.send();
- try req.writeAll("Hello, ");
- try req.writeAll("World!\n");
- try req.finish();
+ var body_writer = try req.sendBody(&.{});
+ try body_writer.writer.writeAll("Hello, ");
+ try body_writer.writer.writeAll("World!\n");
+ try body_writer.end();
- try req.wait();
- try expectEqual(.ok, req.response.status);
+ var response = try req.receiveHead(&redirect_buffer);
+ try expectEqual(.ok, response.head.status);
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("Hello, World!\n", body);
@@ -1105,9 +1057,8 @@ fn echoTests(client: *http.Client, port: u16) !void {
defer gpa.free(location);
const uri = try std.Uri.parse(location);
- var server_header_buffer: [1024]u8 = undefined;
- var req = try client.open(.POST, uri, .{
- .server_header_buffer = &server_header_buffer,
+ var redirect_buffer: [1024]u8 = undefined;
+ var req = try client.request(.POST, uri, .{
.extra_headers = &.{
.{ .name = "content-type", .value = "text/plain" },
.{ .name = "expect", .value = "garbage" },
@@ -1117,23 +1068,24 @@ fn echoTests(client: *http.Client, port: u16) !void {
req.transfer_encoding = .chunked;
- try req.send();
- try req.wait();
- try expectEqual(.expectation_failed, req.response.status);
+ var body_writer = try req.sendBody(&.{});
+ try body_writer.flush();
+ var response = try req.receiveHead(&redirect_buffer);
+ try expectEqual(.expectation_failed, response.head.status);
+ _ = try response.reader(&.{}).discardRemaining();
}
-
- _ = try client.fetch(.{
- .location = .{
- .url = try std.fmt.bufPrint(&location_buffer, "http://127.0.0.1:{d}/end", .{port}),
- },
- });
}
const TestServer = struct {
+ shutting_down: bool,
server_thread: std.Thread,
net_server: std.net.Server,
fn destroy(self: *@This()) void {
+ self.shutting_down = true;
+ const conn = std.net.tcpConnectToAddress(self.net_server.listen_address) catch @panic("shutdown failure");
+ conn.close();
+
self.server_thread.join();
self.net_server.deinit();
std.testing.allocator.destroy(self);
@@ -1153,20 +1105,27 @@ fn createTestServer(S: type) !*TestServer {
const address = try std.net.Address.parseIp("127.0.0.1", 0);
const test_server = try std.testing.allocator.create(TestServer);
- test_server.net_server = try address.listen(.{ .reuse_address = true });
- test_server.server_thread = try std.Thread.spawn(.{}, S.run, .{&test_server.net_server});
+ test_server.* = .{
+ .net_server = try address.listen(.{ .reuse_address = true }),
+ .server_thread = try std.Thread.spawn(.{}, S.run, .{test_server}),
+ .shutting_down = false,
+ };
return test_server;
}
test "redirect to different connection" {
const test_server_new = try createTestServer(struct {
- fn run(net_server: *std.net.Server) anyerror!void {
- var header_buffer: [888]u8 = undefined;
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [888]u8 = undefined;
+ var send_buffer: [777]u8 = undefined;
- const conn = try net_server.accept();
- defer conn.stream.close();
+ const connection = try net_server.accept();
+ defer connection.stream.close();
- var server = http.Server.init(conn, &header_buffer);
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var server = http.Server.init(connection_br.interface(), &connection_bw.interface);
var request = try server.receiveHead();
try expectEqualStrings(request.head.target, "/ok");
try request.respond("good job, you pass", .{});
@@ -1180,18 +1139,22 @@ test "redirect to different connection" {
global.other_port = test_server_new.port();
const test_server_orig = try createTestServer(struct {
- fn run(net_server: *std.net.Server) anyerror!void {
- var header_buffer: [999]u8 = undefined;
+ fn run(test_server: *TestServer) anyerror!void {
+ const net_server = &test_server.net_server;
+ var recv_buffer: [999]u8 = undefined;
var send_buffer: [100]u8 = undefined;
- const conn = try net_server.accept();
- defer conn.stream.close();
+ const connection = try net_server.accept();
+ defer connection.stream.close();
- const new_loc = try std.fmt.bufPrint(&send_buffer, "http://127.0.0.1:{d}/ok", .{
+ var loc_buf: [50]u8 = undefined;
+ const new_loc = try std.fmt.bufPrint(&loc_buf, "http://127.0.0.1:{d}/ok", .{
global.other_port.?,
});
- var server = http.Server.init(conn, &header_buffer);
+ var connection_br = connection.stream.reader(&recv_buffer);
+ var connection_bw = connection.stream.writer(&send_buffer);
+ var server = http.Server.init(connection_br.interface(), &connection_bw.interface);
var request = try server.receiveHead();
try expectEqualStrings(request.head.target, "/help");
try request.respond("", .{
@@ -1216,16 +1179,15 @@ test "redirect to different connection" {
const uri = try std.Uri.parse(location);
{
- var server_header_buffer: [666]u8 = undefined;
- var req = try client.open(.GET, uri, .{
- .server_header_buffer = &server_header_buffer,
- });
+ var redirect_buffer: [666]u8 = undefined;
+ var req = try client.request(.GET, uri, .{});
defer req.deinit();
- try req.send();
- try req.wait();
+ try req.sendBodiless();
+ var response = try req.receiveHead(&redirect_buffer);
+ var reader = response.reader(&.{});
- const body = try req.reader().readAllAlloc(gpa, 8192);
+ const body = try reader.allocRemaining(gpa, .unlimited);
defer gpa.free(body);
try expectEqualStrings("good job, you pass", body);
diff --git a/lib/std/net.zig b/lib/std/net.zig
@@ -1944,7 +1944,7 @@ pub const Stream = struct {
pub const Error = ReadError;
pub fn getStream(r: *const Reader) Stream {
- return r.stream;
+ return r.net_stream;
}
pub fn getError(r: *const Reader) ?Error {
diff --git a/lib/std/std.zig b/lib/std/std.zig
@@ -57,7 +57,6 @@ pub const debug = @import("debug.zig");
pub const dwarf = @import("dwarf.zig");
pub const elf = @import("elf.zig");
pub const enums = @import("enums.zig");
-pub const fifo = @import("fifo.zig");
pub const fmt = @import("fmt.zig");
pub const fs = @import("fs.zig");
pub const gpu = @import("gpu.zig");
diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig
@@ -385,21 +385,23 @@ pub fn run(f: *Fetch) RunError!void {
var resource: Resource = .{ .dir = dir };
return f.runResource(path_or_url, &resource, null);
} else |dir_err| {
+ var server_header_buffer: [init_resource_buffer_size]u8 = undefined;
+
const file_err = if (dir_err == error.NotDir) e: {
if (fs.cwd().openFile(path_or_url, .{})) |file| {
- var resource: Resource = .{ .file = file };
+ var resource: Resource = .{ .file = file.reader(&server_header_buffer) };
return f.runResource(path_or_url, &resource, null);
} else |err| break :e err;
} else dir_err;
const uri = std.Uri.parse(path_or_url) catch |uri_err| {
return f.fail(0, try eb.printString(
- "'{s}' could not be recognized as a file path ({s}) or an URL ({s})",
- .{ path_or_url, @errorName(file_err), @errorName(uri_err) },
+ "'{s}' could not be recognized as a file path ({t}) or an URL ({t})",
+ .{ path_or_url, file_err, uri_err },
));
};
- var server_header_buffer: [header_buffer_size]u8 = undefined;
- var resource = try f.initResource(uri, &server_header_buffer);
+ var resource: Resource = undefined;
+ try f.initResource(uri, &resource, &server_header_buffer);
return f.runResource(try uri.path.toRawMaybeAlloc(arena), &resource, null);
}
},
@@ -464,8 +466,9 @@ pub fn run(f: *Fetch) RunError!void {
f.location_tok,
try eb.printString("invalid URI: {s}", .{@errorName(err)}),
);
- var server_header_buffer: [header_buffer_size]u8 = undefined;
- var resource = try f.initResource(uri, &server_header_buffer);
+ var buffer: [init_resource_buffer_size]u8 = undefined;
+ var resource: Resource = undefined;
+ try f.initResource(uri, &resource, &buffer);
return f.runResource(try uri.path.toRawMaybeAlloc(arena), &resource, remote.hash);
}
@@ -866,8 +869,8 @@ fn fail(f: *Fetch, msg_tok: std.zig.Ast.TokenIndex, msg_str: u32) RunError {
}
const Resource = union(enum) {
- file: fs.File,
- http_request: std.http.Client.Request,
+ file: fs.File.Reader,
+ http_request: HttpRequest,
git: Git,
dir: fs.Dir,
@@ -877,10 +880,16 @@ const Resource = union(enum) {
want_oid: git.Oid,
};
+ const HttpRequest = struct {
+ request: std.http.Client.Request,
+ response: std.http.Client.Response,
+ buffer: []u8,
+ };
+
fn deinit(resource: *Resource) void {
switch (resource.*) {
- .file => |*file| file.close(),
- .http_request => |*req| req.deinit(),
+ .file => |*file_reader| file_reader.file.close(),
+ .http_request => |*http_request| http_request.request.deinit(),
.git => |*git_resource| {
git_resource.fetch_stream.deinit();
git_resource.session.deinit();
@@ -890,21 +899,13 @@ const Resource = union(enum) {
resource.* = undefined;
}
- fn reader(resource: *Resource) std.io.AnyReader {
- return .{
- .context = resource,
- .readFn = read,
- };
- }
-
- fn read(context: *const anyopaque, buffer: []u8) anyerror!usize {
- const resource: *Resource = @ptrCast(@alignCast(@constCast(context)));
- switch (resource.*) {
- .file => |*f| return f.read(buffer),
- .http_request => |*r| return r.read(buffer),
- .git => |*g| return g.fetch_stream.read(buffer),
+ fn reader(resource: *Resource) *std.Io.Reader {
+ return switch (resource.*) {
+ .file => |*file_reader| return &file_reader.interface,
+ .http_request => |*http_request| return http_request.response.reader(http_request.buffer),
+ .git => |*g| return &g.fetch_stream.reader,
.dir => unreachable,
- }
+ };
}
};
@@ -967,20 +968,22 @@ const FileType = enum {
}
};
-const header_buffer_size = 16 * 1024;
+const init_resource_buffer_size = git.Packet.max_data_length;
-fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Resource {
+fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u8) RunError!void {
const gpa = f.arena.child_allocator;
const arena = f.arena.allocator();
const eb = &f.error_bundle;
if (ascii.eqlIgnoreCase(uri.scheme, "file")) {
const path = try uri.path.toRawMaybeAlloc(arena);
- return .{ .file = f.parent_package_root.openFile(path, .{}) catch |err| {
- return f.fail(f.location_tok, try eb.printString("unable to open '{f}{s}': {s}", .{
- f.parent_package_root, path, @errorName(err),
+ const file = f.parent_package_root.openFile(path, .{}) catch |err| {
+ return f.fail(f.location_tok, try eb.printString("unable to open '{f}{s}': {t}", .{
+ f.parent_package_root, path, err,
}));
- } };
+ };
+ resource.* = .{ .file = file.reader(reader_buffer) };
+ return;
}
const http_client = f.job_queue.http_client;
@@ -988,37 +991,35 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re
if (ascii.eqlIgnoreCase(uri.scheme, "http") or
ascii.eqlIgnoreCase(uri.scheme, "https"))
{
- var req = http_client.open(.GET, uri, .{
- .server_header_buffer = server_header_buffer,
- }) catch |err| {
- return f.fail(f.location_tok, try eb.printString(
- "unable to connect to server: {s}",
- .{@errorName(err)},
- ));
- };
- errdefer req.deinit(); // releases more than memory
-
- req.send() catch |err| {
- return f.fail(f.location_tok, try eb.printString(
- "HTTP request failed: {s}",
- .{@errorName(err)},
- ));
- };
- req.wait() catch |err| {
- return f.fail(f.location_tok, try eb.printString(
- "invalid HTTP response: {s}",
- .{@errorName(err)},
- ));
+ resource.* = .{ .http_request = .{
+ .request = http_client.request(.GET, uri, .{}) catch |err|
+ return f.fail(f.location_tok, try eb.printString("unable to connect to server: {t}", .{err})),
+ .response = undefined,
+ .buffer = reader_buffer,
+ } };
+ const request = &resource.http_request.request;
+ errdefer request.deinit();
+
+ request.sendBodiless() catch |err|
+ return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err}));
+
+ var redirect_buffer: [1024]u8 = undefined;
+ const response = &resource.http_request.response;
+ response.* = request.receiveHead(&redirect_buffer) catch |err| switch (err) {
+ error.ReadFailed => {
+ return f.fail(f.location_tok, try eb.printString("HTTP response read failure: {t}", .{
+ request.connection.?.getReadError().?,
+ }));
+ },
+ else => |e| return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{e})),
};
- if (req.response.status != .ok) {
- return f.fail(f.location_tok, try eb.printString(
- "bad HTTP response code: '{d} {s}'",
- .{ @intFromEnum(req.response.status), req.response.status.phrase() orelse "" },
- ));
- }
+ if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString(
+ "bad HTTP response code: '{d} {s}'",
+ .{ response.head.status, response.head.status.phrase() orelse "" },
+ ));
- return .{ .http_request = req };
+ return;
}
if (ascii.eqlIgnoreCase(uri.scheme, "git+http") or
@@ -1026,7 +1027,7 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re
{
var transport_uri = uri;
transport_uri.scheme = uri.scheme["git+".len..];
- var session = git.Session.init(gpa, http_client, transport_uri, server_header_buffer) catch |err| {
+ var session = git.Session.init(gpa, http_client, transport_uri, reader_buffer) catch |err| {
return f.fail(f.location_tok, try eb.printString(
"unable to discover remote git server capabilities: {s}",
.{@errorName(err)},
@@ -1042,16 +1043,12 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re
const want_ref_head = try std.fmt.allocPrint(arena, "refs/heads/{s}", .{want_ref});
const want_ref_tag = try std.fmt.allocPrint(arena, "refs/tags/{s}", .{want_ref});
- var ref_iterator = session.listRefs(.{
+ var ref_iterator: git.Session.RefIterator = undefined;
+ session.listRefs(&ref_iterator, .{
.ref_prefixes = &.{ want_ref, want_ref_head, want_ref_tag },
.include_peeled = true,
- .server_header_buffer = server_header_buffer,
- }) catch |err| {
- return f.fail(f.location_tok, try eb.printString(
- "unable to list refs: {s}",
- .{@errorName(err)},
- ));
- };
+ .buffer = reader_buffer,
+ }) catch |err| return f.fail(f.location_tok, try eb.printString("unable to list refs: {t}", .{err}));
defer ref_iterator.deinit();
while (ref_iterator.next() catch |err| {
return f.fail(f.location_tok, try eb.printString(
@@ -1089,25 +1086,21 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re
var want_oid_buf: [git.Oid.max_formatted_length]u8 = undefined;
_ = std.fmt.bufPrint(&want_oid_buf, "{f}", .{want_oid}) catch unreachable;
- var fetch_stream = session.fetch(&.{&want_oid_buf}, server_header_buffer) catch |err| {
- return f.fail(f.location_tok, try eb.printString(
- "unable to create fetch stream: {s}",
- .{@errorName(err)},
- ));
+ var fetch_stream: git.Session.FetchStream = undefined;
+ session.fetch(&fetch_stream, &.{&want_oid_buf}, reader_buffer) catch |err| {
+ return f.fail(f.location_tok, try eb.printString("unable to create fetch stream: {t}", .{err}));
};
errdefer fetch_stream.deinit();
- return .{ .git = .{
+ resource.* = .{ .git = .{
.session = session,
.fetch_stream = fetch_stream,
.want_oid = want_oid,
} };
+ return;
}
- return f.fail(f.location_tok, try eb.printString(
- "unsupported URL scheme: {s}",
- .{uri.scheme},
- ));
+ return f.fail(f.location_tok, try eb.printString("unsupported URL scheme: {s}", .{uri.scheme}));
}
fn unpackResource(
@@ -1121,9 +1114,11 @@ fn unpackResource(
.file => FileType.fromPath(uri_path) orelse
return f.fail(f.location_tok, try eb.printString("unknown file type: '{s}'", .{uri_path})),
- .http_request => |req| ft: {
+ .http_request => |*http_request| ft: {
+ const head = &http_request.response.head;
+
// Content-Type takes first precedence.
- const content_type = req.response.content_type orelse
+ const content_type = head.content_type orelse
return f.fail(f.location_tok, try eb.addString("missing 'Content-Type' header"));
// Extract the MIME type, ignoring charset and boundary directives
@@ -1165,7 +1160,7 @@ fn unpackResource(
}
// Next, the filename from 'content-disposition: attachment' takes precedence.
- if (req.response.content_disposition) |cd_header| {
+ if (head.content_disposition) |cd_header| {
break :ft FileType.fromContentDisposition(cd_header) orelse {
return f.fail(f.location_tok, try eb.printString(
"unsupported Content-Disposition header value: '{s}' for Content-Type=application/octet-stream",
@@ -1176,10 +1171,7 @@ fn unpackResource(
// Finally, the path from the URI is used.
break :ft FileType.fromPath(uri_path) orelse {
- return f.fail(f.location_tok, try eb.printString(
- "unknown file type: '{s}'",
- .{uri_path},
- ));
+ return f.fail(f.location_tok, try eb.printString("unknown file type: '{s}'", .{uri_path}));
};
},
@@ -1187,10 +1179,9 @@ fn unpackResource(
.dir => |dir| {
f.recursiveDirectoryCopy(dir, tmp_directory.handle) catch |err| {
- return f.fail(f.location_tok, try eb.printString(
- "unable to copy directory '{s}': {s}",
- .{ uri_path, @errorName(err) },
- ));
+ return f.fail(f.location_tok, try eb.printString("unable to copy directory '{s}': {t}", .{
+ uri_path, err,
+ }));
};
return .{};
},
@@ -1198,27 +1189,17 @@ fn unpackResource(
switch (file_type) {
.tar => {
- var adapter_buffer: [1024]u8 = undefined;
- var adapter = resource.reader().adaptToNewApi(&adapter_buffer);
- return unpackTarball(f, tmp_directory.handle, &adapter.new_interface);
+ return unpackTarball(f, tmp_directory.handle, resource.reader());
},
.@"tar.gz" => {
- var adapter_buffer: [std.crypto.tls.max_ciphertext_record_len]u8 = undefined;
- var adapter = resource.reader().adaptToNewApi(&adapter_buffer);
var flate_buffer: [std.compress.flate.max_window_len]u8 = undefined;
- var decompress: std.compress.flate.Decompress = .init(&adapter.new_interface, .gzip, &flate_buffer);
+ var decompress: std.compress.flate.Decompress = .init(resource.reader(), .gzip, &flate_buffer);
return try unpackTarball(f, tmp_directory.handle, &decompress.reader);
},
.@"tar.xz" => {
const gpa = f.arena.child_allocator;
- const reader = resource.reader();
- var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader);
- var dcp = std.compress.xz.decompress(gpa, br.reader()) catch |err| {
- return f.fail(f.location_tok, try eb.printString(
- "unable to decompress tarball: {s}",
- .{@errorName(err)},
- ));
- };
+ var dcp = std.compress.xz.decompress(gpa, resource.reader().adaptToOldInterface()) catch |err|
+ return f.fail(f.location_tok, try eb.printString("unable to decompress tarball: {t}", .{err}));
defer dcp.deinit();
var adapter_buffer: [1024]u8 = undefined;
var adapter = dcp.reader().adaptToNewApi(&adapter_buffer);
@@ -1227,9 +1208,7 @@ fn unpackResource(
.@"tar.zst" => {
const window_size = std.compress.zstd.default_window_len;
const window_buffer = try f.arena.allocator().create([window_size]u8);
- var adapter_buffer: [std.crypto.tls.max_ciphertext_record_len]u8 = undefined;
- var adapter = resource.reader().adaptToNewApi(&adapter_buffer);
- var decompress: std.compress.zstd.Decompress = .init(&adapter.new_interface, window_buffer, .{
+ var decompress: std.compress.zstd.Decompress = .init(resource.reader(), window_buffer, .{
.verify_checksum = false,
});
return try unpackTarball(f, tmp_directory.handle, &decompress.reader);
@@ -1237,12 +1216,15 @@ fn unpackResource(
.git_pack => return unpackGitPack(f, tmp_directory.handle, &resource.git) catch |err| switch (err) {
error.FetchFailed => return error.FetchFailed,
error.OutOfMemory => return error.OutOfMemory,
- else => |e| return f.fail(f.location_tok, try eb.printString(
- "unable to unpack git files: {s}",
- .{@errorName(e)},
+ else => |e| return f.fail(f.location_tok, try eb.printString("unable to unpack git files: {t}", .{e})),
+ },
+ .zip => return unzip(f, tmp_directory.handle, resource.reader()) catch |err| switch (err) {
+ error.ReadFailed => return f.fail(f.location_tok, try eb.printString(
+ "failed reading resource: {t}",
+ .{err},
)),
+ else => |e| return e,
},
- .zip => return try unzip(f, tmp_directory.handle, resource.reader()),
}
}
@@ -1277,99 +1259,69 @@ fn unpackTarball(f: *Fetch, out_dir: fs.Dir, reader: *std.Io.Reader) RunError!Un
return res;
}
-fn unzip(f: *Fetch, out_dir: fs.Dir, reader: anytype) RunError!UnpackResult {
+fn unzip(f: *Fetch, out_dir: fs.Dir, reader: *std.Io.Reader) error{ ReadFailed, OutOfMemory, FetchFailed }!UnpackResult {
// We write the entire contents to a file first because zip files
// must be processed back to front and they could be too large to
// load into memory.
const cache_root = f.job_queue.global_cache;
-
- // TODO: the downside of this solution is if we get a failure/crash/oom/power out
- // during this process, we leave behind a zip file that would be
- // difficult to know if/when it can be cleaned up.
- // Might be worth it to use a mechanism that enables other processes
- // to see if the owning process of a file is still alive (on linux this
- // can be done with file locks).
- // Coupled with this mechansism, we could also use slots (i.e. zig-cache/tmp/0,
- // zig-cache/tmp/1, etc) which would mean that subsequent runs would
- // automatically clean up old dead files.
- // This could all be done with a simple TmpFile abstraction.
const prefix = "tmp/";
const suffix = ".zip";
-
- const random_bytes_count = 20;
- const random_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count);
- var zip_path: [prefix.len + random_path_len + suffix.len]u8 = undefined;
- @memcpy(zip_path[0..prefix.len], prefix);
- @memcpy(zip_path[prefix.len + random_path_len ..], suffix);
- {
- var random_bytes: [random_bytes_count]u8 = undefined;
- std.crypto.random.bytes(&random_bytes);
- _ = std.fs.base64_encoder.encode(
- zip_path[prefix.len..][0..random_path_len],
- &random_bytes,
- );
- }
-
- defer cache_root.handle.deleteFile(&zip_path) catch {};
-
const eb = &f.error_bundle;
-
- {
- var zip_file = cache_root.handle.createFile(
- &zip_path,
- .{},
- ) catch |err| return f.fail(f.location_tok, try eb.printString(
- "failed to create tmp zip file: {s}",
- .{@errorName(err)},
- ));
- defer zip_file.close();
- var buf: [4096]u8 = undefined;
- while (true) {
- const len = reader.readAll(&buf) catch |err| return f.fail(f.location_tok, try eb.printString(
- "read zip stream failed: {s}",
- .{@errorName(err)},
- ));
- if (len == 0) break;
- zip_file.deprecatedWriter().writeAll(buf[0..len]) catch |err| return f.fail(f.location_tok, try eb.printString(
- "write temporary zip file failed: {s}",
- .{@errorName(err)},
- ));
- }
- }
+ const random_len = @sizeOf(u64) * 2;
+
+ var zip_path: [prefix.len + random_len + suffix.len]u8 = undefined;
+ zip_path[0..prefix.len].* = prefix.*;
+ zip_path[prefix.len + random_len ..].* = suffix.*;
+
+ var zip_file = while (true) {
+ const random_integer = std.crypto.random.int(u64);
+ zip_path[prefix.len..][0..random_len].* = std.fmt.hex(random_integer);
+
+ break cache_root.handle.createFile(&zip_path, .{
+ .exclusive = true,
+ .read = true,
+ }) catch |err| switch (err) {
+ error.PathAlreadyExists => continue,
+ else => |e| return f.fail(
+ f.location_tok,
+ try eb.printString("failed to create temporary zip file: {t}", .{e}),
+ ),
+ };
+ };
+ defer zip_file.close();
+ var zip_file_buffer: [4096]u8 = undefined;
+ var zip_file_reader = b: {
+ var zip_file_writer = zip_file.writer(&zip_file_buffer);
+
+ _ = reader.streamRemaining(&zip_file_writer.interface) catch |err| switch (err) {
+ error.ReadFailed => return error.ReadFailed,
+ error.WriteFailed => return f.fail(
+ f.location_tok,
+ try eb.printString("failed writing temporary zip file: {t}", .{err}),
+ ),
+ };
+ zip_file_writer.interface.flush() catch |err| return f.fail(
+ f.location_tok,
+ try eb.printString("failed writing temporary zip file: {t}", .{err}),
+ );
+ break :b zip_file_writer.moveToReader();
+ };
var diagnostics: std.zip.Diagnostics = .{ .allocator = f.arena.allocator() };
// no need to deinit since we are using an arena allocator
- {
- var zip_file = cache_root.handle.openFile(
- &zip_path,
- .{},
- ) catch |err| return f.fail(f.location_tok, try eb.printString(
- "failed to open temporary zip file: {s}",
- .{@errorName(err)},
- ));
- defer zip_file.close();
-
- var zip_file_buffer: [1024]u8 = undefined;
- var zip_file_reader = zip_file.reader(&zip_file_buffer);
-
- std.zip.extract(out_dir, &zip_file_reader, .{
- .allow_backslashes = true,
- .diagnostics = &diagnostics,
- }) catch |err| return f.fail(f.location_tok, try eb.printString(
- "zip extract failed: {s}",
- .{@errorName(err)},
- ));
- }
+ zip_file_reader.seekTo(0) catch |err|
+ return f.fail(f.location_tok, try eb.printString("failed to seek temporary zip file: {t}", .{err}));
+ std.zip.extract(out_dir, &zip_file_reader, .{
+ .allow_backslashes = true,
+ .diagnostics = &diagnostics,
+ }) catch |err| return f.fail(f.location_tok, try eb.printString("zip extract failed: {t}", .{err}));
- cache_root.handle.deleteFile(&zip_path) catch |err| return f.fail(f.location_tok, try eb.printString(
- "delete temporary zip failed: {s}",
- .{@errorName(err)},
- ));
+ cache_root.handle.deleteFile(&zip_path) catch |err|
+ return f.fail(f.location_tok, try eb.printString("delete temporary zip failed: {t}", .{err}));
- const res: UnpackResult = .{ .root_dir = diagnostics.root_dir };
- return res;
+ return .{ .root_dir = diagnostics.root_dir };
}
fn unpackGitPack(f: *Fetch, out_dir: fs.Dir, resource: *Resource.Git) anyerror!UnpackResult {
@@ -1387,10 +1339,13 @@ fn unpackGitPack(f: *Fetch, out_dir: fs.Dir, resource: *Resource.Git) anyerror!U
var pack_file = try pack_dir.createFile("pkg.pack", .{ .read = true });
defer pack_file.close();
var pack_file_buffer: [4096]u8 = undefined;
- var fifo = std.fifo.LinearFifo(u8, .{ .Slice = {} }).init(&pack_file_buffer);
- try fifo.pump(resource.fetch_stream.reader(), pack_file.deprecatedWriter());
-
- var pack_file_reader = pack_file.reader(&pack_file_buffer);
+ var pack_file_reader = b: {
+ var pack_file_writer = pack_file.writer(&pack_file_buffer);
+ const fetch_reader = &resource.fetch_stream.reader;
+ _ = try fetch_reader.streamRemaining(&pack_file_writer.interface);
+ try pack_file_writer.interface.flush();
+ break :b pack_file_writer.moveToReader();
+ };
var index_file = try pack_dir.createFile("pkg.idx", .{ .read = true });
defer index_file.close();
diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig
@@ -585,17 +585,17 @@ const ObjectCache = struct {
/// [protocol-common](https://git-scm.com/docs/protocol-common). The special
/// meanings of the delimiter and response-end packets are documented in
/// [protocol-v2](https://git-scm.com/docs/protocol-v2).
-const Packet = union(enum) {
+pub const Packet = union(enum) {
flush,
delimiter,
response_end,
data: []const u8,
- const max_data_length = 65516;
+ pub const max_data_length = 65516;
/// Reads a packet in pkt-line format.
- fn read(reader: anytype, buf: *[max_data_length]u8) !Packet {
- const length = std.fmt.parseUnsigned(u16, &try reader.readBytesNoEof(4), 16) catch return error.InvalidPacket;
+ fn read(reader: *std.Io.Reader) !Packet {
+ const length = std.fmt.parseUnsigned(u16, try reader.take(4), 16) catch return error.InvalidPacket;
switch (length) {
0 => return .flush,
1 => return .delimiter,
@@ -603,13 +603,11 @@ const Packet = union(enum) {
3 => return error.InvalidPacket,
else => if (length - 4 > max_data_length) return error.InvalidPacket,
}
- const data = buf[0 .. length - 4];
- try reader.readNoEof(data);
- return .{ .data = data };
+ return .{ .data = try reader.take(length - 4) };
}
/// Writes a packet in pkt-line format.
- fn write(packet: Packet, writer: anytype) !void {
+ fn write(packet: Packet, writer: *std.Io.Writer) !void {
switch (packet) {
.flush => try writer.writeAll("0000"),
.delimiter => try writer.writeAll("0001"),
@@ -657,8 +655,10 @@ pub const Session = struct {
allocator: Allocator,
transport: *std.http.Client,
uri: std.Uri,
- http_headers_buffer: []u8,
+ /// Asserted to be at least `Packet.max_data_length`
+ response_buffer: []u8,
) !Session {
+ assert(response_buffer.len >= Packet.max_data_length);
var session: Session = .{
.transport = transport,
.location = try .init(allocator, uri),
@@ -668,7 +668,8 @@ pub const Session = struct {
.allocator = allocator,
};
errdefer session.deinit();
- var capability_iterator = try session.getCapabilities(http_headers_buffer);
+ var capability_iterator: CapabilityIterator = undefined;
+ try session.getCapabilities(&capability_iterator, response_buffer);
defer capability_iterator.deinit();
while (try capability_iterator.next()) |capability| {
if (mem.eql(u8, capability.key, "agent")) {
@@ -743,7 +744,8 @@ pub const Session = struct {
///
/// The `session.location` is updated if the server returns a redirect, so
/// that subsequent session functions do not need to handle redirects.
- fn getCapabilities(session: *Session, http_headers_buffer: []u8) !CapabilityIterator {
+ fn getCapabilities(session: *Session, it: *CapabilityIterator, response_buffer: []u8) !void {
+ assert(response_buffer.len >= Packet.max_data_length);
var info_refs_uri = session.location.uri;
{
const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{
@@ -757,19 +759,22 @@ pub const Session = struct {
info_refs_uri.fragment = null;
const max_redirects = 3;
- var request = try session.transport.open(.GET, info_refs_uri, .{
- .redirect_behavior = @enumFromInt(max_redirects),
- .server_header_buffer = http_headers_buffer,
- .extra_headers = &.{
- .{ .name = "Git-Protocol", .value = "version=2" },
- },
- });
- errdefer request.deinit();
- try request.send();
- try request.finish();
+ it.* = .{
+ .request = try session.transport.request(.GET, info_refs_uri, .{
+ .redirect_behavior = .init(max_redirects),
+ .extra_headers = &.{
+ .{ .name = "Git-Protocol", .value = "version=2" },
+ },
+ }),
+ .reader = undefined,
+ };
+ errdefer it.deinit();
+ const request = &it.request;
+ try request.sendBodiless();
- try request.wait();
- if (request.response.status != .ok) return error.ProtocolError;
+ var redirect_buffer: [1024]u8 = undefined;
+ var response = try request.receiveHead(&redirect_buffer);
+ if (response.head.status != .ok) return error.ProtocolError;
const any_redirects_occurred = request.redirect_behavior.remaining() < max_redirects;
if (any_redirects_occurred) {
const request_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{
@@ -784,8 +789,7 @@ pub const Session = struct {
session.location = new_location;
}
- const reader = request.reader();
- var buf: [Packet.max_data_length]u8 = undefined;
+ it.reader = response.reader(response_buffer);
var state: enum { response_start, response_content } = .response_start;
while (true) {
// Some Git servers (at least GitHub) include an additional
@@ -795,15 +799,15 @@ pub const Session = struct {
// Thus, we need to skip any such useless additional responses
// before we get the one we're actually looking for. The responses
// will be delimited by flush packets.
- const packet = Packet.read(reader, &buf) catch |e| switch (e) {
+ const packet = Packet.read(it.reader) catch |err| switch (err) {
error.EndOfStream => return error.UnsupportedProtocol, // 'version 2' packet not found
- else => |other| return other,
+ else => |e| return e,
};
switch (packet) {
.flush => state = .response_start,
.data => |data| switch (state) {
.response_start => if (mem.eql(u8, Packet.normalizeText(data), "version 2")) {
- return .{ .request = request };
+ return;
} else {
state = .response_content;
},
@@ -816,7 +820,7 @@ pub const Session = struct {
const CapabilityIterator = struct {
request: std.http.Client.Request,
- buf: [Packet.max_data_length]u8 = undefined,
+ reader: *std.Io.Reader,
const Capability = struct {
key: []const u8,
@@ -830,13 +834,13 @@ pub const Session = struct {
}
};
- fn deinit(iterator: *CapabilityIterator) void {
- iterator.request.deinit();
- iterator.* = undefined;
+ fn deinit(it: *CapabilityIterator) void {
+ it.request.deinit();
+ it.* = undefined;
}
- fn next(iterator: *CapabilityIterator) !?Capability {
- switch (try Packet.read(iterator.request.reader(), &iterator.buf)) {
+ fn next(it: *CapabilityIterator) !?Capability {
+ switch (try Packet.read(it.reader)) {
.flush => return null,
.data => |data| return Capability.parse(Packet.normalizeText(data)),
else => return error.UnexpectedPacket,
@@ -854,11 +858,13 @@ pub const Session = struct {
include_symrefs: bool = false,
/// Whether to include the peeled object ID for returned tag refs.
include_peeled: bool = false,
- server_header_buffer: []u8,
+ /// Asserted to be at least `Packet.max_data_length`.
+ buffer: []u8,
};
/// Returns an iterator over refs known to the server.
- pub fn listRefs(session: Session, options: ListRefsOptions) !RefIterator {
+ pub fn listRefs(session: Session, it: *RefIterator, options: ListRefsOptions) !void {
+ assert(options.buffer.len >= Packet.max_data_length);
var upload_pack_uri = session.location.uri;
{
const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{
@@ -871,59 +877,56 @@ pub const Session = struct {
upload_pack_uri.query = null;
upload_pack_uri.fragment = null;
- var body: std.ArrayListUnmanaged(u8) = .empty;
- defer body.deinit(session.allocator);
- const body_writer = body.writer(session.allocator);
- try Packet.write(.{ .data = "command=ls-refs\n" }, body_writer);
+ var body: std.Io.Writer = .fixed(options.buffer);
+ try Packet.write(.{ .data = "command=ls-refs\n" }, &body);
if (session.supports_agent) {
- try Packet.write(.{ .data = agent_capability }, body_writer);
+ try Packet.write(.{ .data = agent_capability }, &body);
}
{
- const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={s}\n", .{@tagName(session.object_format)});
+ const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={t}\n", .{
+ session.object_format,
+ });
defer session.allocator.free(object_format_packet);
- try Packet.write(.{ .data = object_format_packet }, body_writer);
+ try Packet.write(.{ .data = object_format_packet }, &body);
}
- try Packet.write(.delimiter, body_writer);
+ try Packet.write(.delimiter, &body);
for (options.ref_prefixes) |ref_prefix| {
const ref_prefix_packet = try std.fmt.allocPrint(session.allocator, "ref-prefix {s}\n", .{ref_prefix});
defer session.allocator.free(ref_prefix_packet);
- try Packet.write(.{ .data = ref_prefix_packet }, body_writer);
+ try Packet.write(.{ .data = ref_prefix_packet }, &body);
}
if (options.include_symrefs) {
- try Packet.write(.{ .data = "symrefs\n" }, body_writer);
+ try Packet.write(.{ .data = "symrefs\n" }, &body);
}
if (options.include_peeled) {
- try Packet.write(.{ .data = "peel\n" }, body_writer);
+ try Packet.write(.{ .data = "peel\n" }, &body);
}
- try Packet.write(.flush, body_writer);
-
- var request = try session.transport.open(.POST, upload_pack_uri, .{
- .redirect_behavior = .unhandled,
- .server_header_buffer = options.server_header_buffer,
- .extra_headers = &.{
- .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" },
- .{ .name = "Git-Protocol", .value = "version=2" },
- },
- });
- errdefer request.deinit();
- request.transfer_encoding = .{ .content_length = body.items.len };
- try request.send();
- try request.writeAll(body.items);
- try request.finish();
-
- try request.wait();
- if (request.response.status != .ok) return error.ProtocolError;
-
- return .{
+ try Packet.write(.flush, &body);
+
+ it.* = .{
+ .request = try session.transport.request(.POST, upload_pack_uri, .{
+ .redirect_behavior = .unhandled,
+ .extra_headers = &.{
+ .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" },
+ .{ .name = "Git-Protocol", .value = "version=2" },
+ },
+ }),
+ .reader = undefined,
.format = session.object_format,
- .request = request,
};
+ const request = &it.request;
+ errdefer request.deinit();
+ try request.sendBodyComplete(body.buffered());
+
+ var response = try request.receiveHead(options.buffer);
+ if (response.head.status != .ok) return error.ProtocolError;
+ it.reader = response.reader(options.buffer);
}
pub const RefIterator = struct {
format: Oid.Format,
request: std.http.Client.Request,
- buf: [Packet.max_data_length]u8 = undefined,
+ reader: *std.Io.Reader,
pub const Ref = struct {
oid: Oid,
@@ -937,13 +940,13 @@ pub const Session = struct {
iterator.* = undefined;
}
- pub fn next(iterator: *RefIterator) !?Ref {
- switch (try Packet.read(iterator.request.reader(), &iterator.buf)) {
+ pub fn next(it: *RefIterator) !?Ref {
+ switch (try Packet.read(it.reader)) {
.flush => return null,
.data => |data| {
const ref_data = Packet.normalizeText(data);
const oid_sep_pos = mem.indexOfScalar(u8, ref_data, ' ') orelse return error.InvalidRefPacket;
- const oid = Oid.parse(iterator.format, data[0..oid_sep_pos]) catch return error.InvalidRefPacket;
+ const oid = Oid.parse(it.format, data[0..oid_sep_pos]) catch return error.InvalidRefPacket;
const name_sep_pos = mem.indexOfScalarPos(u8, ref_data, oid_sep_pos + 1, ' ') orelse ref_data.len;
const name = ref_data[oid_sep_pos + 1 .. name_sep_pos];
@@ -957,7 +960,7 @@ pub const Session = struct {
if (mem.startsWith(u8, attribute, "symref-target:")) {
symref_target = attribute["symref-target:".len..];
} else if (mem.startsWith(u8, attribute, "peeled:")) {
- peeled = Oid.parse(iterator.format, attribute["peeled:".len..]) catch return error.InvalidRefPacket;
+ peeled = Oid.parse(it.format, attribute["peeled:".len..]) catch return error.InvalidRefPacket;
}
last_sep_pos = next_sep_pos;
}
@@ -973,9 +976,12 @@ pub const Session = struct {
/// performed if the server supports it.
pub fn fetch(
session: Session,
+ fs: *FetchStream,
wants: []const []const u8,
- http_headers_buffer: []u8,
- ) !FetchStream {
+ /// Asserted to be at least `Packet.max_data_length`.
+ response_buffer: []u8,
+ ) !void {
+ assert(response_buffer.len >= Packet.max_data_length);
var upload_pack_uri = session.location.uri;
{
const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{
@@ -988,63 +994,71 @@ pub const Session = struct {
upload_pack_uri.query = null;
upload_pack_uri.fragment = null;
- var body: std.ArrayListUnmanaged(u8) = .empty;
- defer body.deinit(session.allocator);
- const body_writer = body.writer(session.allocator);
- try Packet.write(.{ .data = "command=fetch\n" }, body_writer);
+ var body: std.Io.Writer = .fixed(response_buffer);
+ try Packet.write(.{ .data = "command=fetch\n" }, &body);
if (session.supports_agent) {
- try Packet.write(.{ .data = agent_capability }, body_writer);
+ try Packet.write(.{ .data = agent_capability }, &body);
}
{
const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={s}\n", .{@tagName(session.object_format)});
defer session.allocator.free(object_format_packet);
- try Packet.write(.{ .data = object_format_packet }, body_writer);
+ try Packet.write(.{ .data = object_format_packet }, &body);
}
- try Packet.write(.delimiter, body_writer);
+ try Packet.write(.delimiter, &body);
// Our packfile parser supports the OFS_DELTA object type
- try Packet.write(.{ .data = "ofs-delta\n" }, body_writer);
+ try Packet.write(.{ .data = "ofs-delta\n" }, &body);
// We do not currently convey server progress information to the user
- try Packet.write(.{ .data = "no-progress\n" }, body_writer);
+ try Packet.write(.{ .data = "no-progress\n" }, &body);
if (session.supports_shallow) {
- try Packet.write(.{ .data = "deepen 1\n" }, body_writer);
+ try Packet.write(.{ .data = "deepen 1\n" }, &body);
}
for (wants) |want| {
var buf: [Packet.max_data_length]u8 = undefined;
const arg = std.fmt.bufPrint(&buf, "want {s}\n", .{want}) catch unreachable;
- try Packet.write(.{ .data = arg }, body_writer);
+ try Packet.write(.{ .data = arg }, &body);
}
- try Packet.write(.{ .data = "done\n" }, body_writer);
- try Packet.write(.flush, body_writer);
-
- var request = try session.transport.open(.POST, upload_pack_uri, .{
- .redirect_behavior = .not_allowed,
- .server_header_buffer = http_headers_buffer,
- .extra_headers = &.{
- .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" },
- .{ .name = "Git-Protocol", .value = "version=2" },
- },
- });
+ try Packet.write(.{ .data = "done\n" }, &body);
+ try Packet.write(.flush, &body);
+
+ fs.* = .{
+ .request = try session.transport.request(.POST, upload_pack_uri, .{
+ .redirect_behavior = .not_allowed,
+ .extra_headers = &.{
+ .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" },
+ .{ .name = "Git-Protocol", .value = "version=2" },
+ },
+ }),
+ .input = undefined,
+ .reader = undefined,
+ .remaining_len = undefined,
+ };
+ const request = &fs.request;
errdefer request.deinit();
- request.transfer_encoding = .{ .content_length = body.items.len };
- try request.send();
- try request.writeAll(body.items);
- try request.finish();
- try request.wait();
- if (request.response.status != .ok) return error.ProtocolError;
+ try request.sendBodyComplete(body.buffered());
+
+ var response = try request.receiveHead(&.{});
+ if (response.head.status != .ok) return error.ProtocolError;
- const reader = request.reader();
+ const reader = response.reader(response_buffer);
// We are not interested in any of the sections of the returned fetch
// data other than the packfile section, since we aren't doing anything
// complex like ref negotiation (this is a fresh clone).
var state: enum { section_start, section_content } = .section_start;
while (true) {
- var buf: [Packet.max_data_length]u8 = undefined;
- const packet = try Packet.read(reader, &buf);
+ const packet = try Packet.read(reader);
switch (state) {
.section_start => switch (packet) {
.data => |data| if (mem.eql(u8, Packet.normalizeText(data), "packfile")) {
- return .{ .request = request };
+ fs.input = reader;
+ fs.reader = .{
+ .buffer = &.{},
+ .vtable = &.{ .stream = FetchStream.stream },
+ .seek = 0,
+ .end = 0,
+ };
+ fs.remaining_len = 0;
+ return;
} else {
state = .section_content;
},
@@ -1061,20 +1075,23 @@ pub const Session = struct {
pub const FetchStream = struct {
request: std.http.Client.Request,
- buf: [Packet.max_data_length]u8 = undefined,
- pos: usize = 0,
- len: usize = 0,
+ input: *std.Io.Reader,
+ reader: std.Io.Reader,
+ err: ?Error = null,
+ remaining_len: usize,
- pub fn deinit(stream: *FetchStream) void {
- stream.request.deinit();
+ pub fn deinit(fs: *FetchStream) void {
+ fs.request.deinit();
}
- pub const ReadError = std.http.Client.Request.ReadError || error{
+ pub const Error = error{
InvalidPacket,
ProtocolError,
UnexpectedPacket,
+ WriteFailed,
+ ReadFailed,
+ EndOfStream,
};
- pub const Reader = std.io.GenericReader(*FetchStream, ReadError, read);
const StreamCode = enum(u8) {
pack_data = 1,
@@ -1083,33 +1100,41 @@ pub const Session = struct {
_,
};
- pub fn reader(stream: *FetchStream) Reader {
- return .{ .context = stream };
- }
-
- pub fn read(stream: *FetchStream, buf: []u8) !usize {
- if (stream.pos == stream.len) {
+ pub fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize {
+ const fs: *FetchStream = @alignCast(@fieldParentPtr("reader", r));
+ const input = fs.input;
+ if (fs.remaining_len == 0) {
while (true) {
- switch (try Packet.read(stream.request.reader(), &stream.buf)) {
- .flush => return 0,
+ switch (Packet.read(input) catch |err| {
+ fs.err = err;
+ return error.ReadFailed;
+ }) {
+ .flush => return error.EndOfStream,
.data => |data| if (data.len > 1) switch (@as(StreamCode, @enumFromInt(data[0]))) {
.pack_data => {
- stream.pos = 1;
- stream.len = data.len;
+ input.toss(1);
+ fs.remaining_len = data.len;
break;
},
- .fatal_error => return error.ProtocolError,
+ .fatal_error => {
+ fs.err = error.ProtocolError;
+ return error.ReadFailed;
+ },
else => {},
},
- else => return error.UnexpectedPacket,
+ else => {
+ fs.err = error.UnexpectedPacket;
+ return error.ReadFailed;
+ },
}
}
}
-
- const size = @min(buf.len, stream.len - stream.pos);
- @memcpy(buf[0..size], stream.buf[stream.pos .. stream.pos + size]);
- stream.pos += size;
- return size;
+ const buf = limit.slice(try w.writableSliceGreedy(1));
+ const n = @min(buf.len, fs.remaining_len);
+ @memcpy(buf[0..n], input.buffered()[0..n]);
+ input.toss(n);
+ fs.remaining_len -= n;
+ return n;
}
};
};