From 4d17fe8870d76735a45ebf703a42901b49e820c8 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Fri, 29 Sep 2023 09:37:33 -0300 Subject: [PATCH] pippenger: use normalized points and fix everything to not jump over abstractions Signed-off-by: Ignacio Hagopian --- src/bandersnatch/bandersnatch.zig | 26 +------ src/bandersnatch/points/affine.zig | 1 + src/bandersnatch/points/extended.zig | 48 ++++++------- src/banderwagon/banderwagon.zig | 99 +++++++++++++++++++++++++- src/bench.zig | 7 +- src/crs/crs.zig | 21 +++--- src/ipa/ipa.zig | 10 ++- src/ipa/transcript.zig | 6 ++ src/main.zig | 4 +- src/msm/pippenger.zig | 12 ++-- src/msm/precomp.zig | 36 +++++----- src/multiproof/multiproof.zig | 11 +-- src/polynomial/precomputed_weights.zig | 1 - 13 files changed, 184 insertions(+), 98 deletions(-) diff --git a/src/bandersnatch/bandersnatch.zig b/src/bandersnatch/bandersnatch.zig index dd30d7b..7fda404 100644 --- a/src/bandersnatch/bandersnatch.zig +++ b/src/bandersnatch/bandersnatch.zig @@ -50,7 +50,7 @@ test "neg" { test "serialize gen" { const gen = ExtendedPoint.generator(); - const serialised_point = gen.to_bytes(); + const serialised_point = gen.toBytes(); // test vector taken from the rust code (see spec reference) const expected = "18ae52a26618e7e1658499ad22c0792bf342be7b77113774c5340b2ccc32c129"; @@ -78,7 +78,7 @@ test "scalar mul minus one" { const result = gen.scalarMul(scalar); const expected = "e951ad5d98e7181e99d76452e0e343281295e38d90c602bf824892fd86742c4a"; - const actual = std.fmt.bytesToHex(result.to_bytes(), std.fmt.Case.lower); + const actual = std.fmt.bytesToHex(result.toBytes(), std.fmt.Case.lower); try std.testing.expectEqualSlices(u8, expected, &actual); } @@ -193,25 +193,3 @@ test "batch inv with error" { const out = Fp.batchInv(&got_invs, &fes); try std.testing.expectError(error.CantInvertZeroElement, out); } - -test "ExtendedPoint -> ExtendedPointNormalized" { - const g = ExtendedPoint.generator(); - const scalars = [_]Fr{ Fr.fromInteger(3213), Fr.fromInteger(1212), Fr.fromInteger(4433) }; - - var points: [scalars.len]ExtendedPoint = undefined; - for (0..scalars.len) |i| { - points[i] = g.scalarMul(scalars[i]); - } - - var expected: [scalars.len]ExtendedPointNormalized = undefined; - for (0..scalars.len) |i| { - expected[i] = ExtendedPointNormalized.fromExtendedPoint(points[i]); - } - - var got: [scalars.len]ExtendedPointNormalized = undefined; - try ExtendedPointNormalized.fromExtendedPoints(&got, &points); - - for (0..expected.len) |i| { - try std.testing.expect(expected[i].equal(got[i])); - } -} diff --git a/src/bandersnatch/points/affine.zig b/src/bandersnatch/points/affine.zig index 09b8a16..16f85ce 100644 --- a/src/bandersnatch/points/affine.zig +++ b/src/bandersnatch/points/affine.zig @@ -97,6 +97,7 @@ pub fn scalarMul(point: AffinePoint, scalar: Fr) AffinePoint { return result; } +// TODO pub fn to_bytes(self: AffinePoint) [32]u8 { const mCompressedNegative = 0x80; const mCompressedPositive = 0x00; diff --git a/src/bandersnatch/points/extended.zig b/src/bandersnatch/points/extended.zig index 1ff70e2..cc70bea 100644 --- a/src/bandersnatch/points/extended.zig +++ b/src/bandersnatch/points/extended.zig @@ -13,6 +13,18 @@ pub const ExtendedPointNormalized = struct { return comptime fromExtendedPoint(ExtendedPoint.identity()); } + pub fn generator() ExtendedPointNormalized { + return comptime fromExtendedPoint(ExtendedPoint.generator()); + } + + pub fn initUnsafe(x: Fp, y: Fp) ExtendedPointNormalized { + return ExtendedPointNormalized{ + .x = x, + .y = y, + .t = x.mul(y), + }; + } + pub fn fromExtendedPoint(p: ExtendedPoint) ExtendedPointNormalized { const z_inv = p.z.inv().?; const x = p.x.mul(z_inv); @@ -24,29 +36,6 @@ pub const ExtendedPointNormalized = struct { }; } - pub fn fromExtendedPoints(result: []ExtendedPointNormalized, points: []const ExtendedPoint) !void { - var accumulator = Fp.one(); - - for (0..points.len) |i| { - result[i].x = accumulator; - accumulator = Fp.mul(accumulator, points[i].z); - } - - var accInverse = accumulator.inv().?; - - for (0..points.len) |i| { - result[result.len - 1 - i].x = Fp.mul(result[result.len - 1 - i].x, accInverse); - accInverse = Fp.mul(accInverse, points[points.len - 1 - i].z); - } - - for (0..points.len) |i| { - const z_inv = result[i].x; - result[i].x = Fp.mul(points[i].x, z_inv); - result[i].y = Fp.mul(points[i].y, z_inv); - result[i].t = Fp.mul(result[i].x, result[i].y); - } - } - pub fn equal(self: ExtendedPointNormalized, other: ExtendedPointNormalized) bool { return self.x.equal(other.x) and self.y.equal(other.y) and self.t.equal(other.t); } @@ -74,6 +63,15 @@ pub const ExtendedPoint = struct { }; } + pub fn fromExtendedPointNormalized(e: ExtendedPointNormalized) ExtendedPoint { + return ExtendedPoint{ + .x = e.x, + .y = e.y, + .t = e.t, + .z = Fp.one(), + }; + } + pub fn identity() ExtendedPoint { const iden = comptime AffinePoint.identity(); return comptime initUnsafe(iden.x, iden.y); @@ -116,6 +114,7 @@ pub const ExtendedPoint = struct { return (p.x.mul(q.z).equal(p.z.mul(q.x))) and (p.y.mul(q.z).equal(q.y.mul(p.z))); } + // TODO: change api to result receiver. pub fn add(p: ExtendedPoint, q: ExtendedPoint) ExtendedPoint { // https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html#addition-add-2008-hwcd const a = Fp.mul(p.x, q.x); @@ -222,8 +221,7 @@ pub const ExtendedPoint = struct { } } - // # Only used for testing purposes. - pub fn to_bytes(self: ExtendedPoint) [32]u8 { + pub fn toBytes(self: ExtendedPoint) [32]u8 { return self.toAffine().to_bytes(); } }; diff --git a/src/banderwagon/banderwagon.zig b/src/banderwagon/banderwagon.zig index 7ffc333..82f5bc3 100644 --- a/src/banderwagon/banderwagon.zig +++ b/src/banderwagon/banderwagon.zig @@ -3,7 +3,7 @@ const Bandersnatch = @import("../bandersnatch/bandersnatch.zig"); const Fp = Bandersnatch.Fp; const AffinePoint = Bandersnatch.AffinePoint; const ExtendedPoint = Bandersnatch.ExtendedPoint; -const ArrayList = std.ArrayList; +const ExtendedPointNormalized = Bandersnatch.ExtendedPointNormalized; // Fr is the scalar field of the Banderwgaon group, which matches with the // scalar field size of the Bandersnatch primer-ordered subgroup. @@ -19,6 +19,12 @@ pub const Element = struct { return Element{ .point = ExtendedPoint.initUnsafe(bytes) }; } + pub fn fromElementNormalized(e: ElementNormalized) Element { + return Element{ + .point = ExtendedPoint.fromExtendedPointNormalized(e.point), + }; + } + // fromBytes deserializes an element from a byte array. // The spec serialization is the X coordinate in big endian form. pub fn fromBytes(bytes: [BytesSize]u8) !Element { @@ -64,6 +70,12 @@ pub const Element = struct { self.point = ExtendedPoint.add(p.point, q.point); } + pub fn mixedAdd(a: Element, b: ElementNormalized) Element { + return Element{ + .point = ExtendedPoint.mixedAdd(a.point, b.point), + }; + } + // sub subtracts two elements of the Banderwagon group. pub fn sub(self: *Element, p: Element, q: Element) void { self.point = ExtendedPoint.sub(p.point, q.point); @@ -134,10 +146,10 @@ pub const Element = struct { }; // msm computes the multi-scalar multiplication of scalars and points. +// TODO: change to Pippenger calls, and make everything use this. pub fn msm(points: []const Element, scalars: []const Fr) Element { std.debug.assert(scalars.len == points.len); - // TODO: optimize! var res = Element.identity(); for (scalars, points) |scalar, point| { if (scalar.isZero()) { @@ -203,3 +215,86 @@ test "two torsion" { try std.testing.expect(result.equal(gen)); } + +pub const ElementNormalized = struct { + point: ExtendedPointNormalized, + + // fromBytes deserializes an element from a byte array. + // The spec serialization is the X coordinate in big endian form. + pub fn fromBytes(bytes: [Element.BytesSize]u8) !ElementNormalized { + var bytes_le = bytes; + std.mem.reverse(u8, &bytes_le); + const x = Fp.fromBytes(bytes_le); // TODO: reject if bytes are not canonical? + + if (Element.subgroupCheck(x) != 1) { + return error.NotInSubgroup; + } + const y = try AffinePoint.getYCoordinate(x, true); + + return ElementNormalized{ .point = ExtendedPointNormalized.initUnsafe(x, y) }; + } + + pub fn generator() ElementNormalized { + return ElementNormalized{ .point = ExtendedPointNormalized.generator() }; + } + + pub fn fromElement(p: Element) ElementNormalized { + return ElementNormalized{ + .point = ExtendedPointNormalized.fromExtendedPoint(p.point), + }; + } + + pub fn equal(a: ElementNormalized, b: ElementNormalized) bool { + return ExtendedPointNormalized.equal(a.point, b.point); + } + + pub fn toBytes(self: ElementNormalized) [Element.BytesSize]u8 { + return Element.fromElementNormalized(self).toBytes(); + } + + // TODO: move this. + pub fn fromElements(result: []ElementNormalized, points: []const Element) void { + var accumulator = Fp.one(); + + for (0..points.len) |i| { + result[i].point.x = accumulator; + accumulator = Fp.mul(accumulator, points[i].point.z); + } + + var accInverse = accumulator.inv().?; + + for (0..points.len) |i| { + result[result.len - 1 - i].point.x = Fp.mul(result[result.len - 1 - i].point.x, accInverse); + accInverse = Fp.mul(accInverse, points[points.len - 1 - i].point.z); + } + + for (0..points.len) |i| { + const z_inv = result[i].point.x; + result[i].point.x = Fp.mul(points[i].point.x, z_inv); + result[i].point.y = Fp.mul(points[i].point.y, z_inv); + result[i].point.t = Fp.mul(result[i].point.x, result[i].point.y); + } + } +}; + +test "Element -> ElementNormalized" { + const g = Element.generator(); + const scalars = [_]Fr{ Fr.fromInteger(3213), Fr.fromInteger(1212), Fr.fromInteger(4433) }; + + var points: [scalars.len]Element = undefined; + for (0..scalars.len) |i| { + points[i] = g.scalarMul(scalars[i]); + } + + var expected: [scalars.len]ElementNormalized = undefined; + for (0..scalars.len) |i| { + expected[i] = ElementNormalized.fromElement(points[i]); + } + + var got: [scalars.len]ElementNormalized = undefined; + ElementNormalized.fromElements(&got, &points); + + for (0..expected.len) |i| { + try std.testing.expect(expected[i].equal(got[i])); + } +} diff --git a/src/bench.zig b/src/bench.zig index 6c042b2..794f6b7 100644 --- a/src/bench.zig +++ b/src/bench.zig @@ -169,8 +169,8 @@ fn benchMultiproofs() !void { const LagrangeBasis = polynomials.LagrangeBasis(crs.DomainSize, crs.Domain); std.debug.print("Setting up multiproofs benchmark...\n", .{}); - const N = 1; - const openings = [_]u16{1000}; + const N = 25; + const openings = [_]u16{ 100, 1_000, 5_000, 10_000 }; var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer { @@ -229,11 +229,10 @@ fn benchMultiproofs() !void { defer allocator.free(verifier_queries); for (0..num_openings) |i| { verifier_queries[i] = multiproof.VerifierQuery{ - .C = vec_openings[i].C, + .C = banderwagon.ElementNormalized.fromElement(vec_openings[i].C), .z = vec_openings[i].z, .y = vec_openings[i].poly_evaluations[vec_openings[i].z], }; - verifier_queries[i].C.normalize(); } start = std.time.milliTimestamp(); const ok = try mproof.verifyProof(allocator, &verifier_transcript, verifier_queries, proof); diff --git a/src/crs/crs.zig b/src/crs/crs.zig index 20de044..014cfbb 100644 --- a/src/crs/crs.zig +++ b/src/crs/crs.zig @@ -4,6 +4,7 @@ const sha256 = std.crypto.hash.sha2.Sha256; const banderwagon = @import("../banderwagon/banderwagon.zig"); const msm = @import("../msm/precomp.zig"); const Element = banderwagon.Element; +const ElementNormalized = banderwagon.ElementNormalized; const Fr = banderwagon.Fr; // DomainSize is the size of the domain. @@ -24,15 +25,15 @@ pub const Domain: [DomainSize]Fr = domain_elements: { pub const CRS = struct { const PrecompMSM = msm.PrecompMSM(2, 8); - Gs: [DomainSize]Element, - Q: Element, + Gs: [DomainSize]ElementNormalized, + Q: ElementNormalized, precomp: PrecompMSM, pub fn init(allocator: Allocator) !CRS { const points = deserialize_vkt_points(); return CRS{ .Gs = points, - .Q = Element.generator(), + .Q = ElementNormalized.generator(), .precomp = try PrecompMSM.init(allocator, &points), }; } @@ -46,16 +47,20 @@ pub const CRS = struct { } pub fn commitSlow(self: CRS, values: [DomainSize]Fr) Element { - return banderwagon.msm(&self.Gs, &values); + var gs: [DomainSize]Element = undefined; + for (0..DomainSize) |i| { + gs[i] = Element.fromElementNormalized(self.Gs[i]); + } + return banderwagon.msm(&gs, &values); } }; -fn deserialize_vkt_points() [DomainSize]Element { - var points: [vkt_crs_points.len]Element = undefined; +fn deserialize_vkt_points() [DomainSize]ElementNormalized { + var points: [vkt_crs_points.len]ElementNormalized = undefined; for (vkt_crs_points, 0..) |serialized_point, i| { var g_be_bytes: [32]u8 = undefined; _ = std.fmt.hexToBytes(&g_be_bytes, serialized_point) catch unreachable; - points[i] = Element.fromBytes(g_be_bytes) catch unreachable; + points[i] = ElementNormalized.fromBytes(g_be_bytes) catch unreachable; } return points; } @@ -85,7 +90,7 @@ test "crs is consistent" { test "Gs cannot contain the generator" { const crs = try CRS.init(std.testing.allocator); defer crs.deinit(); - const generator = Element.generator(); + const generator = ElementNormalized.generator(); for (crs.Gs) |point| { try std.testing.expect(!generator.equal(point)); } diff --git a/src/ipa/ipa.zig b/src/ipa/ipa.zig index a785eb6..04c4532 100644 --- a/src/ipa/ipa.zig +++ b/src/ipa/ipa.zig @@ -1,6 +1,7 @@ const std = @import("std"); const banderwagon = @import("../banderwagon/banderwagon.zig"); const Element = banderwagon.Element; +const ElementNormalized = banderwagon.ElementNormalized; const Fr = banderwagon.Fr; const crs = @import("../crs/crs.zig"); const Transcript = @import("transcript.zig"); @@ -61,11 +62,14 @@ pub fn IPA(comptime VectorLength: comptime_int) type { // Rescale Q. const w = transcript.challengeScalar("w"); - const q = xcrs.Q.scalarMul(w); + const q = Element.fromElementNormalized(xcrs.Q).scalarMul(w); var L: [NUM_STEPS]Element = undefined; var R: [NUM_STEPS]Element = undefined; - var _basis = xcrs.Gs; + var _basis: [crs.DomainSize]Element = undefined; + for (0..crs.DomainSize) |i| { + _basis[i] = Element.fromElementNormalized(xcrs.Gs[i]); + } var basis: []Element = _basis[0..]; var step: usize = 0; @@ -133,7 +137,7 @@ pub fn IPA(comptime VectorLength: comptime_int) type { transcript.appendScalar(y, "output point"); const w = transcript.challengeScalar("w"); - const q = xcrs.Q.scalarMul(w); + const q = Element.fromElementNormalized(xcrs.Q).scalarMul(w); var commitment: Element = undefined; commitment.add(C, q.scalarMul(y)); diff --git a/src/ipa/transcript.zig b/src/ipa/transcript.zig index c87c231..2baaabd 100644 --- a/src/ipa/transcript.zig +++ b/src/ipa/transcript.zig @@ -3,6 +3,7 @@ const sha256 = std.crypto.hash.sha2.Sha256; const banderwagon = @import("../banderwagon/banderwagon.zig"); const Fr = banderwagon.Fr; const Element = banderwagon.Element; +const ElementNormalized = banderwagon.ElementNormalized; state: sha256, @@ -36,6 +37,11 @@ pub fn appendPoint(self: *Transcript, point: Element, label: []const u8) void { self.appendBytes(&point_as_bytes, label); } +pub fn appendPointNormalized(self: *Transcript, point: ElementNormalized, label: []const u8) void { + const point_as_bytes = point.toBytes(); + self.appendBytes(&point_as_bytes, label); +} + pub fn challengeScalar(self: *Transcript, label: []const u8) Fr { self.domainSep(label); diff --git a/src/main.zig b/src/main.zig index b46cc68..6e25b37 100644 --- a/src/main.zig +++ b/src/main.zig @@ -4,7 +4,6 @@ pub fn main() !void {} test "bandersnatch" { _ = @import("bandersnatch/bandersnatch.zig"); - // std.testing.refAllDeclsRecursive(@This()); } test "banderwagon" { @@ -16,7 +15,8 @@ test "crs" { } test "msm" { - _ = @import("msm/precomp.zig"); + // _ = @import("msm/precomp.zig"); + _ = @import("msm/pippenger.zig"); } test "fields" { diff --git a/src/msm/pippenger.zig b/src/msm/pippenger.zig index a3f64e2..8c24a94 100644 --- a/src/msm/pippenger.zig +++ b/src/msm/pippenger.zig @@ -2,6 +2,7 @@ const std = @import("std"); const Allocator = std.mem.Allocator; const banderwagon = @import("../banderwagon/banderwagon.zig"); const Element = banderwagon.Element; +const ElementNormalized = banderwagon.ElementNormalized; const Fr = banderwagon.Fr; pub fn Pippenger(comptime c: comptime_int) type { @@ -10,7 +11,7 @@ pub fn Pippenger(comptime c: comptime_int) type { const num_windows = std.math.divCeil(u8, Fr.BitSize, c) catch unreachable; const num_buckets = (1 << c) - 1; - pub fn msm(base_allocator: Allocator, basis: []const Element, scalars_mont: []const Fr) !Element { + pub fn msm(base_allocator: Allocator, basis: []const ElementNormalized, scalars_mont: []const Fr) !Element { std.debug.assert(basis.len >= scalars_mont.len); var arena = std.heap.ArenaAllocator.init(base_allocator); @@ -41,10 +42,9 @@ pub fn Pippenger(comptime c: comptime_int) type { continue; } if (buckets[scalar_windows[i] - 1] == null) { - buckets[scalar_windows[i] - 1] = basis[i]; - continue; + buckets[scalar_windows[i] - 1] = Element.identity(); } - buckets[scalar_windows[i] - 1].?.add(buckets[scalar_windows[i] - 1].?, basis[i]); + buckets[scalar_windows[i] - 1] = Element.mixedAdd(buckets[scalar_windows[i] - 1].?, basis[i]); } // Aggregate buckets. @@ -59,7 +59,9 @@ pub fn Pippenger(comptime c: comptime_int) type { sum = buckets[buckets.len - 1 - i]; continue; } - sum.?.add(sum.?, buckets[buckets.len - 1 - i] orelse Element.identity()); + if (buckets[buckets.len - 1 - i] != null) { + sum.?.add(sum.?, buckets[buckets.len - 1 - i].?); + } window_aggr.?.add(window_aggr.?, sum.?); } diff --git a/src/msm/precomp.zig b/src/msm/precomp.zig index 8fd09aa..4dabd41 100644 --- a/src/msm/precomp.zig +++ b/src/msm/precomp.zig @@ -2,10 +2,8 @@ const std = @import("std"); const Allocator = std.mem.Allocator; const banderwagon = @import("../banderwagon/banderwagon.zig"); const Element = banderwagon.Element; +const ElementNormalized = banderwagon.ElementNormalized; const Fr = banderwagon.Fr; -const bandersnatch = @import("../bandersnatch/bandersnatch.zig"); -const ExtendedPoint = bandersnatch.ExtendedPoint; -const ExtendedPointNormalized = bandersnatch.ExtendedPointNormalized; pub fn PrecompMSM( comptime _t: comptime_int, @@ -20,29 +18,29 @@ pub fn PrecompMSM( const points_per_column = (Fr.BitSize + t - 1) / t; allocator: Allocator, - table: []const ExtendedPointNormalized, + table: []const ElementNormalized, num_windows: usize, basis_len: usize, - pub fn init(allocator: Allocator, basis: []const Element) !Self { + pub fn init(allocator: Allocator, basis: []const ElementNormalized) !Self { const num_windows = (points_per_column * basis.len + b - 1) / b; - var table_basis = try allocator.alloc(ExtendedPoint, points_per_column * basis.len); + var table_basis = try allocator.alloc(Element, points_per_column * basis.len); defer allocator.free(table_basis); var idx: usize = 0; for (0..basis.len) |hi| { - table_basis[idx] = basis[hi].point; + table_basis[idx] = Element.fromElementNormalized(basis[hi]); idx += 1; for (1..points_per_column) |_| { table_basis[idx] = table_basis[idx - 1]; for (0..t) |_| { - table_basis[idx] = ExtendedPoint.double(table_basis[idx]); + table_basis[idx].double(table_basis[idx]); } idx += 1; } } - var nn_table = try allocator.alloc(ExtendedPoint, window_size * num_windows); + var nn_table = try allocator.alloc(Element, window_size * num_windows); defer allocator.free(nn_table); for (0..num_windows) |w| { const start = w * b; @@ -54,8 +52,8 @@ pub fn PrecompMSM( fillWindow(window_basis, nn_table[w * window_size .. (w + 1) * window_size]); } - var table = try allocator.alloc(ExtendedPointNormalized, window_size * num_windows); - try ExtendedPointNormalized.fromExtendedPoints(table, nn_table); + var table = try allocator.alloc(ElementNormalized, window_size * num_windows); + ElementNormalized.fromElements(table, nn_table); return Self{ .allocator = allocator, @@ -80,10 +78,10 @@ pub fn PrecompMSM( scalars[i] = mont_scalars[i].toInteger(); } - var accum = bandersnatch.ExtendedPoint.identity(); + var accum = Element.identity(); for (0..t) |t_i| { if (t_i > 0) { - accum = bandersnatch.ExtendedPoint.double(accum); + accum.double(accum); } var curr_window_idx: usize = 0; @@ -100,7 +98,7 @@ pub fn PrecompMSM( if (curr_window_b_idx == b) { if (curr_window_scalar > 0) { - accum = bandersnatch.ExtendedPoint.mixedAdd(accum, self.table[curr_window_idx * window_size .. (curr_window_idx + 1) * window_size][curr_window_scalar]); + accum = Element.mixedAdd(accum, self.table[curr_window_idx * window_size .. (curr_window_idx + 1) * window_size][curr_window_scalar]); } curr_window_idx += 1; @@ -110,23 +108,23 @@ pub fn PrecompMSM( } } if (curr_window_scalar > 0) { - accum = bandersnatch.ExtendedPoint.mixedAdd(accum, self.table[curr_window_idx * window_size .. (curr_window_idx + 1) * window_size][curr_window_scalar]); + accum = Element.mixedAdd(accum, self.table[curr_window_idx * window_size .. (curr_window_idx + 1) * window_size][curr_window_scalar]); } } - return Element{ .point = accum }; + return accum; } - fn fillWindow(basis: []const ExtendedPoint, table: []ExtendedPoint) void { + fn fillWindow(basis: []const Element, table: []Element) void { if (basis.len == 0) { for (0..table.len) |i| { - table[i] = ExtendedPoint.identity(); + table[i] = Element.identity(); } return; } fillWindow(basis[1..], table[0 .. table.len / 2]); for (0..table.len / 2) |i| { - table[table.len / 2 + i] = ExtendedPoint.add(table[i], basis[0]); + table[table.len / 2 + i].add(table[i], basis[0]); } } }; diff --git a/src/multiproof/multiproof.zig b/src/multiproof/multiproof.zig index d2528c3..7b56585 100644 --- a/src/multiproof/multiproof.zig +++ b/src/multiproof/multiproof.zig @@ -2,6 +2,7 @@ const std = @import("std"); const Allocator = std.mem.Allocator; const banderwagon = @import("../banderwagon/banderwagon.zig"); const Element = banderwagon.Element; +const ElementNormalized = banderwagon.ElementNormalized; const Fr = banderwagon.Fr; const lagrange_basis = @import("../polynomial/lagrange_basis.zig"); const Transcript = @import("../ipa/transcript.zig"); @@ -24,7 +25,7 @@ pub const ProverQuery = struct { }; pub const VerifierQuery = struct { - C: Element, + C: ElementNormalized, z: u8, y: Fr, }; @@ -150,7 +151,7 @@ pub const MultiProof = struct { const C_i = query.C; const z_i = query.z; const y_i = query.y; - transcript.appendPoint(C_i, "C"); + transcript.appendPointNormalized(C_i, "C"); transcript.appendScalar(Fr.fromInteger(z_i), "z"); transcript.appendScalar(y_i, "y"); } @@ -199,7 +200,7 @@ pub const MultiProof = struct { // Compute E = sum(C_i * r^i/(t-z_i)) var E_coefficients = try allocator.alloc(Fr, queries.len); - var Cs = try allocator.alloc(Element, queries.len); + var Cs = try allocator.alloc(ElementNormalized, queries.len); for (queries, 0..) |query, i| { Cs[i] = query.C; E_coefficients[i] = Fr.mul(powers_of_r[i], helper_scalar_den[queries[i].z]); @@ -366,8 +367,8 @@ test "basic" { ); var verifier_transcript = Transcript.init("test"); - var vquery_a = VerifierQuery{ .C = Cs[0], .z = zs[0], .y = ys[0] }; - var vquery_b = VerifierQuery{ .C = Cs[1], .z = zs[1], .y = ys[1] }; + var vquery_a = VerifierQuery{ .C = ElementNormalized.fromElement(Cs[0]), .z = zs[0], .y = ys[0] }; + var vquery_b = VerifierQuery{ .C = ElementNormalized.fromElement(Cs[1]), .z = zs[1], .y = ys[1] }; const ok = try multiproof.verifyProof(allocator, &verifier_transcript, &[_]VerifierQuery{ vquery_a, vquery_b }, proof); try std.testing.expect(ok); diff --git a/src/polynomial/precomputed_weights.zig b/src/polynomial/precomputed_weights.zig index 14ad5b9..66a0c6a 100644 --- a/src/polynomial/precomputed_weights.zig +++ b/src/polynomial/precomputed_weights.zig @@ -58,7 +58,6 @@ pub fn PrecomputedWeights( // barycentricFormularConstants returns a slice with the constants to be used when evaluating a polynomial at z. // b_i = A(z) / A'(DOMAIN[i]) * 1 / (z - DOMAIN[i]) - // The caller is responsible for freeing the returned slice. pub fn barycentricFormulaConstants(self: Self, z: Fr) ![DomainSize]Fr { std.debug.assert(z.toInteger() >= DomainSize);