diff --git a/rolling-shutter/go.mod b/rolling-shutter/go.mod index 0cfaccd5..238d671e 100644 --- a/rolling-shutter/go.mod +++ b/rolling-shutter/go.mod @@ -16,6 +16,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/icza/gog v0.0.0-20240529172513-3355cf65d018 github.com/ipfs/go-log/v2 v2.5.1 github.com/jackc/pgconn v1.14.1 @@ -127,7 +128,6 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-bexpr v0.1.11 // indirect github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d // indirect - github.com/hashicorp/golang-lru/v2 v2.0.5 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/holiman/bloomfilter/v2 v2.0.3 // indirect github.com/holiman/uint256 v1.2.4 // indirect diff --git a/rolling-shutter/go.sum b/rolling-shutter/go.sum index 0fbd8838..81d8917d 100644 --- a/rolling-shutter/go.sum +++ b/rolling-shutter/go.sum @@ -418,8 +418,8 @@ github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d h1:dg1dEPuWpEqDnvIw251EVy4zlP8gWbsGj4BsUKCRpYs= github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= -github.com/hashicorp/golang-lru/v2 v2.0.5 h1:wW7h1TG88eUIJ2i69gaE3uNVtEPIagzhGvHgwfx2Vm4= -github.com/hashicorp/golang-lru/v2 v2.0.5/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/holiman/billy v0.0.0-20230718173358-1c7e68d277a7 h1:3JQNjnMRil1yD0IfZKHF9GxxWKDJGj8I0IqOUol//sw= diff --git a/rolling-shutter/keyper/epochkghandler/key.go b/rolling-shutter/keyper/epochkghandler/key.go index 2f06830f..3667eddd 100644 --- a/rolling-shutter/keyper/epochkghandler/key.go +++ b/rolling-shutter/keyper/epochkghandler/key.go @@ -3,8 +3,8 @@ package epochkghandler import ( "bytes" "context" - "math" + lru "github.com/hashicorp/golang-lru/v2" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" pubsub "github.com/libp2p/go-libp2p-pubsub" @@ -14,18 +14,23 @@ import ( "github.com/shutter-network/shutter/shlib/shcrypto" "github.com/shutter-network/rolling-shutter/rolling-shutter/keyper/database" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley" "github.com/shutter-network/rolling-shutter/rolling-shutter/p2p" "github.com/shutter-network/rolling-shutter/rolling-shutter/p2pmsg" "github.com/shutter-network/rolling-shutter/rolling-shutter/shdb" ) func NewDecryptionKeyHandler(config Config, dbpool *pgxpool.Pool) p2p.MessageHandler { - return &DecryptionKeyHandler{config: config, dbpool: dbpool} + // Not catching the error as it only can happen if non-positive size was applied + cache, _ := lru.New[shcrypto.EpochSecretKey, []byte](1024) + return &DecryptionKeyHandler{config: config, dbpool: dbpool, cache: cache} } type DecryptionKeyHandler struct { config Config dbpool *pgxpool.Pool + // keep 1024 verified keys in Cache to skip additional verifications + cache *lru.Cache[shcrypto.EpochSecretKey, []byte] } func (*DecryptionKeyHandler) MessagePrototypes() []p2pmsg.Message { @@ -38,13 +43,13 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2 return pubsub.ValidationReject, errors.Errorf("instance ID mismatch (want=%d, have=%d)", handler.config.GetInstanceID(), key.GetInstanceID()) } - if key.Eon > math.MaxInt64 { - return pubsub.ValidationReject, errors.Errorf("eon %d overflows int64", key.Eon) + eon, err := medley.Uint64ToInt64Safe(key.Eon) + if err != nil { + return pubsub.ValidationReject, errors.Wrapf(err, "overflow error while converting eon to int64 %d", eon) } queries := database.New(handler.dbpool) - - _, isKeyper, err := queries.GetKeyperIndex(ctx, int64(key.Eon), handler.config.GetAddress()) + _, isKeyper, err := queries.GetKeyperIndex(ctx, eon, handler.config.GetAddress()) if err != nil { return pubsub.ValidationReject, err } @@ -52,20 +57,19 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2 log.Debug().Uint64("eon", key.Eon).Msg("Ignoring decryptionKey for eon; we're not a Keyper") return pubsub.ValidationReject, nil } - - dkgResultDB, err := queries.GetDKGResultForKeyperConfigIndex(ctx, int64(key.Eon)) - if err == pgx.ErrNoRows { - return pubsub.ValidationReject, errors.Errorf("no DKG result found for eon %d", key.Eon) + dkgResultDB, err := queries.GetDKGResultForKeyperConfigIndex(ctx, eon) + if errors.Is(err, pgx.ErrNoRows) { + return pubsub.ValidationReject, errors.Errorf("no DKG result found for eon %d", eon) } if err != nil { - return pubsub.ValidationReject, errors.Wrapf(err, "failed to get dkg result for eon %d from db", key.Eon) + return pubsub.ValidationReject, errors.Wrapf(err, "failed to get dkg result for eon %d from db", eon) } if !dkgResultDB.Success { - return pubsub.ValidationReject, errors.Errorf("no successful DKG result found for eon %d", key.Eon) + return pubsub.ValidationReject, errors.Errorf("no successful DKG result found for eon %d", eon) } pureDKGResult, err := shdb.DecodePureDKGResult(dkgResultDB.PureResult) if err != nil { - return pubsub.ValidationReject, errors.Wrapf(err, "error while decoding pure DKG result for eon %d", key.Eon) + return pubsub.ValidationReject, errors.Wrapf(err, "error while decoding pure DKG result for eon %d", eon) } if len(key.Keys) == 0 { @@ -74,11 +78,19 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2 if len(key.Keys) > int(handler.config.GetMaxNumKeysPerMessage()) { return pubsub.ValidationReject, errors.Errorf("too many keys in message (%d > %d)", len(key.Keys), handler.config.GetMaxNumKeysPerMessage()) } + for i, k := range key.Keys { epochSecretKey, err := k.GetEpochSecretKey() if err != nil { return pubsub.ValidationReject, err } + identity, exists := handler.cache.Get(*epochSecretKey) + if exists { + if bytes.Equal(k.Identity, identity) { + continue + } + return pubsub.ValidationReject, errors.Errorf("epoch secret key for identity %x is not valid", k.Identity) + } ok, err := shcrypto.VerifyEpochSecretKey(epochSecretKey, pureDKGResult.PublicKey, k.Identity) if err != nil { return pubsub.ValidationReject, errors.Wrapf(err, "error while checking epoch secret key for identity %x", k.Identity) @@ -86,7 +98,6 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2 if !ok { return pubsub.ValidationReject, errors.Errorf("epoch secret key for identity %x is not valid", k.Identity) } - if i > 0 && bytes.Compare(k.Identity, key.Keys[i-1].Identity) < 0 { return pubsub.ValidationReject, errors.Errorf("keys not ordered") } @@ -97,7 +108,15 @@ func (handler *DecryptionKeyHandler) ValidateMessage(ctx context.Context, msg p2 func (handler *DecryptionKeyHandler) HandleMessage(ctx context.Context, msg p2pmsg.Message) ([]p2pmsg.Message, error) { metricsEpochKGDecryptionKeysReceived.Inc() key := msg.(*p2pmsg.DecryptionKeys) - // Insert the key into the db. We assume that it's valid as it already passed the libp2p - // validator. + // We assume that it's valid as it already passed the libp2p validator. + // Insert the key into the cache. + for _, k := range key.Keys { + epochSecretKey, err := k.GetEpochSecretKey() + if err != nil { + return nil, err + } + handler.cache.Add(*epochSecretKey, k.Identity) + } + // Insert the key into the db. return nil, database.New(handler.dbpool).InsertDecryptionKeysMsg(ctx, key) }