Skip to content

Commit

Permalink
Finish key assignment integration in driver
Browse files Browse the repository at this point in the history
  • Loading branch information
p-offtermatt committed Jan 19, 2024
1 parent ff0489e commit 35e5d6c
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 34 deletions.
99 changes: 66 additions & 33 deletions tests/mbt/driver/mbt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/stretchr/testify/require"

sdktypes "github.com/cosmos/cosmos-sdk/types"

cmttypes "github.com/cometbft/cometbft/types"

providertypes "github.com/cosmos/interchain-security/v3/x/ccv/provider/types"
tmencoding "github.com/cometbft/cometbft/crypto/encoding"
"github.com/cosmos/interchain-security/v3/testutil/integration"

sdktypes "github.com/cosmos/cosmos-sdk/types"

cryptotestutil "github.com/cosmos/interchain-security/v3/testutil/crypto"
providertypes "github.com/cosmos/interchain-security/v3/x/ccv/provider/types"
)

const verbose = false
Expand Down Expand Up @@ -71,6 +72,7 @@ func TestMBT(t *testing.T) {
t.Logf("Number of sent packets: %v", stats.numSentPackets)
t.Logf("Number of blocks: %v", stats.numBlocks)
t.Logf("Number of transactions: %v", stats.numTxs)
t.Logf("Number of key assignments: %v", stats.numKeyAssignments)
t.Logf("Average summed block time delta passed per trace: %v", stats.totalBlockTimePassedPerTrace/time.Duration(numTraces))
}

