diff --git a/utils/multicall/multicaller.go b/utils/multicall/multicaller.go index ec03c5e2..e502c239 100644 --- a/utils/multicall/multicaller.go +++ b/utils/multicall/multicaller.go @@ -17,22 +17,18 @@ import ( ) type Call struct { - Method string `json:"method"` - Target common.Address `json:"target"` - CallData []byte `json:"call_data"` - Contract *rocketpool.Contract - output interface{} + Target common.Address `json:"target"` + CallData []byte `json:"call_data"` + UnpackFunc func([]byte) error } type CallResponse struct { - Method string Status bool ReturnDataRaw []byte `json:"returnData"` } type Result struct { Success bool `json:"success"` - Output interface{} } func (call Call) GetMultiCall() MultiCall { @@ -43,7 +39,7 @@ type MultiCaller struct { Client rocketpool.ExecutionClient ABI abi.ABI ContractAddress common.Address - calls []Call + Calls []Call } func NewMultiCaller(client rocketpool.ExecutionClient, multicallerAddress common.Address) (*MultiCaller, error) { @@ -56,7 +52,7 @@ func NewMultiCaller(client rocketpool.ExecutionClient, multicallerAddress common Client: client, ABI: mcAbi, ContractAddress: multicallerAddress, - calls: []Call{}, + Calls: []Call{}, }, nil } @@ -66,19 +62,19 @@ func (caller *MultiCaller) AddCall(contract *rocketpool.Contract, output interfa return fmt.Errorf("error adding call [%s]: %w", method, err) } call := Call{ - Method: method, Target: *contract.Address, CallData: callData, - Contract: contract, - output: output, + UnpackFunc: func(rawData []byte) error { + return contract.ABI.UnpackIntoInterface(output, method, rawData) + }, } - caller.calls = append(caller.calls, call) + caller.Calls = append(caller.Calls, call) return nil } func (caller *MultiCaller) Execute(requireSuccess bool, opts *bind.CallOpts) ([]CallResponse, error) { - var multiCalls = make([]MultiCall, 0, len(caller.calls)) - for _, call := range caller.calls { + var multiCalls = make([]MultiCall, 0, len(caller.Calls)) + for _, call := range caller.Calls { multiCalls = append(multiCalls, call.GetMultiCall()) } callData, err := caller.ABI.Pack("tryAggregate", requireSuccess, multiCalls) @@ -97,12 +93,11 @@ func (caller *MultiCaller) Execute(requireSuccess bool, opts *bind.CallOpts) ([] return nil, err } - results := make([]CallResponse, len(caller.calls)) + results := make([]CallResponse, len(caller.Calls)) for i, response := range responses[0].([]struct { Success bool `json:"success"` ReturnData []byte `json:"returnData"` }) { - results[i].Method = caller.calls[i].Method results[i].ReturnDataRaw = response.ReturnData results[i].Status = response.Success } @@ -110,24 +105,23 @@ func (caller *MultiCaller) Execute(requireSuccess bool, opts *bind.CallOpts) ([] } func (caller *MultiCaller) FlexibleCall(requireSuccess bool, opts *bind.CallOpts) ([]Result, error) { - res := make([]Result, len(caller.calls)) + res := make([]Result, len(caller.Calls)) results, err := caller.Execute(requireSuccess, opts) if err != nil { - caller.calls = []Call{} + caller.Calls = []Call{} return nil, err } - for i, call := range caller.calls { + for i, call := range caller.Calls { callSuccess := results[i].Status if callSuccess { - err := call.Contract.ABI.UnpackIntoInterface(call.output, call.Method, results[i].ReturnDataRaw) + err := call.UnpackFunc(results[i].ReturnDataRaw) if err != nil { - caller.calls = []Call{} + caller.Calls = []Call{} return nil, err } } res[i].Success = callSuccess - res[i].Output = call.output } - caller.calls = []Call{} + caller.Calls = []Call{} return res, err }