Skip to content

Commit

Permalink
cleanups around checkpoint-recovery (#6071)
Browse files Browse the repository at this point in the history
## Motivation

few refactors to simplify and untangle the code
  • Loading branch information
poszu committed Jun 28, 2024
1 parent 2dde3a8 commit 679a98d
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 163 deletions.
7 changes: 7 additions & 0 deletions checkpoint/recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ func Recover(
fs afero.Fs,
cfg *RecoverConfig,
) (*PreservedData, error) {
if len(cfg.Uri) == 0 {
return nil, errors.New("recovery uri not set")
}
if cfg.Restore == 0 {
return nil, errors.New("restore layer not set")
}
logger.Info("recovering from checkpoint", zap.String("url", cfg.Uri), zap.Stringer("restore", cfg.Restore))
db, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile))
if err != nil {
return nil, fmt.Errorf("open old database: %w", err)
Expand Down
75 changes: 51 additions & 24 deletions checkpoint/recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func atxEqual(
commitAtx types.ATXID,
vrfnonce types.VRFPostIndex,
) {
tb.Helper()
require.True(tb, bytes.Equal(sAtx.ID, vAtx.ID().Bytes()))
require.EqualValues(tb, sAtx.Epoch, vAtx.PublishEpoch)
require.True(tb, bytes.Equal(sAtx.CommitmentAtx, commitAtx.Bytes()))
Expand All @@ -63,6 +64,7 @@ func atxEqual(
}

func accountEqual(tb testing.TB, cacct types.AccountSnapshot, acct *types.Account) {
tb.Helper()
require.True(tb, bytes.Equal(cacct.Address, acct.Address.Bytes()))
require.Equal(tb, cacct.Balance, acct.Balance)
require.Equal(tb, cacct.Nonce, acct.NextNonce)
Expand All @@ -76,6 +78,7 @@ func accountEqual(tb testing.TB, cacct types.AccountSnapshot, acct *types.Accoun
}

func verifyDbContent(tb testing.TB, db *sql.Database) {
tb.Helper()
var expected types.Checkpoint
require.NoError(tb, json.Unmarshal([]byte(checkpointData), &expected))
expAtx := map[types.ATXID]types.AtxSnapshot{}
Expand Down Expand Up @@ -117,6 +120,7 @@ func verifyDbContent(tb testing.TB, db *sql.Database) {
}

func TestRecover(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -148,9 +152,6 @@ func TestRecover(t *testing.T) {

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

fs := afero.NewMemMapFs()
cfg := &checkpoint.RecoverConfig{
GoldenAtx: goldenAtx,
Expand All @@ -166,13 +167,13 @@ func TestRecover(t *testing.T) {
require.NoError(t, fs.MkdirAll(bsdir, 0o700))
db := sql.InMemory()
localDB := localsql.InMemory()
preserve, err := checkpoint.RecoverWithDb(ctx, zaptest.NewLogger(t), db, localDB, fs, cfg)
data, err := checkpoint.RecoverWithDb(context.Background(), zaptest.NewLogger(t), db, localDB, fs, cfg)
if tc.expErr != nil {
require.ErrorIs(t, err, tc.expErr)
return
}
require.NoError(t, err)
require.Nil(t, preserve)
require.Nil(t, data)
newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile))
require.NoError(t, err)
require.NotNil(t, newDB)
Expand All @@ -189,6 +190,7 @@ func TestRecover(t *testing.T) {
}

func TestRecover_SameRecoveryInfo(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -224,11 +226,29 @@ func TestRecover_SameRecoveryInfo(t *testing.T) {
require.True(t, exist)
}

func TestRecover_URIMustBeSet(t *testing.T) {
t.Parallel()
cfg := &checkpoint.RecoverConfig{}
d, err := checkpoint.Recover(context.Background(), zaptest.NewLogger(t), afero.NewMemMapFs(), cfg)
require.ErrorContains(t, err, "uri not set")
require.Nil(t, d)
}

func TestRecover_RestoreLayerCannotBeZero(t *testing.T) {
t.Parallel()
cfg := &checkpoint.RecoverConfig{
Uri: "http://nowhere/snapshot-15",
}
_, err := checkpoint.Recover(context.Background(), zaptest.NewLogger(t), afero.NewMemMapFs(), cfg)
require.ErrorContains(t, err, "restore layer not set")
}

func validateAndPreserveData(
tb testing.TB,
db *sql.Database,
deps []*checkpoint.AtxDep,
) {
tb.Helper()
lg := zaptest.NewLogger(tb)
ctrl := gomock.NewController(tb)
mclock := activation.NewMocklayerClock(ctrl)
Expand Down Expand Up @@ -299,7 +319,6 @@ func validateAndPreserveData(
}

func newChainedAtx(
tb testing.TB,
prev, pos types.ATXID,
commitAtx *types.ATXID,
poetProofRef types.PoetProofRef,
Expand Down Expand Up @@ -377,24 +396,24 @@ func createInterlinkedAtxChain(
}

// epoch 2
sig1Atx1 := newChainedAtx(tb, types.EmptyATXID, goldenAtx, &goldenAtx, poetRef(), 2, 0, 113, sig1)
sig1Atx1 := newChainedAtx(types.EmptyATXID, goldenAtx, &goldenAtx, poetRef(), 2, 0, 113, sig1)
// epoch 3
sig1Atx2 := newChainedAtx(tb, sig1Atx1.ID, sig1Atx1.ID, nil, poetRef(), 3, 1, 0, sig1)
sig1Atx2 := newChainedAtx(sig1Atx1.ID, sig1Atx1.ID, nil, poetRef(), 3, 1, 0, sig1)
// epoch 4
sig1Atx3 := newChainedAtx(tb, sig1Atx2.ID, sig1Atx2.ID, nil, poetRef(), 4, 2, 0, sig1)
sig1Atx3 := newChainedAtx(sig1Atx2.ID, sig1Atx2.ID, nil, poetRef(), 4, 2, 0, sig1)
commitAtxID := sig1Atx2.ID
sig2Atx1 := newChainedAtx(tb, types.EmptyATXID, sig1Atx2.ID, &commitAtxID, poetRef(), 4, 0, 513, sig2)
sig2Atx1 := newChainedAtx(types.EmptyATXID, sig1Atx2.ID, &commitAtxID, poetRef(), 4, 0, 513, sig2)
// epoch 5
sig1Atx4 := newChainedAtx(tb, sig1Atx3.ID, sig2Atx1.ID, nil, poetRef(), 5, 3, 0, sig1)
sig1Atx4 := newChainedAtx(sig1Atx3.ID, sig2Atx1.ID, nil, poetRef(), 5, 3, 0, sig1)
// epoch 6
sig1Atx5 := newChainedAtx(tb, sig1Atx4.ID, sig1Atx4.ID, nil, poetRef(), 6, 4, 0, sig1)
sig2Atx2 := newChainedAtx(tb, sig2Atx1.ID, sig1Atx4.ID, nil, poetRef(), 6, 1, 0, sig2)
sig1Atx5 := newChainedAtx(sig1Atx4.ID, sig1Atx4.ID, nil, poetRef(), 6, 4, 0, sig1)
sig2Atx2 := newChainedAtx(sig2Atx1.ID, sig1Atx4.ID, nil, poetRef(), 6, 1, 0, sig2)
// epoch 7
sig1Atx6 := newChainedAtx(tb, sig1Atx5.ID, sig2Atx2.ID, nil, poetRef(), 7, 5, 0, sig1)
sig1Atx6 := newChainedAtx(sig1Atx5.ID, sig2Atx2.ID, nil, poetRef(), 7, 5, 0, sig1)
// epoch 8
sig2Atx3 := newChainedAtx(tb, sig2Atx2.ID, sig1Atx6.ID, nil, poetRef(), 8, 2, 0, sig2)
sig2Atx3 := newChainedAtx(sig2Atx2.ID, sig1Atx6.ID, nil, poetRef(), 8, 2, 0, sig2)
// epoch 9
sig1Atx7 := newChainedAtx(tb, sig1Atx6.ID, sig2Atx3.ID, nil, poetRef(), 9, 6, 0, sig1)
sig1Atx7 := newChainedAtx(sig1Atx6.ID, sig2Atx3.ID, nil, poetRef(), 9, 6, 0, sig1)

vAtxs := []*checkpoint.AtxDep{
sig1Atx1,
Expand All @@ -413,12 +432,14 @@ func createInterlinkedAtxChain(
}

func createAtxChain(tb testing.TB, sig *signing.EdSigner) ([]*checkpoint.AtxDep, []*types.PoetProofMessage) {
tb.Helper()
other, err := signing.NewEdSigner()
require.NoError(tb, err)
return createInterlinkedAtxChain(tb, other, sig)
}

func createAtxChainDepsOnly(tb testing.TB) ([]*checkpoint.AtxDep, []*types.PoetProofMessage) {
tb.Helper()
other, err := signing.NewEdSigner()
require.NoError(tb, err)

Expand All @@ -430,11 +451,11 @@ func createAtxChainDepsOnly(tb testing.TB) ([]*checkpoint.AtxDep, []*types.PoetP
}

// epoch 2
othAtx1 := newChainedAtx(tb, types.EmptyATXID, goldenAtx, &goldenAtx, poetRef(), 2, 0, 113, other)
othAtx1 := newChainedAtx(types.EmptyATXID, goldenAtx, &goldenAtx, poetRef(), 2, 0, 113, other)
// epoch 3
othAtx2 := newChainedAtx(tb, othAtx1.ID, othAtx1.ID, nil, poetRef(), 3, 1, 0, other)
othAtx2 := newChainedAtx(othAtx1.ID, othAtx1.ID, nil, poetRef(), 3, 1, 0, other)
// epoch 4
othAtx3 := newChainedAtx(tb, othAtx2.ID, othAtx2.ID, nil, poetRef(), 4, 2, 0, other)
othAtx3 := newChainedAtx(othAtx2.ID, othAtx2.ID, nil, poetRef(), 4, 2, 0, other)
atxDeps := []*checkpoint.AtxDep{othAtx1, othAtx2, othAtx3}

return atxDeps, proofs
Expand Down Expand Up @@ -505,6 +526,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) {
for _, proof := range proofs {
encoded, err := codec.Encode(proof)
require.NoError(t, err)

ref, err := proof.Ref()
require.NoError(t, err)

Expand Down Expand Up @@ -550,6 +572,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) {
}

func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -591,6 +614,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) {
for _, proof := range proofs {
encoded, err := codec.Encode(proof)
require.NoError(t, err)

ref, err := proof.Ref()
require.NoError(t, err)

Expand Down Expand Up @@ -671,6 +695,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) {
}

func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -770,6 +795,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T)
}

func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -852,6 +878,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) {
}

func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -917,6 +944,7 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) {
}

func TestRecover_OwnAtxInCheckpoint(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
Expand All @@ -927,20 +955,19 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

data, err := hex.DecodeString("0230c5d75d42b84f98800eceb47bc9cc4d803058900a50346a09ff61d56b6582")
nid, err := hex.DecodeString("0230c5d75d42b84f98800eceb47bc9cc4d803058900a50346a09ff61d56b6582")
require.NoError(t, err)
nid := types.BytesToNodeID(data)
data, err = hex.DecodeString("98e47278c1f58acfd2b670a730f28898f74eb140482a07b91ff81f9ff0b7d9f4")
atxid, err := hex.DecodeString("98e47278c1f58acfd2b670a730f28898f74eb140482a07b91ff81f9ff0b7d9f4")
require.NoError(t, err)
atx := newAtx(types.ATXID(types.BytesToHash(data)), types.EmptyATXID, nil, 3, 1, 0, nid)
atx := newAtx(types.ATXID(atxid), types.EmptyATXID, nil, 3, 1, 0, nid)

cfg := &checkpoint.RecoverConfig{
GoldenAtx: goldenAtx,
DataDir: t.TempDir(),
DbFile: "test.sql",
LocalDbFile: "local.sql",
PreserveOwnAtx: true,
NodeIDs: []types.NodeID{nid},
NodeIDs: []types.NodeID{types.BytesToNodeID(nid)},
Uri: fmt.Sprintf("%s/snapshot-15", ts.URL),
Restore: types.LayerID(recoverLayer),
}
Expand Down
Loading

0 comments on commit 679a98d

Please sign in to comment.