Skip to content

Commit

Permalink
scaling operations
Browse files Browse the repository at this point in the history
  • Loading branch information
querolita committed Dec 12, 2024
1 parent 90154ee commit 1e70628
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 3 deletions.
251 changes: 249 additions & 2 deletions src/lib/provable/gadgets/eddsa.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import { arrayGet, assertNotVectorEquals } from './basic.js';
import { sliceField3 } from './bit-slices.js';
import { exists } from '../core/exists.js';
import { ProvableType } from '../types/provable-intf.js';
import { Point } from './elliptic-curve.js';
import { arrayGetGeneric, point, Point } from './elliptic-curve.js';

// external API
export { EllipticCurveTwisted, Point, Eddsa };
export { EllipticCurveTwisted, Eddsa };

const EllipticCurveTwisted = {
add,
Expand Down Expand Up @@ -226,3 +226,250 @@ function assertOnCurve(
}
ForeignField.assertMul(dTimesX2, y2, aTimesX2Minus1, f, message);
}

/**
* EC scalar multiplication, `scalar*point`
*
* The result is constrained to be not zero.
*/
function scale(
scalar: Field3,
point: Point,
Curve: CurveTwisted,
config: {
mode?: 'assert-nonzero' | 'assert-zero';
windowSize?: number;
multiples?: Point[];
} = { mode: 'assert-nonzero' }
) {
config.windowSize ??= Point.isConstant(point) ? 4 : 3;
return multiScalarMul([scalar], [point], Curve, [config], config.mode);
}

// check whether a point equals a constant point
// TODO implement the full case of two vars
function equals(p1: Point, p2: point, Curve: { modulus: bigint }) {
let xEquals = ForeignField.equals(p1.x, p2.x, Curve.modulus);
let yEquals = ForeignField.equals(p1.y, p2.y, Curve.modulus);
return xEquals.and(yEquals);
}

function multiScalarMulConstant(
scalars: Field3[],
points: Point[],
Curve: CurveTwisted,
mode: 'assert-nonzero' | 'assert-zero' = 'assert-nonzero'
): Point {
let n = points.length;
assert(scalars.length === n, 'Points and scalars lengths must match');
assertPositiveInteger(n, 'Expected at least 1 point and scalar');

// TODO dedicated MSM
let s = scalars.map(Field3.toBigint);
let P = points.map(Point.toBigint);
let sum: GroupTwisted = Curve.zero;
for (let i = 0; i < n; i++) {
sum = Curve.add(sum, Curve.scale(P[i], s[i]));
}
if (mode === 'assert-zero') {
assert(sum.infinity, 'scalar multiplication: expected zero result');
return Point.from(Curve.zero);
}
assert(!sum.infinity, 'scalar multiplication: expected non-zero result');
return Point.from(sum);
}

/**
* Multi-scalar multiplication:
*
* s_0 * P_0 + ... + s_(n-1) * P_(n-1)
*
* where P_i are any points.
*
* By default, we prove that the result is not zero.
*
* If you set the `mode` parameter to `'assert-zero'`, on the other hand,
* we assert that the result is zero and just return the constant zero point.
*
* Implementation: We double all points together and leverage a precomputed table of size 2^c to avoid all but every cth addition.
*
* Note: this algorithm targets a small number of points
*
* TODO: could use lookups for picking precomputed multiples, instead of O(2^c) provable switch
* TODO: custom bit representation for the scalar that avoids 0, to get rid of the degenerate addition case
*/
function multiScalarMul(
scalars: Field3[],
points: Point[],
Curve: CurveTwisted,
tableConfigs: (
| { windowSize?: number; multiples?: Point[] }
| undefined
)[] = [],
mode: 'assert-nonzero' | 'assert-zero' = 'assert-nonzero',
ia?: point
): Point {
let n = points.length;
assert(scalars.length === n, 'Points and scalars lengths must match');
assertPositiveInteger(n, 'Expected at least 1 point and scalar');
let useGlv = Curve.hasEndomorphism;

// constant case
if (scalars.every(Field3.isConstant) && points.every(Point.isConstant)) {
return multiScalarMulConstant(scalars, points, Curve, mode);
}

// parse or build point tables
let windowSizes = points.map((_, i) => tableConfigs[i]?.windowSize ?? 1);
let tables = points.map((P, i) =>
getPointTable(Curve, P, windowSizes[i], tableConfigs[i]?.multiples)
);

let maxBits = Curve.Scalar.sizeInBits;

// slice scalars
let scalarChunks = scalars.map((s, i) =>
sliceField3(s, { maxBits, chunkSize: windowSizes[i] })
);

// initialize sum to the initial aggregator, which is expected to be unrelated
// to any point that this gadget is used with
// note: this is a trick to ensure _completeness_ of the gadget
// soundness follows because add() and double() are sound, on all inputs that
// are valid non-zero curve points
ia ??= initialAggregator(Curve);
let sum = Point.from(ia);

for (let i = maxBits - 1; i >= 0; i--) {
// add in multiple of each point
for (let j = 0; j < n; j++) {
let windowSize = windowSizes[j];
if (i % windowSize === 0) {
// pick point to add based on the scalar chunk
let sj = scalarChunks[j][i / windowSize];
let sjP =
windowSize === 1
? points[j]
: arrayGetGeneric(Point.provable, tables[j], sj);

// ec addition
let added = add(sum, sjP, Curve);

// handle degenerate case
// (if sj = 0, Gj is all zeros and the add result is garbage)
sum = Provable.if(sj.equals(0), Point, sum, added);
}
}

if (i === 0) break;

// jointly double all points
// (note: the highest couple of bits will not create any constraints because
// sum is constant; no need to handle that explicitly)
sum = double(sum, Curve);
}

// the sum is now 2^(b-1)*IA + sum_i s_i*P_i
// we assert that sum != 2^(b-1)*IA, and add -2^(b-1)*IA to get our result
let iaFinal = Curve.scale(Curve.fromNonzero(ia), 1n << BigInt(maxBits - 1));
let isZero = equals(sum, iaFinal, Curve);

if (mode === 'assert-nonzero') {
isZero.assertFalse();
sum = add(sum, Point.from(Curve.negate(iaFinal)), Curve);
} else {
isZero.assertTrue();
// for type consistency with the 'assert-nonzero' case
sum = Point.from(Curve.zero);
}

return sum;
}

