diff --git a/packages/starknet-snap/src/__tests__/helper.ts b/packages/starknet-snap/src/__tests__/helper.ts index 6c6321d7..9af83081 100644 --- a/packages/starknet-snap/src/__tests__/helper.ts +++ b/packages/starknet-snap/src/__tests__/helper.ts @@ -325,7 +325,7 @@ export function generateTransactionRequests({ networkName: 'Sepolia', signer: address, maxFee: '100', - feeToken: + selectedFeeToken: feeTokens[Math.floor(generateRandomValue() * feeTokens.length)].symbol, calls: [ { diff --git a/packages/starknet-snap/src/rpcs/__tests__/helper.ts b/packages/starknet-snap/src/rpcs/__tests__/helper.ts index 40b9b6ee..a78c2f9d 100644 --- a/packages/starknet-snap/src/rpcs/__tests__/helper.ts +++ b/packages/starknet-snap/src/rpcs/__tests__/helper.ts @@ -58,6 +58,20 @@ export function prepareConfirmDialog() { }; } +/** + * + */ +export function prepareConfirmDialogInteractiveUI() { + const confirmDialogSpy = jest.spyOn( + snapHelper, + 'createInteractiveConfirmDialog', + ); + confirmDialogSpy.mockResolvedValue(true); + return { + confirmDialogSpy, + }; +} + /** * */ diff --git a/packages/starknet-snap/src/rpcs/execute-txn.test.ts b/packages/starknet-snap/src/rpcs/execute-txn.test.ts index f9498055..0f10cec8 100644 --- a/packages/starknet-snap/src/rpcs/execute-txn.test.ts +++ b/packages/starknet-snap/src/rpcs/execute-txn.test.ts @@ -14,7 +14,7 @@ import { executeTxn as executeTxnUtil } from '../utils/starknetUtils'; import { generateRandomFee, mockAccount, - prepareConfirmDialog, + prepareConfirmDialogInteractiveUI, prepareMockAccount, } from './__tests__/helper'; import type { ExecuteTxnParams } from './execute-txn'; @@ -35,7 +35,7 @@ const prepareMockExecuteTxn = async ( networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], transactions: [], }; - const { confirmDialogSpy } = prepareConfirmDialog(); + const { confirmDialogSpy } = prepareConfirmDialogInteractiveUI(); const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA); prepareMockAccount(account, state); diff --git a/packages/starknet-snap/src/rpcs/execute-txn.ts b/packages/starknet-snap/src/rpcs/execute-txn.ts index 9e6d6108..2710fb39 100644 --- a/packages/starknet-snap/src/rpcs/execute-txn.ts +++ b/packages/starknet-snap/src/rpcs/execute-txn.ts @@ -1,34 +1,30 @@ -import type { Component, Json } from '@metamask/snaps-sdk'; -import convert from 'ethereum-unit-converter'; +import { type Json } from '@metamask/snaps-sdk'; import type { Call, Calldata } from 'starknet'; import { constants, TransactionStatus, TransactionType } from 'starknet'; import type { Infer } from 'superstruct'; import { object, string, assign, optional, any } from 'superstruct'; +import { v4 as uuidv4 } from 'uuid'; import { AccountStateManager } from '../state/account-state-manager'; import { TokenStateManager } from '../state/token-state-manager'; import { TransactionStateManager } from '../state/transaction-state-manager'; import { FeeToken } from '../types/snapApi'; +import type { TransactionRequest } from '../types/snapState'; import { VoyagerTransactionType, type Transaction } from '../types/snapState'; +import { ExecuteTxnUI } from '../ui/components'; +import { generateExecuteTxnFlow } from '../ui/utils'; import type { AccountRpcControllerOptions } from '../utils'; import { AddressStruct, BaseRequestStruct, AccountRpcController, - confirmDialog, UniversalDetailsStruct, CallsStruct, mapDeprecatedParams, - addressUI, - signerUI, - networkUI, - jsonDataUI, - dividerUI, - headerUI, - rowUI, + createInteractiveConfirmDialog, + callToTransactionReqCall, } from '../utils'; import { UserRejectedOpError } from '../utils/exceptions'; -import { logger } from '../utils/logger'; import { createAccount, executeTxn as executeTxnUtil, @@ -112,6 +108,7 @@ export class ExecuteTxnRpc extends AccountRpcController< ): Promise { const { address, calls, abis, details } = params; const { privateKey, publicKey } = this.account; + const callsArray = Array.isArray(calls) ? calls : [calls]; const { includeDeploy, suggestedMaxFee, estimateResults } = await getEstimatedFees( @@ -132,15 +129,38 @@ export class ExecuteTxnRpc extends AccountRpcController< const version = details?.version as unknown as constants.TRANSACTION_VERSION; - if ( - !(await this.getExecuteTxnConsensus( - address, - accountDeployed, - calls, - suggestedMaxFee, - version, - )) - ) { + const formattedCalls = await Promise.all( + callsArray.map(async (call) => + callToTransactionReqCall( + call, + this.network.chainId, + address, + this.tokenStateManager, + ), + ), + ); + + const request: TransactionRequest = { + chainId: this.network.chainId, + networkName: this.network.name, + id: uuidv4(), + interfaceId: '', + type: TransactionType.INVOKE, + signer: address, + maxFee: suggestedMaxFee, + calls: formattedCalls, + selectedFeeToken: + version === constants.TRANSACTION_VERSION.V3 + ? FeeToken.STRK + : FeeToken.ETH, + includeDeploy, + }; + + const interfaceId = await generateExecuteTxnFlow(ExecuteTxnUI, request); + + request.interfaceId = interfaceId; + + if (!(await createInteractiveConfirmDialog(interfaceId))) { throw new UserRejectedOpError() as unknown as Error; } @@ -213,119 +233,6 @@ export class ExecuteTxnRpc extends AccountRpcController< }); } - protected async getExecuteTxnConsensus( - address: string, - accountDeployed: boolean, - calls: Call[] | Call, - maxFee: string, - version?: constants.TRANSACTION_VERSION, - ) { - const { name: chainName, chainId } = this.network; - const callsArray = Array.isArray(calls) ? calls : [calls]; - - const components: Component[] = []; - const feeToken: FeeToken = - version === constants.TRANSACTION_VERSION.V3 - ? FeeToken.STRK - : FeeToken.ETH; - - components.push(headerUI('Do you want to sign this transaction?')); - components.push( - signerUI({ - address, - chainId, - }), - ); - - // Display a message to indicate the signed transaction will include an account deployment - if (!accountDeployed) { - components.push(headerUI(`The account will be deployed`)); - } - - components.push(dividerUI()); - components.push( - rowUI({ - label: `Estimated Gas Fee (${feeToken})`, - value: convert(maxFee, 'wei', 'ether'), - }), - ); - - components.push(dividerUI()); - components.push( - networkUI({ - networkName: chainName, - }), - ); - - // Iterate over each call in the calls array - for (const call of callsArray) { - const { contractAddress, calldata, entrypoint } = call; - components.push(dividerUI()); - components.push( - addressUI({ - label: 'Contract', - address: contractAddress, - chainId, - }), - ); - - components.push( - jsonDataUI({ - label: 'Call Data', - data: calldata, - }), - ); - - // If the contract is an ERC20 token and the function is 'transfer', display sender, recipient, and amount - const token = await this.tokenStateManager.getToken({ - address: contractAddress, - chainId, - }); - - if (token && entrypoint === 'transfer' && calldata) { - try { - const senderAddress = address; - const recipientAddress = calldata[0]; // Assuming the first element in calldata is the recipient - let amount = ''; - - if ([3, 6, 9, 12, 15, 18].includes(token.decimals)) { - amount = convert(calldata[1], -1 * token.decimals, 'ether'); - } else { - amount = ( - Number(calldata[1]) * Math.pow(10, -1 * token.decimals) - ).toFixed(token.decimals); - } - components.push(dividerUI()); - components.push( - addressUI({ - label: 'Sender Address', - address: senderAddress, - chainId, - }), - dividerUI(), - addressUI({ - label: 'Recipient Address', - address: recipientAddress, - chainId, - }), - dividerUI(), - rowUI({ - label: `Amount (${token.symbol})`, - value: amount, - }), - ); - } catch (error) { - logger.warn( - // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `error found in amount conversion: ${error}`, - ); - } - } - } - - return await confirmDialog(components); - } - protected createDeployTxn( address: string, transactionHash: string, diff --git a/packages/starknet-snap/src/state/request-state-manager.test.ts b/packages/starknet-snap/src/state/request-state-manager.test.ts deleted file mode 100644 index 04531eb8..00000000 --- a/packages/starknet-snap/src/state/request-state-manager.test.ts +++ /dev/null @@ -1,173 +0,0 @@ -import { constants } from 'starknet'; - -import { generateTransactionRequests } from '../__tests__/helper'; -import type { TransactionRequest } from '../types/snapState'; -import { mockAcccounts, mockState } from './__tests__/helper'; -import { TransactionRequestStateManager } from './request-state-manager'; -import { StateManagerError } from './state-manager'; - -describe('TransactionRequestStateManager', () => { - const getChainId = () => constants.StarknetChainId.SN_SEPOLIA; - - const prepareMockData = async () => { - const chainId = getChainId(); - const accounts = await mockAcccounts(chainId, 1); - const transactionRequests = generateTransactionRequests({ - chainId, - address: accounts[0].address, - cnt: 10, - }); - - const { state, setDataSpy, getDataSpy } = await mockState({ - transactionRequests, - }); - - return { - state, - setDataSpy, - getDataSpy, - account: accounts[0], - transactionRequests, - }; - }; - - const getNewEntity = (address) => { - const chainId = getChainId(); - const transactionRequests = generateTransactionRequests({ - chainId, - address, - cnt: 1, - }); - - return transactionRequests[0]; - }; - - const getUpdateEntity = (request: TransactionRequest) => { - return { - ...request, - maxFee: '999999', - }; - }; - - describe('getTransactionRequest', () => { - it('returns the transaction request', async () => { - const { - transactionRequests: [transactionRequest], - } = await prepareMockData(); - - const stateManager = new TransactionRequestStateManager(); - const result = await stateManager.getTransactionRequest({ - requestId: transactionRequest.id, - }); - - expect(result).toStrictEqual(transactionRequest); - }); - - it('finds the request by interfaceId', async () => { - const { - transactionRequests: [transactionRequest], - } = await prepareMockData(); - - const stateManager = new TransactionRequestStateManager(); - const result = await stateManager.getTransactionRequest({ - interfaceId: transactionRequest.interfaceId, - }); - - expect(result).toStrictEqual(transactionRequest); - }); - - it('returns null if the transaction request can not be found', async () => { - await prepareMockData(); - - const stateManager = new TransactionRequestStateManager(); - - const result = await stateManager.getTransactionRequest({ - requestId: 'something', - }); - expect(result).toBeNull(); - }); - - it('throws a `At least one search condition must be provided` error if no search criteria given', async () => { - const stateManager = new TransactionRequestStateManager(); - - await expect(stateManager.getTransactionRequest({})).rejects.toThrow( - 'At least one search condition must be provided', - ); - }); - }); - - describe('upsertTransactionRequest', () => { - it('updates the transaction request if the transaction request found', async () => { - const { - state, - transactionRequests: [transactionRequest], - } = await prepareMockData(); - const entity = getUpdateEntity(transactionRequest); - - const stateManager = new TransactionRequestStateManager(); - await stateManager.upsertTransactionRequest(entity); - - expect( - state.transactionRequests.find( - (req) => req.id === transactionRequest.id, - ), - ).toStrictEqual(entity); - }); - - it('add a new transaction request if the transaction request does not found', async () => { - const { state, account } = await prepareMockData(); - const entity = getNewEntity(account.address); - const orgLength = state.transactionRequests.length; - - const stateManager = new TransactionRequestStateManager(); - await stateManager.upsertTransactionRequest(entity); - - expect(state.transactionRequests).toHaveLength(orgLength + 1); - expect( - state.transactionRequests.find((req) => req.id === entity.id), - ).toStrictEqual(entity); - }); - - it('throws a `StateManagerError` error if an error was thrown', async () => { - const { account, setDataSpy } = await prepareMockData(); - const entity = getNewEntity(account.address); - setDataSpy.mockRejectedValue(new Error('Error')); - - const stateManager = new TransactionRequestStateManager(); - - await expect( - stateManager.upsertTransactionRequest(entity), - ).rejects.toThrow(StateManagerError); - }); - }); - - describe('removeTransactionRequests', () => { - it('removes the request', async () => { - const { - transactionRequests: [{ id }], - state, - } = await prepareMockData(); - const stateManager = new TransactionRequestStateManager(); - - await stateManager.removeTransactionRequest(id); - - expect( - state.transactionRequests.filter((req) => req.id === id), - ).toStrictEqual([]); - }); - - it('throws a `StateManagerError` error if an error was thrown', async () => { - const { - transactionRequests: [{ id }], - setDataSpy, - } = await prepareMockData(); - setDataSpy.mockRejectedValue(new Error('Error')); - - const stateManager = new TransactionRequestStateManager(); - - await expect(stateManager.removeTransactionRequest(id)).rejects.toThrow( - StateManagerError, - ); - }); - }); -}); diff --git a/packages/starknet-snap/src/state/request-state-manager.ts b/packages/starknet-snap/src/state/request-state-manager.ts deleted file mode 100644 index 1baa0a37..00000000 --- a/packages/starknet-snap/src/state/request-state-manager.ts +++ /dev/null @@ -1,123 +0,0 @@ -import type { TransactionRequest, SnapState } from '../types/snapState'; -import { logger } from '../utils'; -import type { IFilter } from './filter'; -import { StringFllter } from './filter'; -import { StateManager, StateManagerError } from './state-manager'; - -export type ITransactionRequestFilter = IFilter; - -export class IdFilter - extends StringFllter - implements ITransactionRequestFilter -{ - dataKey = 'id'; -} - -export class InterfaceIdFilter - extends StringFllter - implements ITransactionRequestFilter -{ - dataKey = 'interfaceId'; -} - -export class TransactionRequestStateManager extends StateManager { - protected getCollection(state: SnapState): TransactionRequest[] { - return state.transactionRequests ?? []; - } - - protected updateEntity( - dataInState: TransactionRequest, - data: TransactionRequest, - ): void { - // This is the only field that can be updated - dataInState.maxFee = data.maxFee; - dataInState.feeToken = data.feeToken; - } - - /** - * Finds a `TransactionRequest` object based on the given requestId or interfaceId. - * - * @param param - The param object. - * @param param.requestId - The requestId to search for. - * @param param.interfaceId - The interfaceId to search for. - * @param [state] - The optional SnapState object. - * @returns A Promise that resolves with the `TransactionRequest` object if found, or null if not found. - */ - async getTransactionRequest( - { - requestId, - interfaceId, - }: { - requestId?: string; - interfaceId?: string; - }, - state?: SnapState, - ): Promise { - const filters: ITransactionRequestFilter[] = []; - if (requestId) { - filters.push(new IdFilter([requestId])); - } - if (interfaceId) { - filters.push(new InterfaceIdFilter([interfaceId])); - } - if (filters.length === 0) { - throw new StateManagerError( - 'At least one search condition must be provided', - ); - } - return await this.find(filters, state); - } - - /** - * Upsert a `TransactionRequest` in the state with the given data. - * - * @param data - The `TransactionRequest` object. - * @returns A Promise that resolves when the upsert is complete. - */ - async upsertTransactionRequest(data: TransactionRequest): Promise { - try { - await this.update(async (state: SnapState) => { - const dataInState = await this.getTransactionRequest( - { - requestId: data.id, - }, - state, - ); - - if (dataInState === null) { - this.getCollection(state)?.push(data); - } else { - this.updateEntity(dataInState, data); - } - }); - } catch (error) { - throw new StateManagerError(error.message); - } - } - - /** - * Removes the `TransactionRequest` objects in the state with the given requestId. - * - * @param requestId - The requestId to search for. - * @returns A Promise that resolves when the remove is complete. - */ - async removeTransactionRequest(requestId: string): Promise { - try { - await this.update(async (state: SnapState) => { - const sizeOfTransactionRequests = this.getCollection(state).length; - - state.transactionRequests = this.getCollection(state).filter((req) => { - return req.id !== requestId; - }); - - // Check if the TransactionRequest was removed - if (sizeOfTransactionRequests === this.getCollection(state).length) { - // If the TransactionRequest does not exist, log a warning instead of throwing an error - logger.warn(`TransactionRequest with id ${requestId} does not exist`); - } - }); - } catch (error) { - throw new StateManagerError(error.message); - } - } -} diff --git a/packages/starknet-snap/src/types/snapState.ts b/packages/starknet-snap/src/types/snapState.ts index ee9ed7bc..ceea001a 100644 --- a/packages/starknet-snap/src/types/snapState.ts +++ b/packages/starknet-snap/src/types/snapState.ts @@ -34,7 +34,7 @@ export type TransactionRequest = { networkName: string; maxFee: string; calls: FormattedCallData[]; - feeToken: string; + selectedFeeToken: string; includeDeploy: boolean; }; diff --git a/packages/starknet-snap/src/ui/utils.tsx b/packages/starknet-snap/src/ui/utils.tsx index 1fd9f87d..81d3ba27 100644 --- a/packages/starknet-snap/src/ui/utils.tsx +++ b/packages/starknet-snap/src/ui/utils.tsx @@ -1,5 +1,6 @@ -import type { FormattedCallData } from '../types/snapState'; +import type { FormattedCallData, TransactionRequest } from '../types/snapState'; import { DEFAULT_DECIMAL_PLACES } from '../utils/constants'; +import { ExecuteTxnUI } from './components'; import type { TokenTotals } from './types'; /** @@ -40,3 +41,42 @@ export const accumulateTotals = ( }, ); }; +/** + * Generate the interface for a ExecuteTxnUI + * + * @param request - TransactionRequest + * @returns A Promise that resolves to the interface ID generated by the Snap request. + * The ID can be used for tracking or referencing the created interface. + */ +export async function generateExecuteTxnFlow( + request: TransactionRequest, // Request must match props and include an `id` +) { + const { + signer, + chainId, + networkName, + maxFee, + calls, + selectedFeeToken, + includeDeploy, + } = request; + return await snap.request({ + method: 'snap_createInterface', + params: { + ui: ( + + ), + context: { + request, + }, + }, + }); +} diff --git a/packages/starknet-snap/src/utils/__mocks__/snap.ts b/packages/starknet-snap/src/utils/__mocks__/snap.ts index 75f5e80b..8f0328e9 100644 --- a/packages/starknet-snap/src/utils/__mocks__/snap.ts +++ b/packages/starknet-snap/src/utils/__mocks__/snap.ts @@ -4,6 +4,8 @@ export const getBip44Deriver = jest.fn(); export const confirmDialog = jest.fn(); +export const createInteractiveConfirmDialog = jest.fn(); + export const alertDialog = jest.fn(); export const getStateData = jest.fn(); diff --git a/packages/starknet-snap/src/utils/formatter-utils.ts b/packages/starknet-snap/src/utils/formatter-utils.ts index c91e1488..7f88601f 100644 --- a/packages/starknet-snap/src/utils/formatter-utils.ts +++ b/packages/starknet-snap/src/utils/formatter-utils.ts @@ -1,3 +1,11 @@ +import type { Call } from 'starknet'; +import { assert } from 'superstruct'; + +import type { TokenStateManager } from '../state/token-state-manager'; +import type { FormattedCallData } from '../types/snapState'; +import { logger } from './logger'; +import { AddressStruct, NumberStringStruct } from './superstruct'; + export const hexToString = (hexStr) => { let str = ''; for (let i = 0; i < hexStr.length; i += 2) { @@ -37,3 +45,47 @@ export const mapDeprecatedParams = ( } }); }; + +export const callToTransactionReqCall = async ( + call: Call, + chainId: string, + address: string, + tokenStateManager: TokenStateManager, +): Promise => { + const { contractAddress, calldata, entrypoint } = call; + // Base data object for each call, with transfer fields left as optional + const formattedCall: FormattedCallData = { + contractAddress, + calldata: calldata as string[], + entrypoint, + }; + // Check if the contract is an ERC20 token and entrypoint is 'transfer' to populate transfer fields + const token = await tokenStateManager.getToken({ + address: contractAddress, + chainId, + }); + if (token && entrypoint === 'transfer' && calldata) { + try { + const senderAddress = address; + + // ensure the data is in correct format, + // if an error occur, it will catch and not to format it + assert(calldata[0], AddressStruct); + assert(calldata[1], NumberStringStruct); + const recipientAddress = calldata[0]; // Assuming calldata[0] is the recipient address + const amount = calldata[1]; + // Populate transfer-specific fields + formattedCall.tokenTransferData = { + senderAddress, + recipientAddress, + amount: typeof amount === 'number' ? amount.toString() : amount, + symbol: token.symbol, + decimals: token.decimals, + }; + } catch (error) { + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + logger.warn(`Error in amount conversion: ${error.message}`); + } + } + return formattedCall; +}; diff --git a/packages/starknet-snap/src/utils/snap.test.ts b/packages/starknet-snap/src/utils/snap.test.ts index 8fa99749..dfbeb7ea 100644 --- a/packages/starknet-snap/src/utils/snap.test.ts +++ b/packages/starknet-snap/src/utils/snap.test.ts @@ -22,6 +22,23 @@ describe('getBip44Deriver', () => { }); }); +describe('createInteractiveConfirmDialog', () => { + it('calls snap_dialog', async () => { + const spy = jest.spyOn(snapUtil.getProvider(), 'request'); + const interfaceId = 'test'; + + await snapUtil.createInteractiveConfirmDialog(interfaceId); + + expect(spy).toHaveBeenCalledWith({ + method: 'snap_dialog', + params: { + type: 'confirmation', + id: interfaceId, + }, + }); + }); +}); + describe('confirmDialog', () => { it('calls snap_dialog', async () => { const spy = jest.spyOn(snapUtil.getProvider(), 'request'); diff --git a/packages/starknet-snap/src/utils/snap.ts b/packages/starknet-snap/src/utils/snap.ts index 49856f57..6d9bb14c 100644 --- a/packages/starknet-snap/src/utils/snap.ts +++ b/packages/starknet-snap/src/utils/snap.ts @@ -29,6 +29,24 @@ export async function getBip44Deriver(): Promise { return getBIP44AddressKeyDeriver(bip44Node); } +/** + * Displays a confirmation dialog with the specified interface id. + * + * @param interfaceId - A string representing the id of the interface. + * @returns A Promise that resolves to the result of the dialog. + */ +export async function createInteractiveConfirmDialog( + interfaceId: string, +): Promise { + return snap.request({ + method: 'snap_dialog', + params: { + type: DialogType.Confirmation, + id: interfaceId, + }, + }); +} + /** * Displays a confirmation dialog with the specified components. *