blob 04bc6a82 (10315B) - Raw
1 // SPDX-License-Identifier: MIT 2 // Copyright (c) 2015-2020 Zig Contributors 3 // This file is part of [zig](https://ziglang.org/), which is MIT licensed. 4 // The MIT license requires this copyright notice to be included in all copies 5 // and substantial portions of the software. 6 // 7 // Adapted from BearSSL's ctmul64 implementation originally written by Thomas Pornin <pornin@bolet.org> 8 9 const std = @import("../std.zig"); 10 const assert = std.debug.assert; 11 const math = std.math; 12 const mem = std.mem; 13 14 /// GHASH is a universal hash function that features multiplication 15 /// by a fixed parameter within a Galois field. 16 /// 17 /// It is not a general purpose hash function - The key must be secret, unpredictable and never reused. 18 /// 19 /// GHASH is typically used to compute the authentication tag in the AES-GCM construction. 20 pub const Ghash = struct { 21 pub const block_size: usize = 16; 22 pub const mac_length = 16; 23 pub const minimum_key_length = 16; 24 25 y0: u64 = 0, 26 y1: u64 = 0, 27 h0: u64, 28 h1: u64, 29 h2: u64, 30 h0r: u64, 31 h1r: u64, 32 h2r: u64, 33 34 hh0: u64 = undefined, 35 hh1: u64 = undefined, 36 hh2: u64 = undefined, 37 hh0r: u64 = undefined, 38 hh1r: u64 = undefined, 39 hh2r: u64 = undefined, 40 41 leftover: usize = 0, 42 buf: [block_size]u8 align(16) = undefined, 43 44 pub fn init(key: *const [minimum_key_length]u8) Ghash { 45 const h1 = mem.readIntBig(u64, key[0..8]); 46 const h0 = mem.readIntBig(u64, key[8..16]); 47 const h1r = @bitReverse(u64, h1); 48 const h0r = @bitReverse(u64, h0); 49 const h2 = h0 ^ h1; 50 const h2r = h0r ^ h1r; 51 52 if (std.builtin.mode == .ReleaseSmall) { 53 return Ghash{ 54 .h0 = h0, 55 .h1 = h1, 56 .h2 = h2, 57 .h0r = h0r, 58 .h1r = h1r, 59 .h2r = h2r, 60 }; 61 } else { 62 // Precompute H^2 63 var hh = Ghash{ 64 .h0 = h0, 65 .h1 = h1, 66 .h2 = h2, 67 .h0r = h0r, 68 .h1r = h1r, 69 .h2r = h2r, 70 }; 71 hh.update(key); 72 const hh1 = hh.y1; 73 const hh0 = hh.y0; 74 const hh1r = @bitReverse(u64, hh1); 75 const hh0r = @bitReverse(u64, hh0); 76 const hh2 = hh0 ^ hh1; 77 const hh2r = hh0r ^ hh1r; 78 79 return Ghash{ 80 .h0 = h0, 81 .h1 = h1, 82 .h2 = h2, 83 .h0r = h0r, 84 .h1r = h1r, 85 .h2r = h2r, 86 87 .hh0 = hh0, 88 .hh1 = hh1, 89 .hh2 = hh2, 90 .hh0r = hh0r, 91 .hh1r = hh1r, 92 .hh2r = hh2r, 93 }; 94 } 95 } 96 97 inline fn clmul_pclmul(x: u64, y: u64) u64 { 98 const Vector = std.meta.Vector; 99 const product = asm ( 100 \\ vpclmulqdq $0x00, %[x], %[y], %[out] 101 : [out] "=x" (-> Vector(2, u64)) 102 : [x] "x" (@bitCast(Vector(2, u64), @as(u128, x))), 103 [y] "x" (@bitCast(Vector(2, u64), @as(u128, y))) 104 ); 105 return product[0]; 106 } 107 108 fn clmul_soft(x: u64, y: u64) u64 { 109 const x0 = x & 0x1111111111111111; 110 const x1 = x & 0x2222222222222222; 111 const x2 = x & 0x4444444444444444; 112 const x3 = x & 0x8888888888888888; 113 const y0 = y & 0x1111111111111111; 114 const y1 = y & 0x2222222222222222; 115 const y2 = y & 0x4444444444444444; 116 const y3 = y & 0x8888888888888888; 117 var z0 = (x0 *% y0) ^ (x1 *% y3) ^ (x2 *% y2) ^ (x3 *% y1); 118 var z1 = (x0 *% y1) ^ (x1 *% y0) ^ (x2 *% y3) ^ (x3 *% y2); 119 var z2 = (x0 *% y2) ^ (x1 *% y1) ^ (x2 *% y0) ^ (x3 *% y3); 120 var z3 = (x0 *% y3) ^ (x1 *% y2) ^ (x2 *% y1) ^ (x3 *% y0); 121 z0 &= 0x1111111111111111; 122 z1 &= 0x2222222222222222; 123 z2 &= 0x4444444444444444; 124 z3 &= 0x8888888888888888; 125 return z0 | z1 | z2 | z3; 126 } 127 128 const has_pclmul = comptime std.Target.x86.featureSetHas(std.Target.current.cpu.features, .pclmul); 129 const has_avx = comptime std.Target.x86.featureSetHas(std.Target.current.cpu.features, .avx); 130 const clmul = if (std.Target.current.cpu.arch == .x86_64 and has_pclmul and has_avx) clmul_pclmul else clmul_soft; 131 132 fn blocks(st: *Ghash, msg: []const u8) void { 133 assert(msg.len % 16 == 0); // GHASH blocks() expects full blocks 134 var y1 = st.y1; 135 var y0 = st.y0; 136 137 var i: usize = 0; 138 139 // 2-blocks aggregated reduction 140 if (std.builtin.mode != .ReleaseSmall) { 141 while (i + 32 <= msg.len) : (i += 32) { 142 // B0 * H^2 unreduced 143 y1 ^= mem.readIntBig(u64, msg[i..][0..8]); 144 y0 ^= mem.readIntBig(u64, msg[i..][8..16]); 145 146 const y1r = @bitReverse(u64, y1); 147 const y0r = @bitReverse(u64, y0); 148 const y2 = y0 ^ y1; 149 const y2r = y0r ^ y1r; 150 151 var z0 = clmul(y0, st.hh0); 152 var z1 = clmul(y1, st.hh1); 153 var z2 = clmul(y2, st.hh2) ^ z0 ^ z1; 154 var z0h = clmul(y0r, st.hh0r); 155 var z1h = clmul(y1r, st.hh1r); 156 var z2h = clmul(y2r, st.hh2r) ^ z0h ^ z1h; 157 158 // B1 * H unreduced 159 const sy1 = mem.readIntBig(u64, msg[i..][16..24]); 160 const sy0 = mem.readIntBig(u64, msg[i..][24..32]); 161 162 const sy1r = @bitReverse(u64, sy1); 163 const sy0r = @bitReverse(u64, sy0); 164 const sy2 = sy0 ^ sy1; 165 const sy2r = sy0r ^ sy1r; 166 167 const sz0 = clmul(sy0, st.h0); 168 const sz1 = clmul(sy1, st.h1); 169 const sz2 = clmul(sy2, st.h2) ^ sz0 ^ sz1; 170 const sz0h = clmul(sy0r, st.h0r); 171 const sz1h = clmul(sy1r, st.h1r); 172 const sz2h = clmul(sy2r, st.h2r) ^ sz0h ^ sz1h; 173 174 // ((B0 * H^2) + B1 * H) (mod M) 175 z0 ^= sz0; 176 z1 ^= sz1; 177 z2 ^= sz2; 178 z0h ^= sz0h; 179 z1h ^= sz1h; 180 z2h ^= sz2h; 181 z0h = @bitReverse(u64, z0h) >> 1; 182 z1h = @bitReverse(u64, z1h) >> 1; 183 z2h = @bitReverse(u64, z2h) >> 1; 184 185 var v3 = z1h; 186 var v2 = z1 ^ z2h; 187 var v1 = z0h ^ z2; 188 var v0 = z0; 189 190 v3 = (v3 << 1) | (v2 >> 63); 191 v2 = (v2 << 1) | (v1 >> 63); 192 v1 = (v1 << 1) | (v0 >> 63); 193 v0 = (v0 << 1); 194 195 v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7); 196 v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57); 197 y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7); 198 y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57); 199 } 200 } 201 202 // single block 203 while (i + 16 <= msg.len) : (i += 16) { 204 y1 ^= mem.readIntBig(u64, msg[i..][0..8]); 205 y0 ^= mem.readIntBig(u64, msg[i..][8..16]); 206 207 const y1r = @bitReverse(u64, y1); 208 const y0r = @bitReverse(u64, y0); 209 const y2 = y0 ^ y1; 210 const y2r = y0r ^ y1r; 211 212 const z0 = clmul(y0, st.h0); 213 const z1 = clmul(y1, st.h1); 214 var z2 = clmul(y2, st.h2) ^ z0 ^ z1; 215 var z0h = clmul(y0r, st.h0r); 216 var z1h = clmul(y1r, st.h1r); 217 var z2h = clmul(y2r, st.h2r) ^ z0h ^ z1h; 218 z0h = @bitReverse(u64, z0h) >> 1; 219 z1h = @bitReverse(u64, z1h) >> 1; 220 z2h = @bitReverse(u64, z2h) >> 1; 221 222 // shift & reduce 223 var v3 = z1h; 224 var v2 = z1 ^ z2h; 225 var v1 = z0h ^ z2; 226 var v0 = z0; 227 228 v3 = (v3 << 1) | (v2 >> 63); 229 v2 = (v2 << 1) | (v1 >> 63); 230 v1 = (v1 << 1) | (v0 >> 63); 231 v0 = (v0 << 1); 232 233 v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7); 234 v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57); 235 y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7); 236 y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57); 237 } 238 st.y1 = y1; 239 st.y0 = y0; 240 } 241 242 pub fn update(st: *Ghash, m: []const u8) void { 243 var mb = m; 244 245 if (st.leftover > 0) { 246 const want = math.min(block_size - st.leftover, mb.len); 247 const mc = mb[0..want]; 248 for (mc) |x, i| { 249 st.buf[st.leftover + i] = x; 250 } 251 mb = mb[want..]; 252 st.leftover += want; 253 if (st.leftover < block_size) { 254 return; 255 } 256 st.blocks(&st.buf); 257 st.leftover = 0; 258 } 259 if (mb.len >= block_size) { 260 const want = mb.len & ~(block_size - 1); 261 st.blocks(mb[0..want]); 262 mb = mb[want..]; 263 } 264 if (mb.len > 0) { 265 for (mb) |x, i| { 266 st.buf[st.leftover + i] = x; 267 } 268 st.leftover += mb.len; 269 } 270 } 271 272 /// Zero-pad to align the next input to the first byte of a block 273 pub fn pad(st: *Ghash) void { 274 if (st.leftover == 0) { 275 return; 276 } 277 var i = st.leftover; 278 while (i < block_size) : (i += 1) { 279 st.buf[i] = 0; 280 } 281 st.blocks(&st.buf); 282 st.leftover = 0; 283 } 284 285 pub fn final(st: *Ghash, out: *[mac_length]u8) void { 286 st.pad(); 287 mem.writeIntBig(u64, out[0..8], st.y1); 288 mem.writeIntBig(u64, out[8..16], st.y0); 289 290 mem.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Ghash)]); 291 } 292 293 pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [minimum_key_length]u8) void { 294 var st = Ghash.init(key); 295 st.update(msg); 296 st.final(out); 297 } 298 }; 299 300 const htest = @import("test.zig"); 301 302 test "ghash" { 303 const key = [_]u8{0x42} ** 16; 304 const m = [_]u8{0x69} ** 256; 305 306 var st = Ghash.init(&key); 307 st.update(&m); 308 var out: [16]u8 = undefined; 309 st.final(&out); 310 htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out); 311 312 st = Ghash.init(&key); 313 st.update(m[0..100]); 314 st.update(m[100..]); 315 st.final(&out); 316 htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out); 317 }