/**
* Given a point P, create the list of multiples [0, P, 2P, 3P, ..., (2^windowSize-1) * P].
* This method is provable, but won't create any constraints given a constant point.
*/
function getPointTable(
Curve: CurveTwisted,
P: Point,
windowSize: number,
table?: Point[]
): Point[] {
assertPositiveInteger(windowSize, 'invalid window size');
let n = 1 << windowSize; // n >= 2

assert(table === undefined || table.length === n, 'invalid table');
if (table !== undefined) return table;

table = [Point.from(Curve.zero), P];
if (n === 2) return table;

let Pi = double(P, Curve);
table.push(Pi);
for (let i = 3; i < n; i++) {
Pi = add(Pi, P, Curve);
table.push(Pi);
}
return table;
}

/**
* For EC scalar multiplication we use an initial point which is subtracted
* at the end, to avoid encountering the point at infinity.
*
* This is a simple hash-to-group algorithm which finds that initial point.
* It's important that this point has no known discrete logarithm so that nobody
* can create an invalid proof of EC scaling.
*/
function initialAggregator(Curve: CurveTwisted) {
// hash that identifies the curve
let h = sha256.create();
h.update('initial-aggregator');
h.update(bigIntToBytes(Curve.modulus));
h.update(bigIntToBytes(Curve.order));
h.update(bigIntToBytes(Curve.a));
h.update(bigIntToBytes(Curve.d));
let bytes = h.array();

// bytes represent a 256-bit number
// use that as x coordinate
const F = Curve.Field;
let x = F.mod(bytesToBigInt(bytes));
return simpleMapToCurve(x, Curve);
}

function random(Curve: CurveTwisted) {
let x = Curve.Field.random();
return simpleMapToCurve(x, Curve);
}

/**
* Given an x coordinate (base field element), increment it until we find one with
* a y coordinate that satisfies the curve equation, and return the point.
*
* If the curve has a cofactor, multiply by it to get a point in the correct subgroup.
*/
function simpleMapToCurve(x: bigint, Curve: CurveTwisted) {
const F = Curve.Field;
let y: bigint | undefined = undefined;

// increment x until we find a y coordinate
while (y === undefined) {
x = F.add(x, 1n);
// solve y^2 = (1 - a * x^2)/(1 - d * x^2)
let x2 = F.square(x);
let num = F.sub(1n, F.mul(x2, Curve.a));
let den = F.sub(1n, F.mul(x2, Curve.d));
if (den == 0n) continue;
let y2 = F.div(num, den)!; // guaranteed that den has an inverse
y = F.sqrt(y2);
}
let p = { x, y, infinity: false };

// clear cofactor
if (Curve.hasCofactor) {
p = Curve.scale(p, Curve.cofactor!);
}
return p;
}
8 changes: 7 additions & 1 deletion src/lib/provable/gadgets/elliptic-curve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ import { ProvableType } from '../types/provable-intf.js';
export { EllipticCurve, Point, Ecdsa };

// internal API
export { verifyEcdsaConstant, initialAggregator, simpleMapToCurve };
export {
verifyEcdsaConstant,
initialAggregator,
simpleMapToCurve,
arrayGetGeneric,
point,
};

const EllipticCurve = {
add,
Expand Down

0 comments on commit 1e70628

Please sign in to comment.