zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

blob 04bc6a82 (10315B) - Raw


      1 // SPDX-License-Identifier: MIT
      2 // Copyright (c) 2015-2020 Zig Contributors
      3 // This file is part of [zig](https://ziglang.org/), which is MIT licensed.
      4 // The MIT license requires this copyright notice to be included in all copies
      5 // and substantial portions of the software.
      6 //
      7 // Adapted from BearSSL's ctmul64 implementation originally written by Thomas Pornin <pornin@bolet.org>
      8 
      9 const std = @import("../std.zig");
     10 const assert = std.debug.assert;
     11 const math = std.math;
     12 const mem = std.mem;
     13 
     14 /// GHASH is a universal hash function that features multiplication
     15 /// by a fixed parameter within a Galois field.
     16 ///
     17 /// It is not a general purpose hash function - The key must be secret, unpredictable and never reused.
     18 ///
     19 /// GHASH is typically used to compute the authentication tag in the AES-GCM construction.
     20 pub const Ghash = struct {
     21     pub const block_size: usize = 16;
     22     pub const mac_length = 16;
     23     pub const minimum_key_length = 16;
     24 
     25     y0: u64 = 0,
     26     y1: u64 = 0,
     27     h0: u64,
     28     h1: u64,
     29     h2: u64,
     30     h0r: u64,
     31     h1r: u64,
     32     h2r: u64,
     33 
     34     hh0: u64 = undefined,
     35     hh1: u64 = undefined,
     36     hh2: u64 = undefined,
     37     hh0r: u64 = undefined,
     38     hh1r: u64 = undefined,
     39     hh2r: u64 = undefined,
     40 
     41     leftover: usize = 0,
     42     buf: [block_size]u8 align(16) = undefined,
     43 
     44     pub fn init(key: *const [minimum_key_length]u8) Ghash {
     45         const h1 = mem.readIntBig(u64, key[0..8]);
     46         const h0 = mem.readIntBig(u64, key[8..16]);
     47         const h1r = @bitReverse(u64, h1);
     48         const h0r = @bitReverse(u64, h0);
     49         const h2 = h0 ^ h1;
     50         const h2r = h0r ^ h1r;
     51 
     52         if (std.builtin.mode == .ReleaseSmall) {
     53             return Ghash{
     54                 .h0 = h0,
     55                 .h1 = h1,
     56                 .h2 = h2,
     57                 .h0r = h0r,
     58                 .h1r = h1r,
     59                 .h2r = h2r,
     60             };
     61         } else {
     62             // Precompute H^2
     63             var hh = Ghash{
     64                 .h0 = h0,
     65                 .h1 = h1,
     66                 .h2 = h2,
     67                 .h0r = h0r,
     68                 .h1r = h1r,
     69                 .h2r = h2r,
     70             };
     71             hh.update(key);
     72             const hh1 = hh.y1;
     73             const hh0 = hh.y0;
     74             const hh1r = @bitReverse(u64, hh1);
     75             const hh0r = @bitReverse(u64, hh0);
     76             const hh2 = hh0 ^ hh1;
     77             const hh2r = hh0r ^ hh1r;
     78 
     79             return Ghash{
     80                 .h0 = h0,
     81                 .h1 = h1,
     82                 .h2 = h2,
     83                 .h0r = h0r,
     84                 .h1r = h1r,
     85                 .h2r = h2r,
     86 
     87                 .hh0 = hh0,
     88                 .hh1 = hh1,
     89                 .hh2 = hh2,
     90                 .hh0r = hh0r,
     91                 .hh1r = hh1r,
     92                 .hh2r = hh2r,
     93             };
     94         }
     95     }
     96 
     97     inline fn clmul_pclmul(x: u64, y: u64) u64 {
     98         const Vector = std.meta.Vector;
     99         const product = asm (
    100             \\ vpclmulqdq $0x00, %[x], %[y], %[out]
    101             : [out] "=x" (-> Vector(2, u64))
    102             : [x] "x" (@bitCast(Vector(2, u64), @as(u128, x))),
    103               [y] "x" (@bitCast(Vector(2, u64), @as(u128, y)))
    104         );
    105         return product[0];
    106     }
    107 
    108     fn clmul_soft(x: u64, y: u64) u64 {
    109         const x0 = x & 0x1111111111111111;
    110         const x1 = x & 0x2222222222222222;
    111         const x2 = x & 0x4444444444444444;
    112         const x3 = x & 0x8888888888888888;
    113         const y0 = y & 0x1111111111111111;
    114         const y1 = y & 0x2222222222222222;
    115         const y2 = y & 0x4444444444444444;
    116         const y3 = y & 0x8888888888888888;
    117         var z0 = (x0 *% y0) ^ (x1 *% y3) ^ (x2 *% y2) ^ (x3 *% y1);
    118         var z1 = (x0 *% y1) ^ (x1 *% y0) ^ (x2 *% y3) ^ (x3 *% y2);
    119         var z2 = (x0 *% y2) ^ (x1 *% y1) ^ (x2 *% y0) ^ (x3 *% y3);
    120         var z3 = (x0 *% y3) ^ (x1 *% y2) ^ (x2 *% y1) ^ (x3 *% y0);
    121         z0 &= 0x1111111111111111;
    122         z1 &= 0x2222222222222222;
    123         z2 &= 0x4444444444444444;
    124         z3 &= 0x8888888888888888;
    125         return z0 | z1 | z2 | z3;
    126     }
    127 
    128     const has_pclmul = comptime std.Target.x86.featureSetHas(std.Target.current.cpu.features, .pclmul);
    129     const has_avx = comptime std.Target.x86.featureSetHas(std.Target.current.cpu.features, .avx);
    130     const clmul = if (std.Target.current.cpu.arch == .x86_64 and has_pclmul and has_avx) clmul_pclmul else clmul_soft;
    131 
    132     fn blocks(st: *Ghash, msg: []const u8) void {
    133         assert(msg.len % 16 == 0); // GHASH blocks() expects full blocks
    134         var y1 = st.y1;
    135         var y0 = st.y0;
    136 
    137         var i: usize = 0;
    138 
    139         // 2-blocks aggregated reduction
    140         if (std.builtin.mode != .ReleaseSmall) {
    141             while (i + 32 <= msg.len) : (i += 32) {
    142                 // B0 * H^2 unreduced
    143                 y1 ^= mem.readIntBig(u64, msg[i..][0..8]);
    144                 y0 ^= mem.readIntBig(u64, msg[i..][8..16]);
    145 
    146                 const y1r = @bitReverse(u64, y1);
    147                 const y0r = @bitReverse(u64, y0);
    148                 const y2 = y0 ^ y1;
    149                 const y2r = y0r ^ y1r;
    150 
    151                 var z0 = clmul(y0, st.hh0);
    152                 var z1 = clmul(y1, st.hh1);
    153                 var z2 = clmul(y2, st.hh2) ^ z0 ^ z1;
    154                 var z0h = clmul(y0r, st.hh0r);
    155                 var z1h = clmul(y1r, st.hh1r);
    156                 var z2h = clmul(y2r, st.hh2r) ^ z0h ^ z1h;
    157 
    158                 // B1 * H unreduced
    159                 const sy1 = mem.readIntBig(u64, msg[i..][16..24]);
    160                 const sy0 = mem.readIntBig(u64, msg[i..][24..32]);
    161 
    162                 const sy1r = @bitReverse(u64, sy1);
    163                 const sy0r = @bitReverse(u64, sy0);
    164                 const sy2 = sy0 ^ sy1;
    165                 const sy2r = sy0r ^ sy1r;
    166 
    167                 const sz0 = clmul(sy0, st.h0);
    168                 const sz1 = clmul(sy1, st.h1);
    169                 const sz2 = clmul(sy2, st.h2) ^ sz0 ^ sz1;
    170                 const sz0h = clmul(sy0r, st.h0r);
    171                 const sz1h = clmul(sy1r, st.h1r);
    172                 const sz2h = clmul(sy2r, st.h2r) ^ sz0h ^ sz1h;
    173 
    174                 // ((B0 * H^2) + B1 * H) (mod M)
    175                 z0 ^= sz0;
    176                 z1 ^= sz1;
    177                 z2 ^= sz2;
    178                 z0h ^= sz0h;
    179                 z1h ^= sz1h;
    180                 z2h ^= sz2h;
    181                 z0h = @bitReverse(u64, z0h) >> 1;
    182                 z1h = @bitReverse(u64, z1h) >> 1;
    183                 z2h = @bitReverse(u64, z2h) >> 1;
    184 
    185                 var v3 = z1h;
    186                 var v2 = z1 ^ z2h;
    187                 var v1 = z0h ^ z2;
    188                 var v0 = z0;
    189 
    190                 v3 = (v3 << 1) | (v2 >> 63);
    191                 v2 = (v2 << 1) | (v1 >> 63);
    192                 v1 = (v1 << 1) | (v0 >> 63);
    193                 v0 = (v0 << 1);
    194 
    195                 v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7);
    196                 v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57);
    197                 y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7);
    198                 y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57);
    199             }
    200         }
    201 
    202         // single block
    203         while (i + 16 <= msg.len) : (i += 16) {
    204             y1 ^= mem.readIntBig(u64, msg[i..][0..8]);
    205             y0 ^= mem.readIntBig(u64, msg[i..][8..16]);
    206 
    207             const y1r = @bitReverse(u64, y1);
    208             const y0r = @bitReverse(u64, y0);
    209             const y2 = y0 ^ y1;
    210             const y2r = y0r ^ y1r;
    211 
    212             const z0 = clmul(y0, st.h0);
    213             const z1 = clmul(y1, st.h1);
    214             var z2 = clmul(y2, st.h2) ^ z0 ^ z1;
    215             var z0h = clmul(y0r, st.h0r);
    216             var z1h = clmul(y1r, st.h1r);
    217             var z2h = clmul(y2r, st.h2r) ^ z0h ^ z1h;
    218             z0h = @bitReverse(u64, z0h) >> 1;
    219             z1h = @bitReverse(u64, z1h) >> 1;
    220             z2h = @bitReverse(u64, z2h) >> 1;
    221 
    222             // shift & reduce
    223             var v3 = z1h;
    224             var v2 = z1 ^ z2h;
    225             var v1 = z0h ^ z2;
    226             var v0 = z0;
    227 
    228             v3 = (v3 << 1) | (v2 >> 63);
    229             v2 = (v2 << 1) | (v1 >> 63);
    230             v1 = (v1 << 1) | (v0 >> 63);
    231             v0 = (v0 << 1);
    232 
    233             v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7);
    234             v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57);
    235             y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7);
    236             y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57);
    237         }
    238         st.y1 = y1;
    239         st.y0 = y0;
    240     }
    241 
    242     pub fn update(st: *Ghash, m: []const u8) void {
    243         var mb = m;
    244 
    245         if (st.leftover > 0) {
    246             const want = math.min(block_size - st.leftover, mb.len);
    247             const mc = mb[0..want];
    248             for (mc) |x, i| {
    249                 st.buf[st.leftover + i] = x;
    250             }
    251             mb = mb[want..];
    252             st.leftover += want;
    253             if (st.leftover < block_size) {
    254                 return;
    255             }
    256             st.blocks(&st.buf);
    257             st.leftover = 0;
    258         }
    259         if (mb.len >= block_size) {
    260             const want = mb.len & ~(block_size - 1);
    261             st.blocks(mb[0..want]);
    262             mb = mb[want..];
    263         }
    264         if (mb.len > 0) {
    265             for (mb) |x, i| {
    266                 st.buf[st.leftover + i] = x;
    267             }
    268             st.leftover += mb.len;
    269         }
    270     }
    271 
    272     /// Zero-pad to align the next input to the first byte of a block
    273     pub fn pad(st: *Ghash) void {
    274         if (st.leftover == 0) {
    275             return;
    276         }
    277         var i = st.leftover;
    278         while (i < block_size) : (i += 1) {
    279             st.buf[i] = 0;
    280         }
    281         st.blocks(&st.buf);
    282         st.leftover = 0;
    283     }
    284 
    285     pub fn final(st: *Ghash, out: *[mac_length]u8) void {
    286         st.pad();
    287         mem.writeIntBig(u64, out[0..8], st.y1);
    288         mem.writeIntBig(u64, out[8..16], st.y0);
    289 
    290         mem.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Ghash)]);
    291     }
    292 
    293     pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [minimum_key_length]u8) void {
    294         var st = Ghash.init(key);
    295         st.update(msg);
    296         st.final(out);
    297     }
    298 };
    299 
    300 const htest = @import("test.zig");
    301 
    302 test "ghash" {
    303     const key = [_]u8{0x42} ** 16;
    304     const m = [_]u8{0x69} ** 256;
    305 
    306     var st = Ghash.init(&key);
    307     st.update(&m);
    308     var out: [16]u8 = undefined;
    309     st.final(&out);
    310     htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out);
    311 
    312     st = Ghash.init(&key);
    313     st.update(m[0..100]);
    314     st.update(m[100..]);
    315     st.final(&out);
    316     htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out);
    317 }