Expand Down Expand Up @@ -119,6 +121,21 @@ func RunItfTrace(t *testing.T, path string) {

t.Log("Chains are: ", chains)

// generate keys that can be assigned on consumers, according to the ConsumerAddresses in the trace
consumerAddressesExpr := params["ConsumerAddresses"].Value.(itf.ListExprType)

consumerPrivVals, err := integration.CreateSigners(len(consumerAddressesExpr))
require.NoError(t, err, "Error creating consumer signers")

consumerAddrNamesToPrivVals := make(map[string]cmttypes.PrivValidator, len(consumerAddressesExpr))
realAddrsToModelConsAddrs := make(map[string]string, len(consumerAddressesExpr))
i := 0
for address, privVal := range consumerPrivVals {
consumerAddrNamesToPrivVals[consumerAddressesExpr[i].Value.(string)] = privVal
realAddrsToModelConsAddrs[address] = consumerAddressesExpr[i].Value.(string)
i++
}

// create params struct
vscTimeout := time.Duration(params["VscTimeout"].Value.(int64)) * time.Second

Expand Down Expand Up @@ -147,6 +164,15 @@ func RunItfTrace(t *testing.T, path string) {
valSet, addressMap, signers, err := CreateValSet(initialValSet)
require.NoError(t, err, "Error creating validator set")

// get the set of signers for consumers: the validator signers, plus signers for the assignable addresses
consumerSigners := make(map[string]cmttypes.PrivValidator, 0)
for consAddr, consPrivVal := range consumerPrivVals {
consumerSigners[consAddr] = consPrivVal
}
for consAddr, signer := range signers {
consumerSigners[consAddr] = signer
}

// get a slice of validators in the right order
nodes := make([]*cmttypes.Validator, len(valNames))
for i, valName := range valNames {
Expand All @@ -158,17 +184,6 @@ func RunItfTrace(t *testing.T, path string) {

driver.setupProvider(modelParams, valSet, signers, nodes, valNames)

// generate keys that can be assigned on consumers, according to the ConsumerAddresses in the trace
consumerAddressesExpr := params["ConsumerAddresses"].Value.(itf.ListExprType)

assignableIdentities := make(map[string]*cryptotestutil.CryptoIdentity, 0)
i := 0
for _, consumerAddr := range consumerAddressesExpr {
// future TOOD: does this need logic to avoid doubling on keys with the validators? probably so statistically unlikely that it's fine
assignableIdentities[consumerAddr.Value.(string)] = cryptotestutil.NewCryptoIdentityFromIntSeed(i)
i++
}

// remember the time offsets to be able to compare times to the model
// this is necessary because the system needs to do many steps to initialize the chains,
// which is abstracted away in the model
Expand Down Expand Up @@ -256,7 +271,7 @@ func RunItfTrace(t *testing.T, path string) {
consumer.Value.(string),
modelParams,
driver.providerChain().Vals,
signers,
consumerSigners,
nodes,
valNames,
driver.providerChain(),
Expand Down Expand Up @@ -347,11 +362,17 @@ func RunItfTrace(t *testing.T, path string) {
consumerAddr := lastAction["consumerAddr"].Value.(string)

t.Log("KeyAssignment", consumerChain, node, consumerAddr)
stats.numKeyAssignments++

valIndex := getIndexOfString(node, valNames)
assignedKey := assignableIdentities[consumerAddr].TMProtoCryptoPublicKey()
assignedPrivVal := consumerAddrNamesToPrivVals[consumerAddr]
assignedKey, err := assignedPrivVal.GetPubKey()
require.NoError(t, err, "Error getting pubkey")

error := driver.AssignKey(ChainId(consumerChain), int64(valIndex), assignedKey)
protoPubKey, err := tmencoding.PubKeyToProto(assignedKey)
require.NoError(t, err, "Error converting pubkey to proto")

error := driver.AssignKey(ChainId(consumerChain), int64(valIndex), protoPubKey)
require.NoError(t, error, "Error assigning key")

default:
Expand Down Expand Up @@ -389,7 +410,7 @@ func RunItfTrace(t *testing.T, path string) {
require.Equal(t, modelRunningConsumers, actualRunningConsumers, "Running consumers do not match")

// check validator sets - provider current validator set should be the one from the staking keeper
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers)
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers, realAddrsToModelConsAddrs)

// check times - sanity check that the block times match the ones from the model
CompareTimes(driver, actualRunningConsumers, currentModelState, timeOffset)
Expand All @@ -408,7 +429,14 @@ func RunItfTrace(t *testing.T, path string) {
t.Log("🟢 Trace is ok!")
}

func CompareValidatorSets(t *testing.T, driver *Driver, currentModelState map[string]itf.Expr, consumers []string) {
func CompareValidatorSets(
t *testing.T,
driver *Driver,
currentModelState map[string]itf.Expr,
consumers []string,
// a map from real addresses to the names of those consumer addresses in the model
keyAddrsToModelConsAddrName map[string]string,
) {
t.Helper()
modelValSet := ValidatorSet(currentModelState, "provider")

Expand All @@ -432,23 +460,28 @@ func CompareValidatorSets(t *testing.T, driver *Driver, currentModelState map[st
pubkey, err := val.ConsPubKey()
require.NoError(t, err, "Error getting pubkey")

consAddr := providertypes.NewConsumerConsAddress(sdktypes.ConsAddress(pubkey.Address().Bytes()))
consAddrModelName, ok := keyAddrsToModelConsAddrName[pubkey.Address().String()]
if ok { // the node has a key assigned, use the name of the consumer address in the model
consumerCurValSet[consAddrModelName] = val.Power
} else { // the node doesn't have a key assigned yet, get the validator moniker
consAddr := providertypes.NewConsumerConsAddress(sdktypes.ConsAddress(pubkey.Address().Bytes()))

// the consumer vals right now are CrossChainValidators, for which we don't know their mnemonic
// so we need to find the mnemonic of the consumer val now to enter it by name in the map
// the consumer vals right now are CrossChainValidators, for which we don't know their mnemonic
// so we need to find the mnemonic of the consumer val now to enter it by name in the map

// get the address on the provider that corresponds to the consumer address
providerConsAddr, found := driver.providerKeeper().GetValidatorByConsumerAddr(driver.providerCtx(), consumer, consAddr)
if !found {
providerConsAddr = providertypes.NewProviderConsAddress(consAddr.Address)
}
// get the address on the provider that corresponds to the consumer address
providerConsAddr, found := driver.providerKeeper().GetValidatorByConsumerAddr(driver.providerCtx(), consumer, consAddr)
if !found {
providerConsAddr = providertypes.NewProviderConsAddress(consAddr.Address)
}

// get the validator for that address on the provider
providerVal, found := driver.providerStakingKeeper().GetValidatorByConsAddr(driver.providerCtx(), providerConsAddr.Address)
require.True(t, found, "Error getting provider validator")
// get the validator for that address on the provider
providerVal, found := driver.providerStakingKeeper().GetValidatorByConsAddr(driver.providerCtx(), providerConsAddr.Address)
require.True(t, found, "Error getting provider validator")

// use the moniker of that validator
consumerCurValSet[providerVal.GetMoniker()] = val.Power
// use the moniker of that validator
consumerCurValSet[providerVal.GetMoniker()] = val.Power
}
}
require.NoError(t, CompareValSet(modelValSet, consumerCurValSet), "Validator sets do not match for consumer %v", consumer)
}
Expand Down
2 changes: 2 additions & 0 deletions tests/mbt/driver/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ type Stats struct {
numTxs int

totalBlockTimePassedPerTrace time.Duration

numKeyAssignments int
}
3 changes: 2 additions & 1 deletion tests/mbt/model/ccv_model.qnt
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ module ccv_model {
any {
nondet node = oneOf(nodes)
nondet consumerAddr = oneOf(consumerAddresses)
KeyAssignment("consumer1", node, consumerAddr),
nondet consumer = oneOf(runningConsumers)
KeyAssignment(consumer, node, consumerAddr),
}

action KeyAssignment(
Expand Down
6 changes: 6 additions & 0 deletions testutil/integration/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import (
tmtypes "github.com/cometbft/cometbft/types"
)

// creates n signers. does not return a validator set, just only the PrivValidators by their addresses
func CreateSigners(n int) (map[string]tmtypes.PrivValidator, error) {
_, _, signersByAddress, err := CreateValidators(n)
return signersByAddress, err
}

func CreateValidators(n int) (
*tmtypes.ValidatorSet, []types.ValidatorUpdate, map[string]tmtypes.PrivValidator, error,
) {
Expand Down

0 comments on commit 35e5d6c

Please sign in to comment.