From c11b3f1d7455823b986a0857cd289c190b58c8b9 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Fri, 6 Oct 2023 16:18:31 -0300 Subject: [PATCH] msm/pippenger: add automatic window calculation Signed-off-by: Ignacio Hagopian --- src/bench.zig | 99 +++++++++------ src/msm/pippenger.zig | 228 ++++++++++++++++++---------------- src/multiproof/multiproof.zig | 3 +- 3 files changed, 182 insertions(+), 148 deletions(-) diff --git a/src/bench.zig b/src/bench.zig index a53665c..3398c06 100644 --- a/src/bench.zig +++ b/src/bench.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const Allocator = std.mem.Allocator; const fields = @import("fields/fields.zig"); const crs = @import("crs/crs.zig"); const Fp = fields.BandersnatchFields.BaseField; @@ -8,15 +9,24 @@ const multiproof = @import("multiproof/multiproof.zig"); const polynomials = @import("polynomial/lagrange_basis.zig"); const ipa = @import("ipa/ipa.zig"); const Transcript = @import("ipa/transcript.zig"); -const msm = @import("msm/precomp.zig"); +const msmprecomp = @import("msm/precomp.zig"); +const pippenger = @import("msm/pippenger.zig"); pub fn main() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer { + const deinit_status = gpa.deinit(); + if (deinit_status == .leak) std.testing.expect(false) catch @panic("memory leak"); + } + var allocator = gpa.allocator(); + try benchFields(); - try benchPedersenHash(); - try benchIPAs(); - try benchMultiproofs(); + try benchPedersenHash(allocator); + try benchIPAs(allocator); + try benchMultiproofs(allocator); - try analyzePedersenHashConfigs(); + try analyzePedersenHashConfigs(allocator); + try analyzePippengerWindowSize(allocator); } fn benchFields() !void { @@ -74,17 +84,10 @@ fn benchFields() !void { std.debug.print("\n", .{}); } -fn benchPedersenHash() !void { +fn benchPedersenHash(allocator: Allocator) !void { std.debug.print("Benchmarking Pedersen hashing...\n", .{}); const N = 5000; - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - defer { - const deinit_status = gpa.deinit(); - if (deinit_status == .leak) std.testing.expect(false) catch @panic("memory leak"); - } - var allocator = gpa.allocator(); - var xcrs = try crs.CRS.init(allocator); defer xcrs.deinit(); @@ -112,16 +115,9 @@ fn benchPedersenHash() !void { std.debug.print("\n", .{}); } -fn benchIPAs() !void { +fn benchIPAs(allocator: Allocator) !void { const N = 100; - std.debug.print("Setting up IPA benchmark...\n", .{}); - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - defer { - const deinit_status = gpa.deinit(); - if (deinit_status == .leak) std.testing.expect(false) catch @panic("memory leak"); - } - var allocator = gpa.allocator(); var xcrs = try crs.CRS.init(allocator); defer xcrs.deinit(); @@ -172,20 +168,13 @@ fn benchIPAs() !void { std.debug.print("\n", .{}); } -fn benchMultiproofs() !void { +fn benchMultiproofs(allocator: Allocator) !void { const LagrangeBasis = polynomials.LagrangeBasis(crs.DomainSize, crs.Domain); std.debug.print("Setting up multiproofs benchmark...\n", .{}); const N = 25; const openings = [_]u16{ 100, 1_000, 5_000, 10_000 }; - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - defer { - const deinit_status = gpa.deinit(); - if (deinit_status == .leak) std.testing.expect(false) catch @panic("memory leak"); - } - var allocator = gpa.allocator(); - var xcrs = try crs.CRS.init(allocator); defer xcrs.deinit(); @@ -279,14 +268,7 @@ fn genBaseFieldElements(comptime N: usize) [N]Fp { return fps; } -fn analyzePedersenHashConfigs() !void { - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - defer { - const deinit_status = gpa.deinit(); - if (deinit_status == .leak) std.testing.expect(false) catch @panic("memory leak"); - } - var allocator = gpa.allocator(); - +fn analyzePedersenHashConfigs(allocator: Allocator) !void { const N = 1_000; var scalars: [crs.DomainSize]Fr = undefined; @@ -304,7 +286,7 @@ fn analyzePedersenHashConfigs() !void { inline for (ts) |t| { inline for (bs) |b| { - var precomp = try msm.PrecompMSM(t, b).init(allocator, &xcrs.Gs); + var precomp = try msmprecomp.PrecompMSM(t, b).init(allocator, &xcrs.Gs); defer precomp.deinit(); const table_size = precomp.table.len * @sizeOf(banderwagon.ElementMSM) >> 20; @@ -331,7 +313,7 @@ fn analyzePedersenHashConfigs() !void { inline for (ts) |t| { inline for (bs) |b| { - var hybprecomp = try msm.HybridPrecompMSM(cutoff, t, b, 4, 8).init(allocator, &xcrs.Gs); + var hybprecomp = try msmprecomp.HybridPrecompMSM(cutoff, t, b, 4, 8).init(allocator, &xcrs.Gs); defer hybprecomp.deinit(); const table_size1 = hybprecomp.precomp1.table.len * @sizeOf(banderwagon.ElementMSM) >> 20; @@ -352,3 +334,42 @@ fn analyzePedersenHashConfigs() !void { } std.debug.print("\n", .{}); } + +fn analyzePippengerWindowSize(allocator: Allocator) !void { + const N = 1_000; + const msm_lengths = [_]usize{ 3, 10, 100, 500, 1000, 5000, 10000 }; + + var tmp_basis: [msm_lengths[msm_lengths.len - 1]]banderwagon.Element = undefined; + tmp_basis[0] = banderwagon.Element.generator(); + for (1..msm_lengths.len) |i| { + tmp_basis[i].double(tmp_basis[i - 1]); + } + var basis: [msm_lengths[msm_lengths.len - 1]]banderwagon.ElementMSM = undefined; + banderwagon.ElementMSM.fromElements(&basis, &tmp_basis); + + var scalars: [msm_lengths[msm_lengths.len - 1]]Fr = undefined; + for (0..scalars.len) |i| { + scalars[i] = Fr.fromInteger(i + 0x424242); + } + var optimal: [msm_lengths.len]usize = std.mem.zeroes([msm_lengths.len]usize); + inline for (0..msm_lengths.len) |i| { + var lowest_duration: i64 = std.math.maxInt(i64); + std.debug.print("MSM length {}:\n", .{msm_lengths[i]}); + inline for (2..10) |window_size| { + const start = std.time.microTimestamp(); + for (0..N) |_| { + _ = try pippenger.msmWithWindowSize(allocator, window_size, basis[0..msm_lengths[i]], scalars[0..msm_lengths[i]]); + } + const duration: i64 = @divTrunc((std.time.microTimestamp() - start), N); + if (duration < lowest_duration) { + optimal[i] = window_size; + lowest_duration = duration; + } + std.debug.print("\tw={}:{}µs\n", .{ window_size, duration }); + } + } + std.debug.print("\nOptimal window sizes:\n", .{}); + for (0..msm_lengths.len) |i| { + std.debug.print("\tmsm_length={} w={}\n", .{ msm_lengths[i], optimal[i] }); + } +} diff --git a/src/msm/pippenger.zig b/src/msm/pippenger.zig index 0a988e7..8ff5f53 100644 --- a/src/msm/pippenger.zig +++ b/src/msm/pippenger.zig @@ -7,104 +7,122 @@ const Fr = banderwagon.Fr; // This is an implementation of "Notes on MSMs with Precomputation" by Gottfried Herold. -pub fn Pippenger(comptime c: comptime_int) type { - return struct { - const num_windows = std.math.divCeil(u8, Fr.BitSize, c) catch unreachable; - const num_buckets = 1 << (c - 1); - - pub fn msm(base_allocator: Allocator, basis: []const ElementNormalized, scalars_mont: []const Fr) !Element { - std.debug.assert(basis.len >= scalars_mont.len); - - var arena = std.heap.ArenaAllocator.init(base_allocator); - defer arena.deinit(); - var allocator = arena.allocator(); - - var scalars_windows = try signedDigitDecomposition(allocator, scalars_mont); - - var result: ?Element = null; - var buckets: [num_buckets]?Element = std.mem.zeroes([num_buckets]?Element); - for (0..num_windows) |w| { - // Accumulate in buckets. - for (0..buckets.len) |i| { - buckets[i] = null; - } - for (0..scalars_mont.len) |i| { - var scalar_window = scalars_windows[i + w * scalars_mont.len]; - if (scalar_window == 0) { - continue; - } - - var adj_basis: ElementNormalized = basis[i]; - if (scalar_window < 0) { - adj_basis = ElementNormalized.neg(basis[i]); - scalar_window = -scalar_window; - } - const bucket_idx = @as(usize, @intCast(scalar_window)) - 1; - if (buckets[bucket_idx] == null) { - buckets[bucket_idx] = Element.identity(); - } - buckets[bucket_idx] = Element.mixedMsmAdd(buckets[bucket_idx].?, adj_basis); - } - - // Aggregate buckets. - var window_aggr: ?Element = null; - var sum: ?Element = null; - for (0..buckets.len) |i| { - if (window_aggr == null and buckets[buckets.len - 1 - i] == null) { - continue; - } - if (window_aggr == null) { - window_aggr = buckets[buckets.len - 1 - i]; - sum = buckets[buckets.len - 1 - i]; - continue; - } - if (buckets[buckets.len - 1 - i] != null) { - sum.?.add(sum.?, buckets[buckets.len - 1 - i].?); - } - window_aggr.?.add(window_aggr.?, sum.?); - } - - // Aggregate into the final result. - if (result != null) { - for (0..c) |_| { - result.?.double(result.?); - } - } - if (window_aggr != null) { - if (result == null) { - result = window_aggr.?; - } else { - result.?.add(result.?, window_aggr.?); - } - } +const optimals: [3]struct { length: u64, value: u4 } = .{ + .{ .length = 10, .value = 4 }, + .{ .length = 100, .value = 6 }, + .{ .length = std.math.maxInt(u64), .value = 8 }, +}; + +// msm computes the multi-scalar multiplication of scalars_mont and basis. It automatically +// select the optimal window size. +pub fn msm(base_allocator: Allocator, basis: []const ElementNormalized, scalars_mont: []const Fr) !Element { + const c: u4 = inline for (optimals) |optimal| { + if (basis.len <= optimal.length) { + break optimal.value; + } + }; + + return msmWithWindowSize(base_allocator, c, basis, scalars_mont); +} + +// msmWithWindowSize computes the multi-scalar multiplication of scalars_mont and basis using a specific window size. +// Usually clients should be using `msm` function instead to calculate this automatically. +pub fn msmWithWindowSize(base_allocator: Allocator, c: u4, basis: []const ElementNormalized, scalars_mont: []const Fr) !Element { + const num_windows = std.math.divCeil(u8, Fr.BitSize, c) catch unreachable; + const num_buckets = @as(u16, 1) << (c - 1); + + std.debug.assert(basis.len >= scalars_mont.len); + + var arena = std.heap.ArenaAllocator.init(base_allocator); + defer arena.deinit(); + var allocator = arena.allocator(); + + var scalars_windows = try signedDigitDecomposition(allocator, c, num_windows, scalars_mont); + + var result: ?Element = null; + var buckets = try allocator.alloc(?Element, num_buckets); + @memset(buckets, null); + + for (0..num_windows) |w| { + // Accumulate in buckets. + for (0..buckets.len) |i| { + buckets[i] = null; + } + for (0..scalars_mont.len) |i| { + var scalar_window = scalars_windows[i + w * scalars_mont.len]; + if (scalar_window == 0) { + continue; } - return result orelse Element.identity(); + var adj_basis: ElementNormalized = basis[i]; + if (scalar_window < 0) { + adj_basis = ElementNormalized.neg(basis[i]); + scalar_window = -scalar_window; + } + const bucket_idx = @as(usize, @intCast(scalar_window)) - 1; + if (buckets[bucket_idx] == null) { + buckets[bucket_idx] = Element.identity(); + } + buckets[bucket_idx] = Element.mixedMsmAdd(buckets[bucket_idx].?, adj_basis); } - fn signedDigitDecomposition(arena: Allocator, scalars_mont: []const Fr) ![]i16 { - const window_mask = (1 << c) - 1; - var scalars_windows = try arena.alloc(i16, scalars_mont.len * num_windows); - - for (0..scalars_mont.len) |i| { - const scalar = scalars_mont[i].toInteger(); - var carry: u1 = 0; - for (0..num_windows) |j| { - const curr_window = @as(u16, @intCast((scalar >> @as(u8, @intCast(j * c))) & window_mask)) + carry; - carry = 0; - if (curr_window >= 1 << (c - 1)) { - std.debug.assert(j != num_windows - 1); - scalars_windows[(num_windows - 1 - j) * scalars_mont.len + i] = @as(i16, @intCast(curr_window)) - (1 << c); - carry = 1; - } else { - scalars_windows[(num_windows - 1 - j) * scalars_mont.len + i] = @as(i16, @intCast(curr_window)); - } - } + // Aggregate buckets. + var window_aggr: ?Element = null; + var sum: ?Element = null; + for (0..buckets.len) |i| { + if (window_aggr == null and buckets[buckets.len - 1 - i] == null) { + continue; + } + if (window_aggr == null) { + window_aggr = buckets[buckets.len - 1 - i]; + sum = buckets[buckets.len - 1 - i]; + continue; } + if (buckets[buckets.len - 1 - i] != null) { + sum.?.add(sum.?, buckets[buckets.len - 1 - i].?); + } + window_aggr.?.add(window_aggr.?, sum.?); + } - return scalars_windows; + // Aggregate into the final result. + if (result != null) { + for (0..c) |_| { + result.?.double(result.?); + } } - }; + if (window_aggr != null) { + if (result == null) { + result = window_aggr.?; + } else { + result.?.add(result.?, window_aggr.?); + } + } + } + + return result orelse Element.identity(); +} + +fn signedDigitDecomposition(arena: Allocator, c: u4, num_windows: u8, scalars_mont: []const Fr) ![]i16 { + const window_mask = (@as(u16, 1) << c) - 1; + var scalars_windows = try arena.alloc(i16, scalars_mont.len * num_windows); + + for (0..scalars_mont.len) |i| { + const scalar = scalars_mont[i].toInteger(); + var carry: u1 = 0; + for (0..num_windows) |j| { + const curr_window = @as(u16, @intCast((scalar >> @as(u8, @intCast(j * c))) & window_mask)) + carry; + carry = 0; + if (curr_window >= @as(u16, 1) << (c - 1)) { + std.debug.assert(j != num_windows - 1); + scalars_windows[(num_windows - 1 - j) * scalars_mont.len + i] = @as(i16, @intCast(curr_window)) - (@as(i16, 1) << c); + carry = 1; + } else { + scalars_windows[(num_windows - 1 - j) * scalars_mont.len + i] = @as(i16, @intCast(curr_window)); + } + } + } + + return scalars_windows; } test "correctness" { @@ -117,24 +135,20 @@ test "correctness" { scalars[i] = Fr.fromInteger((i + 0x93434) *% 0x424242); } - inline for (3..8) |c| { - const pippenger = Pippenger(c); + for (1..crs.DomainSize) |msm_length| { + const msm_scalars = scalars[0..msm_length]; - for (1..crs.DomainSize) |msm_length| { - const msm_scalars = scalars[0..msm_length]; - - var full_scalars: [crs.DomainSize]Fr = undefined; - for (0..full_scalars.len) |i| { - if (i < msm_length) { - full_scalars[i] = msm_scalars[i]; - continue; - } - full_scalars[i] = Fr.zero(); + var full_scalars: [crs.DomainSize]Fr = undefined; + for (0..full_scalars.len) |i| { + if (i < msm_length) { + full_scalars[i] = msm_scalars[i]; + continue; } - const exp = xcrs.commitSlow(full_scalars); - const got = try pippenger.msm(std.testing.allocator, xcrs.Gs[0..msm_length], msm_scalars); - - try std.testing.expect(Element.equal(exp, got)); + full_scalars[i] = Fr.zero(); } + const exp = xcrs.commitSlow(full_scalars); + const got = try msm(std.testing.allocator, xcrs.Gs[0..msm_length], msm_scalars); + + try std.testing.expect(Element.equal(exp, got)); } } diff --git a/src/multiproof/multiproof.zig b/src/multiproof/multiproof.zig index 2866ac2..bc3ac69 100644 --- a/src/multiproof/multiproof.zig +++ b/src/multiproof/multiproof.zig @@ -205,8 +205,7 @@ pub const MultiProof = struct { Cs[i] = query.C; E_coefficients[i] = Fr.mul(powers_of_r[i], helper_scalar_den[queries[i].z]); } - // TODO: make the window size be dynamically calculated. - const E = try pippenger.Pippenger(10).msm(allocator, Cs, E_coefficients); + const E = try pippenger.msm(allocator, Cs, E_coefficients); transcript.appendPoint(E, "E"); // Check IPA proof.