From f7830b59a3fb9300c0468ec25530fff7c78a38d1 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Thu, 5 Oct 2023 15:38:06 -0300 Subject: [PATCH] fields: generalize bitsize calculation Signed-off-by: Ignacio Hagopian --- src/bandersnatch/bandersnatch.zig | 18 +++++------ src/banderwagon/banderwagon.zig | 6 ++-- src/fields/fields.zig | 42 +++++++++++++------------- src/fields/sqrt.zig | 2 +- src/ipa/transcript.zig | 2 +- src/polynomial/precomputed_weights.zig | 2 +- 6 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/bandersnatch/bandersnatch.zig b/src/bandersnatch/bandersnatch.zig index 3144505..c4d5784 100644 --- a/src/bandersnatch/bandersnatch.zig +++ b/src/bandersnatch/bandersnatch.zig @@ -7,7 +7,7 @@ pub const Fp = BandersnatchFields.BaseField; pub const Fr = BandersnatchFields.ScalarField; // Curve parameters. -pub const A = Fp.fromInteger(Fp.MODULO - 5); +pub const A = Fp.fromInteger(Fp.Modulo - 5); pub const D = Fp.fromInteger(138827208126141220649022263972958607803).div(Fp.fromInteger(171449701953573178309673572579671231137)) catch unreachable; // Points. @@ -72,7 +72,7 @@ test "scalar mul smoke" { test "scalar mul minus one" { const gen = ExtendedPoint.generator(); - const integer = Fr.MODULO - 1; + const integer = Fr.Modulo - 1; const scalar = Fr.fromInteger(integer); const result = gen.scalarMul(scalar); @@ -98,14 +98,14 @@ test "zero" { test "lexographically largest" { try std.testing.expect(!Fp.fromInteger(0).lexographicallyLargest()); - try std.testing.expect(!Fp.fromInteger(Fp.Q_MIN_ONE_DIV_2).lexographicallyLargest()); + try std.testing.expect(!Fp.fromInteger(Fp.QMinOneDiv2).lexographicallyLargest()); - try std.testing.expect(Fp.fromInteger(Fp.Q_MIN_ONE_DIV_2 + 1).lexographicallyLargest()); - try std.testing.expect(Fp.fromInteger(Fp.MODULO - 1).lexographicallyLargest()); + try std.testing.expect(Fp.fromInteger(Fp.QMinOneDiv2 + 1).lexographicallyLargest()); + try std.testing.expect(Fp.fromInteger(Fp.Modulo - 1).lexographicallyLargest()); } test "from and to bytes" { - const cases = [_]Fp{ Fp.fromInteger(0), Fp.fromInteger(1), Fp.fromInteger(Fp.Q_MIN_ONE_DIV_2), Fp.fromInteger(Fp.MODULO - 1) }; + const cases = [_]Fp{ Fp.fromInteger(0), Fp.fromInteger(1), Fp.fromInteger(Fp.QMinOneDiv2), Fp.fromInteger(Fp.Modulo - 1) }; for (cases) |fe| { const bytes = fe.toBytes(); @@ -124,12 +124,12 @@ test "to integer" { } test "add sub mul neg" { - const got = Fp.fromInteger(10).mul(Fp.fromInteger(20)).add(Fp.fromInteger(30)).sub(Fp.fromInteger(40)).add(Fp.fromInteger(Fp.MODULO)); + const got = Fp.fromInteger(10).mul(Fp.fromInteger(20)).add(Fp.fromInteger(30)).sub(Fp.fromInteger(40)).add(Fp.fromInteger(Fp.Modulo)); const want = Fp.fromInteger(190); try std.testing.expect(got.equal(want)); const gotneg = got.neg(); - const wantneg = Fp.fromInteger(Fp.MODULO - 190); + const wantneg = Fp.fromInteger(Fp.Modulo - 190); try std.testing.expect(gotneg.equal(wantneg)); } @@ -140,7 +140,7 @@ test "inv" { try std.testing.expect(T.fromInteger(0).inv() == null); const one = T.one(); - const cases = [_]T{ T.fromInteger(2), T.fromInteger(42), T.fromInteger(T.MODULO - 1) }; + const cases = [_]T{ T.fromInteger(2), T.fromInteger(42), T.fromInteger(T.Modulo - 1) }; for (cases) |fe| { try std.testing.expect(fe.mul(fe.inv().?).equal(one)); } diff --git a/src/banderwagon/banderwagon.zig b/src/banderwagon/banderwagon.zig index ce6ae90..d1edb65 100644 --- a/src/banderwagon/banderwagon.zig +++ b/src/banderwagon/banderwagon.zig @@ -221,7 +221,7 @@ pub const ElementMSM = struct { // The spec serialization is the X coordinate in big endian form. pub fn fromBytes(bytes: [Element.BytesSize]u8) !ElementMSM { const bi = std.mem.readIntSlice(u256, &bytes, std.builtin.Endian.Big); - if (bi >= Fp.MODULO) { + if (bi >= Fp.Modulo) { return error.BytesNotCanonical; } @@ -306,14 +306,14 @@ test "Element -> ElementNormalized" { } test "bytes canonical" { - const max_value_fp = Fp.MODULO - 1; + const max_value_fp = Fp.Modulo - 1; var bytes: [Fp.BytesSize]u8 = undefined; std.mem.writeInt(u256, &bytes, max_value_fp, std.builtin.Endian.Big); // Must succeed. _ = try ElementMSM.fromBytes(bytes); for (0..3) |i| { - const bigger_than_modulus = Fp.MODULO + i; + const bigger_than_modulus = Fp.Modulo + i; std.mem.writeInt(u256, &bytes, bigger_than_modulus, std.builtin.Endian.Big); const must_error = ElementMSM.fromBytes(bytes); try std.testing.expectError(error.BytesNotCanonical, must_error); diff --git a/src/fields/fields.zig b/src/fields/fields.zig index 462e333..1761062 100644 --- a/src/fields/fields.zig +++ b/src/fields/fields.zig @@ -12,13 +12,13 @@ pub const BandersnatchFields = struct { fn Field(comptime F: type, comptime mod: u256) type { return struct { - pub const BitSize = 253; // TODO - pub const BytesSize = 32; - pub const MODULO = mod; - pub const Q_MIN_ONE_DIV_2 = (MODULO - 1) / 2; + pub const BitSize = @bitSizeOf(u256) - @clz(mod); + pub const BytesSize = @sizeOf(u256); + pub const Modulo = mod; + pub const QMinOneDiv2 = (Modulo - 1) / 2; const Self = @This(); - const baseZero = val: { + const base_zero = val: { var bz: F.MontgomeryDomainFieldElement = undefined; F.fromBytes(&bz, [_]u8{0} ** BytesSize); break :val Self{ .fe = bz }; @@ -28,7 +28,7 @@ fn Field(comptime F: type, comptime mod: u256) type { pub fn fromInteger(num: u256) Self { var lbe: [BytesSize]u8 = [_]u8{0} ** BytesSize; - std.mem.writeInt(u256, lbe[0..], num % MODULO, std.builtin.Endian.Little); + std.mem.writeInt(u256, lbe[0..], num % Modulo, std.builtin.Endian.Little); var nonMont: F.NonMontgomeryDomainFieldElement = undefined; F.fromBytes(&nonMont, lbe); @@ -39,7 +39,7 @@ fn Field(comptime F: type, comptime mod: u256) type { } pub fn zero() Self { - return baseZero; + return base_zero; } pub fn one() Self { @@ -75,7 +75,7 @@ fn Field(comptime F: type, comptime mod: u256) type { pub fn lexographicallyLargest(self: Self) bool { const selfNonMont = self.toInteger(); - return selfNonMont > Q_MIN_ONE_DIV_2; + return selfNonMont > QMinOneDiv2; } pub fn fromMontgomery(self: Self) F.NonMontgomeryDomainFieldElement { @@ -111,12 +111,12 @@ fn Field(comptime F: type, comptime mod: u256) type { pub fn neg(self: Self) Self { var ret: F.MontgomeryDomainFieldElement = undefined; - F.sub(&ret, baseZero.fe, self.fe); + F.sub(&ret, base_zero.fe, self.fe); return Self{ .fe = ret }; } pub fn isZero(self: Self) bool { - return self.equal(baseZero); + return self.equal(base_zero); } pub fn isOne(self: Self) bool { @@ -165,7 +165,7 @@ fn Field(comptime F: type, comptime mod: u256) type { } pub fn inv(self: Self) ?Self { - var r: u256 = MODULO; + var r: u256 = Modulo; var t: i512 = 0; var newr: u256 = self.toInteger(); @@ -188,15 +188,15 @@ fn Field(comptime F: type, comptime mod: u256) type { } if (t < 0) { - t = t + MODULO; + t = t + Modulo; } return Self.fromInteger(@intCast(t)); } pub fn div(self: Self, den: Self) !Self { - const denInv = den.inv() orelse return error.DivisionByZero; - return self.mul(denInv); + const den_inv = den.inv() orelse return error.DivisionByZero; + return self.mul(den_inv); } pub fn equal(self: Self, other: Self) bool { @@ -218,13 +218,13 @@ fn Field(comptime F: type, comptime mod: u256) type { return null; } var candidate: Self = undefined; - var rootOfUnity: Self = undefined; - fastsqrt.sqrtAlg_ComputeRelevantPowers(x, &candidate, &rootOfUnity); - if (!fastsqrt.invSqrtEqDyadic(&rootOfUnity)) { + var root_of_unity: Self = undefined; + fastsqrt.sqrtAlg_ComputeRelevantPowers(x, &candidate, &root_of_unity); + if (!fastsqrt.invSqrtEqDyadic(&root_of_unity)) { return null; } - return mul(candidate, rootOfUnity); + return mul(candidate, root_of_unity); } pub fn legendre(a: Self) i2 { @@ -234,10 +234,10 @@ fn Field(comptime F: type, comptime mod: u256) type { // a, then a|p = 0) // Returns 1 if a has a square root modulo // p, -1 otherwise. - const ls = a.pow((MODULO - 1) / 2); + const ls = a.pow((Modulo - 1) / 2); - const moduloMinusOne = comptime fromInteger(MODULO - 1); - if (ls.equal(moduloMinusOne)) { + const modulo_minus_one = comptime fromInteger(Modulo - 1); + if (ls.equal(modulo_minus_one)) { return -1; } else if (ls.isZero()) { return 0; diff --git a/src/fields/sqrt.zig b/src/fields/sqrt.zig index 4204b62..4f4449d 100644 --- a/src/fields/sqrt.zig +++ b/src/fields/sqrt.zig @@ -26,7 +26,7 @@ const sqrtPrecomp_PrimitiveDyadicRoots: [BaseField2Adicity + 1]feType_SquareRoot ret[i] = Fp.square(ret[i - 1]); } - if (ret[BaseField2Adicity - 1].toInteger() != Fp.MODULO - 1) { + if (ret[BaseField2Adicity - 1].toInteger() != Fp.Modulo - 1) { @compileError("something is wrong with the dyadic roots of unity"); } diff --git a/src/ipa/transcript.zig b/src/ipa/transcript.zig index d31c159..848a937 100644 --- a/src/ipa/transcript.zig +++ b/src/ipa/transcript.zig @@ -139,7 +139,7 @@ test "test vector 2" { test "test vector 3" { // Test that domain separation is consistent across implementations var transcript = Transcript.init("simple_protocol"); - const minus_one = Fr.fromInteger(Fr.MODULO - 1); + const minus_one = Fr.fromInteger(Fr.Modulo - 1); const one = Fr.one(); transcript.appendScalar(minus_one, "-1"); transcript.domainSep("separate me"); diff --git a/src/polynomial/precomputed_weights.zig b/src/polynomial/precomputed_weights.zig index 66a0c6a..dc71193 100644 --- a/src/polynomial/precomputed_weights.zig +++ b/src/polynomial/precomputed_weights.zig @@ -45,7 +45,7 @@ pub fn PrecomputedWeights( inverses[0] = Fr.zero(); for (1..DomainSize) |d| { inverses[d] = Fr.inv(Fr.fromInteger(d)) orelse Fr.zero(); - inverses[inverses.len - d] = Fr.inv(Fr.fromInteger(Fr.MODULO - d)) orelse Fr.zero(); + inverses[inverses.len - d] = Fr.inv(Fr.fromInteger(Fr.Modulo - d)) orelse Fr.zero(); } return .{ .A = _A,