Skip to content

Commit

Permalink
refactor: revamp starkNet_estimateFee (#329)
Browse files Browse the repository at this point in the history
* chore: revamp estimate fee and removed unused rpc calls

* fix: verify if account need upgrade or deploy does not always throw

* fix: address comments

* refactor: revamp sign delcare transaction

* feat: match starknet.js signature in estimate fee

* feat: createStructWithAdditionalProperties superstruct factory

* fix: improve error message in rpc input validation

* refactor: add universal details and invocations in superstruct util

* chore: lint + prettier

* chore: add superstruct test

* chore: update get-starknet interface

* fix: address comments

* chore: restrict tx version to v3 and v2

* fix: address further comments

* fix: address further comments

* refactor: rpc estimatefee test ts

* chore: lint + prettier

* feat: use define superstruct instead of union

* fix: getEstimatedFees

* chore: refine estimate fee (#337)

* fix: address comments

* chore: lint + prettier

* fix: address comments

* feat: remove transactionVersion, use invocationDetails

* fix: last one

---------

Co-authored-by: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com>
  • Loading branch information
khanti42 and stanleyyconsensys authored Aug 28, 2024
1 parent d7708da commit e6d2518
Show file tree
Hide file tree
Showing 19 changed files with 5,814 additions and 43 deletions.
4,969 changes: 4,969 additions & 0 deletions packages/starknet-snap/src/__tests__/fixture/contract-example.json

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions packages/starknet-snap/src/estimateFee.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import type { Invocations } from 'starknet';
import { TransactionType } from 'starknet';

import type {
ApiParamsWithKeyDeriver,
EstimateFeeRequestParams,
import {
FeeTokenUnit,
type ApiParamsWithKeyDeriver,
type EstimateFeeRequestParams,
} from './types/snapApi';
import { ACCOUNT_CLASS_HASH } from './utils/constants';
import { logger } from './utils/logger';
Expand Down Expand Up @@ -130,7 +131,7 @@ export async function estimateFee(params: ApiParamsWithKeyDeriver) {
const resp = {
suggestedMaxFee: estimateFeeResp.suggestedMaxFee.toString(10),
overallFee: estimateFeeResp.overall_fee.toString(10),
unit: 'wei',
unit: FeeTokenUnit.ETH,
includeDeploy: !accountDeployed,
};
logger.log(`estimateFee:\nresp: ${toJson(resp)}`);
Expand Down
8 changes: 4 additions & 4 deletions packages/starknet-snap/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import { addNetwork } from './addNetwork';
import { createAccount } from './createAccount';
import { declareContract } from './declareContract';
import { estimateAccDeployFee } from './estimateAccountDeployFee';
import { estimateFee } from './estimateFee';
import { estimateFees } from './estimateFees';
import { executeTxn } from './executeTxn';
import { extractPublicKey } from './extractPublicKey';
Expand All @@ -37,13 +36,15 @@ import { getValue } from './getValue';
import { recoverAccounts } from './recoverAccounts';
import type {
DisplayPrivateKeyParams,
EstimateFeeParams,
SignMessageParams,
SignTransactionParams,
SignDeclareTransactionParams,
VerifySignatureParams,
} from './rpcs';
import {
displayPrivateKey,
estimateFee,
signMessage,
signTransaction,
signDeclareTransaction,
Expand Down Expand Up @@ -224,9 +225,8 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => {
return await getValue(apiParams);

case 'starkNet_estimateFee':
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return await estimateFee(
apiParams as unknown as ApiParamsWithKeyDeriver,
return await estimateFee.execute(
apiParams.requestParams as unknown as EstimateFeeParams,
);

case 'starkNet_estimateAccountDeployFee':
Expand Down
100 changes: 100 additions & 0 deletions packages/starknet-snap/src/rpcs/estimateFee.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import { InvalidParamsError } from '@metamask/snaps-sdk';
import type { Invocations } from 'starknet';
import { constants, TransactionType } from 'starknet';
import type { Infer } from 'superstruct';

import { FeeTokenUnit } from '../types/snapApi';
import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants';
import * as starknetUtils from '../utils/starknetUtils';
import type { TxVersionStruct } from '../utils/superstruct';
import { mockAccount, prepareMockAccount } from './__tests__/helper';
import { estimateFee } from './estimateFee';
import type { EstimateFeeParams } from './estimateFee';

jest.mock('../utils/snap');
jest.mock('../utils/logger');

const prepareMockEstimateFee = ({
chainId,
address,
version,
includeDeploy = false,
}: {
chainId: constants.StarknetChainId;
address: string;
version: Infer<typeof TxVersionStruct>;
includeDeploy?: boolean;
}) => {
const invocations: Invocations = [
{
type: TransactionType.INVOKE,
payload: {
contractAddress:
'0x00b28a089e7fb83debee4607b6334d687918644796b47d9e9e38ea8213833137',
entrypoint: 'functionName',
calldata: ['1', '1'],
},
},
];

const request = {
chainId,
address,
invocations,
details: { version },
} as unknown as EstimateFeeParams;

const estimateBulkFeeRespMock = {
suggestedMaxFee: BigInt(1000000000000000).toString(10),
overallFee: BigInt(1500000000000000).toString(10),
unit: FeeTokenUnit.ETH,
includeDeploy,
};

const getEstimatedFeesSpy = jest.spyOn(starknetUtils, 'getEstimatedFees');
getEstimatedFeesSpy.mockResolvedValue(estimateBulkFeeRespMock);

return { estimateBulkFeeRespMock, invocations, request, getEstimatedFeesSpy };
};

describe('estimateFee', () => {
const state = {
accContracts: [],
erc20Tokens: [],
networks: [STARKNET_SEPOLIA_TESTNET_NETWORK],
transactions: [],
};

it('estimates fee correctly', async () => {
const chainId = constants.StarknetChainId.SN_SEPOLIA;
const account = await mockAccount(chainId);
prepareMockAccount(account, state);
const { request, getEstimatedFeesSpy, estimateBulkFeeRespMock } =
prepareMockEstimateFee({
includeDeploy: false,
chainId,
address: account.address,
version: constants.TRANSACTION_VERSION.V2,
});

const result = await estimateFee.execute(request);

expect(getEstimatedFeesSpy).toHaveBeenCalledWith(
STARKNET_SEPOLIA_TESTNET_NETWORK,
account.address,
account.privateKey,
account.publicKey,
request.invocations,
{
version: constants.TRANSACTION_VERSION.V2,
},
);
expect(result).toStrictEqual(estimateBulkFeeRespMock);
});

it('throws `InvalidParamsError` when request parameter is not correct', async () => {
await expect(
estimateFee.execute({} as unknown as EstimateFeeParams),
).rejects.toThrow(InvalidParamsError);
});
});
81 changes: 81 additions & 0 deletions packages/starknet-snap/src/rpcs/estimateFee.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import type { Json } from '@metamask/snaps-sdk';
import type { UniversalDetails } from 'starknet';
import type { Infer } from 'superstruct';
import { object, string, assign, boolean, optional, enums } from 'superstruct';

import { FeeTokenUnit } from '../types/snapApi';
import {
AddressStruct,
BaseRequestStruct,
AccountRpcController,
UniversalDetailsStruct,
InvocationsStruct,
} from '../utils';
import { getEstimatedFees } from '../utils/starknetUtils';

export const EstimateFeeRequestStruct = assign(
object({
address: AddressStruct,
invocations: InvocationsStruct,
details: optional(UniversalDetailsStruct),
}),
BaseRequestStruct,
);

export const EstimateFeeResponseStruct = object({
suggestedMaxFee: string(),
overallFee: string(),
unit: enums(Object.values(FeeTokenUnit)),
includeDeploy: boolean(),
});

export type EstimateFeeParams = Infer<typeof EstimateFeeRequestStruct> & Json;

export type EstimateFeeResponse = Infer<typeof EstimateFeeResponseStruct>;

/**
* The RPC handler to estimate fee of a transaction.
*/
export class EstimateFeeRpc extends AccountRpcController<
EstimateFeeParams,
EstimateFeeResponse
> {
protected requestStruct = EstimateFeeRequestStruct;

protected responseStruct = EstimateFeeResponseStruct;

/**
* Execute the bulk estimate transaction fee request handler.
*
* @param params - The parameters of the request.
* @param params.address - The account address.
* @param params.invocations - The invocations to estimate fee. Reference: https://starknetjs.com/docs/API/namespaces/types#invocations
* @param params.details - The universal details associated to the invocations. Reference: https://starknetjs.com/docs/API/interfaces/types.EstimateFeeDetails
* @param params.chainId - The chain id of the network.
* @returns A promise that resolves to a EstimateFeeResponse object.
*/
async execute(params: EstimateFeeParams): Promise<EstimateFeeResponse> {
return super.execute(params);
}

protected async handleRequest(
params: EstimateFeeParams,
): Promise<EstimateFeeResponse> {
const { address, invocations, details } = params;

const estimateFeeResp = await getEstimatedFees(
this.network,
address,
this.account.privateKey,
this.account.publicKey,
invocations,
details as UniversalDetails,
);

return estimateFeeResp;
}
}

export const estimateFee = new EstimateFeeRpc({
showInvalidAccountAlert: false,
});
3 changes: 2 additions & 1 deletion packages/starknet-snap/src/rpcs/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export * from './signMessage';
export * from './displayPrivateKey';
export * from './estimateFee';
export * from './signMessage';
export * from './signTransaction';
export * from './sign-declare-transaction';
export * from './verify-signature';
10 changes: 10 additions & 0 deletions packages/starknet-snap/src/types/snapApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,14 @@ export type GetStarkNameRequestParam = {
userAddress: string;
} & BaseRequestParams;

export enum FeeToken {
ETH = 'ETH',
STRK = 'STRK',
}

export enum FeeTokenUnit {
ETH = 'wei',
STRK = 'fri',
}

/* eslint-disable */
12 changes: 12 additions & 0 deletions packages/starknet-snap/src/types/starknet.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import type { GetTransactionResponse } from 'starknet';
import type { Infer } from 'superstruct';

import type { TxVersionStruct } from '../utils';

export type TransactionStatuses = {
executionStatus: string | undefined;
Expand All @@ -13,3 +16,12 @@ export type TransactionResponse = GetTransactionResponse & {
contract_address?: string;
calldata?: string[];
};

export type TransactionVersion = Infer<typeof TxVersionStruct>;

export type DeployAccountPayload = {
classHash: string;
contractAddress: string;
constructorCalldata: string[];
addressSalt: string;
};
5 changes: 5 additions & 0 deletions packages/starknet-snap/src/utils/rpc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ describe('validateRequest', () => {
expect(() =>
validateRequest(requestParams, validateStruct as unknown as Struct),
).toThrow(InvalidParamsError);
expect(() =>
validateRequest(requestParams, validateStruct as unknown as Struct),
).toThrow(
'At path: signerAddress -- Expected a string, but received: 1234',
);
});
});

Expand Down
2 changes: 1 addition & 1 deletion packages/starknet-snap/src/utils/rpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ export type Account = {
derivationPath: ReturnType<typeof getBIP44ChangePathString>;
};

export type AccountRpcControllerOptions = Json & {
export type AccountRpcControllerOptions = {
showInvalidAccountAlert: boolean;
};

Expand Down
9 changes: 5 additions & 4 deletions packages/starknet-snap/src/utils/snapUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import type {
DeployAccountSignerDetails,
} from 'starknet';

import type {
AddErc20TokenRequestParams,
AddNetworkRequestParams,
import {
FeeToken,
type AddErc20TokenRequestParams,
type AddNetworkRequestParams,
} from '../types/snapApi';
import { TransactionStatus } from '../types/snapState';
import type {
Expand Down Expand Up @@ -391,7 +392,7 @@ export function getTxnSnapTxt(
if (invocationsDetails?.maxFee) {
addDialogTxt(
components,
'Max Fee(ETH)',
`Max Fee(${FeeToken.ETH})`,
convert(invocationsDetails.maxFee, 'wei', 'ether'),
);
}
Expand Down
Loading

0 comments on commit e6d2518

Please sign in to comment.