diff --git a/blssig/aggregation.go b/blssig/aggregation.go index ea468305..bb14f6ed 100644 --- a/blssig/aggregation.go +++ b/blssig/aggregation.go @@ -12,12 +12,18 @@ import ( "github.com/drand/kyber" "github.com/drand/kyber/sign" + "github.com/drand/kyber/sign/bdn" ) // Max size of the point cache. const maxPointCacheSize = 10_000 -func (v *Verifier) Aggregate(pubkeys []gpbft.PubKey, signatures [][]byte) (_agg []byte, _err error) { +type aggregation struct { + mask *bdn.CachedMask + scheme *bdn.Scheme +} + +func (a *aggregation) Aggregate(mask []int, signatures [][]byte) (_agg []byte, _err error) { defer func() { status := measurements.AttrStatusSuccess if _err != nil { @@ -32,22 +38,24 @@ func (v *Verifier) Aggregate(pubkeys []gpbft.PubKey, signatures [][]byte) (_agg } metrics.aggregate.Record( - context.TODO(), int64(len(pubkeys)), + context.TODO(), int64(len(mask)), metric.WithAttributes(status), ) }() - if len(pubkeys) != len(signatures) { + if len(mask) != len(signatures) { return nil, fmt.Errorf("lengths of pubkeys and sigs does not match %d != %d", - len(pubkeys), len(signatures)) + len(mask), len(signatures)) } - mask, err := v.pubkeysToMask(pubkeys) - if err != nil { - return nil, fmt.Errorf("converting public keys to mask: %w", err) + bdnMask := a.mask.Clone() + for _, bit := range mask { + if err := bdnMask.SetBit(bit, true); err != nil { + return nil, err + } } - aggSigPoint, err := v.scheme.AggregateSignatures(signatures, mask) + aggSigPoint, err := a.scheme.AggregateSignatures(signatures, bdnMask) if err != nil { return nil, fmt.Errorf("computing aggregate signature: %w", err) } @@ -59,7 +67,7 @@ func (v *Verifier) Aggregate(pubkeys []gpbft.PubKey, signatures [][]byte) (_agg return aggSig, nil } -func (v *Verifier) VerifyAggregate(msg []byte, signature []byte, pubkeys []gpbft.PubKey) (_err error) { +func (a *aggregation) VerifyAggregate(mask []int, msg []byte, signature []byte) (_err error) { defer func() { status := measurements.AttrStatusSuccess if _err != nil { @@ -75,25 +83,46 @@ func (v *Verifier) VerifyAggregate(msg []byte, signature []byte, pubkeys []gpbft } metrics.verifyAggregate.Record( - context.TODO(), int64(len(pubkeys)), + context.TODO(), int64(len(mask)), metric.WithAttributes(status), ) }() - mask, err := v.pubkeysToMask(pubkeys) - if err != nil { - return fmt.Errorf("converting public keys to mask: %w", err) + bdnMask := a.mask.Clone() + for _, bit := range mask { + if err := bdnMask.SetBit(bit, true); err != nil { + return err + } } - aggPubKey, err := v.scheme.AggregatePublicKeys(mask) + aggPubKey, err := a.scheme.AggregatePublicKeys(bdnMask) if err != nil { return fmt.Errorf("aggregating public keys: %w", err) } - return v.scheme.Verify(aggPubKey, msg, signature) + return a.scheme.Verify(aggPubKey, msg, signature) } -func (v *Verifier) pubkeysToMask(pubkeys []gpbft.PubKey) (*sign.Mask, error) { +func (v *Verifier) Aggregate(pubkeys []gpbft.PubKey) (_agg gpbft.Aggregate, _err error) { + defer func() { + status := measurements.AttrStatusSuccess + if _err != nil { + status = measurements.AttrStatusError + } + + if perr := recover(); perr != nil { + _err = fmt.Errorf("panicked aggregating public keys: %v\n%s", + perr, string(debug.Stack())) + log.Error(_err) + status = measurements.AttrStatusPanic + } + + metrics.aggregate.Record( + context.TODO(), int64(len(pubkeys)), + metric.WithAttributes(status), + ) + }() + kPubkeys := make([]kyber.Point, 0, len(pubkeys)) for i, p := range pubkeys { point, err := v.pubkeyToPoint(p) @@ -107,11 +136,12 @@ func (v *Verifier) pubkeysToMask(pubkeys []gpbft.PubKey) (*sign.Mask, error) { if err != nil { return nil, fmt.Errorf("creating key mask: %w", err) } - for i := range kPubkeys { - err := mask.SetBit(i, true) - if err != nil { - return nil, fmt.Errorf("setting mask bit %d: %w", i, err) - } + cmask, err := bdn.NewCachedMask(mask) + if err != nil { + return nil, fmt.Errorf("creating key mask: %w", err) } - return mask, nil + return &aggregation{ + mask: cmask, + scheme: v.scheme, + }, nil } diff --git a/certs/certs.go b/certs/certs.go index 02034fa8..eb8e1804 100644 --- a/certs/certs.go +++ b/certs/certs.go @@ -149,7 +149,8 @@ func verifyFinalityCertificateSignature(verifier gpbft.Verifier, powerTable gpbf return fmt.Errorf("failed to scale power table: %w", err) } - signers := make([]gpbft.PubKey, 0, len(powerTable)) + keys := powerTable.PublicKeys() + mask := make([]int, 0, len(powerTable)) var signerPowers int64 if err := cert.Signers.ForEach(func(i uint64) error { if i >= uint64(len(powerTable)) { @@ -164,7 +165,7 @@ func verifyFinalityCertificateSignature(verifier gpbft.Verifier, powerTable gpbf cert.GPBFTInstance, powerTable[i].ID) } signerPowers += power - signers = append(signers, powerTable[i].PubKey) + mask = append(mask, int(i)) return nil }); err != nil { return err @@ -191,7 +192,12 @@ func verifyFinalityCertificateSignature(verifier gpbft.Verifier, powerTable gpbf signedBytes = payload.MarshalForSigning(nn) } - if err := verifier.VerifyAggregate(signedBytes, cert.Signature, signers); err != nil { + aggregate, err := verifier.Aggregate(keys) + if err != nil { + return err + } + + if err := aggregate.VerifyAggregate(mask, signedBytes, cert.Signature); err != nil { return fmt.Errorf("invalid signature on finality certificate for instance %d: %w", cert.GPBFTInstance, err) } return nil diff --git a/emulator/instance.go b/emulator/instance.go index 5cfea50a..71193353 100644 --- a/emulator/instance.go +++ b/emulator/instance.go @@ -14,13 +14,14 @@ import ( // Instance represents a GPBFT instance capturing all the information necessary // for GPBFT to function, along with the final decision reached if any. type Instance struct { - t *testing.T - id uint64 - supplementalData gpbft.SupplementalData - proposal gpbft.ECChain - powerTable *gpbft.PowerTable - beacon []byte - decision *gpbft.Justification + t *testing.T + id uint64 + supplementalData gpbft.SupplementalData + proposal gpbft.ECChain + powerTable *gpbft.PowerTable + aggregateVerifier gpbft.Aggregate + beacon []byte + decision *gpbft.Justification } // NewInstance instantiates a new Instance for emulation. If absent, the @@ -57,12 +58,17 @@ func NewInstance(t *testing.T, id uint64, powerEntries gpbft.PowerEntries, propo } proposalChain, err := gpbft.NewChain(proposal[0], proposal[1:]...) require.NoError(t, err) + + aggVerifier, err := signing.Aggregate(pt.Entries.PublicKeys()) + require.NoError(t, err) + return &Instance{ - t: t, - id: id, - powerTable: pt, - beacon: []byte(fmt.Sprintf("🥓%d", id)), - proposal: proposalChain, + t: t, + id: id, + powerTable: pt, + aggregateVerifier: aggVerifier, + beacon: []byte(fmt.Sprintf("🥓%d", id)), + proposal: proposalChain, } } @@ -129,7 +135,6 @@ func (i *Instance) NewJustification(round uint64, step gpbft.Phase, vote gpbft.E msg := signing.MarshalPayloadForSigning(networkName, &payload) qr := gpbft.QuorumResult{ Signers: make([]int, len(from)), - PubKeys: make([]gpbft.PubKey, len(from)), Signatures: make([][]byte, len(from)), } for j, actor := range from { @@ -139,10 +144,9 @@ func (i *Instance) NewJustification(round uint64, step gpbft.Phase, vote gpbft.E signature, err := signing.Sign(context.Background(), entry.PubKey, msg) require.NoError(i.t, err) qr.Signatures[j] = signature - qr.PubKeys[j] = entry.PubKey qr.Signers[j] = index } - aggregate, err := signing.Aggregate(qr.PubKeys, qr.Signatures) + aggregate, err := i.aggregateVerifier.Aggregate(qr.Signers, qr.Signatures) require.NoError(i.t, err) return &gpbft.Justification{ Vote: payload, diff --git a/emulator/signing.go b/emulator/signing.go index 0bb35476..bcf95c27 100644 --- a/emulator/signing.go +++ b/emulator/signing.go @@ -3,6 +3,7 @@ package emulator import ( "bytes" "context" + "encoding/binary" "errors" "hash/crc32" @@ -46,13 +47,22 @@ func (s adhocSigning) Verify(sender gpbft.PubKey, msg, got []byte) error { } } -func (s adhocSigning) Aggregate(signers []gpbft.PubKey, sigs [][]byte) ([]byte, error) { - if len(signers) != len(sigs) { +type aggregate struct { + keys []gpbft.PubKey + signing adhocSigning +} + +// Aggregate implements gpbft.Aggregate. +func (a *aggregate) Aggregate(signerMask []int, sigs [][]byte) ([]byte, error) { + if len(signerMask) != len(sigs) { return nil, errors.New("public keys and signatures length mismatch") } hasher := crc32.NewIEEE() - for i, signer := range signers { - if _, err := hasher.Write(signer); err != nil { + for i, bit := range signerMask { + if err := binary.Write(hasher, binary.BigEndian, uint64(bit)); err != nil { + return nil, err + } + if _, err := hasher.Write(a.keys[bit]); err != nil { return nil, err } if _, err := hasher.Write(sigs[i]); err != nil { @@ -62,16 +72,17 @@ func (s adhocSigning) Aggregate(signers []gpbft.PubKey, sigs [][]byte) ([]byte, return hasher.Sum(nil), nil } -func (s adhocSigning) VerifyAggregate(payload, got []byte, signers []gpbft.PubKey) error { - signatures := make([][]byte, len(signers)) +// VerifyAggregate implements gpbft.Aggregate. +func (a *aggregate) VerifyAggregate(signerMask []int, payload []byte, got []byte) error { + signatures := make([][]byte, len(signerMask)) var err error - for i, signer := range signers { - signatures[i], err = s.Sign(context.Background(), signer, payload) + for i, bit := range signerMask { + signatures[i], err = a.signing.Sign(context.Background(), a.keys[bit], payload) if err != nil { return err } } - want, err := s.Aggregate(signers, signatures) + want, err := a.Aggregate(signerMask, signatures) if err != nil { return err } @@ -81,6 +92,12 @@ func (s adhocSigning) VerifyAggregate(payload, got []byte, signers []gpbft.PubKe return nil } +func (s adhocSigning) Aggregate(keys []gpbft.PubKey) (gpbft.Aggregate, error) { + return &aggregate{keys: keys, + signing: s, + }, nil +} + func (s adhocSigning) MarshalPayloadForSigning(name gpbft.NetworkName, payload *gpbft.Payload) []byte { return payload.MarshalForSigning(name) } diff --git a/go.mod b/go.mod index 2f949642..16c812dc 100644 --- a/go.mod +++ b/go.mod @@ -143,3 +143,5 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.2.1 // indirect ) + +replace github.com/drand/kyber => github.com/Stebalien/kyber v1.3.2-0.20240827162216-c96a0e427578 diff --git a/go.sum b/go.sum index 045636d4..2c4cf26e 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGy github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Kubuxu/go-broadcast v0.0.0-20240621161059-1a8c90734cd6 h1:yh2/1fz3ajTaeKskSWxtSBNScdRZfQ/A5nyd9+64T6M= github.com/Kubuxu/go-broadcast v0.0.0-20240621161059-1a8c90734cd6/go.mod h1:5LOj/fF3Oc/cvJqzDiyfx4XwtBPRWUYEz+V+b13sH5U= +github.com/Stebalien/kyber v1.3.2-0.20240827162216-c96a0e427578 h1:dx1hCR7KbG1HbehvPPRJKExoI9COfy8eMg7sCidKJEs= +github.com/Stebalien/kyber v1.3.2-0.20240827162216-c96a0e427578/go.mod h1:f+mNHjiGT++CuueBrpeMhFNdKZAsy0tu03bKq9D5LPA= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= @@ -49,8 +51,6 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3 github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/drand/kyber v1.3.1 h1:E0p6M3II+loMVwTlAp5zu4+GGZFNiRfq02qZxzw2T+Y= -github.com/drand/kyber v1.3.1/go.mod h1:f+mNHjiGT++CuueBrpeMhFNdKZAsy0tu03bKq9D5LPA= github.com/drand/kyber-bls12381 v0.3.1 h1:KWb8l/zYTP5yrvKTgvhOrk2eNPscbMiUOIeWBnmUxGo= github.com/drand/kyber-bls12381 v0.3.1/go.mod h1:H4y9bLPu7KZA/1efDg+jtJ7emKx+ro3PU7/jWUVt140= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= diff --git a/gpbft/api.go b/gpbft/api.go index ad8a5c89..bfaa6762 100644 --- a/gpbft/api.go +++ b/gpbft/api.go @@ -100,15 +100,27 @@ type SigningMarshaler interface { MarshalPayloadForSigning(NetworkName, *Payload) []byte } +type Aggregate interface { + // Aggregates signatures from a participants. + // + // Implementations must be safe for concurrent use. + Aggregate(signerMask []int, sigs [][]byte) ([]byte, error) + // VerifyAggregate verifies an aggregate signature. + // + // Implementations must be safe for concurrent use. + VerifyAggregate(signerMask []int, payload, aggSig []byte) error +} + type Verifier interface { // Verifies a signature for the given public key. + // // Implementations must be safe for concurrent use. Verify(pubKey PubKey, msg, sig []byte) error - // Aggregates signatures from a participants. - Aggregate(pubKeys []PubKey, sigs [][]byte) ([]byte, error) - // VerifyAggregate verifies an aggregate signature. + // Return an Aggregate that can aggregate and verify aggregate signatures made by the given + // public keys. + // // Implementations must be safe for concurrent use. - VerifyAggregate(payload, aggSig []byte, signers []PubKey) error + Aggregate(pubKeys []PubKey) (Aggregate, error) } type Signatures interface { diff --git a/gpbft/gpbft.go b/gpbft/gpbft.go index 20a1e695..598cd6bc 100644 --- a/gpbft/gpbft.go +++ b/gpbft/gpbft.go @@ -156,6 +156,8 @@ type instance struct { input ECChain // The power table for the base chain, used for power in this instance. powerTable PowerTable + // The aggregate signature verifier/aggregator. + aggregateVerifier Aggregate // The beacon value from the base chain, used for tickets in this instance. beacon []byte // Current round number. @@ -217,6 +219,7 @@ func newInstance( input ECChain, data *SupplementalData, powerTable PowerTable, + aggregateVerifier Aggregate, beacon []byte) (*instance, error) { if input.IsZero() { return nil, fmt.Errorf("input is empty") @@ -227,19 +230,20 @@ func newInstance( metrics.currentRound.Record(context.TODO(), 0) return &instance{ - participant: participant, - instanceID: instanceID, - input: input, - powerTable: powerTable, - beacon: beacon, - round: 0, - phase: INITIAL_PHASE, - supplementalData: data, - proposal: input, - broadcasted: newBroadcastState(), - value: ECChain{}, - candidates: []ECChain{input.BaseChain()}, - quality: newQuorumState(powerTable), + participant: participant, + instanceID: instanceID, + input: input, + powerTable: powerTable, + aggregateVerifier: aggregateVerifier, + beacon: beacon, + round: 0, + phase: INITIAL_PHASE, + supplementalData: data, + proposal: input, + broadcasted: newBroadcastState(), + value: ECChain{}, + candidates: []ECChain{input.BaseChain()}, + quality: newQuorumState(powerTable), rounds: map[uint64]*roundState{ 0: newRoundState(powerTable), }, @@ -944,7 +948,7 @@ func (i *instance) alarmAfterSynchrony() time.Time { // Builds a justification for a value from a quorum result. func (i *instance) buildJustification(quorum QuorumResult, round uint64, phase Phase, value ECChain) *Justification { - aggSignature, err := quorum.Aggregate(i.participant.host) + aggSignature, err := quorum.Aggregate(i.aggregateVerifier) if err != nil { panic(fmt.Errorf("aggregating for phase %v: %v", phase, err)) } @@ -1132,12 +1136,11 @@ func (q *quorumState) CouldReachStrongQuorumFor(key ChainKey, withAdversary bool type QuorumResult struct { // Signers is an array of indexes into the powertable, sorted in increasing order Signers []int - PubKeys []PubKey Signatures [][]byte } -func (q QuorumResult) Aggregate(v Verifier) ([]byte, error) { - return v.Aggregate(q.PubKeys, q.Signatures) +func (q QuorumResult) Aggregate(v Aggregate) ([]byte, error) { + return v.Aggregate(q.Signers, q.Signatures) } func (q QuorumResult) SignersBitfield() bitfield.BitField { @@ -1174,7 +1177,6 @@ func (q *quorumState) FindStrongQuorumFor(key ChainKey) (QuorumResult, bool) { // Accumulate signers and signatures until they reach a strong quorum. signatures := make([][]byte, 0, len(chainSupport.signatures)) - pubkeys := make([]PubKey, 0, len(signatures)) var justificationPower int64 for i, idx := range signers { if idx >= len(q.powerTable.Entries) { @@ -1184,11 +1186,9 @@ func (q *quorumState) FindStrongQuorumFor(key ChainKey) (QuorumResult, bool) { entry := q.powerTable.Entries[idx] justificationPower += power signatures = append(signatures, chainSupport.signatures[entry.ID]) - pubkeys = append(pubkeys, entry.PubKey) if IsStrongQuorum(justificationPower, q.powerTable.ScaledTotal) { return QuorumResult{ Signers: signers[:i+1], - PubKeys: pubkeys, Signatures: signatures, }, true } diff --git a/gpbft/mock_host_test.go b/gpbft/mock_host_test.go index 4bbbca06..36b50602 100644 --- a/gpbft/mock_host_test.go +++ b/gpbft/mock_host_test.go @@ -21,29 +21,29 @@ func (_m *MockHost) EXPECT() *MockHost_Expecter { return &MockHost_Expecter{mock: &_m.Mock} } -// Aggregate provides a mock function with given fields: pubKeys, sigs -func (_m *MockHost) Aggregate(pubKeys []PubKey, sigs [][]byte) ([]byte, error) { - ret := _m.Called(pubKeys, sigs) +// Aggregate provides a mock function with given fields: pubKeys +func (_m *MockHost) Aggregate(pubKeys []PubKey) (Aggregate, error) { + ret := _m.Called(pubKeys) if len(ret) == 0 { panic("no return value specified for Aggregate") } - var r0 []byte + var r0 Aggregate var r1 error - if rf, ok := ret.Get(0).(func([]PubKey, [][]byte) ([]byte, error)); ok { - return rf(pubKeys, sigs) + if rf, ok := ret.Get(0).(func([]PubKey) (Aggregate, error)); ok { + return rf(pubKeys) } - if rf, ok := ret.Get(0).(func([]PubKey, [][]byte) []byte); ok { - r0 = rf(pubKeys, sigs) + if rf, ok := ret.Get(0).(func([]PubKey) Aggregate); ok { + r0 = rf(pubKeys) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) + r0 = ret.Get(0).(Aggregate) } } - if rf, ok := ret.Get(1).(func([]PubKey, [][]byte) error); ok { - r1 = rf(pubKeys, sigs) + if rf, ok := ret.Get(1).(func([]PubKey) error); ok { + r1 = rf(pubKeys) } else { r1 = ret.Error(1) } @@ -58,24 +58,23 @@ type MockHost_Aggregate_Call struct { // Aggregate is a helper method to define mock.On call // - pubKeys []PubKey -// - sigs [][]byte -func (_e *MockHost_Expecter) Aggregate(pubKeys interface{}, sigs interface{}) *MockHost_Aggregate_Call { - return &MockHost_Aggregate_Call{Call: _e.mock.On("Aggregate", pubKeys, sigs)} +func (_e *MockHost_Expecter) Aggregate(pubKeys interface{}) *MockHost_Aggregate_Call { + return &MockHost_Aggregate_Call{Call: _e.mock.On("Aggregate", pubKeys)} } -func (_c *MockHost_Aggregate_Call) Run(run func(pubKeys []PubKey, sigs [][]byte)) *MockHost_Aggregate_Call { +func (_c *MockHost_Aggregate_Call) Run(run func(pubKeys []PubKey)) *MockHost_Aggregate_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]PubKey), args[1].([][]byte)) + run(args[0].([]PubKey)) }) return _c } -func (_c *MockHost_Aggregate_Call) Return(_a0 []byte, _a1 error) *MockHost_Aggregate_Call { +func (_c *MockHost_Aggregate_Call) Return(_a0 Aggregate, _a1 error) *MockHost_Aggregate_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockHost_Aggregate_Call) RunAndReturn(run func([]PubKey, [][]byte) ([]byte, error)) *MockHost_Aggregate_Call { +func (_c *MockHost_Aggregate_Call) RunAndReturn(run func([]PubKey) (Aggregate, error)) *MockHost_Aggregate_Call { _c.Call.Return(run) return _c } @@ -536,54 +535,6 @@ func (_c *MockHost_Verify_Call) RunAndReturn(run func(PubKey, []byte, []byte) er return _c } -// VerifyAggregate provides a mock function with given fields: payload, aggSig, signers -func (_m *MockHost) VerifyAggregate(payload []byte, aggSig []byte, signers []PubKey) error { - ret := _m.Called(payload, aggSig, signers) - - if len(ret) == 0 { - panic("no return value specified for VerifyAggregate") - } - - var r0 error - if rf, ok := ret.Get(0).(func([]byte, []byte, []PubKey) error); ok { - r0 = rf(payload, aggSig, signers) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockHost_VerifyAggregate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAggregate' -type MockHost_VerifyAggregate_Call struct { - *mock.Call -} - -// VerifyAggregate is a helper method to define mock.On call -// - payload []byte -// - aggSig []byte -// - signers []PubKey -func (_e *MockHost_Expecter) VerifyAggregate(payload interface{}, aggSig interface{}, signers interface{}) *MockHost_VerifyAggregate_Call { - return &MockHost_VerifyAggregate_Call{Call: _e.mock.On("VerifyAggregate", payload, aggSig, signers)} -} - -func (_c *MockHost_VerifyAggregate_Call) Run(run func(payload []byte, aggSig []byte, signers []PubKey)) *MockHost_VerifyAggregate_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]byte), args[1].([]byte), args[2].([]PubKey)) - }) - return _c -} - -func (_c *MockHost_VerifyAggregate_Call) Return(_a0 error) *MockHost_VerifyAggregate_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockHost_VerifyAggregate_Call) RunAndReturn(run func([]byte, []byte, []PubKey) error) *MockHost_VerifyAggregate_Call { - _c.Call.Return(run) - return _c -} - // NewMockHost creates a new instance of MockHost. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockHost(t interface { diff --git a/gpbft/participant.go b/gpbft/participant.go index 88df03ff..96fc9c3f 100644 --- a/gpbft/participant.go +++ b/gpbft/participant.go @@ -333,7 +333,7 @@ func (p *Participant) validateJustification(msg *GMessage, comt *committee) erro // Check justification power and signature. var justificationPower int64 - signers := make([]PubKey, 0) + signers := make([]int, 0) if err := msg.Justification.Signers.ForEach(func(bit uint64) error { if int(bit) >= len(comt.power.Entries) { return fmt.Errorf("invalid signer index: %d", bit) @@ -343,7 +343,7 @@ func (p *Participant) validateJustification(msg *GMessage, comt *committee) erro return fmt.Errorf("signer with ID %d has no power", comt.power.Entries[bit].ID) } justificationPower += power - signers = append(signers, comt.power.Entries[bit].PubKey) + signers = append(signers, int(bit)) return nil }); err != nil { return fmt.Errorf("failed to iterate over signers: %w", err) @@ -354,7 +354,7 @@ func (p *Participant) validateJustification(msg *GMessage, comt *committee) erro } payload := p.host.MarshalPayloadForSigning(p.host.NetworkName(), &msg.Justification.Vote) - if err := p.host.VerifyAggregate(payload, msg.Justification.Signature, signers); err != nil { + if err := comt.aggregateVerifier.VerifyAggregate(signers, payload, msg.Justification.Signature); err != nil { return fmt.Errorf("verification of the aggregate failed: %+v: %w", msg.Justification, err) } @@ -445,7 +445,7 @@ func (p *Participant) beginInstance() error { if err != nil { return err } - if p.gpbft, err = newInstance(p, p.currentInstance, chain, data, *comt.power, comt.beacon); err != nil { + if p.gpbft, err = newInstance(p, p.currentInstance, chain, data, *comt.power, comt.aggregateVerifier, comt.beacon); err != nil { return fmt.Errorf("failed creating new gpbft instance: %w", err) } if err := p.gpbft.Start(); err != nil { @@ -485,7 +485,16 @@ func (p *Participant) fetchCommittee(instance uint64) (*committee, error) { if err := power.Validate(); err != nil { return nil, fmt.Errorf("instance %d: %w: invalid power: %w", instance, ErrValidationNoCommittee, err) } - comt = &committee{power: power, beacon: beacon} + + // TODO: filter out participants with no effective power after rounding? + // TODO: this is slow and under a lock, but we only want to do it once per + // instance... ideally we'd have a per-instance lock/once, but that probably isn't + // worth it. + agg, err := p.host.Aggregate(power.Entries.PublicKeys()) + if err != nil { + return nil, fmt.Errorf("failed to pre-compute aggregate mask for instance %d: %w: %w", instance, ErrValidationNoCommittee, err) + } + comt = &committee{power: power, beacon: beacon, aggregateVerifier: agg} p.committees[instance] = comt } return comt, nil @@ -544,8 +553,9 @@ func (p *Participant) Describe() string { // A power table and beacon value used as the committee inputs to an instance. type committee struct { - power *PowerTable - beacon []byte + power *PowerTable + beacon []byte + aggregateVerifier Aggregate } // A collection of messages queued for delivery for a future instance. diff --git a/gpbft/participant_test.go b/gpbft/participant_test.go index 24ed4a9a..5d27f3af 100644 --- a/gpbft/participant_test.go +++ b/gpbft/participant_test.go @@ -91,9 +91,12 @@ func (pt *participantTestSubject) Log(format string, args ...any) { } func (pt *participantTestSubject) expectBeginInstance() { + publicKeys := pt.powerTable.Entries.PublicKeys() + // Prepare the test host. pt.host.On("GetProposalForInstance", pt.instance).Return(pt.supplementalData, pt.canonicalChain, nil) pt.host.On("GetCommitteeForInstance", pt.instance).Return(pt.powerTable, pt.beacon, nil).Once() + pt.host.On("Aggregate", publicKeys).Return(nil, nil) pt.host.On("Time").Return(pt.time) pt.host.On("NetworkName").Return(pt.networkName).Maybe() // We need to use `Maybe` here because `MarshalPayloadForSigning` may be called @@ -106,6 +109,7 @@ func (pt *participantTestSubject) expectBeginInstance() { // Expect calls to get the host state prior to beginning of an instance. pt.host.EXPECT().GetProposalForInstance(pt.instance) pt.host.EXPECT().GetCommitteeForInstance(pt.instance) + pt.host.EXPECT().Aggregate(publicKeys) pt.host.EXPECT().Time() // Expect alarm is set to 2X of configured delta. @@ -189,6 +193,7 @@ func (pt *participantTestSubject) mockValidTicket(target gpbft.PubKey, ticket gp func (pt *participantTestSubject) mockCommitteeForInstance(instance uint64, powerTable *gpbft.PowerTable, beacon []byte) { pt.host.On("GetCommitteeForInstance", instance).Return(powerTable, beacon, nil).Once() + pt.host.On("Aggregate", powerTable.Entries.PublicKeys()).Return(nil, nil) } func (pt *participantTestSubject) mockCommitteeUnavailableForInstance(instance uint64) { diff --git a/gpbft/powertable.go b/gpbft/powertable.go index a876b5ab..e63568b6 100644 --- a/gpbft/powertable.go +++ b/gpbft/powertable.go @@ -33,6 +33,14 @@ type PowerTable struct { ScaledTotal int64 } +func (e PowerEntries) PublicKeys() []PubKey { + keys := make([]PubKey, len(e)) + for i, e := range e { + keys[i] = e.PubKey + } + return keys +} + func (p *PowerEntry) Equal(o *PowerEntry) bool { return p.ID == o.ID && p.Power.Equals(o.Power) && bytes.Equal(p.PubKey, o.PubKey) } diff --git a/host.go b/host.go index 9dd8dbbd..56cbda78 100644 --- a/host.go +++ b/host.go @@ -700,13 +700,6 @@ func (h *gpbftHost) Verify(pubKey gpbft.PubKey, msg []byte, sig []byte) error { return h.verifier.Verify(pubKey, msg, sig) } -// Aggregates signatures from a participants. -func (h *gpbftHost) Aggregate(pubKeys []gpbft.PubKey, sigs [][]byte) ([]byte, error) { - return h.verifier.Aggregate(pubKeys, sigs) -} - -// VerifyAggregate verifies an aggregate signature. -// Implementations must be safe for concurrent use. -func (h *gpbftHost) VerifyAggregate(payload []byte, aggSig []byte, signers []gpbft.PubKey) error { - return h.verifier.VerifyAggregate(payload, aggSig, signers) +func (h *gpbftHost) Aggregate(pubKeys []gpbft.PubKey) (gpbft.Aggregate, error) { + return h.verifier.Aggregate(pubKeys) } diff --git a/sim/adversary/decide.go b/sim/adversary/decide.go index 899d6fc9..22827f10 100644 --- a/sim/adversary/decide.go +++ b/sim/adversary/decide.go @@ -97,30 +97,32 @@ func (i *ImmediateDecide) StartInstanceAt(instance uint64, _when time.Time) erro } var ( - pubkeys []gpbft.PubKey - sigs [][]byte + mask []int + sigs [][]byte ) if err := signers.ForEach(func(j uint64) error { - pubkey := gpbft.PubKey("fake pubkeyaaaaa") - sig := []byte("fake sig") - if j < uint64(len(powertable.Entries)) { - pubkey = powertable.Entries[j].PubKey - var err error - sig, err = i.host.Sign(context.Background(), pubkey, sigPayload) - if err != nil { - return err - } + if j >= uint64(len(powertable.Entries)) { + return nil + } + pubkey := powertable.Entries[j].PubKey + sig, err := i.host.Sign(context.Background(), pubkey, sigPayload) + if err != nil { + return err } - pubkeys = append(pubkeys, pubkey) + mask = append(mask, int(j)) sigs = append(sigs, sig) return nil }); err != nil { panic(err) } - aggregatedSig, err := i.host.Aggregate(pubkeys, sigs) + agg, err := i.host.Aggregate(powertable.Entries.PublicKeys()) + if err != nil { + panic(err) + } + aggregatedSig, err := agg.Aggregate(mask, sigs) if err != nil { panic(err) } diff --git a/sim/adversary/withhold.go b/sim/adversary/withhold.go index d3a0e43b..cb3814ad 100644 --- a/sim/adversary/withhold.go +++ b/sim/adversary/withhold.go @@ -106,15 +106,19 @@ func (w *WithholdCommit) StartInstanceAt(instance uint64, _when time.Time) error sort.Ints(signers) signatures := make([][]byte, 0) - pubKeys := make([]gpbft.PubKey, 0) + mask := make([]int, 0) prepareMarshalled := w.host.MarshalPayloadForSigning(w.host.NetworkName(), &preparePayload) for _, signerIndex := range signers { entry := powertable.Entries[signerIndex] signatures = append(signatures, w.sign(entry.PubKey, prepareMarshalled)) - pubKeys = append(pubKeys, entry.PubKey) + mask = append(mask, signerIndex) justification.Signers.Set(uint64(signerIndex)) } - justification.Signature, err = w.host.Aggregate(pubKeys, signatures) + agg, err := w.host.Aggregate(powertable.Entries.PublicKeys()) + if err != nil { + panic(err) + } + justification.Signature, err = agg.Aggregate(mask, signatures) if err != nil { panic(err) } diff --git a/sim/ec.go b/sim/ec.go index d9f7ae4c..197dee09 100644 --- a/sim/ec.go +++ b/sim/ec.go @@ -36,8 +36,9 @@ type ECInstance struct { // SupplementalData is the additional data for this instance. SupplementalData *gpbft.SupplementalData - ec *simEC - decisions map[gpbft.ActorID]*gpbft.Justification + ec *simEC + decisions map[gpbft.ActorID]*gpbft.Justification + aggregateVerifier gpbft.Aggregate } type errGroup []error @@ -64,6 +65,12 @@ func (ec *simEC) BeginInstance(baseChain gpbft.ECChain, pt *gpbft.PowerTable) *E // Note a real beacon value will come from a finalised chain with some lookback. beacon := baseChain.Head().Key nextInstanceID := uint64(ec.Len()) + + agg, err := ec.verifier.Aggregate(pt.Entries.PublicKeys()) + if err != nil { + panic(err) + } + instance := &ECInstance{ Instance: nextInstanceID, BaseChain: baseChain, @@ -72,8 +79,9 @@ func (ec *simEC) BeginInstance(baseChain gpbft.ECChain, pt *gpbft.PowerTable) *E SupplementalData: &gpbft.SupplementalData{ PowerTable: gpbft.CID(fmt.Sprintf("supp-data-pt@%d", nextInstanceID)), }, - ec: ec, - decisions: make(map[gpbft.ActorID]*gpbft.Justification), + ec: ec, + aggregateVerifier: agg, + decisions: make(map[gpbft.ActorID]*gpbft.Justification), } ec.instances = append(ec.instances, instance) return instance @@ -123,14 +131,14 @@ func (eci *ECInstance) validateDecision(decision *gpbft.Justification) error { // Extract signers. justificationPower := gpbft.NewStoragePower(0) - signers := make([]gpbft.PubKey, 0) + signers := make([]int, 0) powerTable := eci.PowerTable if err := decision.Signers.ForEach(func(bit uint64) error { if int(bit) >= len(powerTable.Entries) { return fmt.Errorf("invalid signer index: %d", bit) } justificationPower = big.Add(justificationPower, powerTable.Entries[bit].Power) - signers = append(signers, powerTable.Entries[bit].PubKey) + signers = append(signers, int(bit)) return nil }); err != nil { return fmt.Errorf("failed to iterate over signers: %w", err) @@ -144,7 +152,8 @@ func (eci *ECInstance) validateDecision(decision *gpbft.Justification) error { } // Verify aggregate signature payload := eci.ec.verifier.MarshalPayloadForSigning(eci.ec.networkName, &decision.Vote) - if err := eci.ec.verifier.VerifyAggregate(payload, decision.Signature, signers); err != nil { + + if err := eci.aggregateVerifier.VerifyAggregate(signers, payload, decision.Signature); err != nil { return fmt.Errorf("invalid aggregate signature: %v: %w", decision, err) } diff --git a/sim/justification.go b/sim/justification.go index d6ac7065..935e21ba 100644 --- a/sim/justification.go +++ b/sim/justification.go @@ -67,6 +67,8 @@ func MakeJustification(backend signing.Backend, nn gpbft.NetworkName, chain gpbf slices.SortFunc(votes, func(a, b vote) int { return cmp.Compare(a.index, b.index) }) + signers = signers[:len(votes)] + slices.Sort(signers) pks := make([]gpbft.PubKey, len(votes)) sigs := make([][]byte, len(votes)) for i, vote := range votes { @@ -74,7 +76,12 @@ func MakeJustification(backend signing.Backend, nn gpbft.NetworkName, chain gpbf sigs[i] = vote.sig } - sig, err := backend.Aggregate(pks, sigs) + agg, err := backend.Aggregate(powerTable.PublicKeys()) + if err != nil { + return nil, err + } + + sig, err := agg.Aggregate(signers, sigs) if err != nil { return nil, err } diff --git a/sim/signing/fake.go b/sim/signing/fake.go index ca44a9f0..5f524e95 100644 --- a/sim/signing/fake.go +++ b/sim/signing/fake.go @@ -76,35 +76,18 @@ func (s *FakeBackend) Verify(signer gpbft.PubKey, msg, sig []byte) error { } } -func (*FakeBackend) Aggregate(signers []gpbft.PubKey, sigs [][]byte) ([]byte, error) { - if len(signers) != len(sigs) { - return nil, errors.New("public keys and signatures length mismatch") - } - hasher := sha256.New() - for i, signer := range signers { +func (s *FakeBackend) Aggregate(keys []gpbft.PubKey) (gpbft.Aggregate, error) { + for i, signer := range keys { if len(signer) != 16 { - return nil, fmt.Errorf("wrong signer pubkey length: %d != 16", len(signer)) + return nil, fmt.Errorf("wrong signer %d pubkey length: %d != 16", i, len(signer)) } - hasher.Write(signer) - hasher.Write(sigs[i]) } - return hasher.Sum(nil), nil -} -func (s *FakeBackend) VerifyAggregate(payload, aggSig []byte, signers []gpbft.PubKey) error { - hasher := sha256.New() - for _, signer := range signers { - sig, err := s.generateSignature(signer, payload) - if err != nil { - return err - } - hasher.Write(signer) - hasher.Write(sig) - } - if !bytes.Equal(aggSig, hasher.Sum(nil)) { - return errors.New("signature is not valid") - } - return nil + return &fakeAggregate{ + keys: keys, + backend: s, + }, nil + } func (v *FakeBackend) MarshalPayloadForSigning(nn gpbft.NetworkName, p *gpbft.Payload) []byte { @@ -142,3 +125,44 @@ func (v *FakeBackend) MarshalPayloadForSigning(nn gpbft.NetworkName, p *gpbft.Pa } return buf.Bytes() } + +type fakeAggregate struct { + keys []gpbft.PubKey + backend *FakeBackend +} + +// Aggregate implements gpbft.Aggregate. +func (f *fakeAggregate) Aggregate(signerMask []int, sigs [][]byte) ([]byte, error) { + if len(signerMask) != len(sigs) { + return nil, errors.New("public keys and signatures length mismatch") + } + hasher := sha256.New() + for i, bit := range signerMask { + if bit >= len(f.keys) { + return nil, fmt.Errorf("signer %d out of range", bit) + } + binary.Write(hasher, binary.BigEndian, int64(bit)) + hasher.Write(f.keys[bit]) + hasher.Write(sigs[i]) + } + return hasher.Sum(nil), nil +} + +// VerifyAggregate implements gpbft.Aggregate. +func (f *fakeAggregate) VerifyAggregate(signerMask []int, payload []byte, aggSig []byte) error { + hasher := sha256.New() + for _, bit := range signerMask { + signer := f.keys[bit] + sig, err := f.backend.generateSignature(signer, payload) + if err != nil { + return err + } + binary.Write(hasher, binary.BigEndian, int64(bit)) + hasher.Write(signer) + hasher.Write(sig) + } + if !bytes.Equal(aggSig, hasher.Sum(nil)) { + return errors.New("signature is not valid") + } + return nil +} diff --git a/test/signing_suite_test.go b/test/signing_suite_test.go index 95dd9cb9..4b500211 100644 --- a/test/signing_suite_test.go +++ b/test/signing_suite_test.go @@ -89,44 +89,47 @@ func (s *SigningTestSuite) TestAggregateAndVerify() { pubKey2, signer2 := s.signerTestSubject(s.T()) pubKeys := []gpbft.PubKey{pubKey1, pubKey2} + aggregator, err := s.verifier.Aggregate(pubKeys) + require.NoError(s.T(), err) + + mask := []int{0, 1} sigs := make([][]byte, len(pubKeys)) - var err error sigs[0], err = signer1.Sign(ctx, pubKey1, msg) require.NoError(s.T(), err) sigs[1], err = signer2.Sign(ctx, pubKey2, msg) require.NoError(s.T(), err) - aggSig, err := s.verifier.Aggregate(pubKeys, sigs) + aggSig, err := aggregator.Aggregate(mask, sigs) require.NoError(t, err) - err = s.verifier.VerifyAggregate(msg, aggSig, pubKeys) + err = aggregator.VerifyAggregate(mask, msg, aggSig) require.NoError(t, err) - aggSig, err = s.verifier.Aggregate(pubKeys[0:1], sigs[0:1]) + aggSig, err = aggregator.Aggregate(mask[0:1], sigs[0:1]) require.NoError(t, err) - err = s.verifier.VerifyAggregate(msg, aggSig, pubKeys) + err = aggregator.VerifyAggregate(mask, msg, aggSig) require.Error(t, err) - aggSig, err = s.verifier.Aggregate(pubKeys, [][]byte{sigs[0], sigs[0]}) + aggSig, err = aggregator.Aggregate(mask, [][]byte{sigs[0], sigs[0]}) require.NoError(t, err) - err = s.verifier.VerifyAggregate(msg, aggSig, pubKeys) + err = aggregator.VerifyAggregate(mask, msg, aggSig) require.Error(t, err) - err = s.verifier.VerifyAggregate(msg, []byte("bad sig"), pubKeys) + err = aggregator.VerifyAggregate(mask, msg, []byte("bad sig")) require.Error(t, err) - _, err = s.verifier.Aggregate(pubKeys, [][]byte{sigs[0]}) + _, err = aggregator.Aggregate(mask, [][]byte{sigs[0]}) require.Error(t, err, "Missmatched pubkeys and sigs lengths should fail") { pubKeys2 := slices.Clone(pubKeys) - pubKeys2[0] = slices.Clone(pubKeys2[0]) - pubKeys2[0] = pubKeys2[0][1:len(pubKeys2)] - _, err = s.verifier.Aggregate(pubKeys2, sigs) - require.Error(t, err, "damaged pubkey should error") + pubKey3, _ := s.signerTestSubject(s.T()) + pubKeys2[0] = pubKey3 + wrongKeyAggregator, err := s.verifier.Aggregate(pubKeys2) + require.NoError(t, err) - require.Error(t, s.verifier.VerifyAggregate(msg, aggSig, pubKeys2), "damaged pubkey should error") + require.Error(t, wrongKeyAggregator.VerifyAggregate(mask, msg, aggSig), "wrong pubkey should error") } }