commit 93ab8be8d8464452af6d2e686e91be2c1da98979 (tree)
parent 02c33d02e05f3dd067bc5492d2617b7805ef897d
Author: Andrew Kelley <andrew@ziglang.org>
Date: Fri, 16 Dec 2022 14:08:45 -0700
extract std.crypto.tls.Client into separate namespace
Diffstat:
5 files changed, 1083 insertions(+), 1066 deletions(-)
diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig
@@ -176,7 +176,7 @@ const std = @import("std.zig");
pub const errors = @import("crypto/errors.zig");
-pub const Tls = @import("crypto/Tls.zig");
+pub const tls = @import("crypto/tls.zig");
test {
_ = aead.aegis.Aegis128L;
diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig
@@ -1,1059 +0,0 @@
-const std = @import("../std.zig");
-const Tls = @This();
-const net = std.net;
-const mem = std.mem;
-const crypto = std.crypto;
-const assert = std.debug.assert;
-
-application_cipher: ApplicationCipher,
-read_seq: u64,
-write_seq: u64,
-/// The size is enough to contain exactly one TLSCiphertext record.
-partially_read_buffer: [max_ciphertext_record_len]u8,
-/// The number of partially read bytes inside `partiall_read_buffer`.
-partially_read_len: u15,
-eof: bool,
-
-pub const ciphertext_record_header_len = 5;
-pub const max_ciphertext_len = (1 << 14) + 256;
-pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len;
-pub const hello_retry_request_sequence = [32]u8{
- 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
- 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
-};
-
-pub const ProtocolVersion = enum(u16) {
- tls_1_2 = 0x0303,
- tls_1_3 = 0x0304,
- _,
-};
-
-pub const ContentType = enum(u8) {
- invalid = 0,
- change_cipher_spec = 20,
- alert = 21,
- handshake = 22,
- application_data = 23,
- _,
-};
-
-pub const HandshakeType = enum(u8) {
- client_hello = 1,
- server_hello = 2,
- new_session_ticket = 4,
- end_of_early_data = 5,
- encrypted_extensions = 8,
- certificate = 11,
- certificate_request = 13,
- certificate_verify = 15,
- finished = 20,
- key_update = 24,
- message_hash = 254,
-};
-
-pub const ExtensionType = enum(u16) {
- /// RFC 6066
- server_name = 0,
- /// RFC 6066
- max_fragment_length = 1,
- /// RFC 6066
- status_request = 5,
- /// RFC 8422, 7919
- supported_groups = 10,
- /// RFC 8446
- signature_algorithms = 13,
- /// RFC 5764
- use_srtp = 14,
- /// RFC 6520
- heartbeat = 15,
- /// RFC 7301
- application_layer_protocol_negotiation = 16,
- /// RFC 6962
- signed_certificate_timestamp = 18,
- /// RFC 7250
- client_certificate_type = 19,
- /// RFC 7250
- server_certificate_type = 20,
- /// RFC 7685
- padding = 21,
- /// RFC 8446
- pre_shared_key = 41,
- /// RFC 8446
- early_data = 42,
- /// RFC 8446
- supported_versions = 43,
- /// RFC 8446
- cookie = 44,
- /// RFC 8446
- psk_key_exchange_modes = 45,
- /// RFC 8446
- certificate_authorities = 47,
- /// RFC 8446
- oid_filters = 48,
- /// RFC 8446
- post_handshake_auth = 49,
- /// RFC 8446
- signature_algorithms_cert = 50,
- /// RFC 8446
- key_share = 51,
-};
-
-pub const AlertLevel = enum(u8) {
- warning = 1,
- fatal = 2,
- _,
-};
-
-pub const AlertDescription = enum(u8) {
- close_notify = 0,
- unexpected_message = 10,
- bad_record_mac = 20,
- record_overflow = 22,
- handshake_failure = 40,
- bad_certificate = 42,
- unsupported_certificate = 43,
- certificate_revoked = 44,
- certificate_expired = 45,
- certificate_unknown = 46,
- illegal_parameter = 47,
- unknown_ca = 48,
- access_denied = 49,
- decode_error = 50,
- decrypt_error = 51,
- protocol_version = 70,
- insufficient_security = 71,
- internal_error = 80,
- inappropriate_fallback = 86,
- user_canceled = 90,
- missing_extension = 109,
- unsupported_extension = 110,
- unrecognized_name = 112,
- bad_certificate_status_response = 113,
- unknown_psk_identity = 115,
- certificate_required = 116,
- no_application_protocol = 120,
- _,
-};
-
-pub const SignatureScheme = enum(u16) {
- // RSASSA-PKCS1-v1_5 algorithms
- rsa_pkcs1_sha256 = 0x0401,
- rsa_pkcs1_sha384 = 0x0501,
- rsa_pkcs1_sha512 = 0x0601,
-
- // ECDSA algorithms
- ecdsa_secp256r1_sha256 = 0x0403,
- ecdsa_secp384r1_sha384 = 0x0503,
- ecdsa_secp521r1_sha512 = 0x0603,
-
- // RSASSA-PSS algorithms with public key OID rsaEncryption
- rsa_pss_rsae_sha256 = 0x0804,
- rsa_pss_rsae_sha384 = 0x0805,
- rsa_pss_rsae_sha512 = 0x0806,
-
- // EdDSA algorithms
- ed25519 = 0x0807,
- ed448 = 0x0808,
-
- // RSASSA-PSS algorithms with public key OID RSASSA-PSS
- rsa_pss_pss_sha256 = 0x0809,
- rsa_pss_pss_sha384 = 0x080a,
- rsa_pss_pss_sha512 = 0x080b,
-
- // Legacy algorithms
- rsa_pkcs1_sha1 = 0x0201,
- ecdsa_sha1 = 0x0203,
-
- _,
-};
-
-pub const NamedGroup = enum(u16) {
- // Elliptic Curve Groups (ECDHE)
- secp256r1 = 0x0017,
- secp384r1 = 0x0018,
- secp521r1 = 0x0019,
- x25519 = 0x001D,
- x448 = 0x001E,
-
- // Finite Field Groups (DHE)
- ffdhe2048 = 0x0100,
- ffdhe3072 = 0x0101,
- ffdhe4096 = 0x0102,
- ffdhe6144 = 0x0103,
- ffdhe8192 = 0x0104,
-
- _,
-};
-
-// Plaintext:
-// * type: ContentType
-// * legacy_record_version: u16 = 0x0303,
-// * length: u16,
-// - The length (in bytes) of the following TLSPlaintext.fragment. The
-// length MUST NOT exceed 2^14 bytes.
-// * fragment: opaque
-// - the data being transmitted
-
-// Ciphertext
-// * ContentType opaque_type = application_data; /* 23 */
-// * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */
-// * uint16 length;
-// * opaque encrypted_record[TLSCiphertext.length];
-
-// Handshake:
-// * type: HandshakeType
-// * length: u24
-// * data: opaque
-
-// ServerHello:
-// * ProtocolVersion legacy_version = 0x0303;
-// * Random random;
-// * opaque legacy_session_id_echo<0..32>;
-// * CipherSuite cipher_suite;
-// * uint8 legacy_compression_method = 0;
-// * Extension extensions<6..2^16-1>;
-
-// Extension:
-// * ExtensionType extension_type;
-// * opaque extension_data<0..2^16-1>;
-
-pub const CipherSuite = enum(u16) {
- TLS_AES_128_GCM_SHA256 = 0x1301,
- TLS_AES_256_GCM_SHA384 = 0x1302,
- TLS_CHACHA20_POLY1305_SHA256 = 0x1303,
- TLS_AES_128_CCM_SHA256 = 0x1304,
- TLS_AES_128_CCM_8_SHA256 = 0x1305,
-};
-
-pub const CipherParams = union(CipherSuite) {
- TLS_AES_128_GCM_SHA256: struct {
- const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
- const Hash = crypto.hash.sha2.Sha256;
- const Hmac = crypto.auth.hmac.Hmac(Hash);
- const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
-
- handshake_secret: [Hkdf.key_len]u8,
- master_secret: [Hkdf.key_len]u8,
- client_handshake_key: [AEAD.key_length]u8,
- server_handshake_key: [AEAD.key_length]u8,
- client_finished_key: [Hmac.key_length]u8,
- server_finished_key: [Hmac.key_length]u8,
- client_handshake_iv: [AEAD.nonce_length]u8,
- server_handshake_iv: [AEAD.nonce_length]u8,
- transcript_hash: Hash,
- },
- TLS_AES_256_GCM_SHA384: struct {
- const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
- const Hash = crypto.hash.sha2.Sha384;
- const Hmac = crypto.auth.hmac.Hmac(Hash);
- const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
-
- handshake_secret: [Hkdf.key_len]u8,
- master_secret: [Hkdf.key_len]u8,
- client_handshake_key: [AEAD.key_length]u8,
- server_handshake_key: [AEAD.key_length]u8,
- client_finished_key: [Hmac.key_length]u8,
- server_finished_key: [Hmac.key_length]u8,
- client_handshake_iv: [AEAD.nonce_length]u8,
- server_handshake_iv: [AEAD.nonce_length]u8,
- transcript_hash: Hash,
- },
- TLS_CHACHA20_POLY1305_SHA256: void,
- TLS_AES_128_CCM_SHA256: void,
- TLS_AES_128_CCM_8_SHA256: void,
-};
-
-/// Encryption parameters for application traffic.
-pub const ApplicationCipher = union(CipherSuite) {
- TLS_AES_128_GCM_SHA256: struct {
- const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
- const Hash = crypto.hash.sha2.Sha256;
- const Hmac = crypto.auth.hmac.Hmac(Hash);
- const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
-
- client_key: [AEAD.key_length]u8,
- server_key: [AEAD.key_length]u8,
- client_iv: [AEAD.nonce_length]u8,
- server_iv: [AEAD.nonce_length]u8,
- },
- TLS_AES_256_GCM_SHA384: struct {
- const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
- const Hash = crypto.hash.sha2.Sha384;
- const Hmac = crypto.auth.hmac.Hmac(Hash);
- const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
-
- client_key: [AEAD.key_length]u8,
- server_key: [AEAD.key_length]u8,
- client_iv: [AEAD.nonce_length]u8,
- server_iv: [AEAD.nonce_length]u8,
- },
- TLS_CHACHA20_POLY1305_SHA256: void,
- TLS_AES_128_CCM_SHA256: void,
- TLS_AES_128_CCM_8_SHA256: void,
-};
-
-const cipher_suites = blk: {
- const fields = @typeInfo(CipherSuite).Enum.fields;
- var result: [(fields.len + 1) * 2]u8 = undefined;
- mem.writeIntBig(u16, result[0..2], result.len - 2);
- for (fields) |field, i| {
- const int = @enumToInt(@field(CipherSuite, field.name));
- result[(i + 1) * 2] = @truncate(u8, int >> 8);
- result[(i + 1) * 2 + 1] = @truncate(u8, int);
- }
- break :blk result;
-};
-
-/// `host` is only borrowed during this function call.
-pub fn init(stream: net.Stream, host: []const u8) !Tls {
- var x25519_priv_key: [32]u8 = undefined;
- crypto.random.bytes(&x25519_priv_key);
- const x25519_pub_key = crypto.dh.X25519.recoverPublicKey(x25519_priv_key) catch |err| {
- switch (err) {
- // Only possible to happen if the private key is all zeroes.
- error.IdentityElement => return error.InsufficientEntropy,
- }
- };
-
- // random (u32)
- var rand_buf: [32]u8 = undefined;
- crypto.random.bytes(&rand_buf);
-
- const extensions_header = [_]u8{
- // Extensions byte length
- undefined, undefined,
-
- // Extension: supported_versions (only TLS 1.3)
- 0, 43, // ExtensionType.supported_versions
- 0x00, 0x05, // byte length of this extension payload
- 0x04, // byte length of supported versions
- 0x03, 0x04, // TLS 1.3
- 0x03, 0x03, // TLS 1.2
-
- // Extension: signature_algorithms
- 0, 13, // ExtensionType.signature_algorithms
- 0x00, 0x22, // byte length of this extension payload
- 0x00, 0x20, // byte length of signature algorithms list
- 0x04, 0x01, // rsa_pkcs1_sha256
- 0x05, 0x01, // rsa_pkcs1_sha384
- 0x06, 0x01, // rsa_pkcs1_sha512
- 0x04, 0x03, // ecdsa_secp256r1_sha256
- 0x05, 0x03, // ecdsa_secp384r1_sha384
- 0x06, 0x03, // ecdsa_secp521r1_sha512
- 0x08, 0x04, // rsa_pss_rsae_sha256
- 0x08, 0x05, // rsa_pss_rsae_sha384
- 0x08, 0x06, // rsa_pss_rsae_sha512
- 0x08, 0x07, // ed25519
- 0x08, 0x08, // ed448
- 0x08, 0x09, // rsa_pss_pss_sha256
- 0x08, 0x0a, // rsa_pss_pss_sha384
- 0x08, 0x0b, // rsa_pss_pss_sha512
- 0x02, 0x01, // rsa_pkcs1_sha1
- 0x02, 0x03, // ecdsa_sha1
-
- // Extension: supported_groups
- 0, 10, // ExtensionType.supported_groups
- 0x00, 0x0c, // byte length of this extension payload
- 0x00, 0x0a, // byte length of supported groups list
- 0x00, 0x17, // secp256r1
- 0x00, 0x18, // secp384r1
- 0x00, 0x19, // secp521r1
- 0x00, 0x1D, // x25519
- 0x00, 0x1E, // x448
-
- // Extension: key_share
- 0, 51, // ExtensionType.key_share
- 0, 38, // byte length of this extension payload
- 0, 36, // byte length of client_shares
- 0x00, 0x1D, // NamedGroup.x25519
- 0, 32, // byte length of key_exchange
- } ++ x25519_pub_key ++ [_]u8{
-
- // Extension: server_name
- 0, 0, // ExtensionType.server_name
- undefined, undefined, // byte length of this extension payload
- undefined, undefined, // server_name_list byte count
- 0x00, // name_type
- undefined, undefined, // host name len
- };
-
- var hello_header = [_]u8{
- // Plaintext header
- @enumToInt(ContentType.handshake),
- 0x03, 0x01, // legacy_record_version
- undefined, undefined, // Plaintext fragment length (u16)
-
- // Handshake header
- @enumToInt(HandshakeType.client_hello),
- undefined, undefined, undefined, // handshake length (u24)
-
- // ClientHello
- 0x03, 0x03, // legacy_version
- } ++ rand_buf ++ [1]u8{0} ++ cipher_suites ++ [_]u8{
- 0x01, 0x00, // legacy_compression_methods
- } ++ extensions_header;
-
- mem.writeIntBig(u16, hello_header[3..][0..2], @intCast(u16, hello_header.len - 5 + host.len));
- mem.writeIntBig(u24, hello_header[6..][0..3], @intCast(u24, hello_header.len - 9 + host.len));
- mem.writeIntBig(
- u16,
- hello_header[hello_header.len - extensions_header.len ..][0..2],
- @intCast(u16, extensions_header.len - 2 + host.len),
- );
- mem.writeIntBig(u16, hello_header[hello_header.len - 7 ..][0..2], @intCast(u16, 5 + host.len));
- mem.writeIntBig(u16, hello_header[hello_header.len - 5 ..][0..2], @intCast(u16, 3 + host.len));
- mem.writeIntBig(u16, hello_header[hello_header.len - 2 ..][0..2], @intCast(u16, 0 + host.len));
-
- {
- var iovecs = [_]std.os.iovec_const{
- .{
- .iov_base = &hello_header,
- .iov_len = hello_header.len,
- },
- .{
- .iov_base = host.ptr,
- .iov_len = host.len,
- },
- };
- try stream.writevAll(&iovecs);
- }
-
- const client_hello_bytes1 = hello_header[5..];
-
- var cipher_params: CipherParams = undefined;
-
- var handshake_buf: [8000]u8 = undefined;
- var len: usize = 0;
- var i: usize = i: {
- const plaintext = handshake_buf[0..5];
- len = try stream.readAtLeast(&handshake_buf, plaintext.len);
- if (len < plaintext.len) return error.EndOfStream;
- const ct = @intToEnum(ContentType, plaintext[0]);
- const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
- const end = plaintext.len + frag_len;
- if (end > handshake_buf.len) return error.TlsRecordOverflow;
- if (end > len) {
- len += try stream.readAtLeast(handshake_buf[len..], end - len);
- if (end > len) return error.EndOfStream;
- }
- const frag = handshake_buf[plaintext.len..end];
-
- switch (ct) {
- .alert => {
- const level = @intToEnum(AlertLevel, frag[0]);
- const desc = @intToEnum(AlertDescription, frag[1]);
- std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
- return error.TlsAlert;
- },
- .handshake => {
- if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
- return error.TlsUnexpectedMessage;
- }
- const length = mem.readIntBig(u24, frag[1..4]);
- if (4 + length != frag.len) return error.TlsBadLength;
- const hello = frag[4..];
- const legacy_version = mem.readIntBig(u16, hello[0..2]);
- const random = hello[2..34].*;
- if (mem.eql(u8, &random, &hello_retry_request_sequence)) {
- @panic("TODO handle HelloRetryRequest");
- }
- const legacy_session_id_echo_len = hello[34];
- if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
- const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
- const cipher_suite_tag = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
- return error.TlsIllegalParameter;
- std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite_tag)});
- const legacy_compression_method = hello[37];
- _ = legacy_compression_method;
- const extensions_size = mem.readIntBig(u16, hello[38..40]);
- if (40 + extensions_size != hello.len) return error.TlsBadLength;
- var i: usize = 40;
- var supported_version: u16 = 0;
- var opt_x25519_server_pub_key: ?*[32]u8 = null;
- while (i < hello.len) {
- const et = mem.readIntBig(u16, hello[i..][0..2]);
- i += 2;
- const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
- i += 2;
- const next_i = i + ext_size;
- if (next_i > hello.len) return error.TlsBadLength;
- switch (et) {
- @enumToInt(ExtensionType.supported_versions) => {
- if (supported_version != 0) return error.TlsIllegalParameter;
- supported_version = mem.readIntBig(u16, hello[i..][0..2]);
- },
- @enumToInt(ExtensionType.key_share) => {
- if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter;
- const named_group = mem.readIntBig(u16, hello[i..][0..2]);
- i += 2;
- switch (named_group) {
- @enumToInt(NamedGroup.x25519) => {
- const key_size = mem.readIntBig(u16, hello[i..][0..2]);
- i += 2;
- if (key_size != 32) return error.TlsBadLength;
- opt_x25519_server_pub_key = hello[i..][0..32];
- },
- else => {
- std.debug.print("named group: {x}\n", .{named_group});
- return error.TlsIllegalParameter;
- },
- }
- },
- else => {
- std.debug.print("unexpected extension: {x}\n", .{et});
- },
- }
- i = next_i;
- }
- const x25519_server_pub_key = opt_x25519_server_pub_key orelse
- return error.TlsIllegalParameter;
- const tls_version = if (supported_version == 0) legacy_version else supported_version;
- switch (tls_version) {
- @enumToInt(ProtocolVersion.tls_1_2) => {
- std.debug.print("server wants TLS v1.2\n", .{});
- },
- @enumToInt(ProtocolVersion.tls_1_3) => {
- std.debug.print("server wants TLS v1.3\n", .{});
- },
- else => return error.TlsIllegalParameter,
- }
-
- const shared_key = crypto.dh.X25519.scalarmult(
- x25519_priv_key,
- x25519_server_pub_key.*,
- ) catch return error.TlsDecryptFailure;
-
- switch (cipher_suite_tag) {
- inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |tag| {
- const P = std.meta.TagPayload(CipherParams, tag);
- cipher_params = @unionInit(CipherParams, @tagName(tag), .{
- .handshake_secret = undefined,
- .master_secret = undefined,
- .client_handshake_key = undefined,
- .server_handshake_key = undefined,
- .client_finished_key = undefined,
- .server_finished_key = undefined,
- .client_handshake_iv = undefined,
- .server_handshake_iv = undefined,
- .transcript_hash = P.Hash.init(.{}),
- });
- const p = &@field(cipher_params, @tagName(tag));
- p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
- p.transcript_hash.update(host); // Client Hello part 2
- p.transcript_hash.update(frag); // Server Hello
- const hello_hash = p.transcript_hash.peek();
- const zeroes = [1]u8{0} ** P.Hash.digest_length;
- const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
- const empty_hash = emptyHash(P.Hash);
- const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length);
- p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, &shared_key);
- const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length);
- p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
- const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
- const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
- p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length);
- p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length);
- p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
- p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
- p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
- p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
- //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\nclient_handshake_iv: {}\nserver_handshake_iv: {}\n", .{
- // std.fmt.fmtSliceHexLower(&shared_key),
- // std.fmt.fmtSliceHexLower(&hello_hash),
- // std.fmt.fmtSliceHexLower(&early_secret),
- // std.fmt.fmtSliceHexLower(&empty_hash),
- // std.fmt.fmtSliceHexLower(&hs_derived_secret),
- // std.fmt.fmtSliceHexLower(&p.handshake_secret),
- // std.fmt.fmtSliceHexLower(&client_secret),
- // std.fmt.fmtSliceHexLower(&server_secret),
- // std.fmt.fmtSliceHexLower(&p.client_handshake_iv),
- // std.fmt.fmtSliceHexLower(&p.server_handshake_iv),
- //});
- },
- .TLS_CHACHA20_POLY1305_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_8_SHA256 => {
- @panic("TODO");
- },
- }
- },
- else => return error.TlsUnexpectedMessage,
- }
- break :i end;
- };
-
- var read_seq: u64 = 0;
-
- while (true) {
- const end_hdr = i + 5;
- if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
- if (end_hdr > len) {
- len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
- if (end_hdr > len) return error.EndOfStream;
- }
- const ct = @intToEnum(ContentType, handshake_buf[i]);
- i += 1;
- const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
- i += 2;
- _ = legacy_version;
- const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
- i += 2;
- const end = i + record_size;
- if (end > handshake_buf.len) return error.TlsRecordOverflow;
- if (end > len) {
- len += try stream.readAtLeast(handshake_buf[len..], end - len);
- if (end > len) return error.EndOfStream;
- }
- switch (ct) {
- .change_cipher_spec => {
- if (record_size != 1) return error.TlsUnexpectedMessage;
- if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
- },
- .application_data => {
- var cleartext_buf: [8000]u8 = undefined;
- const cleartext = switch (cipher_params) {
- inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
- const P = @TypeOf(p.*);
- const ciphertext_len = record_size - P.AEAD.tag_length;
- const ciphertext = handshake_buf[i..][0..ciphertext_len];
- i += ciphertext.len;
- if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
- const cleartext = cleartext_buf[0..ciphertext.len];
- const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*;
- const V = @Vector(P.AEAD.nonce_length, u8);
- const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
- const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
- read_seq += 1;
- const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_handshake_iv) ^ operand;
- const ad = handshake_buf[end_hdr - 5 ..][0..5];
- P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch
- return error.TlsBadRecordMac;
- p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]);
- break :c cleartext;
- },
- .TLS_CHACHA20_POLY1305_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_8_SHA256 => {
- @panic("TODO");
- },
- };
-
- const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
- switch (inner_ct) {
- .handshake => {
- var ct_i: usize = 0;
- while (true) {
- const handshake_type = cleartext[ct_i];
- ct_i += 1;
- const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
- ct_i += 3;
- const next_handshake_i = ct_i + handshake_len;
- if (next_handshake_i > cleartext.len - 1)
- return error.TlsBadLength;
- switch (handshake_type) {
- @enumToInt(HandshakeType.encrypted_extensions) => {
- const total_ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
- ct_i += 2;
- const end_ext_i = ct_i + total_ext_size;
- while (ct_i < end_ext_i) {
- const et = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
- ct_i += 2;
- const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
- ct_i += 2;
- const next_ext_i = ct_i + ext_size;
- switch (et) {
- @enumToInt(ExtensionType.server_name) => {},
- else => {
- std.debug.print("encrypted extension: {any}\n", .{
- et,
- });
- },
- }
- ct_i = next_ext_i;
- }
- },
- @enumToInt(HandshakeType.certificate) => {
- std.debug.print("cool certificate bro\n", .{});
- },
- @enumToInt(HandshakeType.certificate_verify) => {
- std.debug.print("the certificate came with a fancy signature\n", .{});
- },
- @enumToInt(HandshakeType.finished) => {
- // This message is to trick buggy proxies into behaving correctly.
- const client_change_cipher_spec_msg = [_]u8{
- @enumToInt(ContentType.change_cipher_spec),
- 0x03, 0x03, // legacy protocol version
- 0x00, 0x01, // length
- 0x01,
- };
- const app_cipher = switch (cipher_params) {
- inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: {
- const P = @TypeOf(p.*);
- // TODO verify the server's data
- const handshake_hash = p.transcript_hash.finalResult();
- const verify_data = hmac(P.Hmac, &handshake_hash, p.client_finished_key);
- const out_cleartext = [_]u8{
- @enumToInt(HandshakeType.finished),
- 0, 0, verify_data.len, // length
- } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)};
-
- const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
-
- var finished_msg = [_]u8{
- @enumToInt(ContentType.application_data),
- 0x03, 0x03, // legacy protocol version
- 0, wrapped_len, // byte length of encrypted record
- } ++ ([1]u8{undefined} ** wrapped_len);
-
- const ad = finished_msg[0..5];
- const ciphertext = finished_msg[5..][0..out_cleartext.len];
- const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..];
- const nonce = p.client_handshake_iv;
- P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
-
- const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
- try stream.writeAll(&both_msgs);
-
- const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
- const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
- //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{
- // std.fmt.fmtSliceHexLower(&p.master_secret),
- // std.fmt.fmtSliceHexLower(&client_secret),
- // std.fmt.fmtSliceHexLower(&server_secret),
- //});
- break :c @unionInit(ApplicationCipher, @tagName(tag), .{
- .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
- .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
- .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
- .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
- });
- },
- .TLS_CHACHA20_POLY1305_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_8_SHA256 => {
- @panic("TODO");
- },
- };
- std.debug.print("remaining bytes: {d}\n", .{len - end});
- return .{
- .application_cipher = app_cipher,
- .read_seq = 0,
- .write_seq = 0,
- .partially_read_buffer = undefined,
- .partially_read_len = 0,
- .eof = false,
- };
- },
- else => {
- std.debug.print("handshake type: {d}\n", .{cleartext[0]});
- return error.TlsUnexpectedMessage;
- },
- }
- ct_i = next_handshake_i;
- if (ct_i >= cleartext.len - 1) break;
- }
- },
- else => {
- std.debug.print("inner content type: {any}\n", .{inner_ct});
- return error.TlsUnexpectedMessage;
- },
- }
- },
- else => {
- std.debug.print("content type: {s}\n", .{@tagName(ct)});
- return error.TlsUnexpectedMessage;
- },
- }
- i = end;
- }
-
- return error.TlsHandshakeFailure;
-}
-
-pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
- var ciphertext_buf: [max_ciphertext_record_len * 4]u8 = undefined;
- // Due to the trailing inner content type byte in the ciphertext, we need
- // an additional buffer for storing the cleartext into before encrypting.
- var cleartext_buf: [max_ciphertext_len]u8 = undefined;
- var iovecs_buf: [5]std.os.iovec_const = undefined;
- var ciphertext_end: usize = 0;
- var iovec_end: usize = 0;
- var bytes_i: usize = 0;
- // How many bytes are taken up by overhead per record.
- const overhead_len: usize = switch (tls.application_cipher) {
- inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: {
- const P = @TypeOf(p.*);
- const V = @Vector(P.AEAD.nonce_length, u8);
- const overhead_len = ciphertext_record_header_len + P.AEAD.tag_length + 1;
- while (true) {
- const encrypted_content_len = @intCast(u16, @min(
- @min(bytes.len - bytes_i, max_ciphertext_len - 1),
- ciphertext_buf.len -
- ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
- ));
- if (encrypted_content_len == 0) break :l overhead_len;
-
- mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
- cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data);
- bytes_i += encrypted_content_len;
- const ciphertext_len = encrypted_content_len + 1;
- const cleartext = cleartext_buf[0..ciphertext_len];
-
- const record_start = ciphertext_end;
- const ad = ciphertext_buf[ciphertext_end..][0..5];
- ad.* =
- [_]u8{@enumToInt(ContentType.application_data)} ++
- int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
- int2(ciphertext_len + P.AEAD.tag_length);
- ciphertext_end += ad.len;
- const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
- ciphertext_end += ciphertext_len;
- const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length];
- ciphertext_end += auth_tag.len;
- const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
- const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq));
- tls.write_seq += 1;
- const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand;
- P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
- //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{
- // tls.write_seq - 1,
- // std.fmt.fmtSliceHexLower(&nonce),
- // std.fmt.fmtSliceHexLower(&p.client_key),
- // std.fmt.fmtSliceHexLower(&p.client_iv),
- // std.fmt.fmtSliceHexLower(ad),
- // std.fmt.fmtSliceHexLower(auth_tag),
- // std.fmt.fmtSliceHexLower(&p.server_key),
- // std.fmt.fmtSliceHexLower(&p.server_iv),
- //});
-
- const record = ciphertext_buf[record_start..ciphertext_end];
- iovecs_buf[iovec_end] = .{
- .iov_base = record.ptr,
- .iov_len = record.len,
- };
- iovec_end += 1;
- }
- },
- .TLS_CHACHA20_POLY1305_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_8_SHA256 => {
- @panic("TODO");
- },
- };
-
- // Ideally we would call writev exactly once here, however, we must ensure
- // that we don't return with a record partially written.
- var i: usize = 0;
- var total_amt: usize = 0;
- while (true) {
- var amt = try stream.writev(iovecs_buf[i..iovec_end]);
- while (amt >= iovecs_buf[i].iov_len) {
- const encrypted_amt = iovecs_buf[i].iov_len;
- total_amt += encrypted_amt - overhead_len;
- amt -= encrypted_amt;
- i += 1;
- // Rely on the property that iovecs delineate records, meaning that
- // if amt equals zero here, we have fortunately found ourselves
- // with a short read that aligns at the record boundary.
- if (i >= iovec_end or amt == 0) return total_amt;
- }
- iovecs_buf[i].iov_base += amt;
- iovecs_buf[i].iov_len -= amt;
- }
-}
-
-pub fn writeAll(tls: *Tls, stream: net.Stream, bytes: []const u8) !void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try tls.write(stream, bytes[index..]);
- }
-}
-
-/// Returns number of bytes that have been read, which are now populated inside
-/// `buffer`. A return value of zero bytes does not necessarily mean end of
-/// stream.
-pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
- const prev_len = tls.partially_read_len;
- var in_buf: [max_ciphertext_len * 4]u8 = undefined;
- mem.copy(u8, &in_buf, tls.partially_read_buffer[0..prev_len]);
-
- // Capacity of output buffer, in records, rounded up.
- const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
- const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len);
- const ask_slice = in_buf[prev_len..@min(wanted_read_len, in_buf.len)];
- const actual_read_len = try stream.read(ask_slice);
- const frag = in_buf[0 .. prev_len + actual_read_len];
- if (frag.len == 0) {
- tls.eof = true;
- return 0;
- }
- var in: usize = 0;
- var out: usize = 0;
-
- while (true) {
- if (in + ciphertext_record_header_len > frag.len) {
- return finishRead(tls, frag, in, out);
- }
- const ct = @intToEnum(ContentType, frag[in]);
- in += 1;
- const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
- in += 2;
- _ = legacy_version;
- const record_size = mem.readIntBig(u16, frag[in..][0..2]);
- in += 2;
- const end = in + record_size;
- if (end > frag.len) {
- if (record_size > max_ciphertext_len) return error.TlsRecordOverflow;
- return finishRead(tls, frag, in, out);
- }
- switch (ct) {
- .alert => {
- @panic("TODO handle an alert here");
- },
- .application_data => {
- const cleartext_len = switch (tls.application_cipher) {
- inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
- const P = @TypeOf(p.*);
- const V = @Vector(P.AEAD.nonce_length, u8);
- const ad = frag[in - 5 ..][0..5];
- const ciphertext_len = record_size - P.AEAD.tag_length;
- const ciphertext = frag[in..][0..ciphertext_len];
- in += ciphertext_len;
- const auth_tag = frag[in..][0..P.AEAD.tag_length].*;
- const cleartext = buffer[out..][0..ciphertext_len];
- const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
- const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq));
- tls.read_seq += 1;
- const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
- //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{
- // tls.read_seq - 1,
- // std.fmt.fmtSliceHexLower(&nonce),
- // std.fmt.fmtSliceHexLower(&p.server_key),
- // std.fmt.fmtSliceHexLower(&p.server_iv),
- //});
- P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
- return error.TlsBadRecordMac;
- break :c cleartext.len;
- },
- .TLS_CHACHA20_POLY1305_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_SHA256 => {
- @panic("TODO");
- },
- .TLS_AES_128_CCM_8_SHA256 => {
- @panic("TODO");
- },
- };
-
- const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]);
- switch (inner_ct) {
- .alert => {
- const level = @intToEnum(AlertLevel, buffer[out]);
- const desc = @intToEnum(AlertDescription, buffer[out + 1]);
- if (desc == .close_notify) {
- tls.eof = true;
- return out;
- }
- std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
- return error.TlsAlert;
- },
- .handshake => {
- std.debug.print("the server wants to keep shaking hands\n", .{});
- },
- .application_data => {
- out += cleartext_len - 1;
- },
- else => {
- std.debug.print("inner content type: {d}\n", .{inner_ct});
- return error.TlsUnexpectedMessage;
- },
- }
- },
- else => {
- std.debug.print("unexpected ct: {any}\n", .{ct});
- return error.TlsUnexpectedMessage;
- },
- }
- in = end;
- }
-}
-
-fn finishRead(tls: *Tls, frag: []const u8, in: usize, out: usize) usize {
- const saved_buf = frag[in..];
- mem.copy(u8, &tls.partially_read_buffer, saved_buf);
- tls.partially_read_len = @intCast(u15, saved_buf.len);
- return out;
-}
-
-fn hkdfExpandLabel(
- comptime Hkdf: type,
- key: [Hkdf.prk_length]u8,
- label: []const u8,
- context: []const u8,
- comptime len: usize,
-) [len]u8 {
- const max_label_len = 255;
- const max_context_len = 255;
- const tls13 = "tls13 ";
- var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined;
- mem.writeIntBig(u16, buf[0..2], len);
- buf[2] = @intCast(u8, tls13.len + label.len);
- buf[3..][0..tls13.len].* = tls13.*;
- var i: usize = 3 + tls13.len;
- mem.copy(u8, buf[i..], label);
- i += label.len;
- buf[i] = @intCast(u8, context.len);
- i += 1;
- mem.copy(u8, buf[i..], context);
- i += context.len;
-
- var result: [len]u8 = undefined;
- Hkdf.expand(&result, buf[0..i], key);
- return result;
-}
-
-fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 {
- var result: [Hash.digest_length]u8 = undefined;
- Hash.hash(&.{}, &result, .{});
- return result;
-}
-
-fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 {
- var result: [Hmac.mac_length]u8 = undefined;
- Hmac.create(&result, message, &key);
- return result;
-}
-
-const builtin = @import("builtin");
-const native_endian = builtin.cpu.arch.endian();
-
-inline fn big(x: anytype) @TypeOf(x) {
- return switch (native_endian) {
- .Big => x,
- .Little => @byteSwap(x),
- };
-}
-
-inline fn int2(x: u16) [2]u8 {
- return .{
- @truncate(u8, x >> 8),
- @truncate(u8, x),
- };
-}
diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig
@@ -0,0 +1,325 @@
+//! Plaintext:
+//! * type: ContentType
+//! * legacy_record_version: u16 = 0x0303,
+//! * length: u16,
+//! - The length (in bytes) of the following TLSPlaintext.fragment. The
+//! length MUST NOT exceed 2^14 bytes.
+//! * fragment: opaque
+//! - the data being transmitted
+//!
+//! Ciphertext
+//! * ContentType opaque_type = application_data; /* 23 */
+//! * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */
+//! * uint16 length;
+//! * opaque encrypted_record[TLSCiphertext.length];
+//!
+//! Handshake:
+//! * type: HandshakeType
+//! * length: u24
+//! * data: opaque
+//!
+//! ServerHello:
+//! * ProtocolVersion legacy_version = 0x0303;
+//! * Random random;
+//! * opaque legacy_session_id_echo<0..32>;
+//! * CipherSuite cipher_suite;
+//! * uint8 legacy_compression_method = 0;
+//! * Extension extensions<6..2^16-1>;
+//!
+//! Extension:
+//! * ExtensionType extension_type;
+//! * opaque extension_data<0..2^16-1>;
+
+const std = @import("../std.zig");
+const Tls = @This();
+const net = std.net;
+const mem = std.mem;
+const crypto = std.crypto;
+const assert = std.debug.assert;
+
+pub const Client = @import("tls/Client.zig");
+
+pub const ciphertext_record_header_len = 5;
+pub const max_ciphertext_len = (1 << 14) + 256;
+pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len;
+pub const hello_retry_request_sequence = [32]u8{
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
+ 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
+};
+
+pub const ProtocolVersion = enum(u16) {
+ tls_1_2 = 0x0303,
+ tls_1_3 = 0x0304,
+ _,
+};
+
+pub const ContentType = enum(u8) {
+ invalid = 0,
+ change_cipher_spec = 20,
+ alert = 21,
+ handshake = 22,
+ application_data = 23,
+ _,
+};
+
+pub const HandshakeType = enum(u8) {
+ client_hello = 1,
+ server_hello = 2,
+ new_session_ticket = 4,
+ end_of_early_data = 5,
+ encrypted_extensions = 8,
+ certificate = 11,
+ certificate_request = 13,
+ certificate_verify = 15,
+ finished = 20,
+ key_update = 24,
+ message_hash = 254,
+};
+
+pub const ExtensionType = enum(u16) {
+ /// RFC 6066
+ server_name = 0,
+ /// RFC 6066
+ max_fragment_length = 1,
+ /// RFC 6066
+ status_request = 5,
+ /// RFC 8422, 7919
+ supported_groups = 10,
+ /// RFC 8446
+ signature_algorithms = 13,
+ /// RFC 5764
+ use_srtp = 14,
+ /// RFC 6520
+ heartbeat = 15,
+ /// RFC 7301
+ application_layer_protocol_negotiation = 16,
+ /// RFC 6962
+ signed_certificate_timestamp = 18,
+ /// RFC 7250
+ client_certificate_type = 19,
+ /// RFC 7250
+ server_certificate_type = 20,
+ /// RFC 7685
+ padding = 21,
+ /// RFC 8446
+ pre_shared_key = 41,
+ /// RFC 8446
+ early_data = 42,
+ /// RFC 8446
+ supported_versions = 43,
+ /// RFC 8446
+ cookie = 44,
+ /// RFC 8446
+ psk_key_exchange_modes = 45,
+ /// RFC 8446
+ certificate_authorities = 47,
+ /// RFC 8446
+ oid_filters = 48,
+ /// RFC 8446
+ post_handshake_auth = 49,
+ /// RFC 8446
+ signature_algorithms_cert = 50,
+ /// RFC 8446
+ key_share = 51,
+};
+
+pub const AlertLevel = enum(u8) {
+ warning = 1,
+ fatal = 2,
+ _,
+};
+
+pub const AlertDescription = enum(u8) {
+ close_notify = 0,
+ unexpected_message = 10,
+ bad_record_mac = 20,
+ record_overflow = 22,
+ handshake_failure = 40,
+ bad_certificate = 42,
+ unsupported_certificate = 43,
+ certificate_revoked = 44,
+ certificate_expired = 45,
+ certificate_unknown = 46,
+ illegal_parameter = 47,
+ unknown_ca = 48,
+ access_denied = 49,
+ decode_error = 50,
+ decrypt_error = 51,
+ protocol_version = 70,
+ insufficient_security = 71,
+ internal_error = 80,
+ inappropriate_fallback = 86,
+ user_canceled = 90,
+ missing_extension = 109,
+ unsupported_extension = 110,
+ unrecognized_name = 112,
+ bad_certificate_status_response = 113,
+ unknown_psk_identity = 115,
+ certificate_required = 116,
+ no_application_protocol = 120,
+ _,
+};
+
+pub const SignatureScheme = enum(u16) {
+ // RSASSA-PKCS1-v1_5 algorithms
+ rsa_pkcs1_sha256 = 0x0401,
+ rsa_pkcs1_sha384 = 0x0501,
+ rsa_pkcs1_sha512 = 0x0601,
+
+ // ECDSA algorithms
+ ecdsa_secp256r1_sha256 = 0x0403,
+ ecdsa_secp384r1_sha384 = 0x0503,
+ ecdsa_secp521r1_sha512 = 0x0603,
+
+ // RSASSA-PSS algorithms with public key OID rsaEncryption
+ rsa_pss_rsae_sha256 = 0x0804,
+ rsa_pss_rsae_sha384 = 0x0805,
+ rsa_pss_rsae_sha512 = 0x0806,
+
+ // EdDSA algorithms
+ ed25519 = 0x0807,
+ ed448 = 0x0808,
+
+ // RSASSA-PSS algorithms with public key OID RSASSA-PSS
+ rsa_pss_pss_sha256 = 0x0809,
+ rsa_pss_pss_sha384 = 0x080a,
+ rsa_pss_pss_sha512 = 0x080b,
+
+ // Legacy algorithms
+ rsa_pkcs1_sha1 = 0x0201,
+ ecdsa_sha1 = 0x0203,
+
+ _,
+};
+
+pub const NamedGroup = enum(u16) {
+ // Elliptic Curve Groups (ECDHE)
+ secp256r1 = 0x0017,
+ secp384r1 = 0x0018,
+ secp521r1 = 0x0019,
+ x25519 = 0x001D,
+ x448 = 0x001E,
+
+ // Finite Field Groups (DHE)
+ ffdhe2048 = 0x0100,
+ ffdhe3072 = 0x0101,
+ ffdhe4096 = 0x0102,
+ ffdhe6144 = 0x0103,
+ ffdhe8192 = 0x0104,
+
+ _,
+};
+
+pub const CipherSuite = enum(u16) {
+ TLS_AES_128_GCM_SHA256 = 0x1301,
+ TLS_AES_256_GCM_SHA384 = 0x1302,
+ TLS_CHACHA20_POLY1305_SHA256 = 0x1303,
+ TLS_AES_128_CCM_SHA256 = 0x1304,
+ TLS_AES_128_CCM_8_SHA256 = 0x1305,
+};
+
+pub const CipherParams = union(CipherSuite) {
+ TLS_AES_128_GCM_SHA256: struct {
+ pub const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
+ pub const Hash = crypto.hash.sha2.Sha256;
+ pub const Hmac = crypto.auth.hmac.Hmac(Hash);
+ pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+ handshake_secret: [Hkdf.prk_length]u8,
+ master_secret: [Hkdf.prk_length]u8,
+ client_handshake_key: [AEAD.key_length]u8,
+ server_handshake_key: [AEAD.key_length]u8,
+ client_finished_key: [Hmac.key_length]u8,
+ server_finished_key: [Hmac.key_length]u8,
+ client_handshake_iv: [AEAD.nonce_length]u8,
+ server_handshake_iv: [AEAD.nonce_length]u8,
+ transcript_hash: Hash,
+ },
+ TLS_AES_256_GCM_SHA384: struct {
+ pub const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
+ pub const Hash = crypto.hash.sha2.Sha384;
+ pub const Hmac = crypto.auth.hmac.Hmac(Hash);
+ pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+ handshake_secret: [Hkdf.prk_length]u8,
+ master_secret: [Hkdf.prk_length]u8,
+ client_handshake_key: [AEAD.key_length]u8,
+ server_handshake_key: [AEAD.key_length]u8,
+ client_finished_key: [Hmac.key_length]u8,
+ server_finished_key: [Hmac.key_length]u8,
+ client_handshake_iv: [AEAD.nonce_length]u8,
+ server_handshake_iv: [AEAD.nonce_length]u8,
+ transcript_hash: Hash,
+ },
+ TLS_CHACHA20_POLY1305_SHA256: void,
+ TLS_AES_128_CCM_SHA256: void,
+ TLS_AES_128_CCM_8_SHA256: void,
+};
+
+/// Encryption parameters for application traffic.
+pub const ApplicationCipher = union(CipherSuite) {
+ TLS_AES_128_GCM_SHA256: struct {
+ pub const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
+ pub const Hash = crypto.hash.sha2.Sha256;
+ pub const Hmac = crypto.auth.hmac.Hmac(Hash);
+ pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+ client_key: [AEAD.key_length]u8,
+ server_key: [AEAD.key_length]u8,
+ client_iv: [AEAD.nonce_length]u8,
+ server_iv: [AEAD.nonce_length]u8,
+ },
+ TLS_AES_256_GCM_SHA384: struct {
+ pub const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
+ pub const Hash = crypto.hash.sha2.Sha384;
+ pub const Hmac = crypto.auth.hmac.Hmac(Hash);
+ pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+ client_key: [AEAD.key_length]u8,
+ server_key: [AEAD.key_length]u8,
+ client_iv: [AEAD.nonce_length]u8,
+ server_iv: [AEAD.nonce_length]u8,
+ },
+ TLS_CHACHA20_POLY1305_SHA256: void,
+ TLS_AES_128_CCM_SHA256: void,
+ TLS_AES_128_CCM_8_SHA256: void,
+};
+
+pub fn hkdfExpandLabel(
+ comptime Hkdf: type,
+ key: [Hkdf.prk_length]u8,
+ label: []const u8,
+ context: []const u8,
+ comptime len: usize,
+) [len]u8 {
+ const max_label_len = 255;
+ const max_context_len = 255;
+ const tls13 = "tls13 ";
+ var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined;
+ mem.writeIntBig(u16, buf[0..2], len);
+ buf[2] = @intCast(u8, tls13.len + label.len);
+ buf[3..][0..tls13.len].* = tls13.*;
+ var i: usize = 3 + tls13.len;
+ mem.copy(u8, buf[i..], label);
+ i += label.len;
+ buf[i] = @intCast(u8, context.len);
+ i += 1;
+ mem.copy(u8, buf[i..], context);
+ i += context.len;
+
+ var result: [len]u8 = undefined;
+ Hkdf.expand(&result, buf[0..i], key);
+ return result;
+}
+
+pub fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 {
+ var result: [Hash.digest_length]u8 = undefined;
+ Hash.hash(&.{}, &result, .{});
+ return result;
+}
+
+pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) [Hmac.mac_length]u8 {
+ var result: [Hmac.mac_length]u8 = undefined;
+ Hmac.create(&result, message, &key);
+ return result;
+}
diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig
@@ -0,0 +1,751 @@
+const std = @import("../../std.zig");
+const tls = std.crypto.tls;
+const Client = @This();
+const net = std.net;
+const mem = std.mem;
+const crypto = std.crypto;
+const assert = std.debug.assert;
+
+const ApplicationCipher = tls.ApplicationCipher;
+const CipherSuite = tls.CipherSuite;
+const ContentType = tls.ContentType;
+const HandshakeType = tls.HandshakeType;
+const CipherParams = tls.CipherParams;
+const max_ciphertext_len = tls.max_ciphertext_len;
+const hkdfExpandLabel = tls.hkdfExpandLabel;
+
+application_cipher: ApplicationCipher,
+read_seq: u64,
+write_seq: u64,
+/// The size is enough to contain exactly one TLSCiphertext record.
+partially_read_buffer: [tls.max_ciphertext_record_len]u8,
+/// The number of partially read bytes inside `partiall_read_buffer`.
+partially_read_len: u15,
+eof: bool,
+
+const cipher_suites = blk: {
+ const fields = @typeInfo(CipherSuite).Enum.fields;
+ var result: [(fields.len + 1) * 2]u8 = undefined;
+ mem.writeIntBig(u16, result[0..2], result.len - 2);
+ for (fields) |field, i| {
+ const int = @enumToInt(@field(CipherSuite, field.name));
+ result[(i + 1) * 2] = @truncate(u8, int >> 8);
+ result[(i + 1) * 2 + 1] = @truncate(u8, int);
+ }
+ break :blk result;
+};
+
+/// `host` is only borrowed during this function call.
+pub fn init(stream: net.Stream, host: []const u8) !Client {
+ var x25519_priv_key: [32]u8 = undefined;
+ crypto.random.bytes(&x25519_priv_key);
+ const x25519_pub_key = crypto.dh.X25519.recoverPublicKey(x25519_priv_key) catch |err| {
+ switch (err) {
+ // Only possible to happen if the private key is all zeroes.
+ error.IdentityElement => return error.InsufficientEntropy,
+ }
+ };
+
+ // random (u32)
+ var rand_buf: [32]u8 = undefined;
+ crypto.random.bytes(&rand_buf);
+
+ const extensions_header = [_]u8{
+ // Extensions byte length
+ undefined, undefined,
+
+ // Extension: supported_versions (only TLS 1.3)
+ 0, 43, // ExtensionType.supported_versions
+ 0x00, 0x05, // byte length of this extension payload
+ 0x04, // byte length of supported versions
+ 0x03, 0x04, // TLS 1.3
+ 0x03, 0x03, // TLS 1.2
+
+ // Extension: signature_algorithms
+ 0, 13, // ExtensionType.signature_algorithms
+ 0x00, 0x22, // byte length of this extension payload
+ 0x00, 0x20, // byte length of signature algorithms list
+ 0x04, 0x01, // rsa_pkcs1_sha256
+ 0x05, 0x01, // rsa_pkcs1_sha384
+ 0x06, 0x01, // rsa_pkcs1_sha512
+ 0x04, 0x03, // ecdsa_secp256r1_sha256
+ 0x05, 0x03, // ecdsa_secp384r1_sha384
+ 0x06, 0x03, // ecdsa_secp521r1_sha512
+ 0x08, 0x04, // rsa_pss_rsae_sha256
+ 0x08, 0x05, // rsa_pss_rsae_sha384
+ 0x08, 0x06, // rsa_pss_rsae_sha512
+ 0x08, 0x07, // ed25519
+ 0x08, 0x08, // ed448
+ 0x08, 0x09, // rsa_pss_pss_sha256
+ 0x08, 0x0a, // rsa_pss_pss_sha384
+ 0x08, 0x0b, // rsa_pss_pss_sha512
+ 0x02, 0x01, // rsa_pkcs1_sha1
+ 0x02, 0x03, // ecdsa_sha1
+
+ // Extension: supported_groups
+ 0, 10, // ExtensionType.supported_groups
+ 0x00, 0x0c, // byte length of this extension payload
+ 0x00, 0x0a, // byte length of supported groups list
+ 0x00, 0x17, // secp256r1
+ 0x00, 0x18, // secp384r1
+ 0x00, 0x19, // secp521r1
+ 0x00, 0x1D, // x25519
+ 0x00, 0x1E, // x448
+
+ // Extension: key_share
+ 0, 51, // ExtensionType.key_share
+ 0, 38, // byte length of this extension payload
+ 0, 36, // byte length of client_shares
+ 0x00, 0x1D, // NamedGroup.x25519
+ 0, 32, // byte length of key_exchange
+ } ++ x25519_pub_key ++ [_]u8{
+
+ // Extension: server_name
+ 0, 0, // ExtensionType.server_name
+ undefined, undefined, // byte length of this extension payload
+ undefined, undefined, // server_name_list byte count
+ 0x00, // name_type
+ undefined, undefined, // host name len
+ };
+
+ var hello_header = [_]u8{
+ // Plaintext header
+ @enumToInt(ContentType.handshake),
+ 0x03, 0x01, // legacy_record_version
+ undefined, undefined, // Plaintext fragment length (u16)
+
+ // Handshake header
+ @enumToInt(HandshakeType.client_hello),
+ undefined, undefined, undefined, // handshake length (u24)
+
+ // ClientHello
+ 0x03, 0x03, // legacy_version
+ } ++ rand_buf ++ [1]u8{0} ++ cipher_suites ++ [_]u8{
+ 0x01, 0x00, // legacy_compression_methods
+ } ++ extensions_header;
+
+ mem.writeIntBig(u16, hello_header[3..][0..2], @intCast(u16, hello_header.len - 5 + host.len));
+ mem.writeIntBig(u24, hello_header[6..][0..3], @intCast(u24, hello_header.len - 9 + host.len));
+ mem.writeIntBig(
+ u16,
+ hello_header[hello_header.len - extensions_header.len ..][0..2],
+ @intCast(u16, extensions_header.len - 2 + host.len),
+ );
+ mem.writeIntBig(u16, hello_header[hello_header.len - 7 ..][0..2], @intCast(u16, 5 + host.len));
+ mem.writeIntBig(u16, hello_header[hello_header.len - 5 ..][0..2], @intCast(u16, 3 + host.len));
+ mem.writeIntBig(u16, hello_header[hello_header.len - 2 ..][0..2], @intCast(u16, 0 + host.len));
+
+ {
+ var iovecs = [_]std.os.iovec_const{
+ .{
+ .iov_base = &hello_header,
+ .iov_len = hello_header.len,
+ },
+ .{
+ .iov_base = host.ptr,
+ .iov_len = host.len,
+ },
+ };
+ try stream.writevAll(&iovecs);
+ }
+
+ const client_hello_bytes1 = hello_header[5..];
+
+ var cipher_params: CipherParams = undefined;
+
+ var handshake_buf: [8000]u8 = undefined;
+ var len: usize = 0;
+ var i: usize = i: {
+ const plaintext = handshake_buf[0..5];
+ len = try stream.readAtLeast(&handshake_buf, plaintext.len);
+ if (len < plaintext.len) return error.EndOfStream;
+ const ct = @intToEnum(ContentType, plaintext[0]);
+ const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
+ const end = plaintext.len + frag_len;
+ if (end > handshake_buf.len) return error.TlsRecordOverflow;
+ if (end > len) {
+ len += try stream.readAtLeast(handshake_buf[len..], end - len);
+ if (end > len) return error.EndOfStream;
+ }
+ const frag = handshake_buf[plaintext.len..end];
+
+ switch (ct) {
+ .alert => {
+ const level = @intToEnum(tls.AlertLevel, frag[0]);
+ const desc = @intToEnum(tls.AlertDescription, frag[1]);
+ std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
+ return error.TlsAlert;
+ },
+ .handshake => {
+ if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
+ return error.TlsUnexpectedMessage;
+ }
+ const length = mem.readIntBig(u24, frag[1..4]);
+ if (4 + length != frag.len) return error.TlsBadLength;
+ const hello = frag[4..];
+ const legacy_version = mem.readIntBig(u16, hello[0..2]);
+ const random = hello[2..34].*;
+ if (mem.eql(u8, &random, &tls.hello_retry_request_sequence)) {
+ @panic("TODO handle HelloRetryRequest");
+ }
+ const legacy_session_id_echo_len = hello[34];
+ if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
+ const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
+ const cipher_suite_tag = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
+ return error.TlsIllegalParameter;
+ std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite_tag)});
+ const legacy_compression_method = hello[37];
+ _ = legacy_compression_method;
+ const extensions_size = mem.readIntBig(u16, hello[38..40]);
+ if (40 + extensions_size != hello.len) return error.TlsBadLength;
+ var i: usize = 40;
+ var supported_version: u16 = 0;
+ var opt_x25519_server_pub_key: ?*[32]u8 = null;
+ while (i < hello.len) {
+ const et = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ const next_i = i + ext_size;
+ if (next_i > hello.len) return error.TlsBadLength;
+ switch (et) {
+ @enumToInt(tls.ExtensionType.supported_versions) => {
+ if (supported_version != 0) return error.TlsIllegalParameter;
+ supported_version = mem.readIntBig(u16, hello[i..][0..2]);
+ },
+ @enumToInt(tls.ExtensionType.key_share) => {
+ if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter;
+ const named_group = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ switch (named_group) {
+ @enumToInt(tls.NamedGroup.x25519) => {
+ const key_size = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ if (key_size != 32) return error.TlsBadLength;
+ opt_x25519_server_pub_key = hello[i..][0..32];
+ },
+ else => {
+ std.debug.print("named group: {x}\n", .{named_group});
+ return error.TlsIllegalParameter;
+ },
+ }
+ },
+ else => {
+ std.debug.print("unexpected extension: {x}\n", .{et});
+ },
+ }
+ i = next_i;
+ }
+ const x25519_server_pub_key = opt_x25519_server_pub_key orelse
+ return error.TlsIllegalParameter;
+ const tls_version = if (supported_version == 0) legacy_version else supported_version;
+ switch (tls_version) {
+ @enumToInt(tls.ProtocolVersion.tls_1_2) => {
+ std.debug.print("server wants TLS v1.2\n", .{});
+ },
+ @enumToInt(tls.ProtocolVersion.tls_1_3) => {
+ std.debug.print("server wants TLS v1.3\n", .{});
+ },
+ else => return error.TlsIllegalParameter,
+ }
+
+ const shared_key = crypto.dh.X25519.scalarmult(
+ x25519_priv_key,
+ x25519_server_pub_key.*,
+ ) catch return error.TlsDecryptFailure;
+
+ switch (cipher_suite_tag) {
+ inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |tag| {
+ const P = std.meta.TagPayload(CipherParams, tag);
+ cipher_params = @unionInit(CipherParams, @tagName(tag), .{
+ .handshake_secret = undefined,
+ .master_secret = undefined,
+ .client_handshake_key = undefined,
+ .server_handshake_key = undefined,
+ .client_finished_key = undefined,
+ .server_finished_key = undefined,
+ .client_handshake_iv = undefined,
+ .server_handshake_iv = undefined,
+ .transcript_hash = P.Hash.init(.{}),
+ });
+ const p = &@field(cipher_params, @tagName(tag));
+ p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
+ p.transcript_hash.update(host); // Client Hello part 2
+ p.transcript_hash.update(frag); // Server Hello
+ const hello_hash = p.transcript_hash.peek();
+ const zeroes = [1]u8{0} ** P.Hash.digest_length;
+ const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
+ const empty_hash = tls.emptyHash(P.Hash);
+ const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length);
+ p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, &shared_key);
+ const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length);
+ p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
+ const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
+ const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
+ p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length);
+ p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length);
+ p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
+ p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
+ p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
+ p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
+ //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\nclient_handshake_iv: {}\nserver_handshake_iv: {}\n", .{
+ // std.fmt.fmtSliceHexLower(&shared_key),
+ // std.fmt.fmtSliceHexLower(&hello_hash),
+ // std.fmt.fmtSliceHexLower(&early_secret),
+ // std.fmt.fmtSliceHexLower(&empty_hash),
+ // std.fmt.fmtSliceHexLower(&hs_derived_secret),
+ // std.fmt.fmtSliceHexLower(&p.handshake_secret),
+ // std.fmt.fmtSliceHexLower(&client_secret),
+ // std.fmt.fmtSliceHexLower(&server_secret),
+ // std.fmt.fmtSliceHexLower(&p.client_handshake_iv),
+ // std.fmt.fmtSliceHexLower(&p.server_handshake_iv),
+ //});
+ },
+ .TLS_CHACHA20_POLY1305_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_8_SHA256 => {
+ @panic("TODO");
+ },
+ }
+ },
+ else => return error.TlsUnexpectedMessage,
+ }
+ break :i end;
+ };
+
+ var read_seq: u64 = 0;
+
+ while (true) {
+ const end_hdr = i + 5;
+ if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
+ if (end_hdr > len) {
+ len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
+ if (end_hdr > len) return error.EndOfStream;
+ }
+ const ct = @intToEnum(ContentType, handshake_buf[i]);
+ i += 1;
+ const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
+ i += 2;
+ _ = legacy_version;
+ const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
+ i += 2;
+ const end = i + record_size;
+ if (end > handshake_buf.len) return error.TlsRecordOverflow;
+ if (end > len) {
+ len += try stream.readAtLeast(handshake_buf[len..], end - len);
+ if (end > len) return error.EndOfStream;
+ }
+ switch (ct) {
+ .change_cipher_spec => {
+ if (record_size != 1) return error.TlsUnexpectedMessage;
+ if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
+ },
+ .application_data => {
+ var cleartext_buf: [8000]u8 = undefined;
+ const cleartext = switch (cipher_params) {
+ inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
+ const P = @TypeOf(p.*);
+ const ciphertext_len = record_size - P.AEAD.tag_length;
+ const ciphertext = handshake_buf[i..][0..ciphertext_len];
+ i += ciphertext.len;
+ if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
+ const cleartext = cleartext_buf[0..ciphertext.len];
+ const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*;
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
+ read_seq += 1;
+ const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_handshake_iv) ^ operand;
+ const ad = handshake_buf[end_hdr - 5 ..][0..5];
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch
+ return error.TlsBadRecordMac;
+ p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]);
+ break :c cleartext;
+ },
+ .TLS_CHACHA20_POLY1305_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_8_SHA256 => {
+ @panic("TODO");
+ },
+ };
+
+ const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
+ switch (inner_ct) {
+ .handshake => {
+ var ct_i: usize = 0;
+ while (true) {
+ const handshake_type = cleartext[ct_i];
+ ct_i += 1;
+ const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
+ ct_i += 3;
+ const next_handshake_i = ct_i + handshake_len;
+ if (next_handshake_i > cleartext.len - 1)
+ return error.TlsBadLength;
+ switch (handshake_type) {
+ @enumToInt(HandshakeType.encrypted_extensions) => {
+ const total_ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
+ ct_i += 2;
+ const end_ext_i = ct_i + total_ext_size;
+ while (ct_i < end_ext_i) {
+ const et = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
+ ct_i += 2;
+ const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
+ ct_i += 2;
+ const next_ext_i = ct_i + ext_size;
+ switch (et) {
+ @enumToInt(tls.ExtensionType.server_name) => {},
+ else => {
+ std.debug.print("encrypted extension: {any}\n", .{
+ et,
+ });
+ },
+ }
+ ct_i = next_ext_i;
+ }
+ },
+ @enumToInt(HandshakeType.certificate) => {
+ std.debug.print("cool certificate bro\n", .{});
+ },
+ @enumToInt(HandshakeType.certificate_verify) => {
+ std.debug.print("the certificate came with a fancy signature\n", .{});
+ },
+ @enumToInt(HandshakeType.finished) => {
+ // This message is to trick buggy proxies into behaving correctly.
+ const client_change_cipher_spec_msg = [_]u8{
+ @enumToInt(ContentType.change_cipher_spec),
+ 0x03, 0x03, // legacy protocol version
+ 0x00, 0x01, // length
+ 0x01,
+ };
+ const app_cipher = switch (cipher_params) {
+ inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: {
+ const P = @TypeOf(p.*);
+ // TODO verify the server's data
+ const handshake_hash = p.transcript_hash.finalResult();
+ const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
+ const out_cleartext = [_]u8{
+ @enumToInt(HandshakeType.finished),
+ 0, 0, verify_data.len, // length
+ } ++ verify_data ++ [1]u8{@enumToInt(ContentType.handshake)};
+
+ const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
+
+ var finished_msg = [_]u8{
+ @enumToInt(ContentType.application_data),
+ 0x03, 0x03, // legacy protocol version
+ 0, wrapped_len, // byte length of encrypted record
+ } ++ ([1]u8{undefined} ** wrapped_len);
+
+ const ad = finished_msg[0..5];
+ const ciphertext = finished_msg[5..][0..out_cleartext.len];
+ const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..];
+ const nonce = p.client_handshake_iv;
+ P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
+
+ const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
+ try stream.writeAll(&both_msgs);
+
+ const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
+ const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
+ //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{
+ // std.fmt.fmtSliceHexLower(&p.master_secret),
+ // std.fmt.fmtSliceHexLower(&client_secret),
+ // std.fmt.fmtSliceHexLower(&server_secret),
+ //});
+ break :c @unionInit(ApplicationCipher, @tagName(tag), .{
+ .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
+ .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
+ .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
+ .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
+ });
+ },
+ .TLS_CHACHA20_POLY1305_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_8_SHA256 => {
+ @panic("TODO");
+ },
+ };
+ std.debug.print("remaining bytes: {d}\n", .{len - end});
+ return .{
+ .application_cipher = app_cipher,
+ .read_seq = 0,
+ .write_seq = 0,
+ .partially_read_buffer = undefined,
+ .partially_read_len = 0,
+ .eof = false,
+ };
+ },
+ else => {
+ std.debug.print("handshake type: {d}\n", .{cleartext[0]});
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ ct_i = next_handshake_i;
+ if (ct_i >= cleartext.len - 1) break;
+ }
+ },
+ else => {
+ std.debug.print("inner content type: {any}\n", .{inner_ct});
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ },
+ else => {
+ std.debug.print("content type: {s}\n", .{@tagName(ct)});
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ i = end;
+ }
+
+ return error.TlsHandshakeFailure;
+}
+
+pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
+ var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
+ // Due to the trailing inner content type byte in the ciphertext, we need
+ // an additional buffer for storing the cleartext into before encrypting.
+ var cleartext_buf: [max_ciphertext_len]u8 = undefined;
+ var iovecs_buf: [5]std.os.iovec_const = undefined;
+ var ciphertext_end: usize = 0;
+ var iovec_end: usize = 0;
+ var bytes_i: usize = 0;
+ // How many bytes are taken up by overhead per record.
+ const overhead_len: usize = switch (c.application_cipher) {
+ inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: {
+ const P = @TypeOf(p.*);
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1;
+ while (true) {
+ const encrypted_content_len = @intCast(u16, @min(
+ @min(bytes.len - bytes_i, max_ciphertext_len - 1),
+ ciphertext_buf.len -
+ tls.ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
+ ));
+ if (encrypted_content_len == 0) break :l overhead_len;
+
+ mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
+ cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data);
+ bytes_i += encrypted_content_len;
+ const ciphertext_len = encrypted_content_len + 1;
+ const cleartext = cleartext_buf[0..ciphertext_len];
+
+ const record_start = ciphertext_end;
+ const ad = ciphertext_buf[ciphertext_end..][0..5];
+ ad.* =
+ [_]u8{@enumToInt(ContentType.application_data)} ++
+ int2(@enumToInt(tls.ProtocolVersion.tls_1_2)) ++
+ int2(ciphertext_len + P.AEAD.tag_length);
+ ciphertext_end += ad.len;
+ const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
+ ciphertext_end += ciphertext_len;
+ const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length];
+ ciphertext_end += auth_tag.len;
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ @bitCast([8]u8, big(c.write_seq));
+ c.write_seq += 1;
+ const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand;
+ P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
+ //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{
+ // c.write_seq - 1,
+ // std.fmt.fmtSliceHexLower(&nonce),
+ // std.fmt.fmtSliceHexLower(&p.client_key),
+ // std.fmt.fmtSliceHexLower(&p.client_iv),
+ // std.fmt.fmtSliceHexLower(ad),
+ // std.fmt.fmtSliceHexLower(auth_tag),
+ // std.fmt.fmtSliceHexLower(&p.server_key),
+ // std.fmt.fmtSliceHexLower(&p.server_iv),
+ //});
+
+ const record = ciphertext_buf[record_start..ciphertext_end];
+ iovecs_buf[iovec_end] = .{
+ .iov_base = record.ptr,
+ .iov_len = record.len,
+ };
+ iovec_end += 1;
+ }
+ },
+ .TLS_CHACHA20_POLY1305_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_8_SHA256 => {
+ @panic("TODO");
+ },
+ };
+
+ // Ideally we would call writev exactly once here, however, we must ensure
+ // that we don't return with a record partially written.
+ var i: usize = 0;
+ var total_amt: usize = 0;
+ while (true) {
+ var amt = try stream.writev(iovecs_buf[i..iovec_end]);
+ while (amt >= iovecs_buf[i].iov_len) {
+ const encrypted_amt = iovecs_buf[i].iov_len;
+ total_amt += encrypted_amt - overhead_len;
+ amt -= encrypted_amt;
+ i += 1;
+ // Rely on the property that iovecs delineate records, meaning that
+ // if amt equals zero here, we have fortunately found ourselves
+ // with a short read that aligns at the record boundary.
+ if (i >= iovec_end or amt == 0) return total_amt;
+ }
+ iovecs_buf[i].iov_base += amt;
+ iovecs_buf[i].iov_len -= amt;
+ }
+}
+
+pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void {
+ var index: usize = 0;
+ while (index < bytes.len) {
+ index += try c.write(stream, bytes[index..]);
+ }
+}
+
+/// Returns number of bytes that have been read, which are now populated inside
+/// `buffer`. A return value of zero bytes does not necessarily mean end of
+/// stream.
+pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
+ const prev_len = c.partially_read_len;
+ var in_buf: [max_ciphertext_len * 4]u8 = undefined;
+ mem.copy(u8, &in_buf, c.partially_read_buffer[0..prev_len]);
+
+ // Capacity of output buffer, in records, rounded up.
+ const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
+ const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len);
+ const ask_slice = in_buf[prev_len..@min(wanted_read_len, in_buf.len)];
+ const actual_read_len = try stream.read(ask_slice);
+ const frag = in_buf[0 .. prev_len + actual_read_len];
+ if (frag.len == 0) {
+ c.eof = true;
+ return 0;
+ }
+ var in: usize = 0;
+ var out: usize = 0;
+
+ while (true) {
+ if (in + tls.ciphertext_record_header_len > frag.len) {
+ return finishRead(c, frag, in, out);
+ }
+ const ct = @intToEnum(ContentType, frag[in]);
+ in += 1;
+ const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
+ in += 2;
+ _ = legacy_version;
+ const record_size = mem.readIntBig(u16, frag[in..][0..2]);
+ in += 2;
+ const end = in + record_size;
+ if (end > frag.len) {
+ if (record_size > max_ciphertext_len) return error.TlsRecordOverflow;
+ return finishRead(c, frag, in, out);
+ }
+ switch (ct) {
+ .alert => {
+ @panic("TODO handle an alert here");
+ },
+ .application_data => {
+ const cleartext_len = switch (c.application_cipher) {
+ inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
+ const P = @TypeOf(p.*);
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const ad = frag[in - 5 ..][0..5];
+ const ciphertext_len = record_size - P.AEAD.tag_length;
+ const ciphertext = frag[in..][0..ciphertext_len];
+ in += ciphertext_len;
+ const auth_tag = frag[in..][0..P.AEAD.tag_length].*;
+ const cleartext = buffer[out..][0..ciphertext_len];
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq));
+ c.read_seq += 1;
+ const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
+ //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{
+ // c.read_seq - 1,
+ // std.fmt.fmtSliceHexLower(&nonce),
+ // std.fmt.fmtSliceHexLower(&p.server_key),
+ // std.fmt.fmtSliceHexLower(&p.server_iv),
+ //});
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
+ return error.TlsBadRecordMac;
+ break :c cleartext.len;
+ },
+ .TLS_CHACHA20_POLY1305_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_8_SHA256 => {
+ @panic("TODO");
+ },
+ };
+
+ const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]);
+ switch (inner_ct) {
+ .alert => {
+ const level = @intToEnum(tls.AlertLevel, buffer[out]);
+ const desc = @intToEnum(tls.AlertDescription, buffer[out + 1]);
+ if (desc == .close_notify) {
+ c.eof = true;
+ return out;
+ }
+ std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
+ return error.TlsAlert;
+ },
+ .handshake => {
+ std.debug.print("the server wants to keep shaking hands\n", .{});
+ },
+ .application_data => {
+ out += cleartext_len - 1;
+ },
+ else => {
+ std.debug.print("inner content type: {d}\n", .{inner_ct});
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ },
+ else => {
+ std.debug.print("unexpected ct: {any}\n", .{ct});
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ in = end;
+ }
+}
+
+fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
+ const saved_buf = frag[in..];
+ mem.copy(u8, &c.partially_read_buffer, saved_buf);
+ c.partially_read_len = @intCast(u15, saved_buf.len);
+ return out;
+}
+
+const builtin = @import("builtin");
+const native_endian = builtin.cpu.arch.endian();
+
+inline fn big(x: anytype) @TypeOf(x) {
+ return switch (native_endian) {
+ .Big => x,
+ .Little => @byteSwap(x),
+ };
+}
+
+inline fn int2(x: u16) [2]u8 {
+ return .{
+ @truncate(u8, x >> 8),
+ @truncate(u8, x),
+ };
+}
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig
@@ -12,7 +12,7 @@ pub const Request = struct {
client: *Client,
stream: net.Stream,
headers: std.ArrayListUnmanaged(u8) = .{},
- tls: std.crypto.Tls,
+ tls_client: std.crypto.tls.Client,
protocol: Protocol,
pub const Protocol = enum { http, https };
@@ -51,7 +51,7 @@ pub const Request = struct {
try req.stream.writeAll(req.headers.items);
},
.https => {
- try req.tls.writeAll(req.stream, req.headers.items);
+ try req.tls_client.writeAll(req.stream, req.headers.items);
},
}
}
@@ -59,7 +59,7 @@ pub const Request = struct {
pub fn read(req: *Request, buffer: []u8) !usize {
switch (req.protocol) {
.http => return req.stream.read(buffer),
- .https => return req.tls.read(req.stream, buffer),
+ .https => return req.tls_client.read(req.stream, buffer),
}
}
@@ -74,7 +74,7 @@ pub const Request = struct {
if (amt == 0) {
switch (req.protocol) {
.http => break,
- .https => if (req.tls.eof) break,
+ .https => if (req.tls_client.eof) break,
}
}
index += amt;
@@ -94,7 +94,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
.client = client,
.stream = try net.tcpConnectToHost(client.allocator, options.host, options.port),
.protocol = options.protocol,
- .tls = undefined,
+ .tls_client = undefined,
};
client.active_requests += 1;
errdefer req.deinit();
@@ -102,7 +102,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
switch (options.protocol) {
.http => {},
.https => {
- req.tls = try std.crypto.Tls.init(req.stream, options.host);
+ req.tls_client = try std.crypto.tls.Client.init(req.stream, options.host);
},
}