From 40d1217191ee58e4f2e48172069eb8ba225e145f Mon Sep 17 00:00:00 2001 From: Dan Laine Date: Fri, 29 Sep 2023 08:09:54 -0400 Subject: [PATCH] Signature Aggregation Refactor (#883) * remove unused code * remove signature aggregation job * appease linter * remove signature aggregation job * move all signature aggregeation logic to AggregateSignatures * nit * nit * nit * add mock and tests * clean up tests * nit * don't autogen SignatureGetter because it requires manual changes because gomock is broken * nit * comment * comments * lower log level * re-add todo * nit * refactor aggregeateSignatures * comment nits * add tests * use error from avalancheWap; combine functions; nits * add logs * typo fix --- plugin/evm/vm.go | 2 +- warp/aggregator/aggregation_job.go | 170 -------- warp/aggregator/aggregation_job_test.go | 355 ---------------- warp/aggregator/aggregator.go | 188 +++++++- warp/aggregator/aggregator_test.go | 425 +++++++++++++++++++ warp/aggregator/mock_signature_getter.go | 53 +++ warp/aggregator/network_signature_backend.go | 6 +- warp/aggregator/signature_job.go | 60 --- warp/aggregator/signature_job_test.go | 140 ------ 9 files changed, 654 insertions(+), 745 deletions(-) delete mode 100644 warp/aggregator/aggregation_job.go delete mode 100644 warp/aggregator/aggregation_job_test.go create mode 100644 warp/aggregator/aggregator_test.go create mode 100644 warp/aggregator/mock_signature_getter.go delete mode 100644 warp/aggregator/signature_job.go delete mode 100644 warp/aggregator/signature_job_test.go diff --git a/plugin/evm/vm.go b/plugin/evm/vm.go index 304551988d..a6ebed2339 100644 --- a/plugin/evm/vm.go +++ b/plugin/evm/vm.go @@ -943,7 +943,7 @@ func (vm *VM) CreateHandlers(context.Context) (map[string]*commonEng.HTTPHandler } if vm.config.WarpAPIEnabled { - warpAggregator := aggregator.NewAggregator(vm.ctx.SubnetID, warpValidators.NewState(vm.ctx), &aggregator.NetworkSigner{Client: vm.client}) + warpAggregator := aggregator.New(vm.ctx.SubnetID, warpValidators.NewState(vm.ctx), &aggregator.NetworkSigner{Client: vm.client}) if err := handler.RegisterName("warp", warp.NewAPI(vm.warpBackend, warpAggregator)); err != nil { return nil, err } diff --git a/warp/aggregator/aggregation_job.go b/warp/aggregator/aggregation_job.go deleted file mode 100644 index 70c18c72ff..0000000000 --- a/warp/aggregator/aggregation_job.go +++ /dev/null @@ -1,170 +0,0 @@ -// (c) 2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package aggregator - -import ( - "context" - "fmt" - "sync" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/crypto/bls" - "github.com/ava-labs/avalanchego/utils/set" - avalancheWarp "github.com/ava-labs/avalanchego/vms/platformvm/warp" - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/log" -) - -// signatureAggregationJob fetches signatures for a single unsigned warp message. -type signatureAggregationJob struct { - // SignatureBackend is assumed to be thread-safe and may be used by multiple signature aggregation jobs concurrently - client SignatureBackend - height uint64 - subnetID ids.ID - - // Minimum threshold at which to return the resulting aggregate signature. If this threshold is not reached, - // return an error instead of aggregating the signatures that were fetched. - minValidQuorumNum uint64 - // Threshold at which to cancel fetching further signatures - maxNeededQuorumNum uint64 - // Denominator to use when checking if we've reached the threshold - quorumDen uint64 - state validators.State - msg *avalancheWarp.UnsignedMessage -} - -type AggregateSignatureResult struct { - SignatureWeight uint64 - TotalWeight uint64 - Message *avalancheWarp.Message -} - -func newSignatureAggregationJob( - client SignatureBackend, - height uint64, - subnetID ids.ID, - minValidQuorumNum uint64, - maxNeededQuorumNum uint64, - quorumDen uint64, - state validators.State, - msg *avalancheWarp.UnsignedMessage, -) *signatureAggregationJob { - return &signatureAggregationJob{ - client: client, - height: height, - subnetID: subnetID, - minValidQuorumNum: minValidQuorumNum, - maxNeededQuorumNum: maxNeededQuorumNum, - quorumDen: quorumDen, - state: state, - msg: msg, - } -} - -// Execute aggregates signatures for the requested message -func (a *signatureAggregationJob) Execute(ctx context.Context) (*AggregateSignatureResult, error) { - log.Info("Fetching signature", "subnetID", a.subnetID, "height", a.height) - validators, totalWeight, err := avalancheWarp.GetCanonicalValidatorSet(ctx, a.state, a.height, a.subnetID) - if err != nil { - return nil, fmt.Errorf("failed to get validator set: %w", err) - } - if len(validators) == 0 { - return nil, fmt.Errorf("cannot aggregate signatures from subnet with no validators (SubnetID: %s, Height: %d)", a.subnetID, a.height) - } - - signatureJobs := make([]*signatureJob, len(validators)) - for i, validator := range validators { - signatureJobs[i] = newSignatureJob(a.client, validator, a.msg) - } - - var ( - // [signatureLock] must be held when accessing [blsSignatures], [bitSet], or [signatureWeight] - // in the goroutine below. - signatureLock sync.Mutex - blsSignatures = make([]*bls.Signature, 0, len(signatureJobs)) - bitSet = set.NewBits() - signatureWeight = uint64(0) - ) - - // Create a child context to cancel signature fetching if we reach [maxNeededQuorumNum] threshold - signatureFetchCtx, signatureFetchCancel := context.WithCancel(ctx) - defer signatureFetchCancel() - - wg := sync.WaitGroup{} - wg.Add(len(signatureJobs)) - for i, signatureJob := range signatureJobs { - i := i - signatureJob := signatureJob - go func() { - defer wg.Done() - - log.Info("Fetching warp signature", - "nodeID", signatureJob.nodeID, - "index", i, - ) - - blsSignature, err := signatureJob.Execute(signatureFetchCtx) - if err != nil { - log.Info("Failed to fetch signature at index %d: %s", i, signatureJob) - return - } - log.Info("Retrieved warp signature", - "nodeID", signatureJob.nodeID, - "index", i, - "signature", hexutil.Bytes(bls.SignatureToBytes(blsSignature)), - ) - - // Add the signature and check if we've reached the requested threshold - signatureLock.Lock() - defer signatureLock.Unlock() - - blsSignatures = append(blsSignatures, blsSignature) - bitSet.Add(i) - signatureWeight += signatureJob.weight - log.Info("Updated weight", - "totalWeight", signatureWeight, - "addedWeight", signatureJob.weight, - ) - - // If the signature weight meets the requested threshold, cancel signature fetching - if err := avalancheWarp.VerifyWeight(signatureWeight, totalWeight, a.maxNeededQuorumNum, a.quorumDen); err == nil { - log.Info("Verify weight passed, exiting aggregation early", - "maxNeededQuorumNum", a.maxNeededQuorumNum, - "totalWeight", totalWeight, - "signatureWeight", signatureWeight, - ) - signatureFetchCancel() - } - }() - } - wg.Wait() - - // If I failed to fetch sufficient signature stake, return an error - if err := avalancheWarp.VerifyWeight(signatureWeight, totalWeight, a.minValidQuorumNum, a.quorumDen); err != nil { - return nil, fmt.Errorf("failed to aggregate signature: %w", err) - } - - // Otherwise, return the aggregate signature - aggregateSignature, err := bls.AggregateSignatures(blsSignatures) - if err != nil { - return nil, fmt.Errorf("failed to aggregate BLS signatures: %w", err) - } - - warpSignature := &avalancheWarp.BitSetSignature{ - Signers: bitSet.Bytes(), - } - copy(warpSignature.Signature[:], bls.SignatureToBytes(aggregateSignature)) - - msg, err := avalancheWarp.NewMessage(a.msg, warpSignature) - if err != nil { - return nil, fmt.Errorf("failed to construct warp message: %w", err) - } - - return &AggregateSignatureResult{ - Message: msg, - SignatureWeight: signatureWeight, - TotalWeight: totalWeight, - }, nil -} diff --git a/warp/aggregator/aggregation_job_test.go b/warp/aggregator/aggregation_job_test.go deleted file mode 100644 index a184138489..0000000000 --- a/warp/aggregator/aggregation_job_test.go +++ /dev/null @@ -1,355 +0,0 @@ -// (c) 2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package aggregator - -import ( - "context" - "errors" - "testing" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/crypto/bls" - "github.com/ava-labs/avalanchego/utils/set" - avalancheWarp "github.com/ava-labs/avalanchego/vms/platformvm/warp" - "github.com/stretchr/testify/require" -) - -var ( - subnetID = ids.GenerateTestID() - pChainHeight = uint64(10) - getSubnetIDF = func(ctx context.Context, chainID ids.ID) (ids.ID, error) { return subnetID, nil } - getCurrentHeightF = func(ctx context.Context) (uint64, error) { return pChainHeight, nil } -) - -type signatureAggregationTest struct { - ctx context.Context - job *signatureAggregationJob - expectedRes *AggregateSignatureResult - expectedErr error -} - -func executeSignatureAggregationTest(t testing.TB, test signatureAggregationTest) { - t.Helper() - - res, err := test.job.Execute(test.ctx) - if test.expectedErr != nil { - require.ErrorIs(t, err, test.expectedErr) - return - } - - require.Equal(t, res.SignatureWeight, test.expectedRes.SignatureWeight) - require.Equal(t, res.TotalWeight, test.expectedRes.TotalWeight) - require.NoError(t, res.Message.Signature.Verify( - context.Background(), - &res.Message.UnsignedMessage, - networkID, - test.job.state, - pChainHeight, - test.job.minValidQuorumNum, - test.job.quorumDen, - )) -} - -func TestSingleSignatureAggregator(t *testing.T) { - ctx := context.Background() - aggregationJob := newSignatureAggregationJob( - &mockFetcher{ - fetch: func(context.Context, ids.NodeID, *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - return blsSignatures[0], nil - }, - }, - pChainHeight, - subnetID, - 100, - 100, - 100, - &validators.TestState{ - GetSubnetIDF: getSubnetIDF, - GetCurrentHeightF: getCurrentHeightF, - GetValidatorSetF: func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - return map[ids.NodeID]*validators.GetValidatorOutput{ - nodeIDs[0]: { - NodeID: nodeIDs[0], - PublicKey: blsPublicKeys[0], - Weight: 100, - }, - }, nil - }, - }, - unsignedMsg, - ) - - signature := &avalancheWarp.BitSetSignature{ - Signers: set.NewBits(0).Bytes(), - } - signedMessage, err := avalancheWarp.NewMessage(unsignedMsg, signature) - require.NoError(t, err) - copy(signature.Signature[:], bls.SignatureToBytes(blsSignatures[0])) - expectedRes := &AggregateSignatureResult{ - SignatureWeight: 100, - TotalWeight: 100, - Message: signedMessage, - } - executeSignatureAggregationTest(t, signatureAggregationTest{ - ctx: ctx, - job: aggregationJob, - expectedRes: expectedRes, - }) -} - -func TestAggregateAllSignatures(t *testing.T) { - ctx := context.Background() - aggregationJob := newSignatureAggregationJob( - &mockFetcher{ - fetch: func(_ context.Context, nodeID ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - for i, matchingNodeID := range nodeIDs { - if matchingNodeID == nodeID { - return blsSignatures[i], nil - } - } - panic("request to unexpected nodeID") - }, - }, - pChainHeight, - subnetID, - 100, - 100, - 100, - &validators.TestState{ - GetSubnetIDF: getSubnetIDF, - GetCurrentHeightF: getCurrentHeightF, - GetValidatorSetF: func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - res := make(map[ids.NodeID]*validators.GetValidatorOutput) - for i := 0; i < 5; i++ { - res[nodeIDs[i]] = &validators.GetValidatorOutput{ - NodeID: nodeIDs[i], - PublicKey: blsPublicKeys[i], - Weight: 100, - } - } - return res, nil - }, - }, - unsignedMsg, - ) - - signature := &avalancheWarp.BitSetSignature{ - Signers: set.NewBits(0, 1, 2, 3, 4).Bytes(), - } - signedMessage, err := avalancheWarp.NewMessage(unsignedMsg, signature) - require.NoError(t, err) - aggregateSignature, err := bls.AggregateSignatures(blsSignatures) - require.NoError(t, err) - copy(signature.Signature[:], bls.SignatureToBytes(aggregateSignature)) - expectedRes := &AggregateSignatureResult{ - SignatureWeight: 500, - TotalWeight: 500, - Message: signedMessage, - } - executeSignatureAggregationTest(t, signatureAggregationTest{ - ctx: ctx, - job: aggregationJob, - expectedRes: expectedRes, - }) -} - -func TestAggregateThresholdSignatures(t *testing.T) { - ctx := context.Background() - aggregationJob := newSignatureAggregationJob( - &mockFetcher{ - fetch: func(_ context.Context, nodeID ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - for i, matchingNodeID := range nodeIDs[:3] { - if matchingNodeID == nodeID { - return blsSignatures[i], nil - } - } - return nil, errors.New("what do we say to the god of death") - }, - }, - pChainHeight, - subnetID, - 60, - 60, - 100, - &validators.TestState{ - GetSubnetIDF: getSubnetIDF, - GetCurrentHeightF: getCurrentHeightF, - GetValidatorSetF: func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - res := make(map[ids.NodeID]*validators.GetValidatorOutput) - for i := 0; i < 5; i++ { - res[nodeIDs[i]] = &validators.GetValidatorOutput{ - NodeID: nodeIDs[i], - PublicKey: blsPublicKeys[i], - Weight: 100, - } - } - return res, nil - }, - }, - unsignedMsg, - ) - - signature := &avalancheWarp.BitSetSignature{ - Signers: set.NewBits(0, 1, 2).Bytes(), - } - signedMessage, err := avalancheWarp.NewMessage(unsignedMsg, signature) - require.NoError(t, err) - aggregateSignature, err := bls.AggregateSignatures(blsSignatures) - require.NoError(t, err) - copy(signature.Signature[:], bls.SignatureToBytes(aggregateSignature)) - expectedRes := &AggregateSignatureResult{ - SignatureWeight: 300, - TotalWeight: 500, - Message: signedMessage, - } - executeSignatureAggregationTest(t, signatureAggregationTest{ - ctx: ctx, - job: aggregationJob, - expectedRes: expectedRes, - }) -} - -func TestAggregateThresholdSignaturesInsufficientWeight(t *testing.T) { - ctx := context.Background() - aggregationJob := newSignatureAggregationJob( - &mockFetcher{ - fetch: func(_ context.Context, nodeID ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - for i, matchingNodeID := range nodeIDs[:3] { - if matchingNodeID == nodeID { - return blsSignatures[i], nil - } - } - return nil, errors.New("what do we say to the god of death") - }, - }, - pChainHeight, - subnetID, - 80, - 80, - 100, - &validators.TestState{ - GetSubnetIDF: getSubnetIDF, - GetCurrentHeightF: getCurrentHeightF, - GetValidatorSetF: func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - res := make(map[ids.NodeID]*validators.GetValidatorOutput) - for i := 0; i < 5; i++ { - res[nodeIDs[i]] = &validators.GetValidatorOutput{ - NodeID: nodeIDs[i], - PublicKey: blsPublicKeys[i], - Weight: 100, - } - } - return res, nil - }, - }, - unsignedMsg, - ) - - executeSignatureAggregationTest(t, signatureAggregationTest{ - ctx: ctx, - job: aggregationJob, - expectedErr: avalancheWarp.ErrInsufficientWeight, - }) -} - -func TestAggregateThresholdSignaturesBlockingRequests(t *testing.T) { - ctx := context.Background() - aggregationJob := newSignatureAggregationJob( - &mockFetcher{ - fetch: func(ctx context.Context, nodeID ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - for i, matchingNodeID := range nodeIDs[:3] { - if matchingNodeID == nodeID { - return blsSignatures[i], nil - } - } - - // Block until the context is cancelled and return the error if not available - <-ctx.Done() - return nil, ctx.Err() - }, - }, - pChainHeight, - subnetID, - 60, - 60, - 100, - &validators.TestState{ - GetSubnetIDF: getSubnetIDF, - GetCurrentHeightF: getCurrentHeightF, - GetValidatorSetF: func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - res := make(map[ids.NodeID]*validators.GetValidatorOutput) - for i := 0; i < 5; i++ { - res[nodeIDs[i]] = &validators.GetValidatorOutput{ - NodeID: nodeIDs[i], - PublicKey: blsPublicKeys[i], - Weight: 100, - } - } - return res, nil - }, - }, - unsignedMsg, - ) - - signature := &avalancheWarp.BitSetSignature{ - Signers: set.NewBits(0, 1, 2).Bytes(), - } - signedMessage, err := avalancheWarp.NewMessage(unsignedMsg, signature) - require.NoError(t, err) - aggregateSignature, err := bls.AggregateSignatures(blsSignatures) - require.NoError(t, err) - copy(signature.Signature[:], bls.SignatureToBytes(aggregateSignature)) - expectedRes := &AggregateSignatureResult{ - SignatureWeight: 300, - TotalWeight: 500, - Message: signedMessage, - } - executeSignatureAggregationTest(t, signatureAggregationTest{ - ctx: ctx, - job: aggregationJob, - expectedRes: expectedRes, - }) -} - -func TestAggregateThresholdSignaturesParentCtxCancels(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - aggregationJob := newSignatureAggregationJob( - &mockFetcher{ - fetch: func(ctx context.Context, nodeID ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - // Block until the context is cancelled and return the error if not available - <-ctx.Done() - return nil, ctx.Err() - }, - }, - pChainHeight, - subnetID, - 60, - 60, - 100, - &validators.TestState{ - GetSubnetIDF: getSubnetIDF, - GetCurrentHeightF: getCurrentHeightF, - GetValidatorSetF: func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - res := make(map[ids.NodeID]*validators.GetValidatorOutput) - for i := 0; i < 5; i++ { - res[nodeIDs[i]] = &validators.GetValidatorOutput{ - NodeID: nodeIDs[i], - PublicKey: blsPublicKeys[i], - Weight: 100, - } - } - return res, nil - }, - }, - unsignedMsg, - ) - - executeSignatureAggregationTest(t, signatureAggregationTest{ - ctx: ctx, - job: aggregationJob, - expectedErr: avalancheWarp.ErrInsufficientWeight, - }) -} diff --git a/warp/aggregator/aggregator.go b/warp/aggregator/aggregator.go index 3530d4a700..8633205373 100644 --- a/warp/aggregator/aggregator.go +++ b/warp/aggregator/aggregator.go @@ -5,22 +5,51 @@ package aggregator import ( "context" + "errors" + "fmt" + + "github.com/ava-labs/subnet-evm/params" + + "github.com/ethereum/go-ethereum/log" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/utils/set" avalancheWarp "github.com/ava-labs/avalanchego/vms/platformvm/warp" - "github.com/ava-labs/subnet-evm/params" ) -// Aggregator fulfills requests to aggregate signatures of a subnet's validator set for Avalanche Warp Messages. +var errNoValidators = errors.New("cannot aggregate signatures from subnet with no validators") + +// SignatureGetter defines the minimum network interface to perform signature aggregation +type SignatureGetter interface { + // GetSignature attempts to fetch a BLS Signature from [nodeID] for [unsignedWarpMessage] + GetSignature(ctx context.Context, nodeID ids.NodeID, unsignedWarpMessage *avalancheWarp.UnsignedMessage) (*bls.Signature, error) +} + +type AggregateSignatureResult struct { + // Weight of validators included in the aggregate signature. + SignatureWeight uint64 + // Total weight of all validators in the subnet. + TotalWeight uint64 + // The message with the aggregate signature. + Message *avalancheWarp.Message +} + +// Aggregator requests signatures from validators and +// aggregates them into a single signature. type Aggregator struct { + // Aggregating signatures for a chain validated by this subnet. subnetID ids.ID - client SignatureBackend - state validators.State + // Fetches signatures from validators. + client SignatureGetter + // Validator state for this chain. + state validators.State } -// NewAggregator returns a signature aggregator, which will aggregate Warp Signatures for the given [ -func NewAggregator(subnetID ids.ID, state validators.State, client SignatureBackend) *Aggregator { +// New returns a signature aggregator for the chain with the given [state] on the +// given [subnetID], and where [client] can be used to fetch signatures from validators. +func New(subnetID ids.ID, state validators.State, client SignatureGetter) *Aggregator { return &Aggregator{ subnetID: subnetID, client: client, @@ -28,6 +57,8 @@ func NewAggregator(subnetID ids.ID, state validators.State, client SignatureBack } } +// Returns an aggregate signature over [unsignedMessage]. +// The returned signature's weight exceeds the threshold given by [quorumNum]. func (a *Aggregator) AggregateSignatures(ctx context.Context, unsignedMessage *avalancheWarp.UnsignedMessage, quorumNum uint64) (*AggregateSignatureResult, error) { // Note: we use the current height as a best guess of the canonical validator set when the aggregated signature will be verified // by the recipient chain. If the validator set changes from [pChainHeight] to the P-Chain height that is actually specified by the @@ -36,16 +67,141 @@ func (a *Aggregator) AggregateSignatures(ctx context.Context, unsignedMessage *a if err != nil { return nil, err } - job := newSignatureAggregationJob( - a.client, - pChainHeight, - a.subnetID, - quorumNum, - quorumNum, - params.WarpQuorumDenominator, - a.state, - unsignedMessage, + + log.Debug("Fetching signature", + "a.subnetID", a.subnetID, + "height", pChainHeight, + ) + validators, totalWeight, err := avalancheWarp.GetCanonicalValidatorSet(ctx, a.state, pChainHeight, a.subnetID) + if err != nil { + return nil, fmt.Errorf("failed to get validator set: %w", err) + } + if len(validators) == 0 { + return nil, fmt.Errorf("%w (SubnetID: %s, Height: %d)", errNoValidators, a.subnetID, pChainHeight) + } + + type signatureFetchResult struct { + sig *bls.Signature + index int + weight uint64 + } + + // Create a child context to cancel signature fetching if we reach signature threshold. + signatureFetchCtx, signatureFetchCancel := context.WithCancel(ctx) + defer signatureFetchCancel() + + // Fetch signatures from validators concurrently. + signatureFetchResultChan := make(chan *signatureFetchResult) + for i, validator := range validators { + var ( + i = i + validator = validator + // TODO: update from a single nodeID to the original slice and use extra nodeIDs as backup. + nodeID = validator.NodeIDs[0] + ) + go func() { + log.Debug("Fetching warp signature", + "nodeID", nodeID, + "index", i, + "msgID", unsignedMessage.ID(), + ) + + signature, err := a.client.GetSignature(signatureFetchCtx, nodeID, unsignedMessage) + if err != nil { + log.Debug("Failed to fetch warp signature", + "nodeID", nodeID, + "index", i, + "err", err, + "msgID", unsignedMessage.ID(), + ) + signatureFetchResultChan <- nil + return + } + + log.Debug("Retrieved warp signature", + "nodeID", nodeID, + "msgID", unsignedMessage.ID(), + "index", i, + ) + + if !bls.Verify(validator.PublicKey, signature, unsignedMessage.Bytes()) { + log.Debug("Failed to verify warp signature", + "nodeID", nodeID, + "index", i, + "msgID", unsignedMessage.ID(), + ) + signatureFetchResultChan <- nil + return + } + + signatureFetchResultChan <- &signatureFetchResult{ + sig: signature, + index: i, + weight: validator.Weight, + } + }() + } + + var ( + signatures = make([]*bls.Signature, 0, len(validators)) + signersBitset = set.NewBits() + signaturesWeight = uint64(0) + signaturesPassedThreshold = false ) - return job.Execute(ctx) + for i := 0; i < len(validators); i++ { + signatureFetchResult := <-signatureFetchResultChan + if signatureFetchResult == nil { + continue + } + + signatures = append(signatures, signatureFetchResult.sig) + signersBitset.Add(signatureFetchResult.index) + signaturesWeight += signatureFetchResult.weight + log.Debug("Updated weight", + "totalWeight", signaturesWeight, + "addedWeight", signatureFetchResult.weight, + "msgID", unsignedMessage.ID(), + ) + + // If the signature weight meets the requested threshold, cancel signature fetching + if err := avalancheWarp.VerifyWeight(signaturesWeight, totalWeight, quorumNum, params.WarpQuorumDenominator); err == nil { + log.Debug("Verify weight passed, exiting aggregation early", + "quorumNum", quorumNum, + "totalWeight", totalWeight, + "signatureWeight", signaturesWeight, + "msgID", unsignedMessage.ID(), + ) + signatureFetchCancel() + signaturesPassedThreshold = true + break + } + } + + // If I failed to fetch sufficient signature stake, return an error + if !signaturesPassedThreshold { + return nil, avalancheWarp.ErrInsufficientWeight + } + + // Otherwise, return the aggregate signature + aggregateSignature, err := bls.AggregateSignatures(signatures) + if err != nil { + return nil, fmt.Errorf("failed to aggregate BLS signatures: %w", err) + } + + warpSignature := &avalancheWarp.BitSetSignature{ + Signers: signersBitset.Bytes(), + } + copy(warpSignature.Signature[:], bls.SignatureToBytes(aggregateSignature)) + + msg, err := avalancheWarp.NewMessage(unsignedMessage, warpSignature) + if err != nil { + return nil, fmt.Errorf("failed to construct warp message: %w", err) + } + + return &AggregateSignatureResult{ + Message: msg, + SignatureWeight: signaturesWeight, + TotalWeight: totalWeight, + }, nil } diff --git a/warp/aggregator/aggregator_test.go b/warp/aggregator/aggregator_test.go new file mode 100644 index 0000000000..01225f89c1 --- /dev/null +++ b/warp/aggregator/aggregator_test.go @@ -0,0 +1,425 @@ +// (c) 2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package aggregator + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "go.uber.org/mock/gomock" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/crypto/bls" + avalancheWarp "github.com/ava-labs/avalanchego/vms/platformvm/warp" +) + +func newValidator(t testing.TB, weight uint64) (*bls.SecretKey, *avalancheWarp.Validator) { + sk, err := bls.NewSecretKey() + require.NoError(t, err) + pk := bls.PublicFromSecretKey(sk) + return sk, &avalancheWarp.Validator{ + PublicKey: pk, + PublicKeyBytes: bls.PublicKeyToBytes(pk), + Weight: weight, + NodeIDs: []ids.NodeID{ids.GenerateTestNodeID()}, + } +} + +func TestAggregateSignatures(t *testing.T) { + subnetID := ids.GenerateTestID() + errTest := errors.New("test error") + pChainHeight := uint64(1337) + unsignedMsg := &avalancheWarp.UnsignedMessage{ + NetworkID: 1338, + SourceChainID: ids.ID{'y', 'e', 'e', 't'}, + Payload: []byte("hello world"), + } + require.NoError(t, unsignedMsg.Initialize()) + + nodeID1, nodeID2, nodeID3 := ids.GenerateTestNodeID(), ids.GenerateTestNodeID(), ids.GenerateTestNodeID() + vdrWeight := uint64(10001) + vdr1sk, vdr1 := newValidator(t, vdrWeight) + vdr2sk, vdr2 := newValidator(t, vdrWeight+1) + vdr3sk, vdr3 := newValidator(t, vdrWeight-1) + sig1 := bls.Sign(vdr1sk, unsignedMsg.Bytes()) + sig2 := bls.Sign(vdr2sk, unsignedMsg.Bytes()) + sig3 := bls.Sign(vdr3sk, unsignedMsg.Bytes()) + vdrToSig := map[*avalancheWarp.Validator]*bls.Signature{ + vdr1: sig1, + vdr2: sig2, + vdr3: sig3, + } + nonVdrSk, err := bls.NewSecretKey() + require.NoError(t, err) + nonVdrSig := bls.Sign(nonVdrSk, unsignedMsg.Bytes()) + vdrSet := map[ids.NodeID]*validators.GetValidatorOutput{ + nodeID1: { + NodeID: nodeID1, + PublicKey: vdr1.PublicKey, + Weight: vdr1.Weight, + }, + nodeID2: { + NodeID: nodeID2, + PublicKey: vdr2.PublicKey, + Weight: vdr2.Weight, + }, + nodeID3: { + NodeID: nodeID3, + PublicKey: vdr3.PublicKey, + Weight: vdr3.Weight, + }, + } + + type test struct { + name string + contextFunc func() context.Context + aggregatorFunc func(*gomock.Controller) *Aggregator + unsignedMsg *avalancheWarp.UnsignedMessage + quorumNum uint64 + expectedSigners []*avalancheWarp.Validator + expectedErr error + } + + tests := []test{ + { + name: "can't get height", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(uint64(0), errTest) + return New(subnetID, state, nil) + }, + unsignedMsg: nil, + quorumNum: 0, + expectedErr: errTest, + }, + { + name: "can't get validator set", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errTest) + return New(subnetID, state, nil) + }, + unsignedMsg: nil, + expectedErr: errTest, + }, + { + name: "no validators exist", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) + return New(subnetID, state, nil) + }, + unsignedMsg: nil, + quorumNum: 0, + expectedErr: errNoValidators, + }, + { + name: "0/3 validators reply with signature", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errTest).AnyTimes() + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 1, + expectedErr: avalancheWarp.ErrInsufficientWeight, + }, + { + name: "1/3 validators reply with signature; insufficient weight", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(sig1, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(nil, errTest) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).Return(nil, errTest) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 35, // Require >1/3 of weight + expectedErr: avalancheWarp.ErrInsufficientWeight, + }, + { + name: "2/3 validators reply with signature; insufficient weight", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(sig1, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(sig2, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).Return(nil, errTest) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 69, // Require >2/3 of weight + expectedErr: avalancheWarp.ErrInsufficientWeight, + }, + { + name: "2/3 validators reply with signature; sufficient weight", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(sig1, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(sig2, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).Return(nil, errTest) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 65, // Require <2/3 of weight + expectedSigners: []*avalancheWarp.Validator{vdr1, vdr2}, + expectedErr: nil, + }, + { + name: "3/3 validators reply with signature; sufficient weight", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(sig1, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(sig2, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).Return(sig3, nil) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 100, // Require all weight + expectedSigners: []*avalancheWarp.Validator{vdr1, vdr2, vdr3}, + expectedErr: nil, + }, + { + name: "3/3 validators reply with signature; 1 invalid signature; sufficient weight", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(nonVdrSig, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(sig2, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).Return(sig3, nil) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 64, + expectedSigners: []*avalancheWarp.Validator{vdr2, vdr3}, + expectedErr: nil, + }, + { + name: "3/3 validators reply with signature; 3 invalid signatures; insufficient weight", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(nonVdrSig, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(nonVdrSig, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).Return(nonVdrSig, nil) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 1, + expectedErr: avalancheWarp.ErrInsufficientWeight, + }, + { + name: "3/3 validators reply with signature; 2 invalid signatures; insufficient weight", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(nonVdrSig, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(nonVdrSig, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).Return(sig3, nil) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 40, + expectedErr: avalancheWarp.ErrInsufficientWeight, + }, + { + name: "2/3 validators reply with signature; 1 invalid signature; sufficient weight", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(nonVdrSig, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(nil, errTest) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).Return(sig3, nil) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 30, + expectedSigners: []*avalancheWarp.Validator{vdr3}, + expectedErr: nil, + }, + { + name: "early termination of signature fetching on parent context cancelation", + contextFunc: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + // Assert that the context passed into each goroutine is canceled + // because the parent context is canceled. + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).DoAndReturn( + func(ctx context.Context, _ ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { + <-ctx.Done() + err := ctx.Err() + require.ErrorIs(t, err, context.Canceled) + return nil, err + }, + ) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).DoAndReturn( + func(ctx context.Context, _ ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { + <-ctx.Done() + err := ctx.Err() + require.ErrorIs(t, err, context.Canceled) + return nil, err + }, + ) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).DoAndReturn( + func(ctx context.Context, _ ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { + <-ctx.Done() + err := ctx.Err() + require.ErrorIs(t, err, context.Canceled) + return nil, err + }, + ) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 60, // Require 2/3 validators + expectedSigners: []*avalancheWarp.Validator{vdr1, vdr2}, + expectedErr: avalancheWarp.ErrInsufficientWeight, + }, + { + name: "early termination of signature fetching on passing threshold", + contextFunc: context.Background, + aggregatorFunc: func(ctrl *gomock.Controller) *Aggregator { + state := validators.NewMockState(ctrl) + state.EXPECT().GetCurrentHeight(gomock.Any()).Return(pChainHeight, nil) + state.EXPECT().GetValidatorSet(gomock.Any(), gomock.Any(), gomock.Any()).Return( + vdrSet, nil, + ) + + client := NewMockSignatureGetter(ctrl) + client.EXPECT().GetSignature(gomock.Any(), nodeID1, gomock.Any()).Return(sig1, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID2, gomock.Any()).Return(sig2, nil) + client.EXPECT().GetSignature(gomock.Any(), nodeID3, gomock.Any()).DoAndReturn( + // The aggregator will receive sig1 and sig2 which is sufficient weight, + // so the remaining outstanding goroutine should be cancelled. + func(ctx context.Context, _ ids.NodeID, _ *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { + <-ctx.Done() + err := ctx.Err() + require.ErrorIs(t, err, context.Canceled) + return nil, err + }, + ) + return New(subnetID, state, client) + }, + unsignedMsg: unsignedMsg, + quorumNum: 60, // Require 2/3 validators + expectedSigners: []*avalancheWarp.Validator{vdr1, vdr2}, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + require := require.New(t) + + a := tt.aggregatorFunc(ctrl) + + res, err := a.AggregateSignatures(tt.contextFunc(), tt.unsignedMsg, tt.quorumNum) + require.ErrorIs(err, tt.expectedErr) + if err != nil { + return + } + + require.Equal(unsignedMsg, &res.Message.UnsignedMessage) + + expectedSigWeight := uint64(0) + for _, vdr := range tt.expectedSigners { + expectedSigWeight += vdr.Weight + } + require.Equal(expectedSigWeight, res.SignatureWeight) + require.Equal(vdr1.Weight+vdr2.Weight+vdr3.Weight, res.TotalWeight) + + expectedSigs := []*bls.Signature{} + for _, vdr := range tt.expectedSigners { + expectedSigs = append(expectedSigs, vdrToSig[vdr]) + } + expectedSig, err := bls.AggregateSignatures(expectedSigs) + require.NoError(err) + gotBLSSig, ok := res.Message.Signature.(*avalancheWarp.BitSetSignature) + require.True(ok) + require.Equal(bls.SignatureToBytes(expectedSig), gotBLSSig.Signature[:]) + + numSigners, err := res.Message.Signature.NumSigners() + require.NoError(err) + require.Len(tt.expectedSigners, numSigners) + }) + } +} diff --git a/warp/aggregator/mock_signature_getter.go b/warp/aggregator/mock_signature_getter.go new file mode 100644 index 0000000000..f00bb920fa --- /dev/null +++ b/warp/aggregator/mock_signature_getter.go @@ -0,0 +1,53 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ava-labs/subnet-evm/warp/aggregator (interfaces: SignatureGetter) + +// Package aggregator is a generated GoMock package. +package aggregator + +import ( + context "context" + reflect "reflect" + + bls "github.com/ava-labs/avalanchego/utils/crypto/bls" + ids "github.com/ava-labs/avalanchego/ids" + warp "github.com/ava-labs/avalanchego/vms/platformvm/warp" + gomock "go.uber.org/mock/gomock" +) + +// MockSignatureGetter is a mock of SignatureGetter interface. +type MockSignatureGetter struct { + ctrl *gomock.Controller + recorder *MockSignatureGetterMockRecorder +} + +// MockSignatureGetterMockRecorder is the mock recorder for MockSignatureGetter. +type MockSignatureGetterMockRecorder struct { + mock *MockSignatureGetter +} + +// NewMockSignatureGetter creates a new mock instance. +func NewMockSignatureGetter(ctrl *gomock.Controller) *MockSignatureGetter { + mock := &MockSignatureGetter{ctrl: ctrl} + mock.recorder = &MockSignatureGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSignatureGetter) EXPECT() *MockSignatureGetterMockRecorder { + return m.recorder +} + +// GetSignature mocks base method. +func (m *MockSignatureGetter) GetSignature(arg0 context.Context, arg1 ids.NodeID, arg2 *warp.UnsignedMessage) (*bls.Signature, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSignature", arg0, arg1, arg2) + ret0, _ := ret[0].(*bls.Signature) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSignature indicates an expected call of GetSignature. +func (mr *MockSignatureGetterMockRecorder) GetSignature(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSignature", reflect.TypeOf((*MockSignatureGetter)(nil).GetSignature), arg0, arg1, arg2) +} diff --git a/warp/aggregator/network_signature_backend.go b/warp/aggregator/network_signature_backend.go index b4079e8a22..6f1cfd2770 100644 --- a/warp/aggregator/network_signature_backend.go +++ b/warp/aggregator/network_signature_backend.go @@ -19,7 +19,7 @@ const ( retryBackoffFactor = 2 ) -var _ SignatureBackend = (*NetworkSigner)(nil) +var _ SignatureGetter = (*NetworkSigner)(nil) type NetworkClient interface { SendAppRequest(nodeID ids.NodeID, message []byte) ([]byte, error) @@ -30,11 +30,11 @@ type NetworkSigner struct { Client NetworkClient } -// FetchWarpSignature attempts to fetch a BLS Signature of [unsignedWarpMessage] from [nodeID] until it succeeds or receives an invalid response +// GetSignature attempts to fetch a BLS Signature of [unsignedWarpMessage] from [nodeID] until it succeeds or receives an invalid response // // Note: this function will continue attempting to fetch the signature from [nodeID] until it receives an invalid value or [ctx] is cancelled. // The caller is responsible to cancel [ctx] if it no longer needs to fetch this signature. -func (s *NetworkSigner) FetchWarpSignature(ctx context.Context, nodeID ids.NodeID, unsignedWarpMessage *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { +func (s *NetworkSigner) GetSignature(ctx context.Context, nodeID ids.NodeID, unsignedWarpMessage *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { signatureReq := message.SignatureRequest{ MessageID: unsignedWarpMessage.ID(), } diff --git a/warp/aggregator/signature_job.go b/warp/aggregator/signature_job.go deleted file mode 100644 index 85fc91ab54..0000000000 --- a/warp/aggregator/signature_job.go +++ /dev/null @@ -1,60 +0,0 @@ -// (c) 2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package aggregator - -import ( - "context" - "errors" - "fmt" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/utils/crypto/bls" - avalancheWarp "github.com/ava-labs/avalanchego/vms/platformvm/warp" - "github.com/ethereum/go-ethereum/common/hexutil" -) - -var errInvalidSignature = errors.New("invalid signature") - -// SignatureBackend defines the minimum network interface to perform signature aggregation -type SignatureBackend interface { - // FetchWarpSignature attempts to fetch a BLS Signature from [nodeID] for [unsignedWarpMessage] - FetchWarpSignature(ctx context.Context, nodeID ids.NodeID, unsignedWarpMessage *avalancheWarp.UnsignedMessage) (*bls.Signature, error) -} - -// signatureJob fetches a single signature using the injected dependency SignatureBackend and returns a verified signature of the requested message. -type signatureJob struct { - backend SignatureBackend - msg *avalancheWarp.UnsignedMessage - - nodeID ids.NodeID - publicKey *bls.PublicKey - weight uint64 -} - -func (s *signatureJob) String() string { - return fmt.Sprintf("(NodeID: %s, UnsignedMsgID: %s)", s.nodeID, s.msg.ID()) -} - -func newSignatureJob(backend SignatureBackend, validator *avalancheWarp.Validator, msg *avalancheWarp.UnsignedMessage) *signatureJob { - return &signatureJob{ - backend: backend, - msg: msg, - nodeID: validator.NodeIDs[0], // TODO: update from a single nodeID to the original slice and use extra nodeIDs as backup. - publicKey: validator.PublicKey, - weight: validator.Weight, - } -} - -// Execute attempts to fetch the signature from the nodeID specified in this job and then verifies and returns the signature -func (s *signatureJob) Execute(ctx context.Context) (*bls.Signature, error) { - signature, err := s.backend.FetchWarpSignature(ctx, s.nodeID, s.msg) - if err != nil { - return nil, err - } - - if !bls.Verify(s.publicKey, signature, s.msg.Bytes()) { - return nil, fmt.Errorf("%w: node %s returned invalid signature %s for msg %s", errInvalidSignature, s.nodeID, hexutil.Bytes(bls.SignatureToBytes(signature)), s.msg.ID()) - } - return signature, nil -} diff --git a/warp/aggregator/signature_job_test.go b/warp/aggregator/signature_job_test.go deleted file mode 100644 index 0ed52a0eeb..0000000000 --- a/warp/aggregator/signature_job_test.go +++ /dev/null @@ -1,140 +0,0 @@ -// (c) 2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package aggregator - -import ( - "context" - "errors" - "testing" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/utils/crypto/bls" - avalancheWarp "github.com/ava-labs/avalanchego/vms/platformvm/warp" - "github.com/stretchr/testify/require" -) - -type mockFetcher struct { - fetch func(ctx context.Context, nodeID ids.NodeID, unsignedWarpMessage *avalancheWarp.UnsignedMessage) (*bls.Signature, error) -} - -func (m *mockFetcher) FetchWarpSignature(ctx context.Context, nodeID ids.NodeID, unsignedWarpMessage *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - return m.fetch(ctx, nodeID, unsignedWarpMessage) -} - -var ( - nodeIDs []ids.NodeID - blsSecretKeys []*bls.SecretKey - blsPublicKeys []*bls.PublicKey - networkID uint32 = 54321 - sourceChainID = ids.GenerateTestID() - unsignedMsg *avalancheWarp.UnsignedMessage - blsSignatures []*bls.Signature -) - -func init() { - var err error - unsignedMsg, err = avalancheWarp.NewUnsignedMessage(networkID, sourceChainID, []byte{1, 2, 3}) - if err != nil { - panic(err) - } - for i := 0; i < 5; i++ { - nodeIDs = append(nodeIDs, ids.GenerateTestNodeID()) - - blsSecretKey, err := bls.NewSecretKey() - if err != nil { - panic(err) - } - blsPublicKey := bls.PublicFromSecretKey(blsSecretKey) - blsSignature := bls.Sign(blsSecretKey, unsignedMsg.Bytes()) - blsSecretKeys = append(blsSecretKeys, blsSecretKey) - blsPublicKeys = append(blsPublicKeys, blsPublicKey) - blsSignatures = append(blsSignatures, blsSignature) - } -} - -type signatureJobTest struct { - ctx context.Context - job *signatureJob - expectedSignature *bls.Signature - expectedErr error -} - -func executeSignatureJobTest(t testing.TB, test signatureJobTest) { - t.Helper() - - blsSignature, err := test.job.Execute(test.ctx) - if test.expectedErr != nil { - require.ErrorIs(t, err, test.expectedErr) - return - } - require.NoError(t, err) - require.Equal(t, bls.SignatureToBytes(blsSignature), bls.SignatureToBytes(test.expectedSignature)) -} - -func TestSignatureRequestSuccess(t *testing.T) { - job := newSignatureJob( - &mockFetcher{ - fetch: func(context.Context, ids.NodeID, *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - return blsSignatures[0], nil - }, - }, - &avalancheWarp.Validator{ - NodeIDs: nodeIDs[:1], - PublicKey: blsPublicKeys[0], - Weight: 10, - }, - unsignedMsg, - ) - - executeSignatureJobTest(t, signatureJobTest{ - ctx: context.Background(), - job: job, - expectedSignature: blsSignatures[0], - }) -} - -func TestSignatureRequestFails(t *testing.T) { - err := errors.New("expected error") - job := newSignatureJob( - &mockFetcher{ - fetch: func(context.Context, ids.NodeID, *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - return nil, err - }, - }, - &avalancheWarp.Validator{ - NodeIDs: nodeIDs[:1], - PublicKey: blsPublicKeys[0], - Weight: 10, - }, - unsignedMsg, - ) - - executeSignatureJobTest(t, signatureJobTest{ - ctx: context.Background(), - job: job, - expectedErr: err, - }) -} - -func TestSignatureRequestInvalidSignature(t *testing.T) { - job := newSignatureJob( - &mockFetcher{ - fetch: func(context.Context, ids.NodeID, *avalancheWarp.UnsignedMessage) (*bls.Signature, error) { - return blsSignatures[1], nil - }, - }, - &avalancheWarp.Validator{ - NodeIDs: nodeIDs[:1], - PublicKey: blsPublicKeys[0], - Weight: 10, - }, - unsignedMsg, - ) - - executeSignatureJobTest(t, signatureJobTest{ - ctx: context.Background(), - job: job, - expectedErr: errInvalidSignature, - }) -}