zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

commit cb4e087fda07a3bec296d1e75da6416112fa2fd1 (tree)
parent bd24e66379612a34313206380df1999afcbe95b8
Author: Andrew Kelley <andrew@ziglang.org>
Date:   Mon, 11 Mar 2024 18:48:08 -0700

Merge pull request #19239 from jedisct1/ml-kem

std.crypto: add support for ML-KEM
Diffstat:
Mlib/std/crypto.zig | 3++-
Dlib/std/crypto/kyber_d00.zig | 1783-------------------------------------------------------------------------------
Alib/std/crypto/ml_kem.zig | 1830+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 1832 insertions(+), 1784 deletions(-)

diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig @@ -72,7 +72,8 @@ pub const dh = struct { /// Key Encapsulation Mechanisms. pub const kem = struct { - pub const kyber_d00 = @import("crypto/kyber_d00.zig"); + pub const kyber_d00 = @import("crypto/ml_kem.zig").kyber_d00; + pub const ml_kem_01 = @import("crypto/ml_kem.zig").ml_kem_01; }; /// Elliptic-curve arithmetic. diff --git a/lib/std/crypto/kyber_d00.zig b/lib/std/crypto/kyber_d00.zig @@ -1,1783 +0,0 @@ -//! Implementation of the IND-CCA2 post-quantum secure key encapsulation -//! mechanism (KEM) CRYSTALS-Kyber, as submitted to the third round of the NIST -//! Post-Quantum Cryptography (v3.02/"draft00"), and selected for standardisation. -//! -//! Kyber will likely change before final standardisation. -//! -//! The namespace suffix (currently `_d00`) refers to the version currently -//! implemented, in accordance with the draft. It may not be updated if new -//! versions of the draft only include editorial changes. -//! -//! The suffix will eventually be removed once Kyber is finalized. -//! -//! Quoting from the CFRG I-D: -//! -//! Kyber is not a Diffie-Hellman (DH) style non-interactive key -//! agreement, but instead, Kyber is a Key Encapsulation Method (KEM). -//! In essence, a KEM is a Public-Key Encryption (PKE) scheme where the -//! plaintext cannot be specified, but is generated as a random key as -//! part of the encryption. A KEM can be transformed into an unrestricted -//! PKE using HPKE (RFC9180). On its own, a KEM can be used as a key -//! agreement method in TLS. -//! -//! Kyber is an IND-CCA2 secure KEM. It is constructed by applying a -//! Fujisaki--Okamato style transformation on InnerPKE, which is the -//! underlying IND-CPA secure Public Key Encryption scheme. We cannot -//! use InnerPKE directly, as its ciphertexts are malleable. -//! -//! ``` -//! F.O. transform -//! InnerPKE ----------------------> Kyber -//! IND-CPA IND-CCA2 -//! ``` -//! -//! Kyber is a lattice-based scheme. More precisely, its security is -//! based on the learning-with-errors-and-rounding problem in module -//! lattices (MLWER). The underlying polynomial ring R (defined in -//! Section 5) is chosen such that multiplication is very fast using the -//! number theoretic transform (NTT, see Section 5.1.3). -//! -//! An InnerPKE private key is a vector _s_ over R of length k which is -//! _small_ in a particular way. Here k is a security parameter akin to -//! the size of a prime modulus. For Kyber512, which targets AES-128's -//! security level, the value of k is 2. -//! -//! The public key consists of two values: -//! -//! * _A_ a uniformly sampled k by k matrix over R _and_ -//! -//! * _t = A s + e_, where e is a suitably small masking vector. -//! -//! Distinguishing between such A s + e and a uniformly sampled t is the -//! module learning-with-errors (MLWE) problem. If that is hard, then it -//! is also hard to recover the private key from the public key as that -//! would allow you to distinguish between those two. -//! -//! To save space in the public key, A is recomputed deterministically -//! from a seed _rho_. -//! -//! A ciphertext for a message m under this public key is a pair (c_1, -//! c_2) computed roughly as follows: -//! -//! c_1 = Compress(A^T r + e_1, d_u) -//! c_2 = Compress(t^T r + e_2 + Decompress(m, 1), d_v) -//! -//! where -//! -//! * e_1, e_2 and r are small blinds; -//! -//! * Compress(-, d) removes some information, leaving d bits per -//! coefficient and Decompress is such that Compress after Decompress -//! does nothing and -//! -//! * d_u, d_v are scheme parameters. -//! -//! Distinguishing such a ciphertext and uniformly sampled (c_1, c_2) is -//! an example of the full MLWER problem, see section 4.4 of [KyberV302]. -//! -//! To decrypt the ciphertext, one computes -//! -//! m = Compress(Decompress(c_2, d_v) - s^T Decompress(c_1, d_u), 1). -//! -//! It it not straight-forward to see that this formula is correct. In -//! fact, there is negligible but non-zero probability that a ciphertext -//! does not decrypt correctly given by the DFP column in Table 4. This -//! failure probability can be computed by a careful automated analysis -//! of the probabilities involved, see kyber_failure.py of [SecEst]. -//! -//! [KyberV302](https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf) -//! [I-D](https://github.com/bwesterb/draft-schwabe-cfrg-kyber) -//! [SecEst](https://github.com/pq-crystals/security-estimates) - -// TODO -// -// - The bottleneck in Kyber are the various hash/xof calls: -// - Optimize Zig's keccak implementation. -// - Use SIMD to compute keccak in parallel. -// - Can we track bounds of coefficients using comptime types without -// duplicating code? -// - Would be neater to have tests closer to the thing under test. -// - When generating a keypair, we have a copy of the inner public key with -// its large matrix A in both the public key and the private key. In Go we -// can just have a pointer in the private key to the public key, but -// how do we do this elegantly in Zig? - -const std = @import("std"); -const builtin = @import("builtin"); - -const testing = std.testing; -const assert = std.debug.assert; -const crypto = std.crypto; -const math = std.math; -const mem = std.mem; -const RndGen = std.Random.DefaultPrng; -const sha3 = crypto.hash.sha3; - -// Q is the parameter q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1. -const Q: i16 = 3329; - -// Montgomery R -const R: i32 = 1 << 16; - -// Parameter n, degree of polynomials. -const N: usize = 256; - -// Size of "small" vectors used in encryption blinds. -const eta2: u8 = 2; - -const Params = struct { - name: []const u8, - - // Width and height of the matrix A. - k: u8, - - // Size of "small" vectors used in private key and encryption blinds. - eta1: u8, - - // How many bits to retain of u, the private-key independent part - // of the ciphertext. - du: u8, - - // How many bits to retain of v, the private-key dependent part - // of the ciphertext. - dv: u8, -}; - -pub const Kyber512 = Kyber(.{ - .name = "Kyber512", - .k = 2, - .eta1 = 3, - .du = 10, - .dv = 4, -}); - -pub const Kyber768 = Kyber(.{ - .name = "Kyber768", - .k = 3, - .eta1 = 2, - .du = 10, - .dv = 4, -}); - -pub const Kyber1024 = Kyber(.{ - .name = "Kyber1024", - .k = 4, - .eta1 = 2, - .du = 11, - .dv = 5, -}); - -const modes = [_]type{ Kyber512, Kyber768, Kyber1024 }; -const h_length: usize = 32; -const inner_seed_length: usize = 32; -const common_encaps_seed_length: usize = 32; -const common_shared_key_size: usize = 32; - -fn Kyber(comptime p: Params) type { - return struct { - // Size of a ciphertext, in bytes. - pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv); - - const Self = @This(); - const V = Vec(p.k); - const M = Mat(p.k); - - /// Length (in bytes) of a shared secret. - pub const shared_length = common_shared_key_size; - /// Length (in bytes) of a seed for deterministic encapsulation. - pub const encaps_seed_length = common_encaps_seed_length; - /// Length (in bytes) of a seed for key generation. - pub const seed_length: usize = inner_seed_length + shared_length; - /// Algorithm name. - pub const name = p.name; - - /// A shared secret, and an encapsulated (encrypted) representation of it. - pub const EncapsulatedSecret = struct { - shared_secret: [shared_length]u8, - ciphertext: [ciphertext_length]u8, - }; - - /// A Kyber public key. - pub const PublicKey = struct { - pk: InnerPk, - - // Cached - hpk: [h_length]u8, // H(pk) - - /// Size of a serialized representation of the key, in bytes. - pub const bytes_length = InnerPk.bytes_length; - - /// Generates a shared secret, and encapsulates it for the public key. - /// If `seed` is `null`, a random seed is used. This is recommended. - /// If `seed` is set, encapsulation is deterministic. - pub fn encaps(pk: PublicKey, seed_: ?[encaps_seed_length]u8) EncapsulatedSecret { - const seed = seed_ orelse seed: { - var random_seed: [encaps_seed_length]u8 = undefined; - crypto.random.bytes(&random_seed); - break :seed random_seed; - }; - - var m: [inner_plaintext_length]u8 = undefined; - - // m = H(seed) - var h = sha3.Sha3_256.init(.{}); - h.update(&seed); - h.final(&m); - - // (K', r) = G(m ‖ H(pk)) - var kr: [inner_plaintext_length + h_length]u8 = undefined; - var g = sha3.Sha3_512.init(.{}); - g.update(&m); - g.update(&pk.hpk); - g.final(&kr); - - // c = innerEncrypy(pk, m, r) - const ct = pk.pk.encrypt(&m, kr[32..64]); - - // Compute H(c) and put in second slot of kr, which will be (K', H(c)). - h = sha3.Sha3_256.init(.{}); - h.update(&ct); - h.final(kr[32..64]); - - // K = KDF(K' ‖ H(c)) - var kdf = sha3.Shake256.init(.{}); - kdf.update(&kr); - var ss: [shared_length]u8 = undefined; - kdf.squeeze(&ss); - - return EncapsulatedSecret{ - .shared_secret = ss, - .ciphertext = ct, - }; - } - - /// Serializes the key into a byte array. - pub fn toBytes(pk: PublicKey) [bytes_length]u8 { - return pk.pk.toBytes(); - } - - /// Deserializes the key from a byte array. - pub fn fromBytes(buf: *const [bytes_length]u8) !PublicKey { - var ret: PublicKey = undefined; - ret.pk = InnerPk.fromBytes(buf[0..InnerPk.bytes_length]); - - var h = sha3.Sha3_256.init(.{}); - h.update(buf); - h.final(&ret.hpk); - return ret; - } - }; - - /// A Kyber secret key. - pub const SecretKey = struct { - sk: InnerSk, - pk: InnerPk, - hpk: [h_length]u8, // H(pk) - z: [shared_length]u8, - - /// Size of a serialized representation of the key, in bytes. - pub const bytes_length: usize = - InnerSk.bytes_length + InnerPk.bytes_length + h_length + shared_length; - - /// Decapsulates the shared secret within ct using the private key. - pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 { - // m' = innerDec(ct) - const m2 = sk.sk.decrypt(ct); - - // (K'', r') = G(m' ‖ H(pk)) - var kr2: [64]u8 = undefined; - var g = sha3.Sha3_512.init(.{}); - g.update(&m2); - g.update(&sk.hpk); - g.final(&kr2); - - // ct' = innerEnc(pk, m', r') - const ct2 = sk.pk.encrypt(&m2, kr2[32..64]); - - // Compute H(ct) and put in the second slot of kr2 which will be (K'', H(ct)). - var h = sha3.Sha3_256.init(.{}); - h.update(ct); - h.final(kr2[32..64]); - - // Replace K'' by z in the first slot of kr2 if ct ≠ ct'. - cmov(32, kr2[0..32], sk.z, ctneq(ciphertext_length, ct.*, ct2)); - - // K = KDF(K''/z, H(c)) - var kdf = sha3.Shake256.init(.{}); - var ss: [shared_length]u8 = undefined; - kdf.update(&kr2); - kdf.squeeze(&ss); - return ss; - } - - /// Serializes the key into a byte array. - pub fn toBytes(sk: SecretKey) [bytes_length]u8 { - return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z; - } - - /// Deserializes the key from a byte array. - pub fn fromBytes(buf: *const [bytes_length]u8) !SecretKey { - var ret: SecretKey = undefined; - comptime var s: usize = 0; - ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]); - s += InnerSk.bytes_length; - ret.pk = InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]); - s += InnerPk.bytes_length; - ret.hpk = buf[s..][0..h_length].*; - s += h_length; - ret.z = buf[s..][0..shared_length].*; - return ret; - } - }; - - /// A Kyber key pair. - pub const KeyPair = struct { - secret_key: SecretKey, - public_key: PublicKey, - - /// Create a new key pair. - /// If seed is null, a random seed will be generated. - /// If a seed is provided, the key pair will be determinsitic. - pub fn create(seed_: ?[seed_length]u8) !KeyPair { - const seed = seed_ orelse sk: { - var random_seed: [seed_length]u8 = undefined; - crypto.random.bytes(&random_seed); - break :sk random_seed; - }; - var ret: KeyPair = undefined; - ret.secret_key.z = seed[inner_seed_length..seed_length].*; - - // Generate inner key - innerKeyFromSeed( - seed[0..inner_seed_length].*, - &ret.public_key.pk, - &ret.secret_key.sk, - ); - ret.secret_key.pk = ret.public_key.pk; - - // Copy over z from seed. - ret.secret_key.z = seed[inner_seed_length..seed_length].*; - - // Compute H(pk) - var h = sha3.Sha3_256.init(.{}); - h.update(&ret.public_key.pk.toBytes()); - h.final(&ret.secret_key.hpk); - ret.public_key.hpk = ret.secret_key.hpk; - - return ret; - } - }; - - // Size of plaintexts of the in - const inner_plaintext_length: usize = Poly.compressedSize(1); - - const InnerPk = struct { - rho: [32]u8, // ρ, the seed for the matrix A - th: V, // NTT(t), normalized - - // Cached values - aT: M, - - const bytes_length = V.bytes_length + 32; - - fn encrypt( - pk: InnerPk, - pt: *const [inner_plaintext_length]u8, - seed: *const [32]u8, - ) [ciphertext_length]u8 { - // Sample r, e₁ and e₂ appropriately - const rh = V.noise(p.eta1, 0, seed).ntt().barrettReduce(); - const e1 = V.noise(eta2, p.k, seed); - const e2 = Poly.noise(eta2, 2 * p.k, seed); - - // Next we compute u = Aᵀ r + e₁. First Aᵀ. - var u: V = undefined; - for (0..p.k) |i| { - // Note that coefficients of r are bounded by q and those of Aᵀ - // are bounded by 4.5q and so their product is bounded by 2¹⁵q - // as required for multiplication. - u.ps[i] = pk.aT.vs[i].dotHat(rh); - } - - // Aᵀ and r were not in Montgomery form, so the Montgomery - // multiplications in the inner product added a factor R⁻¹ which - // the InvNTT cancels out. - u = u.barrettReduce().invNTT().add(e1).normalize(); - - // Next, compute v = <t, r> + e₂ + Decompress_q(m, 1) - const v = pk.th.dotHat(rh).barrettReduce().invNTT() - .add(Poly.decompress(1, pt)).add(e2).normalize(); - - return u.compress(p.du) ++ v.compress(p.dv); - } - - fn toBytes(pk: InnerPk) [bytes_length]u8 { - return pk.th.toBytes() ++ pk.rho; - } - - fn fromBytes(buf: *const [bytes_length]u8) InnerPk { - var ret: InnerPk = undefined; - ret.th = V.fromBytes(buf[0..V.bytes_length]).normalize(); - ret.rho = buf[V.bytes_length..bytes_length].*; - ret.aT = M.uniform(ret.rho, true); - return ret; - } - }; - - // Private key of the inner PKE - const InnerSk = struct { - sh: V, // NTT(s), normalized - const bytes_length = V.bytes_length; - - fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 { - const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]); - const v = Poly.decompress( - p.dv, - ct[comptime V.compressedSize(p.du)..ciphertext_length], - ); - - // Compute m = v - <s, u> - return v.sub(sk.sh.dotHat(u.ntt()).barrettReduce().invNTT()) - .normalize().compress(1); - } - - fn toBytes(sk: InnerSk) [bytes_length]u8 { - return sk.sh.toBytes(); - } - - fn fromBytes(buf: *const [bytes_length]u8) InnerSk { - var ret: InnerSk = undefined; - ret.sh = V.fromBytes(buf).normalize(); - return ret; - } - }; - - // Derives inner PKE keypair from given seed. - fn innerKeyFromSeed(seed: [inner_seed_length]u8, pk: *InnerPk, sk: *InnerSk) void { - var expanded_seed: [64]u8 = undefined; - - var h = sha3.Sha3_512.init(.{}); - h.update(&seed); - h.final(&expanded_seed); - pk.rho = expanded_seed[0..32].*; - const sigma = expanded_seed[32..64]; - pk.aT = M.uniform(pk.rho, false); // Expand ρ to A; we'll transpose later on - - // Sample secret vector s. - sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize(); - - const eh = Vec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e. - var th: V = undefined; - - // Next, we compute t = A s + e. - for (0..p.k) |i| { - // Note that coefficients of s are bounded by q and those of A - // are bounded by 4.5q and so their product is bounded by 2¹⁵q - // as required for multiplication. - // A and s were not in Montgomery form, so the Montgomery - // multiplications in the inner product added a factor R⁻¹ which - // we'll cancel out with toMont(). This will also ensure the - // coefficients of th are bounded in absolute value by q. - th.ps[i] = pk.aT.vs[i].dotHat(sk.sh).toMont(); - } - - pk.th = th.add(eh).normalize(); // bounded by 8q - pk.aT = pk.aT.transpose(); - } - }; -} - -// R mod q -const r_mod_q: i32 = @rem(@as(i32, R), Q); - -// R² mod q -const r2_mod_q: i32 = @rem(r_mod_q * r_mod_q, Q); - -// ζ is the degree 256 primitive root of unity used for the NTT. -const zeta: i16 = 17; - -// (128)⁻¹ R². Used in inverse NTT. -const r2_over_128: i32 = @mod(invertMod(128, Q) * r2_mod_q, Q); - -// zetas lists precomputed powers of the primitive root of unity in -// Montgomery representation used for the NTT: -// -// zetas[i] = ζᵇʳᵛ⁽ⁱ⁾ R mod q -// -// where ζ = 17, brv(i) is the bitreversal of a 7-bit number and R=2¹⁶ mod q. -const zetas = computeZetas(); - -// invNTTReductions keeps track of which coefficients to apply Barrett -// reduction to in Poly.invNTT(). -// -// Generated lazily: once a butterfly is computed which is about to -// overflow the i16, the largest coefficient is reduced. If that is -// not enough, the other coefficient is reduced as well. -// -// This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf -// TODO generate comptime? -const inv_ntt_reductions = [_]i16{ - -1, // after layer 1 - -1, // after layer 2 - 16, - 17, - 48, - 49, - 80, - 81, - 112, - 113, - 144, - 145, - 176, - 177, - 208, - 209, - 240, 241, -1, // after layer 3 - 0, 1, 32, - 33, 34, 35, - 64, 65, 96, - 97, 98, 99, - 128, 129, - 160, 161, 162, 163, 192, 193, 224, 225, 226, 227, -1, // after layer 4 - 2, 3, 66, 67, 68, 69, 70, 71, 130, 131, 194, - 195, 196, 197, - 198, 199, -1, // after layer 5 - 4, 5, 6, - 7, 132, 133, - 134, 135, 136, - 137, 138, 139, - 140, 141, - 142, 143, -1, // after layer 6 - -1, // after layer 7 -}; - -test "invNTTReductions bounds" { - // Checks whether the reductions proposed by invNTTReductions - // don't overflow during invNTT(). - var xs = [_]i32{1} ** 256; // start at |x| ≤ q - - var r: usize = 0; - var layer: math.Log2Int(usize) = 1; - while (layer < 8) : (layer += 1) { - const w = @as(usize, 1) << layer; - var i: usize = 0; - - while (i + w < 256) { - xs[i] = xs[i] + xs[i + w]; - try testing.expect(xs[i] <= 9); // we can't exceed 9q - xs[i + w] = 1; - i += 1; - if (@mod(i, w) == 0) { - i += w; - } - } - - while (true) { - const j = inv_ntt_reductions[r]; - r += 1; - if (j < 0) { - break; - } - xs[@as(usize, @intCast(j))] = 1; - } - } -} - -// Extended euclidean algorithm. -// -// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute -// modular inverse. -fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) { - if (a == 0) { - return .{ .gcd = b, .x = 0, .y = 1 }; - } - const r = eea(@rem(b, a), a); - return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x }; -} - -fn EeaResult(comptime T: type) type { - return struct { gcd: T, x: T, y: T }; -} - -// Returns least common multiple of a and b. -fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) { - const r = eea(a, b); - return a * b / r.gcd; -} - -// Invert modulo p. -fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) { - const r = eea(a, p); - assert(r.gcd == 1); - return r.x; -} - -// Reduce mod q for testing. -fn modQ32(x: i32) i16 { - var y = @as(i16, @intCast(@rem(x, @as(i32, Q)))); - if (y < 0) { - y += Q; - } - return y; -} - -// Given -2¹⁵ q ≤ x < 2¹⁵ q, returns -q < y < q with x 2⁻¹⁶ = y (mod q). -fn montReduce(x: i32) i16 { - const qInv = comptime invertMod(@as(i32, Q), R); - // This is Montgomery reduction with R=2¹⁶. - // - // Note gcd(2¹⁶, q) = 1 as q is prime. Write q' := 62209 = q⁻¹ mod R. - // First we compute - // - // m := ((x mod R) q') mod R - // = x q' mod R - // = int16(x q') - // = int16(int32(x) * int32(q')) - // - // Note that x q' might be as big as 2³² and could overflow the int32 - // multiplication in the last line. However for any int32s a and b, - // we have int32(int64(a)*int64(b)) = int32(a*b) and so the result is ok. - const m: i16 = @truncate(@as(i32, @truncate(x *% qInv))); - - // Note that x - m q is divisible by R; indeed modulo R we have - // - // x - m q ≡ x - x q' q ≡ x - x q⁻¹ q ≡ x - x = 0. - // - // We return y := (x - m q) / R. Note that y is indeed correct as - // modulo q we have - // - // y ≡ x R⁻¹ - m q R⁻¹ = x R⁻¹ - // - // and as both 2¹⁵ q ≤ m q, x < 2¹⁵ q, we have - // 2¹⁶ q ≤ x - m q < 2¹⁶ and so q ≤ (x - m q) / R < q as desired. - const yR = x - @as(i32, m) * @as(i32, Q); - return @bitCast(@as(u16, @truncate(@as(u32, @bitCast(yR)) >> 16))); -} - -test "Test montReduce" { - var rnd = RndGen.init(0); - for (0..1000) |_| { - const bound = comptime @as(i32, Q) * (1 << 15); - const x = rnd.random().intRangeLessThan(i32, -bound, bound); - const y = montReduce(x); - try testing.expect(-Q < y and y < Q); - try testing.expectEqual(modQ32(x), modQ32(@as(i32, y) * R)); - } -} - -// Given any x, return x R mod q where R=2¹⁶. -fn feToMont(x: i16) i16 { - // Note |1353 x| ≤ 1353 2¹⁵ ≤ 13318 q ≤ 2¹⁵ q and so we're within - // the bounds of montReduce. - return montReduce(@as(i32, x) * r2_mod_q); -} - -test "Test feToMont" { - var x: i32 = -(1 << 15); - while (x < 1 << 15) : (x += 1) { - const y = feToMont(@as(i16, @intCast(x))); - try testing.expectEqual(modQ32(@as(i32, y)), modQ32(x * r_mod_q)); - } -} - -// Given any x, compute 0 ≤ y ≤ q with x = y (mod q). -// -// Beware: we might have feBarrettReduce(x) = q ≠ 0 for some x. In fact, -// this happens if and only if x = -nq for some positive integer n. -fn feBarrettReduce(x: i16) i16 { - // This is standard Barrett reduction. - // - // For any x we have x mod q = x - ⌊x/q⌋ q. We will use 20159/2²⁶ as - // an approximation of 1/q. Note that 0 ≤ 20159/2²⁶ - 1/q ≤ 0.135/2²⁶ - // and so | x 20156/2²⁶ - x/q | ≤ 2⁻¹⁰ for |x| ≤ 2¹⁶. For all x - // not a multiple of q, the number x/q is further than 1/q from any integer - // and so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋. If x is a multiple of q and x is positive, - // then x 20156/2²⁶ is larger than x/q so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋ as well. - // Finally, if x is negative multiple of q, then ⌊x 20156/2²⁶⌋ = ⌊x/q⌋-1. - // Thus - // [ q if x=-nq for pos. integer n - // x - ⌊x 20156/2²⁶⌋ q = [ - // [ x mod q otherwise - // - // To actually compute this, note that - // - // ⌊x 20156/2²⁶⌋ = (20159 x) >> 26. - return x -% @as(i16, @intCast((@as(i32, x) * 20159) >> 26)) *% Q; -} - -test "Test Barrett reduction" { - var x: i32 = -(1 << 15); - while (x < 1 << 15) : (x += 1) { - var y1 = feBarrettReduce(@as(i16, @intCast(x))); - const y2 = @mod(@as(i16, @intCast(x)), Q); - if (x < 0 and @rem(-x, Q) == 0) { - y1 -= Q; - } - try testing.expectEqual(y1, y2); - } -} - -// Returns x if x < q and x - q otherwise. Assumes x ≥ -29439. -fn csubq(x: i16) i16 { - var r = x; - r -= Q; - r += (r >> 15) & Q; - return r; -} - -test "Test csubq" { - var x: i32 = -29439; - while (x < 1 << 15) : (x += 1) { - const y1 = csubq(@as(i16, @intCast(x))); - var y2 = @as(i16, @intCast(x)); - if (@as(i16, @intCast(x)) >= Q) { - y2 -= Q; - } - try testing.expectEqual(y1, y2); - } -} - -// Compute a^s mod p. -fn mpow(a: anytype, s: @TypeOf(a), p: @TypeOf(a)) @TypeOf(a) { - var ret: @TypeOf(a) = 1; - var s2 = s; - var a2 = a; - - while (true) { - if (s2 & 1 == 1) { - ret = @mod(ret * a2, p); - } - s2 >>= 1; - if (s2 == 0) { - break; - } - a2 = @mod(a2 * a2, p); - } - return ret; -} - -// Computes zetas table used by ntt and invNTT. -fn computeZetas() [128]i16 { - @setEvalBranchQuota(10000); - var ret: [128]i16 = undefined; - for (&ret, 0..) |*r, i| { - const t = @as(i16, @intCast(mpow(@as(i32, zeta), @bitReverse(@as(u7, @intCast(i))), Q))); - r.* = csubq(feBarrettReduce(feToMont(t))); - } - return ret; -} - -// An element of our base ring R which are polynomials over ℤ_q -// modulo the equation Xᴺ = -1, where q=3329 and N=256. -// -// This type is also used to store NTT-transformed polynomials, -// see Poly.NTT(). -// -// Coefficients aren't always reduced. See Normalize(). -const Poly = struct { - cs: [N]i16, - - const bytes_length = N / 2 * 3; - const zero: Poly = .{ .cs = .{0} ** N }; - - fn add(a: Poly, b: Poly) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = a.cs[i] + b.cs[i]; - } - return ret; - } - - fn sub(a: Poly, b: Poly) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = a.cs[i] - b.cs[i]; - } - return ret; - } - - // For testing, generates a random polynomial with for each - // coefficient |x| ≤ q. - fn randAbsLeqQ(rnd: anytype) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q); - } - return ret; - } - - // For testing, generates a random normalized polynomial. - fn randNormalized(rnd: anytype) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q); - } - return ret; - } - - // Executes a forward "NTT" on p. - // - // Assumes the coefficients are in absolute value ≤q. The resulting - // coefficients are in absolute value ≤7q. If the input is in Montgomery - // form, then the result is in Montgomery form and so (by linearity of the NTT) - // if the input is in regular form, then the result is also in regular form. - fn ntt(a: Poly) Poly { - // Note that ℤ_q does not have a primitive 512ᵗʰ root of unity (as 512 - // does not divide into q-1) and so we cannot do a regular NTT. ℤ_q - // does have a primitive 256ᵗʰ root of unity, the smallest of which - // is ζ := 17. - // - // Recall that our base ring R := ℤ_q[x] / (x²⁵⁶ + 1). The polynomial - // x²⁵⁶+1 will not split completely (as its roots would be 512ᵗʰ roots - // of unity.) However, it does split almost (using ζ¹²⁸ = -1): - // - // x²⁵⁶ + 1 = (x²)¹²⁸ - ζ¹²⁸ - // = ((x²)⁶⁴ - ζ⁶⁴)((x²)⁶⁴ + ζ⁶⁴) - // = ((x²)³² - ζ³²)((x²)³² + ζ³²)((x²)³² - ζ⁹⁶)((x²)³² + ζ⁹⁶) - // ⋮ - // = (x² - ζ)(x² + ζ)(x² - ζ⁶⁵)(x² + ζ⁶⁵) … (x² + ζ¹²⁷) - // - // Note that the powers of ζ that appear (from the second line down) are - // in binary - // - // 0100000 1100000 - // 0010000 1010000 0110000 1110000 - // 0001000 1001000 0101000 1101000 0011000 1011000 0111000 1111000 - // … - // - // That is: brv(2), brv(3), brv(4), …, where brv(x) denotes the 7-bit - // bitreversal of x. These powers of ζ are given by the Zetas array. - // - // The polynomials x² ± ζⁱ are irreducible and coprime, hence by - // the Chinese Remainder Theorem we know - // - // ℤ_q[x]/(x²⁵⁶+1) → ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷) - // - // given by a ↦ ( a mod x²-ζ, …, a mod x²+ζ¹²⁷ ) - // is an isomorphism, which is the "NTT". It can be efficiently computed by - // - // - // a ↦ ( a mod (x²)⁶⁴ - ζ⁶⁴, a mod (x²)⁶⁴ + ζ⁶⁴ ) - // ↦ ( a mod (x²)³² - ζ³², a mod (x²)³² + ζ³², - // a mod (x²)⁹⁶ - ζ⁹⁶, a mod (x²)⁹⁶ + ζ⁹⁶ ) - // - // et cetera - // If N was 8 then this can be pictured in the following diagram: - // - // https://cnx.org/resources/17ee4dfe517a6adda05377b25a00bf6e6c93c334/File0026.png - // - // Each cross is a Cooley-Tukey butterfly: it's the map - // - // (a, b) ↦ (a + ζb, a - ζb) - // - // for the appropriate power ζ for that column and row group. - var p = a; - var k: usize = 0; // index into zetas - - var l = N >> 1; - while (l > 1) : (l >>= 1) { - // On the nᵗʰ iteration of the l-loop, the absolute value of the - // coefficients are bounded by nq. - - // offset effectively loops over the row groups in this column; it is - // the first row in the row group. - var offset: usize = 0; - while (offset < N - l) : (offset += 2 * l) { - k += 1; - const z = @as(i32, zetas[k]); - - // j loops over each butterfly in the row group. - for (offset..offset + l) |j| { - const t = montReduce(z * @as(i32, p.cs[j + l])); - p.cs[j + l] = p.cs[j] - t; - p.cs[j] += t; - } - } - } - - return p; - } - - // Executes an inverse "NTT" on p and multiply by the Montgomery factor R. - // - // Assumes the coefficients are in absolute value ≤q. The resulting - // coefficients are in absolute value ≤q. If the input is in Montgomery - // form, then the result is in Montgomery form and so (by linearity) - // if the input is in regular form, then the result is also in regular form. - fn invNTT(a: Poly) Poly { - var k: usize = 127; // index into zetas - var r: usize = 0; // index into invNTTReductions - var p = a; - - // We basically do the oppposite of NTT, but postpone dividing by 2 in the - // inverse of the Cooley-Tukey butterfly and accumulate that into a big - // division by 2⁷ at the end. See the comments in the ntt() function. - - var l: usize = 2; - while (l < N) : (l <<= 1) { - var offset: usize = 0; - while (offset < N - l) : (offset += 2 * l) { - // As we're inverting, we need powers of ζ⁻¹ (instead of ζ). - // To be precise, we need ζᵇʳᵛ⁽ᵏ⁾⁻¹²⁸. However, as ζ⁻¹²⁸ = -1, - // we can use the existing zetas table instead of - // keeping a separate invZetas table as in Dilithium. - - const minZeta = @as(i32, zetas[k]); - k -= 1; - - for (offset..offset + l) |j| { - // Gentleman-Sande butterfly: (a, b) ↦ (a + b, ζ(a-b)) - const t = p.cs[j + l] - p.cs[j]; - p.cs[j] += p.cs[j + l]; - p.cs[j + l] = montReduce(minZeta * @as(i32, t)); - - // Note that if we had |a| < αq and |b| < βq before the - // butterfly, then now we have |a| < (α+β)q and |b| < q. - } - } - - // We let the invNTTReductions instruct us which coefficients to - // Barrett reduce. - while (true) { - const i = inv_ntt_reductions[r]; - r += 1; - if (i < 0) { - break; - } - p.cs[@as(usize, @intCast(i))] = feBarrettReduce(p.cs[@as(usize, @intCast(i))]); - } - } - - for (0..N) |j| { - // Note 1441 = (128)⁻¹ R². The coefficients are bounded by 9q, so - // as 1441 * 9 ≈ 2¹⁴ < 2¹⁵, we're within the required bounds - // for montReduce(). - p.cs[j] = montReduce(r2_over_128 * @as(i32, p.cs[j])); - } - - return p; - } - - // Normalizes coefficients. - // - // Ensures each coefficient is in {0, …, q-1}. - fn normalize(a: Poly) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = csubq(feBarrettReduce(a.cs[i])); - } - return ret; - } - - // Put p in Montgomery form. - fn toMont(a: Poly) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = feToMont(a.cs[i]); - } - return ret; - } - - // Barret reduce coefficients. - // - // Beware, this does not fully normalize coefficients. - fn barrettReduce(a: Poly) Poly { - var ret: Poly = undefined; - for (0..N) |i| { - ret.cs[i] = feBarrettReduce(a.cs[i]); - } - return ret; - } - - fn compressedSize(comptime d: u8) usize { - return @divTrunc(N * d, 8); - } - - // Returns packed Compress_q(p, d). - // - // Assumes p is normalized. - fn compress(p: Poly, comptime d: u8) [compressedSize(d)]u8 { - @setEvalBranchQuota(10000); - const q_over_2: u32 = comptime @divTrunc(Q, 2); // (q-1)/2 - const two_d_min_1: u32 = comptime (1 << d) - 1; // 2ᵈ-1 - var in_off: usize = 0; - var out_off: usize = 0; - - const batch_size: usize = comptime lcm(@as(i16, d), 8); - const in_batch_size: usize = comptime batch_size / d; - const out_batch_size: usize = comptime batch_size / 8; - - const out_length: usize = comptime @divTrunc(N * d, 8); - comptime assert(out_length * 8 == d * N); - var out = [_]u8{0} ** out_length; - - while (in_off < N) { - // First we compress into in. - var in: [in_batch_size]u16 = undefined; - inline for (0..in_batch_size) |i| { - // Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ - // = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ - // = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ - // = DIV((x << d) + q/2, q) & ((1<<d) - 1) - const t = @as(u24, @intCast(p.cs[in_off + i])) << d; - // Division by invariant multiplication, equivalent to DIV(t + q/2, q). - // A division may not be a constant-time operation, even with a constant denominator. - // Here, side channels would leak information about the shared secret, see https://kyberslash.cr.yp.to - // Multiplication, on the other hand, is a constant-time operation on the CPUs we currently support. - comptime assert(d <= 11); - comptime assert(((20642679 * @as(u64, Q)) >> 36) == 1); - const u: u32 = @intCast((@as(u64, t + q_over_2) * 20642679) >> 36); - in[i] = @intCast(u & two_d_min_1); - } - - // Now we pack the d-bit integers from `in' into out as bytes. - comptime var in_shift: usize = 0; - comptime var j: usize = 0; - comptime var i: usize = 0; - inline while (i < in_batch_size) : (j += 1) { - comptime var todo: usize = 8; - inline while (todo > 0) { - const out_shift = comptime 8 - todo; - out[out_off + j] |= @as(u8, @truncate((in[i] >> in_shift) << out_shift)); - - const done = comptime @min(@min(d, todo), d - in_shift); - todo -= done; - in_shift += done; - - if (in_shift == d) { - in_shift = 0; - i += 1; - } - } - } - - in_off += in_batch_size; - out_off += out_batch_size; - } - - return out; - } - - // Set p to Decompress_q(m, d). - fn decompress(comptime d: u8, in: *const [compressedSize(d)]u8) Poly { - @setEvalBranchQuota(10000); - const inLen = comptime @divTrunc(N * d, 8); - comptime assert(inLen * 8 == d * N); - var ret: Poly = undefined; - var in_off: usize = 0; - var out_off: usize = 0; - - const batch_size: usize = comptime lcm(@as(i16, d), 8); - const in_batch_size: usize = comptime batch_size / 8; - const out_batch_size: usize = comptime batch_size / d; - - while (out_off < N) { - comptime var in_shift: usize = 0; - comptime var j: usize = 0; - comptime var i: usize = 0; - inline while (i < out_batch_size) : (i += 1) { - // First, unpack next coefficient. - comptime var todo = d; - var out: u16 = 0; - - inline while (todo > 0) { - const out_shift = comptime d - todo; - const m = comptime (1 << d) - 1; - out |= (@as(u16, in[in_off + j] >> in_shift) << out_shift) & m; - - const done = comptime @min(@min(8, todo), 8 - in_shift); - todo -= done; - in_shift += done; - - if (in_shift == 8) { - in_shift = 0; - j += 1; - } - } - - // Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋ - // = ⌊(q/2ᵈ)x+½⌋ - // = ⌊(qx + 2ᵈ⁻¹)/2ᵈ⌋ - // = (qx + (1<<(d-1))) >> d - const qx = @as(u32, out) * @as(u32, Q); - ret.cs[out_off + i] = @as(i16, @intCast((qx + (1 << (d - 1))) >> d)); - } - - in_off += in_batch_size; - out_off += out_batch_size; - } - - return ret; - } - - // Returns the "pointwise" multiplication a o b. - // - // That is: invNTT(a o b) = invNTT(a) * invNTT(b). Assumes a and b are in - // Montgomery form. Products between coefficients of a and b must be strictly - // bounded in absolute value by 2¹⁵q. a o b will be in Montgomery form and - // bounded in absolute value by 2q. - fn mulHat(a: Poly, b: Poly) Poly { - // Recall from the discussion in ntt(), that a transformed polynomial is - // an element of ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷); - // that is: 128 degree-one polynomials instead of simply 256 elements - // from ℤ_q as in the regular NTT. So instead of pointwise multiplication, - // we multiply the 128 pairs of degree-one polynomials modulo the - // right equation: - // - // (a₁ + a₂x)(b₁ + b₂x) = a₁b₁ + a₂b₂ζ' + (a₁b₂ + a₂b₁)x, - // - // where ζ' is the appropriate power of ζ. - - var p: Poly = undefined; - var k: usize = 64; - var i: usize = 0; - while (i < N) : (i += 4) { - const z = @as(i32, zetas[k]); - k += 1; - - const a1b1 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i + 1])); - const a0b0 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i])); - const a1b0 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i])); - const a0b1 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i + 1])); - - p.cs[i] = montReduce(a1b1 * z) + a0b0; - p.cs[i + 1] = a0b1 + a1b0; - - const a3b3 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 3])); - const a2b2 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 2])); - const a3b2 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 2])); - const a2b3 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 3])); - - p.cs[i + 2] = a2b2 - montReduce(a3b3 * z); - p.cs[i + 3] = a2b3 + a3b2; - } - - return p; - } - - // Sample p from a centered binomial distribution with n=2η and p=½ - viz: - // coefficients are in {-η, …, η} with probabilities - // - // {ncr(0, 2η)/2^2η, ncr(1, 2η)/2^2η, …, ncr(2η,2η)/2^2η} - fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Poly { - var h = sha3.Shake256.init(.{}); - const suffix: [1]u8 = .{nonce}; - h.update(seed); - h.update(&suffix); - - // The distribution at hand is exactly the same as that - // of (a₁ + a₂ + … + a_η) - (b₁ + … + b_η) where a_i,b_i~U(1). - // Thus we need 2η bits per coefficient. - const buf_len = comptime 2 * eta * N / 8; - var buf: [buf_len]u8 = undefined; - h.squeeze(&buf); - - // buf is interpreted as a₁…a_ηb₁…b_ηa₁…a_ηb₁…b_η…. We process - // multiple coefficients in one batch. - - const T = switch (builtin.target.cpu.arch) { - .x86_64, .x86 => u32, // Generates better code on Intel CPUs - else => u64, // u128 might be faster on some other CPUs. - }; - - comptime var batch_count: usize = undefined; - comptime var batch_bytes: usize = undefined; - comptime var mask: T = 0; - comptime { - batch_count = @bitSizeOf(T) / @as(usize, 2 * eta); - while (@rem(N, batch_count) != 0 and batch_count > 0) : (batch_count -= 1) {} - assert(batch_count > 0); - assert(@rem(2 * eta * batch_count, 8) == 0); - batch_bytes = 2 * eta * batch_count / 8; - - for (0..2 * eta * batch_count) |_| { - mask <<= eta; - mask |= 1; - } - } - - var ret: Poly = undefined; - for (0..comptime N / batch_count) |i| { - // Read coefficients into t. In the case of η=3, - // we have t = a₁ + 2a₂ + 4a₃ + 8b₁ + 16b₂ + … - var t: T = 0; - inline for (0..batch_bytes) |j| { - t |= @as(T, buf[batch_bytes * i + j]) << (8 * j); - } - - // Accumelate `a's and `b's together by masking them out, shifting - // and adding. For η=3, we have d = a₁ + a₂ + a₃ + 8(b₁ + b₂ + b₃) + … - var d: T = 0; - inline for (0..eta) |j| { - d += (t >> j) & mask; - } - - // Extract each a and b separately and set coefficient in polynomial. - inline for (0..batch_count) |j| { - const mask2 = comptime (1 << eta) - 1; - const a = @as(i16, @intCast((d >> (comptime (2 * j * eta))) & mask2)); - const b = @as(i16, @intCast((d >> (comptime ((2 * j + 1) * eta))) & mask2)); - ret.cs[batch_count * i + j] = a - b; - } - } - - return ret; - } - - // Sample p uniformly from the given seed and x and y coordinates. - fn uniform(seed: [32]u8, x: u8, y: u8) Poly { - var h = sha3.Shake128.init(.{}); - const suffix: [2]u8 = .{ x, y }; - h.update(&seed); - h.update(&suffix); - - const buf_len = sha3.Shake128.block_length; // rate SHAKE-128 - var buf: [buf_len]u8 = undefined; - - var ret: Poly = undefined; - var i: usize = 0; // index into ret.cs - outer: while (true) { - h.squeeze(&buf); - - var j: usize = 0; // index into buf - while (j < buf_len) : (j += 3) { - const b0 = @as(u16, buf[j]); - const b1 = @as(u16, buf[j + 1]); - const b2 = @as(u16, buf[j + 2]); - - const ts: [2]u16 = .{ - b0 | ((b1 & 0xf) << 8), - (b1 >> 4) | (b2 << 4), - }; - - inline for (ts) |t| { - if (t < Q) { - ret.cs[i] = @as(i16, @intCast(t)); - i += 1; - - if (i == N) { - break :outer; - } - } - } - } - } - - return ret; - } - - // Packs p. - // - // Assumes p is normalized (and not just Barrett reduced). - fn toBytes(p: Poly) [bytes_length]u8 { - var ret: [bytes_length]u8 = undefined; - for (0..comptime N / 2) |i| { - const t0 = @as(u16, @intCast(p.cs[2 * i])); - const t1 = @as(u16, @intCast(p.cs[2 * i + 1])); - ret[3 * i] = @as(u8, @truncate(t0)); - ret[3 * i + 1] = @as(u8, @truncate((t0 >> 8) | (t1 << 4))); - ret[3 * i + 2] = @as(u8, @truncate(t1 >> 4)); - } - return ret; - } - - // Unpacks a Poly from buf. - // - // p will not be normalized; instead 0 ≤ p[i] < 4096. - fn fromBytes(buf: *const [bytes_length]u8) Poly { - var ret: Poly = undefined; - for (0..comptime N / 2) |i| { - const b0 = @as(i16, buf[3 * i]); - const b1 = @as(i16, buf[3 * i + 1]); - const b2 = @as(i16, buf[3 * i + 2]); - ret.cs[2 * i] = b0 | ((b1 & 0xf) << 8); - ret.cs[2 * i + 1] = (b1 >> 4) | b2 << 4; - } - return ret; - } -}; - -// A vector of K polynomials. -fn Vec(comptime K: u8) type { - return struct { - ps: [K]Poly, - - const Self = @This(); - const bytes_length = K * Poly.bytes_length; - - fn compressedSize(comptime d: u8) usize { - return Poly.compressedSize(d) * K; - } - - fn ntt(a: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].ntt(); - } - return ret; - } - - fn invNTT(a: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].invNTT(); - } - return ret; - } - - fn normalize(a: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].normalize(); - } - return ret; - } - - fn barrettReduce(a: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].barrettReduce(); - } - return ret; - } - - fn add(a: Self, b: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].add(b.ps[i]); - } - return ret; - } - - fn sub(a: Self, b: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = a.ps[i].sub(b.ps[i]); - } - return ret; - } - - // Samples v[i] from centered binomial distribution with the given η, - // seed and nonce+i. - fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self { - var ret: Self = undefined; - for (0..K) |i| { - ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed); - } - return ret; - } - - // Sets p to the inner product of a and b using "pointwise" multiplication. - // - // See MulHat() and NTT() for a description of the multiplication. - // Assumes a and b are in Montgomery form. p will be in Montgomery form, - // and its coefficients will be bounded in absolute value by 2kq. - // If a and b are not in Montgomery form, then the action is the same - // as "pointwise" multiplication followed by multiplying by R⁻¹, the inverse - // of the Montgomery factor. - fn dotHat(a: Self, b: Self) Poly { - var ret: Poly = Poly.zero; - for (0..K) |i| { - ret = ret.add(a.ps[i].mulHat(b.ps[i])); - } - return ret; - } - - fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 { - const cs = comptime Poly.compressedSize(d); - var ret: [compressedSize(d)]u8 = undefined; - inline for (0..K) |i| { - ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d); - } - return ret; - } - - fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self { - const cs = comptime Poly.compressedSize(d); - var ret: Self = undefined; - inline for (0..K) |i| { - ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]); - } - return ret; - } - - /// Serializes the key into a byte array. - fn toBytes(v: Self) [bytes_length]u8 { - var ret: [bytes_length]u8 = undefined; - inline for (0..K) |i| { - ret[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length].* = v.ps[i].toBytes(); - } - return ret; - } - - /// Deserializes the key from a byte array. - fn fromBytes(buf: *const [bytes_length]u8) Self { - var ret: Self = undefined; - inline for (0..K) |i| { - ret.ps[i] = Poly.fromBytes( - buf[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length], - ); - } - return ret; - } - }; -} - -// A matrix of K vectors -fn Mat(comptime K: u8) type { - return struct { - const Self = @This(); - vs: [K]Vec(K), - - fn uniform(seed: [32]u8, comptime transposed: bool) Self { - var ret: Self = undefined; - var i: u8 = 0; - while (i < K) : (i += 1) { - var j: u8 = 0; - while (j < K) : (j += 1) { - ret.vs[i].ps[j] = Poly.uniform( - seed, - if (transposed) i else j, - if (transposed) j else i, - ); - } - } - return ret; - } - - // Returns transpose of A - fn transpose(m: Self) Self { - var ret: Self = undefined; - for (0..K) |i| { - for (0..K) |j| { - ret.vs[i].ps[j] = m.vs[j].ps[i]; - } - } - return ret; - } - }; -} - -// Returns `true` if a ≠ b. -fn ctneq(comptime len: usize, a: [len]u8, b: [len]u8) u1 { - return 1 - @intFromBool(crypto.utils.timingSafeEql([len]u8, a, b)); -} - -// Copy src into dst given b = 1. -fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void { - const mask = @as(u8, 0) -% b; - for (0..len) |i| { - dst[i] ^= mask & (dst[i] ^ src[i]); - } -} - -test "MulHat" { - var rnd = RndGen.init(0); - - for (0..100) |_| { - const a = Poly.randAbsLeqQ(&rnd); - const b = Poly.randAbsLeqQ(&rnd); - - const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize(); - var p: Poly = undefined; - - @memset(&p.cs, 0); - - for (0..N) |i| { - for (0..N) |j| { - var v = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[j])); - var k = i + j; - if (k >= N) { - // Recall Xᴺ = -1. - k -= N; - v = -v; - } - p.cs[k] = feBarrettReduce(v + p.cs[k]); - } - } - - p = p.toMont().normalize(); - - try testing.expectEqual(p, p2); - } -} - -test "NTT" { - var rnd = RndGen.init(0); - - for (0..1000) |_| { - var p = Poly.randAbsLeqQ(&rnd); - const q = p.toMont().normalize(); - p = p.ntt(); - - for (0..N) |i| { - try testing.expect(p.cs[i] <= 7 * Q and -7 * Q <= p.cs[i]); - } - - p = p.normalize().invNTT(); - for (0..N) |i| { - try testing.expect(p.cs[i] <= Q and -Q <= p.cs[i]); - } - - p = p.normalize(); - - try testing.expectEqual(p, q); - } -} - -test "Compression" { - var rnd = RndGen.init(0); - inline for (.{ 1, 4, 5, 10, 11 }) |d| { - for (0..1000) |_| { - const p = Poly.randNormalized(&rnd); - const pp = p.compress(d); - const pq = Poly.decompress(d, &pp).compress(d); - try testing.expectEqual(pp, pq); - } - } -} - -test "noise" { - var seed: [32]u8 = undefined; - for (&seed, 0..) |*s, i| { - s.* = @as(u8, @intCast(i)); - } - try testing.expectEqual(Poly.noise(3, 37, &seed).cs, .{ - 0, 0, 1, -1, 0, 2, 0, -1, -1, 3, 0, 1, -2, -2, 0, 1, -2, - 1, 0, -2, 3, 0, 0, 0, 1, 3, 1, 1, 2, 1, -1, -1, -1, 0, - 1, 0, 1, 0, 2, 0, 1, -2, 0, -1, -1, -2, 1, -1, -1, 2, -1, - 1, 1, 2, -3, -1, -1, 0, 0, 0, 0, 1, -1, -2, -2, 0, -2, 0, - 0, 0, 1, 0, -1, -1, 1, -2, 2, 0, 0, 2, -2, 0, 1, 0, 1, - 1, 1, 0, 1, -2, -1, -2, -1, 1, 0, 0, 0, 0, 0, 1, 0, -1, - -1, 0, -1, 1, 0, 1, 0, -1, -1, 0, -2, 2, 0, -2, 1, -1, 0, - 1, -1, -1, 2, 1, 0, 0, -2, -1, 2, 0, 0, 0, -1, -1, 3, 1, - 0, 1, 0, 1, 0, 2, 1, 0, 0, 1, 0, 1, 0, 0, -1, -1, -1, - 0, 1, 3, 1, 0, 1, 0, 1, -1, -1, -1, -1, 0, 0, -2, -1, -1, - 2, 0, 1, 0, 1, 0, 2, -2, 0, 1, 1, -3, -1, -2, -1, 0, 1, - 0, 1, -2, 2, 2, 1, 1, 0, -1, 0, -1, -1, 1, 0, -1, 2, 1, - -1, 1, 2, -2, 1, 2, 0, 1, 2, 1, 0, 0, 2, 1, 2, 1, 0, - 2, 1, 0, 0, -1, -1, 1, -1, 0, 1, -1, 2, 2, 0, 0, -1, 1, - 1, 1, 1, 0, 0, -2, 0, -1, 1, 2, 0, 0, 1, 1, -1, 1, 0, - 1, - }); - try testing.expectEqual(Poly.noise(2, 37, &seed).cs, .{ - 1, 0, 1, -1, -1, -2, -1, -1, 2, 0, -1, 0, 0, -1, - 1, 1, -1, 1, 0, 2, -2, 0, 1, 2, 0, 0, -1, 1, - 0, -1, 1, -1, 1, 2, 1, 1, 0, -1, 1, -1, -2, -1, - 1, -1, -1, -1, 2, -1, -1, 0, 0, 1, 1, -1, 1, 1, - 1, 1, -1, -2, 0, 1, 0, 0, 2, 1, -1, 2, 0, 0, - 1, 1, 0, -1, 0, 0, -1, -1, 2, 0, 1, -1, 2, -1, - -1, -1, -1, 0, -2, 0, 2, 1, 0, 0, 0, -1, 0, 0, - 0, -1, -1, 0, -1, -1, 0, -1, 0, 0, -2, 1, 1, 0, - 1, 0, 1, 0, 1, 1, -1, 2, 0, 1, -1, 1, 2, 0, - 0, 0, 0, -1, -1, -1, 0, 1, 0, -1, 2, 0, 0, 1, - 1, 1, 0, 1, -1, 1, 2, 1, 0, 2, -1, 1, -1, -2, - -1, -2, -1, 1, 0, -2, -2, -1, 1, 0, 0, 0, 0, 1, - 0, 0, 0, 2, 2, 0, 1, 0, -1, -1, 0, 2, 0, 0, - -2, 1, 0, 2, 1, -1, -2, 0, 0, -1, 1, 1, 0, 0, - 2, 0, 1, 1, -2, 1, -2, 1, 1, 0, 2, 0, -1, 0, - -1, 0, 1, 2, 0, 1, 0, -2, 1, -2, -2, 1, -1, 0, - -1, 1, 1, 0, 0, 0, 1, 0, -1, 1, 1, 0, 0, 0, - 0, 1, 0, 1, -1, 0, 1, -1, -1, 2, 0, 0, 1, -1, - 0, 1, -1, 0, - }); -} - -test "uniform sampling" { - var seed: [32]u8 = undefined; - for (&seed, 0..) |*s, i| { - s.* = @as(u8, @intCast(i)); - } - try testing.expectEqual(Poly.uniform(seed, 1, 0).cs, .{ - 797, 993, 161, 6, 2608, 2385, 2096, 2661, 1676, 247, 2440, - 342, 634, 194, 1570, 2848, 986, 684, 3148, 3208, 2018, 351, - 2288, 612, 1394, 170, 1521, 3119, 58, 596, 2093, 1549, 409, - 2156, 1934, 1730, 1324, 388, 446, 418, 1719, 2202, 1812, 98, - 1019, 2369, 214, 2699, 28, 1523, 2824, 273, 402, 2899, 246, - 210, 1288, 863, 2708, 177, 3076, 349, 44, 949, 854, 1371, - 957, 292, 2502, 1617, 1501, 254, 7, 1761, 2581, 2206, 2655, - 1211, 629, 1274, 2358, 816, 2766, 2115, 2985, 1006, 2433, 856, - 2596, 3192, 1, 1378, 2345, 707, 1891, 1669, 536, 1221, 710, - 2511, 120, 1176, 322, 1897, 2309, 595, 2950, 1171, 801, 1848, - 695, 2912, 1396, 1931, 1775, 2904, 893, 2507, 1810, 2873, 253, - 1529, 1047, 2615, 1687, 831, 1414, 965, 3169, 1887, 753, 3246, - 1937, 115, 2953, 586, 545, 1621, 1667, 3187, 1654, 1988, 1857, - 512, 1239, 1219, 898, 3106, 391, 1331, 2228, 3169, 586, 2412, - 845, 768, 156, 662, 478, 1693, 2632, 573, 2434, 1671, 173, - 969, 364, 1663, 2701, 2169, 813, 1000, 1471, 720, 2431, 2530, - 3161, 733, 1691, 527, 2634, 335, 26, 2377, 1707, 767, 3020, - 950, 502, 426, 1138, 3208, 2607, 2389, 44, 1358, 1392, 2334, - 875, 2097, 173, 1697, 2578, 942, 1817, 974, 1165, 2853, 1958, - 2973, 3282, 271, 1236, 1677, 2230, 673, 1554, 96, 242, 1729, - 2518, 1884, 2272, 71, 1382, 924, 1807, 1610, 456, 1148, 2479, - 2152, 238, 2208, 2329, 713, 1175, 1196, 757, 1078, 3190, 3169, - 708, 3117, 154, 1751, 3225, 1364, 154, 23, 2842, 1105, 1419, - 79, 5, 2013, - }); -} - -test "Polynomial packing" { - var rnd = RndGen.init(0); - - for (0..1000) |_| { - const p = Poly.randNormalized(&rnd); - try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p); - } -} - -test "Test inner PKE" { - var seed: [32]u8 = undefined; - var pt: [32]u8 = undefined; - for (&seed, &pt, 0..) |*s, *p, i| { - s.* = @as(u8, @intCast(i)); - p.* = @as(u8, @intCast(i + 32)); - } - inline for (modes) |mode| { - for (0..100) |i| { - var pk: mode.InnerPk = undefined; - var sk: mode.InnerSk = undefined; - seed[0] = @as(u8, @intCast(i)); - mode.innerKeyFromSeed(seed, &pk, &sk); - for (0..10) |j| { - seed[1] = @as(u8, @intCast(j)); - try testing.expectEqual(sk.decrypt(&pk.encrypt(&pt, &seed)), pt); - } - } - } -} - -test "Test happy flow" { - var seed: [64]u8 = undefined; - for (&seed, 0..) |*s, i| { - s.* = @as(u8, @intCast(i)); - } - inline for (modes) |mode| { - for (0..100) |i| { - seed[0] = @as(u8, @intCast(i)); - const kp = try mode.KeyPair.create(seed); - const sk = try mode.SecretKey.fromBytes(&kp.secret_key.toBytes()); - try testing.expectEqual(sk, kp.secret_key); - const pk = try mode.PublicKey.fromBytes(&kp.public_key.toBytes()); - try testing.expectEqual(pk, kp.public_key); - for (0..10) |j| { - seed[1] = @as(u8, @intCast(j)); - const e = pk.encaps(seed[0..32].*); - try testing.expectEqual(e.shared_secret, try sk.decaps(&e.ciphertext)); - } - } - } -} - -// Code to test NIST Known Answer Tests (KAT), see PQCgenKAT.c. - -const sha2 = crypto.hash.sha2; - -test "NIST KAT test" { - inline for (.{ - .{ Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547" }, - .{ Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5" }, - .{ Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2" }, - }) |modeHash| { - const mode = modeHash[0]; - var seed: [48]u8 = undefined; - for (&seed, 0..) |*s, i| { - s.* = @as(u8, @intCast(i)); - } - var f = sha2.Sha256.init(.{}); - const fw = f.writer(); - var g = NistDRBG.init(seed); - try std.fmt.format(fw, "# {s}\n\n", .{mode.name}); - for (0..100) |i| { - g.fill(&seed); - try std.fmt.format(fw, "count = {}\n", .{i}); - try std.fmt.format(fw, "seed = {s}\n", .{std.fmt.fmtSliceHexUpper(&seed)}); - var g2 = NistDRBG.init(seed); - - // This is not equivalent to g2.fill(kseed[:]). As the reference - // implementation calls randombytes twice generating the keypair, - // we have to do that as well. - var kseed: [64]u8 = undefined; - var eseed: [32]u8 = undefined; - g2.fill(kseed[0..32]); - g2.fill(kseed[32..64]); - g2.fill(&eseed); - const kp = try mode.KeyPair.create(kseed); - const e = kp.public_key.encaps(eseed); - const ss2 = try kp.secret_key.decaps(&e.ciphertext); - try testing.expectEqual(ss2, e.shared_secret); - try std.fmt.format(fw, "pk = {s}\n", .{std.fmt.fmtSliceHexUpper(&kp.public_key.toBytes())}); - try std.fmt.format(fw, "sk = {s}\n", .{std.fmt.fmtSliceHexUpper(&kp.secret_key.toBytes())}); - try std.fmt.format(fw, "ct = {s}\n", .{std.fmt.fmtSliceHexUpper(&e.ciphertext)}); - try std.fmt.format(fw, "ss = {s}\n\n", .{std.fmt.fmtSliceHexUpper(&e.shared_secret)}); - } - - var out: [32]u8 = undefined; - f.final(&out); - var outHex: [64]u8 = undefined; - _ = try std.fmt.bufPrint(&outHex, "{s}", .{std.fmt.fmtSliceHexLower(&out)}); - try testing.expectEqual(outHex, modeHash[1].*); - } -} - -const NistDRBG = struct { - key: [32]u8, - v: [16]u8, - - fn incV(g: *NistDRBG) void { - var j: usize = 15; - while (j >= 0) : (j -= 1) { - if (g.v[j] == 255) { - g.v[j] = 0; - } else { - g.v[j] += 1; - break; - } - } - } - - // AES256_CTR_DRBG_Update(pd, &g.key, &g.v). - fn update(g: *NistDRBG, pd: ?[48]u8) void { - var buf: [48]u8 = undefined; - const ctx = crypto.core.aes.Aes256.initEnc(g.key); - var i: usize = 0; - while (i < 3) : (i += 1) { - g.incV(); - var block: [16]u8 = undefined; - ctx.encrypt(&block, &g.v); - buf[i * 16 ..][0..16].* = block; - } - if (pd) |p| { - for (&buf, p) |*b, x| { - b.* ^= x; - } - } - g.key = buf[0..32].*; - g.v = buf[32..48].*; - } - - // randombytes. - fn fill(g: *NistDRBG, out: []u8) void { - var block: [16]u8 = undefined; - var dst = out; - - const ctx = crypto.core.aes.Aes256.initEnc(g.key); - while (dst.len > 0) { - g.incV(); - ctx.encrypt(&block, &g.v); - if (dst.len < 16) { - @memcpy(dst, block[0..dst.len]); - break; - } - dst[0..block.len].* = block; - dst = dst[16..dst.len]; - } - g.update(null); - } - - fn init(seed: [48]u8) NistDRBG { - var ret: NistDRBG = .{ .key = .{0} ** 32, .v = .{0} ** 16 }; - ret.update(seed); - return ret; - } -}; diff --git a/lib/std/crypto/ml_kem.zig b/lib/std/crypto/ml_kem.zig @@ -0,0 +1,1830 @@ +//! Implementation of the IND-CCA2 post-quantum secure key encapsulation mechanism (KEM) +//! ML-KEM (NIST FIPS-203 publication) and CRYSTALS-Kyber (v3.02/"draft00" CFRG draft). +//! +//! The schemes are not finalized yet, and are still subject to breaking changes. +//! +//! The Kyber namespace suffix (currently `_d00`) refers to the version currently +//! implemented, in accordance with the draft. +//! The ML-KEM namespace suffix (currently `_01`) refers to the NIST FIPS-203 draft +//! published on August 24, 2023, with the unintentional transposition of  having been reverted. +//! +//! Suffixes may not be updated if new versions of the documents only include editorial changes. +//! The suffixes will be removed once the schemes are finalized. +//! +//! Quoting from the CFRG I-D: +//! +//! Kyber is not a Diffie-Hellman (DH) style non-interactive key +//! agreement, but instead, Kyber is a Key Encapsulation Method (KEM). +//! In essence, a KEM is a Public-Key Encryption (PKE) scheme where the +//! plaintext cannot be specified, but is generated as a random key as +//! part of the encryption. A KEM can be transformed into an unrestricted +//! PKE using HPKE (RFC9180). On its own, a KEM can be used as a key +//! agreement method in TLS. +//! +//! Kyber is an IND-CCA2 secure KEM. It is constructed by applying a +//! Fujisaki--Okamato style transformation on InnerPKE, which is the +//! underlying IND-CPA secure Public Key Encryption scheme. We cannot +//! use InnerPKE directly, as its ciphertexts are malleable. +//! +//! ``` +//! F.O. transform +//! InnerPKE ----------------------> Kyber +//! IND-CPA IND-CCA2 +//! ``` +//! +//! Kyber is a lattice-based scheme. More precisely, its security is +//! based on the learning-with-errors-and-rounding problem in module +//! lattices (MLWER). The underlying polynomial ring R (defined in +//! Section 5) is chosen such that multiplication is very fast using the +//! number theoretic transform (NTT, see Section 5.1.3). +//! +//! An InnerPKE private key is a vector _s_ over R of length k which is +//! _small_ in a particular way. Here k is a security parameter akin to +//! the size of a prime modulus. For Kyber512, which targets AES-128's +//! security level, the value of k is 2. +//! +//! The public key consists of two values: +//! +//! * _A_ a uniformly sampled k by k matrix over R _and_ +//! +//! * _t = A s + e_, where e is a suitably small masking vector. +//! +//! Distinguishing between such A s + e and a uniformly sampled t is the +//! module learning-with-errors (MLWE) problem. If that is hard, then it +//! is also hard to recover the private key from the public key as that +//! would allow you to distinguish between those two. +//! +//! To save space in the public key, A is recomputed deterministically +//! from a seed _rho_. +//! +//! A ciphertext for a message m under this public key is a pair (c_1, +//! c_2) computed roughly as follows: +//! +//! c_1 = Compress(A^T r + e_1, d_u) +//! c_2 = Compress(t^T r + e_2 + Decompress(m, 1), d_v) +//! +//! where +//! +//! * e_1, e_2 and r are small blinds; +//! +//! * Compress(-, d) removes some information, leaving d bits per +//! coefficient and Decompress is such that Compress after Decompress +//! does nothing and +//! +//! * d_u, d_v are scheme parameters. +//! +//! Distinguishing such a ciphertext and uniformly sampled (c_1, c_2) is +//! an example of the full MLWER problem, see section 4.4 of [KyberV302]. +//! +//! To decrypt the ciphertext, one computes +//! +//! m = Compress(Decompress(c_2, d_v) - s^T Decompress(c_1, d_u), 1). +//! +//! It it not straight-forward to see that this formula is correct. In +//! fact, there is negligible but non-zero probability that a ciphertext +//! does not decrypt correctly given by the DFP column in Table 4. This +//! failure probability can be computed by a careful automated analysis +//! of the probabilities involved, see kyber_failure.py of [SecEst]. +//! +//! [KyberV302](https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf) +//! [I-D](https://github.com/bwesterb/draft-schwabe-cfrg-kyber) +//! [SecEst](https://github.com/pq-crystals/security-estimates) + +// TODO +// +// - The bottleneck in Kyber are the various hash/xof calls: +// - Optimize Zig's keccak implementation. +// - Use SIMD to compute keccak in parallel. +// - Can we track bounds of coefficients using comptime types without +// duplicating code? +// - Would be neater to have tests closer to the thing under test. +// - When generating a keypair, we have a copy of the inner public key with +// its large matrix A in both the public key and the private key. In Go we +// can just have a pointer in the private key to the public key, but +// how do we do this elegantly in Zig? + +const std = @import("std"); +const builtin = @import("builtin"); + +const testing = std.testing; +const assert = std.debug.assert; +const crypto = std.crypto; +const errors = std.crypto.errors; +const math = std.math; +const mem = std.mem; +const RndGen = std.Random.DefaultPrng; +const sha3 = crypto.hash.sha3; + +// Q is the parameter q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1. +const Q: i16 = 3329; + +// Montgomery R +const R: i32 = 1 << 16; + +// Parameter n, degree of polynomials. +const N: usize = 256; + +// Size of "small" vectors used in encryption blinds. +const eta2: u8 = 2; + +const Params = struct { + name: []const u8, + + // NIST ML-KEM variant instead of Kyber as originally submitted. + ml_kem: bool = false, + + // Width and height of the matrix A. + k: u8, + + // Size of "small" vectors used in private key and encryption blinds. + eta1: u8, + + // How many bits to retain of u, the private-key independent part + // of the ciphertext. + du: u8, + + // How many bits to retain of v, the private-key dependent part + // of the ciphertext. + dv: u8, +}; + +pub const kyber_d00 = struct { + pub const Kyber512 = Kyber(.{ + .name = "Kyber512", + .k = 2, + .eta1 = 3, + .du = 10, + .dv = 4, + }); + + pub const Kyber768 = Kyber(.{ + .name = "Kyber768", + .k = 3, + .eta1 = 2, + .du = 10, + .dv = 4, + }); + + pub const Kyber1024 = Kyber(.{ + .name = "Kyber1024", + .k = 4, + .eta1 = 2, + .du = 11, + .dv = 5, + }); +}; + +pub const ml_kem_01 = struct { + pub const MLKem512 = Kyber(.{ + .name = "ML-KEM-512", + .ml_kem = true, + .k = 2, + .eta1 = 3, + .du = 10, + .dv = 4, + }); + + pub const MLKem768 = Kyber(.{ + .name = "ML-KEM-768", + .ml_kem = true, + .k = 3, + .eta1 = 2, + .du = 10, + .dv = 4, + }); + + pub const MLKem1024 = Kyber(.{ + .name = "ML-KEM-1024", + .ml_kem = true, + .k = 4, + .eta1 = 2, + .du = 11, + .dv = 5, + }); +}; + +const modes = [_]type{ + kyber_d00.Kyber512, + kyber_d00.Kyber768, + kyber_d00.Kyber1024, + ml_kem_01.MLKem512, + ml_kem_01.MLKem768, + ml_kem_01.MLKem1024, +}; +const h_length: usize = 32; +const inner_seed_length: usize = 32; +const common_encaps_seed_length: usize = 32; +const common_shared_key_size: usize = 32; + +fn Kyber(comptime p: Params) type { + return struct { + // Size of a ciphertext, in bytes. + pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv); + + const Self = @This(); + const V = Vec(p.k); + const M = Mat(p.k); + + /// Length (in bytes) of a shared secret. + pub const shared_length = common_shared_key_size; + /// Length (in bytes) of a seed for deterministic encapsulation. + pub const encaps_seed_length = common_encaps_seed_length; + /// Length (in bytes) of a seed for key generation. + pub const seed_length: usize = inner_seed_length + shared_length; + /// Algorithm name. + pub const name = p.name; + + /// A shared secret, and an encapsulated (encrypted) representation of it. + pub const EncapsulatedSecret = struct { + shared_secret: [shared_length]u8, + ciphertext: [ciphertext_length]u8, + }; + + /// A Kyber public key. + pub const PublicKey = struct { + pk: InnerPk, + + // Cached + hpk: [h_length]u8, // H(pk) + + /// Size of a serialized representation of the key, in bytes. + pub const bytes_length = InnerPk.bytes_length; + + /// Generates a shared secret, and encapsulates it for the public key. + /// If `seed` is `null`, a random seed is used. This is recommended. + /// If `seed` is set, encapsulation is deterministic. + pub fn encaps(pk: PublicKey, seed_: ?[encaps_seed_length]u8) EncapsulatedSecret { + var m: [inner_plaintext_length]u8 = undefined; + + if (seed_) |seed| { + if (p.ml_kem) { + @memcpy(&m, &seed); + } else { + // m = H(seed) + sha3.Sha3_256.hash(&seed, &m, .{}); + } + } else { + crypto.random.bytes(&m); + } + + // (K', r) = G(m ‖ H(pk)) + var kr: [inner_plaintext_length + h_length]u8 = undefined; + var g = sha3.Sha3_512.init(.{}); + g.update(&m); + g.update(&pk.hpk); + g.final(&kr); + + // c = innerEncrypt(pk, m, r) + const ct = pk.pk.encrypt(&m, kr[32..64]); + + if (p.ml_kem) { + return EncapsulatedSecret{ + .shared_secret = kr[0..shared_length].*, // ML-KEM: K = K' + .ciphertext = ct, + }; + } else { + // Compute H(c) and put in second slot of kr, which will be (K', H(c)). + sha3.Sha3_256.hash(&ct, kr[32..], .{}); + + var ss: [shared_length]u8 = undefined; + sha3.Shake256.hash(&kr, &ss, .{}); + return EncapsulatedSecret{ + .shared_secret = ss, // Kyber: K = KDF(K' ‖ H(c)) + .ciphertext = ct, + }; + } + } + + /// Serializes the key into a byte array. + pub fn toBytes(pk: PublicKey) [bytes_length]u8 { + return pk.pk.toBytes(); + } + + /// Deserializes the key from a byte array. + pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey { + var ret: PublicKey = undefined; + ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]); + sha3.Sha3_256.hash(buf, &ret.hpk, .{}); + return ret; + } + }; + + /// A Kyber secret key. + pub const SecretKey = struct { + sk: InnerSk, + pk: InnerPk, + hpk: [h_length]u8, // H(pk) + z: [shared_length]u8, + + /// Size of a serialized representation of the key, in bytes. + pub const bytes_length: usize = + InnerSk.bytes_length + InnerPk.bytes_length + h_length + shared_length; + + /// Decapsulates the shared secret within ct using the private key. + pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 { + // m' = innerDec(ct) + const m2 = sk.sk.decrypt(ct); + + // (K'', r') = G(m' ‖ H(pk)) + var kr2: [64]u8 = undefined; + var g = sha3.Sha3_512.init(.{}); + g.update(&m2); + g.update(&sk.hpk); + g.final(&kr2); + + // ct' = innerEnc(pk, m', r') + const ct2 = sk.pk.encrypt(&m2, kr2[32..64]); + + // Compute H(ct) and put in the second slot of kr2 which will be (K'', H(ct)). + sha3.Sha3_256.hash(ct, kr2[32..], .{}); + + // Replace K'' by z in the first slot of kr2 if ct ≠ ct'. + cmov(32, kr2[0..32], sk.z, ctneq(ciphertext_length, ct.*, ct2)); + + if (p.ml_kem) { + // ML-KEM: K = K''/z + return kr2[0..shared_length].*; + } else { + // Kyber: K = KDF(K''/z ‖ H(c)) + var ss: [shared_length]u8 = undefined; + sha3.Shake256.hash(&kr2, &ss, .{}); + return ss; + } + } + + /// Serializes the key into a byte array. + pub fn toBytes(sk: SecretKey) [bytes_length]u8 { + return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z; + } + + /// Deserializes the key from a byte array. + pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey { + var ret: SecretKey = undefined; + comptime var s: usize = 0; + ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]); + s += InnerSk.bytes_length; + ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]); + s += InnerPk.bytes_length; + ret.hpk = buf[s..][0..h_length].*; + s += h_length; + ret.z = buf[s..][0..shared_length].*; + return ret; + } + }; + + /// A Kyber key pair. + pub const KeyPair = struct { + secret_key: SecretKey, + public_key: PublicKey, + + /// Create a new key pair. + /// If seed is null, a random seed will be generated. + /// If a seed is provided, the key pair will be determinsitic. + pub fn create(seed_: ?[seed_length]u8) !KeyPair { + const seed = seed_ orelse sk: { + var random_seed: [seed_length]u8 = undefined; + crypto.random.bytes(&random_seed); + break :sk random_seed; + }; + var ret: KeyPair = undefined; + ret.secret_key.z = seed[inner_seed_length..seed_length].*; + + // Generate inner key + innerKeyFromSeed( + seed[0..inner_seed_length].*, + &ret.public_key.pk, + &ret.secret_key.sk, + ); + ret.secret_key.pk = ret.public_key.pk; + + // Copy over z from seed. + ret.secret_key.z = seed[inner_seed_length..seed_length].*; + + // Compute H(pk) + sha3.Sha3_256.hash(&ret.public_key.pk.toBytes(), &ret.secret_key.hpk, .{}); + ret.public_key.hpk = ret.secret_key.hpk; + + return ret; + } + }; + + // Size of plaintexts of the in + const inner_plaintext_length: usize = Poly.compressedSize(1); + + const InnerPk = struct { + rho: [32]u8, // ρ, the seed for the matrix A + th: V, // NTT(t), normalized + + // Cached values + aT: M, + + const bytes_length = V.bytes_length + 32; + + fn encrypt( + pk: InnerPk, + pt: *const [inner_plaintext_length]u8, + seed: *const [32]u8, + ) [ciphertext_length]u8 { + // Sample r, e₁ and e₂ appropriately + const rh = V.noise(p.eta1, 0, seed).ntt().barrettReduce(); + const e1 = V.noise(eta2, p.k, seed); + const e2 = Poly.noise(eta2, 2 * p.k, seed); + + // Next we compute u = Aᵀ r + e₁. First Aᵀ. + var u: V = undefined; + for (0..p.k) |i| { + // Note that coefficients of r are bounded by q and those of Aᵀ + // are bounded by 4.5q and so their product is bounded by 2¹⁵q + // as required for multiplication. + u.ps[i] = pk.aT.vs[i].dotHat(rh); + } + + // Aᵀ and r were not in Montgomery form, so the Montgomery + // multiplications in the inner product added a factor R⁻¹ which + // the InvNTT cancels out. + u = u.barrettReduce().invNTT().add(e1).normalize(); + + // Next, compute v = <t, r> + e₂ + Decompress_q(m, 1) + const v = pk.th.dotHat(rh).barrettReduce().invNTT() + .add(Poly.decompress(1, pt)).add(e2).normalize(); + + return u.compress(p.du) ++ v.compress(p.dv); + } + + fn toBytes(pk: InnerPk) [bytes_length]u8 { + return pk.th.toBytes() ++ pk.rho; + } + + fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk { + var ret: InnerPk = undefined; + + const th_bytes = buf[0..V.bytes_length]; + ret.th = V.fromBytes(th_bytes).normalize(); + + if (p.ml_kem) { + // Verify that the coefficients used a canonical representation. + if (!mem.eql(u8, &ret.th.toBytes(), th_bytes)) { + return error.NonCanonical; + } + } + + ret.rho = buf[V.bytes_length..bytes_length].*; + ret.aT = M.uniform(ret.rho, true); + return ret; + } + }; + + // Private key of the inner PKE + const InnerSk = struct { + sh: V, // NTT(s), normalized + const bytes_length = V.bytes_length; + + fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 { + const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]); + const v = Poly.decompress( + p.dv, + ct[comptime V.compressedSize(p.du)..ciphertext_length], + ); + + // Compute m = v - <s, u> + return v.sub(sk.sh.dotHat(u.ntt()).barrettReduce().invNTT()) + .normalize().compress(1); + } + + fn toBytes(sk: InnerSk) [bytes_length]u8 { + return sk.sh.toBytes(); + } + + fn fromBytes(buf: *const [bytes_length]u8) InnerSk { + var ret: InnerSk = undefined; + ret.sh = V.fromBytes(buf).normalize(); + return ret; + } + }; + + // Derives inner PKE keypair from given seed. + fn innerKeyFromSeed(seed: [inner_seed_length]u8, pk: *InnerPk, sk: *InnerSk) void { + var expanded_seed: [64]u8 = undefined; + sha3.Sha3_512.hash(&seed, &expanded_seed, .{}); + pk.rho = expanded_seed[0..32].*; + const sigma = expanded_seed[32..64]; + pk.aT = M.uniform(pk.rho, false); // Expand ρ to A; we'll transpose later on + + // Sample secret vector s. + sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize(); + + const eh = Vec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e. + var th: V = undefined; + + // Next, we compute t = A s + e. + for (0..p.k) |i| { + // Note that coefficients of s are bounded by q and those of A + // are bounded by 4.5q and so their product is bounded by 2¹⁵q + // as required for multiplication. + // A and s were not in Montgomery form, so the Montgomery + // multiplications in the inner product added a factor R⁻¹ which + // we'll cancel out with toMont(). This will also ensure the + // coefficients of th are bounded in absolute value by q. + th.ps[i] = pk.aT.vs[i].dotHat(sk.sh).toMont(); + } + + pk.th = th.add(eh).normalize(); // bounded by 8q + pk.aT = pk.aT.transpose(); + } + }; +} + +// R mod q +const r_mod_q: i32 = @rem(@as(i32, R), Q); + +// R² mod q +const r2_mod_q: i32 = @rem(r_mod_q * r_mod_q, Q); + +// ζ is the degree 256 primitive root of unity used for the NTT. +const zeta: i16 = 17; + +// (128)⁻¹ R². Used in inverse NTT. +const r2_over_128: i32 = @mod(invertMod(128, Q) * r2_mod_q, Q); + +// zetas lists precomputed powers of the primitive root of unity in +// Montgomery representation used for the NTT: +// +// zetas[i] = ζᵇʳᵛ⁽ⁱ⁾ R mod q +// +// where ζ = 17, brv(i) is the bitreversal of a 7-bit number and R=2¹⁶ mod q. +const zetas = computeZetas(); + +// invNTTReductions keeps track of which coefficients to apply Barrett +// reduction to in Poly.invNTT(). +// +// Generated lazily: once a butterfly is computed which is about to +// overflow the i16, the largest coefficient is reduced. If that is +// not enough, the other coefficient is reduced as well. +// +// This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf +// TODO generate comptime? +const inv_ntt_reductions = [_]i16{ + -1, // after layer 1 + -1, // after layer 2 + 16, + 17, + 48, + 49, + 80, + 81, + 112, + 113, + 144, + 145, + 176, + 177, + 208, + 209, + 240, 241, -1, // after layer 3 + 0, 1, 32, + 33, 34, 35, + 64, 65, 96, + 97, 98, 99, + 128, 129, + 160, 161, 162, 163, 192, 193, 224, 225, 226, 227, -1, // after layer 4 + 2, 3, 66, 67, 68, 69, 70, 71, 130, 131, 194, + 195, 196, 197, + 198, 199, -1, // after layer 5 + 4, 5, 6, + 7, 132, 133, + 134, 135, 136, + 137, 138, 139, + 140, 141, + 142, 143, -1, // after layer 6 + -1, // after layer 7 +}; + +test "invNTTReductions bounds" { + // Checks whether the reductions proposed by invNTTReductions + // don't overflow during invNTT(). + var xs = [_]i32{1} ** 256; // start at |x| ≤ q + + var r: usize = 0; + var layer: math.Log2Int(usize) = 1; + while (layer < 8) : (layer += 1) { + const w = @as(usize, 1) << layer; + var i: usize = 0; + + while (i + w < 256) { + xs[i] = xs[i] + xs[i + w]; + try testing.expect(xs[i] <= 9); // we can't exceed 9q + xs[i + w] = 1; + i += 1; + if (@mod(i, w) == 0) { + i += w; + } + } + + while (true) { + const j = inv_ntt_reductions[r]; + r += 1; + if (j < 0) { + break; + } + xs[@as(usize, @intCast(j))] = 1; + } + } +} + +// Extended euclidean algorithm. +// +// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute +// modular inverse. +fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) { + if (a == 0) { + return .{ .gcd = b, .x = 0, .y = 1 }; + } + const r = eea(@rem(b, a), a); + return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x }; +} + +fn EeaResult(comptime T: type) type { + return struct { gcd: T, x: T, y: T }; +} + +// Returns least common multiple of a and b. +fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) { + const r = eea(a, b); + return a * b / r.gcd; +} + +// Invert modulo p. +fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) { + const r = eea(a, p); + assert(r.gcd == 1); + return r.x; +} + +// Reduce mod q for testing. +fn modQ32(x: i32) i16 { + var y = @as(i16, @intCast(@rem(x, @as(i32, Q)))); + if (y < 0) { + y += Q; + } + return y; +} + +// Given -2¹⁵ q ≤ x < 2¹⁵ q, returns -q < y < q with x 2⁻¹⁶ = y (mod q). +fn montReduce(x: i32) i16 { + const qInv = comptime invertMod(@as(i32, Q), R); + // This is Montgomery reduction with R=2¹⁶. + // + // Note gcd(2¹⁶, q) = 1 as q is prime. Write q' := 62209 = q⁻¹ mod R. + // First we compute + // + // m := ((x mod R) q') mod R + // = x q' mod R + // = int16(x q') + // = int16(int32(x) * int32(q')) + // + // Note that x q' might be as big as 2³² and could overflow the int32 + // multiplication in the last line. However for any int32s a and b, + // we have int32(int64(a)*int64(b)) = int32(a*b) and so the result is ok. + const m: i16 = @truncate(@as(i32, @truncate(x *% qInv))); + + // Note that x - m q is divisible by R; indeed modulo R we have + // + // x - m q ≡ x - x q' q ≡ x - x q⁻¹ q ≡ x - x = 0. + // + // We return y := (x - m q) / R. Note that y is indeed correct as + // modulo q we have + // + // y ≡ x R⁻¹ - m q R⁻¹ = x R⁻¹ + // + // and as both 2¹⁵ q ≤ m q, x < 2¹⁵ q, we have + // 2¹⁶ q ≤ x - m q < 2¹⁶ and so q ≤ (x - m q) / R < q as desired. + const yR = x - @as(i32, m) * @as(i32, Q); + return @bitCast(@as(u16, @truncate(@as(u32, @bitCast(yR)) >> 16))); +} + +test "Test montReduce" { + var rnd = RndGen.init(0); + for (0..1000) |_| { + const bound = comptime @as(i32, Q) * (1 << 15); + const x = rnd.random().intRangeLessThan(i32, -bound, bound); + const y = montReduce(x); + try testing.expect(-Q < y and y < Q); + try testing.expectEqual(modQ32(x), modQ32(@as(i32, y) * R)); + } +} + +// Given any x, return x R mod q where R=2¹⁶. +fn feToMont(x: i16) i16 { + // Note |1353 x| ≤ 1353 2¹⁵ ≤ 13318 q ≤ 2¹⁵ q and so we're within + // the bounds of montReduce. + return montReduce(@as(i32, x) * r2_mod_q); +} + +test "Test feToMont" { + var x: i32 = -(1 << 15); + while (x < 1 << 15) : (x += 1) { + const y = feToMont(@as(i16, @intCast(x))); + try testing.expectEqual(modQ32(@as(i32, y)), modQ32(x * r_mod_q)); + } +} + +// Given any x, compute 0 ≤ y ≤ q with x = y (mod q). +// +// Beware: we might have feBarrettReduce(x) = q ≠ 0 for some x. In fact, +// this happens if and only if x = -nq for some positive integer n. +fn feBarrettReduce(x: i16) i16 { + // This is standard Barrett reduction. + // + // For any x we have x mod q = x - ⌊x/q⌋ q. We will use 20159/2²⁶ as + // an approximation of 1/q. Note that 0 ≤ 20159/2²⁶ - 1/q ≤ 0.135/2²⁶ + // and so | x 20156/2²⁶ - x/q | ≤ 2⁻¹⁰ for |x| ≤ 2¹⁶. For all x + // not a multiple of q, the number x/q is further than 1/q from any integer + // and so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋. If x is a multiple of q and x is positive, + // then x 20156/2²⁶ is larger than x/q so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋ as well. + // Finally, if x is negative multiple of q, then ⌊x 20156/2²⁶⌋ = ⌊x/q⌋-1. + // Thus + // [ q if x=-nq for pos. integer n + // x - ⌊x 20156/2²⁶⌋ q = [ + // [ x mod q otherwise + // + // To actually compute this, note that + // + // ⌊x 20156/2²⁶⌋ = (20159 x) >> 26. + return x -% @as(i16, @intCast((@as(i32, x) * 20159) >> 26)) *% Q; +} + +test "Test Barrett reduction" { + var x: i32 = -(1 << 15); + while (x < 1 << 15) : (x += 1) { + var y1 = feBarrettReduce(@as(i16, @intCast(x))); + const y2 = @mod(@as(i16, @intCast(x)), Q); + if (x < 0 and @rem(-x, Q) == 0) { + y1 -= Q; + } + try testing.expectEqual(y1, y2); + } +} + +// Returns x if x < q and x - q otherwise. Assumes x ≥ -29439. +fn csubq(x: i16) i16 { + var r = x; + r -= Q; + r += (r >> 15) & Q; + return r; +} + +test "Test csubq" { + var x: i32 = -29439; + while (x < 1 << 15) : (x += 1) { + const y1 = csubq(@as(i16, @intCast(x))); + var y2 = @as(i16, @intCast(x)); + if (@as(i16, @intCast(x)) >= Q) { + y2 -= Q; + } + try testing.expectEqual(y1, y2); + } +} + +// Compute a^s mod p. +fn mpow(a: anytype, s: @TypeOf(a), p: @TypeOf(a)) @TypeOf(a) { + var ret: @TypeOf(a) = 1; + var s2 = s; + var a2 = a; + + while (true) { + if (s2 & 1 == 1) { + ret = @mod(ret * a2, p); + } + s2 >>= 1; + if (s2 == 0) { + break; + } + a2 = @mod(a2 * a2, p); + } + return ret; +} + +// Computes zetas table used by ntt and invNTT. +fn computeZetas() [128]i16 { + @setEvalBranchQuota(10000); + var ret: [128]i16 = undefined; + for (&ret, 0..) |*r, i| { + const t = @as(i16, @intCast(mpow(@as(i32, zeta), @bitReverse(@as(u7, @intCast(i))), Q))); + r.* = csubq(feBarrettReduce(feToMont(t))); + } + return ret; +} + +// An element of our base ring R which are polynomials over ℤ_q +// modulo the equation Xᴺ = -1, where q=3329 and N=256. +// +// This type is also used to store NTT-transformed polynomials, +// see Poly.NTT(). +// +// Coefficients aren't always reduced. See Normalize(). +const Poly = struct { + cs: [N]i16, + + const bytes_length = N / 2 * 3; + const zero: Poly = .{ .cs = .{0} ** N }; + + fn add(a: Poly, b: Poly) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = a.cs[i] + b.cs[i]; + } + return ret; + } + + fn sub(a: Poly, b: Poly) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = a.cs[i] - b.cs[i]; + } + return ret; + } + + // For testing, generates a random polynomial with for each + // coefficient |x| ≤ q. + fn randAbsLeqQ(rnd: anytype) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q); + } + return ret; + } + + // For testing, generates a random normalized polynomial. + fn randNormalized(rnd: anytype) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q); + } + return ret; + } + + // Executes a forward "NTT" on p. + // + // Assumes the coefficients are in absolute value ≤q. The resulting + // coefficients are in absolute value ≤7q. If the input is in Montgomery + // form, then the result is in Montgomery form and so (by linearity of the NTT) + // if the input is in regular form, then the result is also in regular form. + fn ntt(a: Poly) Poly { + // Note that ℤ_q does not have a primitive 512ᵗʰ root of unity (as 512 + // does not divide into q-1) and so we cannot do a regular NTT. ℤ_q + // does have a primitive 256ᵗʰ root of unity, the smallest of which + // is ζ := 17. + // + // Recall that our base ring R := ℤ_q[x] / (x²⁵⁶ + 1). The polynomial + // x²⁵⁶+1 will not split completely (as its roots would be 512ᵗʰ roots + // of unity.) However, it does split almost (using ζ¹²⁸ = -1): + // + // x²⁵⁶ + 1 = (x²)¹²⁸ - ζ¹²⁸ + // = ((x²)⁶⁴ - ζ⁶⁴)((x²)⁶⁴ + ζ⁶⁴) + // = ((x²)³² - ζ³²)((x²)³² + ζ³²)((x²)³² - ζ⁹⁶)((x²)³² + ζ⁹⁶) + // ⋮ + // = (x² - ζ)(x² + ζ)(x² - ζ⁶⁵)(x² + ζ⁶⁵) … (x² + ζ¹²⁷) + // + // Note that the powers of ζ that appear (from the second line down) are + // in binary + // + // 0100000 1100000 + // 0010000 1010000 0110000 1110000 + // 0001000 1001000 0101000 1101000 0011000 1011000 0111000 1111000 + // … + // + // That is: brv(2), brv(3), brv(4), …, where brv(x) denotes the 7-bit + // bitreversal of x. These powers of ζ are given by the Zetas array. + // + // The polynomials x² ± ζⁱ are irreducible and coprime, hence by + // the Chinese Remainder Theorem we know + // + // ℤ_q[x]/(x²⁵⁶+1) → ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷) + // + // given by a ↦ ( a mod x²-ζ, …, a mod x²+ζ¹²⁷ ) + // is an isomorphism, which is the "NTT". It can be efficiently computed by + // + // + // a ↦ ( a mod (x²)⁶⁴ - ζ⁶⁴, a mod (x²)⁶⁴ + ζ⁶⁴ ) + // ↦ ( a mod (x²)³² - ζ³², a mod (x²)³² + ζ³², + // a mod (x²)⁹⁶ - ζ⁹⁶, a mod (x²)⁹⁶ + ζ⁹⁶ ) + // + // et cetera + // If N was 8 then this can be pictured in the following diagram: + // + // https://cnx.org/resources/17ee4dfe517a6adda05377b25a00bf6e6c93c334/File0026.png + // + // Each cross is a Cooley-Tukey butterfly: it's the map + // + // (a, b) ↦ (a + ζb, a - ζb) + // + // for the appropriate power ζ for that column and row group. + var p = a; + var k: usize = 0; // index into zetas + + var l = N >> 1; + while (l > 1) : (l >>= 1) { + // On the nᵗʰ iteration of the l-loop, the absolute value of the + // coefficients are bounded by nq. + + // offset effectively loops over the row groups in this column; it is + // the first row in the row group. + var offset: usize = 0; + while (offset < N - l) : (offset += 2 * l) { + k += 1; + const z = @as(i32, zetas[k]); + + // j loops over each butterfly in the row group. + for (offset..offset + l) |j| { + const t = montReduce(z * @as(i32, p.cs[j + l])); + p.cs[j + l] = p.cs[j] - t; + p.cs[j] += t; + } + } + } + + return p; + } + + // Executes an inverse "NTT" on p and multiply by the Montgomery factor R. + // + // Assumes the coefficients are in absolute value ≤q. The resulting + // coefficients are in absolute value ≤q. If the input is in Montgomery + // form, then the result is in Montgomery form and so (by linearity) + // if the input is in regular form, then the result is also in regular form. + fn invNTT(a: Poly) Poly { + var k: usize = 127; // index into zetas + var r: usize = 0; // index into invNTTReductions + var p = a; + + // We basically do the oppposite of NTT, but postpone dividing by 2 in the + // inverse of the Cooley-Tukey butterfly and accumulate that into a big + // division by 2⁷ at the end. See the comments in the ntt() function. + + var l: usize = 2; + while (l < N) : (l <<= 1) { + var offset: usize = 0; + while (offset < N - l) : (offset += 2 * l) { + // As we're inverting, we need powers of ζ⁻¹ (instead of ζ). + // To be precise, we need ζᵇʳᵛ⁽ᵏ⁾⁻¹²⁸. However, as ζ⁻¹²⁸ = -1, + // we can use the existing zetas table instead of + // keeping a separate invZetas table as in Dilithium. + + const minZeta = @as(i32, zetas[k]); + k -= 1; + + for (offset..offset + l) |j| { + // Gentleman-Sande butterfly: (a, b) ↦ (a + b, ζ(a-b)) + const t = p.cs[j + l] - p.cs[j]; + p.cs[j] += p.cs[j + l]; + p.cs[j + l] = montReduce(minZeta * @as(i32, t)); + + // Note that if we had |a| < αq and |b| < βq before the + // butterfly, then now we have |a| < (α+β)q and |b| < q. + } + } + + // We let the invNTTReductions instruct us which coefficients to + // Barrett reduce. + while (true) { + const i = inv_ntt_reductions[r]; + r += 1; + if (i < 0) { + break; + } + p.cs[@as(usize, @intCast(i))] = feBarrettReduce(p.cs[@as(usize, @intCast(i))]); + } + } + + for (0..N) |j| { + // Note 1441 = (128)⁻¹ R². The coefficients are bounded by 9q, so + // as 1441 * 9 ≈ 2¹⁴ < 2¹⁵, we're within the required bounds + // for montReduce(). + p.cs[j] = montReduce(r2_over_128 * @as(i32, p.cs[j])); + } + + return p; + } + + // Normalizes coefficients. + // + // Ensures each coefficient is in {0, …, q-1}. + fn normalize(a: Poly) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = csubq(feBarrettReduce(a.cs[i])); + } + return ret; + } + + // Put p in Montgomery form. + fn toMont(a: Poly) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = feToMont(a.cs[i]); + } + return ret; + } + + // Barret reduce coefficients. + // + // Beware, this does not fully normalize coefficients. + fn barrettReduce(a: Poly) Poly { + var ret: Poly = undefined; + for (0..N) |i| { + ret.cs[i] = feBarrettReduce(a.cs[i]); + } + return ret; + } + + fn compressedSize(comptime d: u8) usize { + return @divTrunc(N * d, 8); + } + + // Returns packed Compress_q(p, d). + // + // Assumes p is normalized. + fn compress(p: Poly, comptime d: u8) [compressedSize(d)]u8 { + @setEvalBranchQuota(10000); + const q_over_2: u32 = comptime @divTrunc(Q, 2); // (q-1)/2 + const two_d_min_1: u32 = comptime (1 << d) - 1; // 2ᵈ-1 + var in_off: usize = 0; + var out_off: usize = 0; + + const batch_size: usize = comptime lcm(@as(i16, d), 8); + const in_batch_size: usize = comptime batch_size / d; + const out_batch_size: usize = comptime batch_size / 8; + + const out_length: usize = comptime @divTrunc(N * d, 8); + comptime assert(out_length * 8 == d * N); + var out = [_]u8{0} ** out_length; + + while (in_off < N) { + // First we compress into in. + var in: [in_batch_size]u16 = undefined; + inline for (0..in_batch_size) |i| { + // Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ + // = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ + // = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ + // = DIV((x << d) + q/2, q) & ((1<<d) - 1) + const t = @as(u24, @intCast(p.cs[in_off + i])) << d; + // Division by invariant multiplication, equivalent to DIV(t + q/2, q). + // A division may not be a constant-time operation, even with a constant denominator. + // Here, side channels would leak information about the shared secret, see https://kyberslash.cr.yp.to + // Multiplication, on the other hand, is a constant-time operation on the CPUs we currently support. + comptime assert(d <= 11); + comptime assert(((20642679 * @as(u64, Q)) >> 36) == 1); + const u: u32 = @intCast((@as(u64, t + q_over_2) * 20642679) >> 36); + in[i] = @intCast(u & two_d_min_1); + } + + // Now we pack the d-bit integers from `in' into out as bytes. + comptime var in_shift: usize = 0; + comptime var j: usize = 0; + comptime var i: usize = 0; + inline while (i < in_batch_size) : (j += 1) { + comptime var todo: usize = 8; + inline while (todo > 0) { + const out_shift = comptime 8 - todo; + out[out_off + j] |= @as(u8, @truncate((in[i] >> in_shift) << out_shift)); + + const done = comptime @min(@min(d, todo), d - in_shift); + todo -= done; + in_shift += done; + + if (in_shift == d) { + in_shift = 0; + i += 1; + } + } + } + + in_off += in_batch_size; + out_off += out_batch_size; + } + + return out; + } + + // Set p to Decompress_q(m, d). + fn decompress(comptime d: u8, in: *const [compressedSize(d)]u8) Poly { + @setEvalBranchQuota(10000); + const inLen = comptime @divTrunc(N * d, 8); + comptime assert(inLen * 8 == d * N); + var ret: Poly = undefined; + var in_off: usize = 0; + var out_off: usize = 0; + + const batch_size: usize = comptime lcm(@as(i16, d), 8); + const in_batch_size: usize = comptime batch_size / 8; + const out_batch_size: usize = comptime batch_size / d; + + while (out_off < N) { + comptime var in_shift: usize = 0; + comptime var j: usize = 0; + comptime var i: usize = 0; + inline while (i < out_batch_size) : (i += 1) { + // First, unpack next coefficient. + comptime var todo = d; + var out: u16 = 0; + + inline while (todo > 0) { + const out_shift = comptime d - todo; + const m = comptime (1 << d) - 1; + out |= (@as(u16, in[in_off + j] >> in_shift) << out_shift) & m; + + const done = comptime @min(@min(8, todo), 8 - in_shift); + todo -= done; + in_shift += done; + + if (in_shift == 8) { + in_shift = 0; + j += 1; + } + } + + // Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋ + // = ⌊(q/2ᵈ)x+½⌋ + // = ⌊(qx + 2ᵈ⁻¹)/2ᵈ⌋ + // = (qx + (1<<(d-1))) >> d + const qx = @as(u32, out) * @as(u32, Q); + ret.cs[out_off + i] = @as(i16, @intCast((qx + (1 << (d - 1))) >> d)); + } + + in_off += in_batch_size; + out_off += out_batch_size; + } + + return ret; + } + + // Returns the "pointwise" multiplication a o b. + // + // That is: invNTT(a o b) = invNTT(a) * invNTT(b). Assumes a and b are in + // Montgomery form. Products between coefficients of a and b must be strictly + // bounded in absolute value by 2¹⁵q. a o b will be in Montgomery form and + // bounded in absolute value by 2q. + fn mulHat(a: Poly, b: Poly) Poly { + // Recall from the discussion in ntt(), that a transformed polynomial is + // an element of ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷); + // that is: 128 degree-one polynomials instead of simply 256 elements + // from ℤ_q as in the regular NTT. So instead of pointwise multiplication, + // we multiply the 128 pairs of degree-one polynomials modulo the + // right equation: + // + // (a₁ + a₂x)(b₁ + b₂x) = a₁b₁ + a₂b₂ζ' + (a₁b₂ + a₂b₁)x, + // + // where ζ' is the appropriate power of ζ. + + var p: Poly = undefined; + var k: usize = 64; + var i: usize = 0; + while (i < N) : (i += 4) { + const z = @as(i32, zetas[k]); + k += 1; + + const a1b1 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i + 1])); + const a0b0 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i])); + const a1b0 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i])); + const a0b1 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i + 1])); + + p.cs[i] = montReduce(a1b1 * z) + a0b0; + p.cs[i + 1] = a0b1 + a1b0; + + const a3b3 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 3])); + const a2b2 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 2])); + const a3b2 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 2])); + const a2b3 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 3])); + + p.cs[i + 2] = a2b2 - montReduce(a3b3 * z); + p.cs[i + 3] = a2b3 + a3b2; + } + + return p; + } + + // Sample p from a centered binomial distribution with n=2η and p=½ - viz: + // coefficients are in {-η, …, η} with probabilities + // + // {ncr(0, 2η)/2^2η, ncr(1, 2η)/2^2η, …, ncr(2η,2η)/2^2η} + fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Poly { + var h = sha3.Shake256.init(.{}); + const suffix: [1]u8 = .{nonce}; + h.update(seed); + h.update(&suffix); + + // The distribution at hand is exactly the same as that + // of (a₁ + a₂ + … + a_η) - (b₁ + … + b_η) where a_i,b_i~U(1). + // Thus we need 2η bits per coefficient. + const buf_len = comptime 2 * eta * N / 8; + var buf: [buf_len]u8 = undefined; + h.squeeze(&buf); + + // buf is interpreted as a₁…a_ηb₁…b_ηa₁…a_ηb₁…b_η…. We process + // multiple coefficients in one batch. + + const T = switch (builtin.target.cpu.arch) { + .x86_64, .x86 => u32, // Generates better code on Intel CPUs + else => u64, // u128 might be faster on some other CPUs. + }; + + comptime var batch_count: usize = undefined; + comptime var batch_bytes: usize = undefined; + comptime var mask: T = 0; + comptime { + batch_count = @bitSizeOf(T) / @as(usize, 2 * eta); + while (@rem(N, batch_count) != 0 and batch_count > 0) : (batch_count -= 1) {} + assert(batch_count > 0); + assert(@rem(2 * eta * batch_count, 8) == 0); + batch_bytes = 2 * eta * batch_count / 8; + + for (0..2 * eta * batch_count) |_| { + mask <<= eta; + mask |= 1; + } + } + + var ret: Poly = undefined; + for (0..comptime N / batch_count) |i| { + // Read coefficients into t. In the case of η=3, + // we have t = a₁ + 2a₂ + 4a₃ + 8b₁ + 16b₂ + … + var t: T = 0; + inline for (0..batch_bytes) |j| { + t |= @as(T, buf[batch_bytes * i + j]) << (8 * j); + } + + // Accumelate `a's and `b's together by masking them out, shifting + // and adding. For η=3, we have d = a₁ + a₂ + a₃ + 8(b₁ + b₂ + b₃) + … + var d: T = 0; + inline for (0..eta) |j| { + d += (t >> j) & mask; + } + + // Extract each a and b separately and set coefficient in polynomial. + inline for (0..batch_count) |j| { + const mask2 = comptime (1 << eta) - 1; + const a = @as(i16, @intCast((d >> (comptime (2 * j * eta))) & mask2)); + const b = @as(i16, @intCast((d >> (comptime ((2 * j + 1) * eta))) & mask2)); + ret.cs[batch_count * i + j] = a - b; + } + } + + return ret; + } + + // Sample p uniformly from the given seed and x and y coordinates. + fn uniform(seed: [32]u8, x: u8, y: u8) Poly { + var h = sha3.Shake128.init(.{}); + const suffix: [2]u8 = .{ x, y }; + h.update(&seed); + h.update(&suffix); + + const buf_len = sha3.Shake128.block_length; // rate SHAKE-128 + var buf: [buf_len]u8 = undefined; + + var ret: Poly = undefined; + var i: usize = 0; // index into ret.cs + outer: while (true) { + h.squeeze(&buf); + + var j: usize = 0; // index into buf + while (j < buf_len) : (j += 3) { + const b0 = @as(u16, buf[j]); + const b1 = @as(u16, buf[j + 1]); + const b2 = @as(u16, buf[j + 2]); + + const ts: [2]u16 = .{ + b0 | ((b1 & 0xf) << 8), + (b1 >> 4) | (b2 << 4), + }; + + inline for (ts) |t| { + if (t < Q) { + ret.cs[i] = @as(i16, @intCast(t)); + i += 1; + + if (i == N) { + break :outer; + } + } + } + } + } + + return ret; + } + + // Packs p. + // + // Assumes p is normalized (and not just Barrett reduced). + fn toBytes(p: Poly) [bytes_length]u8 { + var ret: [bytes_length]u8 = undefined; + for (0..comptime N / 2) |i| { + const t0 = @as(u16, @intCast(p.cs[2 * i])); + const t1 = @as(u16, @intCast(p.cs[2 * i + 1])); + ret[3 * i] = @as(u8, @truncate(t0)); + ret[3 * i + 1] = @as(u8, @truncate((t0 >> 8) | (t1 << 4))); + ret[3 * i + 2] = @as(u8, @truncate(t1 >> 4)); + } + return ret; + } + + // Unpacks a Poly from buf. + // + // p will not be normalized; instead 0 ≤ p[i] < 4096. + fn fromBytes(buf: *const [bytes_length]u8) Poly { + var ret: Poly = undefined; + for (0..comptime N / 2) |i| { + const b0 = @as(i16, buf[3 * i]); + const b1 = @as(i16, buf[3 * i + 1]); + const b2 = @as(i16, buf[3 * i + 2]); + ret.cs[2 * i] = b0 | ((b1 & 0xf) << 8); + ret.cs[2 * i + 1] = (b1 >> 4) | b2 << 4; + } + return ret; + } +}; + +// A vector of K polynomials. +fn Vec(comptime K: u8) type { + return struct { + ps: [K]Poly, + + const Self = @This(); + const bytes_length = K * Poly.bytes_length; + + fn compressedSize(comptime d: u8) usize { + return Poly.compressedSize(d) * K; + } + + fn ntt(a: Self) Self { + var ret: Self = undefined; + for (0..K) |i| { + ret.ps[i] = a.ps[i].ntt(); + } + return ret; + } + + fn invNTT(a: Self) Self { + var ret: Self = undefined; + for (0..K) |i| { + ret.ps[i] = a.ps[i].invNTT(); + } + return ret; + } + + fn normalize(a: Self) Self { + var ret: Self = undefined; + for (0..K) |i| { + ret.ps[i] = a.ps[i].normalize(); + } + return ret; + } + + fn barrettReduce(a: Self) Self { + var ret: Self = undefined; + for (0..K) |i| { + ret.ps[i] = a.ps[i].barrettReduce(); + } + return ret; + } + + fn add(a: Self, b: Self) Self { + var ret: Self = undefined; + for (0..K) |i| { + ret.ps[i] = a.ps[i].add(b.ps[i]); + } + return ret; + } + + fn sub(a: Self, b: Self) Self { + var ret: Self = undefined; + for (0..K) |i| { + ret.ps[i] = a.ps[i].sub(b.ps[i]); + } + return ret; + } + + // Samples v[i] from centered binomial distribution with the given η, + // seed and nonce+i. + fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self { + var ret: Self = undefined; + for (0..K) |i| { + ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed); + } + return ret; + } + + // Sets p to the inner product of a and b using "pointwise" multiplication. + // + // See MulHat() and NTT() for a description of the multiplication. + // Assumes a and b are in Montgomery form. p will be in Montgomery form, + // and its coefficients will be bounded in absolute value by 2kq. + // If a and b are not in Montgomery form, then the action is the same + // as "pointwise" multiplication followed by multiplying by R⁻¹, the inverse + // of the Montgomery factor. + fn dotHat(a: Self, b: Self) Poly { + var ret: Poly = Poly.zero; + for (0..K) |i| { + ret = ret.add(a.ps[i].mulHat(b.ps[i])); + } + return ret; + } + + fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 { + const cs = comptime Poly.compressedSize(d); + var ret: [compressedSize(d)]u8 = undefined; + inline for (0..K) |i| { + ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d); + } + return ret; + } + + fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self { + const cs = comptime Poly.compressedSize(d); + var ret: Self = undefined; + inline for (0..K) |i| { + ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]); + } + return ret; + } + + /// Serializes the key into a byte array. + fn toBytes(v: Self) [bytes_length]u8 { + var ret: [bytes_length]u8 = undefined; + inline for (0..K) |i| { + ret[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length].* = v.ps[i].toBytes(); + } + return ret; + } + + /// Deserializes the key from a byte array. + fn fromBytes(buf: *const [bytes_length]u8) Self { + var ret: Self = undefined; + inline for (0..K) |i| { + ret.ps[i] = Poly.fromBytes( + buf[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length], + ); + } + return ret; + } + }; +} + +// A matrix of K vectors +fn Mat(comptime K: u8) type { + return struct { + const Self = @This(); + vs: [K]Vec(K), + + fn uniform(seed: [32]u8, comptime transposed: bool) Self { + var ret: Self = undefined; + var i: u8 = 0; + while (i < K) : (i += 1) { + var j: u8 = 0; + while (j < K) : (j += 1) { + ret.vs[i].ps[j] = Poly.uniform( + seed, + if (transposed) i else j, + if (transposed) j else i, + ); + } + } + return ret; + } + + // Returns transpose of A + fn transpose(m: Self) Self { + var ret: Self = undefined; + for (0..K) |i| { + for (0..K) |j| { + ret.vs[i].ps[j] = m.vs[j].ps[i]; + } + } + return ret; + } + }; +} + +// Returns `true` if a ≠ b. +fn ctneq(comptime len: usize, a: [len]u8, b: [len]u8) u1 { + return 1 - @intFromBool(crypto.utils.timingSafeEql([len]u8, a, b)); +} + +// Copy src into dst given b = 1. +fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void { + const mask = @as(u8, 0) -% b; + for (0..len) |i| { + dst[i] ^= mask & (dst[i] ^ src[i]); + } +} + +test "MulHat" { + var rnd = RndGen.init(0); + + for (0..100) |_| { + const a = Poly.randAbsLeqQ(&rnd); + const b = Poly.randAbsLeqQ(&rnd); + + const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize(); + var p: Poly = undefined; + + @memset(&p.cs, 0); + + for (0..N) |i| { + for (0..N) |j| { + var v = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[j])); + var k = i + j; + if (k >= N) { + // Recall Xᴺ = -1. + k -= N; + v = -v; + } + p.cs[k] = feBarrettReduce(v + p.cs[k]); + } + } + + p = p.toMont().normalize(); + + try testing.expectEqual(p, p2); + } +} + +test "NTT" { + var rnd = RndGen.init(0); + + for (0..1000) |_| { + var p = Poly.randAbsLeqQ(&rnd); + const q = p.toMont().normalize(); + p = p.ntt(); + + for (0..N) |i| { + try testing.expect(p.cs[i] <= 7 * Q and -7 * Q <= p.cs[i]); + } + + p = p.normalize().invNTT(); + for (0..N) |i| { + try testing.expect(p.cs[i] <= Q and -Q <= p.cs[i]); + } + + p = p.normalize(); + + try testing.expectEqual(p, q); + } +} + +test "Compression" { + var rnd = RndGen.init(0); + inline for (.{ 1, 4, 5, 10, 11 }) |d| { + for (0..1000) |_| { + const p = Poly.randNormalized(&rnd); + const pp = p.compress(d); + const pq = Poly.decompress(d, &pp).compress(d); + try testing.expectEqual(pp, pq); + } + } +} + +test "noise" { + var seed: [32]u8 = undefined; + for (&seed, 0..) |*s, i| { + s.* = @as(u8, @intCast(i)); + } + try testing.expectEqual(Poly.noise(3, 37, &seed).cs, .{ + 0, 0, 1, -1, 0, 2, 0, -1, -1, 3, 0, 1, -2, -2, 0, 1, -2, + 1, 0, -2, 3, 0, 0, 0, 1, 3, 1, 1, 2, 1, -1, -1, -1, 0, + 1, 0, 1, 0, 2, 0, 1, -2, 0, -1, -1, -2, 1, -1, -1, 2, -1, + 1, 1, 2, -3, -1, -1, 0, 0, 0, 0, 1, -1, -2, -2, 0, -2, 0, + 0, 0, 1, 0, -1, -1, 1, -2, 2, 0, 0, 2, -2, 0, 1, 0, 1, + 1, 1, 0, 1, -2, -1, -2, -1, 1, 0, 0, 0, 0, 0, 1, 0, -1, + -1, 0, -1, 1, 0, 1, 0, -1, -1, 0, -2, 2, 0, -2, 1, -1, 0, + 1, -1, -1, 2, 1, 0, 0, -2, -1, 2, 0, 0, 0, -1, -1, 3, 1, + 0, 1, 0, 1, 0, 2, 1, 0, 0, 1, 0, 1, 0, 0, -1, -1, -1, + 0, 1, 3, 1, 0, 1, 0, 1, -1, -1, -1, -1, 0, 0, -2, -1, -1, + 2, 0, 1, 0, 1, 0, 2, -2, 0, 1, 1, -3, -1, -2, -1, 0, 1, + 0, 1, -2, 2, 2, 1, 1, 0, -1, 0, -1, -1, 1, 0, -1, 2, 1, + -1, 1, 2, -2, 1, 2, 0, 1, 2, 1, 0, 0, 2, 1, 2, 1, 0, + 2, 1, 0, 0, -1, -1, 1, -1, 0, 1, -1, 2, 2, 0, 0, -1, 1, + 1, 1, 1, 0, 0, -2, 0, -1, 1, 2, 0, 0, 1, 1, -1, 1, 0, + 1, + }); + try testing.expectEqual(Poly.noise(2, 37, &seed).cs, .{ + 1, 0, 1, -1, -1, -2, -1, -1, 2, 0, -1, 0, 0, -1, + 1, 1, -1, 1, 0, 2, -2, 0, 1, 2, 0, 0, -1, 1, + 0, -1, 1, -1, 1, 2, 1, 1, 0, -1, 1, -1, -2, -1, + 1, -1, -1, -1, 2, -1, -1, 0, 0, 1, 1, -1, 1, 1, + 1, 1, -1, -2, 0, 1, 0, 0, 2, 1, -1, 2, 0, 0, + 1, 1, 0, -1, 0, 0, -1, -1, 2, 0, 1, -1, 2, -1, + -1, -1, -1, 0, -2, 0, 2, 1, 0, 0, 0, -1, 0, 0, + 0, -1, -1, 0, -1, -1, 0, -1, 0, 0, -2, 1, 1, 0, + 1, 0, 1, 0, 1, 1, -1, 2, 0, 1, -1, 1, 2, 0, + 0, 0, 0, -1, -1, -1, 0, 1, 0, -1, 2, 0, 0, 1, + 1, 1, 0, 1, -1, 1, 2, 1, 0, 2, -1, 1, -1, -2, + -1, -2, -1, 1, 0, -2, -2, -1, 1, 0, 0, 0, 0, 1, + 0, 0, 0, 2, 2, 0, 1, 0, -1, -1, 0, 2, 0, 0, + -2, 1, 0, 2, 1, -1, -2, 0, 0, -1, 1, 1, 0, 0, + 2, 0, 1, 1, -2, 1, -2, 1, 1, 0, 2, 0, -1, 0, + -1, 0, 1, 2, 0, 1, 0, -2, 1, -2, -2, 1, -1, 0, + -1, 1, 1, 0, 0, 0, 1, 0, -1, 1, 1, 0, 0, 0, + 0, 1, 0, 1, -1, 0, 1, -1, -1, 2, 0, 0, 1, -1, + 0, 1, -1, 0, + }); +} + +test "uniform sampling" { + var seed: [32]u8 = undefined; + for (&seed, 0..) |*s, i| { + s.* = @as(u8, @intCast(i)); + } + try testing.expectEqual(Poly.uniform(seed, 1, 0).cs, .{ + 797, 993, 161, 6, 2608, 2385, 2096, 2661, 1676, 247, 2440, + 342, 634, 194, 1570, 2848, 986, 684, 3148, 3208, 2018, 351, + 2288, 612, 1394, 170, 1521, 3119, 58, 596, 2093, 1549, 409, + 2156, 1934, 1730, 1324, 388, 446, 418, 1719, 2202, 1812, 98, + 1019, 2369, 214, 2699, 28, 1523, 2824, 273, 402, 2899, 246, + 210, 1288, 863, 2708, 177, 3076, 349, 44, 949, 854, 1371, + 957, 292, 2502, 1617, 1501, 254, 7, 1761, 2581, 2206, 2655, + 1211, 629, 1274, 2358, 816, 2766, 2115, 2985, 1006, 2433, 856, + 2596, 3192, 1, 1378, 2345, 707, 1891, 1669, 536, 1221, 710, + 2511, 120, 1176, 322, 1897, 2309, 595, 2950, 1171, 801, 1848, + 695, 2912, 1396, 1931, 1775, 2904, 893, 2507, 1810, 2873, 253, + 1529, 1047, 2615, 1687, 831, 1414, 965, 3169, 1887, 753, 3246, + 1937, 115, 2953, 586, 545, 1621, 1667, 3187, 1654, 1988, 1857, + 512, 1239, 1219, 898, 3106, 391, 1331, 2228, 3169, 586, 2412, + 845, 768, 156, 662, 478, 1693, 2632, 573, 2434, 1671, 173, + 969, 364, 1663, 2701, 2169, 813, 1000, 1471, 720, 2431, 2530, + 3161, 733, 1691, 527, 2634, 335, 26, 2377, 1707, 767, 3020, + 950, 502, 426, 1138, 3208, 2607, 2389, 44, 1358, 1392, 2334, + 875, 2097, 173, 1697, 2578, 942, 1817, 974, 1165, 2853, 1958, + 2973, 3282, 271, 1236, 1677, 2230, 673, 1554, 96, 242, 1729, + 2518, 1884, 2272, 71, 1382, 924, 1807, 1610, 456, 1148, 2479, + 2152, 238, 2208, 2329, 713, 1175, 1196, 757, 1078, 3190, 3169, + 708, 3117, 154, 1751, 3225, 1364, 154, 23, 2842, 1105, 1419, + 79, 5, 2013, + }); +} + +test "Polynomial packing" { + var rnd = RndGen.init(0); + + for (0..1000) |_| { + const p = Poly.randNormalized(&rnd); + try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p); + } +} + +test "Test inner PKE" { + var seed: [32]u8 = undefined; + var pt: [32]u8 = undefined; + for (&seed, &pt, 0..) |*s, *p, i| { + s.* = @as(u8, @intCast(i)); + p.* = @as(u8, @intCast(i + 32)); + } + inline for (modes) |mode| { + for (0..100) |i| { + var pk: mode.InnerPk = undefined; + var sk: mode.InnerSk = undefined; + seed[0] = @as(u8, @intCast(i)); + mode.innerKeyFromSeed(seed, &pk, &sk); + for (0..10) |j| { + seed[1] = @as(u8, @intCast(j)); + try testing.expectEqual(sk.decrypt(&pk.encrypt(&pt, &seed)), pt); + } + } + } +} + +test "Test happy flow" { + var seed: [64]u8 = undefined; + for (&seed, 0..) |*s, i| { + s.* = @as(u8, @intCast(i)); + } + inline for (modes) |mode| { + for (0..100) |i| { + seed[0] = @as(u8, @intCast(i)); + const kp = try mode.KeyPair.create(seed); + const sk = try mode.SecretKey.fromBytes(&kp.secret_key.toBytes()); + try testing.expectEqual(sk, kp.secret_key); + const pk = try mode.PublicKey.fromBytes(&kp.public_key.toBytes()); + try testing.expectEqual(pk, kp.public_key); + for (0..10) |j| { + seed[1] = @as(u8, @intCast(j)); + const e = pk.encaps(seed[0..32].*); + try testing.expectEqual(e.shared_secret, try sk.decaps(&e.ciphertext)); + } + } + } +} + +// Code to test NIST Known Answer Tests (KAT), see PQCgenKAT.c. + +const sha2 = crypto.hash.sha2; + +test "NIST KAT test" { + inline for (.{ + .{ kyber_d00.Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547" }, + .{ kyber_d00.Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5" }, + .{ kyber_d00.Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2" }, + }) |modeHash| { + const mode = modeHash[0]; + var seed: [48]u8 = undefined; + for (&seed, 0..) |*s, i| { + s.* = @as(u8, @intCast(i)); + } + var f = sha2.Sha256.init(.{}); + const fw = f.writer(); + var g = NistDRBG.init(seed); + try std.fmt.format(fw, "# {s}\n\n", .{mode.name}); + for (0..100) |i| { + g.fill(&seed); + try std.fmt.format(fw, "count = {}\n", .{i}); + try std.fmt.format(fw, "seed = {s}\n", .{std.fmt.fmtSliceHexUpper(&seed)}); + var g2 = NistDRBG.init(seed); + + // This is not equivalent to g2.fill(kseed[:]). As the reference + // implementation calls randombytes twice generating the keypair, + // we have to do that as well. + var kseed: [64]u8 = undefined; + var eseed: [32]u8 = undefined; + g2.fill(kseed[0..32]); + g2.fill(kseed[32..64]); + g2.fill(&eseed); + const kp = try mode.KeyPair.create(kseed); + const e = kp.public_key.encaps(eseed); + const ss2 = try kp.secret_key.decaps(&e.ciphertext); + try testing.expectEqual(ss2, e.shared_secret); + try std.fmt.format(fw, "pk = {s}\n", .{std.fmt.fmtSliceHexUpper(&kp.public_key.toBytes())}); + try std.fmt.format(fw, "sk = {s}\n", .{std.fmt.fmtSliceHexUpper(&kp.secret_key.toBytes())}); + try std.fmt.format(fw, "ct = {s}\n", .{std.fmt.fmtSliceHexUpper(&e.ciphertext)}); + try std.fmt.format(fw, "ss = {s}\n\n", .{std.fmt.fmtSliceHexUpper(&e.shared_secret)}); + } + + var out: [32]u8 = undefined; + f.final(&out); + var outHex: [64]u8 = undefined; + _ = try std.fmt.bufPrint(&outHex, "{s}", .{std.fmt.fmtSliceHexLower(&out)}); + try testing.expectEqual(outHex, modeHash[1].*); + } +} + +const NistDRBG = struct { + key: [32]u8, + v: [16]u8, + + fn incV(g: *NistDRBG) void { + var j: usize = 15; + while (j >= 0) : (j -= 1) { + if (g.v[j] == 255) { + g.v[j] = 0; + } else { + g.v[j] += 1; + break; + } + } + } + + // AES256_CTR_DRBG_Update(pd, &g.key, &g.v). + fn update(g: *NistDRBG, pd: ?[48]u8) void { + var buf: [48]u8 = undefined; + const ctx = crypto.core.aes.Aes256.initEnc(g.key); + var i: usize = 0; + while (i < 3) : (i += 1) { + g.incV(); + var block: [16]u8 = undefined; + ctx.encrypt(&block, &g.v); + buf[i * 16 ..][0..16].* = block; + } + if (pd) |p| { + for (&buf, p) |*b, x| { + b.* ^= x; + } + } + g.key = buf[0..32].*; + g.v = buf[32..48].*; + } + + // randombytes. + fn fill(g: *NistDRBG, out: []u8) void { + var block: [16]u8 = undefined; + var dst = out; + + const ctx = crypto.core.aes.Aes256.initEnc(g.key); + while (dst.len > 0) { + g.incV(); + ctx.encrypt(&block, &g.v); + if (dst.len < 16) { + @memcpy(dst, block[0..dst.len]); + break; + } + dst[0..block.len].* = block; + dst = dst[16..dst.len]; + } + g.update(null); + } + + fn init(seed: [48]u8) NistDRBG { + var ret: NistDRBG = .{ .key = .{0} ** 32, .v = .{0} ** 16 }; + ret.update(seed); + return ret; + } +};