Skip to content

Commit

Permalink
feat! add ritual initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
piotr-roslaniec committed Jun 30, 2023
1 parent e0dadba commit b23b38f
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 51 deletions.
41 changes: 40 additions & 1 deletion src/agents/coordinator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import {
Coordinator__factory,
} from '../../types/ethers-contracts';
import { BLS12381 } from '../../types/ethers-contracts/Coordinator';
import { ChecksumAddress } from '../types';
import { fromHexString } from '../utils';

import { getContract } from './contracts';
import { DEFAULT_WAIT_N_CONFIRMATIONS, getContract } from './contracts';

export interface CoordinatorRitual {
initiator: string;
Expand Down Expand Up @@ -57,6 +58,20 @@ export class DkgCoordinatorAgent {
});
}

public static async initializeRitual(
provider: ethers.providers.Web3Provider,
providers: ChecksumAddress[]
): Promise<number> {
const Coordinator = await this.connectReadWrite(provider);
const tx = await Coordinator.initiateRitual(providers);
const txReceipt = await tx.wait(DEFAULT_WAIT_N_CONFIRMATIONS);
const [ritualStartEvent] = txReceipt.events ?? [];
if (!ritualStartEvent) {
throw new Error('Ritual start event not found');
}
return ritualStartEvent.args?.ritualId.toNumber();
}

