Skip to content

Commit

Permalink
pippenger: use normalized points and fix everything to not jump over …
Browse files Browse the repository at this point in the history
…abstractions

Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com>
  • Loading branch information
jsign committed Sep 29, 2023
1 parent 42bcdc4 commit 4d17fe8
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 98 deletions.
26 changes: 2 additions & 24 deletions src/bandersnatch/bandersnatch.zig
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ test "neg" {

test "serialize gen" {
const gen = ExtendedPoint.generator();
const serialised_point = gen.to_bytes();
const serialised_point = gen.toBytes();

// test vector taken from the rust code (see spec reference)
const expected = "18ae52a26618e7e1658499ad22c0792bf342be7b77113774c5340b2ccc32c129";
Expand Down Expand Up @@ -78,7 +78,7 @@ test "scalar mul minus one" {
const result = gen.scalarMul(scalar);

const expected = "e951ad5d98e7181e99d76452e0e343281295e38d90c602bf824892fd86742c4a";
const actual = std.fmt.bytesToHex(result.to_bytes(), std.fmt.Case.lower);
const actual = std.fmt.bytesToHex(result.toBytes(), std.fmt.Case.lower);
try std.testing.expectEqualSlices(u8, expected, &actual);
}

Expand Down Expand Up @@ -193,25 +193,3 @@ test "batch inv with error" {
const out = Fp.batchInv(&got_invs, &fes);
try std.testing.expectError(error.CantInvertZeroElement, out);
}

test "ExtendedPoint -> ExtendedPointNormalized" {
const g = ExtendedPoint.generator();
const scalars = [_]Fr{ Fr.fromInteger(3213), Fr.fromInteger(1212), Fr.fromInteger(4433) };

var points: [scalars.len]ExtendedPoint = undefined;
for (0..scalars.len) |i| {
points[i] = g.scalarMul(scalars[i]);
}

var expected: [scalars.len]ExtendedPointNormalized = undefined;
for (0..scalars.len) |i| {
expected[i] = ExtendedPointNormalized.fromExtendedPoint(points[i]);
}

var got: [scalars.len]ExtendedPointNormalized = undefined;
try ExtendedPointNormalized.fromExtendedPoints(&got, &points);

for (0..expected.len) |i| {
try std.testing.expect(expected[i].equal(got[i]));
}
}
1 change: 1 addition & 0 deletions src/bandersnatch/points/affine.zig
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pub fn scalarMul(point: AffinePoint, scalar: Fr) AffinePoint {
return result;
}

// TODO
pub fn to_bytes(self: AffinePoint) [32]u8 {
const mCompressedNegative = 0x80;
const mCompressedPositive = 0x00;
Expand Down
48 changes: 23 additions & 25 deletions src/bandersnatch/points/extended.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ pub const ExtendedPointNormalized = struct {
return comptime fromExtendedPoint(ExtendedPoint.identity());
}

pub fn generator() ExtendedPointNormalized {
return comptime fromExtendedPoint(ExtendedPoint.generator());
}

pub fn initUnsafe(x: Fp, y: Fp) ExtendedPointNormalized {
return ExtendedPointNormalized{
.x = x,
.y = y,
.t = x.mul(y),
};
}

pub fn fromExtendedPoint(p: ExtendedPoint) ExtendedPointNormalized {
const z_inv = p.z.inv().?;
const x = p.x.mul(z_inv);
Expand All @@ -24,29 +36,6 @@ pub const ExtendedPointNormalized = struct {
};
}

pub fn fromExtendedPoints(result: []ExtendedPointNormalized, points: []const ExtendedPoint) !void {
var accumulator = Fp.one();

for (0..points.len) |i| {
result[i].x = accumulator;
accumulator = Fp.mul(accumulator, points[i].z);
}

var accInverse = accumulator.inv().?;

for (0..points.len) |i| {
result[result.len - 1 - i].x = Fp.mul(result[result.len - 1 - i].x, accInverse);
accInverse = Fp.mul(accInverse, points[points.len - 1 - i].z);
}

for (0..points.len) |i| {
const z_inv = result[i].x;
result[i].x = Fp.mul(points[i].x, z_inv);
result[i].y = Fp.mul(points[i].y, z_inv);
result[i].t = Fp.mul(result[i].x, result[i].y);
}
}

pub fn equal(self: ExtendedPointNormalized, other: ExtendedPointNormalized) bool {
return self.x.equal(other.x) and self.y.equal(other.y) and self.t.equal(other.t);
}
Expand Down Expand Up @@ -74,6 +63,15 @@ pub const ExtendedPoint = struct {
};
}

pub fn fromExtendedPointNormalized(e: ExtendedPointNormalized) ExtendedPoint {
return ExtendedPoint{
.x = e.x,
.y = e.y,
.t = e.t,
.z = Fp.one(),
};
}

pub fn identity() ExtendedPoint {
const iden = comptime AffinePoint.identity();
return comptime initUnsafe(iden.x, iden.y);
Expand Down Expand Up @@ -116,6 +114,7 @@ pub const ExtendedPoint = struct {
return (p.x.mul(q.z).equal(p.z.mul(q.x))) and (p.y.mul(q.z).equal(q.y.mul(p.z)));
}

// TODO: change api to result receiver.
pub fn add(p: ExtendedPoint, q: ExtendedPoint) ExtendedPoint {
// https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html#addition-add-2008-hwcd
const a = Fp.mul(p.x, q.x);
Expand Down Expand Up @@ -222,8 +221,7 @@ pub const ExtendedPoint = struct {
}
}

// # Only used for testing purposes.
pub fn to_bytes(self: ExtendedPoint) [32]u8 {
pub fn toBytes(self: ExtendedPoint) [32]u8 {
return self.toAffine().to_bytes();
}
};
99 changes: 97 additions & 2 deletions src/banderwagon/banderwagon.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ const Bandersnatch = @import("../bandersnatch/bandersnatch.zig");
const Fp = Bandersnatch.Fp;
const AffinePoint = Bandersnatch.AffinePoint;
const ExtendedPoint = Bandersnatch.ExtendedPoint;
const ArrayList = std.ArrayList;
const ExtendedPointNormalized = Bandersnatch.ExtendedPointNormalized;

// Fr is the scalar field of the Banderwgaon group, which matches with the
// scalar field size of the Bandersnatch primer-ordered subgroup.
Expand All @@ -19,6 +19,12 @@ pub const Element = struct {
return Element{ .point = ExtendedPoint.initUnsafe(bytes) };
}

pub fn fromElementNormalized(e: ElementNormalized) Element {
return Element{
.point = ExtendedPoint.fromExtendedPointNormalized(e.point),
};
}

// fromBytes deserializes an element from a byte array.
// The spec serialization is the X coordinate in big endian form.
pub fn fromBytes(bytes: [BytesSize]u8) !Element {
Expand Down Expand Up @@ -64,6 +70,12 @@ pub const Element = struct {
self.point = ExtendedPoint.add(p.point, q.point);
}

pub fn mixedAdd(a: Element, b: ElementNormalized) Element {
return Element{
.point = ExtendedPoint.mixedAdd(a.point, b.point),
};
}

// sub subtracts two elements of the Banderwagon group.
pub fn sub(self: *Element, p: Element, q: Element) void {
self.point = ExtendedPoint.sub(p.point, q.point);
Expand Down Expand Up @@ -134,10 +146,10 @@ pub const Element = struct {
};

// msm computes the multi-scalar multiplication of scalars and points.
// TODO: change to Pippenger calls, and make everything use this.
pub fn msm(points: []const Element, scalars: []const Fr) Element {
std.debug.assert(scalars.len == points.len);

// TODO: optimize!
var res = Element.identity();
for (scalars, points) |scalar, point| {
if (scalar.isZero()) {
Expand Down Expand Up @@ -203,3 +215,86 @@ test "two torsion" {

try std.testing.expect(result.equal(gen));
}

pub const ElementNormalized = struct {
point: ExtendedPointNormalized,

// fromBytes deserializes an element from a byte array.
// The spec serialization is the X coordinate in big endian form.
pub fn fromBytes(bytes: [Element.BytesSize]u8) !ElementNormalized {
var bytes_le = bytes;
std.mem.reverse(u8, &bytes_le);
const x = Fp.fromBytes(bytes_le); // TODO: reject if bytes are not canonical?

if (Element.subgroupCheck(x) != 1) {
return error.NotInSubgroup;
}
const y = try AffinePoint.getYCoordinate(x, true);

return ElementNormalized{ .point = ExtendedPointNormalized.initUnsafe(x, y) };
}

pub fn generator() ElementNormalized {
return ElementNormalized{ .point = ExtendedPointNormalized.generator() };
}

pub fn fromElement(p: Element) ElementNormalized {
return ElementNormalized{
.point = ExtendedPointNormalized.fromExtendedPoint(p.point),
};
}

pub fn equal(a: ElementNormalized, b: ElementNormalized) bool {
return ExtendedPointNormalized.equal(a.point, b.point);
}

pub fn toBytes(self: ElementNormalized) [Element.BytesSize]u8 {
return Element.fromElementNormalized(self).toBytes();
}

// TODO: move this.
pub fn fromElements(result: []ElementNormalized, points: []const Element) void {
var accumulator = Fp.one();

for (0..points.len) |i| {
result[i].point.x = accumulator;
accumulator = Fp.mul(accumulator, points[i].point.z);
}

var accInverse = accumulator.inv().?;

for (0..points.len) |i| {
result[result.len - 1 - i].point.x = Fp.mul(result[result.len - 1 - i].point.x, accInverse);
accInverse = Fp.mul(accInverse, points[points.len - 1 - i].point.z);
}

for (0..points.len) |i| {
const z_inv = result[i].point.x;
result[i].point.x = Fp.mul(points[i].point.x, z_inv);
result[i].point.y = Fp.mul(points[i].point.y, z_inv);
result[i].point.t = Fp.mul(result[i].point.x, result[i].point.y);
}
}
};

test "Element -> ElementNormalized" {
const g = Element.generator();
const scalars = [_]Fr{ Fr.fromInteger(3213), Fr.fromInteger(1212), Fr.fromInteger(4433) };

var points: [scalars.len]Element = undefined;
for (0..scalars.len) |i| {
points[i] = g.scalarMul(scalars[i]);
}

var expected: [scalars.len]ElementNormalized = undefined;
for (0..scalars.len) |i| {
expected[i] = ElementNormalized.fromElement(points[i]);
}

var got: [scalars.len]ElementNormalized = undefined;
ElementNormalized.fromElements(&got, &points);

for (0..expected.len) |i| {
try std.testing.expect(expected[i].equal(got[i]));
}
}
7 changes: 3 additions & 4 deletions src/bench.zig
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ fn benchMultiproofs() !void {
const LagrangeBasis = polynomials.LagrangeBasis(crs.DomainSize, crs.Domain);

std.debug.print("Setting up multiproofs benchmark...\n", .{});
const N = 1;
const openings = [_]u16{1000};
const N = 25;
const openings = [_]u16{ 100, 1_000, 5_000, 10_000 };

var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer {
Expand Down Expand Up @@ -229,11 +229,10 @@ fn benchMultiproofs() !void {
defer allocator.free(verifier_queries);
for (0..num_openings) |i| {
verifier_queries[i] = multiproof.VerifierQuery{
.C = vec_openings[i].C,
.C = banderwagon.ElementNormalized.fromElement(vec_openings[i].C),
.z = vec_openings[i].z,
.y = vec_openings[i].poly_evaluations[vec_openings[i].z],
};
verifier_queries[i].C.normalize();
}
start = std.time.milliTimestamp();
const ok = try mproof.verifyProof(allocator, &verifier_transcript, verifier_queries, proof);
Expand Down
21 changes: 13 additions & 8 deletions src/crs/crs.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const sha256 = std.crypto.hash.sha2.Sha256;
const banderwagon = @import("../banderwagon/banderwagon.zig");
const msm = @import("../msm/precomp.zig");
const Element = banderwagon.Element;
const ElementNormalized = banderwagon.ElementNormalized;
const Fr = banderwagon.Fr;

// DomainSize is the size of the domain.
Expand All @@ -24,15 +25,15 @@ pub const Domain: [DomainSize]Fr = domain_elements: {
pub const CRS = struct {
const PrecompMSM = msm.PrecompMSM(2, 8);

Gs: [DomainSize]Element,
Q: Element,
Gs: [DomainSize]ElementNormalized,
Q: ElementNormalized,
precomp: PrecompMSM,

pub fn init(allocator: Allocator) !CRS {
const points = deserialize_vkt_points();
return CRS{
.Gs = points,
.Q = Element.generator(),
.Q = ElementNormalized.generator(),
.precomp = try PrecompMSM.init(allocator, &points),
};
}
Expand All @@ -46,16 +47,20 @@ pub const CRS = struct {
}

pub fn commitSlow(self: CRS, values: [DomainSize]Fr) Element {
return banderwagon.msm(&self.Gs, &values);
var gs: [DomainSize]Element = undefined;
for (0..DomainSize) |i| {
gs[i] = Element.fromElementNormalized(self.Gs[i]);
}
return banderwagon.msm(&gs, &values);
}
};

fn deserialize_vkt_points() [DomainSize]Element {
var points: [vkt_crs_points.len]Element = undefined;
fn deserialize_vkt_points() [DomainSize]ElementNormalized {
var points: [vkt_crs_points.len]ElementNormalized = undefined;
for (vkt_crs_points, 0..) |serialized_point, i| {
var g_be_bytes: [32]u8 = undefined;
_ = std.fmt.hexToBytes(&g_be_bytes, serialized_point) catch unreachable;
points[i] = Element.fromBytes(g_be_bytes) catch unreachable;
points[i] = ElementNormalized.fromBytes(g_be_bytes) catch unreachable;
}
return points;
}
Expand Down Expand Up @@ -85,7 +90,7 @@ test "crs is consistent" {
test "Gs cannot contain the generator" {
const crs = try CRS.init(std.testing.allocator);
defer crs.deinit();
const generator = Element.generator();
const generator = ElementNormalized.generator();
for (crs.Gs) |point| {
try std.testing.expect(!generator.equal(point));
}
Expand Down
10 changes: 7 additions & 3 deletions src/ipa/ipa.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const std = @import("std");
const banderwagon = @import("../banderwagon/banderwagon.zig");
const Element = banderwagon.Element;
const ElementNormalized = banderwagon.ElementNormalized;
const Fr = banderwagon.Fr;
const crs = @import("../crs/crs.zig");
const Transcript = @import("transcript.zig");
Expand Down Expand Up @@ -61,11 +62,14 @@ pub fn IPA(comptime VectorLength: comptime_int) type {

// Rescale Q.
const w = transcript.challengeScalar("w");
const q = xcrs.Q.scalarMul(w);
const q = Element.fromElementNormalized(xcrs.Q).scalarMul(w);

var L: [NUM_STEPS]Element = undefined;
var R: [NUM_STEPS]Element = undefined;
var _basis = xcrs.Gs;
var _basis: [crs.DomainSize]Element = undefined;
for (0..crs.DomainSize) |i| {
_basis[i] = Element.fromElementNormalized(xcrs.Gs[i]);
}
var basis: []Element = _basis[0..];

var step: usize = 0;
Expand Down Expand Up @@ -133,7 +137,7 @@ pub fn IPA(comptime VectorLength: comptime_int) type {
transcript.appendScalar(y, "output point");

const w = transcript.challengeScalar("w");
const q = xcrs.Q.scalarMul(w);
const q = Element.fromElementNormalized(xcrs.Q).scalarMul(w);

var commitment: Element = undefined;
commitment.add(C, q.scalarMul(y));
Expand Down
Loading

0 comments on commit 4d17fe8

Please sign in to comment.