diff --git a/src/lib/provable/gadgets/eddsa.ts b/src/lib/provable/gadgets/eddsa.ts index 223e02f98..b3eb24301 100644 --- a/src/lib/provable/gadgets/eddsa.ts +++ b/src/lib/provable/gadgets/eddsa.ts @@ -22,10 +22,10 @@ import { arrayGet, assertNotVectorEquals } from './basic.js'; import { sliceField3 } from './bit-slices.js'; import { exists } from '../core/exists.js'; import { ProvableType } from '../types/provable-intf.js'; -import { Point } from './elliptic-curve.js'; +import { arrayGetGeneric, point, Point } from './elliptic-curve.js'; // external API -export { EllipticCurveTwisted, Point, Eddsa }; +export { EllipticCurveTwisted, Eddsa }; const EllipticCurveTwisted = { add, @@ -226,3 +226,250 @@ function assertOnCurve( } ForeignField.assertMul(dTimesX2, y2, aTimesX2Minus1, f, message); } + +/** + * EC scalar multiplication, `scalar*point` + * + * The result is constrained to be not zero. + */ +function scale( + scalar: Field3, + point: Point, + Curve: CurveTwisted, + config: { + mode?: 'assert-nonzero' | 'assert-zero'; + windowSize?: number; + multiples?: Point[]; + } = { mode: 'assert-nonzero' } +) { + config.windowSize ??= Point.isConstant(point) ? 4 : 3; + return multiScalarMul([scalar], [point], Curve, [config], config.mode); +} + +// check whether a point equals a constant point +// TODO implement the full case of two vars +function equals(p1: Point, p2: point, Curve: { modulus: bigint }) { + let xEquals = ForeignField.equals(p1.x, p2.x, Curve.modulus); + let yEquals = ForeignField.equals(p1.y, p2.y, Curve.modulus); + return xEquals.and(yEquals); +} + +function multiScalarMulConstant( + scalars: Field3[], + points: Point[], + Curve: CurveTwisted, + mode: 'assert-nonzero' | 'assert-zero' = 'assert-nonzero' +): Point { + let n = points.length; + assert(scalars.length === n, 'Points and scalars lengths must match'); + assertPositiveInteger(n, 'Expected at least 1 point and scalar'); + + // TODO dedicated MSM + let s = scalars.map(Field3.toBigint); + let P = points.map(Point.toBigint); + let sum: GroupTwisted = Curve.zero; + for (let i = 0; i < n; i++) { + sum = Curve.add(sum, Curve.scale(P[i], s[i])); + } + if (mode === 'assert-zero') { + assert(sum.infinity, 'scalar multiplication: expected zero result'); + return Point.from(Curve.zero); + } + assert(!sum.infinity, 'scalar multiplication: expected non-zero result'); + return Point.from(sum); +} + +/** + * Multi-scalar multiplication: + * + * s_0 * P_0 + ... + s_(n-1) * P_(n-1) + * + * where P_i are any points. + * + * By default, we prove that the result is not zero. + * + * If you set the `mode` parameter to `'assert-zero'`, on the other hand, + * we assert that the result is zero and just return the constant zero point. + * + * Implementation: We double all points together and leverage a precomputed table of size 2^c to avoid all but every cth addition. + * + * Note: this algorithm targets a small number of points + * + * TODO: could use lookups for picking precomputed multiples, instead of O(2^c) provable switch + * TODO: custom bit representation for the scalar that avoids 0, to get rid of the degenerate addition case + */ +function multiScalarMul( + scalars: Field3[], + points: Point[], + Curve: CurveTwisted, + tableConfigs: ( + | { windowSize?: number; multiples?: Point[] } + | undefined + )[] = [], + mode: 'assert-nonzero' | 'assert-zero' = 'assert-nonzero', + ia?: point +): Point { + let n = points.length; + assert(scalars.length === n, 'Points and scalars lengths must match'); + assertPositiveInteger(n, 'Expected at least 1 point and scalar'); + let useGlv = Curve.hasEndomorphism; + + // constant case + if (scalars.every(Field3.isConstant) && points.every(Point.isConstant)) { + return multiScalarMulConstant(scalars, points, Curve, mode); + } + + // parse or build point tables + let windowSizes = points.map((_, i) => tableConfigs[i]?.windowSize ?? 1); + let tables = points.map((P, i) => + getPointTable(Curve, P, windowSizes[i], tableConfigs[i]?.multiples) + ); + + let maxBits = Curve.Scalar.sizeInBits; + + // slice scalars + let scalarChunks = scalars.map((s, i) => + sliceField3(s, { maxBits, chunkSize: windowSizes[i] }) + ); + + // initialize sum to the initial aggregator, which is expected to be unrelated + // to any point that this gadget is used with + // note: this is a trick to ensure _completeness_ of the gadget + // soundness follows because add() and double() are sound, on all inputs that + // are valid non-zero curve points + ia ??= initialAggregator(Curve); + let sum = Point.from(ia); + + for (let i = maxBits - 1; i >= 0; i--) { + // add in multiple of each point + for (let j = 0; j < n; j++) { + let windowSize = windowSizes[j]; + if (i % windowSize === 0) { + // pick point to add based on the scalar chunk + let sj = scalarChunks[j][i / windowSize]; + let sjP = + windowSize === 1 + ? points[j] + : arrayGetGeneric(Point.provable, tables[j], sj); + + // ec addition + let added = add(sum, sjP, Curve); + + // handle degenerate case + // (if sj = 0, Gj is all zeros and the add result is garbage) + sum = Provable.if(sj.equals(0), Point, sum, added); + } + } + + if (i === 0) break; + + // jointly double all points + // (note: the highest couple of bits will not create any constraints because + // sum is constant; no need to handle that explicitly) + sum = double(sum, Curve); + } + + // the sum is now 2^(b-1)*IA + sum_i s_i*P_i + // we assert that sum != 2^(b-1)*IA, and add -2^(b-1)*IA to get our result + let iaFinal = Curve.scale(Curve.fromNonzero(ia), 1n << BigInt(maxBits - 1)); + let isZero = equals(sum, iaFinal, Curve); + + if (mode === 'assert-nonzero') { + isZero.assertFalse(); + sum = add(sum, Point.from(Curve.negate(iaFinal)), Curve); + } else { + isZero.assertTrue(); + // for type consistency with the 'assert-nonzero' case + sum = Point.from(Curve.zero); + } + + return sum; +} + +/** + * Given a point P, create the list of multiples [0, P, 2P, 3P, ..., (2^windowSize-1) * P]. + * This method is provable, but won't create any constraints given a constant point. + */ +function getPointTable( + Curve: CurveTwisted, + P: Point, + windowSize: number, + table?: Point[] +): Point[] { + assertPositiveInteger(windowSize, 'invalid window size'); + let n = 1 << windowSize; // n >= 2 + + assert(table === undefined || table.length === n, 'invalid table'); + if (table !== undefined) return table; + + table = [Point.from(Curve.zero), P]; + if (n === 2) return table; + + let Pi = double(P, Curve); + table.push(Pi); + for (let i = 3; i < n; i++) { + Pi = add(Pi, P, Curve); + table.push(Pi); + } + return table; +} + +/** + * For EC scalar multiplication we use an initial point which is subtracted + * at the end, to avoid encountering the point at infinity. + * + * This is a simple hash-to-group algorithm which finds that initial point. + * It's important that this point has no known discrete logarithm so that nobody + * can create an invalid proof of EC scaling. + */ +function initialAggregator(Curve: CurveTwisted) { + // hash that identifies the curve + let h = sha256.create(); + h.update('initial-aggregator'); + h.update(bigIntToBytes(Curve.modulus)); + h.update(bigIntToBytes(Curve.order)); + h.update(bigIntToBytes(Curve.a)); + h.update(bigIntToBytes(Curve.d)); + let bytes = h.array(); + + // bytes represent a 256-bit number + // use that as x coordinate + const F = Curve.Field; + let x = F.mod(bytesToBigInt(bytes)); + return simpleMapToCurve(x, Curve); +} + +function random(Curve: CurveTwisted) { + let x = Curve.Field.random(); + return simpleMapToCurve(x, Curve); +} + +/** + * Given an x coordinate (base field element), increment it until we find one with + * a y coordinate that satisfies the curve equation, and return the point. + * + * If the curve has a cofactor, multiply by it to get a point in the correct subgroup. + */ +function simpleMapToCurve(x: bigint, Curve: CurveTwisted) { + const F = Curve.Field; + let y: bigint | undefined = undefined; + + // increment x until we find a y coordinate + while (y === undefined) { + x = F.add(x, 1n); + // solve y^2 = (1 - a * x^2)/(1 - d * x^2) + let x2 = F.square(x); + let num = F.sub(1n, F.mul(x2, Curve.a)); + let den = F.sub(1n, F.mul(x2, Curve.d)); + if (den == 0n) continue; + let y2 = F.div(num, den)!; // guaranteed that den has an inverse + y = F.sqrt(y2); + } + let p = { x, y, infinity: false }; + + // clear cofactor + if (Curve.hasCofactor) { + p = Curve.scale(p, Curve.cofactor!); + } + return p; +} diff --git a/src/lib/provable/gadgets/elliptic-curve.ts b/src/lib/provable/gadgets/elliptic-curve.ts index c0cff7da5..64c4da509 100644 --- a/src/lib/provable/gadgets/elliptic-curve.ts +++ b/src/lib/provable/gadgets/elliptic-curve.ts @@ -27,7 +27,13 @@ import { ProvableType } from '../types/provable-intf.js'; export { EllipticCurve, Point, Ecdsa }; // internal API -export { verifyEcdsaConstant, initialAggregator, simpleMapToCurve }; +export { + verifyEcdsaConstant, + initialAggregator, + simpleMapToCurve, + arrayGetGeneric, + point, +}; const EllipticCurve = { add,