commit 8709f53d440ed8479f711d871a2d6c2c35dc1014 (tree)
parent 99ec1ee3536b577bd1d14facde523c503108886d
Author: Frank Denis <jedisct1@noreply.codeberg.org>
Date: Sun, 25 Jan 2026 17:42:01 +0100
crypto.ff: allow seamless chaining regardless of representation (#30913)
Finite field elements can be in regular or Montgomery form, and
chaining different operations use to require manual and error-prone
conversions.
Now:
- `add`, `sub` and `mul` convert the second operand to match the
first operand's form
- `sq` and `pow` preserve the input's Montgomery form
- `toPrimitive` and `toBytes` return `UnexpectedRepresentation` if
the element is in Montgomery form, preventing incorrect serialization
This is fully backwards compatible and allows seamless chaining of
operations regardless of their representation.
Diffstat:
| M | lib/std/crypto/ff.zig | | | 166 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------- |
1 file changed, 134 insertions(+), 32 deletions(-)
diff --git a/lib/std/crypto/ff.zig b/lib/std/crypto/ff.zig
@@ -329,7 +329,11 @@ fn Fe_(comptime bits: comptime_int) type {
/// Converts the field element to a primitive.
/// This function may not run in constant time.
- pub fn toPrimitive(self: Self, comptime T: type) OverflowError!T {
+ /// Returns an error if the element is in Montgomery form.
+ pub fn toPrimitive(self: Self, comptime T: type) (OverflowError || RepresentationError)!T {
+ if (self.montgomery) {
+ return error.UnexpectedRepresentation;
+ }
return self.v.toPrimitive(T);
}
@@ -343,7 +347,11 @@ fn Fe_(comptime bits: comptime_int) type {
}
/// Converts the field element to a byte string.
- pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
+ /// Returns an error if the element is in Montgomery form.
+ pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) (OverflowError || RepresentationError)!void {
+ if (self.montgomery) {
+ return error.UnexpectedRepresentation;
+ }
return self.v.toBytes(bytes, endian);
}
@@ -530,19 +538,46 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
/// Adds two field elements (mod m).
pub fn add(self: Self, x: Fe, y: Fe) Fe {
var out = x;
- const overflow = out.v.addWithOverflow(y.v);
- const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
- const need_sub = ct.eql(overflow, underflow);
- _ = out.v.conditionalSubWithOverflow(need_sub, self.v);
- return out;
+ if (x.montgomery == y.montgomery) {
+ @branchHint(.likely);
+ const overflow = out.v.addWithOverflow(y.v);
+ const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
+ const need_sub = ct.eql(overflow, underflow);
+ _ = out.v.conditionalSubWithOverflow(need_sub, self.v);
+ return out;
+ } else {
+ var y_ = y;
+ if (y.montgomery) {
+ self.fromMontgomery(&y_) catch unreachable;
+ } else {
+ self.toMontgomery(&y_) catch unreachable;
+ }
+ const overflow = out.v.addWithOverflow(y_.v);
+ const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
+ const need_sub = ct.eql(overflow, underflow);
+ _ = out.v.conditionalSubWithOverflow(need_sub, self.v);
+ return out;
+ }
}
/// Subtracts two field elements (mod m).
pub fn sub(self: Self, x: Fe, y: Fe) Fe {
var out = x;
- const underflow: bool = @bitCast(out.v.subWithOverflow(y.v));
- _ = out.v.conditionalAddWithOverflow(underflow, self.v);
- return out;
+ if (x.montgomery == y.montgomery) {
+ const underflow: bool = @bitCast(out.v.subWithOverflow(y.v));
+ _ = out.v.conditionalAddWithOverflow(underflow, self.v);
+ return out;
+ } else {
+ var y_ = y;
+ if (y.montgomery) {
+ self.fromMontgomery(&y_) catch unreachable;
+ } else {
+ self.toMontgomery(&y_) catch unreachable;
+ }
+ const underflow: bool = @bitCast(out.v.subWithOverflow(y_.v));
+ _ = out.v.conditionalAddWithOverflow(underflow, self.v);
+ return out;
+ }
}
/// Converts a field element to the Montgomery form.
@@ -663,13 +698,15 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
for (e) |b| acc |= b;
if (acc == 0) return error.NullExponent;
+ const was_montgomery = x.montgomery;
+
var out = self.one();
self.toMontgomery(&out) catch unreachable;
if (public and e.len < 3 or (e.len == 3 and e[if (endian == .big) 0 else 2] <= 0b1111)) {
// Do not use a precomputation table for short, public exponents
var x_m = x;
- if (x.montgomery == false) {
+ if (!x.montgomery) {
self.toMontgomery(&x_m) catch unreachable;
}
var s = switch (endian) {
@@ -702,7 +739,7 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
} else {
// Use a precomputation table for large exponents
var pc = [1]Fe{x} ++ [_]Fe{self.zero} ** 14;
- if (x.montgomery == false) {
+ if (!x.montgomery) {
self.toMontgomery(&pc[0]) catch unreachable;
}
for (1..pc.len) |i| {
@@ -747,38 +784,55 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
}
}
}
- self.fromMontgomery(&out) catch unreachable;
+ if (!was_montgomery) {
+ self.fromMontgomery(&out) catch unreachable;
+ }
return out;
}
/// Multiplies two field elements.
+ /// Result preserves the first operand's form.
pub fn mul(self: Self, x: Fe, y: Fe) Fe {
- if (x.montgomery != y.montgomery) {
- return self.montgomeryMul(x, y);
- }
- var a_ = x;
- if (x.montgomery == false) {
- self.toMontgomery(&a_) catch unreachable;
+ if (x.montgomery) {
+ const y_ = if (!y.montgomery) blk: {
+ var yy = y;
+ self.toMontgomery(&yy) catch unreachable;
+ break :blk yy;
+ } else y;
+ return self.montgomeryMul(x, y_);
} else {
- self.fromMontgomery(&a_) catch unreachable;
+ var x_m = x;
+ var y_m = if (y.montgomery) blk: {
+ var yy = y;
+ self.fromMontgomery(&yy) catch unreachable;
+ break :blk yy;
+ } else y;
+ self.toMontgomery(&x_m) catch unreachable;
+ self.toMontgomery(&y_m) catch unreachable;
+ var out = self.montgomeryMul(x_m, y_m);
+ self.fromMontgomery(&out) catch unreachable;
+ return out;
}
- return self.montgomeryMul(a_, y);
}
/// Squares a field element.
pub fn sq(self: Self, x: Fe) Fe {
- var out = x;
- if (x.montgomery == true) {
+ if (x.montgomery) {
+ return self.montgomerySq(x);
+ } else {
+ var out = x;
+ self.toMontgomery(&out) catch unreachable;
+ out = self.montgomerySq(out);
self.fromMontgomery(&out) catch unreachable;
+ return out;
}
- out = self.montgomerySq(out);
- out.montgomery = false;
- self.toMontgomery(&out) catch unreachable;
- return out;
}
/// Returns x^e (mod m) in constant time.
- pub fn pow(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
+ pub fn pow(self: Self, x: Fe, e: Fe) (NullExponentError || RepresentationError)!Fe {
+ if (e.montgomery) {
+ return error.UnexpectedRepresentation;
+ }
var buf: [Fe.encoded_bytes]u8 = undefined;
e.toBytes(&buf, native_endian) catch unreachable;
return self.powWithEncodedExponent(x, &buf, native_endian);
@@ -786,7 +840,10 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
/// Returns x^e (mod m), assuming that the exponent is public.
/// The function remains constant time with respect to `x`.
- pub fn powPublic(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
+ pub fn powPublic(self: Self, x: Fe, e: Fe) (NullExponentError || RepresentationError)!Fe {
+ if (e.montgomery) {
+ return error.UnexpectedRepresentation;
+ }
var e_normalized = Fe{ .v = e.v.normalize() };
var buf_: [Fe.encoded_bytes]u8 = undefined;
var buf = buf_[0 .. math.divCeil(usize, e_normalized.v.limbs_len * t_bits, 8) catch unreachable];
@@ -927,6 +984,8 @@ test "finite field arithmetic" {
try m.toMontgomery(&x);
x_y = m.mul(x, y);
+ try testing.expect(x_y.montgomery); // result preserves first operand's form
+ try m.fromMontgomery(&x_y);
try testing.expectEqual(x_y.toPrimitive(u256), 1666576607955767413750776202132407807424848069716933450241);
try m.fromMontgomery(&x);
@@ -941,8 +1000,11 @@ test "finite field arithmetic" {
const x_pow_y = try m.powPublic(x, y);
try testing.expectEqual(x_pow_y.toPrimitive(u256), 1631933139300737762906024873185789093007782131928298618473);
+ try testing.expect(!x_pow_y.montgomery);
try m.toMontgomery(&x);
- const x_pow_y2 = try m.powPublic(x, y);
+ var x_pow_y2 = try m.powPublic(x, y);
+ try testing.expect(x_pow_y2.montgomery);
+ try m.fromMontgomery(&x_pow_y2);
try m.fromMontgomery(&x);
try testing.expect(x_pow_y2.eql(x_pow_y));
try testing.expectError(error.NullExponent, m.powPublic(x, m.zero));
@@ -953,13 +1015,53 @@ test "finite field arithmetic" {
const x_sq = m.sq(x);
const x_sq2 = m.mul(x, x);
+ try testing.expect(!x_sq.montgomery);
+ try testing.expect(!x_sq2.montgomery);
try testing.expect(x_sq.eql(x_sq2));
try m.toMontgomery(&x);
- const x_sq3 = m.sq(x);
- const x_sq4 = m.mul(x, x);
+ var x_sq3 = m.sq(x);
+ var x_sq4 = m.mul(x, x);
+ try testing.expect(x_sq3.montgomery);
+ try testing.expect(x_sq4.montgomery);
+ try m.fromMontgomery(&x_sq3);
+ try m.fromMontgomery(&x_sq4);
try testing.expect(x_sq.eql(x_sq3));
try testing.expect(x_sq3.eql(x_sq4));
try m.fromMontgomery(&x);
+
+ var x_mont = x;
+ try m.toMontgomery(&x_mont);
+
+ // Non-montgomery + montgomery
+ const add_nm_m = m.add(x, x_mont);
+ try testing.expect(!add_nm_m.montgomery);
+ var add_m_nm = m.add(x_mont, x);
+ try testing.expect(add_m_nm.montgomery);
+ try m.fromMontgomery(&add_m_nm);
+ try testing.expect(add_nm_m.eql(add_m_nm));
+
+ // Non-montgomery - montgomery
+ const sub_nm_m = m.sub(x, y);
+ try testing.expect(!sub_nm_m.montgomery);
+ var y_mont = y;
+ try m.toMontgomery(&y_mont);
+ var sub_m_nm = m.sub(x_mont, y);
+ try testing.expect(sub_m_nm.montgomery);
+ try m.fromMontgomery(&sub_m_nm);
+ try testing.expect(sub_nm_m.eql(sub_m_nm));
+
+ // mul: preserves first operand's form
+ const mul_nm_m = m.mul(x, x_mont);
+ try testing.expect(!mul_nm_m.montgomery);
+ const mul_nm_nm = m.mul(x, x);
+ try testing.expect(mul_nm_m.eql(mul_nm_nm));
+ var mul_m_nm = m.mul(x_mont, x);
+ try testing.expect(mul_m_nm.montgomery);
+ try m.fromMontgomery(&mul_m_nm);
+ try testing.expect(mul_m_nm.eql(mul_nm_nm));
+
+ try testing.expectEqual(x.toPrimitive(u256), 80169837251094269539116136208111827396136208141182357733);
+ try testing.expectError(error.UnexpectedRepresentation, x_mont.toPrimitive(u256));
}
fn testCt(ct_: anytype) !void {