zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

commit 8709f53d440ed8479f711d871a2d6c2c35dc1014 (tree)
parent 99ec1ee3536b577bd1d14facde523c503108886d
Author: Frank Denis <jedisct1@noreply.codeberg.org>
Date:   Sun, 25 Jan 2026 17:42:01 +0100

crypto.ff: allow seamless chaining regardless of representation (#30913)

Finite field elements can be in regular or Montgomery form, and
chaining different operations use to require manual and error-prone
conversions.

Now:

- `add`, `sub` and `mul` convert the second operand to match the
first operand's form
- `sq` and `pow` preserve the input's Montgomery form
- `toPrimitive` and `toBytes` return `UnexpectedRepresentation` if
the element is in Montgomery form, preventing incorrect serialization

This is fully backwards compatible and allows seamless chaining of
operations regardless of their representation.

Diffstat:
Mlib/std/crypto/ff.zig | 166+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------
1 file changed, 134 insertions(+), 32 deletions(-)

diff --git a/lib/std/crypto/ff.zig b/lib/std/crypto/ff.zig @@ -329,7 +329,11 @@ fn Fe_(comptime bits: comptime_int) type { /// Converts the field element to a primitive. /// This function may not run in constant time. - pub fn toPrimitive(self: Self, comptime T: type) OverflowError!T { + /// Returns an error if the element is in Montgomery form. + pub fn toPrimitive(self: Self, comptime T: type) (OverflowError || RepresentationError)!T { + if (self.montgomery) { + return error.UnexpectedRepresentation; + } return self.v.toPrimitive(T); } @@ -343,7 +347,11 @@ fn Fe_(comptime bits: comptime_int) type { } /// Converts the field element to a byte string. - pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void { + /// Returns an error if the element is in Montgomery form. + pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) (OverflowError || RepresentationError)!void { + if (self.montgomery) { + return error.UnexpectedRepresentation; + } return self.v.toBytes(bytes, endian); } @@ -530,19 +538,46 @@ pub fn Modulus(comptime max_bits: comptime_int) type { /// Adds two field elements (mod m). pub fn add(self: Self, x: Fe, y: Fe) Fe { var out = x; - const overflow = out.v.addWithOverflow(y.v); - const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v)); - const need_sub = ct.eql(overflow, underflow); - _ = out.v.conditionalSubWithOverflow(need_sub, self.v); - return out; + if (x.montgomery == y.montgomery) { + @branchHint(.likely); + const overflow = out.v.addWithOverflow(y.v); + const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v)); + const need_sub = ct.eql(overflow, underflow); + _ = out.v.conditionalSubWithOverflow(need_sub, self.v); + return out; + } else { + var y_ = y; + if (y.montgomery) { + self.fromMontgomery(&y_) catch unreachable; + } else { + self.toMontgomery(&y_) catch unreachable; + } + const overflow = out.v.addWithOverflow(y_.v); + const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v)); + const need_sub = ct.eql(overflow, underflow); + _ = out.v.conditionalSubWithOverflow(need_sub, self.v); + return out; + } } /// Subtracts two field elements (mod m). pub fn sub(self: Self, x: Fe, y: Fe) Fe { var out = x; - const underflow: bool = @bitCast(out.v.subWithOverflow(y.v)); - _ = out.v.conditionalAddWithOverflow(underflow, self.v); - return out; + if (x.montgomery == y.montgomery) { + const underflow: bool = @bitCast(out.v.subWithOverflow(y.v)); + _ = out.v.conditionalAddWithOverflow(underflow, self.v); + return out; + } else { + var y_ = y; + if (y.montgomery) { + self.fromMontgomery(&y_) catch unreachable; + } else { + self.toMontgomery(&y_) catch unreachable; + } + const underflow: bool = @bitCast(out.v.subWithOverflow(y_.v)); + _ = out.v.conditionalAddWithOverflow(underflow, self.v); + return out; + } } /// Converts a field element to the Montgomery form. @@ -663,13 +698,15 @@ pub fn Modulus(comptime max_bits: comptime_int) type { for (e) |b| acc |= b; if (acc == 0) return error.NullExponent; + const was_montgomery = x.montgomery; + var out = self.one(); self.toMontgomery(&out) catch unreachable; if (public and e.len < 3 or (e.len == 3 and e[if (endian == .big) 0 else 2] <= 0b1111)) { // Do not use a precomputation table for short, public exponents var x_m = x; - if (x.montgomery == false) { + if (!x.montgomery) { self.toMontgomery(&x_m) catch unreachable; } var s = switch (endian) { @@ -702,7 +739,7 @@ pub fn Modulus(comptime max_bits: comptime_int) type { } else { // Use a precomputation table for large exponents var pc = [1]Fe{x} ++ [_]Fe{self.zero} ** 14; - if (x.montgomery == false) { + if (!x.montgomery) { self.toMontgomery(&pc[0]) catch unreachable; } for (1..pc.len) |i| { @@ -747,38 +784,55 @@ pub fn Modulus(comptime max_bits: comptime_int) type { } } } - self.fromMontgomery(&out) catch unreachable; + if (!was_montgomery) { + self.fromMontgomery(&out) catch unreachable; + } return out; } /// Multiplies two field elements. + /// Result preserves the first operand's form. pub fn mul(self: Self, x: Fe, y: Fe) Fe { - if (x.montgomery != y.montgomery) { - return self.montgomeryMul(x, y); - } - var a_ = x; - if (x.montgomery == false) { - self.toMontgomery(&a_) catch unreachable; + if (x.montgomery) { + const y_ = if (!y.montgomery) blk: { + var yy = y; + self.toMontgomery(&yy) catch unreachable; + break :blk yy; + } else y; + return self.montgomeryMul(x, y_); } else { - self.fromMontgomery(&a_) catch unreachable; + var x_m = x; + var y_m = if (y.montgomery) blk: { + var yy = y; + self.fromMontgomery(&yy) catch unreachable; + break :blk yy; + } else y; + self.toMontgomery(&x_m) catch unreachable; + self.toMontgomery(&y_m) catch unreachable; + var out = self.montgomeryMul(x_m, y_m); + self.fromMontgomery(&out) catch unreachable; + return out; } - return self.montgomeryMul(a_, y); } /// Squares a field element. pub fn sq(self: Self, x: Fe) Fe { - var out = x; - if (x.montgomery == true) { + if (x.montgomery) { + return self.montgomerySq(x); + } else { + var out = x; + self.toMontgomery(&out) catch unreachable; + out = self.montgomerySq(out); self.fromMontgomery(&out) catch unreachable; + return out; } - out = self.montgomerySq(out); - out.montgomery = false; - self.toMontgomery(&out) catch unreachable; - return out; } /// Returns x^e (mod m) in constant time. - pub fn pow(self: Self, x: Fe, e: Fe) NullExponentError!Fe { + pub fn pow(self: Self, x: Fe, e: Fe) (NullExponentError || RepresentationError)!Fe { + if (e.montgomery) { + return error.UnexpectedRepresentation; + } var buf: [Fe.encoded_bytes]u8 = undefined; e.toBytes(&buf, native_endian) catch unreachable; return self.powWithEncodedExponent(x, &buf, native_endian); @@ -786,7 +840,10 @@ pub fn Modulus(comptime max_bits: comptime_int) type { /// Returns x^e (mod m), assuming that the exponent is public. /// The function remains constant time with respect to `x`. - pub fn powPublic(self: Self, x: Fe, e: Fe) NullExponentError!Fe { + pub fn powPublic(self: Self, x: Fe, e: Fe) (NullExponentError || RepresentationError)!Fe { + if (e.montgomery) { + return error.UnexpectedRepresentation; + } var e_normalized = Fe{ .v = e.v.normalize() }; var buf_: [Fe.encoded_bytes]u8 = undefined; var buf = buf_[0 .. math.divCeil(usize, e_normalized.v.limbs_len * t_bits, 8) catch unreachable]; @@ -927,6 +984,8 @@ test "finite field arithmetic" { try m.toMontgomery(&x); x_y = m.mul(x, y); + try testing.expect(x_y.montgomery); // result preserves first operand's form + try m.fromMontgomery(&x_y); try testing.expectEqual(x_y.toPrimitive(u256), 1666576607955767413750776202132407807424848069716933450241); try m.fromMontgomery(&x); @@ -941,8 +1000,11 @@ test "finite field arithmetic" { const x_pow_y = try m.powPublic(x, y); try testing.expectEqual(x_pow_y.toPrimitive(u256), 1631933139300737762906024873185789093007782131928298618473); + try testing.expect(!x_pow_y.montgomery); try m.toMontgomery(&x); - const x_pow_y2 = try m.powPublic(x, y); + var x_pow_y2 = try m.powPublic(x, y); + try testing.expect(x_pow_y2.montgomery); + try m.fromMontgomery(&x_pow_y2); try m.fromMontgomery(&x); try testing.expect(x_pow_y2.eql(x_pow_y)); try testing.expectError(error.NullExponent, m.powPublic(x, m.zero)); @@ -953,13 +1015,53 @@ test "finite field arithmetic" { const x_sq = m.sq(x); const x_sq2 = m.mul(x, x); + try testing.expect(!x_sq.montgomery); + try testing.expect(!x_sq2.montgomery); try testing.expect(x_sq.eql(x_sq2)); try m.toMontgomery(&x); - const x_sq3 = m.sq(x); - const x_sq4 = m.mul(x, x); + var x_sq3 = m.sq(x); + var x_sq4 = m.mul(x, x); + try testing.expect(x_sq3.montgomery); + try testing.expect(x_sq4.montgomery); + try m.fromMontgomery(&x_sq3); + try m.fromMontgomery(&x_sq4); try testing.expect(x_sq.eql(x_sq3)); try testing.expect(x_sq3.eql(x_sq4)); try m.fromMontgomery(&x); + + var x_mont = x; + try m.toMontgomery(&x_mont); + + // Non-montgomery + montgomery + const add_nm_m = m.add(x, x_mont); + try testing.expect(!add_nm_m.montgomery); + var add_m_nm = m.add(x_mont, x); + try testing.expect(add_m_nm.montgomery); + try m.fromMontgomery(&add_m_nm); + try testing.expect(add_nm_m.eql(add_m_nm)); + + // Non-montgomery - montgomery + const sub_nm_m = m.sub(x, y); + try testing.expect(!sub_nm_m.montgomery); + var y_mont = y; + try m.toMontgomery(&y_mont); + var sub_m_nm = m.sub(x_mont, y); + try testing.expect(sub_m_nm.montgomery); + try m.fromMontgomery(&sub_m_nm); + try testing.expect(sub_nm_m.eql(sub_m_nm)); + + // mul: preserves first operand's form + const mul_nm_m = m.mul(x, x_mont); + try testing.expect(!mul_nm_m.montgomery); + const mul_nm_nm = m.mul(x, x); + try testing.expect(mul_nm_m.eql(mul_nm_nm)); + var mul_m_nm = m.mul(x_mont, x); + try testing.expect(mul_m_nm.montgomery); + try m.fromMontgomery(&mul_m_nm); + try testing.expect(mul_m_nm.eql(mul_nm_nm)); + + try testing.expectEqual(x.toPrimitive(u256), 80169837251094269539116136208111827396136208141182357733); + try testing.expectError(error.UnexpectedRepresentation, x_mont.toPrimitive(u256)); } fn testCt(ct_: anytype) !void {