From 22e2aaa283646858502ac1075c9657383366005d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 29 Dec 2022 17:56:46 -0700 Subject: [PATCH] crypto.tls: support rsa_pss_rsae_sha256 and fixes * fix eof logic * fix read logic * fix VecPut logic * add some debug prints to remove later --- lib/std/crypto/Certificate.zig | 198 ++++++++++++++++++++++++++++++--- lib/std/crypto/tls/Client.zig | 78 +++++++++---- 2 files changed, 239 insertions(+), 37 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index a8511d4d9e..cce0193cf0 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -474,19 +474,9 @@ fn verifyRsa( pub_key: []const u8, ) !void { if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; - const pub_key_seq = try der.Element.parse(pub_key, 0); - if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; - const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); - if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); - if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - // Skip over meaningless zeroes in the modulus. - const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; - const modulus_offset = for (modulus_raw) |byte, i| { - if (byte != 0) break i; - } else modulus_raw.len; - const modulus = modulus_raw[modulus_offset..]; - const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end]; + const pk_components = try rsa.PublicKey.parseDer(pub_key); + const exponent = pk_components.exponent; + const modulus = pk_components.modulus; if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; @@ -688,10 +678,154 @@ test { /// which is licensed under the Apache License Version 2.0, January 2004 /// http://www.apache.org/licenses/ /// The code has been modified. -const rsa = struct { +pub const rsa = struct { const BigInt = std.math.big.int.Managed; - const PublicKey = struct { + pub const PSSSignature = struct { + pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { + var result = [1]u8{0} ** modulus_len; + std.mem.copy(u8, &result, msg); + return result; + } + + pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void { + const mod_bits = try countBits(public_key.n.toConst(), allocator); + const em_dec = try encrypt(modulus_len, sig, public_key, allocator); + + try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator); + } + + fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void { + // TODO + // 1. If the length of M is greater than the input limitation for + // the hash function (2^61 - 1 octets for SHA-1), output + // "inconsistent" and stop. + + // emLen = \ceil(emBits/8) + const emLen = ((emBit - 1) / 8) + 1; + std.debug.assert(emLen == em.len); + + // 2. Let mHash = Hash(M), an octet string of length hLen. + var mHash: [Hash.digest_length]u8 = undefined; + Hash.hash(msg, &mHash, .{}); + + // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. + if (emLen < Hash.digest_length + sLen + 2) { + return error.InvalidSignature; + } + + // 4. If the rightmost octet of EM does not have hexadecimal value + // 0xbc, output "inconsistent" and stop. + if (em[em.len - 1] != 0xbc) { + return error.InvalidSignature; + } + + // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, + // and let H be the next hLen octets. + const maskedDB = em[0..(emLen - Hash.digest_length - 1)]; + const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)]; + + // 6. If the leftmost 8emLen - emBits bits of the leftmost octet in + // maskedDB are not all equal to zero, output "inconsistent" and + // stop. + const zero_bits = emLen * 8 - emBit; + var mask: u8 = maskedDB[0]; + var i: usize = 0; + while (i < 8 - zero_bits) : (i += 1) { + mask = mask >> 1; + } + if (mask != 0) { + return error.InvalidSignature; + } + + // 7. Let dbMask = MGF(H, emLen - hLen - 1). + const mgf_len = emLen - Hash.digest_length - 1; + var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length); + defer allocator.free(mgf_out); + var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator); + + // 8. Let DB = maskedDB \xor dbMask. + i = 0; + while (i < dbMask.len) : (i += 1) { + dbMask[i] = maskedDB[i] ^ dbMask[i]; + } + + // 9. Set the leftmost 8emLen - emBits bits of the leftmost octet + // in DB to zero. + i = 0; + mask = 0; + while (i < 8 - zero_bits) : (i += 1) { + mask = mask << 1; + mask += 1; + } + dbMask[0] = dbMask[0] & mask; + + // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not + // zero or if the octet at position emLen - hLen - sLen - 1 (the + // leftmost position is "position 1") does not have hexadecimal + // value 0x01, output "inconsistent" and stop. + if (dbMask[mgf_len - sLen - 2] != 0x00) { + return error.InvalidSignature; + } + + if (dbMask[mgf_len - sLen - 1] != 0x01) { + return error.InvalidSignature; + } + + // 11. Let salt be the last sLen octets of DB. + const salt = dbMask[(mgf_len - sLen)..]; + + // 12. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen); + defer allocator.free(m_p); + std.mem.copy(u8, m_p, &([_]u8{0} ** 8)); + std.mem.copy(u8, m_p[8..], &mHash); + std.mem.copy(u8, m_p[(8 + Hash.digest_length)..], salt); + + // 13. Let H' = Hash(M'), an octet string of length hLen. + var h_p: [Hash.digest_length]u8 = undefined; + Hash.hash(m_p, &h_p, .{}); + + // 14. If H = H', output "consistent". Otherwise, output + // "inconsistent". + if (!std.mem.eql(u8, h, &h_p)) { + return error.InvalidSignature; + } + } + + fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 { + var counter: usize = 0; + var idx: usize = 0; + var c: [4]u8 = undefined; + + var hash = try allocator.alloc(u8, seed.len + c.len); + defer allocator.free(hash); + std.mem.copy(u8, hash, seed); + var hashed: [Hash.digest_length]u8 = undefined; + + while (idx < len) { + c[0] = @intCast(u8, (counter >> 24) & 0xFF); + c[1] = @intCast(u8, (counter >> 16) & 0xFF); + c[2] = @intCast(u8, (counter >> 8) & 0xFF); + c[3] = @intCast(u8, counter & 0xFF); + + std.mem.copy(u8, hash[seed.len..], &c); + Hash.hash(hash, &hashed, .{}); + + std.mem.copy(u8, out[idx..], &hashed); + idx += hashed.len; + + counter += 1; + } + + return out[0..len]; + } + }; + + pub const PublicKey = struct { n: BigInt, e: BigInt, @@ -714,6 +848,24 @@ const rsa = struct { .e = _e, }; } + + pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } { + const pub_key_seq = try der.Element.parse(pub_key, 0); + if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); + if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); + if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = for (modulus_raw) |byte, i| { + if (byte != 0) break i; + } else modulus_raw.len; + return .{ + .modulus = modulus_raw[modulus_offset..], + .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end], + }; + } }; fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 { @@ -812,6 +964,20 @@ const rsa = struct { try BigInt.divFloor(&q, rem, a, n); } + fn countBits(a: std.math.big.int.Const, allocator: std.mem.Allocator) !usize { + var i: usize = 0; + var a_copy = try BigInt.init(allocator); + defer a_copy.deinit(); + try a_copy.copy(a); + + while (!a_copy.eqZero()) { + try a_copy.shiftRight(&a_copy, 1); + i += 1; + } + + return i; + } + // TODO: flush the toilet - const poop = std.heap.page_allocator; + pub const poop = std.heap.page_allocator; }; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 0e23101ee3..2eb5923187 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -536,7 +536,24 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) try sig.verify(verify_bytes, key); }, .rsa_pss_rsae_sha256 => { - @panic("TODO signature scheme: rsa_pss_rsae_sha256"); + if (main_cert_pub_key_algo != .rsaEncryption) + return error.TlsBadSignatureScheme; + + const Hash = crypto.hash.sha2.Sha256; + const rsa = Certificate.rsa; + const components = try rsa.PublicKey.parseDer(main_cert_pub_key); + const exponent = components.exponent; + const modulus = components.modulus; + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop); + const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); + try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, rsa.poop); + }, + else => { + return error.TlsBadRsaSignatureBitCount; + }, + } }, else => { //std.debug.print("signature scheme: {any}\n", .{ @@ -737,7 +754,7 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { } pub fn eof(c: Client) bool { - return c.received_close_notify and c.partial_ciphertext_end == 0; + return c.received_close_notify and c.partial_ciphertext_idx >= c.partial_ciphertext_end; } /// Returns the number of bytes read, calling the underlying read function the @@ -822,6 +839,10 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = 0; c.partial_ciphertext_end = 0; + } else { + std.debug.print("finished giving partial cleartext. {d} bytes ciphertext remain\n", .{ + c.partial_ciphertext_end - c.partial_ciphertext_idx, + }); } } @@ -866,8 +887,9 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove // There might be more bytes inside `in_stack_buffer` that need to be processed, // but at least frag0 will have one complete ciphertext record. - const frag0 = c.partially_read_buffer[0..@min(c.partially_read_buffer.len, actual_read_len)]; - var frag1 = in_stack_buffer[0 .. actual_read_len - frag0.len]; + const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); + const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; + var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; // We need to decipher frag0 and frag1 but there may be a ciphertext record // straddling the boundary. We can handle this with two memcpy() calls to // assemble the straddling record in between handling the two sides. @@ -900,12 +922,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const record_len = (record_len_byte_0 << 8) | record_len_byte_1; if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - const second_len = record_len + tls.ciphertext_record_header_len - first.len; + const full_record_len = record_len + tls.ciphertext_record_header_len; + const second_len = full_record_len - first.len; if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); mem.copy(u8, frag[0..in], first); mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag = frag[0..full_record_len]; frag1 = frag1[second_len..]; in = 0; continue; @@ -914,23 +938,35 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); in += 2; - _ = legacy_version; + //_ = legacy_version; const record_len = mem.readIntBig(u16, frag[in..][0..2]); + std.debug.print("ct={any} legacy_version={x} record_len={d}\n", .{ + ct, legacy_version, record_len, + }); if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; in += 2; const end = in + record_len; if (end > frag.len) { + // We need the record header on the next iteration of the loop. + in -= tls.ciphertext_record_header_len; + if (frag.ptr == frag1.ptr) return finishRead(c, frag, in, vp.total); // A record straddles the two fragments. Copy into the now-empty first fragment. const first = frag[in..]; - const second_len = record_len + tls.ciphertext_record_header_len - first.len; - if (frag1.len < second_len) + const full_record_len = record_len + tls.ciphertext_record_header_len; + const second_len = full_record_len - first.len; + if (frag1.len < second_len) { + std.debug.print("end > frag.len finishRead2 end={d} frag.len={d}\n", .{ + end, frag.len, + }); return finishRead2(c, first, frag1, vp.total); + } mem.copy(u8, frag[0..in], first); mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag = frag[0..full_record_len]; frag1 = frag1[second_len..]; in = 0; continue; @@ -991,9 +1027,11 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { .new_session_ticket => { + std.debug.print("new_session_ticket\n", .{}); // This client implementation ignores new session tickets. }, .key_update => { + std.debug.print("key_update\n", .{}); switch (c.application_cipher) { inline else => |*p| { const P = @TypeOf(p.*); @@ -1042,10 +1080,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove const dest = c.partially_read_buffer[c.partial_ciphertext_idx..]; mem.copy(u8, dest, msg); c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len); + std.debug.print("application_data {d} bytes to partial buffer\n", .{msg.len}); } else { const amt = vp.put(msg); + std.debug.print("application_data {d} bytes to read buffer\n", .{msg.len}); if (amt < msg.len) { const rest = msg[amt..]; + std.debug.print(" {d} bytes to partial buffer\n", .{rest.len}); c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len); mem.copy(u8, &c.partially_read_buffer, rest); @@ -1055,6 +1096,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove // Output buffer was used directly which means no // memory copying needs to occur, and we can move // on to the next ciphertext record. + std.debug.print("application_data {d} bytes directly to read buffer\n", .{cleartext.len - 1}); vp.next(cleartext.len - 1); } }, @@ -1166,10 +1208,6 @@ const VecPut = struct { const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; mem.copy(u8, dest, src); bytes_i += src.len; - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } vp.off += src.len; if (vp.off >= v.iov_len) { vp.off = 0; @@ -1179,6 +1217,10 @@ const VecPut = struct { return bytes_i; } } + if (bytes_i >= bytes.len) { + vp.total += bytes_i; + return bytes_i; + } } } @@ -1201,17 +1243,11 @@ const VecPut = struct { } fn freeSize(vp: VecPut) usize { + if (vp.idx >= vp.iovecs.len) return 0; var total: usize = 0; - total += vp.iovecs[vp.idx].iov_len - vp.off; - - if (vp.idx + 1 >= vp.iovecs.len) - return total; - - for (vp.iovecs[vp.idx + 1 ..]) |v| { - total += v.iov_len; - } - + if (vp.idx + 1 >= vp.iovecs.len) return total; + for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len; return total; } };