diff --git a/src/agents/coordinator.ts b/src/agents/coordinator.ts index 03c49b9e9..ccab3cf26 100644 --- a/src/agents/coordinator.ts +++ b/src/agents/coordinator.ts @@ -6,9 +6,10 @@ import { Coordinator__factory, } from '../../types/ethers-contracts'; import { BLS12381 } from '../../types/ethers-contracts/Coordinator'; +import { ChecksumAddress } from '../types'; import { fromHexString } from '../utils'; -import { getContract } from './contracts'; +import { DEFAULT_WAIT_N_CONFIRMATIONS, getContract } from './contracts'; export interface CoordinatorRitual { initiator: string; @@ -57,6 +58,20 @@ export class DkgCoordinatorAgent { }); } + public static async initializeRitual( + provider: ethers.providers.Web3Provider, + providers: ChecksumAddress[] + ): Promise { + const Coordinator = await this.connectReadWrite(provider); + const tx = await Coordinator.initiateRitual(providers); + const txReceipt = await tx.wait(DEFAULT_WAIT_N_CONFIRMATIONS); + const [ritualStartEvent] = txReceipt.events ?? []; + if (!ritualStartEvent) { + throw new Error('Ritual start event not found'); + } + return ritualStartEvent.args?.ritualId.toNumber(); + } + public static async getRitual( provider: ethers.providers.Web3Provider, ritualId: number @@ -73,12 +88,36 @@ export class DkgCoordinatorAgent { return await Coordinator.getRitualState(ritualId); } + public static async onRitualEndEvent( + provider: ethers.providers.Web3Provider, + ritualId: number, + callback: (successful: boolean) => void + ): Promise { + const Coordinator = await this.connectReadOnly(provider); + // We leave `initiator` undefined because we don't care who the initiator is + // We leave `successful` undefined because we don't care if the ritual was successful + const eventFilter = Coordinator.filters.EndRitual( + ritualId, + undefined, + undefined + ); + Coordinator.once(eventFilter, (_ritualId, _initiator, successful) => { + callback(successful); + }); + } + private static async connectReadOnly( provider: ethers.providers.Web3Provider ) { return await this.connect(provider); } + private static async connectReadWrite( + web3Provider: ethers.providers.Web3Provider + ) { + return await this.connect(web3Provider, web3Provider.getSigner()); + } + private static async connect( provider: ethers.providers.Web3Provider, signer?: ethers.providers.JsonRpcSigner diff --git a/src/characters/cbd-recipient.ts b/src/characters/cbd-recipient.ts index 8280eb0c8..89e96c271 100644 --- a/src/characters/cbd-recipient.ts +++ b/src/characters/cbd-recipient.ts @@ -48,7 +48,7 @@ export class CbdTDecDecrypter { return new CbdTDecDecrypter( new Porter(porterUri), dkgRitual.id, - dkgRitual.threshold + dkgRitual.dkgParams.threshold ); } diff --git a/src/dkg.ts b/src/dkg.ts index 9108d105b..5de7ee4cd 100644 --- a/src/dkg.ts +++ b/src/dkg.ts @@ -13,8 +13,9 @@ import { } from '@nucypher/nucypher-core'; import { ethers } from 'ethers'; -import { DkgCoordinatorAgent } from './agents/coordinator'; -import { bytesEquals, fromHexString } from './utils'; +import { DkgCoordinatorAgent, DkgRitualState } from './agents/coordinator'; +import { ChecksumAddress } from './types'; +import { bytesEquals, fromHexString, objectEquals } from './utils'; // TODO: Expose from @nucypher/nucypher-core export enum FerveoVariant { @@ -50,33 +51,47 @@ export function getCombineDecryptionSharesFunction( } } +export type DkgRitualParameters = { + sharesNum: number; + threshold: number; +}; + export interface DkgRitualJSON { id: number; dkgPublicKey: Uint8Array; - threshold: number; + dkgParams: DkgRitualParameters; + state: DkgRitualState; } export class DkgRitual { constructor( public readonly id: number, public readonly dkgPublicKey: DkgPublicKey, - public readonly threshold: number + public readonly dkgParams: DkgRitualParameters, + public readonly state: DkgRitualState ) {} public toObj(): DkgRitualJSON { return { id: this.id, dkgPublicKey: this.dkgPublicKey.toBytes(), - threshold: this.threshold, + dkgParams: this.dkgParams, + state: this.state, }; } public static fromObj({ id, dkgPublicKey, - threshold, + dkgParams, + state, }: DkgRitualJSON): DkgRitual { - return new DkgRitual(id, DkgPublicKey.fromBytes(dkgPublicKey), threshold); + return new DkgRitual( + id, + DkgPublicKey.fromBytes(dkgPublicKey), + dkgParams, + state + ); } public equals(other: DkgRitual): boolean { @@ -84,32 +99,85 @@ export class DkgRitual { this.id === other.id && // TODO: Replace with `equals` after https://github.com/nucypher/nucypher-core/issues/56 is fixed bytesEquals(this.dkgPublicKey.toBytes(), other.dkgPublicKey.toBytes()) && - this.threshold === other.threshold + objectEquals(this.dkgParams, other.dkgParams) && + this.state === other.state ); } } +// TODO: Currently, we're assuming that the threshold is always `floor(sharesNum / 2) + 1`. +// https://github.com/nucypher/nucypher/issues/3095 +const assumedThreshold = (sharesNum: number): number => + Math.floor(sharesNum / 2) + 1; + export class DkgClient { - // TODO: Update API: Replace with getExistingRitual and support ritualId in Strategy public static async initializeRitual( web3Provider: ethers.providers.Web3Provider, - ritualParams: { - shares: number; - threshold: number; + ursulas: ChecksumAddress[], + waitUntilEnd = false + ): Promise { + const ritualId = await DkgCoordinatorAgent.initializeRitual( + web3Provider, + ursulas + ); + + if (waitUntilEnd) { + const isSuccessful = await DkgClient.waitUntilRitualEnd( + web3Provider, + ritualId + ); + if (!isSuccessful) { + const ritualState = await DkgCoordinatorAgent.getRitualState( + web3Provider, + ritualId + ); + throw new Error( + `Ritual initialization failed. Ritual id ${ritualId} is in state ${ritualState}` + ); + } } + + return this.getExistingRitual(web3Provider, ritualId); + } + + private static waitUntilRitualEnd = async ( + web3Provider: ethers.providers.Web3Provider, + ritualId: number + ): Promise => { + return new Promise((resolve, reject) => { + const callback = (successful: boolean) => { + if (successful) { + resolve(true); + } else { + reject(); + } + }; + DkgCoordinatorAgent.onRitualEndEvent(web3Provider, ritualId, callback); + }); + }; + + public static async getExistingRitual( + web3Provider: ethers.providers.Web3Provider, + ritualId: number ): Promise { - const ritualId = 2; + const ritualState = await DkgCoordinatorAgent.getRitualState( + web3Provider, + ritualId + ); const ritual = await DkgCoordinatorAgent.getRitual(web3Provider, ritualId); const dkgPkBytes = new Uint8Array([ ...fromHexString(ritual.publicKey.word0), ...fromHexString(ritual.publicKey.word1), ]); - - return { - id: ritualId, - dkgPublicKey: DkgPublicKey.fromBytes(dkgPkBytes), - threshold: ritualParams.threshold, - } as DkgRitual; + return new DkgRitual( + ritualId, + DkgPublicKey.fromBytes(dkgPkBytes), + { + sharesNum: ritual.dkgSize, + threshold: assumedThreshold(ritual.dkgSize), + }, + ritualState + ); } public static async verifyRitual( diff --git a/src/sdk/strategy/cbd-strategy.ts b/src/sdk/strategy/cbd-strategy.ts index 87d424348..2b6887371 100644 --- a/src/sdk/strategy/cbd-strategy.ts +++ b/src/sdk/strategy/cbd-strategy.ts @@ -32,13 +32,9 @@ export class CbdStrategy { public async deploy( web3Provider: ethers.providers.Web3Provider ): Promise { - const dkgRitualParams = { - threshold: this.cohort.configuration.threshold, - shares: this.cohort.configuration.shares, - }; const dkgRitual = await DkgClient.initializeRitual( web3Provider, - dkgRitualParams + this.cohort.ursulaAddresses ); return DeployedCbdStrategy.create(this.cohort, dkgRitual); } diff --git a/test/unit/cbd-strategy.test.ts b/test/unit/cbd-strategy.test.ts index 801379ae1..9723b328d 100644 --- a/test/unit/cbd-strategy.test.ts +++ b/test/unit/cbd-strategy.test.ts @@ -52,7 +52,7 @@ async function makeDeployedCbdStrategy() { const strategy = await makeCbdStrategy(); const mockedDkg = fakeDkgFlow(variant, 0, 4, 4); - const mockedDkgRitual = fakeDkgRitual(mockedDkg, mockedDkg.threshold); + const mockedDkgRitual = fakeDkgRitual(mockedDkg); const web3Provider = fakeWeb3Provider(aliceSecretKey.toBEBytes()); const getUrsulasSpy = mockGetUrsulas(ursulas); const initializeRitualSpy = mockInitializeRitual(mockedDkgRitual); diff --git a/test/utils.ts b/test/utils.ts index c84da1f18..4ffaa1ba2 100644 --- a/test/utils.ts +++ b/test/utils.ts @@ -1,14 +1,25 @@ -// Disabling some of the eslint rules for conveninence. +// Disabling some of the eslint rules for convenience. /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable @typescript-eslint/no-unused-vars */ import { Block } from '@ethersproject/providers'; import { + AggregatedTranscript, Capsule, CapsuleFrag, + Ciphertext, + combineDecryptionSharesPrecomputed, + combineDecryptionSharesSimple, + DecryptionSharePrecomputed, + DecryptionShareSimple, + decryptWithSharedSecret, + Dkg, EncryptedThresholdDecryptionResponse, EncryptedTreasureMap, + EthereumAddress, ferveoEncrypt, FerveoPublicKey, + Keypair, PublicKey, reencrypt, SecretKey, @@ -16,23 +27,11 @@ import { SessionStaticKey, SessionStaticSecret, ThresholdDecryptionResponse, - VerifiedCapsuleFrag, - VerifiedKeyFrag, -} from '@nucypher/nucypher-core'; -import { - AggregatedTranscript, - Ciphertext, - combineDecryptionSharesPrecomputed, - combineDecryptionSharesSimple, - DecryptionSharePrecomputed, - DecryptionShareSimple, - decryptWithSharedSecret, - Dkg, - EthereumAddress, - Keypair, Transcript, Validator, ValidatorMessage, + VerifiedCapsuleFrag, + VerifiedKeyFrag, } from '@nucypher/nucypher-core'; import axios from 'axios'; import { ethers, providers, Wallet } from 'ethers'; @@ -501,19 +500,26 @@ export const mockRandomSessionStaticSecret = (secret: SessionStaticSecret) => { export const fakeRitualId = 0; -export const fakeDkgRitual = (ritual: { dkg: Dkg }, threshold: number) => { - return new DkgRitual(fakeRitualId, ritual.dkg.publicKey(), threshold); +export const fakeDkgRitual = (ritual: { + dkg: Dkg; + sharesNum: number; + threshold: number; +}) => { + return new DkgRitual( + fakeRitualId, + ritual.dkg.publicKey(), + { + sharesNum: ritual.sharesNum, + threshold: ritual.threshold, + }, + DkgRitualState.FINALIZED + ); }; -export const mockInitializeRitual = (fakeRitual: unknown) => { - return ( - jest - .spyOn(DkgClient, 'initializeRitual') - // eslint-disable-next-line @typescript-eslint/no-unused-vars - .mockImplementation((_web3Provider, _ritualParams) => { - return Promise.resolve(fakeRitual) as Promise; - }) - ); +export const mockInitializeRitual = (dkgRitual: DkgRitual) => { + return jest.spyOn(DkgClient, 'initializeRitual').mockImplementation(() => { + return Promise.resolve(dkgRitual); + }); }; export const makeCohort = async (ursulas: Ursula[]) => { @@ -529,24 +535,21 @@ export const makeCohort = async (ursulas: Ursula[]) => { }; export const mockGetRitualState = (state = DkgRitualState.FINALIZED) => { - return jest.spyOn(DkgCoordinatorAgent, 'getRitualState').mockImplementation( - // eslint-disable-next-line @typescript-eslint/no-unused-vars - (_provider, _ritualId) => Promise.resolve(state) - ); + return jest + .spyOn(DkgCoordinatorAgent, 'getRitualState') + .mockImplementation((_provider, _ritualId) => Promise.resolve(state)); }; export const mockVerifyRitual = (isValid = true) => { - return jest.spyOn(DkgClient, 'verifyRitual').mockImplementation( - // eslint-disable-next-line @typescript-eslint/no-unused-vars - (_provider, _ritualId) => Promise.resolve(isValid) - ); + return jest + .spyOn(DkgClient, 'verifyRitual') + .mockImplementation((_provider, _ritualId) => Promise.resolve(isValid)); }; export const mockGetParticipantPublicKey = (pk = fakeFerveoPublicKey()) => { - return jest.spyOn(DkgClient, 'getParticipantPublicKey').mockImplementation( - // eslint-disable-next-line @typescript-eslint/no-unused-vars - (_address) => pk - ); + return jest + .spyOn(DkgClient, 'getParticipantPublicKey') + .mockImplementation((_address) => pk); }; export const fakeFerveoPublicKey = (): FerveoPublicKey => {