diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index babd8b465d..3944b7c974 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -237,6 +237,7 @@ pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type { client_handshake_iv: [AEAD.nonce_length]u8, server_handshake_iv: [AEAD.nonce_length]u8, transcript_hash: Hash, + finished_digest: [Hash.digest_length]u8, }; } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 4afc1b7e17..c4ac6e508a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -257,6 +257,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { .client_handshake_iv = undefined, .server_handshake_iv = undefined, .transcript_hash = P.Hash.init(.{}), + .finished_digest = undefined, }); const p = &@field(cipher_params, @tagName(tag)); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 @@ -391,6 +392,11 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { }, @enumToInt(HandshakeType.certificate_verify) => { std.debug.print("the certificate came with a fancy signature\n", .{}); + switch (cipher_params) { + inline else => |*p| { + p.finished_digest = p.transcript_hash.peek(); + }, + } }, @enumToInt(HandshakeType.finished) => { // This message is to trick buggy proxies into behaving correctly. @@ -403,7 +409,10 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { const app_cipher = switch (cipher_params) { inline else => |*p, tag| c: { const P = @TypeOf(p.*); - // TODO verify the server's data + const expected_server_verify_data = tls.hmac(P.Hmac, &p.finished_digest, p.server_finished_key); + const actual_server_verify_data = cleartext[ct_i..][0..handshake_len]; + if (!mem.eql(u8, &expected_server_verify_data, actual_server_verify_data)) + return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); const out_cleartext = [_]u8{ @@ -454,7 +463,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client { }; }, else => { - std.debug.print("handshake type: {d}\n", .{cleartext[0]}); return error.TlsUnexpectedMessage; }, }