commit c9ef277fa7e43f119a7f2896635b4fdf9c97edbe (tree)
parent 8bd734d60cd55d65ea52a051ccdc35939edeb99c
Author: Andrew Kelley <andrew@ziglang.org>
Date: Tue, 3 Jan 2023 02:43:50 -0500
Merge pull request #13980 from ziglang/std.net
networking: delete std.x; add std.crypto.tls and std.http.Client
Diffstat:
38 files changed, 3917 insertions(+), 3707 deletions(-)
diff --git a/lib/std/Url.zig b/lib/std/Url.zig
@@ -0,0 +1,98 @@
+scheme: []const u8,
+host: []const u8,
+path: []const u8,
+port: ?u16,
+
+/// TODO: redo this implementation according to RFC 1738. This code is only a
+/// placeholder for now.
+pub fn parse(s: []const u8) !Url {
+ var scheme_end: usize = 0;
+ var host_start: usize = 0;
+ var host_end: usize = 0;
+ var path_start: usize = 0;
+ var port_start: usize = 0;
+ var port_end: usize = 0;
+ var state: enum {
+ scheme,
+ scheme_slash1,
+ scheme_slash2,
+ host,
+ port,
+ path,
+ } = .scheme;
+
+ for (s) |b, i| switch (state) {
+ .scheme => switch (b) {
+ ':' => {
+ state = .scheme_slash1;
+ scheme_end = i;
+ },
+ else => {},
+ },
+ .scheme_slash1 => switch (b) {
+ '/' => {
+ state = .scheme_slash2;
+ },
+ else => return error.InvalidUrl,
+ },
+ .scheme_slash2 => switch (b) {
+ '/' => {
+ state = .host;
+ host_start = i + 1;
+ },
+ else => return error.InvalidUrl,
+ },
+ .host => switch (b) {
+ ':' => {
+ state = .port;
+ host_end = i;
+ port_start = i + 1;
+ },
+ '/' => {
+ state = .path;
+ host_end = i;
+ path_start = i;
+ },
+ else => {},
+ },
+ .port => switch (b) {
+ '/' => {
+ port_end = i;
+ state = .path;
+ path_start = i;
+ },
+ else => {},
+ },
+ .path => {},
+ };
+
+ const port_slice = s[port_start..port_end];
+ const port = if (port_slice.len == 0) null else try std.fmt.parseInt(u16, port_slice, 10);
+
+ return .{
+ .scheme = s[0..scheme_end],
+ .host = s[host_start..host_end],
+ .path = s[path_start..],
+ .port = port,
+ };
+}
+
+const Url = @This();
+const std = @import("std.zig");
+const testing = std.testing;
+
+test "basic" {
+ const parsed = try parse("https://ziglang.org/download");
+ try testing.expectEqualStrings("https", parsed.scheme);
+ try testing.expectEqualStrings("ziglang.org", parsed.host);
+ try testing.expectEqualStrings("/download", parsed.path);
+ try testing.expectEqual(@as(?u16, null), parsed.port);
+}
+
+test "with port" {
+ const parsed = try parse("http://example:1337/");
+ try testing.expectEqualStrings("http", parsed.scheme);
+ try testing.expectEqualStrings("example", parsed.host);
+ try testing.expectEqualStrings("/", parsed.path);
+ try testing.expectEqual(@as(?u16, 1337), parsed.port);
+}
diff --git a/lib/std/c.zig b/lib/std/c.zig
@@ -206,7 +206,7 @@ pub extern "c" fn sendto(
dest_addr: ?*const c.sockaddr,
addrlen: c.socklen_t,
) isize;
-pub extern "c" fn sendmsg(sockfd: c.fd_t, msg: *const std.x.os.Socket.Message, flags: c_int) isize;
+pub extern "c" fn sendmsg(sockfd: c.fd_t, msg: *const c.msghdr_const, flags: u32) isize;
pub extern "c" fn recv(sockfd: c.fd_t, arg1: ?*anyopaque, arg2: usize, arg3: c_int) isize;
pub extern "c" fn recvfrom(
@@ -217,7 +217,7 @@ pub extern "c" fn recvfrom(
noalias src_addr: ?*c.sockaddr,
noalias addrlen: ?*c.socklen_t,
) isize;
-pub extern "c" fn recvmsg(sockfd: c.fd_t, msg: *std.x.os.Socket.Message, flags: c_int) isize;
+pub extern "c" fn recvmsg(sockfd: c.fd_t, msg: *c.msghdr, flags: u32) isize;
pub extern "c" fn kill(pid: c.pid_t, sig: c_int) c_int;
pub extern "c" fn getdirentries(fd: c.fd_t, buf_ptr: [*]u8, nbytes: usize, basep: *i64) isize;
diff --git a/lib/std/c/darwin.zig b/lib/std/c/darwin.zig
@@ -1007,7 +1007,16 @@ pub const sockaddr = extern struct {
data: [14]u8,
pub const SS_MAXSIZE = 128;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ len: u8 align(8),
+ family: sa_family_t,
+ padding: [126]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
pub const in = extern struct {
len: u8 = @sizeOf(in),
family: sa_family_t = AF.INET,
diff --git a/lib/std/c/dragonfly.zig b/lib/std/c/dragonfly.zig
@@ -1,5 +1,6 @@
const builtin = @import("builtin");
const std = @import("../std.zig");
+const assert = std.debug.assert;
const maxInt = std.math.maxInt;
const iovec = std.os.iovec;
@@ -478,11 +479,20 @@ pub const CLOCK = struct {
pub const sockaddr = extern struct {
len: u8,
- family: u8,
+ family: sa_family_t,
data: [14]u8,
pub const SS_MAXSIZE = 128;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ len: u8 align(8),
+ family: sa_family_t,
+ padding: [126]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
pub const in = extern struct {
len: u8 = @sizeOf(in),
diff --git a/lib/std/c/freebsd.zig b/lib/std/c/freebsd.zig
@@ -1,4 +1,5 @@
const std = @import("../std.zig");
+const assert = std.debug.assert;
const builtin = @import("builtin");
const maxInt = std.math.maxInt;
const iovec = std.os.iovec;
@@ -404,7 +405,16 @@ pub const sockaddr = extern struct {
data: [14]u8,
pub const SS_MAXSIZE = 128;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ len: u8 align(8),
+ family: sa_family_t,
+ padding: [126]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
pub const in = extern struct {
len: u8 = @sizeOf(in),
diff --git a/lib/std/c/haiku.zig b/lib/std/c/haiku.zig
@@ -1,4 +1,5 @@
const std = @import("../std.zig");
+const assert = std.debug.assert;
const builtin = @import("builtin");
const maxInt = std.math.maxInt;
const iovec = std.os.iovec;
@@ -339,7 +340,16 @@ pub const sockaddr = extern struct {
data: [14]u8,
pub const SS_MAXSIZE = 128;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ len: u8 align(8),
+ family: sa_family_t,
+ padding: [126]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
pub const in = extern struct {
len: u8 = @sizeOf(in),
diff --git a/lib/std/c/netbsd.zig b/lib/std/c/netbsd.zig
@@ -1,4 +1,5 @@
const std = @import("../std.zig");
+const assert = std.debug.assert;
const builtin = @import("builtin");
const maxInt = std.math.maxInt;
const iovec = std.os.iovec;
@@ -481,7 +482,16 @@ pub const sockaddr = extern struct {
data: [14]u8,
pub const SS_MAXSIZE = 128;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ len: u8 align(8),
+ family: sa_family_t,
+ padding: [126]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
pub const in = extern struct {
len: u8 = @sizeOf(in),
diff --git a/lib/std/c/openbsd.zig b/lib/std/c/openbsd.zig
@@ -1,4 +1,5 @@
const std = @import("../std.zig");
+const assert = std.debug.assert;
const maxInt = std.math.maxInt;
const builtin = @import("builtin");
const iovec = std.os.iovec;
@@ -372,7 +373,16 @@ pub const sockaddr = extern struct {
data: [14]u8,
pub const SS_MAXSIZE = 256;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ len: u8 align(8),
+ family: sa_family_t,
+ padding: [254]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
pub const in = extern struct {
len: u8 = @sizeOf(in),
diff --git a/lib/std/c/solaris.zig b/lib/std/c/solaris.zig
@@ -1,4 +1,5 @@
const std = @import("../std.zig");
+const assert = std.debug.assert;
const builtin = @import("builtin");
const maxInt = std.math.maxInt;
const iovec = std.os.iovec;
@@ -435,7 +436,15 @@ pub const sockaddr = extern struct {
data: [14]u8,
pub const SS_MAXSIZE = 256;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ family: sa_family_t align(8),
+ padding: [254]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
pub const in = extern struct {
family: sa_family_t = AF.INET,
diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig
@@ -176,6 +176,9 @@ const std = @import("std.zig");
pub const errors = @import("crypto/errors.zig");
+pub const tls = @import("crypto/tls.zig");
+pub const Certificate = @import("crypto/Certificate.zig");
+
test {
_ = aead.aegis.Aegis128L;
_ = aead.aegis.Aegis256;
@@ -264,6 +267,8 @@ test {
_ = utils;
_ = random;
_ = errors;
+ _ = tls;
+ _ = Certificate;
}
test "CSPRNG" {
diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig
@@ -0,0 +1,1115 @@
+buffer: []const u8,
+index: u32,
+
+pub const Bundle = @import("Certificate/Bundle.zig");
+
+pub const Algorithm = enum {
+ sha1WithRSAEncryption,
+ sha224WithRSAEncryption,
+ sha256WithRSAEncryption,
+ sha384WithRSAEncryption,
+ sha512WithRSAEncryption,
+ ecdsa_with_SHA224,
+ ecdsa_with_SHA256,
+ ecdsa_with_SHA384,
+ ecdsa_with_SHA512,
+
+ pub const map = std.ComptimeStringMap(Algorithm, .{
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 },
+ });
+
+ pub fn Hash(comptime algorithm: Algorithm) type {
+ return switch (algorithm) {
+ .sha1WithRSAEncryption => crypto.hash.Sha1,
+ .ecdsa_with_SHA224, .sha224WithRSAEncryption => crypto.hash.sha2.Sha224,
+ .ecdsa_with_SHA256, .sha256WithRSAEncryption => crypto.hash.sha2.Sha256,
+ .ecdsa_with_SHA384, .sha384WithRSAEncryption => crypto.hash.sha2.Sha384,
+ .ecdsa_with_SHA512, .sha512WithRSAEncryption => crypto.hash.sha2.Sha512,
+ };
+ }
+};
+
+pub const AlgorithmCategory = enum {
+ rsaEncryption,
+ X9_62_id_ecPublicKey,
+
+ pub const map = std.ComptimeStringMap(AlgorithmCategory, .{
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey },
+ });
+};
+
+pub const Attribute = enum {
+ commonName,
+ serialNumber,
+ countryName,
+ localityName,
+ stateOrProvinceName,
+ organizationName,
+ organizationalUnitName,
+ organizationIdentifier,
+ pkcs9_emailAddress,
+
+ pub const map = std.ComptimeStringMap(Attribute, .{
+ .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName },
+ .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber },
+ .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName },
+ .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName },
+ .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName },
+ .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName },
+ .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName },
+ .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress },
+ });
+};
+
+pub const NamedCurve = enum {
+ secp384r1,
+ X9_62_prime256v1,
+
+ pub const map = std.ComptimeStringMap(NamedCurve, .{
+ .{ &[_]u8{ 0x2B, 0x81, 0x04, 0x00, 0x22 }, .secp384r1 },
+ .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07 }, .X9_62_prime256v1 },
+ });
+};
+
+pub const ExtensionId = enum {
+ subject_key_identifier,
+ key_usage,
+ private_key_usage_period,
+ subject_alt_name,
+ issuer_alt_name,
+ basic_constraints,
+ crl_number,
+ certificate_policies,
+ authority_key_identifier,
+
+ pub const map = std.ComptimeStringMap(ExtensionId, .{
+ .{ &[_]u8{ 0x55, 0x1D, 0x0E }, .subject_key_identifier },
+ .{ &[_]u8{ 0x55, 0x1D, 0x0F }, .key_usage },
+ .{ &[_]u8{ 0x55, 0x1D, 0x10 }, .private_key_usage_period },
+ .{ &[_]u8{ 0x55, 0x1D, 0x11 }, .subject_alt_name },
+ .{ &[_]u8{ 0x55, 0x1D, 0x12 }, .issuer_alt_name },
+ .{ &[_]u8{ 0x55, 0x1D, 0x13 }, .basic_constraints },
+ .{ &[_]u8{ 0x55, 0x1D, 0x14 }, .crl_number },
+ .{ &[_]u8{ 0x55, 0x1D, 0x20 }, .certificate_policies },
+ .{ &[_]u8{ 0x55, 0x1D, 0x23 }, .authority_key_identifier },
+ });
+};
+
+pub const GeneralNameTag = enum(u5) {
+ otherName = 0,
+ rfc822Name = 1,
+ dNSName = 2,
+ x400Address = 3,
+ directoryName = 4,
+ ediPartyName = 5,
+ uniformResourceIdentifier = 6,
+ iPAddress = 7,
+ registeredID = 8,
+ _,
+};
+
+pub const Parsed = struct {
+ certificate: Certificate,
+ issuer_slice: Slice,
+ subject_slice: Slice,
+ common_name_slice: Slice,
+ signature_slice: Slice,
+ signature_algorithm: Algorithm,
+ pub_key_algo: PubKeyAlgo,
+ pub_key_slice: Slice,
+ message_slice: Slice,
+ subject_alt_name_slice: Slice,
+ validity: Validity,
+
+ pub const PubKeyAlgo = union(AlgorithmCategory) {
+ rsaEncryption: void,
+ X9_62_id_ecPublicKey: NamedCurve,
+ };
+
+ pub const Validity = struct {
+ not_before: u64,
+ not_after: u64,
+ };
+
+ pub const Slice = der.Element.Slice;
+
+ pub fn slice(p: Parsed, s: Slice) []const u8 {
+ return p.certificate.buffer[s.start..s.end];
+ }
+
+ pub fn issuer(p: Parsed) []const u8 {
+ return p.slice(p.issuer_slice);
+ }
+
+ pub fn subject(p: Parsed) []const u8 {
+ return p.slice(p.subject_slice);
+ }
+
+ pub fn commonName(p: Parsed) []const u8 {
+ return p.slice(p.common_name_slice);
+ }
+
+ pub fn signature(p: Parsed) []const u8 {
+ return p.slice(p.signature_slice);
+ }
+
+ pub fn pubKey(p: Parsed) []const u8 {
+ return p.slice(p.pub_key_slice);
+ }
+
+ pub fn pubKeySigAlgo(p: Parsed) []const u8 {
+ return p.slice(p.pub_key_signature_algorithm_slice);
+ }
+
+ pub fn message(p: Parsed) []const u8 {
+ return p.slice(p.message_slice);
+ }
+
+ pub fn subjectAltName(p: Parsed) []const u8 {
+ return p.slice(p.subject_alt_name_slice);
+ }
+
+ pub const VerifyError = error{
+ CertificateIssuerMismatch,
+ CertificateNotYetValid,
+ CertificateExpired,
+ CertificateSignatureAlgorithmUnsupported,
+ CertificateSignatureAlgorithmMismatch,
+ CertificateFieldHasInvalidLength,
+ CertificateFieldHasWrongDataType,
+ CertificatePublicKeyInvalid,
+ CertificateSignatureInvalidLength,
+ CertificateSignatureInvalid,
+ CertificateSignatureUnsupportedBitCount,
+ CertificateSignatureNamedCurveUnsupported,
+ };
+
+ /// This function verifies:
+ /// * That the subject's issuer is indeed the provided issuer.
+ /// * The time validity of the subject.
+ /// * The signature.
+ pub fn verify(parsed_subject: Parsed, parsed_issuer: Parsed, now_sec: i64) VerifyError!void {
+ // Check that the subject's issuer name matches the issuer's
+ // subject name.
+ if (!mem.eql(u8, parsed_subject.issuer(), parsed_issuer.subject())) {
+ return error.CertificateIssuerMismatch;
+ }
+
+ if (now_sec < parsed_subject.validity.not_before)
+ return error.CertificateNotYetValid;
+ if (now_sec > parsed_subject.validity.not_after)
+ return error.CertificateExpired;
+
+ switch (parsed_subject.signature_algorithm) {
+ inline .sha1WithRSAEncryption,
+ .sha224WithRSAEncryption,
+ .sha256WithRSAEncryption,
+ .sha384WithRSAEncryption,
+ .sha512WithRSAEncryption,
+ => |algorithm| return verifyRsa(
+ algorithm.Hash(),
+ parsed_subject.message(),
+ parsed_subject.signature(),
+ parsed_issuer.pub_key_algo,
+ parsed_issuer.pubKey(),
+ ),
+
+ inline .ecdsa_with_SHA224,
+ .ecdsa_with_SHA256,
+ .ecdsa_with_SHA384,
+ .ecdsa_with_SHA512,
+ => |algorithm| return verify_ecdsa(
+ algorithm.Hash(),
+ parsed_subject.message(),
+ parsed_subject.signature(),
+ parsed_issuer.pub_key_algo,
+ parsed_issuer.pubKey(),
+ ),
+ }
+ }
+
+ pub const VerifyHostNameError = error{
+ CertificateHostMismatch,
+ CertificateFieldHasInvalidLength,
+ };
+
+ pub fn verifyHostName(parsed_subject: Parsed, host_name: []const u8) VerifyHostNameError!void {
+ // If the Subject Alternative Names extension is present, this is
+ // what to check. Otherwise, only the common name is checked.
+ const subject_alt_name = parsed_subject.subjectAltName();
+ if (subject_alt_name.len == 0) {
+ if (checkHostName(host_name, parsed_subject.commonName())) {
+ return;
+ } else {
+ return error.CertificateHostMismatch;
+ }
+ }
+
+ const general_names = try der.Element.parse(subject_alt_name, 0);
+ var name_i = general_names.slice.start;
+ while (name_i < general_names.slice.end) {
+ const general_name = try der.Element.parse(subject_alt_name, name_i);
+ name_i = general_name.slice.end;
+ switch (@intToEnum(GeneralNameTag, @enumToInt(general_name.identifier.tag))) {
+ .dNSName => {
+ const dns_name = subject_alt_name[general_name.slice.start..general_name.slice.end];
+ if (checkHostName(host_name, dns_name)) return;
+ },
+ else => {},
+ }
+ }
+
+ return error.CertificateHostMismatch;
+ }
+
+ fn checkHostName(host_name: []const u8, dns_name: []const u8) bool {
+ if (mem.eql(u8, dns_name, host_name)) {
+ return true; // exact match
+ }
+
+ if (mem.startsWith(u8, dns_name, "*.")) {
+ // wildcard certificate, matches any subdomain
+ // TODO: I think wildcards are not supposed to match any prefix but
+ // only match exactly one subdomain.
+ if (mem.endsWith(u8, host_name, dns_name[1..])) {
+ // The host_name has a subdomain, but the important part matches.
+ return true;
+ }
+ if (mem.eql(u8, dns_name[2..], host_name)) {
+ // The host_name has no subdomain and matches exactly.
+ return true;
+ }
+ }
+
+ return false;
+ }
+};
+
+pub fn parse(cert: Certificate) !Parsed {
+ const cert_bytes = cert.buffer;
+ const certificate = try der.Element.parse(cert_bytes, cert.index);
+ const tbs_certificate = try der.Element.parse(cert_bytes, certificate.slice.start);
+ const version = try der.Element.parse(cert_bytes, tbs_certificate.slice.start);
+ try checkVersion(cert_bytes, version);
+ const serial_number = try der.Element.parse(cert_bytes, version.slice.end);
+ // RFC 5280, section 4.1.2.3:
+ // "This field MUST contain the same algorithm identifier as
+ // the signatureAlgorithm field in the sequence Certificate."
+ const tbs_signature = try der.Element.parse(cert_bytes, serial_number.slice.end);
+ const issuer = try der.Element.parse(cert_bytes, tbs_signature.slice.end);
+ const validity = try der.Element.parse(cert_bytes, issuer.slice.end);
+ const not_before = try der.Element.parse(cert_bytes, validity.slice.start);
+ const not_before_utc = try parseTime(cert, not_before);
+ const not_after = try der.Element.parse(cert_bytes, not_before.slice.end);
+ const not_after_utc = try parseTime(cert, not_after);
+ const subject = try der.Element.parse(cert_bytes, validity.slice.end);
+
+ const pub_key_info = try der.Element.parse(cert_bytes, subject.slice.end);
+ const pub_key_signature_algorithm = try der.Element.parse(cert_bytes, pub_key_info.slice.start);
+ const pub_key_algo_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.start);
+ const pub_key_algo_tag = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem);
+ var pub_key_algo: Parsed.PubKeyAlgo = undefined;
+ switch (pub_key_algo_tag) {
+ .rsaEncryption => {
+ pub_key_algo = .{ .rsaEncryption = {} };
+ },
+ .X9_62_id_ecPublicKey => {
+ // RFC 5480 Section 2.1.1.1 Named Curve
+ // ECParameters ::= CHOICE {
+ // namedCurve OBJECT IDENTIFIER
+ // -- implicitCurve NULL
+ // -- specifiedCurve SpecifiedECDomain
+ // }
+ const params_elem = try der.Element.parse(cert_bytes, pub_key_algo_elem.slice.end);
+ const named_curve = try parseNamedCurve(cert_bytes, params_elem);
+ pub_key_algo = .{ .X9_62_id_ecPublicKey = named_curve };
+ },
+ }
+ const pub_key_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.end);
+ const pub_key = try parseBitString(cert, pub_key_elem);
+
+ var common_name = der.Element.Slice.empty;
+ var name_i = subject.slice.start;
+ while (name_i < subject.slice.end) {
+ const rdn = try der.Element.parse(cert_bytes, name_i);
+ var rdn_i = rdn.slice.start;
+ while (rdn_i < rdn.slice.end) {
+ const atav = try der.Element.parse(cert_bytes, rdn_i);
+ var atav_i = atav.slice.start;
+ while (atav_i < atav.slice.end) {
+ const ty_elem = try der.Element.parse(cert_bytes, atav_i);
+ const ty = try parseAttribute(cert_bytes, ty_elem);
+ const val = try der.Element.parse(cert_bytes, ty_elem.slice.end);
+ switch (ty) {
+ .commonName => common_name = val.slice,
+ else => {},
+ }
+ atav_i = val.slice.end;
+ }
+ rdn_i = atav.slice.end;
+ }
+ name_i = rdn.slice.end;
+ }
+
+ const sig_algo = try der.Element.parse(cert_bytes, tbs_certificate.slice.end);
+ const algo_elem = try der.Element.parse(cert_bytes, sig_algo.slice.start);
+ const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem);
+ const sig_elem = try der.Element.parse(cert_bytes, sig_algo.slice.end);
+ const signature = try parseBitString(cert, sig_elem);
+
+ // Extensions
+ var subject_alt_name_slice = der.Element.Slice.empty;
+ ext: {
+ if (pub_key_info.slice.end >= tbs_certificate.slice.end)
+ break :ext;
+
+ const outer_extensions = try der.Element.parse(cert_bytes, pub_key_info.slice.end);
+ if (outer_extensions.identifier.tag != .bitstring)
+ break :ext;
+
+ const extensions = try der.Element.parse(cert_bytes, outer_extensions.slice.start);
+
+ var ext_i = extensions.slice.start;
+ while (ext_i < extensions.slice.end) {
+ const extension = try der.Element.parse(cert_bytes, ext_i);
+ ext_i = extension.slice.end;
+ const oid_elem = try der.Element.parse(cert_bytes, extension.slice.start);
+ const ext_id = parseExtensionId(cert_bytes, oid_elem) catch |err| switch (err) {
+ error.CertificateHasUnrecognizedObjectId => continue,
+ else => |e| return e,
+ };
+ const critical_elem = try der.Element.parse(cert_bytes, oid_elem.slice.end);
+ const ext_bytes_elem = if (critical_elem.identifier.tag != .boolean)
+ critical_elem
+ else
+ try der.Element.parse(cert_bytes, critical_elem.slice.end);
+ switch (ext_id) {
+ .subject_alt_name => subject_alt_name_slice = ext_bytes_elem.slice,
+ else => continue,
+ }
+ }
+ }
+
+ return .{
+ .certificate = cert,
+ .common_name_slice = common_name,
+ .issuer_slice = issuer.slice,
+ .subject_slice = subject.slice,
+ .signature_slice = signature,
+ .signature_algorithm = signature_algorithm,
+ .message_slice = .{ .start = certificate.slice.start, .end = tbs_certificate.slice.end },
+ .pub_key_algo = pub_key_algo,
+ .pub_key_slice = pub_key,
+ .validity = .{
+ .not_before = not_before_utc,
+ .not_after = not_after_utc,
+ },
+ .subject_alt_name_slice = subject_alt_name_slice,
+ };
+}
+
+pub fn verify(subject: Certificate, issuer: Certificate, now_sec: i64) !void {
+ const parsed_subject = try subject.parse();
+ const parsed_issuer = try issuer.parse();
+ return parsed_subject.verify(parsed_issuer, now_sec);
+}
+
+pub fn contents(cert: Certificate, elem: der.Element) []const u8 {
+ return cert.buffer[elem.slice.start..elem.slice.end];
+}
+
+pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice {
+ if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType;
+ if (cert.buffer[elem.slice.start] != 0) return error.CertificateHasInvalidBitString;
+ return .{ .start = elem.slice.start + 1, .end = elem.slice.end };
+}
+
+/// Returns number of seconds since epoch.
+pub fn parseTime(cert: Certificate, elem: der.Element) !u64 {
+ const bytes = cert.contents(elem);
+ switch (elem.identifier.tag) {
+ .utc_time => {
+ // Example: "YYMMDD000000Z"
+ if (bytes.len != 13)
+ return error.CertificateTimeInvalid;
+ if (bytes[12] != 'Z')
+ return error.CertificateTimeInvalid;
+
+ return Date.toSeconds(.{
+ .year = @as(u16, 2000) + try parseTimeDigits(bytes[0..2].*, 0, 99),
+ .month = try parseTimeDigits(bytes[2..4].*, 1, 12),
+ .day = try parseTimeDigits(bytes[4..6].*, 1, 31),
+ .hour = try parseTimeDigits(bytes[6..8].*, 0, 23),
+ .minute = try parseTimeDigits(bytes[8..10].*, 0, 59),
+ .second = try parseTimeDigits(bytes[10..12].*, 0, 59),
+ });
+ },
+ .generalized_time => {
+ // Examples:
+ // "19920521000000Z"
+ // "19920622123421Z"
+ // "19920722132100.3Z"
+ if (bytes.len < 15)
+ return error.CertificateTimeInvalid;
+ return Date.toSeconds(.{
+ .year = try parseYear4(bytes[0..4]),
+ .month = try parseTimeDigits(bytes[4..6].*, 1, 12),
+ .day = try parseTimeDigits(bytes[6..8].*, 1, 31),
+ .hour = try parseTimeDigits(bytes[8..10].*, 0, 23),
+ .minute = try parseTimeDigits(bytes[10..12].*, 0, 59),
+ .second = try parseTimeDigits(bytes[12..14].*, 0, 59),
+ });
+ },
+ else => return error.CertificateFieldHasWrongDataType,
+ }
+}
+
+const Date = struct {
+ /// example: 1999
+ year: u16,
+ /// range: 1 to 12
+ month: u8,
+ /// range: 1 to 31
+ day: u8,
+ /// range: 0 to 59
+ hour: u8,
+ /// range: 0 to 59
+ minute: u8,
+ /// range: 0 to 59
+ second: u8,
+
+ /// Convert to number of seconds since epoch.
+ pub fn toSeconds(date: Date) u64 {
+ var sec: u64 = 0;
+
+ {
+ var year: u16 = 1970;
+ while (year < date.year) : (year += 1) {
+ const days: u64 = std.time.epoch.getDaysInYear(year);
+ sec += days * std.time.epoch.secs_per_day;
+ }
+ }
+
+ {
+ const is_leap = std.time.epoch.isLeapYear(date.year);
+ var month: u4 = 1;
+ while (month < date.month) : (month += 1) {
+ const days: u64 = std.time.epoch.getDaysInMonth(
+ @intToEnum(std.time.epoch.YearLeapKind, @boolToInt(is_leap)),
+ @intToEnum(std.time.epoch.Month, month),
+ );
+ sec += days * std.time.epoch.secs_per_day;
+ }
+ }
+
+ sec += (date.day - 1) * @as(u64, std.time.epoch.secs_per_day);
+ sec += date.hour * @as(u64, 60 * 60);
+ sec += date.minute * @as(u64, 60);
+ sec += date.second;
+
+ return sec;
+ }
+};
+
+pub fn parseTimeDigits(nn: @Vector(2, u8), min: u8, max: u8) !u8 {
+ const zero: @Vector(2, u8) = .{ '0', '0' };
+ const mm: @Vector(2, u8) = .{ 10, 1 };
+ const result = @reduce(.Add, (nn -% zero) *% mm);
+ if (result < min) return error.CertificateTimeInvalid;
+ if (result > max) return error.CertificateTimeInvalid;
+ return result;
+}
+
+test parseTimeDigits {
+ const expectEqual = std.testing.expectEqual;
+ try expectEqual(@as(u8, 0), try parseTimeDigits("00".*, 0, 99));
+ try expectEqual(@as(u8, 99), try parseTimeDigits("99".*, 0, 99));
+ try expectEqual(@as(u8, 42), try parseTimeDigits("42".*, 0, 99));
+
+ const expectError = std.testing.expectError;
+ try expectError(error.CertificateTimeInvalid, parseTimeDigits("13".*, 1, 12));
+ try expectError(error.CertificateTimeInvalid, parseTimeDigits("00".*, 1, 12));
+}
+
+pub fn parseYear4(text: *const [4]u8) !u16 {
+ const nnnn: @Vector(4, u16) = .{ text[0], text[1], text[2], text[3] };
+ const zero: @Vector(4, u16) = .{ '0', '0', '0', '0' };
+ const mmmm: @Vector(4, u16) = .{ 1000, 100, 10, 1 };
+ const result = @reduce(.Add, (nnnn -% zero) *% mmmm);
+ if (result > 9999) return error.CertificateTimeInvalid;
+ return result;
+}
+
+test parseYear4 {
+ const expectEqual = std.testing.expectEqual;
+ try expectEqual(@as(u16, 0), try parseYear4("0000"));
+ try expectEqual(@as(u16, 9999), try parseYear4("9999"));
+ try expectEqual(@as(u16, 1988), try parseYear4("1988"));
+
+ const expectError = std.testing.expectError;
+ try expectError(error.CertificateTimeInvalid, parseYear4("999b"));
+ try expectError(error.CertificateTimeInvalid, parseYear4("crap"));
+}
+
+pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm {
+ return parseEnum(Algorithm, bytes, element);
+}
+
+pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory {
+ return parseEnum(AlgorithmCategory, bytes, element);
+}
+
+pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute {
+ return parseEnum(Attribute, bytes, element);
+}
+
+pub fn parseNamedCurve(bytes: []const u8, element: der.Element) !NamedCurve {
+ return parseEnum(NamedCurve, bytes, element);
+}
+
+pub fn parseExtensionId(bytes: []const u8, element: der.Element) !ExtensionId {
+ return parseEnum(ExtensionId, bytes, element);
+}
+
+fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) !E {
+ if (element.identifier.tag != .object_identifier)
+ return error.CertificateFieldHasWrongDataType;
+ const oid_bytes = bytes[element.slice.start..element.slice.end];
+ return E.map.get(oid_bytes) orelse return error.CertificateHasUnrecognizedObjectId;
+}
+
+pub fn checkVersion(bytes: []const u8, version: der.Element) !void {
+ if (@bitCast(u8, version.identifier) != 0xa0 or
+ !mem.eql(u8, bytes[version.slice.start..version.slice.end], "\x02\x01\x02"))
+ {
+ return error.UnsupportedCertificateVersion;
+ }
+}
+
+fn verifyRsa(
+ comptime Hash: type,
+ message: []const u8,
+ sig: []const u8,
+ pub_key_algo: Parsed.PubKeyAlgo,
+ pub_key: []const u8,
+) !void {
+ if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch;
+ const pk_components = try rsa.PublicKey.parseDer(pub_key);
+ const exponent = pk_components.exponent;
+ const modulus = pk_components.modulus;
+ if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid;
+ if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength;
+
+ const hash_der = switch (Hash) {
+ crypto.hash.Sha1 => [_]u8{
+ 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e,
+ 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14,
+ },
+ crypto.hash.sha2.Sha224 => [_]u8{
+ 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
+ 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05,
+ 0x00, 0x04, 0x1c,
+ },
+ crypto.hash.sha2.Sha256 => [_]u8{
+ 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
+ 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05,
+ 0x00, 0x04, 0x20,
+ },
+ crypto.hash.sha2.Sha384 => [_]u8{
+ 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
+ 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05,
+ 0x00, 0x04, 0x30,
+ },
+ crypto.hash.sha2.Sha512 => [_]u8{
+ 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
+ 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05,
+ 0x00, 0x04, 0x40,
+ },
+ else => @compileError("unreachable"),
+ };
+
+ var msg_hashed: [Hash.digest_length]u8 = undefined;
+ Hash.hash(message, &msg_hashed, .{});
+
+ var rsa_mem_buf: [512 * 64]u8 = undefined;
+ var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf);
+ const ally = fba.allocator();
+
+ switch (modulus.len) {
+ inline 128, 256, 512 => |modulus_len| {
+ const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3;
+ const em: [modulus_len]u8 =
+ [2]u8{ 0, 1 } ++
+ ([1]u8{0xff} ** ps_len) ++
+ [1]u8{0} ++
+ hash_der ++
+ msg_hashed;
+
+ const public_key = rsa.PublicKey.fromBytes(exponent, modulus, ally) catch |err| switch (err) {
+ error.OutOfMemory => unreachable, // rsa_mem_buf is big enough
+ };
+ const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, ally) catch |err| switch (err) {
+ error.OutOfMemory => unreachable, // rsa_mem_buf is big enough
+
+ error.MessageTooLong => unreachable,
+ error.NegativeIntoUnsigned => @panic("TODO make RSA not emit this error"),
+ error.TargetTooSmall => @panic("TODO make RSA not emit this error"),
+ error.BufferTooSmall => @panic("TODO make RSA not emit this error"),
+ };
+
+ if (!mem.eql(u8, &em, &em_dec)) {
+ return error.CertificateSignatureInvalid;
+ }
+ },
+ else => {
+ return error.CertificateSignatureUnsupportedBitCount;
+ },
+ }
+}
+
+fn verify_ecdsa(
+ comptime Hash: type,
+ message: []const u8,
+ encoded_sig: []const u8,
+ pub_key_algo: Parsed.PubKeyAlgo,
+ sec1_pub_key: []const u8,
+) !void {
+ const sig_named_curve = switch (pub_key_algo) {
+ .X9_62_id_ecPublicKey => |named_curve| named_curve,
+ else => return error.CertificateSignatureAlgorithmMismatch,
+ };
+
+ switch (sig_named_curve) {
+ .secp384r1 => {
+ const P = crypto.ecc.P384;
+ const Ecdsa = crypto.sign.ecdsa.Ecdsa(P, Hash);
+ const sig = Ecdsa.Signature.fromDer(encoded_sig) catch |err| switch (err) {
+ error.InvalidEncoding => return error.CertificateSignatureInvalid,
+ };
+ const pub_key = Ecdsa.PublicKey.fromSec1(sec1_pub_key) catch |err| switch (err) {
+ error.InvalidEncoding => return error.CertificateSignatureInvalid,
+ error.NonCanonical => return error.CertificateSignatureInvalid,
+ error.NotSquare => return error.CertificateSignatureInvalid,
+ };
+ sig.verify(message, pub_key) catch |err| switch (err) {
+ error.IdentityElement => return error.CertificateSignatureInvalid,
+ error.NonCanonical => return error.CertificateSignatureInvalid,
+ error.SignatureVerificationFailed => return error.CertificateSignatureInvalid,
+ };
+ },
+ .X9_62_prime256v1 => {
+ return error.CertificateSignatureNamedCurveUnsupported;
+ },
+ }
+}
+
+const std = @import("../std.zig");
+const crypto = std.crypto;
+const mem = std.mem;
+const Certificate = @This();
+
+pub const der = struct {
+ pub const Class = enum(u2) {
+ universal,
+ application,
+ context_specific,
+ private,
+ };
+
+ pub const PC = enum(u1) {
+ primitive,
+ constructed,
+ };
+
+ pub const Identifier = packed struct(u8) {
+ tag: Tag,
+ pc: PC,
+ class: Class,
+ };
+
+ pub const Tag = enum(u5) {
+ boolean = 1,
+ integer = 2,
+ bitstring = 3,
+ octetstring = 4,
+ null = 5,
+ object_identifier = 6,
+ sequence = 16,
+ sequence_of = 17,
+ utc_time = 23,
+ generalized_time = 24,
+ _,
+ };
+
+ pub const Element = struct {
+ identifier: Identifier,
+ slice: Slice,
+
+ pub const Slice = struct {
+ start: u32,
+ end: u32,
+
+ pub const empty: Slice = .{ .start = 0, .end = 0 };
+ };
+
+ pub const ParseError = error{CertificateFieldHasInvalidLength};
+
+ pub fn parse(bytes: []const u8, index: u32) ParseError!Element {
+ var i = index;
+ const identifier = @bitCast(Identifier, bytes[i]);
+ i += 1;
+ const size_byte = bytes[i];
+ i += 1;
+ if ((size_byte >> 7) == 0) {
+ return .{
+ .identifier = identifier,
+ .slice = .{
+ .start = i,
+ .end = i + size_byte,
+ },
+ };
+ }
+
+ const len_size = @truncate(u7, size_byte);
+ if (len_size > @sizeOf(u32)) {
+ return error.CertificateFieldHasInvalidLength;
+ }
+
+ const end_i = i + len_size;
+ var long_form_size: u32 = 0;
+ while (i < end_i) : (i += 1) {
+ long_form_size = (long_form_size << 8) | bytes[i];
+ }
+
+ return .{
+ .identifier = identifier,
+ .slice = .{
+ .start = i,
+ .end = i + long_form_size,
+ },
+ };
+ }
+ };
+};
+
+test {
+ _ = Bundle;
+}
+
+/// TODO: replace this with Frank's upcoming RSA implementation. the verify
+/// function won't have the possibility of failure - it will either identify a
+/// valid signature or an invalid signature.
+/// This code is borrowed from https://github.com/shiguredo/tls13-zig
+/// which is licensed under the Apache License Version 2.0, January 2004
+/// http://www.apache.org/licenses/
+/// The code has been modified.
+pub const rsa = struct {
+ const BigInt = std.math.big.int.Managed;
+
+ pub const PSSSignature = struct {
+ pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 {
+ var result = [1]u8{0} ** modulus_len;
+ std.mem.copy(u8, &result, msg);
+ return result;
+ }
+
+ pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void {
+ const mod_bits = try countBits(public_key.n.toConst(), allocator);
+ const em_dec = try encrypt(modulus_len, sig, public_key, allocator);
+
+ try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator);
+ }
+
+ fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void {
+ // TODO
+ // 1. If the length of M is greater than the input limitation for
+ // the hash function (2^61 - 1 octets for SHA-1), output
+ // "inconsistent" and stop.
+
+ // emLen = \ceil(emBits/8)
+ const emLen = ((emBit - 1) / 8) + 1;
+ std.debug.assert(emLen == em.len);
+
+ // 2. Let mHash = Hash(M), an octet string of length hLen.
+ var mHash: [Hash.digest_length]u8 = undefined;
+ Hash.hash(msg, &mHash, .{});
+
+ // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
+ if (emLen < Hash.digest_length + sLen + 2) {
+ return error.InvalidSignature;
+ }
+
+ // 4. If the rightmost octet of EM does not have hexadecimal value
+ // 0xbc, output "inconsistent" and stop.
+ if (em[em.len - 1] != 0xbc) {
+ return error.InvalidSignature;
+ }
+
+ // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM,
+ // and let H be the next hLen octets.
+ const maskedDB = em[0..(emLen - Hash.digest_length - 1)];
+ const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)];
+
+ // 6. If the leftmost 8emLen - emBits bits of the leftmost octet in
+ // maskedDB are not all equal to zero, output "inconsistent" and
+ // stop.
+ const zero_bits = emLen * 8 - emBit;
+ var mask: u8 = maskedDB[0];
+ var i: usize = 0;
+ while (i < 8 - zero_bits) : (i += 1) {
+ mask = mask >> 1;
+ }
+ if (mask != 0) {
+ return error.InvalidSignature;
+ }
+
+ // 7. Let dbMask = MGF(H, emLen - hLen - 1).
+ const mgf_len = emLen - Hash.digest_length - 1;
+ var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length);
+ defer allocator.free(mgf_out);
+ var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator);
+
+ // 8. Let DB = maskedDB \xor dbMask.
+ i = 0;
+ while (i < dbMask.len) : (i += 1) {
+ dbMask[i] = maskedDB[i] ^ dbMask[i];
+ }
+
+ // 9. Set the leftmost 8emLen - emBits bits of the leftmost octet
+ // in DB to zero.
+ i = 0;
+ mask = 0;
+ while (i < 8 - zero_bits) : (i += 1) {
+ mask = mask << 1;
+ mask += 1;
+ }
+ dbMask[0] = dbMask[0] & mask;
+
+ // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not
+ // zero or if the octet at position emLen - hLen - sLen - 1 (the
+ // leftmost position is "position 1") does not have hexadecimal
+ // value 0x01, output "inconsistent" and stop.
+ if (dbMask[mgf_len - sLen - 2] != 0x00) {
+ return error.InvalidSignature;
+ }
+
+ if (dbMask[mgf_len - sLen - 1] != 0x01) {
+ return error.InvalidSignature;
+ }
+
+ // 11. Let salt be the last sLen octets of DB.
+ const salt = dbMask[(mgf_len - sLen)..];
+
+ // 12. Let
+ // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
+ // M' is an octet string of length 8 + hLen + sLen with eight
+ // initial zero octets.
+ var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen);
+ defer allocator.free(m_p);
+ std.mem.copy(u8, m_p, &([_]u8{0} ** 8));
+ std.mem.copy(u8, m_p[8..], &mHash);
+ std.mem.copy(u8, m_p[(8 + Hash.digest_length)..], salt);
+
+ // 13. Let H' = Hash(M'), an octet string of length hLen.
+ var h_p: [Hash.digest_length]u8 = undefined;
+ Hash.hash(m_p, &h_p, .{});
+
+ // 14. If H = H', output "consistent". Otherwise, output
+ // "inconsistent".
+ if (!std.mem.eql(u8, h, &h_p)) {
+ return error.InvalidSignature;
+ }
+ }
+
+ fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 {
+ var counter: usize = 0;
+ var idx: usize = 0;
+ var c: [4]u8 = undefined;
+
+ var hash = try allocator.alloc(u8, seed.len + c.len);
+ defer allocator.free(hash);
+ std.mem.copy(u8, hash, seed);
+ var hashed: [Hash.digest_length]u8 = undefined;
+
+ while (idx < len) {
+ c[0] = @intCast(u8, (counter >> 24) & 0xFF);
+ c[1] = @intCast(u8, (counter >> 16) & 0xFF);
+ c[2] = @intCast(u8, (counter >> 8) & 0xFF);
+ c[3] = @intCast(u8, counter & 0xFF);
+
+ std.mem.copy(u8, hash[seed.len..], &c);
+ Hash.hash(hash, &hashed, .{});
+
+ std.mem.copy(u8, out[idx..], &hashed);
+ idx += hashed.len;
+
+ counter += 1;
+ }
+
+ return out[0..len];
+ }
+ };
+
+ pub const PublicKey = struct {
+ n: BigInt,
+ e: BigInt,
+
+ pub fn deinit(self: *PublicKey) void {
+ self.n.deinit();
+ self.e.deinit();
+ }
+
+ pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey {
+ var _n = try BigInt.init(allocator);
+ errdefer _n.deinit();
+ try setBytes(&_n, modulus_bytes, allocator);
+
+ var _e = try BigInt.init(allocator);
+ errdefer _e.deinit();
+ try setBytes(&_e, pub_bytes, allocator);
+
+ return .{
+ .n = _n,
+ .e = _e,
+ };
+ }
+
+ pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } {
+ const pub_key_seq = try der.Element.parse(pub_key, 0);
+ if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
+ const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
+ if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
+ const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
+ if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
+ // Skip over meaningless zeroes in the modulus.
+ const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
+ const modulus_offset = for (modulus_raw) |byte, i| {
+ if (byte != 0) break i;
+ } else modulus_raw.len;
+ return .{
+ .modulus = modulus_raw[modulus_offset..],
+ .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end],
+ };
+ }
+ };
+
+ fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 {
+ var m = try BigInt.init(allocator);
+ defer m.deinit();
+
+ try setBytes(&m, &msg, allocator);
+
+ if (m.order(public_key.n) != .lt) {
+ return error.MessageTooLong;
+ }
+
+ var e = try BigInt.init(allocator);
+ defer e.deinit();
+
+ try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator);
+
+ var res: [modulus_len]u8 = undefined;
+
+ try toBytes(&res, &e, allocator);
+
+ return res;
+ }
+
+ fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void {
+ try r.set(0);
+ var tmp = try BigInt.init(allcator);
+ defer tmp.deinit();
+ for (bytes) |b| {
+ try r.shiftLeft(r, 8);
+ try tmp.set(b);
+ try r.add(r, &tmp);
+ }
+ }
+
+ fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void {
+ var bin_raw: [512]u8 = undefined;
+ try toBytes(&bin_raw, x, allocator);
+
+ var i: usize = 0;
+ while (bin_raw[i] == 0x00) : (i += 1) {}
+ const bin = bin_raw[i..];
+
+ try r.set(1);
+ var r1 = try BigInt.init(allocator);
+ defer r1.deinit();
+ try BigInt.copy(&r1, a.toConst());
+ i = 0;
+ while (i < bin.len * 8) : (i += 1) {
+ if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) {
+ try BigInt.mul(&r1, r, &r1);
+ try mod(&r1, &r1, n, allocator);
+ try BigInt.sqr(r, r);
+ try mod(r, r, n, allocator);
+ } else {
+ try BigInt.mul(r, r, &r1);
+ try mod(r, r, n, allocator);
+ try BigInt.sqr(&r1, &r1);
+ try mod(&r1, &r1, n, allocator);
+ }
+ }
+ }
+
+ fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void {
+ const Error = error{
+ BufferTooSmall,
+ };
+
+ var mask = try BigInt.initSet(allocator, 0xFF);
+ defer mask.deinit();
+ var tmp = try BigInt.init(allocator);
+ defer tmp.deinit();
+
+ var a_copy = try BigInt.init(allocator);
+ defer a_copy.deinit();
+ try a_copy.copy(a.toConst());
+
+ // Encoding into big-endian bytes
+ var i: usize = 0;
+ while (i < out.len) : (i += 1) {
+ try tmp.bitAnd(&a_copy, &mask);
+ const b = try tmp.to(u8);
+ out[out.len - i - 1] = b;
+ try a_copy.shiftRight(&a_copy, 8);
+ }
+
+ if (!a_copy.eqZero()) {
+ return Error.BufferTooSmall;
+ }
+ }
+
+ fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void {
+ var q = try BigInt.init(allocator);
+ defer q.deinit();
+
+ try BigInt.divFloor(&q, rem, a, n);
+ }
+
+ fn countBits(a: std.math.big.int.Const, allocator: std.mem.Allocator) !usize {
+ var i: usize = 0;
+ var a_copy = try BigInt.init(allocator);
+ defer a_copy.deinit();
+ try a_copy.copy(a);
+
+ while (!a_copy.eqZero()) {
+ try a_copy.shiftRight(&a_copy, 1);
+ i += 1;
+ }
+
+ return i;
+ }
+};
diff --git a/lib/std/crypto/Certificate/Bundle.zig b/lib/std/crypto/Certificate/Bundle.zig
@@ -0,0 +1,189 @@
+//! A set of certificates. Typically pre-installed on every operating system,
+//! these are "Certificate Authorities" used to validate SSL certificates.
+//! This data structure stores certificates in DER-encoded form, all of them
+//! concatenated together in the `bytes` array. The `map` field contains an
+//! index from the DER-encoded subject name to the index of the containing
+//! certificate within `bytes`.
+
+/// The key is the contents slice of the subject.
+map: std.HashMapUnmanaged(der.Element.Slice, u32, MapContext, std.hash_map.default_max_load_percentage) = .{},
+bytes: std.ArrayListUnmanaged(u8) = .{},
+
+pub const VerifyError = Certificate.Parsed.VerifyError || error{
+ CertificateIssuerNotFound,
+};
+
+pub fn verify(cb: Bundle, subject: Certificate.Parsed, now_sec: i64) VerifyError!void {
+ const bytes_index = cb.find(subject.issuer()) orelse return error.CertificateIssuerNotFound;
+ const issuer_cert: Certificate = .{
+ .buffer = cb.bytes.items,
+ .index = bytes_index,
+ };
+ // Every certificate in the bundle is pre-parsed before adding it, ensuring
+ // that parsing will succeed here.
+ const issuer = issuer_cert.parse() catch unreachable;
+ try subject.verify(issuer, now_sec);
+}
+
+/// The returned bytes become invalid after calling any of the rescan functions
+/// or add functions.
+pub fn find(cb: Bundle, subject_name: []const u8) ?u32 {
+ const Adapter = struct {
+ cb: Bundle,
+
+ pub fn hash(ctx: @This(), k: []const u8) u64 {
+ _ = ctx;
+ return std.hash_map.hashString(k);
+ }
+
+ pub fn eql(ctx: @This(), a: []const u8, b_key: der.Element.Slice) bool {
+ const b = ctx.cb.bytes.items[b_key.start..b_key.end];
+ return mem.eql(u8, a, b);
+ }
+ };
+ return cb.map.getAdapted(subject_name, Adapter{ .cb = cb });
+}
+
+pub fn deinit(cb: *Bundle, gpa: Allocator) void {
+ cb.map.deinit(gpa);
+ cb.bytes.deinit(gpa);
+ cb.* = undefined;
+}
+
+/// Clears the set of certificates and then scans the host operating system
+/// file system standard locations for certificates.
+/// For operating systems that do not have standard CA installations to be
+/// found, this function clears the set of certificates.
+pub fn rescan(cb: *Bundle, gpa: Allocator) !void {
+ switch (builtin.os.tag) {
+ .linux => return rescanLinux(cb, gpa),
+ .windows => {
+ // TODO
+ },
+ .macos => {
+ // TODO
+ },
+ else => {},
+ }
+}
+
+pub fn rescanLinux(cb: *Bundle, gpa: Allocator) !void {
+ var dir = fs.openIterableDirAbsolute("/etc/ssl/certs", .{}) catch |err| switch (err) {
+ error.FileNotFound => return,
+ else => |e| return e,
+ };
+ defer dir.close();
+
+ cb.bytes.clearRetainingCapacity();
+ cb.map.clearRetainingCapacity();
+
+ var it = dir.iterate();
+ while (try it.next()) |entry| {
+ switch (entry.kind) {
+ .File, .SymLink => {},
+ else => continue,
+ }
+
+ try addCertsFromFile(cb, gpa, dir.dir, entry.name);
+ }
+
+ cb.bytes.shrinkAndFree(gpa, cb.bytes.items.len);
+}
+
+pub fn addCertsFromFile(
+ cb: *Bundle,
+ gpa: Allocator,
+ dir: fs.Dir,
+ sub_file_path: []const u8,
+) !void {
+ var file = try dir.openFile(sub_file_path, .{});
+ defer file.close();
+
+ const size = try file.getEndPos();
+
+ // We borrow `bytes` as a temporary buffer for the base64-encoded data.
+ // This is possible by computing the decoded length and reserving the space
+ // for the decoded bytes first.
+ const decoded_size_upper_bound = size / 4 * 3;
+ const needed_capacity = std.math.cast(u32, decoded_size_upper_bound + size) orelse
+ return error.CertificateAuthorityBundleTooBig;
+ try cb.bytes.ensureUnusedCapacity(gpa, needed_capacity);
+ const end_reserved = @intCast(u32, cb.bytes.items.len + decoded_size_upper_bound);
+ const buffer = cb.bytes.allocatedSlice()[end_reserved..];
+ const end_index = try file.readAll(buffer);
+ const encoded_bytes = buffer[0..end_index];
+
+ const begin_marker = "-----BEGIN CERTIFICATE-----";
+ const end_marker = "-----END CERTIFICATE-----";
+
+ const now_sec = std.time.timestamp();
+
+ var start_index: usize = 0;
+ while (mem.indexOfPos(u8, encoded_bytes, start_index, begin_marker)) |begin_marker_start| {
+ const cert_start = begin_marker_start + begin_marker.len;
+ const cert_end = mem.indexOfPos(u8, encoded_bytes, cert_start, end_marker) orelse
+ return error.MissingEndCertificateMarker;
+ start_index = cert_end + end_marker.len;
+ const encoded_cert = mem.trim(u8, encoded_bytes[cert_start..cert_end], " \t\r\n");
+ const decoded_start = @intCast(u32, cb.bytes.items.len);
+ const dest_buf = cb.bytes.allocatedSlice()[decoded_start..];
+ cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert);
+ // Even though we could only partially parse the certificate to find
+ // the subject name, we pre-parse all of them to make sure and only
+ // include in the bundle ones that we know will parse. This way we can
+ // use `catch unreachable` later.
+ const parsed_cert = try Certificate.parse(.{
+ .buffer = cb.bytes.items,
+ .index = decoded_start,
+ });
+ if (now_sec > parsed_cert.validity.not_after) {
+ // Ignore expired cert.
+ cb.bytes.items.len = decoded_start;
+ continue;
+ }
+ const gop = try cb.map.getOrPutContext(gpa, parsed_cert.subject_slice, .{ .cb = cb });
+ if (gop.found_existing) {
+ cb.bytes.items.len = decoded_start;
+ } else {
+ gop.value_ptr.* = decoded_start;
+ }
+ }
+}
+
+const builtin = @import("builtin");
+const std = @import("../../std.zig");
+const fs = std.fs;
+const mem = std.mem;
+const crypto = std.crypto;
+const Allocator = std.mem.Allocator;
+const Certificate = std.crypto.Certificate;
+const der = Certificate.der;
+const Bundle = @This();
+
+const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n");
+
+const MapContext = struct {
+ cb: *const Bundle,
+
+ pub fn hash(ctx: MapContext, k: der.Element.Slice) u64 {
+ return std.hash_map.hashString(ctx.cb.bytes.items[k.start..k.end]);
+ }
+
+ pub fn eql(ctx: MapContext, a: der.Element.Slice, b: der.Element.Slice) bool {
+ const bytes = ctx.cb.bytes.items;
+ return mem.eql(
+ u8,
+ bytes[a.start..a.end],
+ bytes[b.start..b.end],
+ );
+ }
+};
+
+test "scan for OS-provided certificates" {
+ if (builtin.os.tag == .wasi) return error.SkipZigTest;
+
+ var bundle: Bundle = .{};
+ defer bundle.deinit(std.testing.allocator);
+
+ try bundle.rescan(std.testing.allocator);
+}
diff --git a/lib/std/crypto/aegis.zig b/lib/std/crypto/aegis.zig
@@ -174,7 +174,7 @@ pub const Aegis128L = struct {
acc |= (computed_tag[j] ^ tag[j]);
}
if (acc != 0) {
- mem.set(u8, m, 0xaa);
+ @memset(m.ptr, undefined, m.len);
return error.AuthenticationFailed;
}
}
@@ -343,7 +343,7 @@ pub const Aegis256 = struct {
acc |= (computed_tag[j] ^ tag[j]);
}
if (acc != 0) {
- mem.set(u8, m, 0xaa);
+ @memset(m.ptr, undefined, m.len);
return error.AuthenticationFailed;
}
}
diff --git a/lib/std/crypto/aes_gcm.zig b/lib/std/crypto/aes_gcm.zig
@@ -91,7 +91,7 @@ fn AesGcm(comptime Aes: anytype) type {
acc |= (computed_tag[p] ^ tag[p]);
}
if (acc != 0) {
- mem.set(u8, m, 0xaa);
+ @memset(m.ptr, undefined, m.len);
return error.AuthenticationFailed;
}
diff --git a/lib/std/crypto/sha2.zig b/lib/std/crypto/sha2.zig
@@ -142,6 +142,11 @@ fn Sha2x32(comptime params: Sha2Params32) type {
d.total_len += b.len;
}
+ pub fn peek(d: Self) [digest_length]u8 {
+ var copy = d;
+ return copy.finalResult();
+ }
+
pub fn final(d: *Self, out: *[digest_length]u8) void {
// The buffer here will never be completely full.
mem.set(u8, d.buf[d.buf_len..], 0);
@@ -175,6 +180,12 @@ fn Sha2x32(comptime params: Sha2Params32) type {
}
}
+ pub fn finalResult(d: *Self) [digest_length]u8 {
+ var result: [digest_length]u8 = undefined;
+ d.final(&result);
+ return result;
+ }
+
const W = [64]u32{
0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5,
0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174,
@@ -621,6 +632,11 @@ fn Sha2x64(comptime params: Sha2Params64) type {
d.total_len += b.len;
}
+ pub fn peek(d: Self) [digest_length]u8 {
+ var copy = d;
+ return copy.finalResult();
+ }
+
pub fn final(d: *Self, out: *[digest_length]u8) void {
// The buffer here will never be completely full.
mem.set(u8, d.buf[d.buf_len..], 0);
@@ -654,6 +670,12 @@ fn Sha2x64(comptime params: Sha2Params64) type {
}
}
+ pub fn finalResult(d: *Self) [digest_length]u8 {
+ var result: [digest_length]u8 = undefined;
+ d.final(&result);
+ return result;
+ }
+
fn round(d: *Self, b: *const [128]u8) void {
var s: [80]u64 = undefined;
diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig
@@ -0,0 +1,494 @@
+//! Plaintext:
+//! * type: ContentType
+//! * legacy_record_version: u16 = 0x0303,
+//! * length: u16,
+//! - The length (in bytes) of the following TLSPlaintext.fragment. The
+//! length MUST NOT exceed 2^14 bytes.
+//! * fragment: opaque
+//! - the data being transmitted
+//!
+//! Ciphertext
+//! * ContentType opaque_type = application_data; /* 23 */
+//! * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */
+//! * uint16 length;
+//! * opaque encrypted_record[TLSCiphertext.length];
+//!
+//! Handshake:
+//! * type: HandshakeType
+//! * length: u24
+//! * data: opaque
+//!
+//! ServerHello:
+//! * ProtocolVersion legacy_version = 0x0303;
+//! * Random random;
+//! * opaque legacy_session_id_echo<0..32>;
+//! * CipherSuite cipher_suite;
+//! * uint8 legacy_compression_method = 0;
+//! * Extension extensions<6..2^16-1>;
+//!
+//! Extension:
+//! * ExtensionType extension_type;
+//! * opaque extension_data<0..2^16-1>;
+
+const std = @import("../std.zig");
+const Tls = @This();
+const net = std.net;
+const mem = std.mem;
+const crypto = std.crypto;
+const assert = std.debug.assert;
+
+pub const Client = @import("tls/Client.zig");
+
+pub const record_header_len = 5;
+pub const max_ciphertext_len = (1 << 14) + 256;
+pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len;
+pub const hello_retry_request_sequence = [32]u8{
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
+ 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
+};
+
+pub const close_notify_alert = [_]u8{
+ @enumToInt(AlertLevel.warning),
+ @enumToInt(AlertDescription.close_notify),
+};
+
+pub const ProtocolVersion = enum(u16) {
+ tls_1_2 = 0x0303,
+ tls_1_3 = 0x0304,
+ _,
+};
+
+pub const ContentType = enum(u8) {
+ invalid = 0,
+ change_cipher_spec = 20,
+ alert = 21,
+ handshake = 22,
+ application_data = 23,
+ _,
+};
+
+pub const HandshakeType = enum(u8) {
+ client_hello = 1,
+ server_hello = 2,
+ new_session_ticket = 4,
+ end_of_early_data = 5,
+ encrypted_extensions = 8,
+ certificate = 11,
+ certificate_request = 13,
+ certificate_verify = 15,
+ finished = 20,
+ key_update = 24,
+ message_hash = 254,
+ _,
+};
+
+pub const ExtensionType = enum(u16) {
+ /// RFC 6066
+ server_name = 0,
+ /// RFC 6066
+ max_fragment_length = 1,
+ /// RFC 6066
+ status_request = 5,
+ /// RFC 8422, 7919
+ supported_groups = 10,
+ /// RFC 8446
+ signature_algorithms = 13,
+ /// RFC 5764
+ use_srtp = 14,
+ /// RFC 6520
+ heartbeat = 15,
+ /// RFC 7301
+ application_layer_protocol_negotiation = 16,
+ /// RFC 6962
+ signed_certificate_timestamp = 18,
+ /// RFC 7250
+ client_certificate_type = 19,
+ /// RFC 7250
+ server_certificate_type = 20,
+ /// RFC 7685
+ padding = 21,
+ /// RFC 8446
+ pre_shared_key = 41,
+ /// RFC 8446
+ early_data = 42,
+ /// RFC 8446
+ supported_versions = 43,
+ /// RFC 8446
+ cookie = 44,
+ /// RFC 8446
+ psk_key_exchange_modes = 45,
+ /// RFC 8446
+ certificate_authorities = 47,
+ /// RFC 8446
+ oid_filters = 48,
+ /// RFC 8446
+ post_handshake_auth = 49,
+ /// RFC 8446
+ signature_algorithms_cert = 50,
+ /// RFC 8446
+ key_share = 51,
+
+ _,
+};
+
+pub const AlertLevel = enum(u8) {
+ warning = 1,
+ fatal = 2,
+ _,
+};
+
+pub const AlertDescription = enum(u8) {
+ close_notify = 0,
+ unexpected_message = 10,
+ bad_record_mac = 20,
+ record_overflow = 22,
+ handshake_failure = 40,
+ bad_certificate = 42,
+ unsupported_certificate = 43,
+ certificate_revoked = 44,
+ certificate_expired = 45,
+ certificate_unknown = 46,
+ illegal_parameter = 47,
+ unknown_ca = 48,
+ access_denied = 49,
+ decode_error = 50,
+ decrypt_error = 51,
+ protocol_version = 70,
+ insufficient_security = 71,
+ internal_error = 80,
+ inappropriate_fallback = 86,
+ user_canceled = 90,
+ missing_extension = 109,
+ unsupported_extension = 110,
+ unrecognized_name = 112,
+ bad_certificate_status_response = 113,
+ unknown_psk_identity = 115,
+ certificate_required = 116,
+ no_application_protocol = 120,
+ _,
+};
+
+pub const SignatureScheme = enum(u16) {
+ // RSASSA-PKCS1-v1_5 algorithms
+ rsa_pkcs1_sha256 = 0x0401,
+ rsa_pkcs1_sha384 = 0x0501,
+ rsa_pkcs1_sha512 = 0x0601,
+
+ // ECDSA algorithms
+ ecdsa_secp256r1_sha256 = 0x0403,
+ ecdsa_secp384r1_sha384 = 0x0503,
+ ecdsa_secp521r1_sha512 = 0x0603,
+
+ // RSASSA-PSS algorithms with public key OID rsaEncryption
+ rsa_pss_rsae_sha256 = 0x0804,
+ rsa_pss_rsae_sha384 = 0x0805,
+ rsa_pss_rsae_sha512 = 0x0806,
+
+ // EdDSA algorithms
+ ed25519 = 0x0807,
+ ed448 = 0x0808,
+
+ // RSASSA-PSS algorithms with public key OID RSASSA-PSS
+ rsa_pss_pss_sha256 = 0x0809,
+ rsa_pss_pss_sha384 = 0x080a,
+ rsa_pss_pss_sha512 = 0x080b,
+
+ // Legacy algorithms
+ rsa_pkcs1_sha1 = 0x0201,
+ ecdsa_sha1 = 0x0203,
+
+ _,
+};
+
+pub const NamedGroup = enum(u16) {
+ // Elliptic Curve Groups (ECDHE)
+ secp256r1 = 0x0017,
+ secp384r1 = 0x0018,
+ secp521r1 = 0x0019,
+ x25519 = 0x001D,
+ x448 = 0x001E,
+
+ // Finite Field Groups (DHE)
+ ffdhe2048 = 0x0100,
+ ffdhe3072 = 0x0101,
+ ffdhe4096 = 0x0102,
+ ffdhe6144 = 0x0103,
+ ffdhe8192 = 0x0104,
+
+ _,
+};
+
+pub const CipherSuite = enum(u16) {
+ AES_128_GCM_SHA256 = 0x1301,
+ AES_256_GCM_SHA384 = 0x1302,
+ CHACHA20_POLY1305_SHA256 = 0x1303,
+ AES_128_CCM_SHA256 = 0x1304,
+ AES_128_CCM_8_SHA256 = 0x1305,
+ AEGIS_256_SHA384 = 0x1306,
+ AEGIS_128L_SHA256 = 0x1307,
+ _,
+};
+
+pub const CertificateType = enum(u8) {
+ X509 = 0,
+ RawPublicKey = 2,
+ _,
+};
+
+pub const KeyUpdateRequest = enum(u8) {
+ update_not_requested = 0,
+ update_requested = 1,
+ _,
+};
+
+pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type {
+ return struct {
+ pub const AEAD = AeadType;
+ pub const Hash = HashType;
+ pub const Hmac = crypto.auth.hmac.Hmac(Hash);
+ pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+ handshake_secret: [Hkdf.prk_length]u8,
+ master_secret: [Hkdf.prk_length]u8,
+ client_handshake_key: [AEAD.key_length]u8,
+ server_handshake_key: [AEAD.key_length]u8,
+ client_finished_key: [Hmac.key_length]u8,
+ server_finished_key: [Hmac.key_length]u8,
+ client_handshake_iv: [AEAD.nonce_length]u8,
+ server_handshake_iv: [AEAD.nonce_length]u8,
+ transcript_hash: Hash,
+ };
+}
+
+pub const HandshakeCipher = union(enum) {
+ AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256),
+ AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384),
+ CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256),
+ AEGIS_256_SHA384: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384),
+ AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256),
+};
+
+pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type {
+ return struct {
+ pub const AEAD = AeadType;
+ pub const Hash = HashType;
+ pub const Hmac = crypto.auth.hmac.Hmac(Hash);
+ pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+ client_secret: [Hash.digest_length]u8,
+ server_secret: [Hash.digest_length]u8,
+ client_key: [AEAD.key_length]u8,
+ server_key: [AEAD.key_length]u8,
+ client_iv: [AEAD.nonce_length]u8,
+ server_iv: [AEAD.nonce_length]u8,
+ };
+}
+
+/// Encryption parameters for application traffic.
+pub const ApplicationCipher = union(enum) {
+ AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256),
+ AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384),
+ CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256),
+ AEGIS_256_SHA384: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384),
+ AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256),
+};
+
+pub fn hkdfExpandLabel(
+ comptime Hkdf: type,
+ key: [Hkdf.prk_length]u8,
+ label: []const u8,
+ context: []const u8,
+ comptime len: usize,
+) [len]u8 {
+ const max_label_len = 255;
+ const max_context_len = 255;
+ const tls13 = "tls13 ";
+ var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined;
+ mem.writeIntBig(u16, buf[0..2], len);
+ buf[2] = @intCast(u8, tls13.len + label.len);
+ buf[3..][0..tls13.len].* = tls13.*;
+ var i: usize = 3 + tls13.len;
+ mem.copy(u8, buf[i..], label);
+ i += label.len;
+ buf[i] = @intCast(u8, context.len);
+ i += 1;
+ mem.copy(u8, buf[i..], context);
+ i += context.len;
+
+ var result: [len]u8 = undefined;
+ Hkdf.expand(&result, buf[0..i], key);
+ return result;
+}
+
+pub fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 {
+ var result: [Hash.digest_length]u8 = undefined;
+ Hash.hash(&.{}, &result, .{});
+ return result;
+}
+
+pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 {
+ var result: [Hmac.mac_length]u8 = undefined;
+ Hmac.create(&result, message, &key);
+ return result;
+}
+
+pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 {
+ return int2(@enumToInt(et)) ++ array(1, bytes);
+}
+
+pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 {
+ comptime assert(bytes.len % elem_size == 0);
+ return int2(bytes.len) ++ bytes;
+}
+
+pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 {
+ assert(@sizeOf(E) == 2);
+ var result: [tags.len * 2]u8 = undefined;
+ for (tags) |elem, i| {
+ result[i * 2] = @truncate(u8, @enumToInt(elem) >> 8);
+ result[i * 2 + 1] = @truncate(u8, @enumToInt(elem));
+ }
+ return array(2, result);
+}
+
+pub inline fn int2(x: u16) [2]u8 {
+ return .{
+ @truncate(u8, x >> 8),
+ @truncate(u8, x),
+ };
+}
+
+pub inline fn int3(x: u24) [3]u8 {
+ return .{
+ @truncate(u8, x >> 16),
+ @truncate(u8, x >> 8),
+ @truncate(u8, x),
+ };
+}
+
+/// An abstraction to ensure that protocol-parsing code does not perform an
+/// out-of-bounds read.
+pub const Decoder = struct {
+ buf: []u8,
+ /// Points to the next byte in buffer that will be decoded.
+ idx: usize = 0,
+ /// Up to this point in `buf` we have already checked that `cap` is greater than it.
+ our_end: usize = 0,
+ /// Beyond this point in `buf` is extra tag-along bytes beyond the amount we
+ /// requested with `readAtLeast`.
+ their_end: usize = 0,
+ /// Points to the end within buffer that has been filled. Beyond this point
+ /// in buf is undefined bytes.
+ cap: usize = 0,
+ /// Debug helper to prevent illegal calls to read functions.
+ disable_reads: bool = false,
+
+ pub fn fromTheirSlice(buf: []u8) Decoder {
+ return .{
+ .buf = buf,
+ .their_end = buf.len,
+ .cap = buf.len,
+ .disable_reads = true,
+ };
+ }
+
+ /// Use this function to increase `their_end`.
+ pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
+ assert(!d.disable_reads);
+ const existing_amt = d.cap - d.idx;
+ d.their_end = d.idx + their_amt;
+ if (their_amt <= existing_amt) return;
+ const request_amt = their_amt - existing_amt;
+ const dest = d.buf[d.cap..];
+ if (request_amt > dest.len) return error.TlsRecordOverflow;
+ const actual_amt = try stream.readAtLeast(dest, request_amt);
+ if (actual_amt < request_amt) return error.TlsConnectionTruncated;
+ d.cap += actual_amt;
+ }
+
+ /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
+ /// Use when `our_amt` is calculated by us, not by them.
+ pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
+ assert(!d.disable_reads);
+ try readAtLeast(d, stream, our_amt);
+ d.our_end = d.idx + our_amt;
+ }
+
+ /// Use this function to increase `our_end`.
+ /// This should always be called with an amount provided by us, not them.
+ pub fn ensure(d: *Decoder, amt: usize) !void {
+ d.our_end = @max(d.idx + amt, d.our_end);
+ if (d.our_end > d.their_end) return error.TlsDecodeError;
+ }
+
+ /// Use this function to increase `idx`.
+ pub fn decode(d: *Decoder, comptime T: type) T {
+ switch (@typeInfo(T)) {
+ .Int => |info| switch (info.bits) {
+ 8 => {
+ skip(d, 1);
+ return d.buf[d.idx - 1];
+ },
+ 16 => {
+ skip(d, 2);
+ const b0: u16 = d.buf[d.idx - 2];
+ const b1: u16 = d.buf[d.idx - 1];
+ return (b0 << 8) | b1;
+ },
+ 24 => {
+ skip(d, 3);
+ const b0: u24 = d.buf[d.idx - 3];
+ const b1: u24 = d.buf[d.idx - 2];
+ const b2: u24 = d.buf[d.idx - 1];
+ return (b0 << 16) | (b1 << 8) | b2;
+ },
+ else => @compileError("unsupported int type: " ++ @typeName(T)),
+ },
+ .Enum => |info| {
+ const int = d.decode(info.tag_type);
+ if (info.is_exhaustive) @compileError("exhaustive enum cannot be used");
+ return @intToEnum(T, int);
+ },
+ else => @compileError("unsupported type: " ++ @typeName(T)),
+ }
+ }
+
+ /// Use this function to increase `idx`.
+ pub fn array(d: *Decoder, comptime len: usize) *[len]u8 {
+ skip(d, len);
+ return d.buf[d.idx - len ..][0..len];
+ }
+
+ /// Use this function to increase `idx`.
+ pub fn slice(d: *Decoder, len: usize) []u8 {
+ skip(d, len);
+ return d.buf[d.idx - len ..][0..len];
+ }
+
+ /// Use this function to increase `idx`.
+ pub fn skip(d: *Decoder, amt: usize) void {
+ d.idx += amt;
+ assert(d.idx <= d.our_end); // insufficient ensured bytes
+ }
+
+ pub fn eof(d: Decoder) bool {
+ assert(d.our_end <= d.their_end);
+ assert(d.idx <= d.our_end);
+ return d.idx == d.their_end;
+ }
+
+ /// Provide the length they claim, and receive a sub-decoder specific to that slice.
+ /// The parent decoder is advanced to the end.
+ pub fn sub(d: *Decoder, their_len: usize) !Decoder {
+ const end = d.idx + their_len;
+ if (end > d.their_end) return error.TlsDecodeError;
+ const sub_buf = d.buf[d.idx..end];
+ d.idx = end;
+ d.our_end = end;
+ return fromTheirSlice(sub_buf);
+ }
+
+ pub fn rest(d: Decoder) []u8 {
+ return d.buf[d.idx..d.cap];
+ }
+};
diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig
@@ -0,0 +1,1308 @@
+const std = @import("../../std.zig");
+const tls = std.crypto.tls;
+const Client = @This();
+const net = std.net;
+const mem = std.mem;
+const crypto = std.crypto;
+const assert = std.debug.assert;
+const Certificate = std.crypto.Certificate;
+
+const max_ciphertext_len = tls.max_ciphertext_len;
+const hkdfExpandLabel = tls.hkdfExpandLabel;
+const int2 = tls.int2;
+const int3 = tls.int3;
+const array = tls.array;
+const enum_array = tls.enum_array;
+
+read_seq: u64,
+write_seq: u64,
+/// The starting index of cleartext bytes inside `partially_read_buffer`.
+partial_cleartext_idx: u15,
+/// The ending index of cleartext bytes inside `partially_read_buffer` as well
+/// as the starting index of ciphertext bytes.
+partial_ciphertext_idx: u15,
+/// The ending index of ciphertext bytes inside `partially_read_buffer`.
+partial_ciphertext_end: u15,
+/// When this is true, the stream may still not be at the end because there
+/// may be data in `partially_read_buffer`.
+received_close_notify: bool,
+/// By default, reaching the end-of-stream when reading from the server will
+/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
+/// message has been received. By setting this flag to `true`, instead, the
+/// end-of-stream will be forwarded to the application layer above TLS.
+/// This makes the application vulnerable to truncation attacks unless the
+/// application layer itself verifies that the amount of data received equals
+/// the amount of data expected, such as HTTP with the Content-Length header.
+allow_truncation_attacks: bool = false,
+application_cipher: tls.ApplicationCipher,
+/// The size is enough to contain exactly one TLSCiphertext record.
+/// This buffer is segmented into four parts:
+/// 0. unused
+/// 1. cleartext
+/// 2. ciphertext
+/// 3. unused
+/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
+/// `partial_ciphertext_end` describe the span of the segments.
+partially_read_buffer: [tls.max_ciphertext_record_len]u8,
+
+/// This is an example of the type that is needed by the read and write
+/// functions. It can have any fields but it must at least have these
+/// functions.
+///
+/// Note that `std.net.Stream` conforms to this interface.
+///
+/// This declaration serves as documentation only.
+pub const StreamInterface = struct {
+ /// Can be any error set.
+ pub const ReadError = error{};
+
+ /// Returns the number of bytes read. The number read may be less than the
+ /// buffer space provided. End-of-stream is indicated by a return value of 0.
+ ///
+ /// The `iovecs` parameter is mutable because so that function may to
+ /// mutate the fields in order to handle partial reads from the underlying
+ /// stream layer.
+ pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize {
+ _ = .{ this, iovecs };
+ @panic("unimplemented");
+ }
+
+ /// Can be any error set.
+ pub const WriteError = error{};
+
+ /// Returns the number of bytes read, which may be less than the buffer
+ /// space provided. A short read does not indicate end-of-stream.
+ pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize {
+ _ = .{ this, iovecs };
+ @panic("unimplemented");
+ }
+
+ /// Returns the number of bytes read, which may be less than the buffer
+ /// space provided, indicating end-of-stream.
+ /// The `iovecs` parameter is mutable in case this function needs to mutate
+ /// the fields in order to handle partial writes from the underlying layer.
+ pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize {
+ // This can be implemented in terms of writev, or specialized if desired.
+ _ = .{ this, iovecs };
+ @panic("unimplemented");
+ }
+};
+
+/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which
+/// must conform to `StreamInterface`.
+///
+/// `host` is only borrowed during this function call.
+pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !Client {
+ const host_len = @intCast(u16, host.len);
+
+ var random_buffer: [128]u8 = undefined;
+ crypto.random.bytes(&random_buffer);
+ const hello_rand = random_buffer[0..32].*;
+ const legacy_session_id = random_buffer[32..64].*;
+ const x25519_kp_seed = random_buffer[64..96].*;
+ const secp256r1_kp_seed = random_buffer[96..128].*;
+
+ const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) {
+ // Only possible to happen if the private key is all zeroes.
+ error.IdentityElement => return error.InsufficientEntropy,
+ };
+ const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) {
+ // Only possible to happen if the private key is all zeroes.
+ error.IdentityElement => return error.InsufficientEntropy,
+ };
+
+ const extensions_payload =
+ tls.extension(.supported_versions, [_]u8{
+ 0x02, // byte length of supported versions
+ 0x03, 0x04, // TLS 1.3
+ }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{
+ .ecdsa_secp256r1_sha256,
+ .ecdsa_secp384r1_sha384,
+ .ecdsa_secp521r1_sha512,
+ .rsa_pss_rsae_sha256,
+ .rsa_pss_rsae_sha384,
+ .rsa_pss_rsae_sha512,
+ .rsa_pkcs1_sha256,
+ .rsa_pkcs1_sha384,
+ .rsa_pkcs1_sha512,
+ .ed25519,
+ })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{
+ .secp256r1,
+ .x25519,
+ })) ++ tls.extension(
+ .key_share,
+ array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++
+ array(1, x25519_kp.public_key) ++
+ int2(@enumToInt(tls.NamedGroup.secp256r1)) ++
+ array(1, secp256r1_kp.public_key.toUncompressedSec1())),
+ ) ++
+ int2(@enumToInt(tls.ExtensionType.server_name)) ++
+ int2(host_len + 5) ++ // byte length of this extension payload
+ int2(host_len + 3) ++ // server_name_list byte count
+ [1]u8{0x00} ++ // name_type
+ int2(host_len);
+
+ const extensions_header =
+ int2(@intCast(u16, extensions_payload.len + host_len)) ++
+ extensions_payload;
+
+ const legacy_compression_methods = 0x0100;
+
+ const client_hello =
+ int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++
+ hello_rand ++
+ [1]u8{32} ++ legacy_session_id ++
+ cipher_suites ++
+ int2(legacy_compression_methods) ++
+ extensions_header;
+
+ const out_handshake =
+ [_]u8{@enumToInt(tls.HandshakeType.client_hello)} ++
+ int3(@intCast(u24, client_hello.len + host_len)) ++
+ client_hello;
+
+ const plaintext_header = [_]u8{
+ @enumToInt(tls.ContentType.handshake),
+ 0x03, 0x01, // legacy_record_version
+ } ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake;
+
+ {
+ var iovecs = [_]std.os.iovec_const{
+ .{
+ .iov_base = &plaintext_header,
+ .iov_len = plaintext_header.len,
+ },
+ .{
+ .iov_base = host.ptr,
+ .iov_len = host.len,
+ },
+ };
+ try stream.writevAll(&iovecs);
+ }
+
+ const client_hello_bytes1 = plaintext_header[5..];
+
+ var handshake_cipher: tls.HandshakeCipher = undefined;
+ var handshake_buffer: [8000]u8 = undefined;
+ var d: tls.Decoder = .{ .buf = &handshake_buffer };
+ {
+ try d.readAtLeastOurAmt(stream, tls.record_header_len);
+ const ct = d.decode(tls.ContentType);
+ d.skip(2); // legacy_record_version
+ const record_len = d.decode(u16);
+ try d.readAtLeast(stream, record_len);
+ const server_hello_fragment = d.buf[d.idx..][0..record_len];
+ var ptd = try d.sub(record_len);
+ switch (ct) {
+ .alert => {
+ try ptd.ensure(2);
+ const level = ptd.decode(tls.AlertLevel);
+ const desc = ptd.decode(tls.AlertDescription);
+ _ = level;
+ _ = desc;
+ return error.TlsAlert;
+ },
+ .handshake => {
+ try ptd.ensure(4);
+ const handshake_type = ptd.decode(tls.HandshakeType);
+ if (handshake_type != .server_hello) return error.TlsUnexpectedMessage;
+ const length = ptd.decode(u24);
+ var hsd = try ptd.sub(length);
+ try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2);
+ const legacy_version = hsd.decode(u16);
+ const random = hsd.array(32);
+ if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) {
+ // This is a HelloRetryRequest message. This client implementation
+ // does not expect to get one.
+ return error.TlsUnexpectedMessage;
+ }
+ const legacy_session_id_echo_len = hsd.decode(u8);
+ if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter;
+ const legacy_session_id_echo = hsd.array(32);
+ if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id))
+ return error.TlsIllegalParameter;
+ const cipher_suite_tag = hsd.decode(tls.CipherSuite);
+ hsd.skip(1); // legacy_compression_method
+ const extensions_size = hsd.decode(u16);
+ var all_extd = try hsd.sub(extensions_size);
+ var supported_version: u16 = 0;
+ var shared_key: [32]u8 = undefined;
+ var have_shared_key = false;
+ while (!all_extd.eof()) {
+ try all_extd.ensure(2 + 2);
+ const et = all_extd.decode(tls.ExtensionType);
+ const ext_size = all_extd.decode(u16);
+ var extd = try all_extd.sub(ext_size);
+ switch (et) {
+ .supported_versions => {
+ if (supported_version != 0) return error.TlsIllegalParameter;
+ try extd.ensure(2);
+ supported_version = extd.decode(u16);
+ },
+ .key_share => {
+ if (have_shared_key) return error.TlsIllegalParameter;
+ have_shared_key = true;
+ try extd.ensure(4);
+ const named_group = extd.decode(tls.NamedGroup);
+ const key_size = extd.decode(u16);
+ try extd.ensure(key_size);
+ switch (named_group) {
+ .x25519 => {
+ if (key_size != 32) return error.TlsIllegalParameter;
+ const server_pub_key = extd.array(32);
+
+ shared_key = crypto.dh.X25519.scalarmult(
+ x25519_kp.secret_key,
+ server_pub_key.*,
+ ) catch return error.TlsDecryptFailure;
+ },
+ .secp256r1 => {
+ const server_pub_key = extd.slice(key_size);
+
+ const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
+ const pk = PublicKey.fromSec1(server_pub_key) catch {
+ return error.TlsDecryptFailure;
+ };
+ const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .Big) catch {
+ return error.TlsDecryptFailure;
+ };
+ shared_key = mul.affineCoordinates().x.toBytes(.Big);
+ },
+ else => {
+ return error.TlsIllegalParameter;
+ },
+ }
+ },
+ else => {},
+ }
+ }
+ if (!have_shared_key) return error.TlsIllegalParameter;
+
+ const tls_version = if (supported_version == 0) legacy_version else supported_version;
+ if (tls_version != @enumToInt(tls.ProtocolVersion.tls_1_3))
+ return error.TlsIllegalParameter;
+
+ switch (cipher_suite_tag) {
+ inline .AES_128_GCM_SHA256,
+ .AES_256_GCM_SHA384,
+ .CHACHA20_POLY1305_SHA256,
+ .AEGIS_256_SHA384,
+ .AEGIS_128L_SHA256,
+ => |tag| {
+ const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag));
+ handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{
+ .handshake_secret = undefined,
+ .master_secret = undefined,
+ .client_handshake_key = undefined,
+ .server_handshake_key = undefined,
+ .client_finished_key = undefined,
+ .server_finished_key = undefined,
+ .client_handshake_iv = undefined,
+ .server_handshake_iv = undefined,
+ .transcript_hash = P.Hash.init(.{}),
+ });
+ const p = &@field(handshake_cipher, @tagName(tag));
+ p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
+ p.transcript_hash.update(host); // Client Hello part 2
+ p.transcript_hash.update(server_hello_fragment);
+ const hello_hash = p.transcript_hash.peek();
+ const zeroes = [1]u8{0} ** P.Hash.digest_length;
+ const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
+ const empty_hash = tls.emptyHash(P.Hash);
+ const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length);
+ p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, &shared_key);
+ const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length);
+ p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
+ const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
+ const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
+ p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length);
+ p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length);
+ p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
+ p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
+ p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
+ p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
+ },
+ else => {
+ return error.TlsIllegalParameter;
+ },
+ }
+ },
+ else => return error.TlsUnexpectedMessage,
+ }
+ }
+
+ // This is used for two purposes:
+ // * Detect whether a certificate is the first one presented, in which case
+ // we need to verify the host name.
+ // * Flip back and forth between the two cleartext buffers in order to keep
+ // the previous certificate in memory so that it can be verified by the
+ // next one.
+ var cert_index: usize = 0;
+ var read_seq: u64 = 0;
+ var prev_cert: Certificate.Parsed = undefined;
+ // Set to true once a trust chain has been established from the first
+ // certificate to a root CA.
+ const HandshakeState = enum {
+ /// In this state we expect only an encrypted_extensions message.
+ encrypted_extensions,
+ /// In this state we expect certificate messages.
+ certificate,
+ /// In this state we expect certificate or certificate_verify messages.
+ /// certificate messages are ignored since the trust chain is already
+ /// established.
+ trust_chain_established,
+ /// In this state, we expect only the finished message.
+ finished,
+ };
+ var handshake_state: HandshakeState = .encrypted_extensions;
+ var cleartext_bufs: [2][8000]u8 = undefined;
+ var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined;
+ var main_cert_pub_key_buf: [300]u8 = undefined;
+ var main_cert_pub_key_len: u16 = undefined;
+ const now_sec = std.time.timestamp();
+
+ while (true) {
+ try d.readAtLeastOurAmt(stream, tls.record_header_len);
+ const record_header = d.buf[d.idx..][0..5];
+ const ct = d.decode(tls.ContentType);
+ d.skip(2); // legacy_version
+ const record_len = d.decode(u16);
+ try d.readAtLeast(stream, record_len);
+ var record_decoder = try d.sub(record_len);
+ switch (ct) {
+ .change_cipher_spec => {
+ try record_decoder.ensure(1);
+ if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter;
+ },
+ .application_data => {
+ const cleartext_buf = &cleartext_bufs[cert_index % 2];
+
+ const cleartext = switch (handshake_cipher) {
+ inline else => |*p| c: {
+ const P = @TypeOf(p.*);
+ const ciphertext_len = record_len - P.AEAD.tag_length;
+ try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length);
+ const ciphertext = record_decoder.slice(ciphertext_len);
+ if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
+ const cleartext = cleartext_buf[0..ciphertext.len];
+ const auth_tag = record_decoder.array(P.AEAD.tag_length).*;
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
+ read_seq += 1;
+ const nonce = @as(V, p.server_handshake_iv) ^ operand;
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch
+ return error.TlsBadRecordMac;
+ break :c cleartext;
+ },
+ };
+
+ const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
+ if (inner_ct != .handshake) return error.TlsUnexpectedMessage;
+
+ var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]);
+ while (true) {
+ try ctd.ensure(4);
+ const handshake_type = ctd.decode(tls.HandshakeType);
+ const handshake_len = ctd.decode(u24);
+ var hsd = try ctd.sub(handshake_len);
+ const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
+ const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx];
+ switch (handshake_type) {
+ .encrypted_extensions => {
+ if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
+ handshake_state = .certificate;
+ switch (handshake_cipher) {
+ inline else => |*p| p.transcript_hash.update(wrapped_handshake),
+ }
+ try hsd.ensure(2);
+ const total_ext_size = hsd.decode(u16);
+ var all_extd = try hsd.sub(total_ext_size);
+ while (!all_extd.eof()) {
+ try all_extd.ensure(4);
+ const et = all_extd.decode(tls.ExtensionType);
+ const ext_size = all_extd.decode(u16);
+ var extd = try all_extd.sub(ext_size);
+ _ = extd;
+ switch (et) {
+ .server_name => {},
+ else => {},
+ }
+ }
+ },
+ .certificate => cert: {
+ switch (handshake_cipher) {
+ inline else => |*p| p.transcript_hash.update(wrapped_handshake),
+ }
+ switch (handshake_state) {
+ .certificate => {},
+ .trust_chain_established => break :cert,
+ else => return error.TlsUnexpectedMessage,
+ }
+ try hsd.ensure(1 + 4);
+ const cert_req_ctx_len = hsd.decode(u8);
+ if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
+ const certs_size = hsd.decode(u24);
+ var certs_decoder = try hsd.sub(certs_size);
+ while (!certs_decoder.eof()) {
+ try certs_decoder.ensure(3);
+ const cert_size = certs_decoder.decode(u24);
+ var certd = try certs_decoder.sub(cert_size);
+
+ const subject_cert: Certificate = .{
+ .buffer = certd.buf,
+ .index = @intCast(u32, certd.idx),
+ };
+ const subject = try subject_cert.parse();
+ if (cert_index == 0) {
+ // Verify the host on the first certificate.
+ try subject.verifyHostName(host);
+
+ // Keep track of the public key for the
+ // certificate_verify message later.
+ main_cert_pub_key_algo = subject.pub_key_algo;
+ const pub_key = subject.pubKey();
+ if (pub_key.len > main_cert_pub_key_buf.len)
+ return error.CertificatePublicKeyInvalid;
+ @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len);
+ main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len);
+ } else {
+ try prev_cert.verify(subject, now_sec);
+ }
+
+ if (ca_bundle.verify(subject, now_sec)) |_| {
+ handshake_state = .trust_chain_established;
+ break :cert;
+ } else |err| switch (err) {
+ error.CertificateIssuerNotFound => {},
+ else => |e| return e,
+ }
+
+ prev_cert = subject;
+ cert_index += 1;
+
+ try certs_decoder.ensure(2);
+ const total_ext_size = certs_decoder.decode(u16);
+ var all_extd = try certs_decoder.sub(total_ext_size);
+ _ = all_extd;
+ }
+ },
+ .certificate_verify => {
+ switch (handshake_state) {
+ .trust_chain_established => handshake_state = .finished,
+ .certificate => return error.TlsCertificateNotVerified,
+ else => return error.TlsUnexpectedMessage,
+ }
+
+ try hsd.ensure(4);
+ const scheme = hsd.decode(tls.SignatureScheme);
+ const sig_len = hsd.decode(u16);
+ try hsd.ensure(sig_len);
+ const encoded_sig = hsd.slice(sig_len);
+ const max_digest_len = 64;
+ var verify_buffer =
+ ([1]u8{0x20} ** 64) ++
+ "TLS 1.3, server CertificateVerify\x00".* ++
+ @as([max_digest_len]u8, undefined);
+
+ const verify_bytes = switch (handshake_cipher) {
+ inline else => |*p| v: {
+ const transcript_digest = p.transcript_hash.peek();
+ verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest;
+ p.transcript_hash.update(wrapped_handshake);
+ break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len];
+ },
+ };
+ const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len];
+
+ switch (scheme) {
+ inline .ecdsa_secp256r1_sha256,
+ .ecdsa_secp384r1_sha384,
+ => |comptime_scheme| {
+ if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey)
+ return error.TlsBadSignatureScheme;
+ const Ecdsa = SchemeEcdsa(comptime_scheme);
+ const sig = try Ecdsa.Signature.fromDer(encoded_sig);
+ const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key);
+ try sig.verify(verify_bytes, key);
+ },
+ .rsa_pss_rsae_sha256 => {
+ if (main_cert_pub_key_algo != .rsaEncryption)
+ return error.TlsBadSignatureScheme;
+
+ const Hash = crypto.hash.sha2.Sha256;
+ const rsa = Certificate.rsa;
+ const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
+ const exponent = components.exponent;
+ const modulus = components.modulus;
+ var rsa_mem_buf: [512 * 32]u8 = undefined;
+ var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf);
+ const ally = fba.allocator();
+ switch (modulus.len) {
+ inline 128, 256, 512 => |modulus_len| {
+ const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally);
+ const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
+ try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally);
+ },
+ else => {
+ return error.TlsBadRsaSignatureBitCount;
+ },
+ }
+ },
+ else => {
+ return error.TlsBadSignatureScheme;
+ },
+ }
+ },
+ .finished => {
+ if (handshake_state != .finished) return error.TlsUnexpectedMessage;
+ // This message is to trick buggy proxies into behaving correctly.
+ const client_change_cipher_spec_msg = [_]u8{
+ @enumToInt(tls.ContentType.change_cipher_spec),
+ 0x03, 0x03, // legacy protocol version
+ 0x00, 0x01, // length
+ 0x01,
+ };
+ const app_cipher = switch (handshake_cipher) {
+ inline else => |*p, tag| c: {
+ const P = @TypeOf(p.*);
+ const finished_digest = p.transcript_hash.peek();
+ p.transcript_hash.update(wrapped_handshake);
+ const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key);
+ if (!mem.eql(u8, &expected_server_verify_data, handshake))
+ return error.TlsDecryptError;
+ const handshake_hash = p.transcript_hash.finalResult();
+ const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
+ const out_cleartext = [_]u8{
+ @enumToInt(tls.HandshakeType.finished),
+ 0, 0, verify_data.len, // length
+ } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)};
+
+ const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
+
+ var finished_msg = [_]u8{
+ @enumToInt(tls.ContentType.application_data),
+ 0x03, 0x03, // legacy protocol version
+ 0, wrapped_len, // byte length of encrypted record
+ } ++ @as([wrapped_len]u8, undefined);
+
+ const ad = finished_msg[0..5];
+ const ciphertext = finished_msg[5..][0..out_cleartext.len];
+ const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..];
+ const nonce = p.client_handshake_iv;
+ P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
+
+ const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
+ try stream.writeAll(&both_msgs);
+
+ const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
+ const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
+ break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{
+ .client_secret = client_secret,
+ .server_secret = server_secret,
+ .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
+ .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
+ .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
+ .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
+ });
+ },
+ };
+ const leftover = d.rest();
+ var client: Client = .{
+ .read_seq = 0,
+ .write_seq = 0,
+ .partial_cleartext_idx = 0,
+ .partial_ciphertext_idx = 0,
+ .partial_ciphertext_end = @intCast(u15, leftover.len),
+ .received_close_notify = false,
+ .application_cipher = app_cipher,
+ .partially_read_buffer = undefined,
+ };
+ mem.copy(u8, &client.partially_read_buffer, leftover);
+ return client;
+ },
+ else => {
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ if (ctd.eof()) break;
+ }
+ },
+ else => {
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ }
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`.
+pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize {
+ return writeEnd(c, stream, bytes, false);
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void {
+ var index: usize = 0;
+ while (index < bytes.len) {
+ index += try c.write(stream, bytes[index..]);
+ }
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// If `end` is true, then this function additionally sends a `close_notify` alert,
+/// which is necessary for the server to distinguish between a properly finished
+/// TLS session, or a truncation attack.
+pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void {
+ var index: usize = 0;
+ while (index < bytes.len) {
+ index += try c.writeEnd(stream, bytes[index..], end);
+ }
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`.
+/// If `end` is true, then this function additionally sends a `close_notify` alert,
+/// which is necessary for the server to distinguish between a properly finished
+/// TLS session, or a truncation attack.
+pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
+ var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
+ var iovecs_buf: [6]std.os.iovec_const = undefined;
+ var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
+ if (end) {
+ prepared.iovec_end += prepareCiphertextRecord(
+ c,
+ iovecs_buf[prepared.iovec_end..],
+ ciphertext_buf[prepared.ciphertext_end..],
+ &tls.close_notify_alert,
+ .alert,
+ ).iovec_end;
+ }
+
+ const iovec_end = prepared.iovec_end;
+ const overhead_len = prepared.overhead_len;
+
+ // Ideally we would call writev exactly once here, however, we must ensure
+ // that we don't return with a record partially written.
+ var i: usize = 0;
+ var total_amt: usize = 0;
+ while (true) {
+ var amt = try stream.writev(iovecs_buf[i..iovec_end]);
+ while (amt >= iovecs_buf[i].iov_len) {
+ const encrypted_amt = iovecs_buf[i].iov_len;
+ total_amt += encrypted_amt - overhead_len;
+ amt -= encrypted_amt;
+ i += 1;
+ // Rely on the property that iovecs delineate records, meaning that
+ // if amt equals zero here, we have fortunately found ourselves
+ // with a short read that aligns at the record boundary.
+ if (i >= iovec_end) return total_amt;
+ // We also cannot return on a vector boundary if the final close_notify is
+ // not sent; otherwise the caller would not know to retry the call.
+ if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
+ }
+ iovecs_buf[i].iov_base += amt;
+ iovecs_buf[i].iov_len -= amt;
+ }
+}
+
+fn prepareCiphertextRecord(
+ c: *Client,
+ iovecs: []std.os.iovec_const,
+ ciphertext_buf: []u8,
+ bytes: []const u8,
+ inner_content_type: tls.ContentType,
+) struct {
+ iovec_end: usize,
+ ciphertext_end: usize,
+ /// How many bytes are taken up by overhead per record.
+ overhead_len: usize,
+} {
+ // Due to the trailing inner content type byte in the ciphertext, we need
+ // an additional buffer for storing the cleartext into before encrypting.
+ var cleartext_buf: [max_ciphertext_len]u8 = undefined;
+ var ciphertext_end: usize = 0;
+ var iovec_end: usize = 0;
+ var bytes_i: usize = 0;
+ switch (c.application_cipher) {
+ inline else => |*p| {
+ const P = @TypeOf(p.*);
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
+ const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
+ while (true) {
+ const encrypted_content_len = @intCast(u16, @min(
+ @min(bytes.len - bytes_i, max_ciphertext_len - 1),
+ ciphertext_buf.len - close_notify_alert_reserved -
+ overhead_len - ciphertext_end,
+ ));
+ if (encrypted_content_len == 0) return .{
+ .iovec_end = iovec_end,
+ .ciphertext_end = ciphertext_end,
+ .overhead_len = overhead_len,
+ };
+
+ mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
+ cleartext_buf[encrypted_content_len] = @enumToInt(inner_content_type);
+ bytes_i += encrypted_content_len;
+ const ciphertext_len = encrypted_content_len + 1;
+ const cleartext = cleartext_buf[0..ciphertext_len];
+
+ const record_start = ciphertext_end;
+ const ad = ciphertext_buf[ciphertext_end..][0..5];
+ ad.* =
+ [_]u8{@enumToInt(tls.ContentType.application_data)} ++
+ int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++
+ int2(ciphertext_len + P.AEAD.tag_length);
+ ciphertext_end += ad.len;
+ const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
+ ciphertext_end += ciphertext_len;
+ const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length];
+ ciphertext_end += auth_tag.len;
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ @bitCast([8]u8, big(c.write_seq));
+ c.write_seq += 1; // TODO send key_update on overflow
+ const nonce = @as(V, p.client_iv) ^ operand;
+ P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
+
+ const record = ciphertext_buf[record_start..ciphertext_end];
+ iovecs[iovec_end] = .{
+ .iov_base = record.ptr,
+ .iov_len = record.len,
+ };
+ iovec_end += 1;
+ }
+ },
+ }
+}
+
+pub fn eof(c: Client) bool {
+ return c.received_close_notify and
+ c.partial_cleartext_idx >= c.partial_ciphertext_idx and
+ c.partial_ciphertext_idx >= c.partial_ciphertext_end;
+}
+
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
+/// Returns the number of bytes read, calling the underlying read function the
+/// minimal number of times until the buffer has at least `len` bytes filled.
+/// If the number read is less than `len` it means the stream reached the end.
+/// Reaching the end of the stream is not an error condition.
+pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize {
+ var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }};
+ return readvAtLeast(c, stream, &iovecs, len);
+}
+
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
+pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
+ return readAtLeast(c, stream, buffer, 1);
+}
+
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
+/// Returns the number of bytes read. If the number read is smaller than
+/// `buffer.len`, it means the stream reached the end. Reaching the end of the
+/// stream is not an error condition.
+pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
+ return readAtLeast(c, stream, buffer, buffer.len);
+}
+
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
+/// Returns the number of bytes read. If the number read is less than the space
+/// provided it means the stream reached the end. Reaching the end of the
+/// stream is not an error condition.
+/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
+/// order to handle partial reads from the underlying stream layer.
+pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize {
+ return readvAtLeast(c, stream, iovecs);
+}
+
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
+/// Returns the number of bytes read, calling the underlying read function the
+/// minimal number of times until the iovecs have at least `len` bytes filled.
+/// If the number read is less than `len` it means the stream reached the end.
+/// Reaching the end of the stream is not an error condition.
+/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
+/// order to handle partial reads from the underlying stream layer.
+pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: usize) !usize {
+ if (c.eof()) return 0;
+
+ var off_i: usize = 0;
+ var vec_i: usize = 0;
+ while (true) {
+ var amt = try c.readvAdvanced(stream, iovecs[vec_i..]);
+ off_i += amt;
+ if (c.eof() or off_i >= len) return off_i;
+ while (amt >= iovecs[vec_i].iov_len) {
+ amt -= iovecs[vec_i].iov_len;
+ vec_i += 1;
+ }
+ iovecs[vec_i].iov_base += amt;
+ iovecs[vec_i].iov_len -= amt;
+ }
+}
+
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
+/// Returns number of bytes that have been read, populated inside `iovecs`. A
+/// return value of zero bytes does not mean end of stream. Instead, check the `eof()`
+/// for the end of stream. The `eof()` may be true after any call to
+/// `read`, including when greater than zero bytes are returned, and this
+/// function asserts that `eof()` is `false`.
+/// See `readv` for a higher level function that has the same, familiar API as
+/// other read functions, such as `std.fs.File.read`.
+pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize {
+ var vp: VecPut = .{ .iovecs = iovecs };
+
+ // Give away the buffered cleartext we have, if any.
+ const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx];
+ if (partial_cleartext.len > 0) {
+ const amt = @intCast(u15, vp.put(partial_cleartext));
+ c.partial_cleartext_idx += amt;
+ if (amt < partial_cleartext.len) {
+ // We still have cleartext left so we cannot issue another read() call yet.
+ assert(vp.total == amt);
+ return amt;
+ }
+ if (c.received_close_notify) {
+ c.partial_ciphertext_end = 0;
+ assert(vp.total == amt);
+ return amt;
+ }
+ if (c.partial_ciphertext_end == c.partial_ciphertext_idx) {
+ c.partial_cleartext_idx = 0;
+ c.partial_ciphertext_idx = 0;
+ c.partial_ciphertext_end = 0;
+ }
+ }
+
+ assert(!c.received_close_notify);
+
+ // Ideally, this buffer would never be used. It is needed when `iovecs` are
+ // too small to fit the cleartext, which may be as large as `max_ciphertext_len`.
+ var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
+ // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`.
+ var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined;
+ // How many bytes left in the user's buffer.
+ const free_size = vp.freeSize();
+ // The amount of the user's buffer that we need to repurpose for storing
+ // ciphertext. The end of the buffer will be used for such purposes.
+ const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len;
+ // The amount of the user's buffer that will be used to give cleartext. The
+ // beginning of the buffer will be used for such purposes.
+ const cleartext_buf_len = free_size - ciphertext_buf_len;
+ const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..];
+
+ var ask_iovecs_buf: [2]std.os.iovec = .{
+ .{
+ .iov_base = first_iov.ptr,
+ .iov_len = first_iov.len,
+ },
+ .{
+ .iov_base = &in_stack_buffer,
+ .iov_len = in_stack_buffer.len,
+ },
+ };
+
+ // Cleartext capacity of output buffer, in records, rounded up.
+ const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
+ const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
+ const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len);
+ const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
+ const actual_read_len = try stream.readv(ask_iovecs);
+ if (actual_read_len == 0) {
+ // This is either a truncation attack, a bug in the server, or an
+ // intentional omission of the close_notify message due to truncation
+ // detection handled above the TLS layer.
+ if (c.allow_truncation_attacks) {
+ c.received_close_notify = true;
+ } else {
+ return error.TlsConnectionTruncated;
+ }
+ }
+
+ // There might be more bytes inside `in_stack_buffer` that need to be processed,
+ // but at least frag0 will have one complete ciphertext record.
+ const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len);
+ const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end];
+ var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len];
+ // We need to decipher frag0 and frag1 but there may be a ciphertext record
+ // straddling the boundary. We can handle this with two memcpy() calls to
+ // assemble the straddling record in between handling the two sides.
+ var frag = frag0;
+ var in: usize = 0;
+ while (true) {
+ if (in == frag.len) {
+ // Perfect split.
+ if (frag.ptr == frag1.ptr) {
+ c.partial_ciphertext_end = c.partial_ciphertext_idx;
+ return vp.total;
+ }
+ frag = frag1;
+ in = 0;
+ continue;
+ }
+
+ if (in + tls.record_header_len > frag.len) {
+ if (frag.ptr == frag1.ptr)
+ return finishRead(c, frag, in, vp.total);
+
+ const first = frag[in..];
+
+ if (frag1.len < tls.record_header_len)
+ return finishRead2(c, first, frag1, vp.total);
+
+ // A record straddles the two fragments. Copy into the now-empty first fragment.
+ const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3);
+ const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4);
+ const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
+ if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
+
+ const full_record_len = record_len + tls.record_header_len;
+ const second_len = full_record_len - first.len;
+ if (frag1.len < second_len)
+ return finishRead2(c, first, frag1, vp.total);
+
+ mem.copy(u8, frag[0..in], first);
+ mem.copy(u8, frag[first.len..], frag1[0..second_len]);
+ frag = frag[0..full_record_len];
+ frag1 = frag1[second_len..];
+ in = 0;
+ continue;
+ }
+ const ct = @intToEnum(tls.ContentType, frag[in]);
+ in += 1;
+ const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
+ in += 2;
+ _ = legacy_version;
+ const record_len = mem.readIntBig(u16, frag[in..][0..2]);
+ if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
+ in += 2;
+ const end = in + record_len;
+ if (end > frag.len) {
+ // We need the record header on the next iteration of the loop.
+ in -= tls.record_header_len;
+
+ if (frag.ptr == frag1.ptr)
+ return finishRead(c, frag, in, vp.total);
+
+ // A record straddles the two fragments. Copy into the now-empty first fragment.
+ const first = frag[in..];
+ const full_record_len = record_len + tls.record_header_len;
+ const second_len = full_record_len - first.len;
+ if (frag1.len < second_len)
+ return finishRead2(c, first, frag1, vp.total);
+
+ mem.copy(u8, frag[0..in], first);
+ mem.copy(u8, frag[first.len..], frag1[0..second_len]);
+ frag = frag[0..full_record_len];
+ frag1 = frag1[second_len..];
+ in = 0;
+ continue;
+ }
+ switch (ct) {
+ .alert => {
+ if (in + 2 > frag.len) return error.TlsDecodeError;
+ const level = @intToEnum(tls.AlertLevel, frag[in]);
+ const desc = @intToEnum(tls.AlertDescription, frag[in + 1]);
+ _ = level;
+ _ = desc;
+ return error.TlsAlert;
+ },
+ .application_data => {
+ const cleartext = switch (c.application_cipher) {
+ inline else => |*p| c: {
+ const P = @TypeOf(p.*);
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const ad = frag[in - 5 ..][0..5];
+ const ciphertext_len = record_len - P.AEAD.tag_length;
+ const ciphertext = frag[in..][0..ciphertext_len];
+ in += ciphertext_len;
+ const auth_tag = frag[in..][0..P.AEAD.tag_length].*;
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq));
+ const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
+ const out_buf = vp.peek();
+ const cleartext_buf = if (ciphertext.len <= out_buf.len)
+ out_buf
+ else
+ &cleartext_stack_buffer;
+ const cleartext = cleartext_buf[0..ciphertext.len];
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
+ return error.TlsBadRecordMac;
+ break :c cleartext;
+ },
+ };
+
+ c.read_seq = try std.math.add(u64, c.read_seq, 1);
+
+ const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
+ switch (inner_ct) {
+ .alert => {
+ const level = @intToEnum(tls.AlertLevel, cleartext[0]);
+ const desc = @intToEnum(tls.AlertDescription, cleartext[1]);
+ if (desc == .close_notify) {
+ c.received_close_notify = true;
+ c.partial_ciphertext_end = c.partial_ciphertext_idx;
+ return vp.total;
+ }
+ _ = level;
+ return error.TlsAlert;
+ },
+ .handshake => {
+ var ct_i: usize = 0;
+ while (true) {
+ const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
+ ct_i += 1;
+ const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
+ ct_i += 3;
+ const next_handshake_i = ct_i + handshake_len;
+ if (next_handshake_i > cleartext.len - 1)
+ return error.TlsBadLength;
+ const handshake = cleartext[ct_i..next_handshake_i];
+ switch (handshake_type) {
+ .new_session_ticket => {
+ // This client implementation ignores new session tickets.
+ },
+ .key_update => {
+ switch (c.application_cipher) {
+ inline else => |*p| {
+ const P = @TypeOf(p.*);
+ const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length);
+ p.server_secret = server_secret;
+ p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
+ p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
+ },
+ }
+ c.read_seq = 0;
+
+ switch (@intToEnum(tls.KeyUpdateRequest, handshake[0])) {
+ .update_requested => {
+ switch (c.application_cipher) {
+ inline else => |*p| {
+ const P = @TypeOf(p.*);
+ const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length);
+ p.client_secret = client_secret;
+ p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
+ p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
+ },
+ }
+ c.write_seq = 0;
+ },
+ .update_not_requested => {},
+ _ => return error.TlsIllegalParameter,
+ }
+ },
+ else => {
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ ct_i = next_handshake_i;
+ if (ct_i >= cleartext.len - 1) break;
+ }
+ },
+ .application_data => {
+ // Determine whether the output buffer or a stack
+ // buffer was used for storing the cleartext.
+ if (cleartext.ptr == &cleartext_stack_buffer) {
+ // Stack buffer was used, so we must copy to the output buffer.
+ const msg = cleartext[0 .. cleartext.len - 1];
+ if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
+ // We have already run out of room in iovecs. Continue
+ // appending to `partially_read_buffer`.
+ const dest = c.partially_read_buffer[c.partial_ciphertext_idx..];
+ mem.copy(u8, dest, msg);
+ c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len);
+ } else {
+ const amt = vp.put(msg);
+ if (amt < msg.len) {
+ const rest = msg[amt..];
+ c.partial_cleartext_idx = 0;
+ c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len);
+ mem.copy(u8, &c.partially_read_buffer, rest);
+ }
+ }
+ } else {
+ // Output buffer was used directly which means no
+ // memory copying needs to occur, and we can move
+ // on to the next ciphertext record.
+ vp.next(cleartext.len - 1);
+ }
+ },
+ else => {
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ },
+ else => {
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ in = end;
+ }
+}
+
+fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
+ const saved_buf = frag[in..];
+ if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
+ // There is cleartext at the beginning already which we need to preserve.
+ c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + saved_buf.len);
+ mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], saved_buf);
+ } else {
+ c.partial_cleartext_idx = 0;
+ c.partial_ciphertext_idx = 0;
+ c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), saved_buf.len);
+ mem.copy(u8, &c.partially_read_buffer, saved_buf);
+ }
+ return out;
+}
+
+fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
+ if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
+ // There is cleartext at the beginning already which we need to preserve.
+ c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + first.len + frag1.len);
+ mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], first);
+ mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..], frag1);
+ } else {
+ c.partial_cleartext_idx = 0;
+ c.partial_ciphertext_idx = 0;
+ c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), first.len + frag1.len);
+ mem.copy(u8, &c.partially_read_buffer, first);
+ mem.copy(u8, c.partially_read_buffer[first.len..], frag1);
+ }
+ return out;
+}
+
+fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 {
+ if (index < s1.len) {
+ return s1[index];
+ } else {
+ return s2[index - s1.len];
+ }
+}
+
+const builtin = @import("builtin");
+const native_endian = builtin.cpu.arch.endian();
+
+inline fn big(x: anytype) @TypeOf(x) {
+ return switch (native_endian) {
+ .Big => x,
+ .Little => @byteSwap(x),
+ };
+}
+
+fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type {
+ return switch (scheme) {
+ .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256,
+ .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384,
+ .ecdsa_secp521r1_sha512 => crypto.sign.ecdsa.EcdsaP512Sha512,
+ else => @compileError("bad scheme"),
+ };
+}
+
+/// Abstraction for sending multiple byte buffers to a slice of iovecs.
+const VecPut = struct {
+ iovecs: []const std.os.iovec,
+ idx: usize = 0,
+ off: usize = 0,
+ total: usize = 0,
+
+ /// Returns the amount actually put which is always equal to bytes.len
+ /// unless the vectors ran out of space.
+ fn put(vp: *VecPut, bytes: []const u8) usize {
+ var bytes_i: usize = 0;
+ while (true) {
+ const v = vp.iovecs[vp.idx];
+ const dest = v.iov_base[vp.off..v.iov_len];
+ const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
+ mem.copy(u8, dest, src);
+ bytes_i += src.len;
+ vp.off += src.len;
+ if (vp.off >= v.iov_len) {
+ vp.off = 0;
+ vp.idx += 1;
+ if (vp.idx >= vp.iovecs.len) {
+ vp.total += bytes_i;
+ return bytes_i;
+ }
+ }
+ if (bytes_i >= bytes.len) {
+ vp.total += bytes_i;
+ return bytes_i;
+ }
+ }
+ }
+
+ /// Returns the next buffer that consecutive bytes can go into.
+ fn peek(vp: VecPut) []u8 {
+ if (vp.idx >= vp.iovecs.len) return &.{};
+ const v = vp.iovecs[vp.idx];
+ return v.iov_base[vp.off..v.iov_len];
+ }
+
+ // After writing to the result of peek(), one can call next() to
+ // advance the cursor.
+ fn next(vp: *VecPut, len: usize) void {
+ vp.total += len;
+ vp.off += len;
+ if (vp.off >= vp.iovecs[vp.idx].iov_len) {
+ vp.off = 0;
+ vp.idx += 1;
+ }
+ }
+
+ fn freeSize(vp: VecPut) usize {
+ if (vp.idx >= vp.iovecs.len) return 0;
+ var total: usize = 0;
+ total += vp.iovecs[vp.idx].iov_len - vp.off;
+ if (vp.idx + 1 >= vp.iovecs.len) return total;
+ for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len;
+ return total;
+ }
+};
+
+/// Limit iovecs to a specific byte size.
+fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec {
+ var vec_i: usize = 0;
+ var bytes_left: usize = len;
+ while (true) {
+ if (bytes_left >= iovecs[vec_i].iov_len) {
+ bytes_left -= iovecs[vec_i].iov_len;
+ vec_i += 1;
+ if (vec_i == iovecs.len or bytes_left == 0) return iovecs[0..vec_i];
+ continue;
+ }
+ iovecs[vec_i].iov_len = bytes_left;
+ return iovecs[0..vec_i];
+ }
+}
+
+/// The priority order here is chosen based on what crypto algorithms Zig has
+/// available in the standard library as well as what is faster. Following are
+/// a few data points on the relative performance of these algorithms.
+///
+/// Measurement taken with 0.11.0-dev.810+c2f5848fe
+/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz:
+/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast
+/// aegis-128l: 15382 MiB/s
+/// aegis-256: 9553 MiB/s
+/// aes128-gcm: 3721 MiB/s
+/// aes256-gcm: 3010 MiB/s
+/// chacha20Poly1305: 597 MiB/s
+///
+/// Measurement taken with 0.11.0-dev.810+c2f5848fe
+/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz:
+/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline
+/// aegis-128l: 629 MiB/s
+/// chacha20Poly1305: 529 MiB/s
+/// aegis-256: 461 MiB/s
+/// aes128-gcm: 138 MiB/s
+/// aes256-gcm: 120 MiB/s
+const cipher_suites = enum_array(tls.CipherSuite, &.{
+ .AEGIS_128L_SHA256,
+ .AEGIS_256_SHA384,
+ .AES_128_GCM_SHA256,
+ .AES_256_GCM_SHA384,
+ .CHACHA20_POLY1305_SHA256,
+});
+
+test {
+ _ = StreamInterface;
+}
diff --git a/lib/std/http.zig b/lib/std/http.zig
@@ -1,8 +1,301 @@
-const std = @import("std.zig");
+pub const Client = @import("http/Client.zig");
+
+/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
+/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton
+/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH
+pub const Method = enum {
+ GET,
+ HEAD,
+ POST,
+ PUT,
+ DELETE,
+ CONNECT,
+ OPTIONS,
+ TRACE,
+ PATCH,
+
+ /// Returns true if a request of this method is allowed to have a body
+ /// Actual behavior from servers may vary and should still be checked
+ pub fn requestHasBody(self: Method) bool {
+ return switch (self) {
+ .POST, .PUT, .PATCH => true,
+ .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false,
+ };
+ }
+
+ /// Returns true if a response to this method is allowed to have a body
+ /// Actual behavior from clients may vary and should still be checked
+ pub fn responseHasBody(self: Method) bool {
+ return switch (self) {
+ .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true,
+ .HEAD, .PUT, .TRACE => false,
+ };
+ }
+
+ /// An HTTP method is safe if it doesn't alter the state of the server.
+ /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
+ /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1
+ pub fn safe(self: Method) bool {
+ return switch (self) {
+ .GET, .HEAD, .OPTIONS, .TRACE => true,
+ .POST, .PUT, .DELETE, .CONNECT, .PATCH => false,
+ };
+ }
+
+ /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state.
+ /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent
+ /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2
+ pub fn idempotent(self: Method) bool {
+ return switch (self) {
+ .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true,
+ .CONNECT, .POST, .PATCH => false,
+ };
+ }
+
+ /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server.
+ /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable
+ /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3
+ pub fn cacheable(self: Method) bool {
+ return switch (self) {
+ .GET, .HEAD => true,
+ .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false,
+ };
+ }
+};
+
+/// https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
+pub const Status = enum(u10) {
+ @"continue" = 100, // RFC7231, Section 6.2.1
+ switching_protocols = 101, // RFC7231, Section 6.2.2
+ processing = 102, // RFC2518
+ early_hints = 103, // RFC8297
+
+ ok = 200, // RFC7231, Section 6.3.1
+ created = 201, // RFC7231, Section 6.3.2
+ accepted = 202, // RFC7231, Section 6.3.3
+ non_authoritative_info = 203, // RFC7231, Section 6.3.4
+ no_content = 204, // RFC7231, Section 6.3.5
+ reset_content = 205, // RFC7231, Section 6.3.6
+ partial_content = 206, // RFC7233, Section 4.1
+ multi_status = 207, // RFC4918
+ already_reported = 208, // RFC5842
+ im_used = 226, // RFC3229
+
+ multiple_choice = 300, // RFC7231, Section 6.4.1
+ moved_permanently = 301, // RFC7231, Section 6.4.2
+ found = 302, // RFC7231, Section 6.4.3
+ see_other = 303, // RFC7231, Section 6.4.4
+ not_modified = 304, // RFC7232, Section 4.1
+ use_proxy = 305, // RFC7231, Section 6.4.5
+ temporary_redirect = 307, // RFC7231, Section 6.4.7
+ permanent_redirect = 308, // RFC7538
+
+ bad_request = 400, // RFC7231, Section 6.5.1
+ unauthorized = 401, // RFC7235, Section 3.1
+ payment_required = 402, // RFC7231, Section 6.5.2
+ forbidden = 403, // RFC7231, Section 6.5.3
+ not_found = 404, // RFC7231, Section 6.5.4
+ method_not_allowed = 405, // RFC7231, Section 6.5.5
+ not_acceptable = 406, // RFC7231, Section 6.5.6
+ proxy_auth_required = 407, // RFC7235, Section 3.2
+ request_timeout = 408, // RFC7231, Section 6.5.7
+ conflict = 409, // RFC7231, Section 6.5.8
+ gone = 410, // RFC7231, Section 6.5.9
+ length_required = 411, // RFC7231, Section 6.5.10
+ precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2
+ payload_too_large = 413, // RFC7231, Section 6.5.11
+ uri_too_long = 414, // RFC7231, Section 6.5.12
+ unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3
+ range_not_satisfiable = 416, // RFC7233, Section 4.4
+ expectation_failed = 417, // RFC7231, Section 6.5.14
+ teapot = 418, // RFC 7168, 2.3.3
+ misdirected_request = 421, // RFC7540, Section 9.1.2
+ unprocessable_entity = 422, // RFC4918
+ locked = 423, // RFC4918
+ failed_dependency = 424, // RFC4918
+ too_early = 425, // RFC8470
+ upgrade_required = 426, // RFC7231, Section 6.5.15
+ precondition_required = 428, // RFC6585
+ too_many_requests = 429, // RFC6585
+ header_fields_too_large = 431, // RFC6585
+ unavailable_for_legal_reasons = 451, // RFC7725
+
+ internal_server_error = 500, // RFC7231, Section 6.6.1
+ not_implemented = 501, // RFC7231, Section 6.6.2
+ bad_gateway = 502, // RFC7231, Section 6.6.3
+ service_unavailable = 503, // RFC7231, Section 6.6.4
+ gateway_timeout = 504, // RFC7231, Section 6.6.5
+ http_version_not_supported = 505, // RFC7231, Section 6.6.6
+ variant_also_negotiates = 506, // RFC2295
+ insufficient_storage = 507, // RFC4918
+ loop_detected = 508, // RFC5842
+ not_extended = 510, // RFC2774
+ network_authentication_required = 511, // RFC6585
+
+ _,
+
+ pub fn phrase(self: Status) ?[]const u8 {
+ return switch (self) {
+ // 1xx statuses
+ .@"continue" => "Continue",
+ .switching_protocols => "Switching Protocols",
+ .processing => "Processing",
+ .early_hints => "Early Hints",
-pub const Method = @import("http/method.zig").Method;
-pub const Status = @import("http/status.zig").Status;
+ // 2xx statuses
+ .ok => "OK",
+ .created => "Created",
+ .accepted => "Accepted",
+ .non_authoritative_info => "Non-Authoritative Information",
+ .no_content => "No Content",
+ .reset_content => "Reset Content",
+ .partial_content => "Partial Content",
+ .multi_status => "Multi-Status",
+ .already_reported => "Already Reported",
+ .im_used => "IM Used",
+
+ // 3xx statuses
+ .multiple_choice => "Multiple Choice",
+ .moved_permanently => "Moved Permanently",
+ .found => "Found",
+ .see_other => "See Other",
+ .not_modified => "Not Modified",
+ .use_proxy => "Use Proxy",
+ .temporary_redirect => "Temporary Redirect",
+ .permanent_redirect => "Permanent Redirect",
+
+ // 4xx statuses
+ .bad_request => "Bad Request",
+ .unauthorized => "Unauthorized",
+ .payment_required => "Payment Required",
+ .forbidden => "Forbidden",
+ .not_found => "Not Found",
+ .method_not_allowed => "Method Not Allowed",
+ .not_acceptable => "Not Acceptable",
+ .proxy_auth_required => "Proxy Authentication Required",
+ .request_timeout => "Request Timeout",
+ .conflict => "Conflict",
+ .gone => "Gone",
+ .length_required => "Length Required",
+ .precondition_failed => "Precondition Failed",
+ .payload_too_large => "Payload Too Large",
+ .uri_too_long => "URI Too Long",
+ .unsupported_media_type => "Unsupported Media Type",
+ .range_not_satisfiable => "Range Not Satisfiable",
+ .expectation_failed => "Expectation Failed",
+ .teapot => "I'm a teapot",
+ .misdirected_request => "Misdirected Request",
+ .unprocessable_entity => "Unprocessable Entity",
+ .locked => "Locked",
+ .failed_dependency => "Failed Dependency",
+ .too_early => "Too Early",
+ .upgrade_required => "Upgrade Required",
+ .precondition_required => "Precondition Required",
+ .too_many_requests => "Too Many Requests",
+ .header_fields_too_large => "Request Header Fields Too Large",
+ .unavailable_for_legal_reasons => "Unavailable For Legal Reasons",
+
+ // 5xx statuses
+ .internal_server_error => "Internal Server Error",
+ .not_implemented => "Not Implemented",
+ .bad_gateway => "Bad Gateway",
+ .service_unavailable => "Service Unavailable",
+ .gateway_timeout => "Gateway Timeout",
+ .http_version_not_supported => "HTTP Version Not Supported",
+ .variant_also_negotiates => "Variant Also Negotiates",
+ .insufficient_storage => "Insufficient Storage",
+ .loop_detected => "Loop Detected",
+ .not_extended => "Not Extended",
+ .network_authentication_required => "Network Authentication Required",
+
+ else => return null,
+ };
+ }
+
+ pub const Class = enum {
+ informational,
+ success,
+ redirect,
+ client_error,
+ server_error,
+ };
+
+ pub fn class(self: Status) ?Class {
+ return switch (@enumToInt(self)) {
+ 100...199 => .informational,
+ 200...299 => .success,
+ 300...399 => .redirect,
+ 400...499 => .client_error,
+ 500...599 => .server_error,
+ else => null,
+ };
+ }
+
+ test {
+ try std.testing.expectEqualStrings("OK", Status.ok.phrase().?);
+ try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?);
+ }
+
+ test {
+ try std.testing.expectEqual(@as(?Status.Class, Status.Class.success), Status.ok.class());
+ try std.testing.expectEqual(@as(?Status.Class, Status.Class.client_error), Status.not_found.class());
+ }
+};
+
+pub const Headers = struct {
+ state: State = .start,
+ invalid_index: u32 = undefined,
+
+ pub const State = enum { invalid, start, line, nl_r, nl_n, nl2_r, finished };
+
+ /// Returns how many bytes are processed into headers. Always less than or
+ /// equal to bytes.len. If the amount returned is less than bytes.len, it
+ /// means the headers ended and the first byte after the double \r\n\r\n is
+ /// located at `bytes[result]`.
+ pub fn feed(h: *Headers, bytes: []const u8) usize {
+ for (bytes) |b, i| {
+ switch (h.state) {
+ .start => switch (b) {
+ '\r' => h.state = .nl_r,
+ '\n' => return invalid(h, i),
+ else => {},
+ },
+ .nl_r => switch (b) {
+ '\n' => h.state = .nl_n,
+ else => return invalid(h, i),
+ },
+ .nl_n => switch (b) {
+ '\r' => h.state = .nl2_r,
+ else => h.state = .line,
+ },
+ .nl2_r => switch (b) {
+ '\n' => h.state = .finished,
+ else => return invalid(h, i),
+ },
+ .line => switch (b) {
+ '\r' => h.state = .nl_r,
+ '\n' => return invalid(h, i),
+ else => {},
+ },
+ .invalid => return i,
+ .finished => return i,
+ }
+ }
+ return bytes.len;
+ }
+
+ fn invalid(h: *Headers, i: usize) usize {
+ h.invalid_index = @intCast(u32, i);
+ h.state = .invalid;
+ return i;
+ }
+};
+
+const std = @import("std.zig");
test {
- std.testing.refAllDecls(@This());
+ _ = Client;
+ _ = Method;
+ _ = Status;
+ _ = Headers;
}
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig
@@ -0,0 +1,181 @@
+//! This API is a barely-touched, barely-functional http client, just the
+//! absolute minimum thing I needed in order to test `std.crypto.tls`. Bear
+//! with me and I promise the API will become useful and streamlined.
+
+const std = @import("../std.zig");
+const assert = std.debug.assert;
+const http = std.http;
+const net = std.net;
+const Client = @This();
+const Url = std.Url;
+
+allocator: std.mem.Allocator,
+headers: std.ArrayListUnmanaged(u8) = .{},
+active_requests: usize = 0,
+ca_bundle: std.crypto.Certificate.Bundle = .{},
+
+/// TODO: emit error.UnexpectedEndOfStream or something like that when the read
+/// data does not match the content length. This is necessary since HTTPS disables
+/// close_notify protection on underlying TLS streams.
+pub const Request = struct {
+ client: *Client,
+ stream: net.Stream,
+ headers: std.ArrayListUnmanaged(u8) = .{},
+ tls_client: std.crypto.tls.Client,
+ protocol: Protocol,
+ response_headers: http.Headers = .{},
+
+ pub const Protocol = enum { http, https };
+
+ pub const Options = struct {
+ method: http.Method = .GET,
+ };
+
+ pub fn deinit(req: *Request) void {
+ req.client.active_requests -= 1;
+ req.headers.deinit(req.client.allocator);
+ req.* = undefined;
+ }
+
+ pub fn addHeader(req: *Request, name: []const u8, value: []const u8) !void {
+ const gpa = req.client.allocator;
+ // Ensure an extra +2 for the \r\n in end()
+ try req.headers.ensureUnusedCapacity(gpa, name.len + value.len + 6);
+ req.headers.appendSliceAssumeCapacity(name);
+ req.headers.appendSliceAssumeCapacity(": ");
+ req.headers.appendSliceAssumeCapacity(value);
+ req.headers.appendSliceAssumeCapacity("\r\n");
+ }
+
+ pub fn end(req: *Request) !void {
+ req.headers.appendSliceAssumeCapacity("\r\n");
+ switch (req.protocol) {
+ .http => {
+ try req.stream.writeAll(req.headers.items);
+ },
+ .https => {
+ try req.tls_client.writeAll(req.stream, req.headers.items);
+ },
+ }
+ }
+
+ pub fn readAll(req: *Request, buffer: []u8) !usize {
+ return readAtLeast(req, buffer, buffer.len);
+ }
+
+ pub fn read(req: *Request, buffer: []u8) !usize {
+ return readAtLeast(req, buffer, 1);
+ }
+
+ pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
+ assert(len <= buffer.len);
+ var index: usize = 0;
+ while (index < len) {
+ const headers_finished = req.response_headers.state == .finished;
+ const amt = try readAdvanced(req, buffer[index..]);
+ if (amt == 0 and headers_finished) break;
+ index += amt;
+ }
+ return index;
+ }
+
+ /// This one can return 0 without meaning EOF.
+ /// TODO change to readvAdvanced
+ pub fn readAdvanced(req: *Request, buffer: []u8) !usize {
+ if (req.response_headers.state == .finished) return readRaw(req, buffer);
+
+ const amt = try readRaw(req, buffer);
+ const data = buffer[0..amt];
+ const i = req.response_headers.feed(data);
+ if (req.response_headers.state == .invalid) return error.InvalidHttpHeaders;
+ if (i < data.len) {
+ const rest = data[i..];
+ std.mem.copy(u8, buffer, rest);
+ return rest.len;
+ }
+ return 0;
+ }
+
+ /// Only abstracts over http/https.
+ fn readRaw(req: *Request, buffer: []u8) !usize {
+ switch (req.protocol) {
+ .http => return req.stream.read(buffer),
+ .https => return req.tls_client.read(req.stream, buffer),
+ }
+ }
+
+ /// Only abstracts over http/https.
+ fn readAtLeastRaw(req: *Request, buffer: []u8, len: usize) !usize {
+ switch (req.protocol) {
+ .http => return req.stream.readAtLeast(buffer, len),
+ .https => return req.tls_client.readAtLeast(req.stream, buffer, len),
+ }
+ }
+};
+
+pub fn deinit(client: *Client) void {
+ assert(client.active_requests == 0);
+ client.headers.deinit(client.allocator);
+ client.* = undefined;
+}
+
+pub fn request(client: *Client, url: Url, options: Request.Options) !Request {
+ const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse
+ return error.UnsupportedUrlScheme;
+ const port: u16 = url.port orelse switch (protocol) {
+ .http => 80,
+ .https => 443,
+ };
+
+ var req: Request = .{
+ .client = client,
+ .stream = try net.tcpConnectToHost(client.allocator, url.host, port),
+ .protocol = protocol,
+ .tls_client = undefined,
+ };
+ client.active_requests += 1;
+ errdefer req.deinit();
+
+ switch (protocol) {
+ .http => {},
+ .https => {
+ req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, url.host);
+ // This is appropriate for HTTPS because the HTTP headers contain
+ // the content length which is used to detect truncation attacks.
+ req.tls_client.allow_truncation_attacks = true;
+ },
+ }
+
+ try req.headers.ensureUnusedCapacity(
+ client.allocator,
+ @tagName(options.method).len +
+ 1 +
+ url.path.len +
+ " HTTP/1.1\r\nHost: ".len +
+ url.host.len +
+ "\r\nUpgrade-Insecure-Requests: 1\r\n".len +
+ client.headers.items.len +
+ 2, // for the \r\n at the end of headers
+ );
+ req.headers.appendSliceAssumeCapacity(@tagName(options.method));
+ req.headers.appendSliceAssumeCapacity(" ");
+ req.headers.appendSliceAssumeCapacity(url.path);
+ req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: ");
+ req.headers.appendSliceAssumeCapacity(url.host);
+ switch (protocol) {
+ .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"),
+ .http => req.headers.appendSliceAssumeCapacity("\r\n"),
+ }
+ req.headers.appendSliceAssumeCapacity(client.headers.items);
+
+ return req;
+}
+
+pub fn addHeader(client: *Client, name: []const u8, value: []const u8) !void {
+ const gpa = client.allocator;
+ try client.headers.ensureUnusedCapacity(gpa, name.len + value.len + 4);
+ client.headers.appendSliceAssumeCapacity(name);
+ client.headers.appendSliceAssumeCapacity(": ");
+ client.headers.appendSliceAssumeCapacity(value);
+ client.headers.appendSliceAssumeCapacity("\r\n");
+}
diff --git a/lib/std/http/method.zig b/lib/std/http/method.zig
@@ -1,65 +0,0 @@
-//! HTTP Methods
-//! https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
-
-// Style guide is violated here so that @tagName can be used effectively
-/// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definiton
-/// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH
-pub const Method = enum {
- GET,
- HEAD,
- POST,
- PUT,
- DELETE,
- CONNECT,
- OPTIONS,
- TRACE,
- PATCH,
-
- /// Returns true if a request of this method is allowed to have a body
- /// Actual behavior from servers may vary and should still be checked
- pub fn requestHasBody(self: Method) bool {
- return switch (self) {
- .POST, .PUT, .PATCH => true,
- .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false,
- };
- }
-
- /// Returns true if a response to this method is allowed to have a body
- /// Actual behavior from clients may vary and should still be checked
- pub fn responseHasBody(self: Method) bool {
- return switch (self) {
- .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true,
- .HEAD, .PUT, .TRACE => false,
- };
- }
-
- /// An HTTP method is safe if it doesn't alter the state of the server.
- /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP
- /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1
- pub fn safe(self: Method) bool {
- return switch (self) {
- .GET, .HEAD, .OPTIONS, .TRACE => true,
- .POST, .PUT, .DELETE, .CONNECT, .PATCH => false,
- };
- }
-
- /// An HTTP method is idempotent if an identical request can be made once or several times in a row with the same effect while leaving the server in the same state.
- /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent
- /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2
- pub fn idempotent(self: Method) bool {
- return switch (self) {
- .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true,
- .CONNECT, .POST, .PATCH => false,
- };
- }
-
- /// A cacheable response is an HTTP response that can be cached, that is stored to be retrieved and used later, saving a new request to the server.
- /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable
- /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3
- pub fn cacheable(self: Method) bool {
- return switch (self) {
- .GET, .HEAD => true,
- .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false,
- };
- }
-};
diff --git a/lib/std/http/status.zig b/lib/std/http/status.zig
@@ -1,182 +0,0 @@
-//! HTTP Status
-//! https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
-
-const std = @import("../std.zig");
-
-pub const Status = enum(u10) {
- @"continue" = 100, // RFC7231, Section 6.2.1
- switching_protocols = 101, // RFC7231, Section 6.2.2
- processing = 102, // RFC2518
- early_hints = 103, // RFC8297
-
- ok = 200, // RFC7231, Section 6.3.1
- created = 201, // RFC7231, Section 6.3.2
- accepted = 202, // RFC7231, Section 6.3.3
- non_authoritative_info = 203, // RFC7231, Section 6.3.4
- no_content = 204, // RFC7231, Section 6.3.5
- reset_content = 205, // RFC7231, Section 6.3.6
- partial_content = 206, // RFC7233, Section 4.1
- multi_status = 207, // RFC4918
- already_reported = 208, // RFC5842
- im_used = 226, // RFC3229
-
- multiple_choice = 300, // RFC7231, Section 6.4.1
- moved_permanently = 301, // RFC7231, Section 6.4.2
- found = 302, // RFC7231, Section 6.4.3
- see_other = 303, // RFC7231, Section 6.4.4
- not_modified = 304, // RFC7232, Section 4.1
- use_proxy = 305, // RFC7231, Section 6.4.5
- temporary_redirect = 307, // RFC7231, Section 6.4.7
- permanent_redirect = 308, // RFC7538
-
- bad_request = 400, // RFC7231, Section 6.5.1
- unauthorized = 401, // RFC7235, Section 3.1
- payment_required = 402, // RFC7231, Section 6.5.2
- forbidden = 403, // RFC7231, Section 6.5.3
- not_found = 404, // RFC7231, Section 6.5.4
- method_not_allowed = 405, // RFC7231, Section 6.5.5
- not_acceptable = 406, // RFC7231, Section 6.5.6
- proxy_auth_required = 407, // RFC7235, Section 3.2
- request_timeout = 408, // RFC7231, Section 6.5.7
- conflict = 409, // RFC7231, Section 6.5.8
- gone = 410, // RFC7231, Section 6.5.9
- length_required = 411, // RFC7231, Section 6.5.10
- precondition_failed = 412, // RFC7232, Section 4.2][RFC8144, Section 3.2
- payload_too_large = 413, // RFC7231, Section 6.5.11
- uri_too_long = 414, // RFC7231, Section 6.5.12
- unsupported_media_type = 415, // RFC7231, Section 6.5.13][RFC7694, Section 3
- range_not_satisfiable = 416, // RFC7233, Section 4.4
- expectation_failed = 417, // RFC7231, Section 6.5.14
- teapot = 418, // RFC 7168, 2.3.3
- misdirected_request = 421, // RFC7540, Section 9.1.2
- unprocessable_entity = 422, // RFC4918
- locked = 423, // RFC4918
- failed_dependency = 424, // RFC4918
- too_early = 425, // RFC8470
- upgrade_required = 426, // RFC7231, Section 6.5.15
- precondition_required = 428, // RFC6585
- too_many_requests = 429, // RFC6585
- header_fields_too_large = 431, // RFC6585
- unavailable_for_legal_reasons = 451, // RFC7725
-
- internal_server_error = 500, // RFC7231, Section 6.6.1
- not_implemented = 501, // RFC7231, Section 6.6.2
- bad_gateway = 502, // RFC7231, Section 6.6.3
- service_unavailable = 503, // RFC7231, Section 6.6.4
- gateway_timeout = 504, // RFC7231, Section 6.6.5
- http_version_not_supported = 505, // RFC7231, Section 6.6.6
- variant_also_negotiates = 506, // RFC2295
- insufficient_storage = 507, // RFC4918
- loop_detected = 508, // RFC5842
- not_extended = 510, // RFC2774
- network_authentication_required = 511, // RFC6585
-
- _,
-
- pub fn phrase(self: Status) ?[]const u8 {
- return switch (self) {
- // 1xx statuses
- .@"continue" => "Continue",
- .switching_protocols => "Switching Protocols",
- .processing => "Processing",
- .early_hints => "Early Hints",
-
- // 2xx statuses
- .ok => "OK",
- .created => "Created",
- .accepted => "Accepted",
- .non_authoritative_info => "Non-Authoritative Information",
- .no_content => "No Content",
- .reset_content => "Reset Content",
- .partial_content => "Partial Content",
- .multi_status => "Multi-Status",
- .already_reported => "Already Reported",
- .im_used => "IM Used",
-
- // 3xx statuses
- .multiple_choice => "Multiple Choice",
- .moved_permanently => "Moved Permanently",
- .found => "Found",
- .see_other => "See Other",
- .not_modified => "Not Modified",
- .use_proxy => "Use Proxy",
- .temporary_redirect => "Temporary Redirect",
- .permanent_redirect => "Permanent Redirect",
-
- // 4xx statuses
- .bad_request => "Bad Request",
- .unauthorized => "Unauthorized",
- .payment_required => "Payment Required",
- .forbidden => "Forbidden",
- .not_found => "Not Found",
- .method_not_allowed => "Method Not Allowed",
- .not_acceptable => "Not Acceptable",
- .proxy_auth_required => "Proxy Authentication Required",
- .request_timeout => "Request Timeout",
- .conflict => "Conflict",
- .gone => "Gone",
- .length_required => "Length Required",
- .precondition_failed => "Precondition Failed",
- .payload_too_large => "Payload Too Large",
- .uri_too_long => "URI Too Long",
- .unsupported_media_type => "Unsupported Media Type",
- .range_not_satisfiable => "Range Not Satisfiable",
- .expectation_failed => "Expectation Failed",
- .teapot => "I'm a teapot",
- .misdirected_request => "Misdirected Request",
- .unprocessable_entity => "Unprocessable Entity",
- .locked => "Locked",
- .failed_dependency => "Failed Dependency",
- .too_early => "Too Early",
- .upgrade_required => "Upgrade Required",
- .precondition_required => "Precondition Required",
- .too_many_requests => "Too Many Requests",
- .header_fields_too_large => "Request Header Fields Too Large",
- .unavailable_for_legal_reasons => "Unavailable For Legal Reasons",
-
- // 5xx statuses
- .internal_server_error => "Internal Server Error",
- .not_implemented => "Not Implemented",
- .bad_gateway => "Bad Gateway",
- .service_unavailable => "Service Unavailable",
- .gateway_timeout => "Gateway Timeout",
- .http_version_not_supported => "HTTP Version Not Supported",
- .variant_also_negotiates => "Variant Also Negotiates",
- .insufficient_storage => "Insufficient Storage",
- .loop_detected => "Loop Detected",
- .not_extended => "Not Extended",
- .network_authentication_required => "Network Authentication Required",
-
- else => return null,
- };
- }
-
- pub const Class = enum {
- informational,
- success,
- redirect,
- client_error,
- server_error,
- };
-
- pub fn class(self: Status) ?Class {
- return switch (@enumToInt(self)) {
- 100...199 => .informational,
- 200...299 => .success,
- 300...399 => .redirect,
- 400...499 => .client_error,
- 500...599 => .server_error,
- else => null,
- };
- }
-};
-
-test {
- try std.testing.expectEqualStrings("OK", Status.ok.phrase().?);
- try std.testing.expectEqualStrings("Not Found", Status.not_found.phrase().?);
-}
-
-test {
- try std.testing.expectEqual(@as(?Status.Class, Status.Class.success), Status.ok.class());
- try std.testing.expectEqual(@as(?Status.Class, Status.Class.client_error), Status.not_found.class());
-}
diff --git a/lib/std/meta.zig b/lib/std/meta.zig
@@ -810,21 +810,25 @@ test "std.meta.activeTag" {
const TagPayloadType = TagPayload;
-///Given a tagged union type, and an enum, return the type of the union
-/// field corresponding to the enum tag.
-pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type {
+pub fn TagPayloadByName(comptime U: type, comptime tag_name: []const u8) type {
comptime debug.assert(trait.is(.Union)(U));
const info = @typeInfo(U).Union;
inline for (info.fields) |field_info| {
- if (comptime mem.eql(u8, field_info.name, @tagName(tag)))
+ if (comptime mem.eql(u8, field_info.name, tag_name))
return field_info.type;
}
unreachable;
}
+/// Given a tagged union type, and an enum, return the type of the union field
+/// corresponding to the enum tag.
+pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type {
+ return TagPayloadByName(U, @tagName(tag));
+}
+
test "std.meta.TagPayload" {
const Event = union(enum) {
Moved: struct {
diff --git a/lib/std/net.zig b/lib/std/net.zig
@@ -1672,6 +1672,40 @@ pub const Stream = struct {
}
}
+ pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize {
+ if (builtin.os.tag == .windows) {
+ // TODO improve this to use ReadFileScatter
+ if (iovecs.len == 0) return @as(usize, 0);
+ const first = iovecs[0];
+ return os.windows.ReadFile(s.handle, first.iov_base[0..first.iov_len], null, io.default_mode);
+ }
+
+ return os.readv(s.handle, iovecs);
+ }
+
+ /// Returns the number of bytes read. If the number read is smaller than
+ /// `buffer.len`, it means the stream reached the end. Reaching the end of
+ /// a stream is not an error condition.
+ pub fn readAll(s: Stream, buffer: []u8) ReadError!usize {
+ return readAtLeast(s, buffer, buffer.len);
+ }
+
+ /// Returns the number of bytes read, calling the underlying read function
+ /// the minimal number of times until the buffer has at least `len` bytes
+ /// filled. If the number read is less than `len` it means the stream
+ /// reached the end. Reaching the end of the stream is not an error
+ /// condition.
+ pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
+ assert(len <= buffer.len);
+ var index: usize = 0;
+ while (index < len) {
+ const amt = try s.read(buffer[index..]);
+ if (amt == 0) break;
+ index += amt;
+ }
+ return index;
+ }
+
/// TODO in evented I/O mode, this implementation incorrectly uses the event loop's
/// file system thread instead of non-blocking. It needs to be reworked to properly
/// use non-blocking I/O.
@@ -1687,6 +1721,13 @@ pub const Stream = struct {
}
}
+ pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void {
+ var index: usize = 0;
+ while (index < bytes.len) {
+ index += try self.write(bytes[index..]);
+ }
+ }
+
/// See https://github.com/ziglang/zig/issues/7699
/// See equivalent function: `std.fs.File.writev`.
pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize {
diff --git a/lib/std/os.zig b/lib/std/os.zig
@@ -767,6 +767,7 @@ pub fn readv(fd: fd_t, iov: []const iovec) ReadError!usize {
.ISDIR => return error.IsDir,
.NOBUFS => return error.SystemResources,
.NOMEM => return error.SystemResources,
+ .CONNRESET => return error.ConnectionResetByPeer,
else => |err| return unexpectedErrno(err),
}
}
@@ -5685,11 +5686,11 @@ pub fn sendmsg(
/// The file descriptor of the sending socket.
sockfd: socket_t,
/// Message header and iovecs
- msg: msghdr_const,
+ msg: *const msghdr_const,
flags: u32,
) SendMsgError!usize {
while (true) {
- const rc = system.sendmsg(sockfd, @ptrCast(*const std.x.os.Socket.Message, &msg), @intCast(c_int, flags));
+ const rc = system.sendmsg(sockfd, msg, flags);
if (builtin.os.tag == .windows) {
if (rc == windows.ws2_32.SOCKET_ERROR) {
switch (windows.ws2_32.WSAGetLastError()) {
diff --git a/lib/std/os/linux.zig b/lib/std/os/linux.zig
@@ -1226,11 +1226,14 @@ pub fn getsockopt(fd: i32, level: u32, optname: u32, noalias optval: [*]u8, noal
return syscall5(.getsockopt, @bitCast(usize, @as(isize, fd)), level, optname, @ptrToInt(optval), @ptrToInt(optlen));
}
-pub fn sendmsg(fd: i32, msg: *const std.x.os.Socket.Message, flags: c_int) usize {
+pub fn sendmsg(fd: i32, msg: *const msghdr_const, flags: u32) usize {
+ const fd_usize = @bitCast(usize, @as(isize, fd));
+ const msg_usize = @ptrToInt(msg);
if (native_arch == .x86) {
- return socketcall(SC.sendmsg, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)) });
+ return socketcall(SC.sendmsg, &[3]usize{ fd_usize, msg_usize, flags });
+ } else {
+ return syscall3(.sendmsg, fd_usize, msg_usize, flags);
}
- return syscall3(.sendmsg, @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)));
}
pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize {
@@ -1274,24 +1277,42 @@ pub fn sendmmsg(fd: i32, msgvec: [*]mmsghdr_const, vlen: u32, flags: u32) usize
}
pub fn connect(fd: i32, addr: *const anyopaque, len: socklen_t) usize {
+ const fd_usize = @bitCast(usize, @as(isize, fd));
+ const addr_usize = @ptrToInt(addr);
if (native_arch == .x86) {
- return socketcall(SC.connect, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(addr), len });
+ return socketcall(SC.connect, &[3]usize{ fd_usize, addr_usize, len });
+ } else {
+ return syscall3(.connect, fd_usize, addr_usize, len);
}
- return syscall3(.connect, @bitCast(usize, @as(isize, fd)), @ptrToInt(addr), len);
}
-pub fn recvmsg(fd: i32, msg: *std.x.os.Socket.Message, flags: c_int) usize {
+pub fn recvmsg(fd: i32, msg: *msghdr, flags: u32) usize {
+ const fd_usize = @bitCast(usize, @as(isize, fd));
+ const msg_usize = @ptrToInt(msg);
if (native_arch == .x86) {
- return socketcall(SC.recvmsg, &[3]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)) });
+ return socketcall(SC.recvmsg, &[3]usize{ fd_usize, msg_usize, flags });
+ } else {
+ return syscall3(.recvmsg, fd_usize, msg_usize, flags);
}
- return syscall3(.recvmsg, @bitCast(usize, @as(isize, fd)), @ptrToInt(msg), @bitCast(usize, @as(isize, flags)));
}
-pub fn recvfrom(fd: i32, noalias buf: [*]u8, len: usize, flags: u32, noalias addr: ?*sockaddr, noalias alen: ?*socklen_t) usize {
+pub fn recvfrom(
+ fd: i32,
+ noalias buf: [*]u8,
+ len: usize,
+ flags: u32,
+ noalias addr: ?*sockaddr,
+ noalias alen: ?*socklen_t,
+) usize {
+ const fd_usize = @bitCast(usize, @as(isize, fd));
+ const buf_usize = @ptrToInt(buf);
+ const addr_usize = @ptrToInt(addr);
+ const alen_usize = @ptrToInt(alen);
if (native_arch == .x86) {
- return socketcall(SC.recvfrom, &[6]usize{ @bitCast(usize, @as(isize, fd)), @ptrToInt(buf), len, flags, @ptrToInt(addr), @ptrToInt(alen) });
+ return socketcall(SC.recvfrom, &[6]usize{ fd_usize, buf_usize, len, flags, addr_usize, alen_usize });
+ } else {
+ return syscall6(.recvfrom, fd_usize, buf_usize, len, flags, addr_usize, alen_usize);
}
- return syscall6(.recvfrom, @bitCast(usize, @as(isize, fd)), @ptrToInt(buf), len, flags, @ptrToInt(addr), @ptrToInt(alen));
}
pub fn shutdown(fd: i32, how: i32) usize {
@@ -3219,7 +3240,15 @@ pub const sockaddr = extern struct {
data: [14]u8,
pub const SS_MAXSIZE = 128;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ family: sa_family_t align(8),
+ padding: [SS_MAXSIZE - @sizeOf(sa_family_t)]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
/// IPv4 socket address
pub const in = extern struct {
diff --git a/lib/std/os/linux/seccomp.zig b/lib/std/os/linux/seccomp.zig
@@ -6,16 +6,14 @@
//! isn't that useful for general-purpose applications, and so a mode that
//! utilizes user-supplied filters mode was added.
//!
-//! Seccomp filters are classic BPF programs, which means that all the
-//! information under `std.x.net.bpf` applies here as well. Conceptually, a
-//! seccomp program is attached to the kernel and is executed on each syscall.
-//! The "packet" being validated is the `data` structure, and the verdict is an
-//! action that the kernel performs on the calling process. The actions are
-//! variations on a "pass" or "fail" result, where a pass allows the syscall to
-//! continue and a fail blocks the syscall and returns some sort of error value.
-//! See the full list of actions under ::RET for more information. Finally, only
-//! word-sized, absolute loads (`ld [k]`) are supported to read from the `data`
-//! structure.
+//! Seccomp filters are classic BPF programs. Conceptually, a seccomp program
+//! is attached to the kernel and is executed on each syscall. The "packet"
+//! being validated is the `data` structure, and the verdict is an action that
+//! the kernel performs on the calling process. The actions are variations on a
+//! "pass" or "fail" result, where a pass allows the syscall to continue and a
+//! fail blocks the syscall and returns some sort of error value. See the full
+//! list of actions under ::RET for more information. Finally, only word-sized,
+//! absolute loads (`ld [k]`) are supported to read from the `data` structure.
//!
//! There are some issues with the filter API that have traditionally made
//! writing them a pain:
diff --git a/lib/std/os/windows/ws2_32.zig b/lib/std/os/windows/ws2_32.zig
@@ -1,4 +1,5 @@
const std = @import("../../std.zig");
+const assert = std.debug.assert;
const windows = std.os.windows;
const WINAPI = windows.WINAPI;
@@ -1106,7 +1107,15 @@ pub const sockaddr = extern struct {
data: [14]u8,
pub const SS_MAXSIZE = 128;
- pub const storage = std.x.os.Socket.Address.Native.Storage;
+ pub const storage = extern struct {
+ family: ADDRESS_FAMILY align(8),
+ padding: [SS_MAXSIZE - @sizeOf(ADDRESS_FAMILY)]u8 = undefined,
+
+ comptime {
+ assert(@sizeOf(storage) == SS_MAXSIZE);
+ assert(@alignOf(storage) == 8);
+ }
+ };
/// IPv4 socket address
pub const in = extern struct {
@@ -1207,7 +1216,7 @@ pub const LPFN_GETACCEPTEXSOCKADDRS = *const fn (
pub const LPFN_WSASENDMSG = *const fn (
s: SOCKET,
- lpMsg: *const std.x.os.Socket.Message,
+ lpMsg: *const WSAMSG_const,
dwFlags: u32,
lpNumberOfBytesSent: ?*u32,
lpOverlapped: ?*OVERLAPPED,
@@ -1216,7 +1225,7 @@ pub const LPFN_WSASENDMSG = *const fn (
pub const LPFN_WSARECVMSG = *const fn (
s: SOCKET,
- lpMsg: *std.x.os.Socket.Message,
+ lpMsg: *WSAMSG,
lpdwNumberOfBytesRecv: ?*u32,
lpOverlapped: ?*OVERLAPPED,
lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE,
@@ -2090,7 +2099,7 @@ pub extern "ws2_32" fn WSASend(
pub extern "ws2_32" fn WSASendMsg(
s: SOCKET,
- lpMsg: *const std.x.os.Socket.Message,
+ lpMsg: *WSAMSG_const,
dwFlags: u32,
lpNumberOfBytesSent: ?*u32,
lpOverlapped: ?*OVERLAPPED,
@@ -2099,7 +2108,7 @@ pub extern "ws2_32" fn WSASendMsg(
pub extern "ws2_32" fn WSARecvMsg(
s: SOCKET,
- lpMsg: *std.x.os.Socket.Message,
+ lpMsg: *WSAMSG,
lpdwNumberOfBytesRecv: ?*u32,
lpOverlapped: ?*OVERLAPPED,
lpCompletionRoutine: ?LPWSAOVERLAPPED_COMPLETION_ROUTINE,
diff --git a/lib/std/std.zig b/lib/std/std.zig
@@ -42,6 +42,7 @@ pub const Target = @import("target.zig").Target;
pub const Thread = @import("Thread.zig");
pub const Treap = @import("treap.zig").Treap;
pub const Tz = tz.Tz;
+pub const Url = @import("Url.zig");
pub const array_hash_map = @import("array_hash_map.zig");
pub const atomic = @import("atomic.zig");
@@ -90,7 +91,6 @@ pub const tz = @import("tz.zig");
pub const unicode = @import("unicode.zig");
pub const valgrind = @import("valgrind.zig");
pub const wasm = @import("wasm.zig");
-pub const x = @import("x.zig");
pub const zig = @import("zig.zig");
pub const start = @import("start.zig");
diff --git a/lib/std/x.zig b/lib/std/x.zig
@@ -1,19 +0,0 @@
-const std = @import("std.zig");
-
-pub const os = struct {
- pub const Socket = @import("x/os/socket.zig").Socket;
- pub usingnamespace @import("x/os/io.zig");
- pub usingnamespace @import("x/os/net.zig");
-};
-
-pub const net = struct {
- pub const ip = @import("x/net/ip.zig");
- pub const tcp = @import("x/net/tcp.zig");
- pub const bpf = @import("x/net/bpf.zig");
-};
-
-test {
- inline for (.{ os, net }) |module| {
- std.testing.refAllDecls(module);
- }
-}
diff --git a/lib/std/x/net/bpf.zig b/lib/std/x/net/bpf.zig
@@ -1,1003 +0,0 @@
-//! This package provides instrumentation for creating Berkeley Packet Filter[1]
-//! (BPF) programs, along with a simulator for running them.
-//!
-//! BPF is a mechanism for cheap, in-kernel packet filtering. Programs are
-//! attached to a network device and executed for every packet that flows
-//! through it. The program must then return a verdict: the amount of packet
-//! bytes that the kernel should copy into userspace. Execution speed is
-//! achieved by having programs run in a limited virtual machine, which has the
-//! added benefit of graceful failure in the face of buggy programs.
-//!
-//! The BPF virtual machine has a 32-bit word length and a small number of
-//! word-sized registers:
-//!
-//! - The accumulator, `a`: The source/destination of arithmetic and logic
-//! operations.
-//! - The index register, `x`: Used as an offset for indirect memory access and
-//! as a comparison value for conditional jumps.
-//! - The scratch memory store, `M[0]..M[15]`: Used for saving the value of a/x
-//! for later use.
-//!
-//! The packet being examined is an array of bytes, and is addressed using plain
-//! array subscript notation, e.g. [10] for the byte at offset 10. An implicit
-//! program counter, `pc`, is intialized to zero and incremented for each instruction.
-//!
-//! The machine has a fixed instruction set with the following form, where the
-//! numbers represent bit length:
-//!
-//! ```
-//! ┌───────────┬──────┬──────┐
-//! │ opcode:16 │ jt:8 │ jt:8 │
-//! ├───────────┴──────┴──────┤
-//! │ k:32 │
-//! └─────────────────────────┘
-//! ```
-//!
-//! The `opcode` indicates the instruction class and its addressing mode.
-//! Opcodes are generated by performing binary addition on the 8-bit class and
-//! mode constants. For example, the opcode for loading a byte from the packet
-//! at X + 2, (`ldb [x + 2]`), is:
-//!
-//! ```
-//! LD | IND | B = 0x00 | 0x40 | 0x20
-//! = 0x60
-//! ```
-//!
-//! `jt` is an offset used for conditional jumps, and increments the program
-//! counter by its amount if the comparison was true. Conversely, `jf`
-//! increments the counter if it was false. These fields are ignored in all
-//! other cases. `k` is a generic variable used for various purposes, most
-//! commonly as some sort of constant.
-//!
-//! This package contains opcode extensions used by different implementations,
-//! where "extension" is anything outside of the original that was imported into
-//! 4.4BSD[2]. These are marked with "EXTENSION", along with a list of
-//! implementations that use them.
-//!
-//! Most of the doc-comments use the BPF assembly syntax as described in the
-//! original paper[1]. For the sake of completeness, here is the complete
-//! instruction set, along with the extensions:
-//!
-//!```
-//! opcode addressing modes
-//! ld #k #len M[k] [k] [x + k]
-//! ldh [k] [x + k]
-//! ldb [k] [x + k]
-//! ldx #k #len M[k] 4 * ([k] & 0xf) arc4random()
-//! st M[k]
-//! stx M[k]
-//! jmp L
-//! jeq #k, Lt, Lf
-//! jgt #k, Lt, Lf
-//! jge #k, Lt, Lf
-//! jset #k, Lt, Lf
-//! add #k x
-//! sub #k x
-//! mul #k x
-//! div #k x
-//! or #k x
-//! and #k x
-//! lsh #k x
-//! rsh #k x
-//! neg #k x
-//! mod #k x
-//! xor #k x
-//! ret #k a
-//! tax
-//! txa
-//! ```
-//!
-//! Finally, a note on program design. The lack of backwards jumps leads to a
-//! "return early, return often" control flow. Take for example the program
-//! generated from the tcpdump filter `ip`:
-//!
-//! ```
-//! (000) ldh [12] ; Ethernet Packet Type
-//! (001) jeq #0x86dd, 2, 7 ; ETHERTYPE_IPV6
-//! (002) ldb [20] ; IPv6 Next Header
-//! (003) jeq #0x6, 10, 4 ; TCP
-//! (004) jeq #0x2c, 5, 11 ; IPv6 Fragment Header
-//! (005) ldb [54] ; TCP Source Port
-//! (006) jeq #0x6, 10, 11 ; IPPROTO_TCP
-//! (007) jeq #0x800, 8, 11 ; ETHERTYPE_IP
-//! (008) ldb [23] ; IPv4 Protocol
-//! (009) jeq #0x6, 10, 11 ; IPPROTO_TCP
-//! (010) ret #262144 ; copy 0x40000
-//! (011) ret #0 ; skip packet
-//! ```
-//!
-//! Here we can make a few observations:
-//!
-//! - The problem "filter only tcp packets" has essentially been transformed
-//! into a series of layer checks.
-//! - There are two distinct branches in the code, one for validating IPv4
-//! headers and one for IPv6 headers.
-//! - Most conditional jumps in these branches lead directly to the last two
-//! instructions, a pass or fail. Thus the goal of a program is to find the
-//! fastest route to a pass/fail comparison.
-//!
-//! [1]: S. McCanne and V. Jacobson, "The BSD Packet Filter: A New Architecture
-//! for User-level Packet Capture", Proceedings of the 1993 Winter USENIX.
-//! [2]: https://minnie.tuhs.org/cgi-bin/utree.pl?file=4.4BSD/usr/src/sys/net/bpf.h
-const std = @import("std");
-const builtin = @import("builtin");
-const native_endian = builtin.target.cpu.arch.endian();
-const mem = std.mem;
-const math = std.math;
-const random = std.crypto.random;
-const assert = std.debug.assert;
-const expectEqual = std.testing.expectEqual;
-const expectError = std.testing.expectError;
-const expect = std.testing.expect;
-
-// instruction classes
-/// ld, ldh, ldb: Load data into a.
-pub const LD = 0x00;
-/// ldx: Load data into x.
-pub const LDX = 0x01;
-/// st: Store into scratch memory the value of a.
-pub const ST = 0x02;
-/// st: Store into scratch memory the value of x.
-pub const STX = 0x03;
-/// alu: Wrapping arithmetic/bitwise operations on a using the value of k/x.
-pub const ALU = 0x04;
-/// jmp, jeq, jgt, je, jset: Increment the program counter based on a comparison
-/// between k/x and the accumulator.
-pub const JMP = 0x05;
-/// ret: Return a verdict using the value of k/the accumulator.
-pub const RET = 0x06;
-/// tax, txa: Register value copying between X and a.
-pub const MISC = 0x07;
-
-// Size of data to be loaded from the packet.
-/// ld: 32-bit full word.
-pub const W = 0x00;
-/// ldh: 16-bit half word.
-pub const H = 0x08;
-/// ldb: Single byte.
-pub const B = 0x10;
-
-// Addressing modes used for loads to a/x.
-/// #k: The immediate value stored in k.
-pub const IMM = 0x00;
-/// [k]: The value at offset k in the packet.
-pub const ABS = 0x20;
-/// [x + k]: The value at offset x + k in the packet.
-pub const IND = 0x40;
-/// M[k]: The value of the k'th scratch memory register.
-pub const MEM = 0x60;
-/// #len: The size of the packet.
-pub const LEN = 0x80;
-/// 4 * ([k] & 0xf): Four times the low four bits of the byte at offset k in the
-/// packet. This is used for efficiently loading the header length of an IP
-/// packet.
-pub const MSH = 0xa0;
-/// arc4random: 32-bit integer generated from a CPRNG (see arc4random(3)) loaded into a.
-/// EXTENSION. Defined for:
-/// - OpenBSD.
-pub const RND = 0xc0;
-
-// Modifiers for different instruction classes.
-/// Use the value of k for alu operations (add #k).
-/// Compare against the value of k for jumps (jeq #k, Lt, Lf).
-/// Return the value of k for returns (ret #k).
-pub const K = 0x00;
-/// Use the value of x for alu operations (add x).
-/// Compare against the value of X for jumps (jeq x, Lt, Lf).
-pub const X = 0x08;
-/// Return the value of a for returns (ret a).
-pub const A = 0x10;
-
-// ALU Operations on a using the value of k/x.
-// All arithmetic operations are defined to overflow the value of a.
-/// add: a = a + k
-/// a = a + x.
-pub const ADD = 0x00;
-/// sub: a = a - k
-/// a = a - x.
-pub const SUB = 0x10;
-/// mul: a = a * k
-/// a = a * x.
-pub const MUL = 0x20;
-/// div: a = a / k
-/// a = a / x.
-/// Truncated division.
-pub const DIV = 0x30;
-/// or: a = a | k
-/// a = a | x.
-pub const OR = 0x40;
-/// and: a = a & k
-/// a = a & x.
-pub const AND = 0x50;
-/// lsh: a = a << k
-/// a = a << x.
-/// a = a << k, a = a << x.
-pub const LSH = 0x60;
-/// rsh: a = a >> k
-/// a = a >> x.
-pub const RSH = 0x70;
-/// neg: a = -a.
-/// Note that this isn't a binary negation, rather the value of `~a + 1`.
-pub const NEG = 0x80;
-/// mod: a = a % k
-/// a = a % x.
-/// EXTENSION. Defined for:
-/// - Linux.
-/// - NetBSD + Minix 3.
-/// - FreeBSD and derivitives.
-pub const MOD = 0x90;
-/// xor: a = a ^ k
-/// a = a ^ x.
-/// EXTENSION. Defined for:
-/// - Linux.
-/// - NetBSD + Minix 3.
-/// - FreeBSD and derivitives.
-pub const XOR = 0xa0;
-
-// Jump operations using a comparison between a and x/k.
-/// jmp L: pc += k.
-/// No comparison done here.
-pub const JA = 0x00;
-/// jeq #k, Lt, Lf: pc += (a == k) ? jt : jf.
-/// jeq x, Lt, Lf: pc += (a == x) ? jt : jf.
-pub const JEQ = 0x10;
-/// jgt #k, Lt, Lf: pc += (a > k) ? jt : jf.
-/// jgt x, Lt, Lf: pc += (a > x) ? jt : jf.
-pub const JGT = 0x20;
-/// jge #k, Lt, Lf: pc += (a >= k) ? jt : jf.
-/// jge x, Lt, Lf: pc += (a >= x) ? jt : jf.
-pub const JGE = 0x30;
-/// jset #k, Lt, Lf: pc += (a & k > 0) ? jt : jf.
-/// jset x, Lt, Lf: pc += (a & x > 0) ? jt : jf.
-pub const JSET = 0x40;
-
-// Miscellaneous operations/register copy.
-/// tax: x = a.
-pub const TAX = 0x00;
-/// txa: a = x.
-pub const TXA = 0x80;
-
-/// The 16 registers in the scratch memory store as named enums.
-pub const Scratch = enum(u4) { m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15 };
-pub const MEMWORDS = 16;
-pub const MAXINSNS = switch (builtin.os.tag) {
- .linux => 4096,
- else => 512,
-};
-pub const MINBUFSIZE = 32;
-pub const MAXBUFSIZE = 1 << 21;
-
-pub const Insn = extern struct {
- opcode: u16,
- jt: u8,
- jf: u8,
- k: u32,
-
- /// Implements the `std.fmt.format` API.
- /// The formatting is similar to the output of tcpdump -dd.
- pub fn format(
- self: Insn,
- comptime layout: []const u8,
- opts: std.fmt.FormatOptions,
- writer: anytype,
- ) !void {
- _ = opts;
- if (layout.len != 0) std.fmt.invalidFmtError(layout, self);
-
- try std.fmt.format(
- writer,
- "Insn{{ 0x{X:0<2}, {d}, {d}, 0x{X:0<8} }}",
- .{ self.opcode, self.jt, self.jf, self.k },
- );
- }
-
- const Size = enum(u8) {
- word = W,
- half_word = H,
- byte = B,
- };
-
- fn stmt(opcode: u16, k: u32) Insn {
- return .{
- .opcode = opcode,
- .jt = 0,
- .jf = 0,
- .k = k,
- };
- }
-
- pub fn ld_imm(value: u32) Insn {
- return stmt(LD | IMM, value);
- }
-
- pub fn ld_abs(size: Size, offset: u32) Insn {
- return stmt(LD | ABS | @enumToInt(size), offset);
- }
-
- pub fn ld_ind(size: Size, offset: u32) Insn {
- return stmt(LD | IND | @enumToInt(size), offset);
- }
-
- pub fn ld_mem(reg: Scratch) Insn {
- return stmt(LD | MEM, @enumToInt(reg));
- }
-
- pub fn ld_len() Insn {
- return stmt(LD | LEN | W, 0);
- }
-
- pub fn ld_rnd() Insn {
- return stmt(LD | RND | W, 0);
- }
-
- pub fn ldx_imm(value: u32) Insn {
- return stmt(LDX | IMM, value);
- }
-
- pub fn ldx_mem(reg: Scratch) Insn {
- return stmt(LDX | MEM, @enumToInt(reg));
- }
-
- pub fn ldx_len() Insn {
- return stmt(LDX | LEN | W, 0);
- }
-
- pub fn ldx_msh(offset: u32) Insn {
- return stmt(LDX | MSH | B, offset);
- }
-
- pub fn st(reg: Scratch) Insn {
- return stmt(ST, @enumToInt(reg));
- }
- pub fn stx(reg: Scratch) Insn {
- return stmt(STX, @enumToInt(reg));
- }
-
- const AluOp = enum(u16) {
- add = ADD,
- sub = SUB,
- mul = MUL,
- div = DIV,
- @"or" = OR,
- @"and" = AND,
- lsh = LSH,
- rsh = RSH,
- mod = MOD,
- xor = XOR,
- };
-
- const Source = enum(u16) {
- k = K,
- x = X,
- };
- const KOrX = union(Source) {
- k: u32,
- x: void,
- };
-
- pub fn alu_neg() Insn {
- return stmt(ALU | NEG, 0);
- }
-
- pub fn alu(op: AluOp, source: KOrX) Insn {
- return stmt(
- ALU | @enumToInt(op) | @enumToInt(source),
- if (source == .k) source.k else 0,
- );
- }
-
- const JmpOp = enum(u16) {
- jeq = JEQ,
- jgt = JGT,
- jge = JGE,
- jset = JSET,
- };
-
- pub fn jmp_ja(location: u32) Insn {
- return stmt(JMP | JA, location);
- }
-
- pub fn jmp(op: JmpOp, source: KOrX, jt: u8, jf: u8) Insn {
- return Insn{
- .opcode = JMP | @enumToInt(op) | @enumToInt(source),
- .jt = jt,
- .jf = jf,
- .k = if (source == .k) source.k else 0,
- };
- }
-
- const Verdict = enum(u16) {
- k = K,
- a = A,
- };
- const KOrA = union(Verdict) {
- k: u32,
- a: void,
- };
-
- pub fn ret(verdict: KOrA) Insn {
- return stmt(
- RET | @enumToInt(verdict),
- if (verdict == .k) verdict.k else 0,
- );
- }
-
- pub fn tax() Insn {
- return stmt(MISC | TAX, 0);
- }
-
- pub fn txa() Insn {
- return stmt(MISC | TXA, 0);
- }
-};
-
-fn opcodeEqual(opcode: u16, insn: Insn) !void {
- try expectEqual(opcode, insn.opcode);
-}
-
-test "opcodes" {
- try opcodeEqual(0x00, Insn.ld_imm(0));
- try opcodeEqual(0x20, Insn.ld_abs(.word, 0));
- try opcodeEqual(0x28, Insn.ld_abs(.half_word, 0));
- try opcodeEqual(0x30, Insn.ld_abs(.byte, 0));
- try opcodeEqual(0x40, Insn.ld_ind(.word, 0));
- try opcodeEqual(0x48, Insn.ld_ind(.half_word, 0));
- try opcodeEqual(0x50, Insn.ld_ind(.byte, 0));
- try opcodeEqual(0x60, Insn.ld_mem(.m0));
- try opcodeEqual(0x80, Insn.ld_len());
- try opcodeEqual(0xc0, Insn.ld_rnd());
-
- try opcodeEqual(0x01, Insn.ldx_imm(0));
- try opcodeEqual(0x61, Insn.ldx_mem(.m0));
- try opcodeEqual(0x81, Insn.ldx_len());
- try opcodeEqual(0xb1, Insn.ldx_msh(0));
-
- try opcodeEqual(0x02, Insn.st(.m0));
- try opcodeEqual(0x03, Insn.stx(.m0));
-
- try opcodeEqual(0x04, Insn.alu(.add, .{ .k = 0 }));
- try opcodeEqual(0x14, Insn.alu(.sub, .{ .k = 0 }));
- try opcodeEqual(0x24, Insn.alu(.mul, .{ .k = 0 }));
- try opcodeEqual(0x34, Insn.alu(.div, .{ .k = 0 }));
- try opcodeEqual(0x44, Insn.alu(.@"or", .{ .k = 0 }));
- try opcodeEqual(0x54, Insn.alu(.@"and", .{ .k = 0 }));
- try opcodeEqual(0x64, Insn.alu(.lsh, .{ .k = 0 }));
- try opcodeEqual(0x74, Insn.alu(.rsh, .{ .k = 0 }));
- try opcodeEqual(0x94, Insn.alu(.mod, .{ .k = 0 }));
- try opcodeEqual(0xa4, Insn.alu(.xor, .{ .k = 0 }));
- try opcodeEqual(0x84, Insn.alu_neg());
- try opcodeEqual(0x0c, Insn.alu(.add, .x));
- try opcodeEqual(0x1c, Insn.alu(.sub, .x));
- try opcodeEqual(0x2c, Insn.alu(.mul, .x));
- try opcodeEqual(0x3c, Insn.alu(.div, .x));
- try opcodeEqual(0x4c, Insn.alu(.@"or", .x));
- try opcodeEqual(0x5c, Insn.alu(.@"and", .x));
- try opcodeEqual(0x6c, Insn.alu(.lsh, .x));
- try opcodeEqual(0x7c, Insn.alu(.rsh, .x));
- try opcodeEqual(0x9c, Insn.alu(.mod, .x));
- try opcodeEqual(0xac, Insn.alu(.xor, .x));
-
- try opcodeEqual(0x05, Insn.jmp_ja(0));
- try opcodeEqual(0x15, Insn.jmp(.jeq, .{ .k = 0 }, 0, 0));
- try opcodeEqual(0x25, Insn.jmp(.jgt, .{ .k = 0 }, 0, 0));
- try opcodeEqual(0x35, Insn.jmp(.jge, .{ .k = 0 }, 0, 0));
- try opcodeEqual(0x45, Insn.jmp(.jset, .{ .k = 0 }, 0, 0));
- try opcodeEqual(0x1d, Insn.jmp(.jeq, .x, 0, 0));
- try opcodeEqual(0x2d, Insn.jmp(.jgt, .x, 0, 0));
- try opcodeEqual(0x3d, Insn.jmp(.jge, .x, 0, 0));
- try opcodeEqual(0x4d, Insn.jmp(.jset, .x, 0, 0));
-
- try opcodeEqual(0x06, Insn.ret(.{ .k = 0 }));
- try opcodeEqual(0x16, Insn.ret(.a));
-
- try opcodeEqual(0x07, Insn.tax());
- try opcodeEqual(0x87, Insn.txa());
-}
-
-pub const Error = error{
- InvalidOpcode,
- InvalidOffset,
- InvalidLocation,
- DivisionByZero,
- NoReturn,
-};
-
-/// A simple implementation of the BPF virtual-machine.
-/// Use this to run/debug programs.
-pub fn simulate(
- packet: []const u8,
- filter: []const Insn,
- byte_order: std.builtin.Endian,
-) Error!u32 {
- assert(filter.len > 0 and filter.len < MAXINSNS);
- assert(packet.len < MAXBUFSIZE);
- const len = @intCast(u32, packet.len);
-
- var a: u32 = 0;
- var x: u32 = 0;
- var m = mem.zeroes([MEMWORDS]u32);
- var pc: usize = 0;
-
- while (pc < filter.len) : (pc += 1) {
- const i = filter[pc];
- // Cast to a wider type to protect against overflow.
- const k = @as(u64, i.k);
- const remaining = filter.len - (pc + 1);
-
- // Do validation/error checking here to compress the second switch.
- switch (i.opcode) {
- LD | ABS | W => if (k + @sizeOf(u32) - 1 >= packet.len) return error.InvalidOffset,
- LD | ABS | H => if (k + @sizeOf(u16) - 1 >= packet.len) return error.InvalidOffset,
- LD | ABS | B => if (k >= packet.len) return error.InvalidOffset,
- LD | IND | W => if (k + x + @sizeOf(u32) - 1 >= packet.len) return error.InvalidOffset,
- LD | IND | H => if (k + x + @sizeOf(u16) - 1 >= packet.len) return error.InvalidOffset,
- LD | IND | B => if (k + x >= packet.len) return error.InvalidOffset,
-
- LDX | MSH | B => if (k >= packet.len) return error.InvalidOffset,
- ST, STX, LD | MEM, LDX | MEM => if (i.k >= MEMWORDS) return error.InvalidOffset,
-
- JMP | JA => if (remaining <= i.k) return error.InvalidOffset,
- JMP | JEQ | K,
- JMP | JGT | K,
- JMP | JGE | K,
- JMP | JSET | K,
- JMP | JEQ | X,
- JMP | JGT | X,
- JMP | JGE | X,
- JMP | JSET | X,
- => if (remaining <= i.jt or remaining <= i.jf) return error.InvalidLocation,
- else => {},
- }
- switch (i.opcode) {
- LD | IMM => a = i.k,
- LD | MEM => a = m[i.k],
- LD | LEN | W => a = len,
- LD | RND | W => a = random.int(u32),
- LD | ABS | W => a = mem.readInt(u32, packet[i.k..][0..@sizeOf(u32)], byte_order),
- LD | ABS | H => a = mem.readInt(u16, packet[i.k..][0..@sizeOf(u16)], byte_order),
- LD | ABS | B => a = packet[i.k],
- LD | IND | W => a = mem.readInt(u32, packet[i.k + x ..][0..@sizeOf(u32)], byte_order),
- LD | IND | H => a = mem.readInt(u16, packet[i.k + x ..][0..@sizeOf(u16)], byte_order),
- LD | IND | B => a = packet[i.k + x],
-
- LDX | IMM => x = i.k,
- LDX | MEM => x = m[i.k],
- LDX | LEN | W => x = len,
- LDX | MSH | B => x = @as(u32, @truncate(u4, packet[i.k])) << 2,
-
- ST => m[i.k] = a,
- STX => m[i.k] = x,
-
- ALU | ADD | K => a +%= i.k,
- ALU | SUB | K => a -%= i.k,
- ALU | MUL | K => a *%= i.k,
- ALU | DIV | K => a = try math.divTrunc(u32, a, i.k),
- ALU | OR | K => a |= i.k,
- ALU | AND | K => a &= i.k,
- ALU | LSH | K => a = math.shl(u32, a, i.k),
- ALU | RSH | K => a = math.shr(u32, a, i.k),
- ALU | MOD | K => a = try math.mod(u32, a, i.k),
- ALU | XOR | K => a ^= i.k,
- ALU | ADD | X => a +%= x,
- ALU | SUB | X => a -%= x,
- ALU | MUL | X => a *%= x,
- ALU | DIV | X => a = try math.divTrunc(u32, a, x),
- ALU | OR | X => a |= x,
- ALU | AND | X => a &= x,
- ALU | LSH | X => a = math.shl(u32, a, x),
- ALU | RSH | X => a = math.shr(u32, a, x),
- ALU | MOD | X => a = try math.mod(u32, a, x),
- ALU | XOR | X => a ^= x,
- ALU | NEG => a = @bitCast(u32, -%@bitCast(i32, a)),
-
- JMP | JA => pc += i.k,
- JMP | JEQ | K => pc += if (a == i.k) i.jt else i.jf,
- JMP | JGT | K => pc += if (a > i.k) i.jt else i.jf,
- JMP | JGE | K => pc += if (a >= i.k) i.jt else i.jf,
- JMP | JSET | K => pc += if (a & i.k > 0) i.jt else i.jf,
- JMP | JEQ | X => pc += if (a == x) i.jt else i.jf,
- JMP | JGT | X => pc += if (a > x) i.jt else i.jf,
- JMP | JGE | X => pc += if (a >= x) i.jt else i.jf,
- JMP | JSET | X => pc += if (a & x > 0) i.jt else i.jf,
-
- RET | K => return i.k,
- RET | A => return a,
-
- MISC | TAX => x = a,
- MISC | TXA => a = x,
- else => return error.InvalidOpcode,
- }
- }
-
- return error.NoReturn;
-}
-
-// This program is the BPF form of the tcpdump filter:
-//
-// tcpdump -dd 'ip host mirror.internode.on.net and tcp port ftp-data'
-//
-// As of January 2022, mirror.internode.on.net resolves to 150.101.135.3
-//
-// For reference, here's what it looks like in BPF assembler.
-// Note that the jumps are used for TCP/IP layer checks.
-//
-// ```
-// ldh [12] (#proto)
-// jeq #0x0800 (ETHERTYPE_IP), L1, fail
-// L1: ld [26]
-// jeq #150.101.135.3, L2, dest
-// dest: ld [30]
-// jeq #150.101.135.3, L2, fail
-// L2: ldb [23]
-// jeq #0x6 (IPPROTO_TCP), L3, fail
-// L3: ldh [20]
-// jset #0x1fff, fail, plen
-// plen: ldx 4 * ([14] & 0xf)
-// ldh [x + 14]
-// jeq #0x14 (FTP), pass, dstp
-// dstp: ldh [x + 16]
-// jeq #0x14 (FTP), pass, fail
-// pass: ret #0x40000
-// fail: ret #0
-// ```
-const tcpdump_filter = [_]Insn{
- Insn.ld_abs(.half_word, 12),
- Insn.jmp(.jeq, .{ .k = 0x800 }, 0, 14),
- Insn.ld_abs(.word, 26),
- Insn.jmp(.jeq, .{ .k = 0x96658703 }, 2, 0),
- Insn.ld_abs(.word, 30),
- Insn.jmp(.jeq, .{ .k = 0x96658703 }, 0, 10),
- Insn.ld_abs(.byte, 23),
- Insn.jmp(.jeq, .{ .k = 0x6 }, 0, 8),
- Insn.ld_abs(.half_word, 20),
- Insn.jmp(.jset, .{ .k = 0x1fff }, 6, 0),
- Insn.ldx_msh(14),
- Insn.ld_ind(.half_word, 14),
- Insn.jmp(.jeq, .{ .k = 0x14 }, 2, 0),
- Insn.ld_ind(.half_word, 16),
- Insn.jmp(.jeq, .{ .k = 0x14 }, 0, 1),
- Insn.ret(.{ .k = 0x40000 }),
- Insn.ret(.{ .k = 0 }),
-};
-
-// This packet is the output of `ls` on mirror.internode.on.net:/, captured
-// using the filter above.
-//
-// zig fmt: off
-const ftp_data = [_]u8{
- // ethernet - 14 bytes: IPv4(0x0800) from a4:71:74:ad:4b:f0 -> de:ad:be:ef:f0:0f
- 0xde, 0xad, 0xbe, 0xef, 0xf0, 0x0f, 0xa4, 0x71, 0x74, 0xad, 0x4b, 0xf0, 0x08, 0x00,
- // IPv4 - 20 bytes: TCP data from 150.101.135.3 -> 192.168.1.3
- 0x45, 0x00, 0x01, 0xf2, 0x70, 0x3b, 0x40, 0x00, 0x37, 0x06, 0xf2, 0xb6,
- 0x96, 0x65, 0x87, 0x03, 0xc0, 0xa8, 0x01, 0x03,
- // TCP - 32 bytes: Source port: 20 (FTP). Payload = 446 bytes
- 0x00, 0x14, 0x80, 0x6d, 0x35, 0x81, 0x2d, 0x40, 0x4f, 0x8a, 0x29, 0x9e, 0x80, 0x18, 0x00, 0x2e,
- 0x88, 0x8d, 0x00, 0x00, 0x01, 0x01, 0x08, 0x0a, 0x0b, 0x59, 0x5d, 0x09, 0x32, 0x8b, 0x51, 0xa0
-} ++
- // Raw line-based FTP data - 446 bytes
- "lrwxrwxrwx 1 root root 12 Feb 14 2012 debian -> .pub2/debian\r\n" ++
- "lrwxrwxrwx 1 root root 15 Feb 14 2012 debian-cd -> .pub2/debian-cd\r\n" ++
- "lrwxrwxrwx 1 root root 9 Mar 9 2018 linux -> pub/linux\r\n" ++
- "drwxr-xr-X 3 mirror mirror 4096 Sep 20 08:10 pub\r\n" ++
- "lrwxrwxrwx 1 root root 12 Feb 14 2012 ubuntu -> .pub2/ubuntu\r\n" ++
- "-rw-r--r-- 1 root root 1044 Jan 20 2015 welcome.msg\r\n";
-// zig fmt: on
-
-test "tcpdump filter" {
- try expectEqual(
- @as(u32, 0x40000),
- try simulate(ftp_data, &tcpdump_filter, .Big),
- );
-}
-
-fn expectPass(data: anytype, filter: []const Insn) !void {
- try expectEqual(
- @as(u32, 0),
- try simulate(mem.asBytes(data), filter, .Big),
- );
-}
-
-fn expectFail(expected_error: anyerror, data: anytype, filter: []const Insn) !void {
- try expectError(
- expected_error,
- simulate(mem.asBytes(data), filter, native_endian),
- );
-}
-
-test "simulator coverage" {
- const some_data = [_]u8{
- 0xaa, 0xbb, 0xcc, 0xdd, 0x7f,
- };
-
- try expectPass(&some_data, &.{
- // ld #10
- // ldx #1
- // st M[0]
- // stx M[1]
- // fail if A != 10
- Insn.ld_imm(10),
- Insn.ldx_imm(1),
- Insn.st(.m0),
- Insn.stx(.m1),
- Insn.jmp(.jeq, .{ .k = 10 }, 1, 0),
- Insn.ret(.{ .k = 1 }),
- // ld [0]
- // fail if A != 0xaabbccdd
- Insn.ld_abs(.word, 0),
- Insn.jmp(.jeq, .{ .k = 0xaabbccdd }, 1, 0),
- Insn.ret(.{ .k = 2 }),
- // ldh [0]
- // fail if A != 0xaabb
- Insn.ld_abs(.half_word, 0),
- Insn.jmp(.jeq, .{ .k = 0xaabb }, 1, 0),
- Insn.ret(.{ .k = 3 }),
- // ldb [0]
- // fail if A != 0xaa
- Insn.ld_abs(.byte, 0),
- Insn.jmp(.jeq, .{ .k = 0xaa }, 1, 0),
- Insn.ret(.{ .k = 4 }),
- // ld [x + 0]
- // fail if A != 0xbbccdd7f
- Insn.ld_ind(.word, 0),
- Insn.jmp(.jeq, .{ .k = 0xbbccdd7f }, 1, 0),
- Insn.ret(.{ .k = 5 }),
- // ldh [x + 0]
- // fail if A != 0xbbcc
- Insn.ld_ind(.half_word, 0),
- Insn.jmp(.jeq, .{ .k = 0xbbcc }, 1, 0),
- Insn.ret(.{ .k = 6 }),
- // ldb [x + 0]
- // fail if A != 0xbb
- Insn.ld_ind(.byte, 0),
- Insn.jmp(.jeq, .{ .k = 0xbb }, 1, 0),
- Insn.ret(.{ .k = 7 }),
- // ld M[0]
- // fail if A != 10
- Insn.ld_mem(.m0),
- Insn.jmp(.jeq, .{ .k = 10 }, 1, 0),
- Insn.ret(.{ .k = 8 }),
- // ld #len
- // fail if A != 5
- Insn.ld_len(),
- Insn.jmp(.jeq, .{ .k = some_data.len }, 1, 0),
- Insn.ret(.{ .k = 9 }),
- // ld #0
- // ld arc4random()
- // fail if A == 0
- Insn.ld_imm(0),
- Insn.ld_rnd(),
- Insn.jmp(.jgt, .{ .k = 0 }, 1, 0),
- Insn.ret(.{ .k = 10 }),
- // ld #3
- // ldx #10
- // st M[2]
- // txa
- // fail if a != x
- Insn.ld_imm(3),
- Insn.ldx_imm(10),
- Insn.st(.m2),
- Insn.txa(),
- Insn.jmp(.jeq, .x, 1, 0),
- Insn.ret(.{ .k = 11 }),
- // ldx M[2]
- // fail if A <= X
- Insn.ldx_mem(.m2),
- Insn.jmp(.jgt, .x, 1, 0),
- Insn.ret(.{ .k = 12 }),
- // ldx #len
- // fail if a <= x
- Insn.ldx_len(),
- Insn.jmp(.jgt, .x, 1, 0),
- Insn.ret(.{ .k = 13 }),
- // a = 4 * (0x7f & 0xf)
- // x = 4 * ([4] & 0xf)
- // fail if a != x
- Insn.ld_imm(4 * (0x7f & 0xf)),
- Insn.ldx_msh(4),
- Insn.jmp(.jeq, .x, 1, 0),
- Insn.ret(.{ .k = 14 }),
- // ld #(u32)-1
- // ldx #2
- // add #1
- // fail if a != 0
- Insn.ld_imm(0xffffffff),
- Insn.ldx_imm(2),
- Insn.alu(.add, .{ .k = 1 }),
- Insn.jmp(.jeq, .{ .k = 0 }, 1, 0),
- Insn.ret(.{ .k = 15 }),
- // sub #1
- // fail if a != (u32)-1
- Insn.alu(.sub, .{ .k = 1 }),
- Insn.jmp(.jeq, .{ .k = 0xffffffff }, 1, 0),
- Insn.ret(.{ .k = 16 }),
- // add x
- // fail if a != 1
- Insn.alu(.add, .x),
- Insn.jmp(.jeq, .{ .k = 1 }, 1, 0),
- Insn.ret(.{ .k = 17 }),
- // sub x
- // fail if a != (u32)-1
- Insn.alu(.sub, .x),
- Insn.jmp(.jeq, .{ .k = 0xffffffff }, 1, 0),
- Insn.ret(.{ .k = 18 }),
- // ld #16
- // mul #2
- // fail if a != 32
- Insn.ld_imm(16),
- Insn.alu(.mul, .{ .k = 2 }),
- Insn.jmp(.jeq, .{ .k = 32 }, 1, 0),
- Insn.ret(.{ .k = 19 }),
- // mul x
- // fail if a != 64
- Insn.alu(.mul, .x),
- Insn.jmp(.jeq, .{ .k = 64 }, 1, 0),
- Insn.ret(.{ .k = 20 }),
- // div #2
- // fail if a != 32
- Insn.alu(.div, .{ .k = 2 }),
- Insn.jmp(.jeq, .{ .k = 32 }, 1, 0),
- Insn.ret(.{ .k = 21 }),
- // div x
- // fail if a != 16
- Insn.alu(.div, .x),
- Insn.jmp(.jeq, .{ .k = 16 }, 1, 0),
- Insn.ret(.{ .k = 22 }),
- // or #4
- // fail if a != 20
- Insn.alu(.@"or", .{ .k = 4 }),
- Insn.jmp(.jeq, .{ .k = 20 }, 1, 0),
- Insn.ret(.{ .k = 23 }),
- // or x
- // fail if a != 22
- Insn.alu(.@"or", .x),
- Insn.jmp(.jeq, .{ .k = 22 }, 1, 0),
- Insn.ret(.{ .k = 24 }),
- // and #6
- // fail if a != 6
- Insn.alu(.@"and", .{ .k = 0b110 }),
- Insn.jmp(.jeq, .{ .k = 6 }, 1, 0),
- Insn.ret(.{ .k = 25 }),
- // and x
- // fail if a != 2
- Insn.alu(.@"and", .x),
- Insn.jmp(.jeq, .x, 1, 0),
- Insn.ret(.{ .k = 26 }),
- // xor #15
- // fail if a != 13
- Insn.alu(.xor, .{ .k = 0b1111 }),
- Insn.jmp(.jeq, .{ .k = 0b1101 }, 1, 0),
- Insn.ret(.{ .k = 27 }),
- // xor x
- // fail if a != 15
- Insn.alu(.xor, .x),
- Insn.jmp(.jeq, .{ .k = 0b1111 }, 1, 0),
- Insn.ret(.{ .k = 28 }),
- // rsh #1
- // fail if a != 7
- Insn.alu(.rsh, .{ .k = 1 }),
- Insn.jmp(.jeq, .{ .k = 0b0111 }, 1, 0),
- Insn.ret(.{ .k = 29 }),
- // rsh x
- // fail if a != 1
- Insn.alu(.rsh, .x),
- Insn.jmp(.jeq, .{ .k = 0b0001 }, 1, 0),
- Insn.ret(.{ .k = 30 }),
- // lsh #1
- // fail if a != 2
- Insn.alu(.lsh, .{ .k = 1 }),
- Insn.jmp(.jeq, .{ .k = 0b0010 }, 1, 0),
- Insn.ret(.{ .k = 31 }),
- // lsh x
- // fail if a != 8
- Insn.alu(.lsh, .x),
- Insn.jmp(.jeq, .{ .k = 0b1000 }, 1, 0),
- Insn.ret(.{ .k = 32 }),
- // mod 6
- // fail if a != 2
- Insn.alu(.mod, .{ .k = 6 }),
- Insn.jmp(.jeq, .{ .k = 2 }, 1, 0),
- Insn.ret(.{ .k = 33 }),
- // mod x
- // fail if a != 0
- Insn.alu(.mod, .x),
- Insn.jmp(.jeq, .{ .k = 0 }, 1, 0),
- Insn.ret(.{ .k = 34 }),
- // tax
- // neg
- // fail if a != (u32)-2
- Insn.txa(),
- Insn.alu_neg(),
- Insn.jmp(.jeq, .{ .k = ~@as(u32, 2) + 1 }, 1, 0),
- Insn.ret(.{ .k = 35 }),
- // ja #1 (skip the next instruction)
- Insn.jmp_ja(1),
- Insn.ret(.{ .k = 36 }),
- // ld #20
- // tax
- // fail if a != 20
- // fail if a != x
- Insn.ld_imm(20),
- Insn.tax(),
- Insn.jmp(.jeq, .{ .k = 20 }, 1, 0),
- Insn.ret(.{ .k = 37 }),
- Insn.jmp(.jeq, .x, 1, 0),
- Insn.ret(.{ .k = 38 }),
- // ld #19
- // fail if a == 20
- // fail if a == x
- // fail if a >= 20
- // fail if a >= X
- Insn.ld_imm(19),
- Insn.jmp(.jeq, .{ .k = 20 }, 0, 1),
- Insn.ret(.{ .k = 39 }),
- Insn.jmp(.jeq, .x, 0, 1),
- Insn.ret(.{ .k = 40 }),
- Insn.jmp(.jgt, .{ .k = 20 }, 0, 1),
- Insn.ret(.{ .k = 41 }),
- Insn.jmp(.jgt, .x, 0, 1),
- Insn.ret(.{ .k = 42 }),
- // ld #21
- // fail if a < 20
- // fail if a < x
- Insn.ld_imm(21),
- Insn.jmp(.jgt, .{ .k = 20 }, 1, 0),
- Insn.ret(.{ .k = 43 }),
- Insn.jmp(.jgt, .x, 1, 0),
- Insn.ret(.{ .k = 44 }),
- // ldx #22
- // fail if a < 22
- // fail if a < x
- Insn.ldx_imm(22),
- Insn.jmp(.jge, .{ .k = 22 }, 0, 1),
- Insn.ret(.{ .k = 45 }),
- Insn.jmp(.jge, .x, 0, 1),
- Insn.ret(.{ .k = 46 }),
- // ld #23
- // fail if a >= 22
- // fail if a >= x
- Insn.ld_imm(23),
- Insn.jmp(.jge, .{ .k = 22 }, 1, 0),
- Insn.ret(.{ .k = 47 }),
- Insn.jmp(.jge, .x, 1, 0),
- Insn.ret(.{ .k = 48 }),
- // ldx #0b10100
- // fail if a & 0b10100 == 0
- // fail if a & x == 0
- Insn.ldx_imm(0b10100),
- Insn.jmp(.jset, .{ .k = 0b10100 }, 1, 0),
- Insn.ret(.{ .k = 47 }),
- Insn.jmp(.jset, .x, 1, 0),
- Insn.ret(.{ .k = 48 }),
- // ldx #0
- // fail if a & 0 > 0
- // fail if a & x > 0
- Insn.ldx_imm(0),
- Insn.jmp(.jset, .{ .k = 0 }, 0, 1),
- Insn.ret(.{ .k = 49 }),
- Insn.jmp(.jset, .x, 0, 1),
- Insn.ret(.{ .k = 50 }),
- Insn.ret(.{ .k = 0 }),
- });
- try expectPass(&some_data, &.{
- Insn.ld_imm(35),
- Insn.ld_imm(0),
- Insn.ret(.a),
- });
-
- // Errors
- try expectFail(error.NoReturn, &some_data, &.{
- Insn.ld_imm(10),
- });
- try expectFail(error.InvalidOpcode, &some_data, &.{
- Insn.stmt(0x7f, 0xdeadbeef),
- });
- try expectFail(error.InvalidOffset, &some_data, &.{
- Insn.stmt(LD | ABS | W, 10),
- });
- try expectFail(error.InvalidLocation, &some_data, &.{
- Insn.jmp(.jeq, .{ .k = 0 }, 10, 0),
- });
- try expectFail(error.InvalidLocation, &some_data, &.{
- Insn.jmp(.jeq, .{ .k = 0 }, 0, 10),
- });
-}
diff --git a/lib/std/x/net/ip.zig b/lib/std/x/net/ip.zig
@@ -1,57 +0,0 @@
-const std = @import("../../std.zig");
-
-const fmt = std.fmt;
-
-const IPv4 = std.x.os.IPv4;
-const IPv6 = std.x.os.IPv6;
-const Socket = std.x.os.Socket;
-
-/// A generic IP abstraction.
-const ip = @This();
-
-/// A union of all eligible types of IP addresses.
-pub const Address = union(enum) {
- ipv4: IPv4.Address,
- ipv6: IPv6.Address,
-
- /// Instantiate a new address with a IPv4 host and port.
- pub fn initIPv4(host: IPv4, port: u16) Address {
- return .{ .ipv4 = .{ .host = host, .port = port } };
- }
-
- /// Instantiate a new address with a IPv6 host and port.
- pub fn initIPv6(host: IPv6, port: u16) Address {
- return .{ .ipv6 = .{ .host = host, .port = port } };
- }
-
- /// Re-interpret a generic socket address into an IP address.
- pub fn from(address: Socket.Address) ip.Address {
- return switch (address) {
- .ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address },
- .ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address },
- };
- }
-
- /// Re-interpret an IP address into a generic socket address.
- pub fn into(self: ip.Address) Socket.Address {
- return switch (self) {
- .ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address },
- .ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address },
- };
- }
-
- /// Implements the `std.fmt.format` API.
- pub fn format(
- self: ip.Address,
- comptime layout: []const u8,
- opts: fmt.FormatOptions,
- writer: anytype,
- ) !void {
- if (layout.len != 0) std.fmt.invalidFmtError(layout, self);
- _ = opts;
- switch (self) {
- .ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }),
- .ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }),
- }
- }
-};
diff --git a/lib/std/x/net/tcp.zig b/lib/std/x/net/tcp.zig
@@ -1,447 +0,0 @@
-const std = @import("../../std.zig");
-const builtin = @import("builtin");
-
-const io = std.io;
-const os = std.os;
-const ip = std.x.net.ip;
-
-const fmt = std.fmt;
-const mem = std.mem;
-const testing = std.testing;
-const native_os = builtin.os;
-
-const IPv4 = std.x.os.IPv4;
-const IPv6 = std.x.os.IPv6;
-const Socket = std.x.os.Socket;
-const Buffer = std.x.os.Buffer;
-
-/// A generic TCP socket abstraction.
-const tcp = @This();
-
-/// A TCP client-address pair.
-pub const Connection = struct {
- client: tcp.Client,
- address: ip.Address,
-
- /// Enclose a TCP client and address into a client-address pair.
- pub fn from(conn: Socket.Connection) tcp.Connection {
- return .{
- .client = tcp.Client.from(conn.socket),
- .address = ip.Address.from(conn.address),
- };
- }
-
- /// Unravel a TCP client-address pair into a socket-address pair.
- pub fn into(self: tcp.Connection) Socket.Connection {
- return .{
- .socket = self.client.socket,
- .address = self.address.into(),
- };
- }
-
- /// Closes the underlying client of the connection.
- pub fn deinit(self: tcp.Connection) void {
- self.client.deinit();
- }
-};
-
-/// Possible domains that a TCP client/listener may operate over.
-pub const Domain = enum(u16) {
- ip = os.AF.INET,
- ipv6 = os.AF.INET6,
-};
-
-/// A TCP client.
-pub const Client = struct {
- socket: Socket,
-
- /// Implements `std.io.Reader`.
- pub const Reader = struct {
- client: Client,
- flags: u32,
-
- /// Implements `readFn` for `std.io.Reader`.
- pub fn read(self: Client.Reader, buffer: []u8) !usize {
- return self.client.read(buffer, self.flags);
- }
- };
-
- /// Implements `std.io.Writer`.
- pub const Writer = struct {
- client: Client,
- flags: u32,
-
- /// Implements `writeFn` for `std.io.Writer`.
- pub fn write(self: Client.Writer, buffer: []const u8) !usize {
- return self.client.write(buffer, self.flags);
- }
- };
-
- /// Opens a new client.
- pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Client {
- return Client{
- .socket = try Socket.init(
- @enumToInt(domain),
- os.SOCK.STREAM,
- os.IPPROTO.TCP,
- flags,
- ),
- };
- }
-
- /// Enclose a TCP client over an existing socket.
- pub fn from(socket: Socket) Client {
- return Client{ .socket = socket };
- }
-
- /// Closes the client.
- pub fn deinit(self: Client) void {
- self.socket.deinit();
- }
-
- /// Shutdown either the read side, write side, or all sides of the client's underlying socket.
- pub fn shutdown(self: Client, how: os.ShutdownHow) !void {
- return self.socket.shutdown(how);
- }
-
- /// Have the client attempt to the connect to an address.
- pub fn connect(self: Client, address: ip.Address) !void {
- return self.socket.connect(address.into());
- }
-
- /// Extracts the error set of a function.
- /// TODO: remove after Socket.{read, write} error unions are well-defined across different platforms
- fn ErrorSetOf(comptime Function: anytype) type {
- return @typeInfo(@typeInfo(@TypeOf(Function)).Fn.return_type.?).ErrorUnion.error_set;
- }
-
- /// Wrap `tcp.Client` into `std.io.Reader`.
- pub fn reader(self: Client, flags: u32) io.Reader(Client.Reader, ErrorSetOf(Client.Reader.read), Client.Reader.read) {
- return .{ .context = .{ .client = self, .flags = flags } };
- }
-
- /// Wrap `tcp.Client` into `std.io.Writer`.
- pub fn writer(self: Client, flags: u32) io.Writer(Client.Writer, ErrorSetOf(Client.Writer.write), Client.Writer.write) {
- return .{ .context = .{ .client = self, .flags = flags } };
- }
-
- /// Read data from the socket into the buffer provided with a set of flags
- /// specified. It returns the number of bytes read into the buffer provided.
- pub fn read(self: Client, buf: []u8, flags: u32) !usize {
- return self.socket.read(buf, flags);
- }
-
- /// Write a buffer of data provided to the socket with a set of flags specified.
- /// It returns the number of bytes that are written to the socket.
- pub fn write(self: Client, buf: []const u8, flags: u32) !usize {
- return self.socket.write(buf, flags);
- }
-
- /// Writes multiple I/O vectors with a prepended message header to the socket
- /// with a set of flags specified. It returns the number of bytes that are
- /// written to the socket.
- pub fn writeMessage(self: Client, msg: Socket.Message, flags: u32) !usize {
- return self.socket.writeMessage(msg, flags);
- }
-
- /// Read multiple I/O vectors with a prepended message header from the socket
- /// with a set of flags specified. It returns the number of bytes that were
- /// read into the buffer provided.
- pub fn readMessage(self: Client, msg: *Socket.Message, flags: u32) !usize {
- return self.socket.readMessage(msg, flags);
- }
-
- /// Query and return the latest cached error on the client's underlying socket.
- pub fn getError(self: Client) !void {
- return self.socket.getError();
- }
-
- /// Query the read buffer size of the client's underlying socket.
- pub fn getReadBufferSize(self: Client) !u32 {
- return self.socket.getReadBufferSize();
- }
-
- /// Query the write buffer size of the client's underlying socket.
- pub fn getWriteBufferSize(self: Client) !u32 {
- return self.socket.getWriteBufferSize();
- }
-
- /// Query the address that the client's socket is locally bounded to.
- pub fn getLocalAddress(self: Client) !ip.Address {
- return ip.Address.from(try self.socket.getLocalAddress());
- }
-
- /// Query the address that the socket is connected to.
- pub fn getRemoteAddress(self: Client) !ip.Address {
- return ip.Address.from(try self.socket.getRemoteAddress());
- }
-
- /// Have close() or shutdown() syscalls block until all queued messages in the client have been successfully
- /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption`
- /// if the host does not support the option for a socket to linger around up until a timeout specified in
- /// seconds.
- pub fn setLinger(self: Client, timeout_seconds: ?u16) !void {
- return self.socket.setLinger(timeout_seconds);
- }
-
- /// Have keep-alive messages be sent periodically. The timing in which keep-alive messages are sent are
- /// dependant on operating system settings. It returns `error.UnsupportedSocketOption` if the host does
- /// not support periodically sending keep-alive messages on connection-oriented sockets.
- pub fn setKeepAlive(self: Client, enabled: bool) !void {
- return self.socket.setKeepAlive(enabled);
- }
-
- /// Disable Nagle's algorithm on a TCP socket. It returns `error.UnsupportedSocketOption` if
- /// the host does not support sockets disabling Nagle's algorithm.
- pub fn setNoDelay(self: Client, enabled: bool) !void {
- if (@hasDecl(os.TCP, "NODELAY")) {
- const bytes = mem.asBytes(&@as(usize, @boolToInt(enabled)));
- return self.socket.setOption(os.IPPROTO.TCP, os.TCP.NODELAY, bytes);
- }
- return error.UnsupportedSocketOption;
- }
-
- /// Enables TCP Quick ACK on a TCP socket to immediately send rather than delay ACKs when necessary. It returns
- /// `error.UnsupportedSocketOption` if the host does not support TCP Quick ACK.
- pub fn setQuickACK(self: Client, enabled: bool) !void {
- if (@hasDecl(os.TCP, "QUICKACK")) {
- return self.socket.setOption(os.IPPROTO.TCP, os.TCP.QUICKACK, mem.asBytes(&@as(u32, @boolToInt(enabled))));
- }
- return error.UnsupportedSocketOption;
- }
-
- /// Set the write buffer size of the socket.
- pub fn setWriteBufferSize(self: Client, size: u32) !void {
- return self.socket.setWriteBufferSize(size);
- }
-
- /// Set the read buffer size of the socket.
- pub fn setReadBufferSize(self: Client, size: u32) !void {
- return self.socket.setReadBufferSize(size);
- }
-
- /// Set a timeout on the socket that is to occur if no messages are successfully written
- /// to its bound destination after a specified number of milliseconds. A subsequent write
- /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded.
- pub fn setWriteTimeout(self: Client, milliseconds: u32) !void {
- return self.socket.setWriteTimeout(milliseconds);
- }
-
- /// Set a timeout on the socket that is to occur if no messages are successfully read
- /// from its bound destination after a specified number of milliseconds. A subsequent
- /// read from the socket will thereafter return `error.WouldBlock` should the timeout be
- /// exceeded.
- pub fn setReadTimeout(self: Client, milliseconds: u32) !void {
- return self.socket.setReadTimeout(milliseconds);
- }
-};
-
-/// A TCP listener.
-pub const Listener = struct {
- socket: Socket,
-
- /// Opens a new listener.
- pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Listener {
- return Listener{
- .socket = try Socket.init(
- @enumToInt(domain),
- os.SOCK.STREAM,
- os.IPPROTO.TCP,
- flags,
- ),
- };
- }
-
- /// Closes the listener.
- pub fn deinit(self: Listener) void {
- self.socket.deinit();
- }
-
- /// Shuts down the underlying listener's socket. The next subsequent call, or
- /// a current pending call to accept() after shutdown is called will return
- /// an error.
- pub fn shutdown(self: Listener) !void {
- return self.socket.shutdown(.recv);
- }
-
- /// Binds the listener's socket to an address.
- pub fn bind(self: Listener, address: ip.Address) !void {
- return self.socket.bind(address.into());
- }
-
- /// Start listening for incoming connections.
- pub fn listen(self: Listener, max_backlog_size: u31) !void {
- return self.socket.listen(max_backlog_size);
- }
-
- /// Accept a pending incoming connection queued to the kernel backlog
- /// of the listener's socket.
- pub fn accept(self: Listener, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !tcp.Connection {
- return tcp.Connection.from(try self.socket.accept(flags));
- }
-
- /// Query and return the latest cached error on the listener's underlying socket.
- pub fn getError(self: Client) !void {
- return self.socket.getError();
- }
-
- /// Query the address that the listener's socket is locally bounded to.
- pub fn getLocalAddress(self: Listener) !ip.Address {
- return ip.Address.from(try self.socket.getLocalAddress());
- }
-
- /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if
- /// the host does not support sockets listening the same address.
- pub fn setReuseAddress(self: Listener, enabled: bool) !void {
- return self.socket.setReuseAddress(enabled);
- }
-
- /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if
- /// the host does not supports sockets listening on the same port.
- pub fn setReusePort(self: Listener, enabled: bool) !void {
- return self.socket.setReusePort(enabled);
- }
-
- /// Enables TCP Fast Open (RFC 7413) on a TCP socket. It returns `error.UnsupportedSocketOption` if the host does not
- /// support TCP Fast Open.
- pub fn setFastOpen(self: Listener, enabled: bool) !void {
- if (@hasDecl(os.TCP, "FASTOPEN")) {
- return self.socket.setOption(os.IPPROTO.TCP, os.TCP.FASTOPEN, mem.asBytes(&@as(u32, @boolToInt(enabled))));
- }
- return error.UnsupportedSocketOption;
- }
-
- /// Set a timeout on the listener that is to occur if no new incoming connections come in
- /// after a specified number of milliseconds. A subsequent accept call to the listener
- /// will thereafter return `error.WouldBlock` should the timeout be exceeded.
- pub fn setAcceptTimeout(self: Listener, milliseconds: usize) !void {
- return self.socket.setReadTimeout(milliseconds);
- }
-};
-
-test "tcp: create client/listener pair" {
- if (native_os.tag == .wasi) return error.SkipZigTest;
-
- const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
- defer listener.deinit();
-
- try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
- try listener.listen(128);
-
- var binded_address = try listener.getLocalAddress();
- switch (binded_address) {
- .ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
- .ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
- }
-
- const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
- defer client.deinit();
-
- try client.connect(binded_address);
-
- const conn = try listener.accept(.{ .close_on_exec = true });
- defer conn.deinit();
-}
-
-test "tcp/client: 1ms read timeout" {
- if (native_os.tag == .wasi) return error.SkipZigTest;
-
- const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
- defer listener.deinit();
-
- try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
- try listener.listen(128);
-
- var binded_address = try listener.getLocalAddress();
- switch (binded_address) {
- .ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
- .ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
- }
-
- const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
- defer client.deinit();
-
- try client.connect(binded_address);
- try client.setReadTimeout(1);
-
- const conn = try listener.accept(.{ .close_on_exec = true });
- defer conn.deinit();
-
- var buf: [1]u8 = undefined;
- try testing.expectError(error.WouldBlock, client.reader(0).read(&buf));
-}
-
-test "tcp/client: read and write multiple vectors" {
- if (native_os.tag == .wasi) return error.SkipZigTest;
-
- if (builtin.os.tag == .windows) {
- // https://github.com/ziglang/zig/issues/13893
- return error.SkipZigTest;
- }
-
- const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
- defer listener.deinit();
-
- try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
- try listener.listen(128);
-
- var binded_address = try listener.getLocalAddress();
- switch (binded_address) {
- .ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
- .ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
- }
-
- const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
- defer client.deinit();
-
- try client.connect(binded_address);
-
- const conn = try listener.accept(.{ .close_on_exec = true });
- defer conn.deinit();
-
- const message = "hello world";
- _ = try conn.client.writeMessage(Socket.Message.fromBuffers(&[_]Buffer{
- Buffer.from(message[0 .. message.len / 2]),
- Buffer.from(message[message.len / 2 ..]),
- }), 0);
-
- var buf: [message.len + 1]u8 = undefined;
- var msg = Socket.Message.fromBuffers(&[_]Buffer{
- Buffer.from(buf[0 .. message.len / 2]),
- Buffer.from(buf[message.len / 2 ..]),
- });
- _ = try client.readMessage(&msg, 0);
-
- try testing.expectEqualStrings(message, buf[0..message.len]);
-}
-
-test "tcp/listener: bind to unspecified ipv4 address" {
- if (native_os.tag == .wasi) return error.SkipZigTest;
-
- const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
- defer listener.deinit();
-
- try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
- try listener.listen(128);
-
- const address = try listener.getLocalAddress();
- try testing.expect(address == .ipv4);
-}
-
-test "tcp/listener: bind to unspecified ipv6 address" {
- if (native_os.tag == .wasi) return error.SkipZigTest;
-
- if (builtin.os.tag == .windows) {
- // https://github.com/ziglang/zig/issues/13893
- return error.SkipZigTest;
- }
-
- const listener = try tcp.Listener.init(.ipv6, .{ .close_on_exec = true });
- defer listener.deinit();
-
- try listener.bind(ip.Address.initIPv6(IPv6.unspecified, 0));
- try listener.listen(128);
-
- const address = try listener.getLocalAddress();
- try testing.expect(address == .ipv6);
-}
diff --git a/lib/std/x/os/io.zig b/lib/std/x/os/io.zig
@@ -1,224 +0,0 @@
-const std = @import("../../std.zig");
-const builtin = @import("builtin");
-
-const os = std.os;
-const mem = std.mem;
-const testing = std.testing;
-const native_os = builtin.os;
-const linux = std.os.linux;
-
-/// POSIX `iovec`, or Windows `WSABUF`. The difference between the two are the ordering
-/// of fields, alongside the length being represented as either a ULONG or a size_t.
-pub const Buffer = if (native_os.tag == .windows)
- extern struct {
- len: c_ulong,
- ptr: usize,
-
- pub fn from(slice: []const u8) Buffer {
- return .{ .len = @intCast(c_ulong, slice.len), .ptr = @ptrToInt(slice.ptr) };
- }
-
- pub fn into(self: Buffer) []const u8 {
- return @intToPtr([*]const u8, self.ptr)[0..self.len];
- }
-
- pub fn intoMutable(self: Buffer) []u8 {
- return @intToPtr([*]u8, self.ptr)[0..self.len];
- }
- }
-else
- extern struct {
- ptr: usize,
- len: usize,
-
- pub fn from(slice: []const u8) Buffer {
- return .{ .ptr = @ptrToInt(slice.ptr), .len = slice.len };
- }
-
- pub fn into(self: Buffer) []const u8 {
- return @intToPtr([*]const u8, self.ptr)[0..self.len];
- }
-
- pub fn intoMutable(self: Buffer) []u8 {
- return @intToPtr([*]u8, self.ptr)[0..self.len];
- }
- };
-
-pub const Reactor = struct {
- pub const InitFlags = enum {
- close_on_exec,
- };
-
- pub const Event = struct {
- data: usize,
- is_error: bool,
- is_hup: bool,
- is_readable: bool,
- is_writable: bool,
- };
-
- pub const Interest = struct {
- hup: bool = false,
- oneshot: bool = false,
- readable: bool = false,
- writable: bool = false,
- };
-
- fd: os.fd_t,
-
- pub fn init(flags: std.enums.EnumFieldStruct(Reactor.InitFlags, bool, false)) !Reactor {
- var raw_flags: u32 = 0;
- const set = std.EnumSet(Reactor.InitFlags).init(flags);
- if (set.contains(.close_on_exec)) raw_flags |= linux.EPOLL.CLOEXEC;
- return Reactor{ .fd = try os.epoll_create1(raw_flags) };
- }
-
- pub fn deinit(self: Reactor) void {
- os.close(self.fd);
- }
-
- pub fn update(self: Reactor, fd: os.fd_t, identifier: usize, interest: Reactor.Interest) !void {
- var flags: u32 = 0;
- flags |= if (interest.oneshot) linux.EPOLL.ONESHOT else linux.EPOLL.ET;
- if (interest.hup) flags |= linux.EPOLL.RDHUP;
- if (interest.readable) flags |= linux.EPOLL.IN;
- if (interest.writable) flags |= linux.EPOLL.OUT;
-
- const event = &linux.epoll_event{
- .events = flags,
- .data = .{ .ptr = identifier },
- };
-
- os.epoll_ctl(self.fd, linux.EPOLL.CTL_MOD, fd, event) catch |err| switch (err) {
- error.FileDescriptorNotRegistered => try os.epoll_ctl(self.fd, linux.EPOLL.CTL_ADD, fd, event),
- else => return err,
- };
- }
-
- pub fn remove(self: Reactor, fd: os.fd_t) !void {
- // directly from man epoll_ctl BUGS section
- // In kernel versions before 2.6.9, the EPOLL_CTL_DEL operation re‐
- // quired a non-null pointer in event, even though this argument is
- // ignored. Since Linux 2.6.9, event can be specified as NULL when
- // using EPOLL_CTL_DEL. Applications that need to be portable to
- // kernels before 2.6.9 should specify a non-null pointer in event.
- var event = linux.epoll_event{
- .events = 0,
- .data = .{ .ptr = 0 },
- };
-
- return os.epoll_ctl(self.fd, linux.EPOLL.CTL_DEL, fd, &event);
- }
-
- pub fn poll(self: Reactor, comptime max_num_events: comptime_int, closure: anytype, timeout_milliseconds: ?u64) !void {
- var events: [max_num_events]linux.epoll_event = undefined;
-
- const num_events = os.epoll_wait(self.fd, &events, if (timeout_milliseconds) |ms| @intCast(i32, ms) else -1);
- for (events[0..num_events]) |ev| {
- const is_error = ev.events & linux.EPOLL.ERR != 0;
- const is_hup = ev.events & (linux.EPOLL.HUP | linux.EPOLL.RDHUP) != 0;
- const is_readable = ev.events & linux.EPOLL.IN != 0;
- const is_writable = ev.events & linux.EPOLL.OUT != 0;
-
- try closure.call(Reactor.Event{
- .data = ev.data.ptr,
- .is_error = is_error,
- .is_hup = is_hup,
- .is_readable = is_readable,
- .is_writable = is_writable,
- });
- }
- }
-};
-
-test "reactor/linux: drive async tcp client/listener pair" {
- if (native_os.tag != .linux) return error.SkipZigTest;
-
- const ip = std.x.net.ip;
- const tcp = std.x.net.tcp;
-
- const IPv4 = std.x.os.IPv4;
- const IPv6 = std.x.os.IPv6;
-
- const reactor = try Reactor.init(.{ .close_on_exec = true });
- defer reactor.deinit();
-
- const listener = try tcp.Listener.init(.ip, .{
- .close_on_exec = true,
- .nonblocking = true,
- });
- defer listener.deinit();
-
- try reactor.update(listener.socket.fd, 0, .{ .readable = true });
- try reactor.poll(1, struct {
- fn call(event: Reactor.Event) !void {
- try testing.expectEqual(Reactor.Event{
- .data = 0,
- .is_error = false,
- .is_hup = true,
- .is_readable = false,
- .is_writable = false,
- }, event);
- }
- }, null);
-
- try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
- try listener.listen(128);
-
- var binded_address = try listener.getLocalAddress();
- switch (binded_address) {
- .ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
- .ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
- }
-
- const client = try tcp.Client.init(.ip, .{
- .close_on_exec = true,
- .nonblocking = true,
- });
- defer client.deinit();
-
- try reactor.update(client.socket.fd, 1, .{ .readable = true, .writable = true });
- try reactor.poll(1, struct {
- fn call(event: Reactor.Event) !void {
- try testing.expectEqual(Reactor.Event{
- .data = 1,
- .is_error = false,
- .is_hup = true,
- .is_readable = false,
- .is_writable = true,
- }, event);
- }
- }, null);
-
- client.connect(binded_address) catch |err| switch (err) {
- error.WouldBlock => {},
- else => return err,
- };
-
- try reactor.poll(1, struct {
- fn call(event: Reactor.Event) !void {
- try testing.expectEqual(Reactor.Event{
- .data = 1,
- .is_error = false,
- .is_hup = false,
- .is_readable = false,
- .is_writable = true,
- }, event);
- }
- }, null);
-
- try reactor.poll(1, struct {
- fn call(event: Reactor.Event) !void {
- try testing.expectEqual(Reactor.Event{
- .data = 0,
- .is_error = false,
- .is_hup = false,
- .is_readable = true,
- .is_writable = false,
- }, event);
- }
- }, null);
-
- try reactor.remove(client.socket.fd);
- try reactor.remove(listener.socket.fd);
-}
diff --git a/lib/std/x/os/net.zig b/lib/std/x/os/net.zig
@@ -1,605 +0,0 @@
-const std = @import("../../std.zig");
-const builtin = @import("builtin");
-
-const os = std.os;
-const fmt = std.fmt;
-const mem = std.mem;
-const math = std.math;
-const testing = std.testing;
-const native_os = builtin.os;
-const have_ifnamesize = @hasDecl(os.system, "IFNAMESIZE");
-
-pub const ResolveScopeIdError = error{
- NameTooLong,
- PermissionDenied,
- AddressFamilyNotSupported,
- ProtocolFamilyNotAvailable,
- ProcessFdQuotaExceeded,
- SystemFdQuotaExceeded,
- SystemResources,
- ProtocolNotSupported,
- SocketTypeNotSupported,
- InterfaceNotFound,
- FileSystem,
- Unexpected,
-};
-
-/// Resolves a network interface name into a scope/zone ID. It returns
-/// an error if either resolution fails, or if the interface name is
-/// too long.
-pub fn resolveScopeId(name: []const u8) ResolveScopeIdError!u32 {
- if (have_ifnamesize) {
- if (name.len >= os.IFNAMESIZE) return error.NameTooLong;
-
- if (native_os.tag == .windows or comptime native_os.tag.isDarwin()) {
- var interface_name: [os.IFNAMESIZE:0]u8 = undefined;
- mem.copy(u8, &interface_name, name);
- interface_name[name.len] = 0;
-
- const rc = blk: {
- if (native_os.tag == .windows) {
- break :blk os.windows.ws2_32.if_nametoindex(@ptrCast([*:0]const u8, &interface_name));
- } else {
- const index = os.system.if_nametoindex(@ptrCast([*:0]const u8, &interface_name));
- break :blk @bitCast(u32, index);
- }
- };
- if (rc == 0) {
- return error.InterfaceNotFound;
- }
- return rc;
- }
-
- if (native_os.tag == .linux) {
- const fd = try os.socket(os.AF.INET, os.SOCK.DGRAM, 0);
- defer os.closeSocket(fd);
-
- var f: os.ifreq = undefined;
- mem.copy(u8, &f.ifrn.name, name);
- f.ifrn.name[name.len] = 0;
-
- try os.ioctl_SIOCGIFINDEX(fd, &f);
-
- return @bitCast(u32, f.ifru.ivalue);
- }
- }
-
- return error.InterfaceNotFound;
-}
-
-/// An IPv4 address comprised of 4 bytes.
-pub const IPv4 = extern struct {
- /// A IPv4 host-port pair.
- pub const Address = extern struct {
- host: IPv4,
- port: u16,
- };
-
- /// Octets of a IPv4 address designating the local host.
- pub const localhost_octets = [_]u8{ 127, 0, 0, 1 };
-
- /// The IPv4 address of the local host.
- pub const localhost: IPv4 = .{ .octets = localhost_octets };
-
- /// Octets of an unspecified IPv4 address.
- pub const unspecified_octets = [_]u8{0} ** 4;
-
- /// An unspecified IPv4 address.
- pub const unspecified: IPv4 = .{ .octets = unspecified_octets };
-
- /// Octets of a broadcast IPv4 address.
- pub const broadcast_octets = [_]u8{255} ** 4;
-
- /// An IPv4 broadcast address.
- pub const broadcast: IPv4 = .{ .octets = broadcast_octets };
-
- /// The prefix octet pattern of a link-local IPv4 address.
- pub const link_local_prefix = [_]u8{ 169, 254 };
-
- /// The prefix octet patterns of IPv4 addresses intended for
- /// documentation.
- pub const documentation_prefixes = [_][]const u8{
- &[_]u8{ 192, 0, 2 },
- &[_]u8{ 198, 51, 100 },
- &[_]u8{ 203, 0, 113 },
- };
-
- octets: [4]u8,
-
- /// Returns whether or not the two addresses are equal to, less than, or
- /// greater than each other.
- pub fn cmp(self: IPv4, other: IPv4) math.Order {
- return mem.order(u8, &self.octets, &other.octets);
- }
-
- /// Returns true if both addresses are semantically equivalent.
- pub fn eql(self: IPv4, other: IPv4) bool {
- return mem.eql(u8, &self.octets, &other.octets);
- }
-
- /// Returns true if the address is a loopback address.
- pub fn isLoopback(self: IPv4) bool {
- return self.octets[0] == 127;
- }
-
- /// Returns true if the address is an unspecified IPv4 address.
- pub fn isUnspecified(self: IPv4) bool {
- return mem.eql(u8, &self.octets, &unspecified_octets);
- }
-
- /// Returns true if the address is a private IPv4 address.
- pub fn isPrivate(self: IPv4) bool {
- return self.octets[0] == 10 or
- (self.octets[0] == 172 and self.octets[1] >= 16 and self.octets[1] <= 31) or
- (self.octets[0] == 192 and self.octets[1] == 168);
- }
-
- /// Returns true if the address is a link-local IPv4 address.
- pub fn isLinkLocal(self: IPv4) bool {
- return mem.startsWith(u8, &self.octets, &link_local_prefix);
- }
-
- /// Returns true if the address is a multicast IPv4 address.
- pub fn isMulticast(self: IPv4) bool {
- return self.octets[0] >= 224 and self.octets[0] <= 239;
- }
-
- /// Returns true if the address is a IPv4 broadcast address.
- pub fn isBroadcast(self: IPv4) bool {
- return mem.eql(u8, &self.octets, &broadcast_octets);
- }
-
- /// Returns true if the address is in a range designated for documentation. Refer
- /// to IETF RFC 5737 for more details.
- pub fn isDocumentation(self: IPv4) bool {
- inline for (documentation_prefixes) |prefix| {
- if (mem.startsWith(u8, &self.octets, prefix)) {
- return true;
- }
- }
- return false;
- }
-
- /// Implements the `std.fmt.format` API.
- pub fn format(
- self: IPv4,
- comptime layout: []const u8,
- opts: fmt.FormatOptions,
- writer: anytype,
- ) !void {
- _ = opts;
- if (layout.len != 0) std.fmt.invalidFmtError(layout, self);
-
- try fmt.format(writer, "{}.{}.{}.{}", .{
- self.octets[0],
- self.octets[1],
- self.octets[2],
- self.octets[3],
- });
- }
-
- /// Set of possible errors that may encountered when parsing an IPv4
- /// address.
- pub const ParseError = error{
- UnexpectedEndOfOctet,
- TooManyOctets,
- OctetOverflow,
- UnexpectedToken,
- IncompleteAddress,
- };
-
- /// Parses an arbitrary IPv4 address.
- pub fn parse(buf: []const u8) ParseError!IPv4 {
- var octets: [4]u8 = undefined;
- var octet: u8 = 0;
-
- var index: u8 = 0;
- var saw_any_digits: bool = false;
-
- for (buf) |c| {
- switch (c) {
- '.' => {
- if (!saw_any_digits) return error.UnexpectedEndOfOctet;
- if (index == 3) return error.TooManyOctets;
- octets[index] = octet;
- index += 1;
- octet = 0;
- saw_any_digits = false;
- },
- '0'...'9' => {
- saw_any_digits = true;
- octet = math.mul(u8, octet, 10) catch return error.OctetOverflow;
- octet = math.add(u8, octet, c - '0') catch return error.OctetOverflow;
- },
- else => return error.UnexpectedToken,
- }
- }
-
- if (index == 3 and saw_any_digits) {
- octets[index] = octet;
- return IPv4{ .octets = octets };
- }
-
- return error.IncompleteAddress;
- }
-
- /// Maps the address to its IPv6 equivalent. In most cases, you would
- /// want to map the address to its IPv6 equivalent rather than directly
- /// re-interpreting the address.
- pub fn mapToIPv6(self: IPv4) IPv6 {
- var octets: [16]u8 = undefined;
- mem.copy(u8, octets[0..12], &IPv6.v4_mapped_prefix);
- mem.copy(u8, octets[12..], &self.octets);
- return IPv6{ .octets = octets, .scope_id = IPv6.no_scope_id };
- }
-
- /// Directly re-interprets the address to its IPv6 equivalent. In most
- /// cases, you would want to map the address to its IPv6 equivalent rather
- /// than directly re-interpreting the address.
- pub fn toIPv6(self: IPv4) IPv6 {
- var octets: [16]u8 = undefined;
- mem.set(u8, octets[0..12], 0);
- mem.copy(u8, octets[12..], &self.octets);
- return IPv6{ .octets = octets, .scope_id = IPv6.no_scope_id };
- }
-};
-
-/// An IPv6 address comprised of 16 bytes for an address, and 4 bytes
-/// for a scope ID; cumulatively summing to 20 bytes in total.
-pub const IPv6 = extern struct {
- /// A IPv6 host-port pair.
- pub const Address = extern struct {
- host: IPv6,
- port: u16,
- };
-
- /// Octets of a IPv6 address designating the local host.
- pub const localhost_octets = [_]u8{0} ** 15 ++ [_]u8{0x01};
-
- /// The IPv6 address of the local host.
- pub const localhost: IPv6 = .{
- .octets = localhost_octets,
- .scope_id = no_scope_id,
- };
-
- /// Octets of an unspecified IPv6 address.
- pub const unspecified_octets = [_]u8{0} ** 16;
-
- /// An unspecified IPv6 address.
- pub const unspecified: IPv6 = .{
- .octets = unspecified_octets,
- .scope_id = no_scope_id,
- };
-
- /// The prefix of a IPv6 address that is mapped to a IPv4 address.
- pub const v4_mapped_prefix = [_]u8{0} ** 10 ++ [_]u8{0xFF} ** 2;
-
- /// A marker value used to designate an IPv6 address with no
- /// associated scope ID.
- pub const no_scope_id = math.maxInt(u32);
-
- octets: [16]u8,
- scope_id: u32,
-
- /// Returns whether or not the two addresses are equal to, less than, or
- /// greater than each other.
- pub fn cmp(self: IPv6, other: IPv6) math.Order {
- return switch (mem.order(u8, self.octets, other.octets)) {
- .eq => math.order(self.scope_id, other.scope_id),
- else => |order| order,
- };
- }
-
- /// Returns true if both addresses are semantically equivalent.
- pub fn eql(self: IPv6, other: IPv6) bool {
- return self.scope_id == other.scope_id and mem.eql(u8, &self.octets, &other.octets);
- }
-
- /// Returns true if the address is an unspecified IPv6 address.
- pub fn isUnspecified(self: IPv6) bool {
- return mem.eql(u8, &self.octets, &unspecified_octets);
- }
-
- /// Returns true if the address is a loopback address.
- pub fn isLoopback(self: IPv6) bool {
- return mem.eql(u8, self.octets[0..3], &[_]u8{ 0, 0, 0 }) and
- mem.eql(u8, self.octets[12..], &[_]u8{ 0, 0, 0, 1 });
- }
-
- /// Returns true if the address maps to an IPv4 address.
- pub fn mapsToIPv4(self: IPv6) bool {
- return mem.startsWith(u8, &self.octets, &v4_mapped_prefix);
- }
-
- /// Returns an IPv4 address representative of the address should
- /// it the address be mapped to an IPv4 address. It returns null
- /// otherwise.
- pub fn toIPv4(self: IPv6) ?IPv4 {
- if (!self.mapsToIPv4()) return null;
- return IPv4{ .octets = self.octets[12..][0..4].* };
- }
-
- /// Returns true if the address is a multicast IPv6 address.
- pub fn isMulticast(self: IPv6) bool {
- return self.octets[0] == 0xFF;
- }
-
- /// Returns true if the address is a unicast link local IPv6 address.
- pub fn isLinkLocal(self: IPv6) bool {
- return self.octets[0] == 0xFE and self.octets[1] & 0xC0 == 0x80;
- }
-
- /// Returns true if the address is a deprecated unicast site local
- /// IPv6 address. Refer to IETF RFC 3879 for more details as to
- /// why they are deprecated.
- pub fn isSiteLocal(self: IPv6) bool {
- return self.octets[0] == 0xFE and self.octets[1] & 0xC0 == 0xC0;
- }
-
- /// IPv6 multicast address scopes.
- pub const Scope = enum(u8) {
- interface = 1,
- link = 2,
- realm = 3,
- admin = 4,
- site = 5,
- organization = 8,
- global = 14,
- unknown = 0xFF,
- };
-
- /// Returns the multicast scope of the address.
- pub fn scope(self: IPv6) Scope {
- if (!self.isMulticast()) return .unknown;
-
- return switch (self.octets[0] & 0x0F) {
- 1 => .interface,
- 2 => .link,
- 3 => .realm,
- 4 => .admin,
- 5 => .site,
- 8 => .organization,
- 14 => .global,
- else => .unknown,
- };
- }
-
- /// Implements the `std.fmt.format` API. Specifying 'x' or 's' formats the
- /// address lower-cased octets, while specifying 'X' or 'S' formats the
- /// address using upper-cased ASCII octets.
- ///
- /// The default specifier is 'x'.
- pub fn format(
- self: IPv6,
- comptime layout: []const u8,
- opts: fmt.FormatOptions,
- writer: anytype,
- ) !void {
- _ = opts;
- const specifier = comptime &[_]u8{if (layout.len == 0) 'x' else switch (layout[0]) {
- 'x', 'X' => |specifier| specifier,
- 's' => 'x',
- 'S' => 'X',
- else => std.fmt.invalidFmtError(layout, self),
- }};
-
- if (mem.startsWith(u8, &self.octets, &v4_mapped_prefix)) {
- return fmt.format(writer, "::{" ++ specifier ++ "}{" ++ specifier ++ "}:{}.{}.{}.{}", .{
- 0xFF,
- 0xFF,
- self.octets[12],
- self.octets[13],
- self.octets[14],
- self.octets[15],
- });
- }
-
- const zero_span: struct { from: usize, to: usize } = span: {
- var i: usize = 0;
- while (i < self.octets.len) : (i += 2) {
- if (self.octets[i] == 0 and self.octets[i + 1] == 0) break;
- } else break :span .{ .from = 0, .to = 0 };
-
- const from = i;
-
- while (i < self.octets.len) : (i += 2) {
- if (self.octets[i] != 0 or self.octets[i + 1] != 0) break;
- }
-
- break :span .{ .from = from, .to = i };
- };
-
- var i: usize = 0;
- while (i != 16) : (i += 2) {
- if (zero_span.from != zero_span.to and i == zero_span.from) {
- try writer.writeAll("::");
- } else if (i >= zero_span.from and i < zero_span.to) {} else {
- if (i != 0 and i != zero_span.to) try writer.writeAll(":");
-
- const val = @as(u16, self.octets[i]) << 8 | self.octets[i + 1];
- try fmt.formatIntValue(val, specifier, .{}, writer);
- }
- }
-
- if (self.scope_id != no_scope_id and self.scope_id != 0) {
- try fmt.format(writer, "%{d}", .{self.scope_id});
- }
- }
-
- /// Set of possible errors that may encountered when parsing an IPv6
- /// address.
- pub const ParseError = error{
- MalformedV4Mapping,
- InterfaceNotFound,
- UnknownScopeId,
- } || IPv4.ParseError;
-
- /// Parses an arbitrary IPv6 address, including link-local addresses.
- pub fn parse(buf: []const u8) ParseError!IPv6 {
- if (mem.lastIndexOfScalar(u8, buf, '%')) |index| {
- const ip_slice = buf[0..index];
- const scope_id_slice = buf[index + 1 ..];
-
- if (scope_id_slice.len == 0) return error.UnknownScopeId;
-
- const scope_id: u32 = switch (scope_id_slice[0]) {
- '0'...'9' => fmt.parseInt(u32, scope_id_slice, 10),
- else => resolveScopeId(scope_id_slice) catch |err| switch (err) {
- error.InterfaceNotFound => return error.InterfaceNotFound,
- else => err,
- },
- } catch return error.UnknownScopeId;
-
- return parseWithScopeID(ip_slice, scope_id);
- }
-
- return parseWithScopeID(buf, no_scope_id);
- }
-
- /// Parses an IPv6 address with a pre-specified scope ID. Presumes
- /// that the address is not a link-local address.
- pub fn parseWithScopeID(buf: []const u8, scope_id: u32) ParseError!IPv6 {
- var octets: [16]u8 = undefined;
- var octet: u16 = 0;
- var tail: [16]u8 = undefined;
-
- var out: []u8 = &octets;
- var index: u8 = 0;
-
- var saw_any_digits: bool = false;
- var abbrv: bool = false;
-
- for (buf) |c, i| {
- switch (c) {
- ':' => {
- if (!saw_any_digits) {
- if (abbrv) return error.UnexpectedToken;
- if (i != 0) abbrv = true;
- mem.set(u8, out[index..], 0);
- out = &tail;
- index = 0;
- continue;
- }
- if (index == 14) return error.TooManyOctets;
-
- out[index] = @truncate(u8, octet >> 8);
- index += 1;
- out[index] = @truncate(u8, octet);
- index += 1;
-
- octet = 0;
- saw_any_digits = false;
- },
- '.' => {
- if (!abbrv or out[0] != 0xFF and out[1] != 0xFF) {
- return error.MalformedV4Mapping;
- }
- const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1;
- const v4 = try IPv4.parse(buf[start_index..]);
- octets[10] = 0xFF;
- octets[11] = 0xFF;
- mem.copy(u8, octets[12..], &v4.octets);
-
- return IPv6{ .octets = octets, .scope_id = scope_id };
- },
- else => {
- saw_any_digits = true;
- const digit = fmt.charToDigit(c, 16) catch return error.UnexpectedToken;
- octet = math.mul(u16, octet, 16) catch return error.OctetOverflow;
- octet = math.add(u16, octet, digit) catch return error.OctetOverflow;
- },
- }
- }
-
- if (!saw_any_digits and !abbrv) {
- return error.IncompleteAddress;
- }
-
- if (index == 14) {
- out[14] = @truncate(u8, octet >> 8);
- out[15] = @truncate(u8, octet);
- } else {
- out[index] = @truncate(u8, octet >> 8);
- index += 1;
- out[index] = @truncate(u8, octet);
- index += 1;
- mem.copy(u8, octets[16 - index ..], out[0..index]);
- }
-
- return IPv6{ .octets = octets, .scope_id = scope_id };
- }
-};
-
-test {
- testing.refAllDecls(@This());
-}
-
-test "ip: convert to and from ipv6" {
- try testing.expectFmt("::7f00:1", "{}", .{IPv4.localhost.toIPv6()});
- try testing.expect(!IPv4.localhost.toIPv6().mapsToIPv4());
-
- try testing.expectFmt("::ffff:127.0.0.1", "{}", .{IPv4.localhost.mapToIPv6()});
- try testing.expect(IPv4.localhost.mapToIPv6().mapsToIPv4());
-
- try testing.expect(IPv4.localhost.toIPv6().toIPv4() == null);
- try testing.expectFmt("127.0.0.1", "{?}", .{IPv4.localhost.mapToIPv6().toIPv4()});
-}
-
-test "ipv4: parse & format" {
- const cases = [_][]const u8{
- "0.0.0.0",
- "255.255.255.255",
- "1.2.3.4",
- "123.255.0.91",
- "127.0.0.1",
- };
-
- for (cases) |case| {
- try testing.expectFmt(case, "{}", .{try IPv4.parse(case)});
- }
-}
-
-test "ipv6: parse & format" {
- const inputs = [_][]const u8{
- "FF01:0:0:0:0:0:0:FB",
- "FF01::Fb",
- "::1",
- "::",
- "2001:db8::",
- "::1234:5678",
- "2001:db8::1234:5678",
- "::ffff:123.5.123.5",
- };
-
- const outputs = [_][]const u8{
- "ff01::fb",
- "ff01::fb",
- "::1",
- "::",
- "2001:db8::",
- "::1234:5678",
- "2001:db8::1234:5678",
- "::ffff:123.5.123.5",
- };
-
- for (inputs) |input, i| {
- try testing.expectFmt(outputs[i], "{}", .{try IPv6.parse(input)});
- }
-}
-
-test "ipv6: parse & format addresses with scope ids" {
- if (!have_ifnamesize) return error.SkipZigTest;
- const iface = if (native_os.tag == .linux)
- "lo"
- else
- "lo0";
- const input = "FF01::FB%" ++ iface;
- const output = "ff01::fb%1";
-
- const parsed = IPv6.parse(input) catch |err| switch (err) {
- error.InterfaceNotFound => return,
- else => return err,
- };
-
- try testing.expectFmt(output, "{}", .{parsed});
-}
diff --git a/lib/std/x/os/socket.zig b/lib/std/x/os/socket.zig
@@ -1,320 +0,0 @@
-const std = @import("../../std.zig");
-const builtin = @import("builtin");
-const net = @import("net.zig");
-
-const os = std.os;
-const fmt = std.fmt;
-const mem = std.mem;
-const time = std.time;
-const meta = std.meta;
-const native_os = builtin.os;
-const native_endian = builtin.cpu.arch.endian();
-
-const Buffer = std.x.os.Buffer;
-
-const assert = std.debug.assert;
-
-/// A generic, cross-platform socket abstraction.
-pub const Socket = struct {
- /// A socket-address pair.
- pub const Connection = struct {
- socket: Socket,
- address: Socket.Address,
-
- /// Enclose a socket and address into a socket-address pair.
- pub fn from(socket: Socket, address: Socket.Address) Socket.Connection {
- return .{ .socket = socket, .address = address };
- }
- };
-
- /// A generic socket address abstraction. It is safe to directly access and modify
- /// the fields of a `Socket.Address`.
- pub const Address = union(enum) {
- pub const Native = struct {
- pub const requires_prepended_length = native_os.getVersionRange() == .semver;
- pub const Length = if (requires_prepended_length) u8 else [0]u8;
-
- pub const Family = if (requires_prepended_length) u8 else c_ushort;
-
- /// POSIX `sockaddr.storage`. The expected size and alignment is specified in IETF RFC 2553.
- pub const Storage = extern struct {
- pub const expected_size = os.sockaddr.SS_MAXSIZE;
- pub const expected_alignment = 8;
-
- pub const padding_size = expected_size -
- mem.alignForward(@sizeOf(Address.Native.Length), expected_alignment) -
- mem.alignForward(@sizeOf(Address.Native.Family), expected_alignment);
-
- len: Address.Native.Length align(expected_alignment) = undefined,
- family: Address.Native.Family align(expected_alignment) = undefined,
- padding: [padding_size]u8 align(expected_alignment) = undefined,
-
- comptime {
- assert(@sizeOf(Storage) == Storage.expected_size);
- assert(@alignOf(Storage) == Storage.expected_alignment);
- }
- };
- };
-
- ipv4: net.IPv4.Address,
- ipv6: net.IPv6.Address,
-
- /// Instantiate a new address with a IPv4 host and port.
- pub fn initIPv4(host: net.IPv4, port: u16) Socket.Address {
- return .{ .ipv4 = .{ .host = host, .port = port } };
- }
-
- /// Instantiate a new address with a IPv6 host and port.
- pub fn initIPv6(host: net.IPv6, port: u16) Socket.Address {
- return .{ .ipv6 = .{ .host = host, .port = port } };
- }
-
- /// Parses a `sockaddr` into a generic socket address.
- pub fn fromNative(address: *align(4) const os.sockaddr) Socket.Address {
- switch (address.family) {
- os.AF.INET => {
- const info = @ptrCast(*const os.sockaddr.in, address);
- const host = net.IPv4{ .octets = @bitCast([4]u8, info.addr) };
- const port = mem.bigToNative(u16, info.port);
- return Socket.Address.initIPv4(host, port);
- },
- os.AF.INET6 => {
- const info = @ptrCast(*const os.sockaddr.in6, address);
- const host = net.IPv6{ .octets = info.addr, .scope_id = info.scope_id };
- const port = mem.bigToNative(u16, info.port);
- return Socket.Address.initIPv6(host, port);
- },
- else => unreachable,
- }
- }
-
- /// Encodes a generic socket address into an extern union that may be reliably
- /// casted into a `sockaddr` which may be passed into socket syscalls.
- pub fn toNative(self: Socket.Address) extern union {
- ipv4: os.sockaddr.in,
- ipv6: os.sockaddr.in6,
- } {
- return switch (self) {
- .ipv4 => |address| .{
- .ipv4 = .{
- .addr = @bitCast(u32, address.host.octets),
- .port = mem.nativeToBig(u16, address.port),
- },
- },
- .ipv6 => |address| .{
- .ipv6 = .{
- .addr = address.host.octets,
- .port = mem.nativeToBig(u16, address.port),
- .scope_id = address.host.scope_id,
- .flowinfo = 0,
- },
- },
- };
- }
-
- /// Returns the number of bytes that make up the `sockaddr` equivalent to the address.
- pub fn getNativeSize(self: Socket.Address) u32 {
- return switch (self) {
- .ipv4 => @sizeOf(os.sockaddr.in),
- .ipv6 => @sizeOf(os.sockaddr.in6),
- };
- }
-
- /// Implements the `std.fmt.format` API.
- pub fn format(
- self: Socket.Address,
- comptime layout: []const u8,
- opts: fmt.FormatOptions,
- writer: anytype,
- ) !void {
- if (layout.len != 0) std.fmt.invalidFmtError(layout, self);
- _ = opts;
- switch (self) {
- .ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }),
- .ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }),
- }
- }
- };
-
- /// POSIX `msghdr`. Denotes a destination address, set of buffers, control data, and flags. Ported
- /// directly from musl.
- pub const Message = if (native_os.isAtLeast(.windows, .vista) != null and native_os.isAtLeast(.windows, .vista).?)
- extern struct {
- name: usize = @ptrToInt(@as(?[*]u8, null)),
- name_len: c_int = 0,
-
- buffers: usize = undefined,
- buffers_len: c_ulong = undefined,
-
- control: Buffer = .{
- .ptr = @ptrToInt(@as(?[*]u8, null)),
- .len = 0,
- },
- flags: c_ulong = 0,
-
- pub usingnamespace MessageMixin(Message);
- }
- else if (native_os.tag == .windows)
- extern struct {
- name: usize = @ptrToInt(@as(?[*]u8, null)),
- name_len: c_int = 0,
-
- buffers: usize = undefined,
- buffers_len: u32 = undefined,
-
- control: Buffer = .{
- .ptr = @ptrToInt(@as(?[*]u8, null)),
- .len = 0,
- },
- flags: u32 = 0,
-
- pub usingnamespace MessageMixin(Message);
- }
- else if (@sizeOf(usize) > 4 and native_endian == .Big)
- extern struct {
- name: usize = @ptrToInt(@as(?[*]u8, null)),
- name_len: c_uint = 0,
-
- buffers: usize = undefined,
- _pad_1: c_int = 0,
- buffers_len: c_int = undefined,
-
- control: usize = @ptrToInt(@as(?[*]u8, null)),
- _pad_2: c_int = 0,
- control_len: c_uint = 0,
-
- flags: c_int = 0,
-
- pub usingnamespace MessageMixin(Message);
- }
- else if (@sizeOf(usize) > 4 and native_endian == .Little)
- extern struct {
- name: usize = @ptrToInt(@as(?[*]u8, null)),
- name_len: c_uint = 0,
-
- buffers: usize = undefined,
- buffers_len: c_int = undefined,
- _pad_1: c_int = 0,
-
- control: usize = @ptrToInt(@as(?[*]u8, null)),
- control_len: c_uint = 0,
- _pad_2: c_int = 0,
-
- flags: c_int = 0,
-
- pub usingnamespace MessageMixin(Message);
- }
- else
- extern struct {
- name: usize = @ptrToInt(@as(?[*]u8, null)),
- name_len: c_uint = 0,
-
- buffers: usize = undefined,
- buffers_len: c_int = undefined,
-
- control: usize = @ptrToInt(@as(?[*]u8, null)),
- control_len: c_uint = 0,
-
- flags: c_int = 0,
-
- pub usingnamespace MessageMixin(Message);
- };
-
- fn MessageMixin(comptime Self: type) type {
- return struct {
- pub fn fromBuffers(buffers: []const Buffer) Self {
- var self: Self = .{};
- self.setBuffers(buffers);
- return self;
- }
-
- pub fn setName(self: *Self, name: []const u8) void {
- self.name = @ptrToInt(name.ptr);
- self.name_len = @intCast(meta.fieldInfo(Self, .name_len).type, name.len);
- }
-
- pub fn setBuffers(self: *Self, buffers: []const Buffer) void {
- self.buffers = @ptrToInt(buffers.ptr);
- self.buffers_len = @intCast(meta.fieldInfo(Self, .buffers_len).type, buffers.len);
- }
-
- pub fn setControl(self: *Self, control: []const u8) void {
- if (native_os.tag == .windows) {
- self.control = Buffer.from(control);
- } else {
- self.control = @ptrToInt(control.ptr);
- self.control_len = @intCast(meta.fieldInfo(Self, .control_len).type, control.len);
- }
- }
-
- pub fn setFlags(self: *Self, flags: u32) void {
- self.flags = @intCast(meta.fieldInfo(Self, .flags).type, flags);
- }
-
- pub fn getName(self: Self) []const u8 {
- return @intToPtr([*]const u8, self.name)[0..@intCast(usize, self.name_len)];
- }
-
- pub fn getBuffers(self: Self) []const Buffer {
- return @intToPtr([*]const Buffer, self.buffers)[0..@intCast(usize, self.buffers_len)];
- }
-
- pub fn getControl(self: Self) []const u8 {
- if (native_os.tag == .windows) {
- return self.control.into();
- } else {
- return @intToPtr([*]const u8, self.control)[0..@intCast(usize, self.control_len)];
- }
- }
-
- pub fn getFlags(self: Self) u32 {
- return @intCast(u32, self.flags);
- }
- };
- }
-
- /// POSIX `linger`, denoting the linger settings of a socket.
- ///
- /// Microsoft's documentation and glibc denote the fields to be unsigned
- /// short's on Windows, whereas glibc and musl denote the fields to be
- /// int's on every other platform.
- pub const Linger = extern struct {
- pub const Field = switch (native_os.tag) {
- .windows => c_ushort,
- else => c_int,
- };
-
- enabled: Field,
- timeout_seconds: Field,
-
- pub fn init(timeout_seconds: ?u16) Socket.Linger {
- return .{
- .enabled = @intCast(Socket.Linger.Field, @boolToInt(timeout_seconds != null)),
- .timeout_seconds = if (timeout_seconds) |seconds| @intCast(Socket.Linger.Field, seconds) else 0,
- };
- }
- };
-
- /// Possible set of flags to initialize a socket with.
- pub const InitFlags = enum {
- // Initialize a socket to be non-blocking.
- nonblocking,
-
- // Have a socket close itself on exec syscalls.
- close_on_exec,
- };
-
- /// The underlying handle of a socket.
- fd: os.socket_t,
-
- /// Enclose a socket abstraction over an existing socket file descriptor.
- pub fn from(fd: os.socket_t) Socket {
- return Socket{ .fd = fd };
- }
-
- /// Mix in socket syscalls depending on the platform we are compiling against.
- pub usingnamespace switch (native_os.tag) {
- .windows => @import("socket_windows.zig"),
- else => @import("socket_posix.zig"),
- }.Mixin(Socket);
-};
diff --git a/lib/std/x/os/socket_posix.zig b/lib/std/x/os/socket_posix.zig
@@ -1,275 +0,0 @@
-const std = @import("../../std.zig");
-
-const os = std.os;
-const mem = std.mem;
-const time = std.time;
-
-pub fn Mixin(comptime Socket: type) type {
- return struct {
- /// Open a new socket.
- pub fn init(domain: u32, socket_type: u32, protocol: u32, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket {
- var raw_flags: u32 = socket_type;
- const set = std.EnumSet(Socket.InitFlags).init(flags);
- if (set.contains(.close_on_exec)) raw_flags |= os.SOCK.CLOEXEC;
- if (set.contains(.nonblocking)) raw_flags |= os.SOCK.NONBLOCK;
- return Socket{ .fd = try os.socket(domain, raw_flags, protocol) };
- }
-
- /// Closes the socket.
- pub fn deinit(self: Socket) void {
- os.closeSocket(self.fd);
- }
-
- /// Shutdown either the read side, write side, or all side of the socket.
- pub fn shutdown(self: Socket, how: os.ShutdownHow) !void {
- return os.shutdown(self.fd, how);
- }
-
- /// Binds the socket to an address.
- pub fn bind(self: Socket, address: Socket.Address) !void {
- return os.bind(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize());
- }
-
- /// Start listening for incoming connections on the socket.
- pub fn listen(self: Socket, max_backlog_size: u31) !void {
- return os.listen(self.fd, max_backlog_size);
- }
-
- /// Have the socket attempt to the connect to an address.
- pub fn connect(self: Socket, address: Socket.Address) !void {
- return os.connect(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize());
- }
-
- /// Accept a pending incoming connection queued to the kernel backlog
- /// of the socket.
- pub fn accept(self: Socket, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket.Connection {
- var address: Socket.Address.Native.Storage = undefined;
- var address_len: u32 = @sizeOf(Socket.Address.Native.Storage);
-
- var raw_flags: u32 = 0;
- const set = std.EnumSet(Socket.InitFlags).init(flags);
- if (set.contains(.close_on_exec)) raw_flags |= os.SOCK.CLOEXEC;
- if (set.contains(.nonblocking)) raw_flags |= os.SOCK.NONBLOCK;
-
- const socket = Socket{ .fd = try os.accept(self.fd, @ptrCast(*os.sockaddr, &address), &address_len, raw_flags) };
- const socket_address = Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address));
-
- return Socket.Connection.from(socket, socket_address);
- }
-
- /// Read data from the socket into the buffer provided with a set of flags
- /// specified. It returns the number of bytes read into the buffer provided.
- pub fn read(self: Socket, buf: []u8, flags: u32) !usize {
- return os.recv(self.fd, buf, flags);
- }
-
- /// Write a buffer of data provided to the socket with a set of flags specified.
- /// It returns the number of bytes that are written to the socket.
- pub fn write(self: Socket, buf: []const u8, flags: u32) !usize {
- return os.send(self.fd, buf, flags);
- }
-
- /// Writes multiple I/O vectors with a prepended message header to the socket
- /// with a set of flags specified. It returns the number of bytes that are
- /// written to the socket.
- pub fn writeMessage(self: Socket, msg: Socket.Message, flags: u32) !usize {
- while (true) {
- const rc = os.system.sendmsg(self.fd, &msg, @intCast(c_int, flags));
- return switch (os.errno(rc)) {
- .SUCCESS => return @intCast(usize, rc),
- .ACCES => error.AccessDenied,
- .AGAIN => error.WouldBlock,
- .ALREADY => error.FastOpenAlreadyInProgress,
- .BADF => unreachable, // always a race condition
- .CONNRESET => error.ConnectionResetByPeer,
- .DESTADDRREQ => unreachable, // The socket is not connection-mode, and no peer address is set.
- .FAULT => unreachable, // An invalid user space address was specified for an argument.
- .INTR => continue,
- .INVAL => unreachable, // Invalid argument passed.
- .ISCONN => unreachable, // connection-mode socket was connected already but a recipient was specified
- .MSGSIZE => error.MessageTooBig,
- .NOBUFS => error.SystemResources,
- .NOMEM => error.SystemResources,
- .NOTSOCK => unreachable, // The file descriptor sockfd does not refer to a socket.
- .OPNOTSUPP => unreachable, // Some bit in the flags argument is inappropriate for the socket type.
- .PIPE => error.BrokenPipe,
- .AFNOSUPPORT => error.AddressFamilyNotSupported,
- .LOOP => error.SymLinkLoop,
- .NAMETOOLONG => error.NameTooLong,
- .NOENT => error.FileNotFound,
- .NOTDIR => error.NotDir,
- .HOSTUNREACH => error.NetworkUnreachable,
- .NETUNREACH => error.NetworkUnreachable,
- .NOTCONN => error.SocketNotConnected,
- .NETDOWN => error.NetworkSubsystemFailed,
- else => |err| os.unexpectedErrno(err),
- };
- }
- }
-
- /// Read multiple I/O vectors with a prepended message header from the socket
- /// with a set of flags specified. It returns the number of bytes that were
- /// read into the buffer provided.
- pub fn readMessage(self: Socket, msg: *Socket.Message, flags: u32) !usize {
- while (true) {
- const rc = os.system.recvmsg(self.fd, msg, @intCast(c_int, flags));
- return switch (os.errno(rc)) {
- .SUCCESS => @intCast(usize, rc),
- .BADF => unreachable, // always a race condition
- .FAULT => unreachable,
- .INVAL => unreachable,
- .NOTCONN => unreachable,
- .NOTSOCK => unreachable,
- .INTR => continue,
- .AGAIN => error.WouldBlock,
- .NOMEM => error.SystemResources,
- .CONNREFUSED => error.ConnectionRefused,
- .CONNRESET => error.ConnectionResetByPeer,
- else => |err| os.unexpectedErrno(err),
- };
- }
- }
-
- /// Query the address that the socket is locally bounded to.
- pub fn getLocalAddress(self: Socket) !Socket.Address {
- var address: Socket.Address.Native.Storage = undefined;
- var address_len: u32 = @sizeOf(Socket.Address.Native.Storage);
- try os.getsockname(self.fd, @ptrCast(*os.sockaddr, &address), &address_len);
- return Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address));
- }
-
- /// Query the address that the socket is connected to.
- pub fn getRemoteAddress(self: Socket) !Socket.Address {
- var address: Socket.Address.Native.Storage = undefined;
- var address_len: u32 = @sizeOf(Socket.Address.Native.Storage);
- try os.getpeername(self.fd, @ptrCast(*os.sockaddr, &address), &address_len);
- return Socket.Address.fromNative(@ptrCast(*os.sockaddr, &address));
- }
-
- /// Query and return the latest cached error on the socket.
- pub fn getError(self: Socket) !void {
- return os.getsockoptError(self.fd);
- }
-
- /// Query the read buffer size of the socket.
- pub fn getReadBufferSize(self: Socket) !u32 {
- var value: u32 = undefined;
- var value_len: u32 = @sizeOf(u32);
-
- const rc = os.system.getsockopt(self.fd, os.SOL.SOCKET, os.SO.RCVBUF, mem.asBytes(&value), &value_len);
- return switch (os.errno(rc)) {
- .SUCCESS => value,
- .BADF => error.BadFileDescriptor,
- .FAULT => error.InvalidAddressSpace,
- .INVAL => error.InvalidSocketOption,
- .NOPROTOOPT => error.UnknownSocketOption,
- .NOTSOCK => error.NotASocket,
- else => |err| os.unexpectedErrno(err),
- };
- }
-
- /// Query the write buffer size of the socket.
- pub fn getWriteBufferSize(self: Socket) !u32 {
- var value: u32 = undefined;
- var value_len: u32 = @sizeOf(u32);
-
- const rc = os.system.getsockopt(self.fd, os.SOL.SOCKET, os.SO.SNDBUF, mem.asBytes(&value), &value_len);
- return switch (os.errno(rc)) {
- .SUCCESS => value,
- .BADF => error.BadFileDescriptor,
- .FAULT => error.InvalidAddressSpace,
- .INVAL => error.InvalidSocketOption,
- .NOPROTOOPT => error.UnknownSocketOption,
- .NOTSOCK => error.NotASocket,
- else => |err| os.unexpectedErrno(err),
- };
- }
-
- /// Set a socket option.
- pub fn setOption(self: Socket, level: u32, code: u32, value: []const u8) !void {
- return os.setsockopt(self.fd, level, code, value);
- }
-
- /// Have close() or shutdown() syscalls block until all queued messages in the socket have been successfully
- /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption`
- /// if the host does not support the option for a socket to linger around up until a timeout specified in
- /// seconds.
- pub fn setLinger(self: Socket, timeout_seconds: ?u16) !void {
- if (@hasDecl(os.SO, "LINGER")) {
- const settings = Socket.Linger.init(timeout_seconds);
- return self.setOption(os.SOL.SOCKET, os.SO.LINGER, mem.asBytes(&settings));
- }
-
- return error.UnsupportedSocketOption;
- }
-
- /// On connection-oriented sockets, have keep-alive messages be sent periodically. The timing in which keep-alive
- /// messages are sent are dependant on operating system settings. It returns `error.UnsupportedSocketOption` if
- /// the host does not support periodically sending keep-alive messages on connection-oriented sockets.
- pub fn setKeepAlive(self: Socket, enabled: bool) !void {
- if (@hasDecl(os.SO, "KEEPALIVE")) {
- return self.setOption(os.SOL.SOCKET, os.SO.KEEPALIVE, mem.asBytes(&@as(u32, @boolToInt(enabled))));
- }
- return error.UnsupportedSocketOption;
- }
-
- /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if
- /// the host does not support sockets listening the same address.
- pub fn setReuseAddress(self: Socket, enabled: bool) !void {
- if (@hasDecl(os.SO, "REUSEADDR")) {
- return self.setOption(os.SOL.SOCKET, os.SO.REUSEADDR, mem.asBytes(&@as(u32, @boolToInt(enabled))));
- }
- return error.UnsupportedSocketOption;
- }
-
- /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if
- /// the host does not supports sockets listening on the same port.
- pub fn setReusePort(self: Socket, enabled: bool) !void {
- if (@hasDecl(os.SO, "REUSEPORT")) {
- return self.setOption(os.SOL.SOCKET, os.SO.REUSEPORT, mem.asBytes(&@as(u32, @boolToInt(enabled))));
- }
- return error.UnsupportedSocketOption;
- }
-
- /// Set the write buffer size of the socket.
- pub fn setWriteBufferSize(self: Socket, size: u32) !void {
- return self.setOption(os.SOL.SOCKET, os.SO.SNDBUF, mem.asBytes(&size));
- }
-
- /// Set the read buffer size of the socket.
- pub fn setReadBufferSize(self: Socket, size: u32) !void {
- return self.setOption(os.SOL.SOCKET, os.SO.RCVBUF, mem.asBytes(&size));
- }
-
- /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is
- /// set on a non-blocking socket.
- ///
- /// Set a timeout on the socket that is to occur if no messages are successfully written
- /// to its bound destination after a specified number of milliseconds. A subsequent write
- /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded.
- pub fn setWriteTimeout(self: Socket, milliseconds: usize) !void {
- const timeout = os.timeval{
- .tv_sec = @intCast(i32, milliseconds / time.ms_per_s),
- .tv_usec = @intCast(i32, (milliseconds % time.ms_per_s) * time.us_per_ms),
- };
-
- return self.setOption(os.SOL.SOCKET, os.SO.SNDTIMEO, mem.asBytes(&timeout));
- }
-
- /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is
- /// set on a non-blocking socket.
- ///
- /// Set a timeout on the socket that is to occur if no messages are successfully read
- /// from its bound destination after a specified number of milliseconds. A subsequent
- /// read from the socket will thereafter return `error.WouldBlock` should the timeout be
- /// exceeded.
- pub fn setReadTimeout(self: Socket, milliseconds: usize) !void {
- const timeout = os.timeval{
- .tv_sec = @intCast(i32, milliseconds / time.ms_per_s),
- .tv_usec = @intCast(i32, (milliseconds % time.ms_per_s) * time.us_per_ms),
- };
-
- return self.setOption(os.SOL.SOCKET, os.SO.RCVTIMEO, mem.asBytes(&timeout));
- }
- };
-}
diff --git a/lib/std/x/os/socket_windows.zig b/lib/std/x/os/socket_windows.zig
@@ -1,458 +0,0 @@
-const std = @import("../../std.zig");
-const net = @import("net.zig");
-
-const os = std.os;
-const mem = std.mem;
-
-const windows = std.os.windows;
-const ws2_32 = windows.ws2_32;
-
-pub fn Mixin(comptime Socket: type) type {
- return struct {
- /// Open a new socket.
- pub fn init(domain: u32, socket_type: u32, protocol: u32, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket {
- var raw_flags: u32 = ws2_32.WSA_FLAG_OVERLAPPED;
- const set = std.EnumSet(Socket.InitFlags).init(flags);
- if (set.contains(.close_on_exec)) raw_flags |= ws2_32.WSA_FLAG_NO_HANDLE_INHERIT;
-
- const fd = ws2_32.WSASocketW(
- @intCast(i32, domain),
- @intCast(i32, socket_type),
- @intCast(i32, protocol),
- null,
- 0,
- raw_flags,
- );
- if (fd == ws2_32.INVALID_SOCKET) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSANOTINITIALISED => {
- _ = try windows.WSAStartup(2, 2);
- return init(domain, socket_type, protocol, flags);
- },
- .WSAEAFNOSUPPORT => error.AddressFamilyNotSupported,
- .WSAEMFILE => error.ProcessFdQuotaExceeded,
- .WSAENOBUFS => error.SystemResources,
- .WSAEPROTONOSUPPORT => error.ProtocolNotSupported,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
-
- if (set.contains(.nonblocking)) {
- var enabled: c_ulong = 1;
- const rc = ws2_32.ioctlsocket(fd, ws2_32.FIONBIO, &enabled);
- if (rc == ws2_32.SOCKET_ERROR) {
- return windows.unexpectedWSAError(ws2_32.WSAGetLastError());
- }
- }
-
- return Socket{ .fd = fd };
- }
-
- /// Closes the socket.
- pub fn deinit(self: Socket) void {
- _ = ws2_32.closesocket(self.fd);
- }
-
- /// Shutdown either the read side, write side, or all side of the socket.
- pub fn shutdown(self: Socket, how: os.ShutdownHow) !void {
- const rc = ws2_32.shutdown(self.fd, switch (how) {
- .recv => ws2_32.SD_RECEIVE,
- .send => ws2_32.SD_SEND,
- .both => ws2_32.SD_BOTH,
- });
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSAECONNABORTED => return error.ConnectionAborted,
- .WSAECONNRESET => return error.ConnectionResetByPeer,
- .WSAEINPROGRESS => return error.BlockingOperationInProgress,
- .WSAEINVAL => unreachable,
- .WSAENETDOWN => return error.NetworkSubsystemFailed,
- .WSAENOTCONN => return error.SocketNotConnected,
- .WSAENOTSOCK => unreachable,
- .WSANOTINITIALISED => unreachable,
- else => |err| return windows.unexpectedWSAError(err),
- };
- }
- }
-
- /// Binds the socket to an address.
- pub fn bind(self: Socket, address: Socket.Address) !void {
- const rc = ws2_32.bind(self.fd, @ptrCast(*const ws2_32.sockaddr, &address.toNative()), @intCast(c_int, address.getNativeSize()));
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAEACCES => error.AccessDenied,
- .WSAEADDRINUSE => error.AddressInUse,
- .WSAEADDRNOTAVAIL => error.AddressNotAvailable,
- .WSAEFAULT => error.BadAddress,
- .WSAEINPROGRESS => error.WouldBlock,
- .WSAEINVAL => error.AlreadyBound,
- .WSAENOBUFS => error.NoEphemeralPortsAvailable,
- .WSAENOTSOCK => error.NotASocket,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
- }
-
- /// Start listening for incoming connections on the socket.
- pub fn listen(self: Socket, max_backlog_size: u31) !void {
- const rc = ws2_32.listen(self.fd, max_backlog_size);
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAEADDRINUSE => error.AddressInUse,
- .WSAEISCONN => error.AlreadyConnected,
- .WSAEINVAL => error.SocketNotBound,
- .WSAEMFILE, .WSAENOBUFS => error.SystemResources,
- .WSAENOTSOCK => error.FileDescriptorNotASocket,
- .WSAEOPNOTSUPP => error.OperationNotSupported,
- .WSAEINPROGRESS => error.WouldBlock,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
- }
-
- /// Have the socket attempt to the connect to an address.
- pub fn connect(self: Socket, address: Socket.Address) !void {
- const rc = ws2_32.connect(self.fd, @ptrCast(*const ws2_32.sockaddr, &address.toNative()), @intCast(c_int, address.getNativeSize()));
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSAEADDRINUSE => error.AddressInUse,
- .WSAEADDRNOTAVAIL => error.AddressNotAvailable,
- .WSAECONNREFUSED => error.ConnectionRefused,
- .WSAETIMEDOUT => error.ConnectionTimedOut,
- .WSAEFAULT => error.BadAddress,
- .WSAEINVAL => error.ListeningSocket,
- .WSAEISCONN => error.AlreadyConnected,
- .WSAENOTSOCK => error.NotASocket,
- .WSAEACCES => error.BroadcastNotEnabled,
- .WSAENOBUFS => error.SystemResources,
- .WSAEAFNOSUPPORT => error.AddressFamilyNotSupported,
- .WSAEINPROGRESS, .WSAEWOULDBLOCK => error.WouldBlock,
- .WSAEHOSTUNREACH, .WSAENETUNREACH => error.NetworkUnreachable,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
- }
-
- /// Accept a pending incoming connection queued to the kernel backlog
- /// of the socket.
- pub fn accept(self: Socket, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Socket.Connection {
- var address: Socket.Address.Native.Storage = undefined;
- var address_len: c_int = @sizeOf(Socket.Address.Native.Storage);
-
- const fd = ws2_32.accept(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len);
- if (fd == ws2_32.INVALID_SOCKET) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSANOTINITIALISED => unreachable,
- .WSAECONNRESET => error.ConnectionResetByPeer,
- .WSAEFAULT => unreachable,
- .WSAEINVAL => error.SocketNotListening,
- .WSAEMFILE => error.ProcessFdQuotaExceeded,
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAENOBUFS => error.FileDescriptorNotASocket,
- .WSAEOPNOTSUPP => error.OperationNotSupported,
- .WSAEWOULDBLOCK => error.WouldBlock,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
-
- const socket = Socket.from(fd);
- errdefer socket.deinit();
-
- const socket_address = Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address));
-
- const set = std.EnumSet(Socket.InitFlags).init(flags);
- if (set.contains(.nonblocking)) {
- var enabled: c_ulong = 1;
- const rc = ws2_32.ioctlsocket(fd, ws2_32.FIONBIO, &enabled);
- if (rc == ws2_32.SOCKET_ERROR) {
- return windows.unexpectedWSAError(ws2_32.WSAGetLastError());
- }
- }
-
- return Socket.Connection.from(socket, socket_address);
- }
-
- /// Read data from the socket into the buffer provided with a set of flags
- /// specified. It returns the number of bytes read into the buffer provided.
- pub fn read(self: Socket, buf: []u8, flags: u32) !usize {
- var bufs = &[_]ws2_32.WSABUF{.{ .len = @intCast(u32, buf.len), .buf = buf.ptr }};
- var num_bytes: u32 = undefined;
- var flags_ = flags;
-
- const rc = ws2_32.WSARecv(self.fd, bufs, 1, &num_bytes, &flags_, null, null);
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSAECONNABORTED => error.ConnectionAborted,
- .WSAECONNRESET => error.ConnectionResetByPeer,
- .WSAEDISCON => error.ConnectionClosedByPeer,
- .WSAEFAULT => error.BadBuffer,
- .WSAEINPROGRESS,
- .WSAEWOULDBLOCK,
- .WSA_IO_PENDING,
- .WSAETIMEDOUT,
- => error.WouldBlock,
- .WSAEINTR => error.Cancelled,
- .WSAEINVAL => error.SocketNotBound,
- .WSAEMSGSIZE => error.MessageTooLarge,
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAENETRESET => error.NetworkReset,
- .WSAENOTCONN => error.SocketNotConnected,
- .WSAENOTSOCK => error.FileDescriptorNotASocket,
- .WSAEOPNOTSUPP => error.OperationNotSupported,
- .WSAESHUTDOWN => error.AlreadyShutdown,
- .WSA_OPERATION_ABORTED => error.OperationAborted,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
-
- return @intCast(usize, num_bytes);
- }
-
- /// Write a buffer of data provided to the socket with a set of flags specified.
- /// It returns the number of bytes that are written to the socket.
- pub fn write(self: Socket, buf: []const u8, flags: u32) !usize {
- var bufs = &[_]ws2_32.WSABUF{.{ .len = @intCast(u32, buf.len), .buf = @intToPtr([*]u8, @ptrToInt(buf.ptr)) }};
- var num_bytes: u32 = undefined;
-
- const rc = ws2_32.WSASend(self.fd, bufs, 1, &num_bytes, flags, null, null);
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSAECONNABORTED => error.ConnectionAborted,
- .WSAECONNRESET => error.ConnectionResetByPeer,
- .WSAEFAULT => error.BadBuffer,
- .WSAEINPROGRESS,
- .WSAEWOULDBLOCK,
- .WSA_IO_PENDING,
- .WSAETIMEDOUT,
- => error.WouldBlock,
- .WSAEINTR => error.Cancelled,
- .WSAEINVAL => error.SocketNotBound,
- .WSAEMSGSIZE => error.MessageTooLarge,
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAENETRESET => error.NetworkReset,
- .WSAENOBUFS => error.BufferDeadlock,
- .WSAENOTCONN => error.SocketNotConnected,
- .WSAENOTSOCK => error.FileDescriptorNotASocket,
- .WSAEOPNOTSUPP => error.OperationNotSupported,
- .WSAESHUTDOWN => error.AlreadyShutdown,
- .WSA_OPERATION_ABORTED => error.OperationAborted,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
-
- return @intCast(usize, num_bytes);
- }
-
- /// Writes multiple I/O vectors with a prepended message header to the socket
- /// with a set of flags specified. It returns the number of bytes that are
- /// written to the socket.
- pub fn writeMessage(self: Socket, msg: Socket.Message, flags: u32) !usize {
- const call = try windows.loadWinsockExtensionFunction(ws2_32.LPFN_WSASENDMSG, self.fd, ws2_32.WSAID_WSASENDMSG);
-
- var num_bytes: u32 = undefined;
-
- const rc = call(self.fd, &msg, flags, &num_bytes, null, null);
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSAECONNABORTED => error.ConnectionAborted,
- .WSAECONNRESET => error.ConnectionResetByPeer,
- .WSAEFAULT => error.BadBuffer,
- .WSAEINPROGRESS,
- .WSAEWOULDBLOCK,
- .WSA_IO_PENDING,
- .WSAETIMEDOUT,
- => error.WouldBlock,
- .WSAEINTR => error.Cancelled,
- .WSAEINVAL => error.SocketNotBound,
- .WSAEMSGSIZE => error.MessageTooLarge,
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAENETRESET => error.NetworkReset,
- .WSAENOBUFS => error.BufferDeadlock,
- .WSAENOTCONN => error.SocketNotConnected,
- .WSAENOTSOCK => error.FileDescriptorNotASocket,
- .WSAEOPNOTSUPP => error.OperationNotSupported,
- .WSAESHUTDOWN => error.AlreadyShutdown,
- .WSA_OPERATION_ABORTED => error.OperationAborted,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
-
- return @intCast(usize, num_bytes);
- }
-
- /// Read multiple I/O vectors with a prepended message header from the socket
- /// with a set of flags specified. It returns the number of bytes that were
- /// read into the buffer provided.
- pub fn readMessage(self: Socket, msg: *Socket.Message, flags: u32) !usize {
- _ = flags;
- const call = try windows.loadWinsockExtensionFunction(ws2_32.LPFN_WSARECVMSG, self.fd, ws2_32.WSAID_WSARECVMSG);
-
- var num_bytes: u32 = undefined;
-
- const rc = call(self.fd, msg, &num_bytes, null, null);
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSAECONNABORTED => error.ConnectionAborted,
- .WSAECONNRESET => error.ConnectionResetByPeer,
- .WSAEDISCON => error.ConnectionClosedByPeer,
- .WSAEFAULT => error.BadBuffer,
- .WSAEINPROGRESS,
- .WSAEWOULDBLOCK,
- .WSA_IO_PENDING,
- .WSAETIMEDOUT,
- => error.WouldBlock,
- .WSAEINTR => error.Cancelled,
- .WSAEINVAL => error.SocketNotBound,
- .WSAEMSGSIZE => error.MessageTooLarge,
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAENETRESET => error.NetworkReset,
- .WSAENOTCONN => error.SocketNotConnected,
- .WSAENOTSOCK => error.FileDescriptorNotASocket,
- .WSAEOPNOTSUPP => error.OperationNotSupported,
- .WSAESHUTDOWN => error.AlreadyShutdown,
- .WSA_OPERATION_ABORTED => error.OperationAborted,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
-
- return @intCast(usize, num_bytes);
- }
-
- /// Query the address that the socket is locally bounded to.
- pub fn getLocalAddress(self: Socket) !Socket.Address {
- var address: Socket.Address.Native.Storage = undefined;
- var address_len: c_int = @sizeOf(Socket.Address.Native.Storage);
-
- const rc = ws2_32.getsockname(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len);
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSANOTINITIALISED => unreachable,
- .WSAEFAULT => unreachable,
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAENOTSOCK => error.FileDescriptorNotASocket,
- .WSAEINVAL => error.SocketNotBound,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
-
- return Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address));
- }
-
- /// Query the address that the socket is connected to.
- pub fn getRemoteAddress(self: Socket) !Socket.Address {
- var address: Socket.Address.Native.Storage = undefined;
- var address_len: c_int = @sizeOf(Socket.Address.Native.Storage);
-
- const rc = ws2_32.getpeername(self.fd, @ptrCast(*ws2_32.sockaddr, &address), &address_len);
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSANOTINITIALISED => unreachable,
- .WSAEFAULT => unreachable,
- .WSAENETDOWN => error.NetworkSubsystemFailed,
- .WSAENOTSOCK => error.FileDescriptorNotASocket,
- .WSAEINVAL => error.SocketNotBound,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
-
- return Socket.Address.fromNative(@ptrCast(*ws2_32.sockaddr, &address));
- }
-
- /// Query and return the latest cached error on the socket.
- pub fn getError(self: Socket) !void {
- _ = self;
- return {};
- }
-
- /// Query the read buffer size of the socket.
- pub fn getReadBufferSize(self: Socket) !u32 {
- _ = self;
- return 0;
- }
-
- /// Query the write buffer size of the socket.
- pub fn getWriteBufferSize(self: Socket) !u32 {
- _ = self;
- return 0;
- }
-
- /// Set a socket option.
- pub fn setOption(self: Socket, level: u32, code: u32, value: []const u8) !void {
- const rc = ws2_32.setsockopt(self.fd, @intCast(i32, level), @intCast(i32, code), value.ptr, @intCast(i32, value.len));
- if (rc == ws2_32.SOCKET_ERROR) {
- return switch (ws2_32.WSAGetLastError()) {
- .WSANOTINITIALISED => unreachable,
- .WSAENETDOWN => return error.NetworkSubsystemFailed,
- .WSAEFAULT => unreachable,
- .WSAENOTSOCK => return error.FileDescriptorNotASocket,
- .WSAEINVAL => return error.SocketNotBound,
- else => |err| windows.unexpectedWSAError(err),
- };
- }
- }
-
- /// Have close() or shutdown() syscalls block until all queued messages in the socket have been successfully
- /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption`
- /// if the host does not support the option for a socket to linger around up until a timeout specified in
- /// seconds.
- pub fn setLinger(self: Socket, timeout_seconds: ?u16) !void {
- const settings = Socket.Linger.init(timeout_seconds);
- return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.LINGER, mem.asBytes(&settings));
- }
-
- /// On connection-oriented sockets, have keep-alive messages be sent periodically. The timing in which keep-alive
- /// messages are sent are dependant on operating system settings. It returns `error.UnsupportedSocketOption` if
- /// the host does not support periodically sending keep-alive messages on connection-oriented sockets.
- pub fn setKeepAlive(self: Socket, enabled: bool) !void {
- return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.KEEPALIVE, mem.asBytes(&@as(u32, @boolToInt(enabled))));
- }
-
- /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if
- /// the host does not support sockets listening the same address.
- pub fn setReuseAddress(self: Socket, enabled: bool) !void {
- return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.REUSEADDR, mem.asBytes(&@as(u32, @boolToInt(enabled))));
- }
-
- /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if
- /// the host does not supports sockets listening on the same port.
- ///
- /// TODO: verify if this truly mimicks SO.REUSEPORT behavior, or if SO.REUSE_UNICASTPORT provides the correct behavior
- pub fn setReusePort(self: Socket, enabled: bool) !void {
- try self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.BROADCAST, mem.asBytes(&@as(u32, @boolToInt(enabled))));
- try self.setReuseAddress(enabled);
- }
-
- /// Set the write buffer size of the socket.
- pub fn setWriteBufferSize(self: Socket, size: u32) !void {
- return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.SNDBUF, mem.asBytes(&size));
- }
-
- /// Set the read buffer size of the socket.
- pub fn setReadBufferSize(self: Socket, size: u32) !void {
- return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.RCVBUF, mem.asBytes(&size));
- }
-
- /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is
- /// set on a non-blocking socket.
- ///
- /// Set a timeout on the socket that is to occur if no messages are successfully written
- /// to its bound destination after a specified number of milliseconds. A subsequent write
- /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded.
- pub fn setWriteTimeout(self: Socket, milliseconds: u32) !void {
- return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.SNDTIMEO, mem.asBytes(&milliseconds));
- }
-
- /// WARNING: Timeouts only affect blocking sockets. It is undefined behavior if a timeout is
- /// set on a non-blocking socket.
- ///
- /// Set a timeout on the socket that is to occur if no messages are successfully read
- /// from its bound destination after a specified number of milliseconds. A subsequent
- /// read from the socket will thereafter return `error.WouldBlock` should the timeout be
- /// exceeded.
- pub fn setReadTimeout(self: Socket, milliseconds: u32) !void {
- return self.setOption(ws2_32.SOL.SOCKET, ws2_32.SO.RCVTIMEO, mem.asBytes(&milliseconds));
- }
- };
-}
diff --git a/src/Compilation.zig b/src/Compilation.zig
@@ -584,7 +584,17 @@ pub const AllErrors = struct {
Message.HashContext,
std.hash_map.default_max_load_percentage,
).init(allocator);
- const err_source = try module_err_msg.src_loc.file_scope.getSource(module.gpa);
+ const err_source = module_err_msg.src_loc.file_scope.getSource(module.gpa) catch |err| {
+ const file_path = try module_err_msg.src_loc.file_scope.fullPath(allocator);
+ try errors.append(.{
+ .plain = .{
+ .msg = try std.fmt.allocPrint(allocator, "unable to load '{s}': {s}", .{
+ file_path, @errorName(err),
+ }),
+ },
+ });
+ return;
+ };
const err_span = try module_err_msg.src_loc.span(module.gpa);
const err_loc = std.zig.findLineColumn(err_source.bytes, err_span.main);