Skip to content

Commit

Permalink
msm: apply a trick I figured out for extra performance
Browse files Browse the repository at this point in the history
Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com>
  • Loading branch information
jsign committed Sep 29, 2023
1 parent e9b5637 commit 2bc6ebb
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 45 deletions.
6 changes: 3 additions & 3 deletions src/bandersnatch/bandersnatch.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const std = @import("std");
const BandersnatchFields = @import("../fields/fields.zig").BandersnatchFields;
const extended_points = @import("points/extended.zig");
const extendedpoints = @import("points/extended.zig");

// Bandersnatch base and scalar finite fields.
pub const Fp = BandersnatchFields.BaseField;
Expand All @@ -12,8 +12,8 @@ pub const D = Fp.fromInteger(138827208126141220649022263972958607803).div(Fp.fro

// Points.
pub const AffinePoint = @import("points/affine.zig");
pub const ExtendedPoint = extended_points.ExtendedPoint;
pub const ExtendedPointNormalized = extended_points.ExtendedPointNormalized;
pub const ExtendedPoint = extendedpoints.ExtendedPoint;
pub const ExtendedPointMSM = extendedpoints.ExtendedPointMSM;

// Errors
pub const CurveError = error{
Expand Down
37 changes: 14 additions & 23 deletions src/bandersnatch/points/extended.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,35 @@ const Fp = Bandersnatch.Fp;
const Fr = Bandersnatch.Fr;
const AffinePoint = Bandersnatch.AffinePoint;

pub const ExtendedPointNormalized = struct {
pub const ExtendedPointMSM = struct {
x: Fp,
y: Fp,
t: Fp,

pub fn identity() ExtendedPointNormalized {
pub fn identity() ExtendedPointMSM {
return comptime fromExtendedPoint(ExtendedPoint.identity());
}

pub fn generator() ExtendedPointNormalized {
pub fn generator() ExtendedPointMSM {
return comptime fromExtendedPoint(ExtendedPoint.generator());
}

pub fn initUnsafe(x: Fp, y: Fp) ExtendedPointNormalized {
return ExtendedPointNormalized{
pub fn initUnsafe(x: Fp, y: Fp) ExtendedPointMSM {
return ExtendedPointMSM{
.x = x,
.y = y,
.t = x.mul(y),
.t = x.mul(y).mul(Bandersnatch.D),
};
}

pub fn fromExtendedPoint(p: ExtendedPoint) ExtendedPointNormalized {
pub fn fromExtendedPoint(p: ExtendedPoint) ExtendedPointMSM {
const z_inv = p.z.inv().?;
const x = p.x.mul(z_inv);
const y = p.y.mul(z_inv);
return ExtendedPointNormalized{
.x = x,
.y = y,
.t = Fp.mul(x, y),
};
return initUnsafe(x, y);
}

pub fn equal(self: ExtendedPointNormalized, other: ExtendedPointNormalized) bool {
pub fn equal(self: ExtendedPointMSM, other: ExtendedPointMSM) bool {
return self.x.equal(other.x) and self.y.equal(other.y) and self.t.equal(other.t);
}
};
Expand All @@ -63,13 +59,8 @@ pub const ExtendedPoint = struct {
};
}

pub fn fromExtendedPointNormalized(e: ExtendedPointNormalized) ExtendedPoint {
return ExtendedPoint{
.x = e.x,
.y = e.y,
.t = e.t,
.z = Fp.one(),
};
pub fn fromExtendedPointMSM(e: ExtendedPointMSM) ExtendedPoint {
return initUnsafe(e.x, e.y);
}

pub fn identity() ExtendedPoint {
Expand Down Expand Up @@ -134,12 +125,12 @@ pub const ExtendedPoint = struct {
};
}

pub fn mixedAdd(p: ExtendedPoint, q: ExtendedPointNormalized) ExtendedPoint {
pub fn mixedMsmAdd(p: ExtendedPoint, q: ExtendedPointMSM) ExtendedPoint {
// https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html#addition-madd-2008-hwcd
const A = Fp.mul(p.x, q.x);
const B = Fp.mul(p.y, q.y);
const t0 = Fp.mul(Bandersnatch.D, q.t);
const C = Fp.mul(p.t, t0);
// const t0 = Fp.mul(Bandersnatch.D, q.t);
const C = Fp.mul(p.t, q.t);
const D = p.z;
const t1 = Fp.add(p.x, p.y);
const t2 = Fp.add(q.x, q.y);
Expand Down
24 changes: 12 additions & 12 deletions src/banderwagon/banderwagon.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ const Bandersnatch = @import("../bandersnatch/bandersnatch.zig");
const Fp = Bandersnatch.Fp;
const AffinePoint = Bandersnatch.AffinePoint;
const ExtendedPoint = Bandersnatch.ExtendedPoint;
const ExtendedPointNormalized = Bandersnatch.ExtendedPointNormalized;
const ExtendedPointMSM = Bandersnatch.ExtendedPointMSM;

// Fr is the scalar field of the Banderwgaon group, which matches with the
// scalar field size of the Bandersnatch primer-ordered subgroup.
Expand All @@ -21,7 +21,7 @@ pub const Element = struct {

pub fn fromElementNormalized(e: ElementNormalized) Element {
return Element{
.point = ExtendedPoint.fromExtendedPointNormalized(e.point),
.point = ExtendedPoint.fromExtendedPointMSM(e.point),
};
}

Expand Down Expand Up @@ -70,9 +70,9 @@ pub const Element = struct {
self.point = ExtendedPoint.add(p.point, q.point);
}

pub fn mixedAdd(a: Element, b: ElementNormalized) Element {
pub fn mixedMsmAdd(a: Element, b: ElementNormalized) Element {
return Element{
.point = ExtendedPoint.mixedAdd(a.point, b.point),
.point = ExtendedPoint.mixedMsmAdd(a.point, b.point),
};
}

Expand Down Expand Up @@ -217,7 +217,7 @@ test "two torsion" {
}

pub const ElementNormalized = struct {
point: ExtendedPointNormalized,
point: ExtendedPointMSM,

// fromBytes deserializes an element from a byte array.
// The spec serialization is the X coordinate in big endian form.
Expand All @@ -231,21 +231,21 @@ pub const ElementNormalized = struct {
}
const y = try AffinePoint.getYCoordinate(x, true);

return ElementNormalized{ .point = ExtendedPointNormalized.initUnsafe(x, y) };
return ElementNormalized{ .point = ExtendedPointMSM.initUnsafe(x, y) };
}

pub fn generator() ElementNormalized {
return ElementNormalized{ .point = ExtendedPointNormalized.generator() };
return ElementNormalized{ .point = ExtendedPointMSM.generator() };
}

pub fn fromElement(p: Element) ElementNormalized {
return ElementNormalized{
.point = ExtendedPointNormalized.fromExtendedPoint(p.point),
.point = ExtendedPointMSM.fromExtendedPoint(p.point),
};
}

pub fn equal(a: ElementNormalized, b: ElementNormalized) bool {
return ExtendedPointNormalized.equal(a.point, b.point);
return ExtendedPointMSM.equal(a.point, b.point);
}

pub fn toBytes(self: ElementNormalized) [Element.BytesSize]u8 {
Expand All @@ -270,9 +270,9 @@ pub const ElementNormalized = struct {

for (0..points.len) |i| {
const z_inv = result[i].point.x;
result[i].point.x = Fp.mul(points[i].point.x, z_inv);
result[i].point.y = Fp.mul(points[i].point.y, z_inv);
result[i].point.t = Fp.mul(result[i].point.x, result[i].point.y);
const x = Fp.mul(points[i].point.x, z_inv);
const y = Fp.mul(points[i].point.y, z_inv);
result[i].point = ExtendedPointMSM.initUnsafe(x, y);
}
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/crs/crs.zig
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ pub const CRS = struct {
const PrecompMSM = msm.PrecompMSM(2, 8);

Gs: [DomainSize]ElementNormalized,
Q: ElementNormalized,
Q: Element,
precomp: PrecompMSM,

pub fn init(allocator: Allocator) !CRS {
const points = deserialize_vkt_points();
return CRS{
.Gs = points,
.Q = ElementNormalized.generator(),
.Q = Element.generator(),
.precomp = try PrecompMSM.init(allocator, &points),
};
}
Expand Down
4 changes: 2 additions & 2 deletions src/ipa/ipa.zig
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub fn IPA(comptime VectorLength: comptime_int) type {

// Rescale Q.
const w = transcript.challengeScalar("w");
const q = Element.fromElementNormalized(xcrs.Q).scalarMul(w);
const q = xcrs.Q.scalarMul(w);

var L: [NUM_STEPS]Element = undefined;
var R: [NUM_STEPS]Element = undefined;
Expand Down Expand Up @@ -137,7 +137,7 @@ pub fn IPA(comptime VectorLength: comptime_int) type {
transcript.appendScalar(y, "output point");

const w = transcript.challengeScalar("w");
const q = Element.fromElementNormalized(xcrs.Q).scalarMul(w);
const q = xcrs.Q.scalarMul(w);

var commitment: Element = undefined;
commitment.add(C, q.scalarMul(y));
Expand Down
2 changes: 1 addition & 1 deletion src/msm/pippenger.zig
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub fn Pippenger(comptime c: comptime_int) type {
if (buckets[scalar_windows[i] - 1] == null) {
buckets[scalar_windows[i] - 1] = Element.identity();
}
buckets[scalar_windows[i] - 1] = Element.mixedAdd(buckets[scalar_windows[i] - 1].?, basis[i]);
buckets[scalar_windows[i] - 1] = Element.mixedMsmAdd(buckets[scalar_windows[i] - 1].?, basis[i]);
}

// Aggregate buckets.
Expand Down
4 changes: 2 additions & 2 deletions src/msm/precomp.zig
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ pub fn PrecompMSM(

if (curr_window_b_idx == b) {
if (curr_window_scalar > 0) {
accum = Element.mixedAdd(accum, self.table[curr_window_idx * window_size .. (curr_window_idx + 1) * window_size][curr_window_scalar]);
accum = Element.mixedMsmAdd(accum, self.table[curr_window_idx * window_size .. (curr_window_idx + 1) * window_size][curr_window_scalar]);
}
curr_window_idx += 1;

Expand All @@ -108,7 +108,7 @@ pub fn PrecompMSM(
}
}
if (curr_window_scalar > 0) {
accum = Element.mixedAdd(accum, self.table[curr_window_idx * window_size .. (curr_window_idx + 1) * window_size][curr_window_scalar]);
accum = Element.mixedMsmAdd(accum, self.table[curr_window_idx * window_size .. (curr_window_idx + 1) * window_size][curr_window_scalar]);
}
}

Expand Down

0 comments on commit 2bc6ebb

Please sign in to comment.