diff --git a/packages/starknet-snap/src/__tests__/helper.ts b/packages/starknet-snap/src/__tests__/helper.ts index 6e63fc8b..732ca50f 100644 --- a/packages/starknet-snap/src/__tests__/helper.ts +++ b/packages/starknet-snap/src/__tests__/helper.ts @@ -22,17 +22,14 @@ import { v4 as uuidv4 } from 'uuid'; import { FeeToken } from '../types/snapApi'; import type { AccContract, - Erc20Token, Transaction, TransactionRequest, } from '../types/snapState'; import { ACCOUNT_CLASS_HASH, ACCOUNT_CLASS_HASH_LEGACY, - ETHER_MAINNET, PRELOADED_TOKENS, PROXY_CONTRACT_HASH, - STRK_MAINNET, } from '../utils/constants'; import { grindKey } from '../utils/keyPair'; @@ -299,17 +296,16 @@ export function generateTransactions({ export function generateTransactionRequests({ chainId, address, - selectedFeeToken, + selectedFeeTokens = Object.values(FeeToken), contractAddresses = PRELOADED_TOKENS.map((token) => token.address), cnt = 1, }: { chainId: constants.StarknetChainId | string; address: string; - selectedFeeToken?: Erc20Token; + selectedFeeTokens?: FeeToken[]; contractAddresses?: string[]; cnt?: number; }): TransactionRequest[] { - const feeTokens = [STRK_MAINNET, ETHER_MAINNET]; const request = { chainId: chainId, id: '', @@ -333,8 +329,9 @@ export function generateTransactionRequests({ addressIndex: 0, maxFee: '100', selectedFeeToken: - selectedFeeToken?.symbol ?? - feeTokens[Math.floor(generateRandomValue() * feeTokens.length)].symbol, + selectedFeeTokens[ + Math.floor(generateRandomValue() * selectedFeeTokens.length) + ], calls: [ { contractAddress: diff --git a/packages/starknet-snap/src/ui/controllers/user-input-event-controller.test.ts b/packages/starknet-snap/src/ui/controllers/user-input-event-controller.test.ts index 18209137..05b54532 100644 --- a/packages/starknet-snap/src/ui/controllers/user-input-event-controller.test.ts +++ b/packages/starknet-snap/src/ui/controllers/user-input-event-controller.test.ts @@ -315,23 +315,47 @@ describe('UserInputEventController', () => { }); describe('handleFeeTokenChange', () => { + type PrepareHandleFeeTokenChangeArg = { + feeToken: { + // The fee token that we change from + changeFrom: FeeToken; + // The fee token that we change to + changeTo: FeeToken; + }; + }; + const prepareHandleFeeTokenChange = async ( - feeToken: FeeToken = FeeToken.STRK, - selectedFeeToken?: Erc20Token, + arg: PrepareHandleFeeTokenChangeArg = { + feeToken: { + changeFrom: FeeToken.STRK, + changeTo: FeeToken.ETH, + }, + }, ) => { const network = STARKNET_TESTNET_NETWORK; const { chainId } = network; + const { feeToken } = arg; const [account] = await generateAccounts(chainId, 1); const [transactionRequest] = generateTransactionRequests({ chainId, address: account.address, - selectedFeeToken, + selectedFeeTokens: [feeToken.changeFrom], }); + // Create a copy of the original transaction request, for testing if the transaction request is updated / rolled back + const originalTransactionRequest = { + ...transactionRequest, + // Since only `maxFee`, `selectedFeeToken`, `includeDeploy` and `resourceBounds` has been updated, hence we only need to copy these fields + maxFee: transactionRequest.maxFee, + selectedFeeToken: transactionRequest.selectedFeeToken, + includeDeploy: transactionRequest.includeDeploy, + resourceBounds: [...transactionRequest.resourceBounds], + }; + const event = generateInputEvent({ transactionRequest, - eventValue: feeToken, + eventValue: feeToken.changeTo, }); mockNetworkStateManager(network); @@ -341,24 +365,23 @@ describe('UserInputEventController', () => { return { ...mockHasSufficientFundsForFee(), ...mockUpdateExecuteTxnFlow(), - ...mockEstimateFee(feeToken), + ...mockEstimateFee(feeToken.changeTo), ...mockTransactionRequestStateManager(), event, + originalTransactionRequest, transactionRequest, account, network, - feeToken, + feeToken: feeToken.changeTo, }; }; it.each([STRK_SEPOLIA_TESTNET, ETHER_SEPOLIA_TESTNET])( - 'updates the transaction request with the updated estimated fee: feeToken - %symbol', + 'updates the transaction request with the updated estimated fee: feeToken - $symbol', async (token: Erc20Token) => { - const feeToken = FeeToken[token.symbol]; - const selectedFeeToken = - token.symbol === FeeToken.ETH - ? STRK_SEPOLIA_TESTNET - : ETHER_SEPOLIA_TESTNET; + const feeTokenChangeTo = FeeToken[token.symbol]; + const feeTokenChangeFrom = + token.symbol === FeeToken.ETH ? FeeToken.STRK : FeeToken.ETH; const { event, account, @@ -369,7 +392,13 @@ describe('UserInputEventController', () => { mockGetEstimatedFeesResponse, upsertTransactionRequestSpy, transactionRequest, - } = await prepareHandleFeeTokenChange(feeToken, selectedFeeToken); + } = await prepareHandleFeeTokenChange({ + feeToken: { + changeFrom: feeTokenChangeFrom, + changeTo: feeTokenChangeTo, + }, + }); + const feeTokenAddress = token.address; const { signer, calls } = transactionRequest; const { publicKey, privateKey, address } = account; @@ -394,7 +423,7 @@ describe('UserInputEventController', () => { }, ], { - version: controller.feeTokenToTransactionVersion(feeToken), + version: controller.feeTokenToTransactionVersion(feeTokenChangeTo), }, ); expect(hasSufficientFundsForFeeSpy).toHaveBeenCalledWith({ @@ -410,9 +439,10 @@ describe('UserInputEventController', () => { controller.eventId, transactionRequest, ); + // Make sure the `selectedFeeToken` transaction request has been updated expect(upsertTransactionRequestSpy).toHaveBeenCalledWith({ ...transactionRequest, - selectedFeeToken: feeToken, + selectedFeeToken: feeTokenChangeTo, }); }, ); @@ -421,53 +451,67 @@ describe('UserInputEventController', () => { const { event, hasSufficientFundsForFeeSpy, - transactionRequest, + originalTransactionRequest, updateExecuteTxnFlowSpy, upsertTransactionRequestSpy, feeToken, - } = await prepareHandleFeeTokenChange(); + } = await prepareHandleFeeTokenChange({ + feeToken: { + changeFrom: FeeToken.STRK, + changeTo: FeeToken.ETH, + }, + }); hasSufficientFundsForFeeSpy.mockResolvedValue(false); const controller = createMockController(event); await controller.handleFeeTokenChange(); expect(upsertTransactionRequestSpy).not.toHaveBeenCalled(); + expect(feeToken).not.toStrictEqual( + originalTransactionRequest.selectedFeeToken, + ); + // If the account balance is insufficient to cover the fee, the transaction request should be rolled back or not updated. expect(updateExecuteTxnFlowSpy).toHaveBeenCalledWith( controller.eventId, - transactionRequest, + originalTransactionRequest, { errors: { - fees: `Not enough ${feeToken} to pay for fee, switching back to ${transactionRequest.selectedFeeToken}`, + fees: `Not enough ${feeToken} to pay for fee, switching back to ${originalTransactionRequest.selectedFeeToken}`, }, }, ); }); - it('rollback the transaction request and show a general error message if other error was thrown.', async () => { + it('rollbacks the transaction request with a general error message if another Error was thrown.', async () => { const { event, - transactionRequest, + originalTransactionRequest, updateExecuteTxnFlowSpy, upsertTransactionRequestSpy, - } = await prepareHandleFeeTokenChange(); + feeToken, + } = await prepareHandleFeeTokenChange({ + feeToken: { + changeFrom: FeeToken.STRK, + changeTo: FeeToken.ETH, + }, + }); + // Simulate an error thrown to test the error handling upsertTransactionRequestSpy.mockRejectedValue(new Error('Failed!')); - const rollbackSnapshot = { - maxFee: transactionRequest.maxFee, - selectedFeeToken: transactionRequest.selectedFeeToken, - includeDeploy: transactionRequest.includeDeploy, - resourceBounds: [...transactionRequest.resourceBounds], - }; const controller = createMockController(event); await controller.handleFeeTokenChange(); + expect(feeToken).not.toStrictEqual( + originalTransactionRequest.selectedFeeToken, + ); + // if any Error was thrown, the transaction request should be rolled back or not updated. expect(updateExecuteTxnFlowSpy).toHaveBeenCalledWith( controller.eventId, - { ...transactionRequest, ...rollbackSnapshot }, + originalTransactionRequest, { errors: { - fees: `Failed to calculate the fees, switching back to ${transactionRequest.selectedFeeToken}`, + fees: `Failed to calculate the fees, switching back to ${originalTransactionRequest.selectedFeeToken}`, }, }, ); diff --git a/packages/starknet-snap/src/ui/controllers/user-input-event-controller.ts b/packages/starknet-snap/src/ui/controllers/user-input-event-controller.ts index 28520697..f9224a61 100644 --- a/packages/starknet-snap/src/ui/controllers/user-input-event-controller.ts +++ b/packages/starknet-snap/src/ui/controllers/user-input-event-controller.ts @@ -132,8 +132,9 @@ export class UserInputEventController { protected createRollbackSnapshot( request: TransactionRequest, - ): Partial { + ): TransactionRequest { return { + ...request, maxFee: request.maxFee, selectedFeeToken: request.selectedFeeToken, includeDeploy: request.includeDeploy, @@ -143,7 +144,7 @@ export class UserInputEventController { protected async handleFeeTokenChange() { const request = this.context?.request as TransactionRequest; - const rollbackSnapshot = this.createRollbackSnapshot(request); + const originalRequest = this.createRollbackSnapshot(request); const { addressIndex, calls, signer, chainId } = request; const feeToken = (this.event as InputChangeEvent) .value as unknown as FeeToken; @@ -204,17 +205,12 @@ export class UserInputEventController { } catch (error) { const errorMessage = error instanceof InsufficientFundsError - ? `Not enough ${feeToken} to pay for fee, switching back to ${request.selectedFeeToken}` - : `Failed to calculate the fees, switching back to ${request.selectedFeeToken}`; - + ? `Not enough ${feeToken} to pay for fee, switching back to ${originalRequest.selectedFeeToken}` + : `Failed to calculate the fees, switching back to ${originalRequest.selectedFeeToken}`; // On failure, display ExecuteTxnUI with an error message - await updateExecuteTxnFlow( - this.eventId, - { ...request, ...rollbackSnapshot }, - { - errors: { fees: errorMessage }, - }, - ); + await updateExecuteTxnFlow(this.eventId, originalRequest, { + errors: { fees: errorMessage }, + }); } } }