From 9d179a98f69dbab393cbb3fc5dd4b64c553a721b Mon Sep 17 00:00:00 2001 From: Frank Denis <124872+jedisct1@users.noreply.github.com> Date: Tue, 23 May 2023 09:55:45 +0200 Subject: [PATCH] Make Poly1305 faster by leveraging @addWithOverflow/@subWithOverflow (#15815) These operations are constant-time on most, if not all currently supported architectures. However, even if they are not, this is not a big deal in the case on Poly1305, as the key is added at the end. The final addition remains protected. SalsaPoly and ChaChaPoly do encrypt-then-mac, so side channels would not leak anything about the plaintext anyway. * Apple Silicon (M1) Before: 2048 MiB/s After : 2823 MiB/s * AMD Ryzen 7 Before: 3165 MiB/s After : 4774 MiB/s --- lib/std/crypto/poly1305.zig | 174 +++++++++++++++++------------------- 1 file changed, 84 insertions(+), 90 deletions(-) diff --git a/lib/std/crypto/poly1305.zig b/lib/std/crypto/poly1305.zig index e99cf144d5..014cf651eb 100644 --- a/lib/std/crypto/poly1305.zig +++ b/lib/std/crypto/poly1305.zig @@ -1,6 +1,7 @@ const std = @import("../std.zig"); const utils = std.crypto.utils; const mem = std.mem; +const mulWide = std.math.mulWide; pub const Poly1305 = struct { pub const block_length: usize = 16; @@ -8,7 +9,7 @@ pub const Poly1305 = struct { pub const key_length = 32; // constant multiplier (from the secret key) - r: [3]u64, + r: [2]u64, // accumulated hash h: [3]u64 = [_]u64{ 0, 0, 0 }, // random number added at the end (from the secret key) @@ -19,13 +20,10 @@ pub const Poly1305 = struct { buf: [block_length]u8 align(16) = undefined, pub fn init(key: *const [key_length]u8) Poly1305 { - const t0 = mem.readIntLittle(u64, key[0..8]); - const t1 = mem.readIntLittle(u64, key[8..16]); return Poly1305{ .r = [_]u64{ - t0 & 0xffc0fffffff, - ((t0 >> 44) | (t1 << 20)) & 0xfffffc0ffff, - ((t1 >> 24)) & 0x00ffffffc0f, + mem.readIntLittle(u64, key[0..8]) & 0x0ffffffc0fffffff, + mem.readIntLittle(u64, key[8..16]) & 0x0ffffffc0ffffffc, }, .pad = [_]u64{ mem.readIntLittle(u64, key[16..24]), @@ -34,43 +32,77 @@ pub const Poly1305 = struct { }; } + inline fn add(a: u64, b: u64, c: u1) struct { u64, u1 } { + const v1 = @addWithOverflow(a, b); + const v2 = @addWithOverflow(v1[0], c); + return .{ v2[0], v1[1] | v2[1] }; + } + + inline fn sub(a: u64, b: u64, c: u1) struct { u64, u1 } { + const v1 = @subWithOverflow(a, b); + const v2 = @subWithOverflow(v1[0], c); + return .{ v2[0], v1[1] | v2[1] }; + } + fn blocks(st: *Poly1305, m: []const u8, comptime last: bool) void { - const hibit: u64 = if (last) 0 else 1 << 40; + const hibit: u64 = if (last) 0 else 1; const r0 = st.r[0]; const r1 = st.r[1]; - const r2 = st.r[2]; + var h0 = st.h[0]; var h1 = st.h[1]; var h2 = st.h[2]; - const s1 = r1 * (5 << 2); - const s2 = r2 * (5 << 2); + var i: usize = 0; + while (i + block_length <= m.len) : (i += block_length) { - // h += m[i] - const t0 = mem.readIntLittle(u64, m[i..][0..8]); - const t1 = mem.readIntLittle(u64, m[i + 8 ..][0..8]); - h0 += @truncate(u44, t0); - h1 += @truncate(u44, (t0 >> 44) | (t1 << 20)); - h2 += @truncate(u42, t1 >> 24) | hibit; + const in0 = mem.readIntLittle(u64, m[i..][0..8]); + const in1 = mem.readIntLittle(u64, m[i + 8 ..][0..8]); - // h *= r - const d0 = @as(u128, h0) * r0 + @as(u128, h1) * s2 + @as(u128, h2) * s1; - var d1 = @as(u128, h0) * r1 + @as(u128, h1) * r0 + @as(u128, h2) * s2; - var d2 = @as(u128, h0) * r2 + @as(u128, h1) * r1 + @as(u128, h2) * r0; + // Add the input message to H + var v = @addWithOverflow(h0, in0); + h0 = v[0]; + v = add(h1, in1, v[1]); + h1 = v[0]; + h2 +%= v[1] +% hibit; - // partial reduction - var carry = @intCast(u64, d0 >> 44); - h0 = @truncate(u44, d0); - d1 += carry; - carry = @intCast(u64, d1 >> 44); - h1 = @truncate(u44, d1); - d2 += carry; - carry = @intCast(u64, d2 >> 42); - h2 = @truncate(u42, d2); - h0 += @truncate(u64, carry) * 5; - carry = h0 >> 44; - h0 = @truncate(u44, h0); - h1 += carry; + // Compute H * R + const m0 = mulWide(u64, h0, r0); + const h1r0 = mulWide(u64, h1, r0); + const h0r1 = mulWide(u64, h0, r1); + const h2r0 = mulWide(u64, h2, r0); + const h1r1 = mulWide(u64, h1, r1); + const m3 = mulWide(u64, h2, r1); + const m1 = h1r0 +% h0r1; + const m2 = h2r0 +% h1r1; + + const t0 = @truncate(u64, m0); + v = @addWithOverflow(@truncate(u64, m1), @truncate(u64, m0 >> 64)); + const t1 = v[0]; + v = add(@truncate(u64, m2), @truncate(u64, m1 >> 64), v[1]); + const t2 = v[0]; + v = add(@truncate(u64, m3), @truncate(u64, m2 >> 64), v[1]); + const t3 = v[0]; + + // Partial reduction + h0 = t0; + h1 = t1; + h2 = t2 & 3; + + // Add c*(4+1) + var cclo = t2 & ~@as(u64, 3); + var cchi = t3; + v = @addWithOverflow(h0, cclo); + h0 = v[0]; + v = add(h1, cchi, v[1]); + h1 = v[0]; + h2 +%= v[1]; + const cc = (cclo | (@as(u128, cchi) << 64)) >> 2; + v = @addWithOverflow(h0, @truncate(u64, cc)); + h0 = v[0]; + v = add(h1, @truncate(u64, cc >> 64), v[1]); + h1 = v[0]; + h2 +%= v[1]; } st.h = [_]u64{ h0, h1, h2 }; } @@ -115,10 +147,7 @@ pub const Poly1305 = struct { if (st.leftover == 0) { return; } - var i = st.leftover; - while (i < block_length) : (i += 1) { - st.buf[i] = 0; - } + @memset(st.buf[st.leftover..], 0); st.blocks(&st.buf); st.leftover = 0; } @@ -128,65 +157,30 @@ pub const Poly1305 = struct { var i = st.leftover; st.buf[i] = 1; i += 1; - while (i < block_length) : (i += 1) { - st.buf[i] = 0; - } + @memset(st.buf[i..], 0); st.blocks(&st.buf, true); } - // fully carry h - var carry = st.h[1] >> 44; - st.h[1] = @truncate(u44, st.h[1]); - st.h[2] += carry; - carry = st.h[2] >> 42; - st.h[2] = @truncate(u42, st.h[2]); - st.h[0] += carry * 5; - carry = st.h[0] >> 44; - st.h[0] = @truncate(u44, st.h[0]); - st.h[1] += carry; - carry = st.h[1] >> 44; - st.h[1] = @truncate(u44, st.h[1]); - st.h[2] += carry; - carry = st.h[2] >> 42; - st.h[2] = @truncate(u42, st.h[2]); - st.h[0] += carry * 5; - carry = st.h[0] >> 44; - st.h[0] = @truncate(u44, st.h[0]); - st.h[1] += carry; - // compute h + -p - var g0 = st.h[0] + 5; - carry = g0 >> 44; - g0 = @truncate(u44, g0); - var g1 = st.h[1] + carry; - carry = g1 >> 44; - g1 = @truncate(u44, g1); - var g2 = st.h[2] + carry -% (1 << 42); + var h0 = st.h[0]; + var h1 = st.h[1]; + var h2 = st.h[2]; - // (hopefully) constant-time select h if h < p, or h + -p if h >= p - const mask = (g2 >> 63) -% 1; - g0 &= mask; - g1 &= mask; - g2 &= mask; - const nmask = ~mask; - st.h[0] = (st.h[0] & nmask) | g0; - st.h[1] = (st.h[1] & nmask) | g1; - st.h[2] = (st.h[2] & nmask) | g2; + // H - (2^130 - 5) + var v = sub(h0, 0xfffffffffffffffb, 0); + const h_p0 = v[0]; + v = sub(h1, 0xffffffffffffffff, v[1]); + const h_p1 = v[0]; + v = sub(h2, 0x0000000000000003, v[1]); - // h = (h + pad) - const t0 = st.pad[0]; - const t1 = st.pad[1]; - st.h[0] += @truncate(u44, t0); - carry = st.h[0] >> 44; - st.h[0] = @truncate(u44, st.h[0]); - st.h[1] += @truncate(u44, (t0 >> 44) | (t1 << 20)) + carry; - carry = st.h[1] >> 44; - st.h[1] = @truncate(u44, st.h[1]); - st.h[2] += @truncate(u42, t1 >> 24) + carry; - st.h[2] = @truncate(u42, st.h[2]); + // Final reduction, subtract 2^130-5 from H if H >= 2^130-5 + const mask = v[1] -% 1; + h0 ^= mask & (h0 ^ h_p0); + h1 ^= mask & (h1 ^ h_p1); - // mac = h % (2^128) - st.h[0] |= st.h[1] << 44; - st.h[1] = (st.h[1] >> 20) | (st.h[2] << 24); + // Add the first half of the key, we intentionally don't use @addWithOverflow() here. + st.h[0] = h0 +% st.pad[0]; + const c = ((h0 & st.pad[0]) | ((h0 | st.pad[0]) & ~st.h[0])) >> 63; + st.h[1] = h1 +% st.pad[1] +% c; mem.writeIntLittle(u64, out[0..8], st.h[0]); mem.writeIntLittle(u64, out[8..16], st.h[1]);