diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index fc2523f02a..8ef4d9bfad 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -39,9 +39,9 @@ const assert = std.debug.assert; pub const Client = @import("tls/Client.zig"); -pub const ciphertext_record_header_len = 5; +pub const record_header_len = 5; pub const max_ciphertext_len = (1 << 14) + 256; -pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; +pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len; pub const hello_retry_request_sequence = [32]u8{ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, @@ -360,3 +360,130 @@ pub inline fn int3(x: u24) [3]u8 { @truncate(u8, x), }; } + +/// An abstraction to ensure that protocol-parsing code does not perform an +/// out-of-bounds read. +pub const Decoder = struct { + buf: []u8, + /// Points to the next byte in buffer that will be decoded. + idx: usize = 0, + /// Up to this point in `buf` we have already checked that `cap` is greater than it. + our_end: usize = 0, + /// Beyond this point in `buf` is extra tag-along bytes beyond the amount we + /// requested with `readAtLeast`. + their_end: usize = 0, + /// Points to the end within buffer that has been filled. Beyond this point + /// in buf is undefined bytes. + cap: usize = 0, + /// Debug helper to prevent illegal calls to read functions. + disable_reads: bool = false, + + pub fn fromTheirSlice(buf: []u8) Decoder { + return .{ + .buf = buf, + .their_end = buf.len, + .cap = buf.len, + .disable_reads = true, + }; + } + + /// Use this function to increase `their_end`. + pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { + assert(!d.disable_reads); + const existing_amt = d.cap - d.idx; + d.their_end = d.idx + their_amt; + if (their_amt <= existing_amt) return; + const request_amt = their_amt - existing_amt; + const dest = d.buf[d.cap..]; + if (request_amt > dest.len) return error.TlsRecordOverflow; + const actual_amt = try stream.readAtLeast(dest, request_amt); + if (actual_amt < request_amt) return error.TlsConnectionTruncated; + d.cap += actual_amt; + } + + /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. + /// Use when `our_amt` is calculated by us, not by them. + pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { + assert(!d.disable_reads); + try readAtLeast(d, stream, our_amt); + d.our_end = d.idx + our_amt; + } + + /// Use this function to increase `our_end`. + /// This should always be called with an amount provided by us, not them. + pub fn ensure(d: *Decoder, amt: usize) !void { + d.our_end = @max(d.idx + amt, d.our_end); + if (d.our_end > d.their_end) return error.TlsDecodeError; + } + + /// Use this function to increase `idx`. + pub fn decode(d: *Decoder, comptime T: type) T { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => { + skip(d, 1); + return d.buf[d.idx - 1]; + }, + 16 => { + skip(d, 2); + const b0: u16 = d.buf[d.idx - 2]; + const b1: u16 = d.buf[d.idx - 1]; + return (b0 << 8) | b1; + }, + 24 => { + skip(d, 3); + const b0: u24 = d.buf[d.idx - 3]; + const b1: u24 = d.buf[d.idx - 2]; + const b2: u24 = d.buf[d.idx - 1]; + return (b0 << 16) | (b1 << 8) | b2; + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + const int = d.decode(info.tag_type); + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return @intToEnum(T, int); + }, + else => @compileError("unsupported type: " ++ @typeName(T)), + } + } + + /// Use this function to increase `idx`. + pub fn array(d: *Decoder, comptime len: usize) *[len]u8 { + skip(d, len); + return d.buf[d.idx - len ..][0..len]; + } + + /// Use this function to increase `idx`. + pub fn slice(d: *Decoder, len: usize) []u8 { + skip(d, len); + return d.buf[d.idx - len ..][0..len]; + } + + /// Use this function to increase `idx`. + pub fn skip(d: *Decoder, amt: usize) void { + d.idx += amt; + assert(d.idx <= d.our_end); // insufficient ensured bytes + } + + pub fn eof(d: Decoder) bool { + assert(d.our_end <= d.their_end); + assert(d.idx <= d.our_end); + return d.idx == d.their_end; + } + + /// Provide the length they claim, and receive a sub-decoder specific to that slice. + /// The parent decoder is advanced to the end. + pub fn sub(d: *Decoder, their_len: usize) !Decoder { + const end = d.idx + their_len; + if (end > d.their_end) return error.TlsDecodeError; + const sub_buf = d.buf[d.idx..end]; + d.idx = end; + d.our_end = end; + return fromTheirSlice(sub_buf); + } + + pub fn rest(d: Decoder) []u8 { + return d.buf[d.idx..d.cap]; + } +}; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index ec6f00ad8a..bca05a3ffd 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -126,88 +126,73 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const client_hello_bytes1 = plaintext_header[5..]; var handshake_cipher: tls.HandshakeCipher = undefined; - - var handshake_buf: [8000]u8 = undefined; - var len: usize = 0; - var i: usize = i: { - const plaintext = handshake_buf[0..5]; - len = try stream.readAtLeast(&handshake_buf, plaintext.len); - if (len < plaintext.len) return error.EndOfStream; - const ct = @intToEnum(tls.ContentType, plaintext[0]); - const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]); - const end = plaintext.len + frag_len; - if (end > handshake_buf.len) return error.TlsRecordOverflow; - if (end > len) { - len += try stream.readAtLeast(handshake_buf[len..], end - len); - if (end > len) return error.EndOfStream; - } - const frag = handshake_buf[plaintext.len..end]; - + var handshake_buffer: [8000]u8 = undefined; + var d: tls.Decoder = .{ .buf = &handshake_buffer }; + { + try d.readAtLeastOurAmt(stream, tls.record_header_len); + const ct = d.decode(tls.ContentType); + d.skip(2); // legacy_record_version + const record_len = d.decode(u16); + try d.readAtLeast(stream, record_len); + const server_hello_fragment = d.buf[d.idx..][0..record_len]; + var ptd = try d.sub(record_len); switch (ct) { .alert => { - const level = @intToEnum(tls.AlertLevel, frag[0]); - const desc = @intToEnum(tls.AlertDescription, frag[1]); + try ptd.ensure(2); + const level = ptd.decode(tls.AlertLevel); + const desc = ptd.decode(tls.AlertDescription); _ = level; _ = desc; return error.TlsAlert; }, .handshake => { - if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) { + try ptd.ensure(4); + const handshake_type = ptd.decode(tls.HandshakeType); + if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; + const length = ptd.decode(u24); + var hsd = try ptd.sub(length); + try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); + const legacy_version = hsd.decode(u16); + const random = hsd.array(32); + if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { + // This is a HelloRetryRequest message. This client implementation + // does not expect to get one. return error.TlsUnexpectedMessage; } - const length = mem.readIntBig(u24, frag[1..4]); - if (4 + length != frag.len) return error.TlsBadLength; - var i: usize = 4; - const legacy_version = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - const random = frag[i..][0..32].*; - i += 32; - if (mem.eql(u8, &random, &tls.hello_retry_request_sequence)) { - @panic("TODO handle HelloRetryRequest"); - } - const legacy_session_id_echo_len = frag[i]; - i += 1; + const legacy_session_id_echo_len = hsd.decode(u8); if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; - const legacy_session_id_echo = frag[i..][0..32]; + const legacy_session_id_echo = hsd.array(32); if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter; - i += 32; - const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - const cipher_suite_tag = @intToEnum(tls.CipherSuite, cipher_suite_int); - const legacy_compression_method = frag[i]; - i += 1; - _ = legacy_compression_method; - const extensions_size = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - if (i + extensions_size != frag.len) return error.TlsBadLength; + const cipher_suite_tag = hsd.decode(tls.CipherSuite); + hsd.skip(1); // legacy_compression_method + const extensions_size = hsd.decode(u16); + var all_extd = try hsd.sub(extensions_size); var supported_version: u16 = 0; var shared_key: [32]u8 = undefined; var have_shared_key = false; - while (i < frag.len) { - const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2])); - i += 2; - const ext_size = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - const next_i = i + ext_size; - if (next_i > frag.len) return error.TlsBadLength; + while (!all_extd.eof()) { + try all_extd.ensure(2 + 2); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); switch (et) { .supported_versions => { if (supported_version != 0) return error.TlsIllegalParameter; - supported_version = mem.readIntBig(u16, frag[i..][0..2]); + try extd.ensure(2); + supported_version = extd.decode(u16); }, .key_share => { if (have_shared_key) return error.TlsIllegalParameter; have_shared_key = true; - const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2])); - i += 2; - const key_size = mem.readIntBig(u16, frag[i..][0..2]); - i += 2; - + try extd.ensure(4); + const named_group = extd.decode(tls.NamedGroup); + const key_size = extd.decode(u16); + try extd.ensure(key_size); switch (named_group) { .x25519 => { - if (key_size != 32) return error.TlsBadLength; - const server_pub_key = frag[i..][0..32]; + if (key_size != 32) return error.TlsIllegalParameter; + const server_pub_key = extd.array(32); shared_key = crypto.dh.X25519.scalarmult( x25519_kp.secret_key, @@ -215,7 +200,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) ) catch return error.TlsDecryptFailure; }, .secp256r1 => { - const server_pub_key = frag[i..][0..key_size]; + const server_pub_key = extd.slice(key_size); const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; const pk = PublicKey.fromSec1(server_pub_key) catch { @@ -233,14 +218,12 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, else => {}, } - i = next_i; } if (!have_shared_key) return error.TlsIllegalParameter; + const tls_version = if (supported_version == 0) legacy_version else supported_version; - switch (tls_version) { - @enumToInt(tls.ProtocolVersion.tls_1_3) => {}, - else => return error.TlsIllegalParameter, - } + if (tls_version != @enumToInt(tls.ProtocolVersion.tls_1_3)) + return error.TlsIllegalParameter; switch (cipher_suite_tag) { inline .AES_128_GCM_SHA256, @@ -264,7 +247,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const p = &@field(handshake_cipher, @tagName(tag)); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 p.transcript_hash.update(host); // Client Hello part 2 - p.transcript_hash.update(frag); // Server Hello + p.transcript_hash.update(server_hello_fragment); const hello_hash = p.transcript_hash.peek(); const zeroes = [1]u8{0} ** P.Hash.digest_length; const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); @@ -289,8 +272,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, else => return error.TlsUnexpectedMessage, } - break :i end; - }; + } // This is used for two purposes: // * Detect whether a certificate is the first one presented, in which case @@ -322,29 +304,17 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) var main_cert_pub_key_len: u16 = undefined; while (true) { - const end_hdr = i + 5; - if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; - if (end_hdr > len) { - len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); - if (end_hdr > len) return error.EndOfStream; - } - const ct = @intToEnum(tls.ContentType, handshake_buf[i]); - i += 1; - const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]); - i += 2; - _ = legacy_version; - const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]); - i += 2; - const end = i + record_size; - if (end > handshake_buf.len) return error.TlsRecordOverflow; - if (end > len) { - len += try stream.readAtLeast(handshake_buf[len..], end - len); - if (end > len) return error.EndOfStream; - } + try d.readAtLeastOurAmt(stream, tls.record_header_len); + const record_header = d.buf[d.idx..][0..5]; + const ct = d.decode(tls.ContentType); + d.skip(2); // legacy_version + const record_len = d.decode(u16); + try d.readAtLeast(stream, record_len); + var record_decoder = try d.sub(record_len); switch (ct) { .change_cipher_spec => { - if (record_size != 1) return error.TlsUnexpectedMessage; - if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; + try record_decoder.ensure(1); + if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; }, .application_data => { const cleartext_buf = &cleartext_bufs[cert_index % 2]; @@ -352,276 +322,261 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) const cleartext = switch (handshake_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); - const ciphertext_len = record_size - P.AEAD.tag_length; - const ciphertext = handshake_buf[i..][0..ciphertext_len]; - i += ciphertext.len; + const ciphertext_len = record_len - P.AEAD.tag_length; + try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); + const ciphertext = record_decoder.slice(ciphertext_len); if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*; + const auth_tag = record_decoder.array(P.AEAD.tag_length).*; const V = @Vector(P.AEAD.nonce_length, u8); const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(read_seq)); read_seq += 1; const nonce = @as(V, p.server_handshake_iv) ^ operand; - const ad = handshake_buf[end_hdr - 5 ..][0..5]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch return error.TlsBadRecordMac; break :c cleartext; }, }; const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); - switch (inner_ct) { - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len - 1) - return error.TlsBadLength; - const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i]; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .encrypted_extensions => { - if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - const total_ext_size = mem.readIntBig(u16, handshake[0..2]); - var hs_i: usize = 2; - const end_ext_i = 2 + total_ext_size; - while (hs_i < end_ext_i) { - const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2])); - hs_i += 2; - const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); - hs_i += 2; - const next_ext_i = hs_i + ext_size; - switch (et) { - .server_name => {}, - else => {}, - } - hs_i = next_ext_i; + if (inner_ct != .handshake) return error.TlsUnexpectedMessage; + + var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); + while (true) { + try ctd.ensure(4); + const handshake_type = ctd.decode(tls.HandshakeType); + const handshake_len = ctd.decode(u24); + var hsd = try ctd.sub(handshake_len); + const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; + const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; + switch (handshake_type) { + .encrypted_extensions => { + if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; + handshake_state = .certificate; + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + try hsd.ensure(2); + const total_ext_size = hsd.decode(u16); + var all_extd = try hsd.sub(total_ext_size); + while (!all_extd.eof()) { + try all_extd.ensure(4); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); + _ = extd; + switch (et) { + .server_name => {}, + else => {}, + } + } + }, + .certificate => cert: { + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + switch (handshake_state) { + .certificate => {}, + .trust_chain_established => break :cert, + else => return error.TlsUnexpectedMessage, + } + try hsd.ensure(1 + 4); + const cert_req_ctx_len = hsd.decode(u8); + if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; + const certs_size = hsd.decode(u24); + var certs_decoder = try hsd.sub(certs_size); + while (!certs_decoder.eof()) { + try certs_decoder.ensure(3); + const cert_size = certs_decoder.decode(u24); + var certd = try certs_decoder.sub(cert_size); + + const subject_cert: Certificate = .{ + .buffer = certd.buf, + .index = @intCast(u32, certd.idx), + }; + const subject = try subject_cert.parse(); + if (cert_index == 0) { + // Verify the host on the first certificate. + if (!hostMatchesCommonName(host, subject.commonName())) { + return error.TlsCertificateHostMismatch; } + + // Keep track of the public key for the + // certificate_verify message later. + main_cert_pub_key_algo = subject.pub_key_algo; + const pub_key = subject.pubKey(); + if (pub_key.len > main_cert_pub_key_buf.len) + return error.CertificatePublicKeyInvalid; + @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len); + main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len); + } else { + try prev_cert.verify(subject); + } + + if (ca_bundle.verify(subject)) |_| { + handshake_state = .trust_chain_established; + break :cert; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => |e| return e, + } + + prev_cert = subject; + cert_index += 1; + + try certs_decoder.ensure(2); + const total_ext_size = certs_decoder.decode(u16); + var all_extd = try certs_decoder.sub(total_ext_size); + _ = all_extd; + } + }, + .certificate_verify => { + switch (handshake_state) { + .trust_chain_established => handshake_state = .finished, + .certificate => return error.TlsCertificateNotVerified, + else => return error.TlsUnexpectedMessage, + } + + try hsd.ensure(4); + const scheme = hsd.decode(tls.SignatureScheme); + const sig_len = hsd.decode(u16); + try hsd.ensure(sig_len); + const encoded_sig = hsd.slice(sig_len); + const max_digest_len = 64; + var verify_buffer = + ([1]u8{0x20} ** 64) ++ + "TLS 1.3, server CertificateVerify\x00".* ++ + @as([max_digest_len]u8, undefined); + + const verify_bytes = switch (handshake_cipher) { + inline else => |*p| v: { + const transcript_digest = p.transcript_hash.peek(); + verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; + p.transcript_hash.update(wrapped_handshake); + break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; }, - .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - switch (handshake_state) { - .certificate => {}, - .trust_chain_established => break :cert, - else => return error.TlsUnexpectedMessage, - } - var hs_i: u32 = 0; - const cert_req_ctx_len = handshake[hs_i]; - hs_i += 1; - if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; - const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); - hs_i += 3; - const end_certs = hs_i + certs_size; - while (hs_i < end_certs) { - const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]); - hs_i += 3; - const end_cert = hs_i + cert_size; + }; + const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - const subject_cert: Certificate = .{ - .buffer = handshake, - .index = hs_i, - }; - const subject = try subject_cert.parse(); - if (cert_index == 0) { - // Verify the host on the first certificate. - if (!hostMatchesCommonName(host, subject.commonName())) { - return error.TlsCertificateHostMismatch; - } - - // Keep track of the public key for - // the certificate_verify message - // later. - main_cert_pub_key_algo = subject.pub_key_algo; - const pub_key = subject.pubKey(); - if (pub_key.len > main_cert_pub_key_buf.len) - return error.CertificatePublicKeyInvalid; - @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len); - main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len); - } else { - try prev_cert.verify(subject); - } - - if (ca_bundle.verify(subject)) |_| { - handshake_state = .trust_chain_established; - break :cert; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => |e| return e, - } - - prev_cert = subject; - cert_index += 1; - - hs_i = end_cert; - const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]); - hs_i += 2; - hs_i += total_ext_size; - } + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) + return error.TlsBadSignatureScheme; + const Ecdsa = SchemeEcdsa(comptime_scheme); + const sig = try Ecdsa.Signature.fromDer(encoded_sig); + const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); + try sig.verify(verify_bytes, key); }, - .certificate_verify => { - switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, - .certificate => return error.TlsCertificateNotVerified, - else => return error.TlsUnexpectedMessage, - } + .rsa_pss_rsae_sha256 => { + if (main_cert_pub_key_algo != .rsaEncryption) + return error.TlsBadSignatureScheme; - const scheme = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2])); - const sig_len = mem.readIntBig(u16, handshake[2..4]); - if (4 + sig_len > handshake.len) return error.TlsBadLength; - const encoded_sig = handshake[4..][0..sig_len]; - const max_digest_len = 64; - var verify_buffer = - ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - @as([max_digest_len]u8, undefined); - - const verify_bytes = switch (handshake_cipher) { - inline else => |*p| v: { - const transcript_digest = p.transcript_hash.peek(); - verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; - p.transcript_hash.update(wrapped_handshake); - break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; - }, - }; - const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - return error.TlsBadSignatureScheme; - const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = try Ecdsa.Signature.fromDer(encoded_sig); - const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); - try sig.verify(verify_bytes, key); - }, - .rsa_pss_rsae_sha256 => { - if (main_cert_pub_key_algo != .rsaEncryption) - return error.TlsBadSignatureScheme; - - const Hash = crypto.hash.sha2.Sha256; - const rsa = Certificate.rsa; - const components = try rsa.PublicKey.parseDer(main_cert_pub_key); - const exponent = components.exponent; - const modulus = components.modulus; - var rsa_mem_buf: [512 * 32]u8 = undefined; - var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); - const ally = fba.allocator(); - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally); - const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally); - }, - else => { - return error.TlsBadRsaSignatureBitCount; - }, - } + const Hash = crypto.hash.sha2.Sha256; + const rsa = Certificate.rsa; + const components = try rsa.PublicKey.parseDer(main_cert_pub_key); + const exponent = components.exponent; + const modulus = components.modulus; + var rsa_mem_buf: [512 * 32]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf); + const ally = fba.allocator(); + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally); + const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); + try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally); }, else => { - return error.TlsBadSignatureScheme; + return error.TlsBadRsaSignatureBitCount; }, } }, - .finished => { - if (handshake_state != .finished) return error.TlsUnexpectedMessage; - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @enumToInt(tls.ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (handshake_cipher) { - inline else => |*p, tag| c: { - const P = @TypeOf(p.*); - const finished_digest = p.transcript_hash.peek(); - p.transcript_hash.update(wrapped_handshake); - const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); - if (!mem.eql(u8, &expected_server_verify_data, handshake)) - return error.TlsDecryptError; - const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); - const out_cleartext = [_]u8{ - @enumToInt(tls.HandshakeType.finished), - 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)}; - - const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - - var finished_msg = [_]u8{ - @enumToInt(tls.ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ @as([wrapped_len]u8, undefined); - - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; - const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - - const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - try stream.writeAll(&both_msgs); - - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ - .client_secret = client_secret, - .server_secret = server_secret, - .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), - .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), - .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), - .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); - }, - }; - var client: Client = .{ - .read_seq = 0, - .write_seq = 0, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(u15, len - end), - .received_close_notify = false, - .application_cipher = app_cipher, - .partially_read_buffer = undefined, - }; - mem.copy(u8, &client.partially_read_buffer, handshake_buf[len..end]); - return client; - }, else => { - return error.TlsUnexpectedMessage; + return error.TlsBadSignatureScheme; }, } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len - 1) break; - } - }, - else => { - return error.TlsUnexpectedMessage; - }, + }, + .finished => { + if (handshake_state != .finished) return error.TlsUnexpectedMessage; + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = [_]u8{ + @enumToInt(tls.ContentType.change_cipher_spec), + 0x03, 0x03, // legacy protocol version + 0x00, 0x01, // length + 0x01, + }; + const app_cipher = switch (handshake_cipher) { + inline else => |*p, tag| c: { + const P = @TypeOf(p.*); + const finished_digest = p.transcript_hash.peek(); + p.transcript_hash.update(wrapped_handshake); + const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); + if (!mem.eql(u8, &expected_server_verify_data, handshake)) + return error.TlsDecryptError; + const handshake_hash = p.transcript_hash.finalResult(); + const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const out_cleartext = [_]u8{ + @enumToInt(tls.HandshakeType.finished), + 0, 0, verify_data.len, // length + } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)}; + + const wrapped_len = out_cleartext.len + P.AEAD.tag_length; + + var finished_msg = [_]u8{ + @enumToInt(tls.ContentType.application_data), + 0x03, 0x03, // legacy protocol version + 0, wrapped_len, // byte length of encrypted record + } ++ @as([wrapped_len]u8, undefined); + + const ad = finished_msg[0..5]; + const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; + const nonce = p.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + + const both_msgs = client_change_cipher_spec_msg ++ finished_msg; + try stream.writeAll(&both_msgs); + + const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ + .client_secret = client_secret, + .server_secret = server_secret, + .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), + .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), + .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), + .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), + }); + }, + }; + const leftover = d.rest(); + var client: Client = .{ + .read_seq = 0, + .write_seq = 0, + .partial_cleartext_idx = 0, + .partial_ciphertext_idx = 0, + .partial_ciphertext_end = @intCast(u15, leftover.len), + .received_close_notify = false, + .application_cipher = app_cipher, + .partially_read_buffer = undefined, + }; + mem.copy(u8, &client.partially_read_buffer, leftover); + return client; + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + if (ctd.eof()) break; } }, else => { return error.TlsUnexpectedMessage; }, } - i = end; } - - return error.TlsHandshakeFailure; } pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { @@ -638,12 +593,12 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize { inline else => |*p| l: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); - const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1; + const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; while (true) { const encrypted_content_len = @intCast(u16, @min( @min(bytes.len - bytes_i, max_ciphertext_len - 1), ciphertext_buf.len - - tls.ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1, + tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1, )); if (encrypted_content_len == 0) break :l overhead_len; @@ -829,7 +784,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove // Cleartext capacity of output buffer, in records, rounded up. const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len; - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); + const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); const actual_read_len = try stream.readv(ask_iovecs); @@ -860,13 +815,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove continue; } - if (in + tls.ciphertext_record_header_len > frag.len) { + if (in + tls.record_header_len > frag.len) { if (frag.ptr == frag1.ptr) return finishRead(c, frag, in, vp.total); const first = frag[in..]; - if (frag1.len < tls.ciphertext_record_header_len) + if (frag1.len < tls.record_header_len) return finishRead2(c, first, frag1, vp.total); // A record straddles the two fragments. Copy into the now-empty first fragment. @@ -875,7 +830,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const record_len = (record_len_byte_0 << 8) | record_len_byte_1; if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - const full_record_len = record_len + tls.ciphertext_record_header_len; + const full_record_len = record_len + tls.record_header_len; const second_len = full_record_len - first.len; if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); @@ -898,14 +853,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const end = in + record_len; if (end > frag.len) { // We need the record header on the next iteration of the loop. - in -= tls.ciphertext_record_header_len; + in -= tls.record_header_len; if (frag.ptr == frag1.ptr) return finishRead(c, frag, in, vp.total); // A record straddles the two fragments. Copy into the now-empty first fragment. const first = frag[in..]; - const full_record_len = record_len + tls.ciphertext_record_header_len; + const full_record_len = record_len + tls.record_header_len; const second_len = full_record_len - first.len; if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); @@ -919,7 +874,12 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove } switch (ct) { .alert => { - @panic("TODO handle an alert here"); + if (in + 2 > frag.len) return error.TlsDecodeError; + const level = @intToEnum(tls.AlertLevel, frag[in]); + const desc = @intToEnum(tls.AlertDescription, frag[in + 1]); + _ = level; + _ = desc; + return error.TlsAlert; }, .application_data => { const cleartext = switch (c.application_cipher) {