diff --git a/packages/get-starknet/src/wallet.ts b/packages/get-starknet/src/wallet.ts index b3314700..2cfc0d52 100644 --- a/packages/get-starknet/src/wallet.ts +++ b/packages/get-starknet/src/wallet.ts @@ -57,9 +57,11 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { lock: MutexInterface; - #networkChangeController: AbortController | undefined; + #pollingController: AbortController | undefined; - #accountChangeController: AbortController | undefined; + #accountChangeHandlers: AccountChangeEventHandler[] = []; + + #networkChangeHandlers: NetworkChangeEventHandler[] = []; // eslint-disable-next-line @typescript-eslint/naming-convention, no-restricted-globals static readonly snapId = process.env.SNAP_ID ?? 'npm:@consensys/starknet-snap'; @@ -237,10 +239,14 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { * @param handleEvent - The event handler function */ on(event: Event, handleEvent: WalletEventHandlers[Event]): void { + if (!this.#pollingController) { + this.#startPolling(); + } + if (event === 'accountsChanged') { - this.onAccountChanged(handleEvent as AccountChangeEventHandler); + this.#accountChangeHandlers.push(handleEvent as AccountChangeEventHandler); } else if (event === 'networkChanged') { - this.onNetworkChanged(handleEvent as NetworkChangeEventHandler); + this.#networkChangeHandlers.push(handleEvent as NetworkChangeEventHandler); } else { throw new Error(`Unsupported event: ${String(event)}`); } @@ -254,89 +260,60 @@ export class MetaMaskSnapWallet implements StarknetWindowObject { */ off(event: Event, _handleEvent?: WalletEventHandlers[Event]): void { if (event === 'accountsChanged') { - this.offAccountChanged(); + this.#accountChangeHandlers = []; } else if (event === 'networkChanged') { - this.offNetworkChanged(); + this.#networkChangeHandlers = []; } else { throw new Error(`Unsupported event: ${String(event)}`); } - } - - /** - * Starts polling for account changes and calls the callback when a change is detected. - * @param callback - The function to call when an account change is detected. - */ - onAccountChanged(callback: AccountChangeEventHandler): void { - // Set up an AbortController to manage the polling loop - this.#accountChangeController = new AbortController(); - const { signal } = this.#accountChangeController; - - const pollForAccountChange = async () => { - while (!signal.aborted) { - const previousNetwork = this.#network; - await this.#init(); - if (previousNetwork.chainId !== this.#network.chainId || this.#latestAddress !== this.#selectedAddress) { - this.#latestAddress = this.#selectedAddress; - callback([this.#selectedAddress]); - } - - await new Promise((resolve) => setTimeout(resolve, 5000)); - } - }; - pollForAccountChange().catch((error) => { - if (!signal.aborted) { - console.error('Error in account change polling:', error); - } - }); - } - - /** - * Stops polling for account changes. - */ - offAccountChanged(): void { - if (this.#accountChangeController) { - this.#accountChangeController.abort(); - this.#accountChangeController = undefined; + if (this.#accountChangeHandlers.length === 0 && this.#networkChangeHandlers.length === 0) { + this.#stopPolling(); } } /** - * Starts polling for network changes and calls the callback when a change is detected. - * @param callback - The function to call when a network change is detected. + * Starts polling for account or network changes and calls the respective callbacks. */ - onNetworkChanged(callback: NetworkChangeEventHandler): void { - // Set up an AbortController to manage the polling loop - this.#networkChangeController = new AbortController(); - const { signal } = this.#networkChangeController; + #startPolling(): void { + this.#pollingController = new AbortController(); + const { signal } = this.#pollingController; - const pollForNetworkChange = async () => { + const pollForChanges = async () => { while (!signal.aborted) { const previousNetwork = this.#network; + const previousAddress = this.#latestAddress; + await this.#init(); + // Check for network change if (previousNetwork.chainId !== this.#network.chainId) { - callback(this.#network.chainId, [this.#selectedAddress]); + this.#networkChangeHandlers.forEach((callback) => callback(this.#network.chainId, [this.#selectedAddress])); + } + + // Check for account change + if (previousAddress !== this.#selectedAddress) { + this.#accountChangeHandlers.forEach((callback) => callback([this.#selectedAddress])); } await new Promise((resolve) => setTimeout(resolve, 5000)); } }; - pollForNetworkChange().catch((error) => { + pollForChanges().catch((error) => { if (!signal.aborted) { - console.error('Error in network change polling:', error); + console.error('Error in polling:', error); } }); } /** - * Stops polling for network changes. + * Stops polling for account or network changes. */ - offNetworkChanged(): void { - if (this.#networkChangeController) { - this.#networkChangeController.abort(); - this.#networkChangeController = undefined; + #stopPolling(): void { + if (this.#pollingController) { + this.#pollingController.abort(); + this.#pollingController = undefined; } } }