public static async getRitual(
provider: ethers.providers.Web3Provider,
ritualId: number
Expand All @@ -73,12 +88,36 @@ export class DkgCoordinatorAgent {
return await Coordinator.getRitualState(ritualId);
}

public static async onRitualEndEvent(
provider: ethers.providers.Web3Provider,
ritualId: number,
callback: (successful: boolean) => void
): Promise<void> {
const Coordinator = await this.connectReadOnly(provider);
// We leave `initiator` undefined because we don't care who the initiator is
// We leave `successful` undefined because we don't care if the ritual was successful
const eventFilter = Coordinator.filters.EndRitual(
ritualId,
undefined,
undefined
);
Coordinator.once(eventFilter, (_ritualId, _initiator, successful) => {
callback(successful);
});
}

private static async connectReadOnly(
provider: ethers.providers.Web3Provider
) {
return await this.connect(provider);
}

private static async connectReadWrite(
web3Provider: ethers.providers.Web3Provider
) {
return await this.connect(web3Provider, web3Provider.getSigner());
}

private static async connect(
provider: ethers.providers.Web3Provider,
signer?: ethers.providers.JsonRpcSigner
Expand Down
2 changes: 1 addition & 1 deletion src/characters/cbd-recipient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export class CbdTDecDecrypter {
return new CbdTDecDecrypter(
new Porter(porterUri),
dkgRitual.id,
dkgRitual.threshold
dkgRitual.dkgParams.threshold
);
}

Expand Down
106 changes: 87 additions & 19 deletions src/dkg.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ import {
} from '@nucypher/nucypher-core';
import { ethers } from 'ethers';

import { DkgCoordinatorAgent } from './agents/coordinator';
import { bytesEquals, fromHexString } from './utils';
import { DkgCoordinatorAgent, DkgRitualState } from './agents/coordinator';
import { ChecksumAddress } from './types';
import { bytesEquals, fromHexString, objectEquals } from './utils';

// TODO: Expose from @nucypher/nucypher-core
export enum FerveoVariant {
Expand Down Expand Up @@ -50,66 +51,133 @@ export function getCombineDecryptionSharesFunction(
}
}

export type DkgRitualParameters = {
sharesNum: number;
threshold: number;
};

export interface DkgRitualJSON {
id: number;
dkgPublicKey: Uint8Array;
threshold: number;
dkgParams: DkgRitualParameters;
state: DkgRitualState;
}

export class DkgRitual {
constructor(
public readonly id: number,
public readonly dkgPublicKey: DkgPublicKey,
public readonly threshold: number
public readonly dkgParams: DkgRitualParameters,
public readonly state: DkgRitualState
) {}

public toObj(): DkgRitualJSON {
return {
id: this.id,
dkgPublicKey: this.dkgPublicKey.toBytes(),
threshold: this.threshold,
dkgParams: this.dkgParams,
state: this.state,
};
}

public static fromObj({
id,
dkgPublicKey,
threshold,
dkgParams,
state,
}: DkgRitualJSON): DkgRitual {
return new DkgRitual(id, DkgPublicKey.fromBytes(dkgPublicKey), threshold);
return new DkgRitual(
id,
DkgPublicKey.fromBytes(dkgPublicKey),
dkgParams,
state
);
}

public equals(other: DkgRitual): boolean {
return (
this.id === other.id &&
// TODO: Replace with `equals` after https://github.com/nucypher/nucypher-core/issues/56 is fixed
bytesEquals(this.dkgPublicKey.toBytes(), other.dkgPublicKey.toBytes()) &&
this.threshold === other.threshold
objectEquals(this.dkgParams, other.dkgParams) &&
this.state === other.state
);
}
}

// TODO: Currently, we're assuming that the threshold is always `floor(sharesNum / 2) + 1`.
// https://github.com/nucypher/nucypher/issues/3095
const assumedThreshold = (sharesNum: number): number =>
Math.floor(sharesNum / 2) + 1;

export class DkgClient {
// TODO: Update API: Replace with getExistingRitual and support ritualId in Strategy
public static async initializeRitual(
web3Provider: ethers.providers.Web3Provider,
ritualParams: {
shares: number;
threshold: number;
ursulas: ChecksumAddress[],
waitUntilEnd = false
): Promise<DkgRitual> {
const ritualId = await DkgCoordinatorAgent.initializeRitual(
web3Provider,
ursulas
);

if (waitUntilEnd) {
const isSuccessful = await DkgClient.waitUntilRitualEnd(
web3Provider,
ritualId
);
if (!isSuccessful) {
const ritualState = await DkgCoordinatorAgent.getRitualState(
web3Provider,
ritualId
);
throw new Error(
`Ritual initialization failed. Ritual id ${ritualId} is in state ${ritualState}`
);
}
}

return this.getExistingRitual(web3Provider, ritualId);
}

private static waitUntilRitualEnd = async (
web3Provider: ethers.providers.Web3Provider,
ritualId: number
): Promise<boolean> => {
return new Promise((resolve, reject) => {
const callback = (successful: boolean) => {
if (successful) {
resolve(true);
} else {
reject();
}
};
DkgCoordinatorAgent.onRitualEndEvent(web3Provider, ritualId, callback);
});
};

public static async getExistingRitual(
web3Provider: ethers.providers.Web3Provider,
ritualId: number
): Promise<DkgRitual> {
const ritualId = 2;
const ritualState = await DkgCoordinatorAgent.getRitualState(
web3Provider,
ritualId
);
const ritual = await DkgCoordinatorAgent.getRitual(web3Provider, ritualId);
const dkgPkBytes = new Uint8Array([
...fromHexString(ritual.publicKey.word0),
...fromHexString(ritual.publicKey.word1),
]);

return {
id: ritualId,
dkgPublicKey: DkgPublicKey.fromBytes(dkgPkBytes),
threshold: ritualParams.threshold,
} as DkgRitual;
return new DkgRitual(
ritualId,
DkgPublicKey.fromBytes(dkgPkBytes),
{
sharesNum: ritual.dkgSize,
threshold: assumedThreshold(ritual.dkgSize),
},
ritualState
);
}

public static async verifyRitual(
Expand Down
6 changes: 1 addition & 5 deletions src/sdk/strategy/cbd-strategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ export class CbdStrategy {
public async deploy(
web3Provider: ethers.providers.Web3Provider
): Promise<DeployedCbdStrategy> {
const dkgRitualParams = {
threshold: this.cohort.configuration.threshold,
shares: this.cohort.configuration.shares,
};
const dkgRitual = await DkgClient.initializeRitual(
web3Provider,
dkgRitualParams
this.cohort.ursulaAddresses
);
return DeployedCbdStrategy.create(this.cohort, dkgRitual);
}
Expand Down
2 changes: 1 addition & 1 deletion test/unit/cbd-strategy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async function makeDeployedCbdStrategy() {
const strategy = await makeCbdStrategy();

const mockedDkg = fakeDkgFlow(variant, 0, 4, 4);
const mockedDkgRitual = fakeDkgRitual(mockedDkg, mockedDkg.threshold);
const mockedDkgRitual = fakeDkgRitual(mockedDkg);
const web3Provider = fakeWeb3Provider(aliceSecretKey.toBEBytes());
const getUrsulasSpy = mockGetUrsulas(ursulas);
const initializeRitualSpy = mockInitializeRitual(mockedDkgRitual);
Expand Down
55 changes: 31 additions & 24 deletions test/utils.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
// Disabling some of the eslint rules for conveninence.
// Disabling some of the eslint rules for convenience.
/* eslint-disable @typescript-eslint/no-non-null-assertion */
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */
import { Block } from '@ethersproject/providers';
import {
AggregatedTranscript,
Capsule,
CapsuleFrag,
Ciphertext,
combineDecryptionSharesPrecomputed,
combineDecryptionSharesSimple,
DecryptionSharePrecomputed,
DecryptionShareSimple,
decryptWithSharedSecret,
Dkg,
EncryptedThresholdDecryptionResponse,
EncryptedTreasureMap,
EthereumAddress,
ferveoEncrypt,
FerveoPublicKey,
Keypair,
PublicKey,
reencrypt,
SecretKey,
SessionSecretFactory,
SessionStaticKey,
SessionStaticSecret,
ThresholdDecryptionResponse,
VerifiedCapsuleFrag,
VerifiedKeyFrag,
} from '@nucypher/nucypher-core';
import {
AggregatedTranscript,
Ciphertext,
combineDecryptionSharesPrecomputed,
combineDecryptionSharesSimple,
DecryptionSharePrecomputed,
DecryptionShareSimple,
decryptWithSharedSecret,
Dkg,
EthereumAddress,
Keypair,
Transcript,
Validator,
ValidatorMessage,
VerifiedCapsuleFrag,
VerifiedKeyFrag,
} from '@nucypher/nucypher-core';
import axios from 'axios';
import { ethers, providers, Wallet } from 'ethers';
Expand Down Expand Up @@ -501,17 +500,28 @@ export const mockRandomSessionStaticSecret = (secret: SessionStaticSecret) => {

export const fakeRitualId = 0;

export const fakeDkgRitual = (ritual: { dkg: Dkg }, threshold: number) => {
return new DkgRitual(fakeRitualId, ritual.dkg.publicKey(), threshold);
export const fakeDkgRitual = (ritual: {
dkg: Dkg;
sharesNum: number;
threshold: number;
}) => {
return new DkgRitual(
fakeRitualId,
ritual.dkg.publicKey(),
{
sharesNum: ritual.sharesNum,
threshold: ritual.threshold,
},
DkgRitualState.FINALIZED
);
};

export const mockInitializeRitual = (fakeRitual: unknown) => {
export const mockInitializeRitual = (dkgRitual: DkgRitual) => {
return (
jest
.spyOn(DkgClient, 'initializeRitual')
// eslint-disable-next-line @typescript-eslint/no-unused-vars
.mockImplementation((_web3Provider, _ritualParams) => {
return Promise.resolve(fakeRitual) as Promise<DkgRitual>;
.mockImplementation(() => {
return Promise.resolve(dkgRitual);
})
);
};
Expand All @@ -530,21 +540,18 @@ export const makeCohort = async (ursulas: Ursula[]) => {

export const mockGetRitualState = (state = DkgRitualState.FINALIZED) => {
return jest.spyOn(DkgCoordinatorAgent, 'getRitualState').mockImplementation(
// eslint-disable-next-line @typescript-eslint/no-unused-vars
(_provider, _ritualId) => Promise.resolve(state)
);
};

export const mockVerifyRitual = (isValid = true) => {
return jest.spyOn(DkgClient, 'verifyRitual').mockImplementation(
// eslint-disable-next-line @typescript-eslint/no-unused-vars
(_provider, _ritualId) => Promise.resolve(isValid)
);
};

export const mockGetParticipantPublicKey = (pk = fakeFerveoPublicKey()) => {
return jest.spyOn(DkgClient, 'getParticipantPublicKey').mockImplementation(
// eslint-disable-next-line @typescript-eslint/no-unused-vars
(_address) => pk
);
};
Expand Down

0 comments on commit b23b38f

Please sign in to comment.