From 7f3874c58e6f35262b20b539153d7ffaad25aa29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Migone?= Date: Wed, 27 Nov 2024 17:26:27 -0300 Subject: [PATCH] fix: allow partially collecting RAVs (TRST-M05) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tomás Migone --- .../contracts/interfaces/ITAPCollector.sol | 25 ++++ .../payments/collectors/TAPCollector.sol | 119 ++++++++++-------- .../payments/tap-collector/TAPCollector.t.sol | 18 ++- .../tap-collector/collect/collect.t.sol | 56 ++++++++- 4 files changed, 162 insertions(+), 56 deletions(-) diff --git a/packages/horizon/contracts/interfaces/ITAPCollector.sol b/packages/horizon/contracts/interfaces/ITAPCollector.sol index b364135c3..347ccf565 100644 --- a/packages/horizon/contracts/interfaces/ITAPCollector.sol +++ b/packages/horizon/contracts/interfaces/ITAPCollector.sol @@ -2,6 +2,7 @@ pragma solidity 0.8.27; import { IPaymentsCollector } from "./IPaymentsCollector.sol"; +import { IGraphPayments } from "./IGraphPayments.sol"; /** * @title Interface for the {TAPCollector} contract @@ -175,6 +176,13 @@ interface ITAPCollector is IPaymentsCollector { */ error TAPCollectorInconsistentRAVTokens(uint256 tokens, uint256 tokensCollected); + /** + * Thrown when the attempting to collect more tokens than what it's owed + * @param tokensToCollect The amount of tokens to collect + * @param maxTokensToCollect The maximum amount of tokens to collect + */ + error TAPCollectorInvalidTokensToCollectAmount(uint256 tokensToCollect, uint256 maxTokensToCollect); + /** * @notice Authorize a signer to sign on behalf of the payer. * A signer can not be authorized for multiple payers even after revoking previous authorizations. @@ -237,4 +245,21 @@ interface ITAPCollector is IPaymentsCollector { * @return The hash of the RAV. */ function encodeRAV(ReceiptAggregateVoucher calldata rav) external view returns (bytes32); + + /** + * @notice See {IPaymentsCollector.collect} + * This variant adds the ability to partially collect a RAV by specifying the amount of tokens to collect. + * + * Requirements: + * - The amount of tokens to collect must be less than or equal to the total amount of tokens in the RAV minus + * the tokens already collected. + * @param paymentType The payment type to collect + * @param data Additional data required for the payment collection + * @param tokensToCollect The amount of tokens to collect + */ + function collect( + IGraphPayments.PaymentTypes paymentType, + bytes calldata data, + uint256 tokensToCollect + ) external returns (uint256); } diff --git a/packages/horizon/contracts/payments/collectors/TAPCollector.sol b/packages/horizon/contracts/payments/collectors/TAPCollector.sol index 01fbb9716..ebaa1f33d 100644 --- a/packages/horizon/contracts/payments/collectors/TAPCollector.sol +++ b/packages/horizon/contracts/payments/collectors/TAPCollector.sol @@ -125,37 +125,15 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector { * @notice REVERT: This function may revert if ECDSA.recover fails, check ECDSA library for details. */ function collect(IGraphPayments.PaymentTypes paymentType, bytes memory data) external override returns (uint256) { - // Ensure caller is the RAV data service - (SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(data, (SignedRAV, uint256)); - require( - signedRAV.rav.dataService == msg.sender, - TAPCollectorCallerNotDataService(msg.sender, signedRAV.rav.dataService) - ); - - // Ensure RAV signer is authorized for a payer - address signer = _recoverRAVSigner(signedRAV); - require( - authorizedSigners[signer].payer != address(0) && !authorizedSigners[signer].revoked, - TAPCollectorInvalidRAVSigner() - ); - - // Ensure RAV payer matches the authorized payer - address payer = signedRAV.rav.payer; - require( - authorizedSigners[signer].payer == payer, - TAPCollectorInvalidRAVPayer(authorizedSigners[signer].payer, payer) - ); - - // Check the service provider has an active provision with the data service - // This prevents an attack where the payer can deny the service provider from collecting payments - // by using a signer as data service to syphon off the tokens in the escrow to an account they control - uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable( - signedRAV.rav.serviceProvider, - signedRAV.rav.dataService - ); - require(tokensAvailable > 0, TAPCollectorUnauthorizedDataService(signedRAV.rav.dataService)); + return _collect(paymentType, data, 0); + } - return _collect(paymentType, authorizedSigners[signer].payer, signedRAV, dataServiceCut); + function collect( + IGraphPayments.PaymentTypes paymentType, + bytes memory data, + uint256 tokensToCollect + ) external override returns (uint256) { + return _collect(paymentType, data, tokensToCollect); } /** @@ -177,28 +155,71 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector { */ function _collect( IGraphPayments.PaymentTypes _paymentType, - address _payer, - SignedRAV memory _signedRAV, - uint256 _dataServiceCut + bytes memory _data, + uint256 _tokensToCollect ) private returns (uint256) { - address dataService = _signedRAV.rav.dataService; - address receiver = _signedRAV.rav.serviceProvider; + // Ensure caller is the RAV data service + (SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(_data, (SignedRAV, uint256)); + require( + signedRAV.rav.dataService == msg.sender, + TAPCollectorCallerNotDataService(msg.sender, signedRAV.rav.dataService) + ); + + // Ensure RAV signer is authorized for a payer + address signer = _recoverRAVSigner(signedRAV); + require( + authorizedSigners[signer].payer != address(0) && !authorizedSigners[signer].revoked, + TAPCollectorInvalidRAVSigner() + ); - uint256 tokensRAV = _signedRAV.rav.valueAggregate; - uint256 tokensAlreadyCollected = tokensCollected[dataService][receiver][_payer]; + // Ensure RAV payer matches the authorized payer + address payer = authorizedSigners[signer].payer; require( - tokensRAV > tokensAlreadyCollected, - TAPCollectorInconsistentRAVTokens(tokensRAV, tokensAlreadyCollected) + signedRAV.rav.payer == payer, + TAPCollectorInvalidRAVPayer(payer, signedRAV.rav.payer) ); - uint256 tokensToCollect = tokensRAV - tokensAlreadyCollected; - uint256 tokensDataService = tokensToCollect.mulPPM(_dataServiceCut); + address dataService = signedRAV.rav.dataService; + address receiver = signedRAV.rav.serviceProvider; + + // Check the service provider has an active provision with the data service + // This prevents an attack where the payer can deny the service provider from collecting payments + // by using a signer as data service to syphon off the tokens in the escrow to an account they control + { + uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable( + signedRAV.rav.serviceProvider, + signedRAV.rav.dataService + ); + require(tokensAvailable > 0, TAPCollectorUnauthorizedDataService(signedRAV.rav.dataService)); + } + + uint256 tokensToCollect = 0; + { + uint256 tokensRAV = signedRAV.rav.valueAggregate; + uint256 tokensAlreadyCollected = tokensCollected[dataService][receiver][payer]; + require( + tokensRAV > tokensAlreadyCollected, + TAPCollectorInconsistentRAVTokens(tokensRAV, tokensAlreadyCollected) + ); + + if (_tokensToCollect == 0) { + tokensToCollect = tokensRAV - tokensAlreadyCollected; + } else { + require( + _tokensToCollect <= tokensRAV - tokensAlreadyCollected, + TAPCollectorInvalidTokensToCollectAmount(_tokensToCollect, tokensRAV - tokensAlreadyCollected) + ); + tokensToCollect = _tokensToCollect; + } + } + + uint256 tokensDataService = tokensToCollect.mulPPM(dataServiceCut); if (tokensToCollect > 0) { - tokensCollected[dataService][receiver][_payer] = tokensRAV; + tokensCollected[dataService][receiver][payer] += tokensToCollect; _graphPaymentsEscrow().collect( _paymentType, - _payer, + payer, receiver, tokensToCollect, dataService, @@ -206,15 +227,15 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector { ); } - emit PaymentCollected(_paymentType, _payer, receiver, tokensToCollect, dataService, tokensDataService); + emit PaymentCollected(_paymentType, payer, receiver, tokensToCollect, dataService, tokensDataService); emit RAVCollected( - _payer, + payer, dataService, receiver, - _signedRAV.rav.timestampNs, - _signedRAV.rav.valueAggregate, - _signedRAV.rav.metadata, - _signedRAV.signature + signedRAV.rav.timestampNs, + signedRAV.rav.valueAggregate, + signedRAV.rav.metadata, + signedRAV.signature ); return tokensToCollect; } diff --git a/packages/horizon/test/payments/tap-collector/TAPCollector.t.sol b/packages/horizon/test/payments/tap-collector/TAPCollector.t.sol index 1120c5b92..ac67d6552 100644 --- a/packages/horizon/test/payments/tap-collector/TAPCollector.t.sol +++ b/packages/horizon/test/payments/tap-collector/TAPCollector.t.sol @@ -119,12 +119,20 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest } function _collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data) internal { + __collect(_paymentType, _data, 0); + } + + function _collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data, uint256 _tokensToCollect) internal { + __collect(_paymentType, _data, _tokensToCollect); + } + + function __collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data, uint256 _tokensToCollect) internal { (ITAPCollector.SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(_data, (ITAPCollector.SignedRAV, uint256)); bytes32 messageHash = tapCollector.encodeRAV(signedRAV.rav); address _signer = ECDSA.recover(messageHash, signedRAV.signature); (address _payer, , ) = tapCollector.authorizedSigners(_signer); uint256 tokensAlreadyCollected = tapCollector.tokensCollected(signedRAV.rav.dataService, signedRAV.rav.serviceProvider, _payer); - uint256 tokensToCollect = signedRAV.rav.valueAggregate - tokensAlreadyCollected; + uint256 tokensToCollect = _tokensToCollect == 0 ? signedRAV.rav.valueAggregate - tokensAlreadyCollected : _tokensToCollect; uint256 tokensDataService = tokensToCollect.mulPPM(dataServiceCut); vm.expectEmit(address(tapCollector)); @@ -136,6 +144,7 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest signedRAV.rav.dataService, tokensDataService ); + vm.expectEmit(address(tapCollector)); emit ITAPCollector.RAVCollected( _payer, signedRAV.rav.dataService, @@ -145,11 +154,10 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest signedRAV.rav.metadata, signedRAV.signature ); - - uint256 tokensCollected = tapCollector.collect(_paymentType, _data); - assertEq(tokensCollected, tokensToCollect); + uint256 tokensCollected = _tokensToCollect == 0 ? tapCollector.collect(_paymentType, _data) : tapCollector.collect(_paymentType, _data, _tokensToCollect); uint256 tokensCollectedAfter = tapCollector.tokensCollected(signedRAV.rav.dataService, signedRAV.rav.serviceProvider, _payer); - assertEq(tokensCollectedAfter, signedRAV.rav.valueAggregate); + assertEq(tokensCollected, tokensToCollect); + assertEq(tokensCollectedAfter, _tokensToCollect == 0 ? signedRAV.rav.valueAggregate : tokensAlreadyCollected + _tokensToCollect); } } diff --git a/packages/horizon/test/payments/tap-collector/collect/collect.t.sol b/packages/horizon/test/payments/tap-collector/collect/collect.t.sol index ddb76b919..a4e1eafa7 100644 --- a/packages/horizon/test/payments/tap-collector/collect/collect.t.sol +++ b/packages/horizon/test/payments/tap-collector/collect/collect.t.sol @@ -203,11 +203,12 @@ contract TAPCollectorCollectTest is TAPCollectorTest { tapCollector.collect(IGraphPayments.PaymentTypes.QueryFee, data); } - function testTAPCollector_Collect_RevertWhen_PayerMismatch(uint256 tokens) public useGateway useSigner { + function testTAPCollector_Collect_RevertWhen_PayerMismatch( + uint256 tokens + ) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner { tokens = bound(tokens, 1, type(uint128).max); resetPrank(users.gateway); - _approveCollector(address(tapCollector), tokens); _depositTokens(address(tapCollector), users.indexer, tokens); (address anotherPayer, ) = makeAddrAndKey("anotherPayer"); @@ -340,4 +341,55 @@ contract TAPCollectorCollectTest is TAPCollectorTest { resetPrank(users.verifier); _collect(IGraphPayments.PaymentTypes.QueryFee, data); } + + function testTAPCollector_CollectPartial( + uint256 tokens, + uint256 tokensToCollect + ) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner { + tokens = bound(tokens, 1, type(uint128).max); + tokensToCollect = bound(tokensToCollect, 1, tokens); + + _depositTokens(address(tapCollector), users.indexer, tokens); + + bytes memory data = _getQueryFeeEncodedData( + signerPrivateKey, + users.gateway, + users.indexer, + users.verifier, + uint128(tokens) + ); + + resetPrank(users.verifier); + _collect(IGraphPayments.PaymentTypes.QueryFee, data, tokensToCollect); + } + + function testTAPCollector_CollectPartial_RevertWhen_AmountTooHigh( + uint256 tokens, + uint256 tokensToCollect + ) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner { + tokens = bound(tokens, 1, type(uint128).max - 1); + + _depositTokens(address(tapCollector), users.indexer, tokens); + + bytes memory data = _getQueryFeeEncodedData( + signerPrivateKey, + users.gateway, + users.indexer, + users.verifier, + uint128(tokens) + ); + + resetPrank(users.verifier); + uint256 tokensAlreadyCollected = tapCollector.tokensCollected(users.verifier, users.indexer, users.gateway); + tokensToCollect = bound(tokensToCollect, tokens - tokensAlreadyCollected + 1, type(uint128).max); + + vm.expectRevert( + abi.encodeWithSelector( + ITAPCollector.TAPCollectorInvalidTokensToCollectAmount.selector, + tokensToCollect, + tokens - tokensAlreadyCollected + ) + ); + tapCollector.collect(IGraphPayments.PaymentTypes.QueryFee, data, tokensToCollect); + } }