Make Poly1305 faster by leveraging @addWithOverflow/@subWithOverflow (#15815)
These operations are constant-time on most, if not all currently supported architectures. However, even if they are not, this is not a big deal in the case on Poly1305, as the key is added at the end. The final addition remains protected. SalsaPoly and ChaChaPoly do encrypt-then-mac, so side channels would not leak anything about the plaintext anyway. * Apple Silicon (M1) Before: 2048 MiB/s After : 2823 MiB/s * AMD Ryzen 7 Before: 3165 MiB/s After : 4774 MiB/s
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
const std = @import("../std.zig");
|
||||
const utils = std.crypto.utils;
|
||||
const mem = std.mem;
|
||||
const mulWide = std.math.mulWide;
|
||||
|
||||
pub const Poly1305 = struct {
|
||||
pub const block_length: usize = 16;
|
||||
@@ -8,7 +9,7 @@ pub const Poly1305 = struct {
|
||||
pub const key_length = 32;
|
||||
|
||||
// constant multiplier (from the secret key)
|
||||
r: [3]u64,
|
||||
r: [2]u64,
|
||||
// accumulated hash
|
||||
h: [3]u64 = [_]u64{ 0, 0, 0 },
|
||||
// random number added at the end (from the secret key)
|
||||
@@ -19,13 +20,10 @@ pub const Poly1305 = struct {
|
||||
buf: [block_length]u8 align(16) = undefined,
|
||||
|
||||
pub fn init(key: *const [key_length]u8) Poly1305 {
|
||||
const t0 = mem.readIntLittle(u64, key[0..8]);
|
||||
const t1 = mem.readIntLittle(u64, key[8..16]);
|
||||
return Poly1305{
|
||||
.r = [_]u64{
|
||||
t0 & 0xffc0fffffff,
|
||||
((t0 >> 44) | (t1 << 20)) & 0xfffffc0ffff,
|
||||
((t1 >> 24)) & 0x00ffffffc0f,
|
||||
mem.readIntLittle(u64, key[0..8]) & 0x0ffffffc0fffffff,
|
||||
mem.readIntLittle(u64, key[8..16]) & 0x0ffffffc0ffffffc,
|
||||
},
|
||||
.pad = [_]u64{
|
||||
mem.readIntLittle(u64, key[16..24]),
|
||||
@@ -34,43 +32,77 @@ pub const Poly1305 = struct {
|
||||
};
|
||||
}
|
||||
|
||||
inline fn add(a: u64, b: u64, c: u1) struct { u64, u1 } {
|
||||
const v1 = @addWithOverflow(a, b);
|
||||
const v2 = @addWithOverflow(v1[0], c);
|
||||
return .{ v2[0], v1[1] | v2[1] };
|
||||
}
|
||||
|
||||
inline fn sub(a: u64, b: u64, c: u1) struct { u64, u1 } {
|
||||
const v1 = @subWithOverflow(a, b);
|
||||
const v2 = @subWithOverflow(v1[0], c);
|
||||
return .{ v2[0], v1[1] | v2[1] };
|
||||
}
|
||||
|
||||
fn blocks(st: *Poly1305, m: []const u8, comptime last: bool) void {
|
||||
const hibit: u64 = if (last) 0 else 1 << 40;
|
||||
const hibit: u64 = if (last) 0 else 1;
|
||||
const r0 = st.r[0];
|
||||
const r1 = st.r[1];
|
||||
const r2 = st.r[2];
|
||||
|
||||
var h0 = st.h[0];
|
||||
var h1 = st.h[1];
|
||||
var h2 = st.h[2];
|
||||
const s1 = r1 * (5 << 2);
|
||||
const s2 = r2 * (5 << 2);
|
||||
|
||||
var i: usize = 0;
|
||||
|
||||
while (i + block_length <= m.len) : (i += block_length) {
|
||||
// h += m[i]
|
||||
const t0 = mem.readIntLittle(u64, m[i..][0..8]);
|
||||
const t1 = mem.readIntLittle(u64, m[i + 8 ..][0..8]);
|
||||
h0 += @truncate(u44, t0);
|
||||
h1 += @truncate(u44, (t0 >> 44) | (t1 << 20));
|
||||
h2 += @truncate(u42, t1 >> 24) | hibit;
|
||||
const in0 = mem.readIntLittle(u64, m[i..][0..8]);
|
||||
const in1 = mem.readIntLittle(u64, m[i + 8 ..][0..8]);
|
||||
|
||||
// h *= r
|
||||
const d0 = @as(u128, h0) * r0 + @as(u128, h1) * s2 + @as(u128, h2) * s1;
|
||||
var d1 = @as(u128, h0) * r1 + @as(u128, h1) * r0 + @as(u128, h2) * s2;
|
||||
var d2 = @as(u128, h0) * r2 + @as(u128, h1) * r1 + @as(u128, h2) * r0;
|
||||
// Add the input message to H
|
||||
var v = @addWithOverflow(h0, in0);
|
||||
h0 = v[0];
|
||||
v = add(h1, in1, v[1]);
|
||||
h1 = v[0];
|
||||
h2 +%= v[1] +% hibit;
|
||||
|
||||
// partial reduction
|
||||
var carry = @intCast(u64, d0 >> 44);
|
||||
h0 = @truncate(u44, d0);
|
||||
d1 += carry;
|
||||
carry = @intCast(u64, d1 >> 44);
|
||||
h1 = @truncate(u44, d1);
|
||||
d2 += carry;
|
||||
carry = @intCast(u64, d2 >> 42);
|
||||
h2 = @truncate(u42, d2);
|
||||
h0 += @truncate(u64, carry) * 5;
|
||||
carry = h0 >> 44;
|
||||
h0 = @truncate(u44, h0);
|
||||
h1 += carry;
|
||||
// Compute H * R
|
||||
const m0 = mulWide(u64, h0, r0);
|
||||
const h1r0 = mulWide(u64, h1, r0);
|
||||
const h0r1 = mulWide(u64, h0, r1);
|
||||
const h2r0 = mulWide(u64, h2, r0);
|
||||
const h1r1 = mulWide(u64, h1, r1);
|
||||
const m3 = mulWide(u64, h2, r1);
|
||||
const m1 = h1r0 +% h0r1;
|
||||
const m2 = h2r0 +% h1r1;
|
||||
|
||||
const t0 = @truncate(u64, m0);
|
||||
v = @addWithOverflow(@truncate(u64, m1), @truncate(u64, m0 >> 64));
|
||||
const t1 = v[0];
|
||||
v = add(@truncate(u64, m2), @truncate(u64, m1 >> 64), v[1]);
|
||||
const t2 = v[0];
|
||||
v = add(@truncate(u64, m3), @truncate(u64, m2 >> 64), v[1]);
|
||||
const t3 = v[0];
|
||||
|
||||
// Partial reduction
|
||||
h0 = t0;
|
||||
h1 = t1;
|
||||
h2 = t2 & 3;
|
||||
|
||||
// Add c*(4+1)
|
||||
var cclo = t2 & ~@as(u64, 3);
|
||||
var cchi = t3;
|
||||
v = @addWithOverflow(h0, cclo);
|
||||
h0 = v[0];
|
||||
v = add(h1, cchi, v[1]);
|
||||
h1 = v[0];
|
||||
h2 +%= v[1];
|
||||
const cc = (cclo | (@as(u128, cchi) << 64)) >> 2;
|
||||
v = @addWithOverflow(h0, @truncate(u64, cc));
|
||||
h0 = v[0];
|
||||
v = add(h1, @truncate(u64, cc >> 64), v[1]);
|
||||
h1 = v[0];
|
||||
h2 +%= v[1];
|
||||
}
|
||||
st.h = [_]u64{ h0, h1, h2 };
|
||||
}
|
||||
@@ -115,10 +147,7 @@ pub const Poly1305 = struct {
|
||||
if (st.leftover == 0) {
|
||||
return;
|
||||
}
|
||||
var i = st.leftover;
|
||||
while (i < block_length) : (i += 1) {
|
||||
st.buf[i] = 0;
|
||||
}
|
||||
@memset(st.buf[st.leftover..], 0);
|
||||
st.blocks(&st.buf);
|
||||
st.leftover = 0;
|
||||
}
|
||||
@@ -128,65 +157,30 @@ pub const Poly1305 = struct {
|
||||
var i = st.leftover;
|
||||
st.buf[i] = 1;
|
||||
i += 1;
|
||||
while (i < block_length) : (i += 1) {
|
||||
st.buf[i] = 0;
|
||||
}
|
||||
@memset(st.buf[i..], 0);
|
||||
st.blocks(&st.buf, true);
|
||||
}
|
||||
// fully carry h
|
||||
var carry = st.h[1] >> 44;
|
||||
st.h[1] = @truncate(u44, st.h[1]);
|
||||
st.h[2] += carry;
|
||||
carry = st.h[2] >> 42;
|
||||
st.h[2] = @truncate(u42, st.h[2]);
|
||||
st.h[0] += carry * 5;
|
||||
carry = st.h[0] >> 44;
|
||||
st.h[0] = @truncate(u44, st.h[0]);
|
||||
st.h[1] += carry;
|
||||
carry = st.h[1] >> 44;
|
||||
st.h[1] = @truncate(u44, st.h[1]);
|
||||
st.h[2] += carry;
|
||||
carry = st.h[2] >> 42;
|
||||
st.h[2] = @truncate(u42, st.h[2]);
|
||||
st.h[0] += carry * 5;
|
||||
carry = st.h[0] >> 44;
|
||||
st.h[0] = @truncate(u44, st.h[0]);
|
||||
st.h[1] += carry;
|
||||
|
||||
// compute h + -p
|
||||
var g0 = st.h[0] + 5;
|
||||
carry = g0 >> 44;
|
||||
g0 = @truncate(u44, g0);
|
||||
var g1 = st.h[1] + carry;
|
||||
carry = g1 >> 44;
|
||||
g1 = @truncate(u44, g1);
|
||||
var g2 = st.h[2] + carry -% (1 << 42);
|
||||
var h0 = st.h[0];
|
||||
var h1 = st.h[1];
|
||||
var h2 = st.h[2];
|
||||
|
||||
// (hopefully) constant-time select h if h < p, or h + -p if h >= p
|
||||
const mask = (g2 >> 63) -% 1;
|
||||
g0 &= mask;
|
||||
g1 &= mask;
|
||||
g2 &= mask;
|
||||
const nmask = ~mask;
|
||||
st.h[0] = (st.h[0] & nmask) | g0;
|
||||
st.h[1] = (st.h[1] & nmask) | g1;
|
||||
st.h[2] = (st.h[2] & nmask) | g2;
|
||||
// H - (2^130 - 5)
|
||||
var v = sub(h0, 0xfffffffffffffffb, 0);
|
||||
const h_p0 = v[0];
|
||||
v = sub(h1, 0xffffffffffffffff, v[1]);
|
||||
const h_p1 = v[0];
|
||||
v = sub(h2, 0x0000000000000003, v[1]);
|
||||
|
||||
// h = (h + pad)
|
||||
const t0 = st.pad[0];
|
||||
const t1 = st.pad[1];
|
||||
st.h[0] += @truncate(u44, t0);
|
||||
carry = st.h[0] >> 44;
|
||||
st.h[0] = @truncate(u44, st.h[0]);
|
||||
st.h[1] += @truncate(u44, (t0 >> 44) | (t1 << 20)) + carry;
|
||||
carry = st.h[1] >> 44;
|
||||
st.h[1] = @truncate(u44, st.h[1]);
|
||||
st.h[2] += @truncate(u42, t1 >> 24) + carry;
|
||||
st.h[2] = @truncate(u42, st.h[2]);
|
||||
// Final reduction, subtract 2^130-5 from H if H >= 2^130-5
|
||||
const mask = v[1] -% 1;
|
||||
h0 ^= mask & (h0 ^ h_p0);
|
||||
h1 ^= mask & (h1 ^ h_p1);
|
||||
|
||||
// mac = h % (2^128)
|
||||
st.h[0] |= st.h[1] << 44;
|
||||
st.h[1] = (st.h[1] >> 20) | (st.h[2] << 24);
|
||||
// Add the first half of the key, we intentionally don't use @addWithOverflow() here.
|
||||
st.h[0] = h0 +% st.pad[0];
|
||||
const c = ((h0 & st.pad[0]) | ((h0 | st.pad[0]) & ~st.h[0])) >> 63;
|
||||
st.h[1] = h1 +% st.pad[1] +% c;
|
||||
|
||||
mem.writeIntLittle(u64, out[0..8], st.h[0]);
|
||||
mem.writeIntLittle(u64, out[8..16], st.h[1]);
|
||||
|
||||
Reference in New Issue
Block a user