diff --git a/src/agents/coordinator.ts b/src/agents/coordinator.ts index 264e4ea50..198550c4b 100644 --- a/src/agents/coordinator.ts +++ b/src/agents/coordinator.ts @@ -1,4 +1,4 @@ -import { SessionStaticKey } from '@nucypher/nucypher-core'; +import { SessionStaticKey, Transcript } from '@nucypher/nucypher-core'; import { ethers } from 'ethers'; import { @@ -6,9 +6,10 @@ import { Coordinator__factory, } from '../../types/ethers-contracts'; import { BLS12381 } from '../../types/ethers-contracts/Coordinator'; +import { ChecksumAddress } from '../types'; import { fromHexString } from '../utils'; -import { getContract } from './contracts'; +import { DEFAULT_WAIT_N_CONFIRMATIONS, getContract } from './contracts'; export interface CoordinatorRitual { initiator: string; @@ -24,12 +25,22 @@ export interface CoordinatorRitual { export type DkgParticipant = { provider: string; aggregated: boolean; + transcript: Transcript; decryptionRequestStaticKey: SessionStaticKey; }; +export enum DkgRitualState { + NON_INITIATED, + AWAITING_TRANSCRIPTS, + AWAITING_AGGREGATIONS, + TIMEOUT, + INVALID, + FINALIZED, +} + export class DkgCoordinatorAgent { public static async getParticipants( - provider: ethers.providers.Provider, + provider: ethers.providers.Web3Provider, ritualId: number ): Promise { const Coordinator = await this.connectReadOnly(provider); @@ -39,6 +50,7 @@ export class DkgCoordinatorAgent { return { provider: participant.provider, aggregated: participant.aggregated, + transcript: Transcript.fromBytes(fromHexString(participant.transcript)), decryptionRequestStaticKey: SessionStaticKey.fromBytes( fromHexString(participant.decryptionRequestStaticKey) ), @@ -46,20 +58,85 @@ export class DkgCoordinatorAgent { }); } + public static async initializeRitual( + provider: ethers.providers.Web3Provider, + providers: ChecksumAddress[] + ): Promise { + const Coordinator = await this.connectReadWrite(provider); + const tx = await Coordinator.initiateRitual(providers); + const txReceipt = await tx.wait(DEFAULT_WAIT_N_CONFIRMATIONS); + const [ritualStartEvent] = txReceipt.events ?? []; + if (!ritualStartEvent) { + throw new Error('Ritual start event not found'); + } + return ritualStartEvent.args?.ritualId; + } + public static async getRitual( - provider: ethers.providers.Provider, + provider: ethers.providers.Web3Provider, ritualId: number ): Promise { const Coordinator = await this.connectReadOnly(provider); return Coordinator.rituals(ritualId); } - private static async connectReadOnly(provider: ethers.providers.Provider) { + public static async getTimeout( + provider: ethers.providers.Web3Provider + ): Promise { + const Coordinator = await this.connectReadOnly(provider); + const timeout = await Coordinator.timeout(); + return timeout; + } + + public static async getRitualState( + provider: ethers.providers.Web3Provider, + ritualId: number + ): Promise { + const Coordinator = await this.connectReadOnly(provider); + return await Coordinator.getRitualState(ritualId); + } + + public static async getRitualInitTime( + provider: ethers.providers.Web3Provider, + ritualId: number + ): Promise { + const Coordinator = await this.connectReadOnly(provider); + const ritual = await Coordinator.rituals(ritualId); + return ritual[2]; + } + + public static async onRitualEndEvent( + provider: ethers.providers.Web3Provider, + ritualId: number, + callback: (successful: boolean) => void + ): Promise { + const Coordinator = await this.connectReadOnly(provider); + // We leave `initiator` undefined because we don't care who the initiator is + // We leave `successful` undefined because we don't care if the ritual was successful + const eventFilter = Coordinator.filters.EndRitual( + ritualId, + undefined, + undefined + ); + Coordinator.once(eventFilter, (_ritualId, _initiator, successful) => { + callback(successful); + }); + } + + private static async connectReadOnly( + provider: ethers.providers.Web3Provider + ) { return await this.connect(provider); } + private static async connectReadWrite( + web3Provider: ethers.providers.Web3Provider + ) { + return await this.connect(web3Provider, web3Provider.getSigner()); + } + private static async connect( - provider: ethers.providers.Provider, + provider: ethers.providers.Web3Provider, signer?: ethers.providers.JsonRpcSigner ): Promise { const network = await provider.getNetwork(); diff --git a/src/agents/subscription-manager.ts b/src/agents/subscription-manager.ts index ac8a5a3c5..f997508c8 100644 --- a/src/agents/subscription-manager.ts +++ b/src/agents/subscription-manager.ts @@ -48,7 +48,7 @@ export class PreSubscriptionManagerAgent { } public static async getPolicyCost( - provider: ethers.providers.Provider, + provider: ethers.providers.Web3Provider, size: number, startTimestamp: number, endTimestamp: number @@ -61,7 +61,9 @@ export class PreSubscriptionManagerAgent { ); } - private static async connectReadOnly(provider: ethers.providers.Provider) { + private static async connectReadOnly( + provider: ethers.providers.Web3Provider + ) { return await this.connect(provider); } @@ -72,7 +74,7 @@ export class PreSubscriptionManagerAgent { } private static async connect( - provider: ethers.providers.Provider, + provider: ethers.providers.Web3Provider, signer?: ethers.providers.JsonRpcSigner ): Promise { const network = await provider.getNetwork(); diff --git a/src/characters/cbd-recipient.ts b/src/characters/cbd-recipient.ts index cef8d360c..89e96c271 100644 --- a/src/characters/cbd-recipient.ts +++ b/src/characters/cbd-recipient.ts @@ -12,9 +12,14 @@ import { } from '@nucypher/nucypher-core'; import { ethers } from 'ethers'; -import { DkgCoordinatorAgent, DkgParticipant } from '../agents/coordinator'; +import { + DkgCoordinatorAgent, + DkgParticipant, + DkgRitualState, +} from '../agents/coordinator'; import { ConditionExpression } from '../conditions'; import { + DkgClient, DkgRitual, FerveoVariant, getCombineDecryptionSharesFunction, @@ -43,7 +48,7 @@ export class CbdTDecDecrypter { return new CbdTDecDecrypter( new Porter(porterUri), dkgRitual.id, - dkgRitual.threshold + dkgRitual.dkgParams.threshold ); } @@ -52,13 +57,15 @@ export class CbdTDecDecrypter { provider: ethers.providers.Web3Provider, conditionExpr: ConditionExpression, variant: FerveoVariant, - ciphertext: Ciphertext + ciphertext: Ciphertext, + verifyRitual = true ): Promise { const decryptionShares = await this.retrieve( provider, conditionExpr, variant, - ciphertext + ciphertext, + verifyRitual ); const combineDecryptionSharesFn = @@ -73,16 +80,39 @@ export class CbdTDecDecrypter { // Retrieve decryption shares public async retrieve( - provider: ethers.providers.Web3Provider, + web3Provider: ethers.providers.Web3Provider, conditionExpr: ConditionExpression, variant: number, - ciphertext: Ciphertext + ciphertext: Ciphertext, + verifyRitual = true ): Promise { + const ritualState = await DkgCoordinatorAgent.getRitualState( + web3Provider, + this.ritualId + ); + if (ritualState !== DkgRitualState.FINALIZED) { + throw new Error( + `Ritual with id ${this.ritualId} is not finalized. Ritual state is ${ritualState}.` + ); + } + + if (verifyRitual) { + const isLocallyVerified = await DkgClient.verifyRitual( + web3Provider, + this.ritualId + ); + if (!isLocallyVerified) { + throw new Error( + `Ritual with id ${this.ritualId} has failed local verification.` + ); + } + } + const dkgParticipants = await DkgCoordinatorAgent.getParticipants( - provider, + web3Provider, this.ritualId ); - const contextStr = await conditionExpr.buildContext(provider).toJson(); + const contextStr = await conditionExpr.buildContext(web3Provider).toJson(); const { sharedSecrets, encryptedRequests } = this.makeDecryptionRequests( this.ritualId, variant, diff --git a/src/dkg.ts b/src/dkg.ts index f0ee81e41..89eb44f1e 100644 --- a/src/dkg.ts +++ b/src/dkg.ts @@ -1,15 +1,21 @@ import { + AggregatedTranscript, combineDecryptionSharesPrecomputed, combineDecryptionSharesSimple, DecryptionSharePrecomputed, DecryptionShareSimple, DkgPublicKey, + EthereumAddress, + FerveoPublicKey, SharedSecret, + Validator, + ValidatorMessage, } from '@nucypher/nucypher-core'; import { ethers } from 'ethers'; -import { DkgCoordinatorAgent } from './agents/coordinator'; -import { bytesEquals, fromHexString } from './utils'; +import { DkgCoordinatorAgent, DkgRitualState } from './agents/coordinator'; +import { ChecksumAddress } from './types'; +import { bytesEquals, fromHexString, objectEquals } from './utils'; // TODO: Expose from @nucypher/nucypher-core export enum FerveoVariant { @@ -45,33 +51,47 @@ export function getCombineDecryptionSharesFunction( } } +export type DkgRitualParameters = { + sharesNum: number; + threshold: number; +}; + export interface DkgRitualJSON { id: number; dkgPublicKey: Uint8Array; - threshold: number; + dkgParams: DkgRitualParameters; + state: DkgRitualState; } export class DkgRitual { constructor( public readonly id: number, public readonly dkgPublicKey: DkgPublicKey, - public readonly threshold: number + public readonly dkgParams: DkgRitualParameters, + public readonly state: DkgRitualState ) {} public toObj(): DkgRitualJSON { return { id: this.id, dkgPublicKey: this.dkgPublicKey.toBytes(), - threshold: this.threshold, + dkgParams: this.dkgParams, + state: this.state, }; } public static fromObj({ id, dkgPublicKey, - threshold, + dkgParams, + state, }: DkgRitualJSON): DkgRitual { - return new DkgRitual(id, DkgPublicKey.fromBytes(dkgPublicKey), threshold); + return new DkgRitual( + id, + DkgPublicKey.fromBytes(dkgPublicKey), + dkgParams, + state + ); } public equals(other: DkgRitual): boolean { @@ -79,51 +99,171 @@ export class DkgRitual { this.id === other.id && // TODO: Replace with `equals` after https://github.com/nucypher/nucypher-core/issues/56 is fixed bytesEquals(this.dkgPublicKey.toBytes(), other.dkgPublicKey.toBytes()) && - this.threshold === other.threshold + objectEquals(this.dkgParams, other.dkgParams) && + this.state === other.state ); } } +// TODO: Currently, we're assuming that the threshold is always `floor(sharesNum / 2) + 1`. +// https://github.com/nucypher/nucypher/issues/3095 +const assumedThreshold = (sharesNum: number): number => + Math.floor(sharesNum / 2) + 1; + export class DkgClient { - constructor(private readonly provider: ethers.providers.Web3Provider) {} - - // TODO: Update API: Replace with getExistingRitual and support ritualId in Strategy - public async initializeRitual(ritualParams: { - shares: number; - threshold: number; - }): Promise { - const ritualId = 2; - const ritual = await DkgCoordinatorAgent.getRitual(this.provider, ritualId); + public static async initializeRitual( + web3Provider: ethers.providers.Web3Provider, + ursulas: ChecksumAddress[], + waitUntilEnd = false + ): Promise { + const ritualId = await DkgCoordinatorAgent.initializeRitual( + web3Provider, + ursulas.sort() + ); + + if (waitUntilEnd) { + const initTimestamp = await DkgCoordinatorAgent.getRitualInitTime( + web3Provider, + ritualId + ); + const timeout = await DkgCoordinatorAgent.getTimeout(web3Provider); + const endTimestamp = initTimestamp + timeout; + + // Wait until the current time is past the endTime + while (Math.floor(Date.now() / 1000) < endTimestamp) { + await new Promise((resolve) => setTimeout(resolve, 1000)); // Wait for 1 second before checking again + } + + await this.waitForBlockTime(web3Provider, endTimestamp); + + try { + this.performRitual(web3Provider, ritualId); + } catch (error) { + const ritualState = await DkgCoordinatorAgent.getRitualState( + web3Provider, + ritualId + ); + + throw new Error( + `Ritual initialization failed. Ritual id ${ritualId} is in state ${ritualState}` + ); + } + } + + return ritualId; + } + + private static performRitual = async ( + web3Provider: ethers.providers.Web3Provider, + ritualId: number + ): Promise => { + const isSuccessful = await DkgClient.waitUntilRitualEnd( + web3Provider, + ritualId + ); + + if (!isSuccessful) { + throw new Error(`Ritual initialization failed. Ritual id ${ritualId}`); + } + }; + + private static waitForBlockTime = async ( + web3Provider: ethers.providers.Web3Provider, + endTimestamp: number + ): Promise => { + let currentBlockTime; + do { + const block = await web3Provider.getBlock('latest'); + currentBlockTime = block.timestamp; + if (currentBlockTime < endTimestamp) { + await new Promise((resolve) => setTimeout(resolve, 1000)); // Wait for 1 second before checking again + } + } while (currentBlockTime < endTimestamp); + }; + + private static waitUntilRitualEnd = async ( + web3Provider: ethers.providers.Web3Provider, + ritualId: number + ): Promise => { + return new Promise((resolve, reject) => { + const callback = (successful: boolean) => { + if (successful) { + resolve(true); + } else { + reject(); + } + }; + DkgCoordinatorAgent.onRitualEndEvent(web3Provider, ritualId, callback); + }); + }; + + public static async getExistingRitual( + web3Provider: ethers.providers.Web3Provider, + ritualId: number + ): Promise { + const ritualState = await DkgCoordinatorAgent.getRitualState( + web3Provider, + ritualId + ); + const ritual = await DkgCoordinatorAgent.getRitual(web3Provider, ritualId); const dkgPkBytes = new Uint8Array([ ...fromHexString(ritual.publicKey.word0), ...fromHexString(ritual.publicKey.word1), ]); + return new DkgRitual( + ritualId, + DkgPublicKey.fromBytes(dkgPkBytes), + { + sharesNum: ritual.dkgSize, + threshold: assumedThreshold(ritual.dkgSize), + }, + ritualState + ); + } - return { - id: ritualId, - dkgPublicKey: DkgPublicKey.fromBytes(dkgPkBytes), - threshold: ritualParams.threshold, - } as DkgRitual; + public static async verifyRitual( + web3Provider: ethers.providers.Web3Provider, + ritualId: number + ): Promise { + const ritual = await DkgCoordinatorAgent.getRitual(web3Provider, ritualId); + const participants = await DkgCoordinatorAgent.getParticipants( + web3Provider, + ritualId + ); + + const validatorMessages = participants.map((p) => { + const validatorAddress = EthereumAddress.fromString(p.provider); + // TODO: Replace with real keys + // const publicKey = FerveoPublicKey.fromBytes(fromHexString(p.???)); + const publicKey = DkgClient.getParticipantPublicKey(p.provider); + const validator = new Validator(validatorAddress, publicKey); + return new ValidatorMessage(validator, p.transcript); + }); + const aggregate = new AggregatedTranscript(validatorMessages); + + return aggregate.verify(ritual.dkgSize, validatorMessages); } - // TODO: Without Validator public key in Coordinator, we cannot verify the - // transcript. We need to add it to the Coordinator (nucypher-contracts #77). - // public async verifyRitual(ritualId: number): Promise { - // const ritual = await DkgCoordinatorAgent.getRitual(this.provider, ritualId); - // const participants = await DkgCoordinatorAgent.getParticipants( - // this.provider, - // ritualId - // ); - // - // const validatorMessages = participants.map((p) => { - // const validatorAddress = EthereumAddress.fromString(p.provider); - // const publicKey = FerveoPublicKey.fromBytes(fromHexString(p.???)); - // const validator = new Validator(validatorAddress, publicKey); - // const transcript = Transcript.fromBytes(fromHexString(p.transcript)); - // return new ValidatorMessage(validator, transcript); - // }); - // const aggregate = new AggregatedTranscript(validatorMessages); - // - // return aggregate.verify(ritual.dkgSize, validatorMessages); - // } + public static getParticipantPublicKey = (address: string) => { + // TODO: Without Validator public key in Coordinator, we cannot verify the + // transcript. We need to add it to the Coordinator (nucypher-contracts #77). + const participantPublicKeys: Record = { + '0x210eeAC07542F815ebB6FD6689637D8cA2689392': FerveoPublicKey.fromBytes( + fromHexString( + '6000000000000000ace9d7567b26dafc512b2303cfdaa872850c62b100078ddeaabf8408c7308b3a43dfeb88375c21ef63230fb4008ce7e908764463c6765e556f9b03009eb1757d179eaa26bf875332807cc070d62a385ed2e66e09f4f4766451da12779a09036e' + ) + ), + '0xb15d5A4e2be34f4bE154A1b08a94Ab920FfD8A41': FerveoPublicKey.fromBytes( + fromHexString( + '60000000000000008b373fdb6b43e9dca028bd603c2bf90f0e008ec83ff217a8d7bc006b585570e6ab1ce761bad0e21c1aed1363286145f61134ed0ab53f4ebaa05036396c57f6e587f33d49667c1003cd03b71ad651b09dd4791bc631eaef93f1b313bbee7bd63a' + ) + ), + }; + + const publicKey = participantPublicKeys[address]; + if (!publicKey) { + throw new Error(`No public key for participant: ${address}`); + } + return publicKey; + }; } diff --git a/src/policies/policy.ts b/src/policies/policy.ts index 1ccca6a90..b17d6b81f 100644 --- a/src/policies/policy.ts +++ b/src/policies/policy.ts @@ -118,6 +118,11 @@ export class BlockchainPolicy { public async generatePreEnactedPolicy( ursulas: readonly Ursula[] ): Promise { + if (ursulas.length != this.verifiedKFrags.length) { + throw new Error( + `Number of ursulas must match number of verified kFrags: ${this.verifiedKFrags.length}` + ); + } const treasureMap = this.makeTreasureMap(ursulas, this.verifiedKFrags); const encryptedTreasureMap = this.encryptTreasureMap(treasureMap); // const revocationKit = new RevocationKit(treasureMap, this.publisher.signer); diff --git a/src/sdk/cohort.ts b/src/sdk/cohort.ts index 8571f3816..070027186 100644 --- a/src/sdk/cohort.ts +++ b/src/sdk/cohort.ts @@ -26,6 +26,15 @@ export class Cohort { include: string[] = [], exclude: string[] = [] ) { + if (configuration.threshold > configuration.shares) { + throw new Error('Threshold cannot be greater than the number of shares'); + } + // TODO: Remove this limitation after `nucypher-core@0.11.0` deployment + const isMultipleOf2 = (n: number) => n % 2 === 0; + if (!isMultipleOf2(configuration.shares)) { + throw new Error('Number of shares must be a multiple of 2'); + } + const porter = new Porter(configuration.porterUri); const ursulas = await porter.getUrsulas( configuration.shares, diff --git a/src/sdk/strategy/cbd-strategy.ts b/src/sdk/strategy/cbd-strategy.ts index 0632eeaef..dbc2a6766 100644 --- a/src/sdk/strategy/cbd-strategy.ts +++ b/src/sdk/strategy/cbd-strategy.ts @@ -30,14 +30,21 @@ export class CbdStrategy { } public async deploy( - provider: ethers.providers.Web3Provider + web3Provider: ethers.providers.Web3Provider, + ritualId?: number ): Promise { - const dkgRitualParams = { - threshold: this.cohort.configuration.threshold, - shares: this.cohort.configuration.shares, - }; - const dkgClient = new DkgClient(provider); - const dkgRitual = await dkgClient.initializeRitual(dkgRitualParams); + if (ritualId === undefined) { + ritualId = await DkgClient.initializeRitual( + web3Provider, + this.cohort.ursulaAddresses, + true + ); + } + if (ritualId === undefined) { + // Given that we just initialized the ritual, this should never happen + throw new Error('Ritual ID is undefined'); + } + const dkgRitual = await DkgClient.getExistingRitual(web3Provider, ritualId); return DeployedCbdStrategy.create(this.cohort, dkgRitual); } diff --git a/test/acceptance/alice-grants.test.ts b/test/acceptance/alice-grants.test.ts index a5fce2365..5460d43b3 100644 --- a/test/acceptance/alice-grants.test.ts +++ b/test/acceptance/alice-grants.test.ts @@ -31,7 +31,7 @@ describe('story: alice shares message with bob through policy', () => { const shares = 3; const startDate = new Date(); const endDate = new Date(Date.now() + 60 * 1000); - const mockedUrsulas = fakeUrsulas().slice(0, shares); + const mockedUrsulas = fakeUrsulas(shares); // Intermediate variables used for mocking let encryptedTreasureMap: EncryptedTreasureMap; diff --git a/test/acceptance/delay-enact.test.ts b/test/acceptance/delay-enact.test.ts index 61574df95..857e0e9f0 100644 --- a/test/acceptance/delay-enact.test.ts +++ b/test/acceptance/delay-enact.test.ts @@ -14,7 +14,7 @@ describe('story: alice1 creates a policy but alice2 enacts it', () => { const shares = 3; const startDate = new Date(); const endDate = new Date(Date.now() + 60 * 1000); // 60s later - const mockedUrsulas = fakeUrsulas().slice(0, shares); + const mockedUrsulas = fakeUrsulas(shares); const label = 'fake-data-label'; it('alice generates a new policy', async () => { diff --git a/test/docs/cbd.test.ts b/test/docs/cbd.test.ts index cac579632..7538ba827 100644 --- a/test/docs/cbd.test.ts +++ b/test/docs/cbd.test.ts @@ -61,8 +61,8 @@ describe('Get Started (CBD PoC)', () => { // 2. Build a Cohort const config = { - threshold: 3, - shares: 5, + threshold: 2, + shares: 4, porterUri: 'https://porter-tapir.nucypher.community', }; const newCohort = await Cohort.create(config); diff --git a/test/integration/dkg-client.test.ts b/test/integration/dkg-client.test.ts index acb8337b6..64cc59a75 100644 --- a/test/integration/dkg-client.test.ts +++ b/test/integration/dkg-client.test.ts @@ -1,18 +1,25 @@ import { SecretKey } from '@nucypher/nucypher-core'; -import { DkgCoordinatorAgent } from '../../src/agents/coordinator'; +import { + DkgCoordinatorAgent, + DkgRitualState, +} from '../../src/agents/coordinator'; +import { DkgClient } from '../../src/dkg'; import { fakeCoordinatorRitual, fakeDkgParticipants, fakeRitualId, fakeWeb3Provider, + mockGetParticipantPublicKey, mockGetParticipants, + mockVerifyRitual, } from '../utils'; jest.mock('../../src/agents/coordinator', () => ({ DkgCoordinatorAgent: { getRitual: () => Promise.resolve(fakeCoordinatorRitual(fakeRitualId)), - getParticipants: () => Promise.resolve(fakeDkgParticipants(fakeRitualId)), + getParticipants: () => + Promise.resolve(fakeDkgParticipants(fakeRitualId).participants), }, })); @@ -42,13 +49,134 @@ describe('DkgCoordinatorAgent', () => { }); }); -// TODO: Fix this test after the DkgClient.verifyRitual() method is implemented -// describe('DkgClient', () => { -// it('verifies the dkg ritual', async () => { -// const provider = fakeWeb3Provider(SecretKey.random().toBEBytes()); -// -// const dkgClient = new DkgClient(provider); -// const isValid = await dkgClient.verifyRitual(fakeRitualId); -// expect(isValid).toBeTruthy(); -// }); -// }); +describe('DkgClient', () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('verifies the dkg ritual', async () => { + const provider = fakeWeb3Provider(SecretKey.random().toBEBytes()); + const verifyRitualSpy = mockVerifyRitual(); + + const isValid = await DkgClient.verifyRitual(provider, fakeRitualId); + expect(isValid).toBeTruthy(); + expect(verifyRitualSpy).toHaveBeenCalled(); + }); + + it('rejects on missing participant pk', async () => { + const provider = fakeWeb3Provider(SecretKey.random().toBEBytes()); + + await expect(async () => + DkgClient.verifyRitual(provider, fakeRitualId) + ).rejects.toThrow( + 'No public key for participant: 0x0000000000000000000000000000000000000000' + ); + }); + + it('rejects on bad participant pk', async () => { + const provider = fakeWeb3Provider(SecretKey.random().toBEBytes()); + const getParticipantPublicKeysSpy = mockGetParticipantPublicKey(); + + await expect(async () => + DkgClient.verifyRitual(provider, fakeRitualId) + ).rejects.toThrow( + "Transcript aggregate doesn't match the received PVSS instances" + ); + expect(getParticipantPublicKeysSpy).toHaveBeenCalled(); + }); + + it('waits until the ritual end time during initialization', async () => { + jest.useFakeTimers(); + const fakeProvider = fakeWeb3Provider(SecretKey.random().toBEBytes()); + const fakeUrsulas = ['ursula1', 'ursula2', 'ursula3']; + const fakeRitualId = 123; + const initTimestamp = Math.floor(Date.now() / 1000); + const timeout = 10; + + jest + .spyOn(DkgCoordinatorAgent, 'initializeRitual') + .mockResolvedValue(fakeRitualId); + + jest + .spyOn(DkgCoordinatorAgent, 'getRitualInitTime') + .mockResolvedValue(initTimestamp); + + jest.spyOn(DkgCoordinatorAgent, 'getTimeout').mockResolvedValue(timeout); + + jest.spyOn(DkgClient as any, 'performRitual').mockResolvedValue(undefined); + + const promise = DkgClient.initializeRitual(fakeProvider, fakeUrsulas, true); + + jest.advanceTimersByTime(timeout * 1000); + + await expect(promise).resolves.toBe(fakeRitualId); + + expect(DkgCoordinatorAgent.initializeRitual).toHaveBeenCalledWith( + fakeProvider, + fakeUrsulas + ); + + expect(DkgCoordinatorAgent.getRitualInitTime).toHaveBeenCalledWith( + fakeProvider, + fakeRitualId + ); + + expect(DkgCoordinatorAgent.getTimeout).toHaveBeenCalledWith(fakeProvider); + + expect((DkgClient as any).performRitual).toHaveBeenCalledWith( + fakeProvider, + fakeRitualId + ); + }); + + it('throws an error when initialization times out', async () => { + jest.useFakeTimers(); + const fakeProvider = fakeWeb3Provider(SecretKey.random().toBEBytes()); + const fakeUrsulas = ['ursula1', 'ursula2', 'ursula3']; + const fakeRitualId = 123; + const initTimestamp = Math.floor(Date.now() / 1000); + const timeout = 10; + + jest + .spyOn(DkgCoordinatorAgent, 'initializeRitual') + .mockResolvedValue(fakeRitualId); + + jest + .spyOn(DkgCoordinatorAgent, 'getRitualInitTime') + .mockResolvedValue(initTimestamp); + + jest.spyOn(DkgCoordinatorAgent, 'getTimeout').mockResolvedValue(timeout); + + const performRitualSpy = jest + .spyOn(DkgClient as any, 'performRitual') + .mockRejectedValue( + new Error(`Ritual initialization failed. Ritual id ${fakeRitualId}`) + ); + + jest + .spyOn(DkgCoordinatorAgent, 'getRitualState') + .mockResolvedValue(DkgRitualState.TIMEOUT); + + const promise = DkgClient.initializeRitual(fakeProvider, fakeUrsulas, true); + + jest.advanceTimersByTime(timeout * 1000); + + await expect(promise).rejects.toThrow( + `Ritual initialization failed. Ritual id ${fakeRitualId} is in state TIMEOUT` + ); + + expect(DkgCoordinatorAgent.initializeRitual).toHaveBeenCalledWith( + fakeProvider, + fakeUrsulas + ); + + expect(DkgCoordinatorAgent.getRitualInitTime).toHaveBeenCalledWith( + fakeProvider, + fakeRitualId + ); + + expect(DkgCoordinatorAgent.getTimeout).toHaveBeenCalledWith(fakeProvider); + + expect(performRitualSpy).toHaveBeenCalledWith(fakeProvider, fakeRitualId); + }); +}); diff --git a/test/integration/pre.test.ts b/test/integration/pre.test.ts index e3087e16c..6546228f5 100644 --- a/test/integration/pre.test.ts +++ b/test/integration/pre.test.ts @@ -15,7 +15,7 @@ describe('proxy reencryption', () => { const plaintext = toBytes('plaintext-message'); const threshold = 2; const shares = 3; - const ursulas = fakeUrsulas().slice(0, shares); + const ursulas = fakeUrsulas(shares); const label = 'fake-data-label'; const alice = fakeAlice(); const bob = fakeBob(); diff --git a/test/unit/cbd-strategy.test.ts b/test/unit/cbd-strategy.test.ts index 36a8f216b..7f6f5b720 100644 --- a/test/unit/cbd-strategy.test.ts +++ b/test/unit/cbd-strategy.test.ts @@ -14,10 +14,13 @@ import { fakeWeb3Provider, makeCohort, mockCbdDecrypt, + mockGetExistingRitual, mockGetParticipants, + mockGetRitualState, mockGetUrsulas, mockInitializeRitual, mockRandomSessionStaticSecret, + mockVerifyRitual, } from '../utils'; import { aliceSecretKeyBytes } from './testVariables'; @@ -36,8 +39,9 @@ const ownsNFT = new ERC721Ownership({ chain: 5, }); const conditionExpr = new ConditionExpression(ownsNFT); -const ursulas = fakeUrsulas().slice(0, 3); +const ursulas = fakeUrsulas(); const variant = FerveoVariant.Precomputed; +const ritualId = 0; const makeCbdStrategy = async () => { const cohort = await makeCohort(ursulas); @@ -50,14 +54,16 @@ async function makeDeployedCbdStrategy() { const strategy = await makeCbdStrategy(); const mockedDkg = fakeDkgFlow(variant, 0, 4, 4); - const mockedDkgRitual = fakeDkgRitual(mockedDkg, mockedDkg.threshold); + const mockedDkgRitual = fakeDkgRitual(mockedDkg); const web3Provider = fakeWeb3Provider(aliceSecretKey.toBEBytes()); const getUrsulasSpy = mockGetUrsulas(ursulas); - const initializeRitualSpy = mockInitializeRitual(mockedDkgRitual); + const initializeRitualSpy = mockInitializeRitual(ritualId); + const getExistingRitualSpy = mockGetExistingRitual(mockedDkgRitual); const deployedStrategy = await strategy.deploy(web3Provider); expect(getUrsulasSpy).toHaveBeenCalled(); expect(initializeRitualSpy).toHaveBeenCalled(); + expect(getExistingRitualSpy).toHaveBeenCalled(); return { mockedDkg, deployedStrategy }; } @@ -127,6 +133,8 @@ describe('CbdDeployedStrategy', () => { const getParticipantsSpy = mockGetParticipants(participants); const getUrsulasSpy = mockGetUrsulas(ursulas); const sessionKeySpy = mockRandomSessionStaticSecret(requesterSessionKey); + const getRitualStateSpy = mockGetRitualState(); + const verifyRitualSpy = mockVerifyRitual(); const decryptedMessage = await deployedStrategy.decrypter.retrieveAndDecrypt( @@ -135,6 +143,8 @@ describe('CbdDeployedStrategy', () => { variant, ciphertext ); + expect(getRitualStateSpy).toHaveBeenCalled(); + expect(verifyRitualSpy).toHaveBeenCalled(); expect(getUrsulasSpy).toHaveBeenCalled(); expect(getParticipantsSpy).toHaveBeenCalled(); expect(sessionKeySpy).toHaveBeenCalled(); diff --git a/test/unit/cohort.test.ts b/test/unit/cohort.test.ts index 40ab18ee8..e041bcad8 100644 --- a/test/unit/cohort.test.ts +++ b/test/unit/cohort.test.ts @@ -2,7 +2,7 @@ import { Cohort } from '../../src'; import { fakeUrsulas, makeCohort } from '../utils'; describe('Cohort', () => { - const mockedUrsulas = fakeUrsulas().slice(0, 3); + const mockedUrsulas = fakeUrsulas(); it('creates a Cohort', async () => { const cohort = await makeCohort(mockedUrsulas); diff --git a/test/unit/pre-strategy.test.ts b/test/unit/pre-strategy.test.ts index 9fa599a22..6a5e6d2f3 100644 --- a/test/unit/pre-strategy.test.ts +++ b/test/unit/pre-strategy.test.ts @@ -38,7 +38,7 @@ const ownsNFT = new ERC721Ownership({ chain: 5, }); const conditionExpr = new ConditionExpression(ownsNFT); -const mockedUrsulas = fakeUrsulas().slice(0, 3); +const mockedUrsulas = fakeUrsulas(); const makePreStrategy = async () => { const cohort = await makeCohort(mockedUrsulas); diff --git a/test/utils.ts b/test/utils.ts index 46dce7df5..694b2f6fc 100644 --- a/test/utils.ts +++ b/test/utils.ts @@ -1,13 +1,25 @@ -// Disabling some of the eslint rules for conveninence. +// Disabling some of the eslint rules for convenience. /* eslint-disable @typescript-eslint/no-non-null-assertion */ /* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable @typescript-eslint/no-unused-vars */ import { Block } from '@ethersproject/providers'; import { + AggregatedTranscript, Capsule, CapsuleFrag, + Ciphertext, + combineDecryptionSharesPrecomputed, + combineDecryptionSharesSimple, + DecryptionSharePrecomputed, + DecryptionShareSimple, + decryptWithSharedSecret, + Dkg, EncryptedThresholdDecryptionResponse, EncryptedTreasureMap, + EthereumAddress, ferveoEncrypt, + FerveoPublicKey, + Keypair, PublicKey, reencrypt, SecretKey, @@ -15,30 +27,22 @@ import { SessionStaticKey, SessionStaticSecret, ThresholdDecryptionResponse, - VerifiedCapsuleFrag, - VerifiedKeyFrag, -} from '@nucypher/nucypher-core'; -import { - AggregatedTranscript, - Ciphertext, - combineDecryptionSharesPrecomputed, - combineDecryptionSharesSimple, - DecryptionSharePrecomputed, - DecryptionShareSimple, - decryptWithSharedSecret, - Dkg, - EthereumAddress, - Keypair, Transcript, Validator, ValidatorMessage, + VerifiedCapsuleFrag, + VerifiedKeyFrag, } from '@nucypher/nucypher-core'; import axios from 'axios'; import { ethers, providers, Wallet } from 'ethers'; import { keccak256 } from 'ethers/lib/utils'; import { Alice, Bob, Cohort, Configuration, RemoteBob } from '../src'; -import { DkgCoordinatorAgent, DkgParticipant } from '../src/agents/coordinator'; +import { + DkgCoordinatorAgent, + DkgParticipant, + DkgRitualState, +} from '../src/agents/coordinator'; import { CbdTDecDecrypter } from '../src/characters/cbd-recipient'; import { CbdDecryptResult, @@ -110,11 +114,12 @@ const genChecksumAddress = (i: number) => '0x' + '0'.repeat(40 - i.toString(16).length) + i.toString(16); const genEthAddr = (i: number) => EthereumAddress.fromString(genChecksumAddress(i)); -export const fakeUrsulas = (): readonly Ursula[] => - [0, 1, 2, 3, 4].map((i: number) => ({ +export const fakeUrsulas = (n = 4): Ursula[] => + // 0...n-1 + Array.from(Array(n).keys()).map((i: number) => ({ encryptingKey: SecretKey.random().publicKey(), checksumAddress: genChecksumAddress(i).toLowerCase(), - uri: 'https://example.a.com:9151', + uri: `https://example.${i}.com:9151`, })); export const mockGetUrsulas = (ursulas: readonly Ursula[]) => { @@ -496,26 +501,64 @@ export const mockRandomSessionStaticSecret = (secret: SessionStaticSecret) => { export const fakeRitualId = 0; -export const fakeDkgRitual = (ritual: { dkg: Dkg }, threshold: number) => { - return new DkgRitual(fakeRitualId, ritual.dkg.publicKey(), threshold); +export const fakeDkgRitual = (ritual: { + dkg: Dkg; + sharesNum: number; + threshold: number; +}) => { + return new DkgRitual( + fakeRitualId, + ritual.dkg.publicKey(), + { + sharesNum: ritual.sharesNum, + threshold: ritual.threshold, + }, + DkgRitualState.FINALIZED + ); }; -export const mockInitializeRitual = (fakeRitual: unknown) => { - return jest - .spyOn(DkgClient.prototype as any, 'initializeRitual') - .mockImplementation(() => { - return Promise.resolve(fakeRitual); - }); +export const mockInitializeRitual = (ritualId: number) => { + return jest.spyOn(DkgClient, 'initializeRitual').mockImplementation(() => { + return Promise.resolve(ritualId); + }); +}; + +export const mockGetExistingRitual = (dkgRitual: DkgRitual) => { + return jest.spyOn(DkgClient, 'getExistingRitual').mockImplementation(() => { + return Promise.resolve(dkgRitual); + }); }; export const makeCohort = async (ursulas: Ursula[]) => { const getUrsulasSpy = mockGetUrsulas(ursulas); const config = { threshold: 2, - shares: 3, + shares: ursulas.length, porterUri: 'https://_this.should.crash', }; const cohort = await Cohort.create(config); expect(getUrsulasSpy).toHaveBeenCalled(); return cohort; }; + +export const mockGetRitualState = (state = DkgRitualState.FINALIZED) => { + return jest + .spyOn(DkgCoordinatorAgent, 'getRitualState') + .mockImplementation((_provider, _ritualId) => Promise.resolve(state)); +}; + +export const mockVerifyRitual = (isValid = true) => { + return jest + .spyOn(DkgClient, 'verifyRitual') + .mockImplementation((_provider, _ritualId) => Promise.resolve(isValid)); +}; + +export const mockGetParticipantPublicKey = (pk = fakeFerveoPublicKey()) => { + return jest + .spyOn(DkgClient, 'getParticipantPublicKey') + .mockImplementation((_address) => pk); +}; + +export const fakeFerveoPublicKey = (): FerveoPublicKey => { + return Keypair.random().publicKey; +};