diff --git a/CHANGELOG.md b/CHANGELOG.md index e237b3aafba..ac0a08d403b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -154,6 +154,9 @@ operating your own PoET and want to use certificate authentication please refer ATXs. This vulnerability allows an attacker to claim rewards for a full tick amount although they should not be eligible for them. +* [#6003](https://github.com/spacemeshos/go-spacemesh/pull/6003) Improve database schema handling. + This includes schema drift detection which may happen after running unreleased versions. + * [#6031](https://github.com/spacemeshos/go-spacemesh/pull/6031) Fixed an edge case where the storage units might have changed after the initial PoST was generated but before the first ATX has been emitted, invalidating the initial PoST. The node will now try to verify the initial PoST and regenerate it if necessary. diff --git a/CODING.md b/CODING.md index fbb458d693d..0e373b3746a 100644 --- a/CODING.md +++ b/CODING.md @@ -105,3 +105,79 @@ Some useful logging recommendations for cleaner output: ## Commit Messages For commit messages, follow [this guideline](https://www.conventionalcommits.org/en/v1.0.0/). Use reasonable length for the subject and body, ideally no longer than 72 characters. Use the imperative mood for subject lines. + +## Handling database schema changes + +go-spacemesh currently maintains 2 SQLite databases in the data folder: `state.sql` (state database) and `local.sql` (local database). It employs schema versioning for both databases, with a possibility to upgrade older versions of each database to the current schema version by means of running a series of migrations. Also, go-spacemesh tracks any schema drift (unexpected schema changes) in the databases. + +When a database is first created, the corresponding schema file embedded in go-spacemesh executable is used to initialize it: +* `sql/statesql/schema/schema.sql` for `state.sql` +* `sql/localsql/schema/schema.sql` for `local.sql` +The schema file includes `PRAGMA user_version = ...` which sets the version of the database schema. The version of the schema is equal to the number of migrations defined for the corresponding database (`state.sql` or `local.sql`). + +For an existing database, the `PRAGMA user_version` is checked against the expected version number. If the database's schema version is too new, go-spacemesh fails right away as an older go-spacemesh version cannot be expected to work with a database from a newer version. If the database version number is older than the expected version, go-spacemesh runs the necessary migration scripts embedded in go-spacemesh executable and updates `PRAGMA user_version = ...`. The migration scripts are located in the following folders: +* `sql/statesql/schema/migrations` for `state.sql` +* `sql/localsql/schema/migrations` for `local.sql` + +Additionally, some migrations ("coded migrations") can be implemented in Go code, in which case they reside in `.go` files located in `sql/statesql` and `sql/localsql` packages, respectively. It is worth noting that old coded migrations can be removed at some point, rendering database versions that are *too* old unusable with newer go-spacemesh versions. + +After all the migrations are run, go-spacemesh compares the schema of each database to the embedded schema scripts and if they differ, fails with an error message: +``` +Error: open sqlite db schema drift detected (uri file:data/state.sql): + ( + """ + ... // 82 identical lines + PRIMARY KEY (layer, block) + ); ++ CREATE TABLE foo(id int); + CREATE TABLE identities + ( + ... // 66 identical lines + """ + ) +``` + +In this case, a table named `foo` has somehow appeared in the database, causing go-spacemesh to fail due to the schema drift. The possible reasons for schema drift can be the following: +* running an unreleased version of go-spacemesh using your data folder. The unreleased version may contain migrations that may be changed before the release happens +* manual changes in the database +* external SQLite tooling used on the database that adds some tables, indices etc. + +In case if you want to run go-spacemesh with schema drift anyway, you can set `main.db-allow-schema-drift` to true. In this case, a warning with schema diff will be logged instead of failing. + +The schema changes in go-spacemesh code should be always done by means of adding migrations. Let's for example create a new migration (use zero-padded N+1 instead of 0010 with N being the number of the last migration for the local db): + +```console +$ echo 'CREATE TABLE foo(id int);' >sql/localsql/schema/migrations/0010_foo.sql +``` + +After that, we update the schema files +```console +$ make generate +$ # alternative: cd sql/localsql && go generate +$ git diff sql/localsql/schema/schema.sql +diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql +index 02c44d3cc..ebcdf4278 100755 +--- a/sql/localsql/schema/schema.sql ++++ b/sql/localsql/schema/schema.sql +@@ -1,4 +1,4 @@ +-PRAGMA user_version = 9; ++PRAGMA user_version = 10; + CREATE TABLE atx_sync_requests + ( + epoch INT NOT NULL, +@@ -24,6 +24,7 @@ CREATE TABLE "challenge" + post_indices VARCHAR, + post_pow UNSIGNED LONG INT + , poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; ++CREATE TABLE foo(id int); + CREATE TABLE malfeasance_sync_state + ( + id INT NOT NULL PRIMARY KEY, +``` + +Note that the changes include both the new table and an updated `PRAGMA user_version` line. +The changes in the schema file must be committed along with the migration we added. +```console +$ git add sql/localsql/schema/migrations/0010_foo.sql sql/localsql/schema.sql +$ git commit -m "sql: add a test migration" +``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 13c79edce51..d5e4d4f1346 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ Thank you for considering to contribute to the go-spacemesh open source project. We welcome contributions large and small and we actively accept contributions. - go-spacemesh is part of [The Spacemesh open source project](https://spacemesh.io), and is MIT licensed free open source software. -- Please make sure to scan the [open issues](https://github.com/spacemeshos/go-spacemesh/issues). +- Please make sure to scan the [open issues](https://github.com/spacemeshos/go-spacemesh/issues). - Search the closed ones before reporting things, and help us with the open ones. - Make sure that you are able to contribute to MIT licensed free open software (no legal issues please). - Introduce yourself, ask questions about issues or talk about things on our [discord server](https://chat.spacemesh.io/). @@ -39,7 +39,7 @@ Thank you for considering to contribute to the go-spacemesh open source project. # Code Guidelines Please follow these guidelines for your PR to be reviewed and be considered for merging into the project. -1. Document all methods and functions using [go commentary](https://golang.org/doc/effective_go.html#commentary). +1. Document all methods and functions using [go commentary](https://golang.org/doc/effective_go.html#commentary). 2. Provide at least one unit test for each function and method. 3. Provide at least one integration test for each feature with a flow which involves more than one function call. Your tests should reflect the main ways that your code should be used. 4. Run `go mod tidy`, `go fmt ./...` and `make lint` to format and lint your code before submitting your PR. @@ -49,7 +49,7 @@ Please follow these guidelines for your PR to be reviewed and be considered for - Check for existing 3rd-party packages in the vendor folder `./vendor` before adding a new dependency. - Use [govendor](https://github.com/kardianos/govendor) to add a new dependency. -# Working on a funded issue +# Working on a funded issue ## Step 1 - Discover :sunrise_over_mountains: - Browse the [open funded issues](https://github.com/spacemeshos/go-spacemesh/labels/funded%20%3Amoneybag%3A) in our github repo, or on our [gitcoin.io funded issues page](https://gitcoin.co/profile/spacemeshos). @@ -68,6 +68,6 @@ Please follow these guidelines for your PR to be reviewed and be considered for ## Step 3 - Get paid :moneybag: - When ready, submit your PR for review and go through the code review process with one of our maintainers. - Expect a review process that ensures that you have followed our code guidelines at that your design and implementation are solid. You are expected to revision your code based on reviewers comments. -- You should receive your bounty as soon as your PR is approved and merged by one of our maintainers. +- You should receive your bounty as soon as your PR is approved and merged by one of our maintainers. Please review our funded issues program [legal notes](https://github.com/spacemeshos/go-spacemesh/blob/master/legal.md). diff --git a/activation/activation.go b/activation/activation.go index 567332e36e2..d6db514e9ca 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -28,7 +28,6 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -78,7 +77,7 @@ type Builder struct { conf Config db sql.Executor atxsdata *atxsdata.Data - localDB *localsql.Database + localDB sql.LocalDatabase publisher pubsub.Publisher nipostBuilder nipostBuilder validator nipostValidator @@ -174,7 +173,7 @@ func NewBuilder( conf Config, db sql.Executor, atxsdata *atxsdata.Data, - localDB *localsql.Database, + localDB sql.LocalDatabase, publisher pubsub.Publisher, nipostBuilder nipostBuilder, layerClock layerClock, diff --git a/activation/activation_test.go b/activation/activation_test.go index e073d2a334c..0305add4c31 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -33,6 +33,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" sqlmocks "github.com/spacemeshos/go-spacemesh/sql/mocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) // ========== Vars / Consts ========== @@ -54,7 +55,7 @@ func TestMain(m *testing.M) { type testAtxBuilder struct { *Builder db sql.Executor - localDb *localsql.Database + localDb sql.LocalDatabase goldenATXID types.ATXID observedLogs *observer.ObservedLogs @@ -77,7 +78,7 @@ func newTestBuilder(tb testing.TB, numSigners int, opts ...BuilderOption) *testA ctrl := gomock.NewController(tb) tab := &testAtxBuilder{ - db: sql.InMemory(), + db: statesql.InMemory(), localDb: localsql.InMemory(sql.WithConnections(numSigners)), goldenATXID: types.ATXID(types.HexToHash32("77777")), diff --git a/activation/certifier.go b/activation/certifier.go index bbd7cf60f65..636a14ca78e 100644 --- a/activation/certifier.go +++ b/activation/certifier.go @@ -21,7 +21,6 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" - "github.com/spacemeshos/go-spacemesh/sql/localsql" certifierdb "github.com/spacemeshos/go-spacemesh/sql/localsql/certifier" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -80,14 +79,14 @@ type CertifyResponse struct { type Certifier struct { logger *zap.Logger - db *localsql.Database + db sql.LocalDatabase client certifierClient certifications singleflight.Group } func NewCertifier( - db *localsql.Database, + db sql.LocalDatabase, logger *zap.Logger, client certifierClient, ) *Certifier { @@ -145,7 +144,7 @@ type CertifierClient struct { client *retryablehttp.Client logger *zap.Logger db sql.Executor - localDb *localsql.Database + localDb sql.LocalDatabase } type certifierClientOpts func(*CertifierClient) @@ -160,7 +159,7 @@ func WithCertifierClientConfig(cfg CertifierClientConfig) certifierClientOpts { func NewCertifierClient( db sql.Executor, - localDb *localsql.Database, + localDb sql.LocalDatabase, logger *zap.Logger, opts ...certifierClientOpts, ) *CertifierClient { diff --git a/activation/certifier_test.go b/activation/certifier_test.go index fa0db876c01..cb2ad5d3067 100644 --- a/activation/certifier_test.go +++ b/activation/certifier_test.go @@ -16,6 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql" certdb "github.com/spacemeshos/go-spacemesh/sql/localsql/certifier" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestPersistsCerts(t *testing.T) { @@ -113,7 +114,7 @@ func TestObtainingPost(t *testing.T) { id := types.RandomNodeID() t.Run("no POST or ATX", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() certifier := NewCertifierClient(db, localDb, zaptest.NewLogger(t)) @@ -121,7 +122,7 @@ func TestObtainingPost(t *testing.T) { require.ErrorContains(t, err, "PoST not found") }) t.Run("initial POST available", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() post := nipost.Post{ @@ -142,7 +143,7 @@ func TestObtainingPost(t *testing.T) { require.Equal(t, post, *got) }) t.Run("initial POST unavailable but ATX exists", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() atx := newInitialATXv1(t, types.RandomATXID()) diff --git a/activation/e2e/activation_test.go b/activation/e2e/activation_test.go index e2bd4dd0983..1472f4453c6 100644 --- a/activation/e2e/activation_test.go +++ b/activation/e2e/activation_test.go @@ -27,9 +27,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -61,7 +61,7 @@ func Test_BuilderWithMultipleClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() svc := grpcserver.NewPostService(logger, grpcserver.PostServiceQueryInterval(100*time.Millisecond)) diff --git a/activation/e2e/atx_merge_test.go b/activation/e2e/atx_merge_test.go index e0dfcca0ff2..8b1fdab98c2 100644 --- a/activation/e2e/atx_merge_test.go +++ b/activation/e2e/atx_merge_test.go @@ -27,11 +27,11 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" @@ -205,7 +205,7 @@ func Test_MarryAndMerge(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) localDB := localsql.InMemory() diff --git a/activation/e2e/builds_atx_v2_test.go b/activation/e2e/builds_atx_v2_test.go index 2baf3260606..af8fad6a550 100644 --- a/activation/e2e/builds_atx_v2_test.go +++ b/activation/e2e/builds_atx_v2_test.go @@ -25,9 +25,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -53,7 +53,7 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { require.NoError(t, err) cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) opts := testPostSetupOpts(t) diff --git a/activation/e2e/certifier_client_test.go b/activation/e2e/certifier_client_test.go index 9ff1a358870..2f09e1cd958 100644 --- a/activation/e2e/certifier_client_test.go +++ b/activation/e2e/certifier_client_test.go @@ -23,9 +23,9 @@ import ( "github.com/spacemeshos/go-spacemesh/api/grpcserver" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestCertification(t *testing.T) { @@ -34,7 +34,7 @@ func TestCertification(t *testing.T) { require.NoError(t, err) cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() opts := testPostSetupOpts(t) diff --git a/activation/e2e/checkpoint_merged_test.go b/activation/e2e/checkpoint_merged_test.go index f06c319f3dd..970ac803b38 100644 --- a/activation/e2e/checkpoint_merged_test.go +++ b/activation/e2e/checkpoint_merged_test.go @@ -25,11 +25,11 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -43,7 +43,7 @@ func Test_CheckpointAfterMerge(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) localDB := localsql.InMemory() @@ -261,7 +261,7 @@ func Test_CheckpointAfterMerge(t *testing.T) { require.NoError(t, err) require.Nil(t, data) - newDB, err := sql.Open("file:" + recoveryCfg.DbPath()) + newDB, err := statesql.Open("file:" + recoveryCfg.DbPath()) require.NoError(t, err) defer newDB.Close() diff --git a/activation/e2e/checkpoint_test.go b/activation/e2e/checkpoint_test.go index 0120189fe41..0e28ee72d9d 100644 --- a/activation/e2e/checkpoint_test.go +++ b/activation/e2e/checkpoint_test.go @@ -25,9 +25,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -46,7 +46,7 @@ func TestCheckpoint_PublishingSoloATXs(t *testing.T) { require.NoError(t, err) cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) opts := testPostSetupOpts(t) @@ -181,7 +181,7 @@ func TestCheckpoint_PublishingSoloATXs(t *testing.T) { require.NoError(t, err) require.Nil(t, data) - newDB, err := sql.Open("file:" + recoveryCfg.DbPath()) + newDB, err := statesql.Open("file:" + recoveryCfg.DbPath()) require.NoError(t, err) defer newDB.Close() diff --git a/activation/e2e/nipost_test.go b/activation/e2e/nipost_test.go index 74a9e29b579..61891ac2aa3 100644 --- a/activation/e2e/nipost_test.go +++ b/activation/e2e/nipost_test.go @@ -23,9 +23,9 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -136,7 +136,7 @@ func initPost( logger := zaptest.NewLogger(tb) syncer := syncedSyncer(tb) - db := sql.InMemory() + db := statesql.InMemory() mgr, err := activation.NewPostSetupManager(cfg, logger, db, atxsdata.New(), golden, syncer, nil) require.NoError(tb, err) @@ -157,7 +157,7 @@ func TestNIPostBuilderWithClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() opts := testPostSetupOpts(t) @@ -243,7 +243,7 @@ func Test_NIPostBuilderWithMultipleClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() opts := testPostSetupOpts(t) svc := grpcserver.NewPostService(logger, grpcserver.PostServiceQueryInterval(100*time.Millisecond)) diff --git a/activation/e2e/validation_test.go b/activation/e2e/validation_test.go index dde1ff6162e..0880f2346f1 100644 --- a/activation/e2e/validation_test.go +++ b/activation/e2e/validation_test.go @@ -16,8 +16,8 @@ import ( "github.com/spacemeshos/go-spacemesh/api/grpcserver" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestValidator_Validate(t *testing.T) { @@ -29,7 +29,7 @@ func TestValidator_Validate(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() validator := activation.NewMocknipostValidator(gomock.NewController(t)) @@ -50,7 +50,7 @@ func TestValidator_Validate(t *testing.T) { GracePeriod: epoch / 4, } - poetDb := activation.NewPoetDb(sql.InMemory(), logger.Named("poetDb")) + poetDb := activation.NewPoetDb(statesql.InMemory(), logger.Named("poetDb")) client := ae2e.NewTestPoetClient(1, poetCfg) poetService := activation.NewPoetServiceWithClient(poetDb, client, poetCfg, logger) diff --git a/activation/handler_test.go b/activation/handler_test.go index fd7a6689622..29a0f6e3e78 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -31,6 +31,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -195,7 +196,7 @@ func newTestHandlerMocks(tb testing.TB, golden types.ATXID) handlerMocks { func newTestHandler(tb testing.TB, goldenATXID types.ATXID, opts ...HandlerOption) *testHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) edVerifier := signing.NewEdVerifier() mocks := newTestHandlerMocks(tb, goldenATXID) diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 46b24cd2bd4..a3be09c5b4a 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -448,7 +448,7 @@ func (h *HandlerV1) checkWrongPrevAtx( func (h *HandlerV1) checkMalicious( ctx context.Context, - tx *sql.Tx, + tx sql.Transaction, watx *wire.ActivationTxV1, ) (*mwire.MalfeasanceProof, error) { malicious, err := identities.IsMalicious(tx, watx.SmesherID) @@ -472,7 +472,7 @@ func (h *HandlerV1) storeAtx( watx *wire.ActivationTxV1, ) (*mwire.MalfeasanceProof, error) { var proof *mwire.MalfeasanceProof - if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { var err error proof, err = h.checkMalicious(ctx, tx, watx) if err != nil { diff --git a/activation/handler_v1_test.go b/activation/handler_v1_test.go index 5aba8cb7f95..0ebb6f44dd7 100644 --- a/activation/handler_v1_test.go +++ b/activation/handler_v1_test.go @@ -20,9 +20,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type v1TestHandler struct { @@ -33,7 +33,7 @@ type v1TestHandler struct { func newV1TestHandler(tb testing.TB, goldenATXID types.ATXID) *v1TestHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) mocks := newTestHandlerMocks(tb, goldenATXID) return &v1TestHandler{ HandlerV1: &HandlerV1{ diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 6755b9f7e0d..d1d8dd4a15b 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -666,7 +666,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( return &result, nil } -func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activationTx) error { +func (h *HandlerV2) checkMalicious(ctx context.Context, tx sql.Transaction, atx *activationTx) error { malicious, err := identities.IsMalicious(tx, atx.SmesherID) if err != nil { return fmt.Errorf("checking if node is malicious: %w", err) @@ -707,7 +707,7 @@ func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activat return nil } -func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { +func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { for _, m := range atx.marriages { mATX, err := identities.MarriageATX(tx, m.id) if err != nil { @@ -737,7 +737,7 @@ func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx *sql.Tx, atx *activ return false, nil } -func (h *HandlerV2) checkDoublePost(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { +func (h *HandlerV2) checkDoublePost(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { for id := range atx.ids { atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch) switch { @@ -762,7 +762,7 @@ func (h *HandlerV2) checkDoublePost(ctx context.Context, tx *sql.Tx, atx *activa return false, nil } -func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *activationTx) (bool, error) { +func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, watx *activationTx) (bool, error) { if watx.MarriageATX == nil { return false, nil } @@ -789,7 +789,7 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *acti // Store an ATX in the DB. func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx *activationTx) error { - if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { if len(watx.marriages) != 0 { marriageData := identities.MarriageData{ ATX: atx.ID(), @@ -822,7 +822,7 @@ func (h *HandlerV2) storeAtx(ctx context.Context, atx *types.ActivationTx, watx atxs.AtxAdded(h.cdb, atx) var malicious bool - err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { // malfeasance check happens after storing the ATX because storing updates the marriage set // that is needed for the malfeasance proof // TODO(mafa): don't store own ATX if it would mark the node as malicious diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 46f6e76834b..56dbe7303f9 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -26,6 +26,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type v2TestHandler struct { @@ -46,7 +47,7 @@ const ( func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) mocks := newTestHandlerMocks(tb, golden) return &v2TestHandler{ HandlerV2: &HandlerV2{ diff --git a/activation/malfeasance_test.go b/activation/malfeasance_test.go index a90ac6c5dee..cc2dc07b893 100644 --- a/activation/malfeasance_test.go +++ b/activation/malfeasance_test.go @@ -21,6 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func createIdentity(tb testing.TB, db sql.Executor, sig *signing.EdSigner) { @@ -40,11 +41,11 @@ type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { @@ -237,12 +238,12 @@ type testInvalidPostIndexHandler struct { *InvalidPostIndexHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase mockPostVerifier *MockPostVerifier } func newTestInvalidPostIndexHandler(tb testing.TB) *testInvalidPostIndexHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { @@ -428,11 +429,11 @@ type testInvalidPrevATXHandler struct { *InvalidPrevATXHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestInvalidPrevATXHandler(tb testing.TB) *testInvalidPrevATXHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/activation/nipost.go b/activation/nipost.go index ec9f8ac9c51..1682861889e 100644 --- a/activation/nipost.go +++ b/activation/nipost.go @@ -22,7 +22,6 @@ import ( "github.com/spacemeshos/go-spacemesh/metrics/public" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -48,7 +47,7 @@ var ErrInvalidInitialPost = errors.New("invalid initial post") // NIPostBuilder holds the required state and dependencies to create Non-Interactive Proofs of Space-Time (NIPost). type NIPostBuilder struct { - localDB *localsql.Database + localDB sql.LocalDatabase poetProvers map[string]PoetService postService postService @@ -78,7 +77,7 @@ func NipostbuilderWithPostStates(ps PostStates) NIPostBuilderOption { // NewNIPostBuilder returns a NIPostBuilder. func NewNIPostBuilder( - db *localsql.Database, + db sql.LocalDatabase, postService postService, lg *zap.Logger, poetCfg PoetConfig, diff --git a/activation/nipost_test.go b/activation/nipost_test.go index 41c5b8c7822..24afc65e8d6 100644 --- a/activation/nipost_test.go +++ b/activation/nipost_test.go @@ -52,7 +52,7 @@ type testNIPostBuilder struct { observedLogs *observer.ObservedLogs eventSub <-chan events.UserEvent - mDb *localsql.Database + mDb sql.LocalDatabase mLogger *zap.Logger mPoetDb *MockpoetDbAPI mClock *MocklayerClock diff --git a/activation/poet_client_test.go b/activation/poet_client_test.go index 397fbc06b09..866348d9b6a 100644 --- a/activation/poet_client_test.go +++ b/activation/poet_client_test.go @@ -23,8 +23,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql/certifier" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_HTTPPoetClient_ParsesURL(t *testing.T) { @@ -459,7 +459,7 @@ func TestPoetService_CachesCertifierInfo(t *testing.T) { cfg := DefaultPoetConfig() cfg.CertifierInfoCacheTTL = tc.ttl client := NewMockPoetClient(gomock.NewController(t)) - db := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + db := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) client.EXPECT().Address().Return("some_addr").AnyTimes() client.EXPECT().Info(gomock.Any()).Return(&types.PoetInfo{}, nil) diff --git a/activation/poetdb.go b/activation/poetdb.go index c5123bb56c0..3905564f89f 100644 --- a/activation/poetdb.go +++ b/activation/poetdb.go @@ -21,12 +21,12 @@ import ( // PoetDb is a database for PoET proofs. type PoetDb struct { - sqlDB *sql.Database + sqlDB sql.StateDatabase logger *zap.Logger } // NewPoetDb returns a new PoET handler. -func NewPoetDb(db *sql.Database, log *zap.Logger) *PoetDb { +func NewPoetDb(db sql.StateDatabase, log *zap.Logger) *PoetDb { return &PoetDb{sqlDB: db, logger: log} } diff --git a/activation/poetdb_test.go b/activation/poetdb_test.go index 6e86452507e..10fd77283e3 100644 --- a/activation/poetdb_test.go +++ b/activation/poetdb_test.go @@ -16,7 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) var ( @@ -63,7 +63,7 @@ func getPoetProof(t *testing.T) types.PoetProofMessage { func TestPoetDbHappyFlow(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) r.NoError(poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature)) ref, err := msg.Ref() @@ -83,7 +83,7 @@ func TestPoetDbHappyFlow(t *testing.T) { func TestPoetDbInvalidPoetProof(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) msg.PoetProof.Root = []byte("some other root") err := poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature) @@ -99,7 +99,7 @@ func TestPoetDbInvalidPoetProof(t *testing.T) { func TestPoetDbInvalidPoetStatement(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) msg.Statement = types.CalcHash32([]byte("some other statement")) err := poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature) @@ -115,7 +115,7 @@ func TestPoetDbInvalidPoetStatement(t *testing.T) { func TestPoetDbNonExistingKeys(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) _, err := poetDb.GetProofRef(msg.PoetServiceID, "0") r.EqualError( diff --git a/activation/post_supervisor_test.go b/activation/post_supervisor_test.go index f26021a3645..7c21cbf4c37 100644 --- a/activation/post_supervisor_test.go +++ b/activation/post_supervisor_test.go @@ -24,7 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func closedChan() <-chan struct{} { @@ -56,7 +56,7 @@ func newPostManager(t *testing.T, cfg PostConfig, opts PostSetupOpts) *PostSetup close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() mgr, err := NewPostSetupManager(cfg, zaptest.NewLogger(t), db, atxsdata, types.RandomATXID(), syncer, validator) require.NoError(t, err) diff --git a/activation/post_test.go b/activation/post_test.go index b3e24eabda3..58b111fc291 100644 --- a/activation/post_test.go +++ b/activation/post_test.go @@ -17,8 +17,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestPostSetupManager(t *testing.T) { @@ -369,7 +369,7 @@ func newTestPostManager(tb testing.TB) *testPostManager { syncer.EXPECT().RegisterForATXSynced().AnyTimes().Return(synced) logger := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), logger) + cdb := datastore.NewCachedDB(statesql.InMemory(), logger) mgr, err := NewPostSetupManager(DefaultPostConfig(), logger, cdb, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) diff --git a/activation/validation_test.go b/activation/validation_test.go index 6ab64fd5c7f..09c5d7c5edc 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -16,8 +16,8 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_Validation_VRFNonce(t *testing.T) { @@ -476,7 +476,7 @@ func TestValidateMerkleProof(t *testing.T) { } func TestVerifyChainDeps(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() goldenATXID := types.ATXID{2, 3, 4} signer, err := signing.NewEdSigner() @@ -662,7 +662,7 @@ func TestVerifyChainDeps(t *testing.T) { } func TestVerifyChainDepsAfterCheckpoint(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() goldenATXID := types.ATXID{2, 3, 4} signer, err := signing.NewEdSigner() diff --git a/activation/wire/malfeasance_double_marry_test.go b/activation/wire/malfeasance_double_marry_test.go index 351aa265ed1..90bfd49de98 100644 --- a/activation/wire/malfeasance_double_marry_test.go +++ b/activation/wire/malfeasance_double_marry_test.go @@ -1,3 +1,7 @@ +//go:build exclude + +// FIXME: tmp circular dep fix + package wire import ( @@ -9,8 +13,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_DoubleMarryProof(t *testing.T) { @@ -21,7 +25,7 @@ func Test_DoubleMarryProof(t *testing.T) { require.NoError(t, err) t.Run("valid", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -50,7 +54,7 @@ func Test_DoubleMarryProof(t *testing.T) { }) t.Run("does not contain same certificate owner", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atx1 := newActivationTxV2( withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), @@ -79,7 +83,7 @@ func Test_DoubleMarryProof(t *testing.T) { atx1 := newActivationTxV2() atx1.Sign(sig) - db := sql.InMemory() + db := statesql.InMemory() proof, err := NewDoubleMarryProof(db, atx1, atx1, sig.NodeID()) require.ErrorContains(t, err, "ATXs have the same ID") require.Nil(t, proof) @@ -103,7 +107,7 @@ func Test_DoubleMarryProof(t *testing.T) { }) t.Run("invalid marriage proof", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -150,7 +154,7 @@ func Test_DoubleMarryProof(t *testing.T) { }) t.Run("invalid certificate proof", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -197,7 +201,7 @@ func Test_DoubleMarryProof(t *testing.T) { }) t.Run("invalid atx signature", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -233,7 +237,7 @@ func Test_DoubleMarryProof(t *testing.T) { }) t.Run("invalid certificate signature", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -269,7 +273,7 @@ func Test_DoubleMarryProof(t *testing.T) { }) t.Run("unknown reference ATX", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atx1 := newActivationTxV2( withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), diff --git a/activation/wire/wire_v2_test.go b/activation/wire/wire_v2_test.go index e0303affb07..9c3cd5f3f6f 100644 --- a/activation/wire/wire_v2_test.go +++ b/activation/wire/wire_v2_test.go @@ -1,58 +1,56 @@ package wire import ( - "math/rand/v2" "testing" fuzz "github.com/google/gofuzz" "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/signing" ) -type testAtxV2Opt func(*ActivationTxV2) +// type testAtxV2Opt func(*ActivationTxV2) -func withMarriageCertificate(sig *signing.EdSigner, refAtx types.ATXID, atxPublisher types.NodeID) testAtxV2Opt { - return func(atx *ActivationTxV2) { - certificate := MarriageCertificate{ - ReferenceAtx: refAtx, - Signature: sig.Sign(signing.MARRIAGE, atxPublisher.Bytes()), - } - atx.Marriages = append(atx.Marriages, certificate) - } -} +// func withMarriageCertificate(sig *signing.EdSigner, refAtx types.ATXID, atxPublisher types.NodeID) testAtxV2Opt { +// return func(atx *ActivationTxV2) { +// certificate := MarriageCertificate{ +// ReferenceAtx: refAtx, +// Signature: sig.Sign(signing.MARRIAGE, atxPublisher.Bytes()), +// } +// atx.Marriages = append(atx.Marriages, certificate) +// } +// } -func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { - atx := &ActivationTxV2{ - PublishEpoch: rand.N(types.EpochID(255)), - PositioningATX: types.RandomATXID(), - PreviousATXs: make([]types.ATXID, 1+rand.IntN(255)), - NiPosts: []NiPostsV2{ - { - Membership: MerkleProofV2{ - Nodes: make([]types.Hash32, 32), - }, - Challenge: types.RandomHash(), - Posts: []SubPostV2{ - { - MarriageIndex: rand.Uint32N(256), - PrevATXIndex: 0, - Post: PostV1{ - Nonce: 0, - Indices: make([]byte, 800), - Pow: 0, - }, - }, - }, - }, - }, - } - for _, opt := range opts { - opt(atx) - } - return atx -} +// func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { +// atx := &ActivationTxV2{ +// PublishEpoch: rand.N(types.EpochID(255)), +// PositioningATX: types.RandomATXID(), +// PreviousATXs: make([]types.ATXID, 1+rand.IntN(255)), +// NiPosts: []NiPostsV2{ +// { +// Membership: MerkleProofV2{ +// Nodes: make([]types.Hash32, 32), +// }, +// Challenge: types.RandomHash(), +// Posts: []SubPostV2{ +// { +// MarriageIndex: rand.Uint32N(256), +// PrevATXIndex: 0, +// Post: PostV1{ +// Nonce: 0, +// Indices: make([]byte, 800), +// Pow: 0, +// }, +// }, +// }, +// }, +// }, +// } +// for _, opt := range opts { +// opt(atx) +// } +// return atx +// } func Benchmark_ATXv2ID(b *testing.B) { f := fuzz.New() diff --git a/api/grpcserver/activation_service_test.go b/api/grpcserver/activation_service_test.go index eac99561f68..3339cdb4dba 100644 --- a/api/grpcserver/activation_service_test.go +++ b/api/grpcserver/activation_service_test.go @@ -17,6 +17,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_Highest_ReturnsGoldenAtxOnError(t *testing.T) { @@ -156,7 +157,7 @@ func TestGet_IdentityCanceled(t *testing.T) { atxProvider := grpcserver.NewMockatxProvider(ctrl) activationService := grpcserver.NewActivationService(atxProvider, types.ATXID{1}) - smesher, proof := grpcserver.BallotMalfeasance(t, sql.InMemory()) + smesher, proof := grpcserver.BallotMalfeasance(t, statesql.InMemory()) previous := types.RandomATXID() id := types.RandomATXID() atx := types.ActivationTx{ diff --git a/api/grpcserver/admin_service.go b/api/grpcserver/admin_service.go index eb7f4e99882..75bcbd15fff 100644 --- a/api/grpcserver/admin_service.go +++ b/api/grpcserver/admin_service.go @@ -34,14 +34,14 @@ const ( // AdminService exposes endpoints for node administration. type AdminService struct { - db *sql.Database + db sql.StateDatabase dataDir string recover func() p peers } // NewAdminService creates a new admin grpc service. -func NewAdminService(db *sql.Database, dataDir string, p peers) *AdminService { +func NewAdminService(db sql.StateDatabase, dataDir string, p peers) *AdminService { return &AdminService{ db: db, dataDir: dataDir, diff --git a/api/grpcserver/admin_service_test.go b/api/grpcserver/admin_service_test.go index cdd087c3036..02bb7f805ca 100644 --- a/api/grpcserver/admin_service_test.go +++ b/api/grpcserver/admin_service_test.go @@ -19,11 +19,12 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const snapshot uint32 = 15 -func newAtx(tb testing.TB, db *sql.Database) { +func newAtx(tb testing.TB, db sql.StateDatabase) { atx := &types.ActivationTx{ PublishEpoch: types.EpochID(2), Sequence: 0, @@ -41,7 +42,7 @@ func newAtx(tb testing.TB, db *sql.Database) { require.NoError(tb, atxs.SetPost(db, atx.ID(), types.EmptyATXID, 0, atx.SmesherID, atx.NumUnits)) } -func createMesh(tb testing.TB, db *sql.Database) { +func createMesh(tb testing.TB, db sql.StateDatabase) { for range 10 { newAtx(tb, db) } @@ -57,7 +58,7 @@ func createMesh(tb testing.TB, db *sql.Database) { } func TestAdminService_Checkpoint(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() createMesh(t, db) svc := NewAdminService(db, t.TempDir(), nil) cfg, cleanup := launchServer(t, svc) @@ -94,7 +95,7 @@ func TestAdminService_Checkpoint(t *testing.T) { } func TestAdminService_CheckpointError(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() svc := NewAdminService(db, t.TempDir(), nil) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -111,7 +112,7 @@ func TestAdminService_CheckpointError(t *testing.T) { } func TestAdminService_Recovery(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() recoveryCalled := atomic.Bool{} svc := NewAdminService(db, t.TempDir(), nil) svc.recover = func() { recoveryCalled.Store(true) } @@ -133,7 +134,7 @@ func TestAdminService_PeerInfo(t *testing.T) { ctrl := gomock.NewController(t) p := NewMockpeers(ctrl) - db := sql.InMemory() + db := statesql.InMemory() svc := NewAdminService(db, t.TempDir(), p) cfg, cleanup := launchServer(t, svc) diff --git a/api/grpcserver/debug_service.go b/api/grpcserver/debug_service.go index 5d13b47a004..cd107b41eb8 100644 --- a/api/grpcserver/debug_service.go +++ b/api/grpcserver/debug_service.go @@ -24,7 +24,7 @@ import ( // DebugService exposes global state data, output from the STF. type DebugService struct { - db *sql.Database + db sql.StateDatabase conState conservativeState netInfo networkInfo oracle oracle @@ -46,7 +46,7 @@ func (d *DebugService) String() string { } // NewDebugService creates a new grpc service using config data. -func NewDebugService(db *sql.Database, conState conservativeState, host networkInfo, oracle oracle, +func NewDebugService(db sql.StateDatabase, conState conservativeState, host networkInfo, oracle oracle, loggers map[string]*zap.AtomicLevel, ) *DebugService { return &DebugService{ diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index b8862f1bd2f..e4dd0ccad70 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -47,11 +47,11 @@ import ( peerinfomocks "github.com/spacemeshos/go-spacemesh/p2p/peerinfo/mocks" pubsubmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/txs" ) @@ -720,7 +720,7 @@ func TestMeshService(t *testing.T) { genesis := time.Unix(genTimeUnix, 0) genTime.EXPECT().GenesisTime().Return(genesis) genTime.EXPECT().CurrentLayer().Return(layerCurrent).AnyTimes() - db := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + db := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) svc := NewMeshService( db, meshAPIMock, @@ -1249,7 +1249,7 @@ func TestTransactionServiceSubmitUnsync(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil) - svc := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + svc := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -1288,7 +1288,7 @@ func TestTransactionServiceSubmitInvalidTx(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(errors.New("failed validation")) - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1321,7 +1321,7 @@ func TestTransactionService_SubmitNoConcurrency(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).Times(numTxs) - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1349,7 +1349,7 @@ func TestTransactionService(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1657,7 +1657,7 @@ func TestAccountMeshDataStream_comprehensive(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) grpcService := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, @@ -1836,7 +1836,7 @@ func TestLayerStream_comprehensive(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + db := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) grpcService := NewMeshService( db, @@ -1983,7 +1983,7 @@ func TestMultiService(t *testing.T) { genTime.EXPECT().GenesisTime().Return(genesis) svc1 := NewNodeService(peerCounter, meshAPIMock, genTime, syncer, "v0.0.0", "cafebabe") svc2 := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, @@ -2028,7 +2028,7 @@ func TestDebugService(t *testing.T) { ctrl := gomock.NewController(t) netInfo := NewMocknetworkInfo(ctrl) mOracle := NewMockoracle(ctrl) - db := sql.InMemory() + db := statesql.InMemory() testLog := zap.NewAtomicLevel() loggers := map[string]*zap.AtomicLevel{ @@ -2220,7 +2220,7 @@ func TestEventsReceived(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) - txService := NewTransactionService(sql.InMemory(), nil, meshAPIMock, conStateAPI, nil, nil) + txService := NewTransactionService(statesql.InMemory(), nil, meshAPIMock, conStateAPI, nil, nil) gsService := NewGlobalStateService(meshAPIMock, conStateAPI) cfg, cleanup := launchServer(t, txService, gsService) t.Cleanup(cleanup) @@ -2270,8 +2270,8 @@ func TestEventsReceived(t *testing.T) { time.Sleep(time.Millisecond * 50) lg := zaptest.NewLogger(t) - svm := vm.New(sql.InMemory(), vm.WithLogger(lg)) - conState := txs.NewConservativeState(svm, sql.InMemory(), txs.WithLogger(lg.Named("conState"))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(lg)) + conState := txs.NewConservativeState(svm, statesql.InMemory(), txs.WithLogger(lg.Named("conState"))) conState.AddToCache(context.Background(), globalTx, time.Now()) weight := new(big.Rat).SetFloat64(18.7) @@ -2334,7 +2334,7 @@ func TestTransactionsRewards(t *testing.T) { req.NoError(err, "stream request returned unexpected error") time.Sleep(50 * time.Millisecond) - svm := vm.New(sql.InMemory(), vm.WithLogger(zaptest.NewLogger(t))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(zaptest.NewLogger(t))) _, _, err = svm.Apply(vm.ApplyContext{Layer: types.LayerID(17)}, []types.Transaction{*globalTx}, rewards) req.NoError(err) @@ -2355,7 +2355,7 @@ func TestTransactionsRewards(t *testing.T) { req.NoError(err, "stream request returned unexpected error") time.Sleep(50 * time.Millisecond) - svm := vm.New(sql.InMemory(), vm.WithLogger(zaptest.NewLogger(t))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(zaptest.NewLogger(t))) _, _, err = svm.Apply(vm.ApplyContext{Layer: types.LayerID(17)}, []types.Transaction{*globalTx}, rewards) req.NoError(err) @@ -2374,7 +2374,7 @@ func TestVMAccountUpdates(t *testing.T) { events.InitializeReporter() // in memory database doesn't allow reads while writer locked db - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) require.NoError(t, err) t.Cleanup(func() { db.Close() }) svm := vm.New(db, vm.WithLogger(zaptest.NewLogger(t))) @@ -2470,7 +2470,7 @@ func createAtxs(tb testing.TB, epoch types.EpochID, atxids []types.ATXID) []*typ func TestMeshService_EpochStream(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, diff --git a/api/grpcserver/http_server_test.go b/api/grpcserver/http_server_test.go index 9f72c576f4c..beb70be6f07 100644 --- a/api/grpcserver/http_server_test.go +++ b/api/grpcserver/http_server_test.go @@ -18,7 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func launchJsonServer(tb testing.TB, services ...ServiceAPI) (Config, func()) { @@ -66,7 +66,7 @@ func TestJsonApi(t *testing.T) { conStateAPI := NewMockconservativeState(ctrl) svc1 := NewNodeService(peerCounter, meshAPIMock, genTime, syncer, version, build) svc2 := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, diff --git a/api/grpcserver/mesh_service_test.go b/api/grpcserver/mesh_service_test.go index 82a8fbd0873..57c1fc02418 100644 --- a/api/grpcserver/mesh_service_test.go +++ b/api/grpcserver/mesh_service_test.go @@ -22,6 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -142,7 +143,7 @@ func HareMalfeasance(tb testing.TB, db sql.Executor) (types.NodeID, *wire.Malfea func TestMeshService_MalfeasanceQuery(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, @@ -193,7 +194,7 @@ func TestMeshService_MalfeasanceStream(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, @@ -299,7 +300,7 @@ func (t *ConStateAPIMockInstrumented) GetLayerStateRoot(types.LayerID) (types.Ha func TestReadLayer(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), &MeshAPIMockInstrumented{}, diff --git a/api/grpcserver/post_service_test.go b/api/grpcserver/post_service_test.go index f3fddd506b0..464ef984a16 100644 --- a/api/grpcserver/post_service_test.go +++ b/api/grpcserver/post_service_test.go @@ -23,7 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func launchPostSupervisor( @@ -58,7 +58,7 @@ func launchPostSupervisor( close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() logger := log.Named("post manager") mgr, err := activation.NewPostSetupManager(postCfg, logger, db, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) @@ -102,7 +102,7 @@ func launchPostSupervisorTLS( close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() logger := log.Named("post supervisor") mgr, err := activation.NewPostSetupManager(postCfg, logger, db, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) diff --git a/api/grpcserver/transaction_service.go b/api/grpcserver/transaction_service.go index ba0eea8cfe5..f5627fd27a5 100644 --- a/api/grpcserver/transaction_service.go +++ b/api/grpcserver/transaction_service.go @@ -28,7 +28,7 @@ import ( // TransactionService exposes transaction data, and a submit tx endpoint. type TransactionService struct { - db *sql.Database + db sql.StateDatabase publisher pubsub.Publisher // P2P Swarm mesh meshAPI // Mesh conState conservativeState @@ -52,7 +52,7 @@ func (s *TransactionService) String() string { // NewTransactionService creates a new grpc service using config data. func NewTransactionService( - db *sql.Database, + db sql.StateDatabase, publisher pubsub.Publisher, msh meshAPI, conState conservativeState, diff --git a/api/grpcserver/transaction_service_test.go b/api/grpcserver/transaction_service_test.go index 7fe12a9e340..c6b6165ea76 100644 --- a/api/grpcserver/transaction_service_test.go +++ b/api/grpcserver/transaction_service_test.go @@ -23,12 +23,13 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/txs" ) func TestTransactionService_StreamResults(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -36,7 +37,7 @@ func TestTransactionService_StreamResults(t *testing.T) { gen := fixture.NewTransactionResultGenerator(). WithAddresses(2) txs := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(ctx, func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(ctx, func(dtx sql.Transaction) error { for i := range txs { tx := gen.Next() @@ -134,7 +135,7 @@ func TestTransactionService_StreamResults(t *testing.T) { } func BenchmarkStreamResults(b *testing.B) { - db := sql.InMemory() + db := statesql.InMemory() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -216,7 +217,7 @@ func parseOk() parseExpectation { } func TestParseTransactions(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() vminst := vm.New(db) cfg, cleanup := launchServer(t, NewTransactionService(db, nil, nil, txs.NewConservativeState(vminst, db), nil, nil)) diff --git a/api/grpcserver/v2alpha1/account_test.go b/api/grpcserver/v2alpha1/account_test.go index a61d1c6fe0f..3b80a98969b 100644 --- a/api/grpcserver/v2alpha1/account_test.go +++ b/api/grpcserver/v2alpha1/account_test.go @@ -14,8 +14,8 @@ import ( "google.golang.org/grpc/status" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testAccount struct { @@ -27,7 +27,7 @@ type testAccount struct { } func TestAccountService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctrl, ctx := gomock.WithContext(context.Background(), t) conState := NewMockaccountConState(ctrl) diff --git a/api/grpcserver/v2alpha1/activation_test.go b/api/grpcserver/v2alpha1/activation_test.go index 183bd11c8be..13bd09040fa 100644 --- a/api/grpcserver/v2alpha1/activation_test.go +++ b/api/grpcserver/v2alpha1/activation_test.go @@ -16,12 +16,12 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestActivationService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewAtxsGenerator() @@ -104,7 +104,7 @@ func TestActivationService_List(t *testing.T) { } func TestActivationStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewAtxsGenerator() @@ -214,7 +214,7 @@ func TestActivationStreamService_Stream(t *testing.T) { } func TestActivationService_ActivationsCount(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() genEpoch3 := fixture.NewAtxsGenerator().WithEpochs(3, 1) diff --git a/api/grpcserver/v2alpha1/layer_test.go b/api/grpcserver/v2alpha1/layer_test.go index 59af523ea9a..6a63271d9fc 100644 --- a/api/grpcserver/v2alpha1/layer_test.go +++ b/api/grpcserver/v2alpha1/layer_test.go @@ -19,10 +19,11 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestLayerService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lrs := make([]layers.Layer, 100) @@ -98,7 +99,7 @@ func TestLayerConvertEventStatus(t *testing.T) { } func TestLayerStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lrs := make([]layers.Layer, 100) @@ -225,7 +226,7 @@ func layerGenWithBlock(withBlock bool) layerGenOpt { } } -func generateLayer(db *sql.Database, id types.LayerID, opts ...layerGenOpt) (*layers.Layer, error) { +func generateLayer(db sql.StateDatabase, id types.LayerID, opts ...layerGenOpt) (*layers.Layer, error) { g := &layerGenOpts{} for _, opt := range opts { opt(g) diff --git a/api/grpcserver/v2alpha1/reward_test.go b/api/grpcserver/v2alpha1/reward_test.go index d0afdda5be3..624b03e64a3 100644 --- a/api/grpcserver/v2alpha1/reward_test.go +++ b/api/grpcserver/v2alpha1/reward_test.go @@ -15,12 +15,12 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/rewards" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRewardService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewRewardsGenerator().WithAddresses(100).WithUniqueCoinbase() @@ -103,7 +103,7 @@ func TestRewardService_List(t *testing.T) { } func TestRewardStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewRewardsGenerator() diff --git a/api/grpcserver/v2alpha1/transaction_test.go b/api/grpcserver/v2alpha1/transaction_test.go index 4929015b0c8..8178e1c7d2e 100644 --- a/api/grpcserver/v2alpha1/transaction_test.go +++ b/api/grpcserver/v2alpha1/transaction_test.go @@ -31,18 +31,19 @@ import ( pubsubmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/txs" ) func TestTransactionService_List(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewTransactionResultGenerator().WithAddresses(2) txsList := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(ctx, func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(ctx, func(dtx sql.Transaction) error { for i := range txsList { tx := gen.Next() @@ -178,7 +179,7 @@ func TestTransactionService_List(t *testing.T) { func TestTransactionService_EstimateGas(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() vminst := vm.New(db) ctx := context.Background() @@ -241,7 +242,7 @@ func TestTransactionService_EstimateGas(t *testing.T) { func TestTransactionService_ParseTransaction(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() vminst := vm.New(db) ctx := context.Background() @@ -357,7 +358,7 @@ func TestTransactionServiceSubmitUnsync(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -400,7 +401,7 @@ func TestTransactionServiceSubmitInvalidTx(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(errors.New("failed validation")) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -437,7 +438,7 @@ func TestTransactionService_SubmitNoConcurrency(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).Times(numTxs) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) diff --git a/atxsdata/warmup.go b/atxsdata/warmup.go index 6f4d5e5abcb..f4ea14eefb3 100644 --- a/atxsdata/warmup.go +++ b/atxsdata/warmup.go @@ -14,7 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/layers" ) -func Warm(db *sql.Database, keep types.EpochID, logger *zap.Logger) (*Data, error) { +func Warm(db sql.StateDatabase, keep types.EpochID, logger *zap.Logger) (*Data, error) { cache := New() tx, err := db.Tx(context.Background()) if err != nil { diff --git a/atxsdata/warmup_test.go b/atxsdata/warmup_test.go index c2051fa1c71..bf14b257cfe 100644 --- a/atxsdata/warmup_test.go +++ b/atxsdata/warmup_test.go @@ -15,6 +15,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/mocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func gatx( @@ -38,7 +39,7 @@ func gatx( func TestWarmup(t *testing.T) { types.SetLayersPerEpoch(3) t.Run("sanity", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() applied := types.LayerID(10) nonce := types.VRFPostIndex(1) data := []types.ActivationTx{ @@ -61,19 +62,19 @@ func TestWarmup(t *testing.T) { } }) t.Run("no data", func(t *testing.T) { - c, err := Warm(sql.InMemory(), 1, zaptest.NewLogger(t)) + c, err := Warm(statesql.InMemory(), 1, zaptest.NewLogger(t)) require.NoError(t, err) require.NotNil(t, c) }) t.Run("closed db", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, db.Close()) c, err := Warm(db, 1, zaptest.NewLogger(t)) require.Error(t, err) require.Nil(t, c) }) t.Run("db failures", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nonce := types.VRFPostIndex(1) data := gatx(types.ATXID{1, 1}, 1, types.NodeID{1}, nonce) require.NoError(t, atxs.Add(db, &data, types.AtxBlob{})) diff --git a/beacon/beacon_test.go b/beacon/beacon_test.go index 89be3a07a02..1ef3c134d56 100644 --- a/beacon/beacon_test.go +++ b/beacon/beacon_test.go @@ -30,6 +30,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -91,7 +92,7 @@ func newTestDriver(tb testing.TB, cfg Config, p pubsub.Publisher, miners int, id tpd.mVerifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(true) - tpd.cdb = datastore.NewCachedDB(sql.InMemory(), lg) + tpd.cdb = datastore.NewCachedDB(statesql.InMemory(), lg) tpd.ProtocolDriver = New(p, signing.NewEdVerifier(), tpd.mVerifier, tpd.cdb, tpd.mClock, WithConfig(cfg), WithLogger(lg), @@ -493,7 +494,7 @@ func TestBeacon_NoRaceOnClose(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, closed: make(chan struct{}), results: make(chan result.Beacon, 100), @@ -528,7 +529,7 @@ func TestBeacon_BeaconsWithDatabase(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, } epoch3 := types.EpochID(3) @@ -581,7 +582,7 @@ func TestBeacon_BeaconsWithDatabaseFailure(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, } epoch := types.EpochID(3) @@ -599,7 +600,7 @@ func TestBeacon_BeaconsCleanupOldEpoch(t *testing.T) { mclock := NewMocklayerClock(gomock.NewController(t)) pd := &ProtocolDriver{ logger: lg.Named("Beacon"), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, @@ -704,7 +705,7 @@ func TestBeacon_ReportBeaconFromBallot(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), config: UnitTestConfig(), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, @@ -740,7 +741,7 @@ func TestBeacon_ReportBeaconFromBallot_SameBallot(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), config: UnitTestConfig(), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, diff --git a/blocks/certifier.go b/blocks/certifier.go index d5045149e93..b3461a98fc0 100644 --- a/blocks/certifier.go +++ b/blocks/certifier.go @@ -81,7 +81,7 @@ type Certifier struct { stop func() stopped atomic.Bool - db *sql.Database + db sql.StateDatabase oracle eligibility.Rolacle signers map[types.NodeID]*signing.EdSigner edVerifier *signing.EdVerifier @@ -99,7 +99,7 @@ type Certifier struct { // NewCertifier creates new block certifier. func NewCertifier( - db *sql.Database, + db sql.StateDatabase, o eligibility.Rolacle, v *signing.EdVerifier, @@ -567,7 +567,7 @@ func (c *Certifier) save( if len(valid)+len(invalid) == 0 { return certificates.Add(c.db, lid, cert) } - return c.db.WithTx(ctx, func(dbtx *sql.Tx) error { + return c.db.WithTx(ctx, func(dbtx sql.Transaction) error { if err := certificates.Add(dbtx, lid, cert); err != nil { return err } diff --git a/blocks/certifier_test.go b/blocks/certifier_test.go index 8a3cc721206..d4a0c7faef0 100644 --- a/blocks/certifier_test.go +++ b/blocks/certifier_test.go @@ -20,6 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -27,7 +28,7 @@ const defaultCnt = uint16(2) type testCertifier struct { *Certifier - db *sql.Database + db sql.StateDatabase mOracle *eligibility.MockRolacle mPub *pubsubmock.MockPublisher mClk *mocks.MocklayerClock @@ -38,7 +39,7 @@ type testCertifier struct { func newTestCertifier(t *testing.T, signers int) *testCertifier { t.Helper() types.SetLayersPerEpoch(3) - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(t) mo := eligibility.NewMockRolacle(ctrl) mp := pubsubmock.NewMockPublisher(ctrl) diff --git a/blocks/generator.go b/blocks/generator.go index b7070c2ca09..1c7ab6b196c 100644 --- a/blocks/generator.go +++ b/blocks/generator.go @@ -30,7 +30,7 @@ type Generator struct { eg errgroup.Group stop func() - db *sql.Database + db sql.StateDatabase atxs *atxsdata.Data proposals *store.Store msh meshProvider @@ -84,7 +84,7 @@ func WithHareOutputChan(ch <-chan hare4.ConsensusOutput) GeneratorOpt { // NewGenerator creates new block generator. func NewGenerator( - db *sql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, proposals *store.Store, exec executor, diff --git a/blocks/generator_test.go b/blocks/generator_test.go index f77d0fbc306..6a6bcec9585 100644 --- a/blocks/generator_test.go +++ b/blocks/generator_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -73,7 +74,7 @@ func createTestGenerator(t *testing.T) *testGenerator { } tg.mockMesh.EXPECT().ProcessedLayer().Return(types.LayerID(1)).AnyTimes() lg := zaptest.NewLogger(t) - db := sql.InMemory() + db := statesql.InMemory() data := atxsdata.New() proposals := store.New() tg.Generator = NewGenerator( @@ -266,7 +267,7 @@ func Test_StopBeforeStart(t *testing.T) { func genData( t *testing.T, - db *sql.Database, + db sql.StateDatabase, data *atxsdata.Data, store *store.Store, lid types.LayerID, diff --git a/blocks/handler.go b/blocks/handler.go index 9260fb8ed5d..e892d7f1f36 100644 --- a/blocks/handler.go +++ b/blocks/handler.go @@ -28,7 +28,7 @@ type Handler struct { logger *zap.Logger fetcher system.Fetcher - db *sql.Database + db sql.StateDatabase tortoise tortoiseProvider mesh meshProvider } @@ -44,7 +44,13 @@ func WithLogger(logger *zap.Logger) Opt { } // NewHandler creates new Handler. -func NewHandler(f system.Fetcher, db *sql.Database, tortoise tortoiseProvider, m meshProvider, opts ...Opt) *Handler { +func NewHandler( + f system.Fetcher, + db sql.StateDatabase, + tortoise tortoiseProvider, + m meshProvider, + opts ...Opt, +) *Handler { h := &Handler{ logger: zap.NewNop(), fetcher: f, diff --git a/blocks/handler_test.go b/blocks/handler_test.go index 8c44db02726..7ccb2078406 100644 --- a/blocks/handler_test.go +++ b/blocks/handler_test.go @@ -13,8 +13,8 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -34,7 +34,7 @@ func createTestHandler(t *testing.T) *testHandler { } th.Handler = NewHandler( th.mockFetcher, - sql.InMemory(), + statesql.InMemory(), th.mockTortoise, th.mockMesh, WithLogger(zaptest.NewLogger(t)), diff --git a/blocks/utils.go b/blocks/utils.go index 96dcd82b3ad..b2d03622874 100644 --- a/blocks/utils.go +++ b/blocks/utils.go @@ -50,7 +50,7 @@ type proposalMetadata struct { func getProposalMetadata( ctx context.Context, logger *zap.Logger, - db *sql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, cfg Config, lid types.LayerID, @@ -232,7 +232,7 @@ func toUint64Slice(b []byte) []uint64 { func rewardInfoAndHeight( cfg Config, - db *sql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, props []*types.Proposal, ) (uint64, []types.AnyReward, error) { diff --git a/blocks/utils_test.go b/blocks/utils_test.go index 11776145b5b..d7f9b1f3ca4 100644 --- a/blocks/utils_test.go +++ b/blocks/utils_test.go @@ -14,9 +14,9 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -159,7 +159,7 @@ func Test_getBlockTXs_expected_order(t *testing.T) { func Test_getProposalMetadata(t *testing.T) { lg := zaptest.NewLogger(t) - db := sql.InMemory() + db := statesql.InMemory() data := atxsdata.New() cfg := Config{OptFilterThreshold: 70} lid := types.LayerID(111) diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 31df1cdff18..652eda2cd46 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -29,6 +29,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const recoveryDir = "recovery" @@ -120,7 +121,7 @@ func Recover( return nil, errors.New("restore layer not set") } logger.Info("recovering from checkpoint", zap.String("url", cfg.Uri), zap.Stringer("restore", cfg.Restore)) - db, err := sql.Open("file:" + cfg.DbPath()) + db, err := statesql.Open("file:" + cfg.DbPath()) if err != nil { return nil, fmt.Errorf("open old database: %w", err) } @@ -131,7 +132,7 @@ func Recover( } defer localDB.Close() logger.Info("clearing atx and malfeasance sync metadata from local database") - if err := localDB.WithTx(ctx, func(tx *sql.Tx) error { + if err := localDB.WithTx(ctx, func(tx sql.Transaction) error { if err := atxsync.Clear(tx); err != nil { return err } @@ -153,8 +154,8 @@ func Recover( func RecoverWithDb( ctx context.Context, logger *zap.Logger, - db *sql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, fs afero.Fs, cfg *RecoverConfig, ) (*PreservedData, error) { @@ -187,8 +188,8 @@ type recoveryData struct { func RecoverFromLocalFile( ctx context.Context, logger *zap.Logger, - db *sql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, fs afero.Fs, cfg *RecoverConfig, file string, @@ -257,8 +258,7 @@ func RecoverFromLocalFile( } logger.Info("backed up old database", log.ZContext(ctx), zap.String("backup dir", backupDir)) - var newDB *sql.Database - newDB, err = sql.Open("file:" + cfg.DbPath()) + newDB, err := statesql.Open("file:" + cfg.DbPath()) if err != nil { return nil, fmt.Errorf("creating new DB: %w", err) } @@ -268,7 +268,7 @@ func RecoverFromLocalFile( zap.Int("num accounts", len(data.accounts)), zap.Int("num atxs", len(data.atxs)), ) - if err = newDB.WithTx(ctx, func(tx *sql.Tx) error { + if err = newDB.WithTx(ctx, func(tx sql.Transaction) error { for _, acct := range data.accounts { if err = accounts.Update(tx, acct); err != nil { return fmt.Errorf("restore account snapshot: %w", err) @@ -392,8 +392,8 @@ func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recove func collectOwnAtxDeps( logger *zap.Logger, - db *sql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, nodeID types.NodeID, goldenATX types.ATXID, data *recoveryData, @@ -454,7 +454,7 @@ func collectOwnAtxDeps( } func collectDeps( - db *sql.Database, + db sql.StateDatabase, ref types.ATXID, all map[types.ATXID]struct{}, ) (map[types.ATXID]*AtxDep, map[types.PoetProofRef]*types.PoetProofMessage, error) { @@ -470,7 +470,7 @@ func collectDeps( } func collect( - db *sql.Database, + db sql.StateDatabase, ref types.ATXID, all map[types.ATXID]struct{}, deps map[types.ATXID]*AtxDep, @@ -535,7 +535,7 @@ func collect( } func poetProofs( - db *sql.Database, + db sql.StateDatabase, atxIds map[types.ATXID]*AtxDep, ) (map[types.PoetProofRef]*types.PoetProofMessage, error) { proofs := make(map[types.PoetProofRef]*types.PoetProofMessage, len(atxIds)) diff --git a/checkpoint/recovery_collecting_deps_test.go b/checkpoint/recovery_collecting_deps_test.go index 1df6ca0a9b7..4bbc892c105 100644 --- a/checkpoint/recovery_collecting_deps_test.go +++ b/checkpoint/recovery_collecting_deps_test.go @@ -10,15 +10,15 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestCollectingDeps(t *testing.T) { golden := types.RandomATXID() t.Run("collect marriage ATXs", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() marriageATX := &wire.ActivationTxV1{ InnerActivationTxV1: wire.InnerActivationTxV1{ diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index d67a259225e..eb0e6d0a58d 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -35,6 +35,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -77,7 +78,7 @@ func accountEqual(tb testing.TB, cacct types.AccountSnapshot, acct *types.Accoun } } -func verifyDbContent(tb testing.TB, db *sql.Database) { +func verifyDbContent(tb testing.TB, db sql.StateDatabase) { tb.Helper() var expected types.Checkpoint require.NoError(tb, json.Unmarshal([]byte(checkpointData), &expected)) @@ -168,7 +169,7 @@ func TestRecover(t *testing.T) { } bsdir := filepath.Join(cfg.DataDir, bootstrap.DirName) require.NoError(t, fs.MkdirAll(bsdir, 0o700)) - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() data, err := checkpoint.RecoverWithDb(context.Background(), zaptest.NewLogger(t), db, localDB, fs, cfg) if tc.expErr != nil { @@ -177,7 +178,7 @@ func TestRecover(t *testing.T) { } require.NoError(t, err) require.Nil(t, data) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) defer newDB.Close() @@ -209,7 +210,7 @@ func TestRecover_SameRecoveryInfo(t *testing.T) { } bsdir := filepath.Join(cfg.DataDir, bootstrap.DirName) require.NoError(t, fs.MkdirAll(bsdir, 0o700)) - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() types.SetEffectiveGenesis(0) require.NoError(t, recovery.SetCheckpoint(db, types.LayerID(recoverLayer))) @@ -241,7 +242,7 @@ func TestRecover_RestoreLayerCannotBeZero(t *testing.T) { func validateAndPreserveData( tb testing.TB, - db *sql.Database, + db sql.StateDatabase, deps []*checkpoint.AtxDep, ) { tb.Helper() @@ -496,7 +497,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -541,7 +542,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) { require.ElementsMatch(t, atxRef, atxIDs(preserve.Deps)) require.ElementsMatch(t, proofRef, proofRefs(preserve.Proofs)) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -581,7 +582,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -653,7 +654,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { require.ElementsMatch(t, atxRef, atxIDs(preserve.Deps)) require.ElementsMatch(t, proofRef, proofRefs(preserve.Proofs)) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -691,7 +692,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -748,7 +749,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -781,7 +782,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) vAtxs, proofs := createAtxChain(t, sig) @@ -824,7 +825,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -857,7 +858,7 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) vAtxs, proofs := createAtxChain(t, sig) @@ -883,7 +884,7 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -921,7 +922,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) require.NoError(t, atxs.Add(oldDB, atx, types.AtxBlob{})) @@ -931,7 +932,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) diff --git a/checkpoint/runner.go b/checkpoint/runner.go index 7039aa54c0f..965dd1ca00c 100644 --- a/checkpoint/runner.go +++ b/checkpoint/runner.go @@ -28,7 +28,7 @@ const ( func checkpointDB( ctx context.Context, - db *sql.Database, + db sql.StateDatabase, snapshot types.LayerID, numAtxs int, ) (*types.Checkpoint, error) { @@ -171,7 +171,7 @@ func checkpointDB( func Generate( ctx context.Context, fs afero.Fs, - db *sql.Database, + db sql.StateDatabase, dataDir string, snapshot types.LayerID, numAtxs int, diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index fc1a33eda32..727da2843f6 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -21,6 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -262,7 +263,7 @@ func asAtxSnapshot(v *types.ActivationTx, cmt *types.ATXID) types.AtxSnapshot { } } -func createMesh(t testing.TB, db *sql.Database, miners []miner, accts []*types.Account) { +func createMesh(t testing.TB, db sql.StateDatabase, miners []miner, accts []*types.Account) { t.Helper() for _, miner := range miners { for _, atx := range miner.atxs { @@ -309,7 +310,7 @@ func TestRunner_Generate(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) createMesh(t, db, tc.miners, tc.accts) @@ -345,7 +346,7 @@ func TestRunner_Generate_Error(t *testing.T) { t.Parallel() t.Run("no commitment atx", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) atx := newAtx(types.ATXID{13}, nil, 2, 1, 11, types.RandomNodeID().Bytes()) @@ -359,7 +360,7 @@ func TestRunner_Generate_Error(t *testing.T) { }) t.Run("no atxs", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) createMesh(t, db, nil, allAccounts) @@ -372,7 +373,7 @@ func TestRunner_Generate_Error(t *testing.T) { }) t.Run("no accounts", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) createMesh(t, db, allMiners, nil) @@ -387,7 +388,7 @@ func TestRunner_Generate_Error(t *testing.T) { func TestRunner_Generate_PreservesMarriageATX(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, accounts.Update(db, &types.Account{Address: types.Address{1, 1}})) diff --git a/cmd/activeset/activeset.go b/cmd/activeset/activeset.go index 1046916b024..1415972e6b5 100644 --- a/cmd/activeset/activeset.go +++ b/cmd/activeset/activeset.go @@ -9,8 +9,8 @@ import ( "strconv" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func main() { @@ -30,7 +30,7 @@ Example: if len(dbpath) == 0 { must(errors.New("dbpath is empty"), "dbpath is empty\n") } - db, err := sql.Open("file:" + dbpath) + db, err := statesql.Open("file:" + dbpath) must(err, "can't open db at dbpath=%v. err=%s\n", dbpath, err) ids, err := atxs.GetIDsByEpoch(context.Background(), db, types.EpochID(publish)) diff --git a/cmd/bootstrapper/generator_test.go b/cmd/bootstrapper/generator_test.go index 4c1aec7bb43..e201d9186bc 100644 --- a/cmd/bootstrapper/generator_test.go +++ b/cmd/bootstrapper/generator_test.go @@ -24,6 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -102,7 +103,7 @@ func verifyUpdate(tb testing.TB, data []byte, epoch types.EpochID, expBeacon str func TestGenerator_Generate(t *testing.T) { t.Parallel() targetEpoch := types.EpochID(3) - db := sql.InMemory() + db := statesql.InMemory() createAtxs(t, db, targetEpoch-1, types.RandomActiveSet(activeSetSize)) cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, zaptest.NewLogger(t))) t.Cleanup(cleanup) @@ -168,7 +169,7 @@ func TestGenerator_Generate(t *testing.T) { func TestGenerator_CheckAPI(t *testing.T) { t.Parallel() targetEpoch := types.EpochID(3) - db := sql.InMemory() + db := statesql.InMemory() lg := zaptest.NewLogger(t) createAtxs(t, db, targetEpoch-1, types.RandomActiveSet(activeSetSize)) cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, lg)) diff --git a/cmd/bootstrapper/server_test.go b/cmd/bootstrapper/server_test.go index 1bdb61c6bb0..df0f20a36f9 100644 --- a/cmd/bootstrapper/server_test.go +++ b/cmd/bootstrapper/server_test.go @@ -19,7 +19,7 @@ import ( "github.com/spacemeshos/go-spacemesh/bootstrap" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) //go:embed checkpointdata.json @@ -56,7 +56,7 @@ func updateCheckpoint(t *testing.T, ctx context.Context, data string) { } func TestServer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, zaptest.NewLogger(t))) t.Cleanup(cleanup) diff --git a/cmd/merge-nodes/internal/errors.go b/cmd/merge-nodes/internal/errors.go index 53af669de08..b3e3449b0b5 100644 --- a/cmd/merge-nodes/internal/errors.go +++ b/cmd/merge-nodes/internal/errors.go @@ -4,7 +4,4 @@ import ( "errors" ) -var ( - ErrSupervisedNode = errors.New("merging of supervised smeshing nodes is not supported") - ErrInvalidSchema = errors.New("database has an invalid schema version") -) +var ErrSupervisedNode = errors.New("merging of supervised smeshing nodes is not supported") diff --git a/cmd/merge-nodes/internal/merge_action.go b/cmd/merge-nodes/internal/merge_action.go index c516ff297c5..426ad82f2c4 100644 --- a/cmd/merge-nodes/internal/merge_action.go +++ b/cmd/merge-nodes/internal/merge_action.go @@ -26,7 +26,7 @@ const ( func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { // Open the target database - var dstDB *localsql.Database + var dstDB sql.LocalDatabase var err error dstDB, err = openDB(dbLog, to) switch { @@ -158,7 +158,7 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { } dbLog.Info("merging databases", zap.String("from", from), zap.String("to", to)) - err = dstDB.WithTx(ctx, func(tx *sql.Tx) error { + err = dstDB.WithTx(ctx, func(tx sql.Transaction) error { enc := func(stmt *sql.Statement) { stmt.BindText(1, filepath.Join(from, localDbFile)) } @@ -191,38 +191,20 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { return nil } -func openDB(dbLog *zap.Logger, path string) (*localsql.Database, error) { +func openDB(dbLog *zap.Logger, path string) (sql.LocalDatabase, error) { dbPath := filepath.Join(path, localDbFile) if _, err := os.Stat(dbPath); err != nil { - return nil, fmt.Errorf("open database %s: %w", dbPath, err) - } - - migrations, err := sql.LocalMigrations() - if err != nil { - return nil, fmt.Errorf("get local migrations: %w", err) + return nil, fmt.Errorf("stat source database %s: %w", dbPath, err) } db, err := localsql.Open("file:"+dbPath, sql.WithLogger(dbLog), - sql.WithMigrations(nil), // do not migrate database when opening + sql.WithMigrationsDisabled(), ) if err != nil { return nil, fmt.Errorf("open source database %s: %w", dbPath, err) } - // check if the source database has the right schema - var version int - _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - if err != nil { - return nil, fmt.Errorf("get source database schema for %s: %w", dbPath, err) - } - if version != len(migrations) { - db.Close() - return nil, ErrInvalidSchema - } return db, nil } diff --git a/cmd/merge-nodes/internal/merge_action_test.go b/cmd/merge-nodes/internal/merge_action_test.go index 405676cb987..7fae559a48c 100644 --- a/cmd/merge-nodes/internal/merge_action_test.go +++ b/cmd/merge-nodes/internal/merge_action_test.go @@ -24,20 +24,26 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) -func Test_MergeDBs_InvalidTargetScheme(t *testing.T) { - tmpDst := t.TempDir() - - migrations, err := sql.LocalMigrations() +func oldSchema(t *testing.T) *sql.Schema { + schema, err := localsql.Schema() require.NoError(t, err) + schema.Migrations = schema.Migrations[:2] + return schema +} + +func Test_MergeDBs_InvalidTargetSchema(t *testing.T) { + tmpDst := t.TempDir() db, err := localsql.Open("file:"+filepath.Join(tmpDst, localDbFile), - sql.WithMigrations(migrations[:2]), // old schema + sql.WithDatabaseSchema(oldSchema(t)), + sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) err = MergeDBs(context.Background(), zaptest.NewLogger(t), "", tmpDst) - require.ErrorIs(t, err, ErrInvalidSchema) + require.ErrorIs(t, err, sql.ErrOldSchema) require.ErrorContains(t, err, "target database") } @@ -79,25 +85,24 @@ func Test_MergeDBs_InvalidSourcePath(t *testing.T) { require.ErrorIs(t, err, fs.ErrNotExist) } -func Test_MergeDBs_InvalidSourceScheme(t *testing.T) { +func Test_MergeDBs_InvalidSourceSchema(t *testing.T) { tmpDst := t.TempDir() - migrations, err := sql.LocalMigrations() - require.NoError(t, err) - db, err := localsql.Open("file:" + filepath.Join(tmpDst, localDbFile)) require.NoError(t, err) require.NoError(t, db.Close()) tmpSrc := t.TempDir() db, err = localsql.Open("file:"+filepath.Join(tmpSrc, localDbFile), - sql.WithMigrations(migrations[:2]), // old schema + sql.WithDatabaseSchema(oldSchema(t)), + sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) err = MergeDBs(context.Background(), zaptest.NewLogger(t), tmpSrc, tmpDst) - require.ErrorIs(t, err, ErrInvalidSchema) + require.ErrorIs(t, err, sql.ErrOldSchema) require.ErrorContains(t, err, "source database") } diff --git a/cmd/node/main.go b/cmd/node/main.go index 99ef2265e49..0fb230794db 100644 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -3,7 +3,6 @@ package main import ( - "fmt" _ "net/http/pprof" "os" @@ -24,7 +23,8 @@ func main() { // run the app cmd.Branch = branch cmd.NoMainNet = noMainNet == "true" if err := node.GetCommand().Execute(); err != nil { - fmt.Fprintln(os.Stderr, err) + // Do not print error as cmd.SilenceErrors is false + // and the error was already printed os.Exit(1) } } diff --git a/config/config.go b/config/config.go index b289ba722bc..0c737b75344 100644 --- a/config/config.go +++ b/config/config.go @@ -119,6 +119,7 @@ type BaseConfig struct { DatabaseSkipMigrations []int `mapstructure:"db-skip-migrations"` DatabaseQueryCache bool `mapstructure:"db-query-cache"` DatabaseQueryCacheSizes DatabaseQueryCacheSizes `mapstructure:"db-query-cache-sizes"` + DatabaseSchemaAllowDrift bool `mapstructure:"db-allow-schema-drift"` PruneActivesetsFrom types.EpochID `mapstructure:"prune-activesets-from"` diff --git a/datastore/mocks/mocks.go b/datastore/mocks/mocks.go deleted file mode 100644 index bbb150d389f..00000000000 --- a/datastore/mocks/mocks.go +++ /dev/null @@ -1,156 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./store.go -// -// Generated by this command: -// -// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./store.go -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - sql "github.com/spacemeshos/go-spacemesh/sql" - gomock "go.uber.org/mock/gomock" -) - -// MockExecutor is a mock of Executor interface. -type MockExecutor struct { - ctrl *gomock.Controller - recorder *MockExecutorMockRecorder -} - -// MockExecutorMockRecorder is the mock recorder for MockExecutor. -type MockExecutorMockRecorder struct { - mock *MockExecutor -} - -// NewMockExecutor creates a new mock instance. -func NewMockExecutor(ctrl *gomock.Controller) *MockExecutor { - mock := &MockExecutor{ctrl: ctrl} - mock.recorder = &MockExecutorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockExecutor) EXPECT() *MockExecutorMockRecorder { - return m.recorder -} - -// Exec mocks base method. -func (m *MockExecutor) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockExecutorMockRecorder) Exec(arg0, arg1, arg2 any) *MockExecutorExecCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockExecutor)(nil).Exec), arg0, arg1, arg2) - return &MockExecutorExecCall{Call: call} -} - -// MockExecutorExecCall wrap *gomock.Call -type MockExecutorExecCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorExecCall) Return(arg0 int, arg1 error) *MockExecutorExecCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockExecutorExecCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockExecutorExecCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCache mocks base method. -func (m *MockExecutor) QueryCache() sql.QueryCache { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCache") - ret0, _ := ret[0].(sql.QueryCache) - return ret0 -} - -// QueryCache indicates an expected call of QueryCache. -func (mr *MockExecutorMockRecorder) QueryCache() *MockExecutorQueryCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockExecutor)(nil).QueryCache)) - return &MockExecutorQueryCacheCall{Call: call} -} - -// MockExecutorQueryCacheCall wrap *gomock.Call -type MockExecutorQueryCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorQueryCacheCall) Return(arg0 sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorQueryCacheCall) Do(f func() sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTx mocks base method. -func (m *MockExecutor) WithTx(arg0 context.Context, arg1 func(*sql.Tx) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTx", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTx indicates an expected call of WithTx. -func (mr *MockExecutorMockRecorder) WithTx(arg0, arg1 any) *MockExecutorWithTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockExecutor)(nil).WithTx), arg0, arg1) - return &MockExecutorWithTxCall{Call: call} -} - -// MockExecutorWithTxCall wrap *gomock.Call -type MockExecutorWithTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorWithTxCall) Return(arg0 error) *MockExecutorWithTxCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorWithTxCall) Do(f func(context.Context, func(*sql.Tx) error) error) *MockExecutorWithTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorWithTxCall) DoAndReturn(f func(context.Context, func(*sql.Tx) error) error) *MockExecutorWithTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/datastore/store.go b/datastore/store.go index 7a2cccb4168..073c7bacdc9 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -30,18 +30,9 @@ type VrfNonceKey struct { Epoch types.EpochID } -//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./store.go - -type Executor interface { - sql.Executor - WithTx(context.Context, func(*sql.Tx) error) error - QueryCache() sql.QueryCache -} - // CachedDB is simply a database injected with cache. type CachedDB struct { - Executor - sql.QueryCache + sql.Database logger *zap.Logger // cache is optional in tests. It MUST be set for the 'App' @@ -91,7 +82,7 @@ func WithConsensusCache(c *atxsdata.Data) Opt { } // NewCachedDB create an instance of a CachedDB. -func NewCachedDB(db Executor, lg *zap.Logger, opts ...Opt) *CachedDB { +func NewCachedDB(db sql.StateDatabase, lg *zap.Logger, opts ...Opt) *CachedDB { o := cacheOpts{cfg: DefaultConfig()} for _, opt := range opts { opt(&o) @@ -114,8 +105,7 @@ func NewCachedDB(db Executor, lg *zap.Logger, opts ...Opt) *CachedDB { } return &CachedDB{ - Executor: db, - QueryCache: db.QueryCache(), + Database: db, logger: lg, atxsdata: o.atxsdata, atxCache: atxHdrCache, @@ -143,7 +133,7 @@ func (db *CachedDB) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof return proof, nil } - proof, err := identities.GetMalfeasanceProof(db.Executor, id) + proof, err := identities.GetMalfeasanceProof(db.Database, id) if err != nil && err != sql.ErrNotFound { return nil, err } diff --git a/datastore/store_test.go b/datastore/store_test.go index e928e1b226a..13372858b16 100644 --- a/datastore/store_test.go +++ b/datastore/store_test.go @@ -25,6 +25,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/poets" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -53,7 +54,7 @@ func getBytes( } func TestMalfeasanceProof_Dishonest(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t)) require.Equal(t, 0, cdb.MalfeasanceCacheSize()) @@ -81,7 +82,7 @@ func TestMalfeasanceProof_Dishonest(t *testing.T) { } func TestBlobStore_GetATXBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -124,7 +125,7 @@ func TestBlobStore_GetATXBlob(t *testing.T) { } func TestBlobStore_GetBallotBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -159,7 +160,7 @@ func TestBlobStore_GetBallotBlob(t *testing.T) { } func TestBlobStore_GetBlockBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -194,7 +195,7 @@ func TestBlobStore_GetBlockBlob(t *testing.T) { } func TestBlobStore_GetPoetBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -223,7 +224,7 @@ func TestBlobStore_GetPoetBlob(t *testing.T) { } func TestBlobStore_GetProposalBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() proposals := store.New() bs := datastore.NewBlobStore(db, proposals) ctx := context.Background() @@ -261,7 +262,7 @@ func TestBlobStore_GetProposalBlob(t *testing.T) { } func TestBlobStore_GetTXBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -289,7 +290,7 @@ func TestBlobStore_GetTXBlob(t *testing.T) { } func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -323,7 +324,7 @@ func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { } func TestBlobStore_GetActiveSet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() diff --git a/fetch/fetch_test.go b/fetch/fetch_test.go index 8a525ed8c7d..d136cb4ac07 100644 --- a/fetch/fetch_test.go +++ b/fetch/fetch_test.go @@ -21,7 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/proposals/store" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testFetch struct { @@ -80,7 +80,7 @@ func createFetch(tb testing.TB) *testFetch { } lg := zaptest.NewLogger(tb) - tf.Fetch = NewFetch(datastore.NewCachedDB(sql.InMemory(), lg), store.New(), nil, + tf.Fetch = NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg), store.New(), nil, WithContext(context.TODO()), WithConfig(cfg), WithLogger(lg), @@ -117,7 +117,7 @@ func badReceiver(context.Context, types.Hash32, p2p.Peer, []byte) error { func TestFetch_Start(t *testing.T) { lg := zaptest.NewLogger(t) - f := NewFetch(datastore.NewCachedDB(sql.InMemory(), lg), store.New(), nil, + f := NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg), store.New(), nil, WithContext(context.TODO()), WithConfig(DefaultConfig()), WithLogger(lg), @@ -382,7 +382,7 @@ func TestFetch_PeerDroppedWhenMessageResultsInValidationReject(t *testing.T) { }) defer eg.Wait() - fetcher := NewFetch(datastore.NewCachedDB(sql.InMemory(), lg), store.New(), h, + fetcher := NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg), store.New(), h, WithContext(ctx), WithConfig(cfg), WithLogger(lg), diff --git a/fetch/handler_test.go b/fetch/handler_test.go index d42900a28a5..45905c1dc23 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -24,17 +24,18 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testHandler struct { *handler - db *sql.Database + db sql.StateDatabase cdb *datastore.CachedDB } func createTestHandler(t testing.TB, opts ...sql.Opt) *testHandler { lg := zaptest.NewLogger(t) - db := sql.InMemory(opts...) + db := statesql.InMemory(opts...) cdb := datastore.NewCachedDB(db, lg) return &testHandler{ handler: newHandler(cdb, datastore.NewBlobStore(cdb, store.New()), lg), @@ -349,6 +350,8 @@ func testHandleEpochInfoReqWithQueryCache( getInfo func(th *testHandler, req []byte, ed *EpochData), ) { th := createTestHandler(t, sql.WithQueryCache(true)) + require.True(t, th.cdb.Database.IsCached()) + require.True(t, sql.IsCached(th.cdb)) epoch := types.EpochID(11) var expected EpochData @@ -359,8 +362,7 @@ func testHandleEpochInfoReqWithQueryCache( expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) } - qc := th.cdb.Executor.(interface{ QueryCount() int }) - require.Equal(t, 20, qc.QueryCount()) + require.Equal(t, 20, th.cdb.Database.QueryCount()) epochBytes, err := codec.Encode(epoch) require.NoError(t, err) @@ -368,7 +370,7 @@ func testHandleEpochInfoReqWithQueryCache( for i := 0; i < 3; i++ { getInfo(th, epochBytes, &got) require.ElementsMatch(t, expected.AtxIDs, got.AtxIDs) - require.Equal(t, 21, qc.QueryCount()) + require.Equal(t, 21, th.cdb.Database.QueryCount(), "query count @ i = %d", i) } // Add another ATX which should be appended to the cached slice @@ -376,14 +378,14 @@ func testHandleEpochInfoReqWithQueryCache( require.NoError(t, atxs.Add(th.cdb, vatx, blob)) atxs.AtxAdded(th.cdb, vatx) expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) - require.Equal(t, 23, qc.QueryCount()) + require.Equal(t, 23, th.cdb.Database.QueryCount()) getInfo(th, epochBytes, &got) require.ElementsMatch(t, expected.AtxIDs, got.AtxIDs) // The query count is not incremented as the slice is still // cached and the new atx is just appended to it, even though // the response is re-serialized. - require.Equal(t, 23, qc.QueryCount()) + require.Equal(t, 23, th.cdb.Database.QueryCount()) } func TestHandleEpochInfoReqWithQueryCache(t *testing.T) { diff --git a/fetch/mesh_data_test.go b/fetch/mesh_data_test.go index a4084653b19..31205448458 100644 --- a/fetch/mesh_data_test.go +++ b/fetch/mesh_data_test.go @@ -26,7 +26,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/proposals/store" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -999,7 +999,7 @@ func Test_GetAtxsLimiting(t *testing.T) { cfg.QueueSize = 1000 cfg.GetAtxsConcurrency = getAtxConcurrency - cdb := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + cdb := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) client := server.New(wrapHost(mesh.Hosts()[0]), hashProtocol, nil) host, err := p2p.Upgrade(mesh.Hosts()[0]) require.NoError(t, err) diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index a9dc85a8254..26dd5524c5f 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -27,6 +27,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/poets" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -37,13 +38,13 @@ type blobKey struct { type testP2PFetch struct { t *testing.T - clientDB *sql.Database + clientDB sql.StateDatabase // client proposals clientPDB *store.Store clientCDB *datastore.CachedDB clientFetch *Fetch serverID peer.ID - serverDB *sql.Database + serverDB sql.StateDatabase // server proposals serverPDB *store.Store serverCDB *datastore.CachedDB @@ -110,8 +111,8 @@ func createP2PFetch( if sqlCache { sqlOpts = []sql.Opt{sql.WithQueryCache(true)} } - clientDB := sql.InMemory(sqlOpts...) - serverDB := sql.InMemory(sqlOpts...) + clientDB := statesql.InMemory(sqlOpts...) + serverDB := statesql.InMemory(sqlOpts...) tpf := &testP2PFetch{ t: t, clientDB: clientDB, diff --git a/genvm/core/context_test.go b/genvm/core/context_test.go index e13a8287ff5..92d8cb627db 100644 --- a/genvm/core/context_test.go +++ b/genvm/core/context_test.go @@ -10,23 +10,23 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/core" "github.com/spacemeshos/go-spacemesh/genvm/core/mocks" "github.com/spacemeshos/go-spacemesh/genvm/registry" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestTransfer(t *testing.T) { t.Run("NoBalance", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} require.ErrorIs(t, ctx.Transfer(core.Address{}, 100), core.ErrNoBalance) }) t.Run("MaxSpend", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 100 require.NoError(t, ctx.Transfer(core.Address{1}, 50)) require.ErrorIs(t, ctx.Transfer(core.Address{2}, 100), core.ErrMaxSpend) }) t.Run("ReducesBalance", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 1000 for _, amount := range []uint64{50, 100, 200, 255} { @@ -67,7 +67,7 @@ func TestConsume(t *testing.T) { func TestApply(t *testing.T) { t.Run("UpdatesNonce", func(t *testing.T) { - ss := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + ss := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: ss} ctx.PrincipalAccount.Address = core.Address{1} ctx.Header.Nonce = 10 @@ -80,7 +80,7 @@ func TestApply(t *testing.T) { require.Equal(t, ctx.PrincipalAccount.NextNonce, account.NextNonce) }) t.Run("ConsumeMaxGas", func(t *testing.T) { - ss := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + ss := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: ss} ctx.PrincipalAccount.Balance = 1000 @@ -97,7 +97,7 @@ func TestApply(t *testing.T) { require.Equal(t, ctx.Fee(), ctx.Header.MaxGas*ctx.Header.GasPrice) }) t.Run("PreserveTransferOrder", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Address = core.Address{1} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 1000 @@ -129,7 +129,7 @@ func TestRelay(t *testing.T) { remote = core.Address{'r', 'e', 'm'} ) t.Run("not spawned", func(t *testing.T) { - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: cache} call := func(remote core.Host) error { require.Fail(t, "not expected to be called") @@ -138,7 +138,7 @@ func TestRelay(t *testing.T) { require.ErrorIs(t, ctx.Relay(template, remote, call), core.ErrNotSpawned) }) t.Run("mismatched template", func(t *testing.T) { - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) require.NoError(t, cache.Update(core.Account{ Address: remote, TemplateAddress: &core.Address{'m', 'i', 's'}, @@ -166,7 +166,7 @@ func TestRelay(t *testing.T) { reg := registry.New() reg.Register(template, handler) - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) receiver2 := core.Address{'f'} const ( total = 1000 diff --git a/genvm/core/staged_cache_test.go b/genvm/core/staged_cache_test.go index 8a6b29519ea..a0189309555 100644 --- a/genvm/core/staged_cache_test.go +++ b/genvm/core/staged_cache_test.go @@ -6,11 +6,11 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/genvm/core" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestCacheGetCopies(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ss := core.NewStagedCache(core.DBLoader{db}) address := core.Address{1} account, err := ss.Get(address) @@ -23,7 +23,7 @@ func TestCacheGetCopies(t *testing.T) { } func TestCacheUpdatePreserveOrder(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ss := core.NewStagedCache(core.DBLoader{db}) order := []core.Address{{3}, {1}, {2}} for _, address := range order { diff --git a/genvm/templates/vault/vault_test.go b/genvm/templates/vault/vault_test.go index 28fef09d14d..e1f461fd9c8 100644 --- a/genvm/templates/vault/vault_test.go +++ b/genvm/templates/vault/vault_test.go @@ -9,7 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/genvm/core" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestVested(t *testing.T) { @@ -392,7 +392,7 @@ func TestSpend(t *testing.T) { } ctx := core.Context{ LayerID: types.LayerID(tc.lid), - Loader: core.NewStagedCache(core.DBLoader{Executor: sql.InMemory()}), + Loader: core.NewStagedCache(core.DBLoader{Executor: statesql.InMemory()}), Header: types.TxHeader{MaxSpend: math.MaxUint64}, PrincipalAccount: types.Account{ Address: owner, @@ -419,7 +419,7 @@ func TestSpend(t *testing.T) { } ctx := core.Context{ LayerID: types.LayerID(2), - Loader: core.NewStagedCache(core.DBLoader{Executor: sql.InMemory()}), + Loader: core.NewStagedCache(core.DBLoader{Executor: statesql.InMemory()}), Header: types.TxHeader{MaxSpend: math.MaxUint64}, PrincipalAccount: types.Account{ Address: owner, diff --git a/genvm/vm.go b/genvm/vm.go index 0636e0879e0..f00c01b721a 100644 --- a/genvm/vm.go +++ b/genvm/vm.go @@ -59,7 +59,7 @@ func WithConfig(cfg Config) Opt { } // New returns VM instance. -func New(db *sql.Database, opts ...Opt) *VM { +func New(db sql.StateDatabase, opts ...Opt) *VM { vm := &VM{ logger: zap.NewNop(), db: db, @@ -79,7 +79,7 @@ func New(db *sql.Database, opts ...Opt) *VM { // VM handles modifications to the account state. type VM struct { logger *zap.Logger - db *sql.Database + db sql.StateDatabase cfg Config registry *registry.Registry } diff --git a/genvm/vm_test.go b/genvm/vm_test.go index 4ee7934da8a..eb36dcbeab8 100644 --- a/genvm/vm_test.go +++ b/genvm/vm_test.go @@ -34,9 +34,9 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/templates/wallet" "github.com/spacemeshos/go-spacemesh/hash" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func testContext(lid types.LayerID) ApplyContext { @@ -48,7 +48,7 @@ func testContext(lid types.LayerID) ApplyContext { func newTester(tb testing.TB) *tester { return &tester{ TB: tb, - VM: New(sql.InMemory(), + VM: New(statesql.InMemory(), WithLogger(zaptest.NewLogger(tb)), WithConfig(Config{GasLimit: math.MaxUint64}), ), @@ -279,7 +279,7 @@ type tester struct { } func (t *tester) persistent() *tester { - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) t.Cleanup(func() { require.NoError(t, db.Close()) }) require.NoError(t, err) t.VM = New(db, WithLogger(zaptest.NewLogger(t)), @@ -2573,7 +2573,7 @@ func TestVestingData(t *testing.T) { spendAccountNonce := t2.nonces[0] spendAmount := uint64(1_000_000) - vm := New(sql.InMemory(), WithLogger(zaptest.NewLogger(t))) + vm := New(statesql.InMemory(), WithLogger(zaptest.NewLogger(t))) require.NoError(t, vm.ApplyGenesis( []core.Account{ { diff --git a/hare3/eligibility/oracle_test.go b/hare3/eligibility/oracle_test.go index d6ae2e4a582..88c684278a2 100644 --- a/hare3/eligibility/oracle_test.go +++ b/hare3/eligibility/oracle_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -46,14 +47,14 @@ func TestMain(m *testing.M) { type testOracle struct { *Oracle tb testing.TB - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data mBeacon *mocks.MockBeaconGetter mVerifier *MockvrfVerifier } func defaultOracle(tb testing.TB) *testOracle { - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ctrl := gomock.NewController(tb) diff --git a/hare3/hare.go b/hare3/hare.go index 8538243a18b..2f0076964c7 100644 --- a/hare3/hare.go +++ b/hare3/hare.go @@ -172,7 +172,7 @@ type nodeclock interface { func New( nodeclock nodeclock, pubsub pubsub.PublishSubsciber, - db *sql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, proposals *store.Store, verifier *signing.EdVerifier, @@ -234,7 +234,7 @@ type Hare struct { // dependencies nodeclock nodeclock pubsub pubsub.PublishSubsciber - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data proposals *store.Store verifier *signing.EdVerifier diff --git a/hare3/hare_test.go b/hare3/hare_test.go index 0030041ead5..76543ffafe6 100644 --- a/hare3/hare_test.go +++ b/hare3/hare_test.go @@ -31,6 +31,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/beacons" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -116,7 +117,7 @@ type node struct { vrfsigner *signing.VRFSigner atx *types.ActivationTx oracle *eligibility.Oracle - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data proposals *store.Store @@ -148,7 +149,7 @@ func (n *node) reuseSigner(signer *signing.EdSigner) *node { } func (n *node) withDb(tb testing.TB) *node { - n.db = sql.InMemoryTest(tb) + n.db = statesql.InMemoryTest(tb) n.atxsdata = atxsdata.New() n.proposals = store.New() return n @@ -891,7 +892,7 @@ func TestProposals(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() proposals := store.New() hare := New( diff --git a/hare3/malfeasance_test.go b/hare3/malfeasance_test.go index 7ff1cf52b1d..0382ad3741e 100644 --- a/hare3/malfeasance_test.go +++ b/hare3/malfeasance_test.go @@ -16,17 +16,18 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemoryTest(tb) + db := statesql.InMemoryTest(tb) observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/hare4/eligibility/oracle_test.go b/hare4/eligibility/oracle_test.go index 5dfdf877b6c..fac51a5b6b7 100644 --- a/hare4/eligibility/oracle_test.go +++ b/hare4/eligibility/oracle_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -46,14 +47,14 @@ func TestMain(m *testing.M) { type testOracle struct { *Oracle tb testing.TB - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data mBeacon *mocks.MockBeaconGetter mVerifier *MockvrfVerifier } func defaultOracle(tb testing.TB) *testOracle { - db := sql.InMemoryTest(tb) + db := statesql.InMemoryTest(tb) atxsdata := atxsdata.New() ctrl := gomock.NewController(tb) diff --git a/hare4/hare.go b/hare4/hare.go index 7789c6c723f..91f8b298bec 100644 --- a/hare4/hare.go +++ b/hare4/hare.go @@ -194,7 +194,7 @@ type nodeclock interface { func New( nodeclock nodeclock, pubsub pubsub.PublishSubsciber, - db *sql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, proposals *store.Store, verif verifier, @@ -263,7 +263,7 @@ type Hare struct { // dependencies nodeclock nodeclock pubsub pubsub.PublishSubsciber - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data proposals *store.Store verifier verifier diff --git a/hare4/hare_test.go b/hare4/hare_test.go index be619142a08..5e2dd169214 100644 --- a/hare4/hare_test.go +++ b/hare4/hare_test.go @@ -38,6 +38,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/beacons" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -125,7 +126,7 @@ type node struct { vrfsigner *signing.VRFSigner atx *types.ActivationTx oracle *eligibility.Oracle - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data proposals *store.Store @@ -159,7 +160,7 @@ func (n *node) reuseSigner(signer *signing.EdSigner) *node { } func (n *node) withDb(tb testing.TB) *node { - n.db = sql.InMemoryTest(tb) + n.db = statesql.InMemoryTest(tb) n.atxsdata = atxsdata.New() n.proposals = store.New() return n @@ -1066,7 +1067,7 @@ func TestProposals(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() proposals := store.New() hare := New( diff --git a/hare4/malfeasance_test.go b/hare4/malfeasance_test.go index 2bdaa119af0..6611e36ea51 100644 --- a/hare4/malfeasance_test.go +++ b/hare4/malfeasance_test.go @@ -16,17 +16,18 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/malfeasance/handler.go b/malfeasance/handler.go index 94779563d66..45acdfcd543 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -168,7 +168,7 @@ func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceGossip h.countInvalidProof(&p.MalfeasanceProof) return types.EmptyNodeID, errors.Join(err, pubsub.ErrValidationReject) } - if err := h.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { malicious, err := identities.IsMalicious(dbtx, nodeID) if err != nil { return fmt.Errorf("check known malicious: %w", err) diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index f6d87ef186f..86532ce6771 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -21,18 +21,19 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *Handler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase mockTrt *Mocktortoise } func newHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/mesh/executor_test.go b/mesh/executor_test.go index 30ca2abb017..c15367f715d 100644 --- a/mesh/executor_test.go +++ b/mesh/executor_test.go @@ -22,6 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -33,7 +34,7 @@ func TestMain(m *testing.M) { type testExecutor struct { tb testing.TB exec *mesh.Executor - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data mcs *mocks.MockconservativeState mvm *mocks.MockvmState @@ -43,7 +44,7 @@ func newTestExecutor(t *testing.T) *testExecutor { ctrl := gomock.NewController(t) te := &testExecutor{ tb: t, - db: sql.InMemory(), + db: statesql.InMemory(), atxsdata: atxsdata.New(), mvm: mocks.NewMockvmState(ctrl), mcs: mocks.NewMockconservativeState(ctrl), diff --git a/mesh/malfeasance_test.go b/mesh/malfeasance_test.go index 8e4c607bf89..6d473c517a5 100644 --- a/mesh/malfeasance_test.go +++ b/mesh/malfeasance_test.go @@ -16,17 +16,18 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/mesh/mesh.go b/mesh/mesh.go index b32bcc4605d..61f05a25d8a 100644 --- a/mesh/mesh.go +++ b/mesh/mesh.go @@ -36,7 +36,7 @@ import ( // Mesh is the logic layer above our mesh.DB database. type Mesh struct { logger *zap.Logger - cdb *sql.Database + cdb sql.StateDatabase atxsdata *atxsdata.Data clock layerClock @@ -59,7 +59,7 @@ type Mesh struct { // NewMesh creates a new instant of a mesh. func NewMesh( - db *sql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, c layerClock, trtl system.Tortoise, @@ -93,7 +93,7 @@ func NewMesh( } genesis := types.GetEffectiveGenesis() - if err = db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + if err = db.WithTx(context.Background(), func(dbtx sql.Transaction) error { if err = layers.SetProcessed(dbtx, genesis); err != nil { return fmt.Errorf("mesh init: %w", err) } @@ -373,7 +373,7 @@ func (msh *Mesh) applyResults(ctx context.Context, results []result.Layer) error return fmt.Errorf("execute block %v/%v: %w", layer.Layer, target, err) } } - if err := msh.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := msh.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { if err := layers.SetApplied(dbtx, layer.Layer, target); err != nil { return fmt.Errorf("set applied for %v/%v: %w", layer.Layer, target, err) } @@ -423,7 +423,7 @@ func (msh *Mesh) saveHareOutput(ctx context.Context, lid types.LayerID, bid type certs []certificates.CertValidity err error ) - if err = msh.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err = msh.cdb.WithTx(ctx, func(tx sql.Transaction) error { // check if a certificate has been generated or sync'ed. // - node generated the certificate when it collected enough certify messages // - hare outputs are processed in layer order. i.e. when hare fails for a previous layer N, @@ -545,7 +545,7 @@ func (msh *Mesh) AddBallot( var proof *wire.MalfeasanceProof // ballots.LayerBallotByNodeID and ballots.Add should be atomic // otherwise concurrent ballots.Add from the same smesher may not be noticed - if err := msh.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := msh.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { if !malicious { prev, err := ballots.LayerBallotByNodeID(dbtx, ballot.Layer, ballot.SmesherID) if err != nil && !errors.Is(err, sql.ErrNotFound) { diff --git a/mesh/mesh_test.go b/mesh/mesh_test.go index 290d4af93df..f32362948ca 100644 --- a/mesh/mesh_test.go +++ b/mesh/mesh_test.go @@ -24,6 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -36,7 +37,7 @@ const ( type testMesh struct { *Mesh - db *sql.Database + db sql.StateDatabase // it is used in malfeasence.Validate, which is called in the tests cdb *datastore.CachedDB atxsdata *atxsdata.Data @@ -50,7 +51,7 @@ func createTestMesh(t *testing.T) *testMesh { t.Helper() types.SetLayersPerEpoch(3) lg := zaptest.NewLogger(t) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ctrl := gomock.NewController(t) tm := &testMesh{ diff --git a/miner/active_set_generator_test.go b/miner/active_set_generator_test.go index 660952d03c0..c1e381c2ce9 100644 --- a/miner/active_set_generator_test.go +++ b/miner/active_set_generator_test.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/activeset" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type expect struct { @@ -65,7 +66,7 @@ func unixPtr(sec, nsec int64) *time.Time { func newTesterActiveSetGenerator(tb testing.TB, cfg config) *testerActiveSetGenerator { var ( - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ctrl = gomock.NewController(tb) @@ -97,8 +98,8 @@ type testerActiveSetGenerator struct { tb testing.TB gen *activeSetGenerator - db *sql.Database - localdb *localsql.Database + db sql.StateDatabase + localdb sql.LocalDatabase atxsdata *atxsdata.Data ctrl *gomock.Controller clock *mocks.MocklayerClock diff --git a/miner/proposal_builder_test.go b/miner/proposal_builder_test.go index 6df1c625924..1e9925bdf71 100644 --- a/miner/proposal_builder_test.go +++ b/miner/proposal_builder_test.go @@ -25,7 +25,6 @@ import ( pmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/proposals" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" @@ -35,6 +34,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -749,7 +749,7 @@ func TestBuild(t *testing.T) { publisher = pmocks.NewMockPublisher(ctrl) tortoise = mocks.NewMockvotesEncoder(ctrl) syncer = smocks.NewMockSyncStateProvider(ctrl) - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ) @@ -905,7 +905,7 @@ func TestStartStop(t *testing.T) { publisher = pmocks.NewMockPublisher(ctrl) tortoise = mocks.NewMockvotesEncoder(ctrl) syncer = smocks.NewMockSyncStateProvider(ctrl) - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ) diff --git a/node/node.go b/node/node.go index a2d6573f667..5643d6a6d25 100644 --- a/node/node.go +++ b/node/node.go @@ -77,7 +77,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" dbmetrics "github.com/spacemeshos/go-spacemesh/sql/metrics" - "github.com/spacemeshos/go-spacemesh/sql/migrations" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/atxsync" "github.com/spacemeshos/go-spacemesh/syncer/blockssync" @@ -380,10 +380,10 @@ type App struct { fileLock *flock.Flock signers []*signing.EdSigner Config *config.Config - db *sql.Database + db sql.StateDatabase cachedDB *datastore.CachedDB dbMetrics *dbmetrics.DBMetricsCollector - localDB *localsql.Database + localDB sql.LocalDatabase grpcPublicServer *grpcserver.Server grpcPrivateServer *grpcserver.Server grpcPostServer *grpcserver.Server @@ -1950,18 +1950,20 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { return fmt.Errorf("failed to create %s: %w", dbPath, err) } dbLog := app.addLogger(StateDbLogger, lg).Zap() - m21 := migrations.New0021Migration(dbLog, 1_000_000) - migrations, err := sql.StateMigrations() + schema, err := statesql.Schema() if err != nil { - return fmt.Errorf("failed to load migrations: %w", err) + return fmt.Errorf("error loading db schema: %w", err) + } + if len(app.Config.DatabaseSkipMigrations) > 0 { + schema.SkipMigrations(app.Config.DatabaseSkipMigrations...) } dbopts := []sql.Opt{ sql.WithLogger(dbLog), - sql.WithMigrations(migrations), - sql.WithMigration(m21), + sql.WithDatabaseSchema(schema), sql.WithConnections(app.Config.DatabaseConnections), sql.WithLatencyMetering(app.Config.DatabaseLatencyMetering), sql.WithVacuumState(app.Config.DatabaseVacuumState), + sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), sql.WithQueryCache(app.Config.DatabaseQueryCache), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ atxs.CacheKindEpochATXs: app.Config.DatabaseQueryCacheSizes.EpochATXs, @@ -1969,10 +1971,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { activesets.CacheKindActiveSetBlob: app.Config.DatabaseQueryCacheSizes.ActiveSetBlob, }), } - if len(app.Config.DatabaseSkipMigrations) > 0 { - dbopts = append(dbopts, sql.WithSkipMigrations(app.Config.DatabaseSkipMigrations...)) - } - sqlDB, err := sql.Open("file:"+filepath.Join(dbPath, dbFile), dbopts...) + sqlDB, err := statesql.Open("file:"+filepath.Join(dbPath, dbFile), dbopts...) if err != nil { return fmt.Errorf("open sqlite db %w", err) } @@ -2009,14 +2008,10 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { datastore.WithConsensusCache(app.atxsdata), ) - migrations, err = sql.LocalMigrations() - if err != nil { - return fmt.Errorf("load local migrations: %w", err) - } localDB, err := localsql.Open("file:"+filepath.Join(dbPath, localDbFile), sql.WithLogger(dbLog), - sql.WithMigrations(migrations), sql.WithConnections(app.Config.DatabaseConnections), + sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), ) if err != nil { return fmt.Errorf("open sqlite db %w", err) diff --git a/node/node_version_check_test.go b/node/node_version_check_test.go index 8e8d8a6d766..affc503f86e 100644 --- a/node/node_version_check_test.go +++ b/node/node_version_check_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/config" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestUpgradeToV15(t *testing.T) { @@ -37,10 +38,15 @@ func TestUpgradeToV15(t *testing.T) { uri := path.Join(cfg.DataDir(), localDbFile) - migrations, err := sql.LocalMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - db, err := sql.Open(uri, sql.WithMigrations(migrations[:2])) + schema.Migrations = schema.Migrations[:2] + + db, err := statesql.Open(uri, + sql.WithDatabaseSchema(schema), + sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift()) require.NoError(t, err) require.NoError(t, db.Close()) diff --git a/proposals/handler.go b/proposals/handler.go index c4497608660..c1bac838694 100644 --- a/proposals/handler.go +++ b/proposals/handler.go @@ -50,7 +50,7 @@ type Handler struct { logger *zap.Logger cfg Config - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data activeSets *lru.Cache[types.Hash32, uint64] edVerifier *signing.EdVerifier @@ -109,7 +109,7 @@ func WithConfig(cfg Config) Opt { // NewHandler creates new Handler. func NewHandler( - db *sql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, proposals proposalsConsumer, edVerifier *signing.EdVerifier, diff --git a/proposals/handler_test.go b/proposals/handler_test.go index 37485e74287..8a466fba354 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -31,6 +31,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/blocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -93,7 +94,7 @@ func fullMockSet(tb testing.TB) *mockSet { func createTestHandler(t *testing.T) *testHandler { types.SetLayersPerEpoch(layersPerEpoch) ms := fullMockSet(t) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ms.md.EXPECT().GetBallot(gomock.Any()).AnyTimes().DoAndReturn(func(id types.BallotID) *tortoise.BallotData { ballot, err := ballots.Get(db, id) @@ -236,7 +237,7 @@ func createProposal(t *testing.T, opts ...any) *types.Proposal { return p } -func createAtx(t *testing.T, db *sql.Database, epoch types.EpochID, atxID types.ATXID, nodeID types.NodeID) { +func createAtx(t *testing.T, db sql.StateDatabase, epoch types.EpochID, atxID types.ATXID, nodeID types.NodeID) { atx := &types.ActivationTx{ PublishEpoch: epoch, NumUnits: 1, diff --git a/prune/prune.go b/prune/prune.go index 0f7ddb4218c..2fe9ef20c41 100644 --- a/prune/prune.go +++ b/prune/prune.go @@ -22,7 +22,7 @@ func WithLogger(logger *zap.Logger) Opt { } } -func New(db *sql.Database, safeDist uint32, activesetEpoch types.EpochID, opts ...Opt) *Pruner { +func New(db sql.StateDatabase, safeDist uint32, activesetEpoch types.EpochID, opts ...Opt) *Pruner { p := &Pruner{ logger: zap.NewNop(), db: db, @@ -37,7 +37,7 @@ func New(db *sql.Database, safeDist uint32, activesetEpoch types.EpochID, opts . type Pruner struct { logger *zap.Logger - db *sql.Database + db sql.StateDatabase safeDist uint32 activesetEpoch types.EpochID } diff --git a/prune/prune_test.go b/prune/prune_test.go index efefdf6dc14..b914f63a50f 100644 --- a/prune/prune_test.go +++ b/prune/prune_test.go @@ -11,13 +11,14 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) func TestPrune(t *testing.T) { types.SetLayersPerEpoch(3) - db := sql.InMemory() + db := statesql.InMemory() current := types.LayerID(10) lyrProps := make([]*types.Proposal, 0, current) diff --git a/sql/accounts/accounts_test.go b/sql/accounts/accounts_test.go index 21d34dbea2c..4556ad332a6 100644 --- a/sql/accounts/accounts_test.go +++ b/sql/accounts/accounts_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/builder" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func genSeq(address types.Address, n int) []*types.Account { @@ -21,7 +22,7 @@ func genSeq(address types.Address, n int) []*types.Account { func TestUpdate(t *testing.T) { address := types.Address{1, 2, 3} - db := sql.InMemory() + db := statesql.InMemory() seq := genSeq(address, 2) for _, update := range seq { require.NoError(t, Update(db, update)) @@ -34,7 +35,7 @@ func TestUpdate(t *testing.T) { func TestHas(t *testing.T) { address := types.Address{1, 2, 3} - db := sql.InMemory() + db := statesql.InMemory() has, err := Has(db, address) require.NoError(t, err) require.False(t, has) @@ -50,7 +51,7 @@ func TestHas(t *testing.T) { func TestRevert(t *testing.T) { address := types.Address{1, 1} seq := genSeq(address, 10) - db := sql.InMemory() + db := statesql.InMemory() for _, update := range seq { require.NoError(t, Update(db, update)) } @@ -62,7 +63,7 @@ func TestRevert(t *testing.T) { } func TestAll(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() addresses := []types.Address{{1, 1}, {2, 2}, {3, 3}} n := []int{10, 7, 20} for i, address := range addresses { @@ -81,7 +82,7 @@ func TestAll(t *testing.T) { } func TestSnapshot(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Snapshot(db, types.LayerID(1)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -108,7 +109,7 @@ func TestSnapshot(t *testing.T) { } func TestIterateAccountsOps(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i := 0; i < 100; i++ { addr := types.Address{} diff --git a/sql/activesets/activesets_test.go b/sql/activesets/activesets_test.go index 8828fa417e6..acb8fe53ce9 100644 --- a/sql/activesets/activesets_test.go +++ b/sql/activesets/activesets_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestActiveSet(t *testing.T) { @@ -19,7 +20,7 @@ func TestActiveSet(t *testing.T) { Epoch: 2, Set: []types.ATXID{{1}, {2}}, } - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, Add(db, ids[0], set)) require.ErrorIs(t, Add(db, ids[0], set), sql.ErrObjectExists) @@ -68,7 +69,7 @@ func TestCachedActiveSet(t *testing.T) { Epoch: 2, Set: []types.ATXID{{3}, {4}}, } - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) require.NoError(t, Add(db, ids[0], set0)) require.NoError(t, Add(db, ids[1], set1)) @@ -78,12 +79,12 @@ func TestCachedActiveSet(t *testing.T) { for i := 0; i < 3; i++ { require.NoError(t, LoadBlob(ctx, db, ids[0].Bytes(), &b)) require.Equal(t, codec.MustEncode(set0), b.Bytes) - require.Equal(t, 3, db.QueryCount()) + require.Equal(t, 3, db.QueryCount(), "ids[0]: QueryCount at i=%d", i) } for i := 0; i < 3; i++ { require.NoError(t, LoadBlob(ctx, db, ids[1].Bytes(), &b)) require.Equal(t, codec.MustEncode(set1), b.Bytes) - require.Equal(t, 4, db.QueryCount()) + require.Equal(t, 4, db.QueryCount(), "ids[1]: QueryCount at i=%d", i) } } diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index eb507fe969a..b66ef3157e4 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -18,6 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 5 @@ -30,7 +31,7 @@ func TestMain(m *testing.M) { } func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -52,7 +53,7 @@ func TestGet(t *testing.T) { } func TestAll(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() var expected []types.ATXID for i := 0; i < 3; i++ { @@ -69,7 +70,7 @@ func TestAll(t *testing.T) { } func TestHasID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -92,7 +93,7 @@ func TestHasID(t *testing.T) { } func Test_IdentityExists(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -110,7 +111,7 @@ func Test_IdentityExists(t *testing.T) { } func TestGetFirstIDByNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -143,7 +144,7 @@ func TestGetFirstIDByNodeID(t *testing.T) { } func TestLatestN(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -235,7 +236,7 @@ func TestLatestN(t *testing.T) { } func TestGetByEpochAndNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -267,7 +268,7 @@ func TestGetByEpochAndNodeID(t *testing.T) { } func TestGetLastIDByNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -301,7 +302,7 @@ func TestGetLastIDByNodeID(t *testing.T) { } func TestGetIDByEpochAndNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -345,7 +346,7 @@ func TestGetIDByEpochAndNodeID(t *testing.T) { } func TestGetIDsByEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig1, err := signing.NewEdSigner() @@ -381,7 +382,7 @@ func TestGetIDsByEpoch(t *testing.T) { } func TestGetIDsByEpochCached(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) ctx := context.Background() sig1, err := signing.NewEdSigner() @@ -430,7 +431,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.Equal(t, 11, db.QueryCount()) } - require.NoError(t, db.WithTx(context.Background(), func(tx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { atxs.Add(tx, atx5, types.AtxBlob{}) return nil })) @@ -442,7 +443,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.ElementsMatch(t, []types.ATXID{atx4.ID(), atx5.ID()}, ids3) require.Equal(t, 13, db.QueryCount()) // not incremented after Add - require.Error(t, db.WithTx(context.Background(), func(tx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { atxs.Add(tx, atx6, types.AtxBlob{}) return errors.New("fail") // rollback })) @@ -455,7 +456,7 @@ func TestGetIDsByEpochCached(t *testing.T) { } func Test_IterateAtxsWithMalfeasance(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() e1 := types.EpochID(1) m := make(map[types.ATXID]bool) @@ -485,7 +486,7 @@ func Test_IterateAtxsWithMalfeasance(t *testing.T) { } func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() e1 := types.EpochID(1) m := make(map[types.ATXID]bool) @@ -516,7 +517,7 @@ func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { func TestVRFNonce(t *testing.T) { // Arrange - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -549,7 +550,7 @@ func TestVRFNonce(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig, err := signing.NewEdSigner() @@ -601,7 +602,7 @@ func TestLoadBlob(t *testing.T) { } func TestLoadBlob_DefaultsToV1(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -618,7 +619,7 @@ func TestLoadBlob_DefaultsToV1(t *testing.T) { } func TestGetBlobCached(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) ctx := context.Background() sig, err := signing.NewEdSigner() @@ -640,7 +641,7 @@ func TestGetBlobCached(t *testing.T) { // Test that we don't put in the cache a reference to the blob that was passed to LoadBlob. // Each cache entry must use a unique slice for the blob. func TestGetBlobCached_CacheEntriesAreDistinct(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) atx := types.ActivationTx{} atx.SetID(types.RandomATXID()) @@ -672,7 +673,7 @@ func TestGetBlobCached_CacheEntriesAreDistinct(t *testing.T) { // Test that the cached blob is not shared with the caller // but copied into the provided blob. func TestGetBlobCached_OverwriteSafety(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) atx := types.ActivationTx{} atx.SetID(types.RandomATXID()) blob := types.AtxBlob{Blob: []byte("original blob")} @@ -690,7 +691,7 @@ func TestGetBlobCached_OverwriteSafety(t *testing.T) { } func TestCachedBlobEviction(t *testing.T) { - db := sql.InMemory( + db := statesql.InMemory( sql.WithQueryCache(true), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ atxs.CacheKindATXBlob: 10, @@ -731,7 +732,7 @@ func TestCachedBlobEviction(t *testing.T) { } func TestCheckpointATX(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig, err := signing.NewEdSigner() @@ -778,7 +779,7 @@ func TestCheckpointATX(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nonExistingATXID := types.ATXID(types.CalcHash32([]byte("0"))) _, err := atxs.Get(db, nonExistingATXID) @@ -850,7 +851,7 @@ type header struct { filteredOut bool } -func createAtx(tb testing.TB, db *sql.Database, hdr header) (types.ATXID, *signing.EdSigner) { +func createAtx(tb testing.TB, db sql.StateDatabase, hdr header) (types.ATXID, *signing.EdSigner) { sig, err := signing.NewEdSigner() require.NoError(tb, err) @@ -964,7 +965,7 @@ func TestGetIDWithMaxHeight(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() var sigs []*signing.EdSigner var ids []types.ATXID filtered := make(map[types.ATXID]struct{}) @@ -1005,7 +1006,7 @@ func TestLatest(t *testing.T) { {"out of order", []uint32{3, 4, 1, 2}, 4}, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i, epoch := range tc.epochs { full := &types.ActivationTx{ PublishEpoch: types.EpochID(epoch), @@ -1024,7 +1025,7 @@ func TestLatest(t *testing.T) { } func Test_PrevATXCollisions(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -1077,13 +1078,13 @@ func TestCoinbase(t *testing.T) { t.Parallel() t.Run("not found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.Coinbase(db, types.NodeID{}) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) atx, blob := newAtx(t, sig, withCoinbase(types.Address{1, 2, 3})) @@ -1094,7 +1095,7 @@ func TestCoinbase(t *testing.T) { }) t.Run("picks last", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) atx1, blob1 := newAtx(t, sig, withPublishEpoch(1), withCoinbase(types.Address{1, 2, 3})) @@ -1111,13 +1112,13 @@ func TestUnits(t *testing.T) { t.Parallel() t.Run("ATX not found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.Units(db, types.RandomATXID(), types.RandomNodeID()) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("smesher has no units in ATX", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atxID := types.RandomATXID() require.NoError(t, atxs.SetPost(db, atxID, types.EmptyATXID, 0, types.RandomNodeID(), 10)) _, err := atxs.Units(db, atxID, types.RandomNodeID()) @@ -1125,7 +1126,7 @@ func TestUnits(t *testing.T) { }) t.Run("returns units for given smesher in given ATX", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atxID := types.RandomATXID() units := map[types.NodeID]uint32{ {1, 2, 3}: 10, @@ -1154,12 +1155,12 @@ func Test_AtxWithPrevious(t *testing.T) { prev := types.RandomATXID() t.Run("no atxs", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.AtxWithPrevious(db, prev, sig.NodeID()) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("finds other ATX with same previous", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() prev := types.RandomATXID() atx, blob := newAtx(t, sig) @@ -1171,7 +1172,7 @@ func Test_AtxWithPrevious(t *testing.T) { require.Equal(t, atx.ID(), id) }) t.Run("finds other ATX with same previous (empty)", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atx, blob := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx, blob)) @@ -1182,7 +1183,7 @@ func Test_AtxWithPrevious(t *testing.T) { require.Equal(t, atx.ID(), id) }) t.Run("same previous used by 2 IDs in two ATXs", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig2, err := signing.NewEdSigner() require.NoError(t, err) @@ -1212,14 +1213,14 @@ func Test_FindDoublePublish(t *testing.T) { require.NoError(t, err) t.Run("no atxs", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.FindDoublePublish(db, types.RandomNodeID(), 0) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("no double publish", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() // one atx atx0, blob := newAtx(t, sig, withPublishEpoch(1)) @@ -1239,7 +1240,7 @@ func Test_FindDoublePublish(t *testing.T) { }) t.Run("double publish", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx0, blob := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx0, blob)) @@ -1259,7 +1260,7 @@ func Test_FindDoublePublish(t *testing.T) { }) t.Run("double publish different smesher", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx0Signer, err := signing.NewEdSigner() require.NoError(t, err) @@ -1287,13 +1288,13 @@ func Test_MergeConflict(t *testing.T) { t.Parallel() t.Run("no atxs", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.MergeConflict(db, types.RandomATXID(), 0) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("no conflict", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() marriage := types.RandomATXID() atx := types.ActivationTx{MarriageATX: &marriage} @@ -1305,7 +1306,7 @@ func Test_MergeConflict(t *testing.T) { }) t.Run("finds conflict", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() marriage := types.RandomATXID() atx0 := types.ActivationTx{MarriageATX: &marriage} @@ -1337,12 +1338,12 @@ func Test_MergeConflict(t *testing.T) { func Test_Previous(t *testing.T) { t.Run("not found", func(t *testing.T) { - db := sql.InMemoryTest(t) + db := statesql.InMemoryTest(t) _, err := atxs.Previous(db, types.RandomATXID()) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("returns ATXs in order", func(t *testing.T) { - db := sql.InMemoryTest(t) + db := statesql.InMemoryTest(t) atx := types.RandomATXID() var previousAtxs []types.ATXID diff --git a/sql/ballots/ballots_test.go b/sql/ballots/ballots_test.go index 519dbf80aa5..8c49b67b86c 100644 --- a/sql/ballots/ballots_test.go +++ b/sql/ballots/ballots_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 3 @@ -26,7 +27,7 @@ func TestMain(m *testing.M) { } func TestLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) pub := types.BytesToNodeID([]byte{1, 1, 1}) ballots := []types.Ballot{ @@ -65,7 +66,7 @@ func TestLayer(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nodeID := types.RandomNodeID() ballot := types.NewExistingBallot(types.BallotID{1}, types.RandomEdSignature(), nodeID, types.LayerID(0)) _, err := Get(db, ballot.ID()) @@ -85,7 +86,7 @@ func TestAdd(t *testing.T) { } func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ballot := types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, types.EmptyNodeID, types.LayerID(0)) exists, err := Has(db, ballot.ID()) @@ -99,7 +100,7 @@ func TestHas(t *testing.T) { } func TestLatest(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() latest, err := LatestLayer(db) require.NoError(t, err) require.Equal(t, types.LayerID(0), latest) @@ -123,7 +124,7 @@ func TestLatest(t *testing.T) { } func TestLayerBallotBySmesher(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) nodeID1 := types.RandomNodeID() nodeID2 := types.RandomNodeID() @@ -157,7 +158,7 @@ func newAtx(signer *signing.EdSigner, layerID types.LayerID) *types.ActivationTx } func TestFirstInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(layersPerEpoch * 2) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -284,7 +285,7 @@ func TestAllFirstInEpoch(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() for _, ballot := range tc.ballots { require.NoError(t, Add(db, &ballot)) } @@ -300,7 +301,7 @@ func TestAllFirstInEpoch(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() ballot1 := types.NewExistingBallot( diff --git a/sql/beacons/beacons_test.go b/sql/beacons/beacons_test.go index 1d5648d30d6..8a012e54ba9 100644 --- a/sql/beacons/beacons_test.go +++ b/sql/beacons/beacons_test.go @@ -7,12 +7,13 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const baseEpoch = 3 func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() beacons := []types.Beacon{ types.HexToBeacon("0x1"), @@ -35,7 +36,7 @@ func TestGet(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Get(db, types.EpochID(baseEpoch)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -50,7 +51,7 @@ func TestAdd(t *testing.T) { } func TestSet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Get(db, types.EpochID(baseEpoch)) require.ErrorIs(t, err, sql.ErrNotFound) diff --git a/sql/blocks/blocks_test.go b/sql/blocks/blocks_test.go index 8b64d2ab65c..038b542caab 100644 --- a/sql/blocks/blocks_test.go +++ b/sql/blocks/blocks_test.go @@ -10,10 +10,11 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestAddGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1, 1}, types.InnerBlock{LayerIndex: types.LayerID(1)}, @@ -26,7 +27,7 @@ func TestAddGet(t *testing.T) { } func TestAlreadyExists(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1}, types.InnerBlock{}, @@ -36,7 +37,7 @@ func TestAlreadyExists(t *testing.T) { } func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1}, types.InnerBlock{}, @@ -52,7 +53,7 @@ func TestHas(t *testing.T) { } func TestValidity(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -86,7 +87,7 @@ func TestValidity(t *testing.T) { } func TestLayerFilter(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -122,7 +123,7 @@ func TestLayerFilter(t *testing.T) { } func TestLayerOrdered(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -153,7 +154,7 @@ func TestLayerOrdered(t *testing.T) { } func TestContextualValidity(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -197,7 +198,7 @@ func TestContextualValidity(t *testing.T) { } func TestGetLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid1 := types.LayerID(11) block1 := types.NewExistingBlock( types.BlockID{1, 1}, @@ -222,12 +223,12 @@ func TestGetLayer(t *testing.T) { func TestLastValid(t *testing.T) { t.Run("empty", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := LastValid(db) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("all valid", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blocks := map[types.BlockID]struct { lid types.LayerID }{ @@ -248,7 +249,7 @@ func TestLastValid(t *testing.T) { require.Equal(t, 33, int(last)) }) t.Run("last is invalid", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blocks := map[types.BlockID]struct { invalid bool lid types.LayerID @@ -274,7 +275,7 @@ func TestLastValid(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lid1 := types.LayerID(11) @@ -315,7 +316,7 @@ func TestLoadBlob(t *testing.T) { } func TestLayerForMangledBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := db.Exec("insert into blocks (id, layer, block) values (?1, ?2, ?3);", func(stmt *sql.Statement) { stmt.BindBytes(1, []byte(`mangled-block-id`)) diff --git a/sql/certificates/certs_test.go b/sql/certificates/certs_test.go index cfe4ff9104e..2056f6b0919 100644 --- a/sql/certificates/certs_test.go +++ b/sql/certificates/certs_test.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 5 @@ -44,7 +45,7 @@ func makeCert(lid types.LayerID, bid types.BlockID) *types.Certificate { } func TestCertificates(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) got, err := Get(db, lid) @@ -94,7 +95,7 @@ func TestCertificates(t *testing.T) { } func TestHareOutput(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) ho, err := GetHareOutput(db, lid) @@ -132,7 +133,7 @@ func TestHareOutput(t *testing.T) { } func TestCertifiedBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lyrBlocks := map[types.LayerID]types.BlockID{ types.LayerID(layersPerEpoch - 1): {1}, // epoch 0 @@ -161,7 +162,7 @@ func TestCertifiedBlock(t *testing.T) { } func TestDeleteCert(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, Add(db, types.LayerID(2), &types.Certificate{BlockID: types.BlockID{2}})) require.NoError(t, Add(db, types.LayerID(3), &types.Certificate{BlockID: types.BlockID{3}})) require.NoError(t, Add(db, types.LayerID(4), &types.Certificate{BlockID: types.BlockID{4}})) @@ -177,7 +178,7 @@ func TestDeleteCert(t *testing.T) { } func TestFirstInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lyrBlocks := map[types.LayerID]types.BlockID{ types.LayerID(layersPerEpoch - 1): {1}, // epoch 0 diff --git a/sql/database.go b/sql/database.go index e20a393b632..6e89049ca66 100644 --- a/sql/database.go +++ b/sql/database.go @@ -6,12 +6,9 @@ import ( "errors" "fmt" "maps" - "slices" - "sort" "strings" "sync" "sync/atomic" - "testing" "time" sqlite "github.com/go-llsqlite/crawshaw" @@ -31,6 +28,9 @@ var ( ErrObjectExists = errors.New("database: object exists") // ErrTooNew is returned if database version is newer than expected. ErrTooNew = errors.New("database version is too new") + // ErrOldSchema is returned when the database version differs from the expected one + // and migrations are disabled. + ErrOldSchema = errors.New("old database version") ) const ( @@ -38,7 +38,7 @@ const ( beginImmediate = "BEGIN IMMEDIATE;" ) -//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./database.go +//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go github.com/spacemeshos/go-spacemesh/sql Executor // Executor is an interface for executing raw statement. type Executor interface { @@ -62,29 +62,27 @@ type Encoder func(*Statement) type Decoder func(*Statement) bool func defaultConf() *conf { - migrations, err := StateMigrations() - if err != nil { - panic(err) - } - return &conf{ - connections: 16, - migrations: migrations, - skipMigration: map[int]struct{}{}, - logger: zap.NewNop(), + enableMigrations: true, + connections: 16, + logger: zap.NewNop(), + schema: &Schema{}, } } type conf struct { - flags sqlite.OpenFlags - connections int - skipMigration map[int]struct{} - vacuumState int - migrations []Migration - enableLatency bool - cache bool - cacheSizes map[QueryCacheKind]int - logger *zap.Logger + enableMigrations bool + forceFresh bool + forceMigrations bool + connections int + vacuumState int + enableLatency bool + cache bool + cacheSizes map[QueryCacheKind]int + logger *zap.Logger + schema *Schema + allowSchemaDrift bool + ignoreSchemaDrift bool } // WithConnections overwrites number of pooled connections. @@ -94,48 +92,18 @@ func WithConnections(n int) Opt { } } +// WithLogger specifies logger for the database. func WithLogger(logger *zap.Logger) Opt { return func(c *conf) { c.logger = logger } } -// WithMigrations overwrites embedded migrations. -// Migrations are sorted by order before applying. -func WithMigrations(migrations []Migration) Opt { +// WithMigrationsDisabled disables migrations for the database. +// The migrations are enabled by default. +func WithMigrationsDisabled() Opt { return func(c *conf) { - sort.Slice(migrations, func(i, j int) bool { - return migrations[i].Order() < migrations[j].Order() - }) - c.migrations = migrations - } -} - -// WithMigration adds migration to the list of migrations. -// It will overwrite an existing migration with the same order. -func WithMigration(migration Migration) Opt { - return func(c *conf) { - for i, m := range c.migrations { - if m.Order() == migration.Order() { - c.migrations[i] = migration - return - } - if m.Order() > migration.Order() { - c.migrations = slices.Insert(c.migrations, i, migration) - return - } - } - c.migrations = append(c.migrations, migration) - } -} - -// WithSkipMigrations will update database version with executing associated migrations. -// It should be used at your own risk. -func WithSkipMigrations(i ...int) Opt { - return func(c *conf) { - for _, index := range i { - c.skipMigration[index] = struct{}{} - } + c.enableMigrations = false } } @@ -173,28 +141,57 @@ func WithQueryCacheSizes(sizes map[QueryCacheKind]int) Opt { } } +// WithForceMigrations forces database to run all the migrations instead +// of using a schema snapshot in case of a fresh database. +func WithForceMigrations(force bool) Opt { + return func(c *conf) { + c.forceMigrations = true + } +} + +// WithSchema specifies database schema script. +func WithDatabaseSchema(schema *Schema) Opt { + return func(c *conf) { + c.schema = schema + } +} + +// WithAllowSchemaDrift prevents Open from failing upon schema +// drift. A warning is printed instead. +func WithAllowSchemaDrift(allow bool) Opt { + return func(c *conf) { + c.allowSchemaDrift = allow + } +} + +func WithIgnoreSchemaDrift() Opt { + return func(c *conf) { + c.ignoreSchemaDrift = true + } +} + +func withForceFresh() Opt { + return func(c *conf) { + c.forceFresh = true + } +} + // Opt for configuring database. type Opt func(c *conf) -// InMemory database for testing. -// Please use InMemoryTest for automatic closing of the returned db during `tb.Cleanup`. -func InMemory(opts ...Opt) *Database { - opts = append(opts, WithConnections(1)) - db, err := Open("file::memory:?mode=memory", opts...) - if err != nil { - panic(err) - } - return db +// OpenInMemory creates an in-memory database. +func OpenInMemory(opts ...Opt) (*sqliteDatabase, error) { + opts = append(opts, WithConnections(1), withForceFresh()) + return Open("file::memory:?mode=memory", opts...) } -// InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. -func InMemoryTest(tb testing.TB, opts ...Opt) *Database { - opts = append(opts, WithConnections(1)) - db, err := Open("file::memory:?mode=memory", opts...) +// InMemory creates an in-memory database for testing and panics if +// there's an error. +func InMemory(opts ...Opt) *sqliteDatabase { + db, err := OpenInMemory(opts...) if err != nil { panic(err) } - tb.Cleanup(func() { db.Close() }) return db } @@ -203,82 +200,92 @@ func InMemoryTest(tb testing.TB, opts ...Opt) *Database { // Database is opened in WAL mode and pragma synchronous=normal. // https://sqlite.org/wal.html // https://www.sqlite.org/pragma.html#pragma_synchronous -func Open(uri string, opts ...Opt) (*Database, error) { +func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { config := defaultConf() for _, opt := range opts { opt(config) } - pool, err := sqlitex.Open(uri, config.flags, config.connections) + logger := config.logger.With(zap.String("uri", uri)) + var flags sqlite.OpenFlags + if !config.forceFresh { + flags = sqlite.SQLITE_OPEN_READWRITE | + sqlite.SQLITE_OPEN_WAL | + sqlite.SQLITE_OPEN_URI | + sqlite.SQLITE_OPEN_NOMUTEX + } + freshDB := config.forceFresh + pool, err := sqlitex.Open(uri, flags, config.connections) if err != nil { - return nil, fmt.Errorf("open db %s: %w", uri, err) + if config.forceFresh || sqlite.ErrCode(err) != sqlite.SQLITE_CANTOPEN { + return nil, fmt.Errorf("open db %s: %w", uri, err) + } + flags |= sqlite.SQLITE_OPEN_CREATE + freshDB = true + pool, err = sqlitex.Open(uri, flags, config.connections) + if err != nil { + return nil, fmt.Errorf("create db %s: %w", uri, err) + } } - db := &Database{pool: pool} + db := &sqliteDatabase{pool: pool} if config.enableLatency { db.latency = newQueryLatency() } - //nolint:nestif - if config.migrations != nil { - before, err := version(db) - if err != nil { - return nil, err - } - after := 0 - if len(config.migrations) > 0 { - after = config.migrations[len(config.migrations)-1].Order() + if freshDB && !config.forceMigrations { + if err := config.schema.Apply(db); err != nil { + return nil, errors.Join( + fmt.Errorf("error running schema script: %w", err), + db.Close()) } - if before > after { - pool.Close() - config.logger.Error("database version is newer than expected - downgrade is not supported", - zap.String("uri", uri), + } else { + before, after, err := config.schema.CheckDBVersion(logger, db) + switch { + case err != nil: + return nil, errors.Join(err, db.Close()) + case before != after && config.enableMigrations: + logger.Info("running migrations", zap.Int("current version", before), zap.Int("target version", after), ) - return nil, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) - } - config.logger.Info("running migrations", - zap.String("uri", uri), - zap.Int("current version", before), - zap.Int("target version", after), - ) - for i, m := range config.migrations { - if m.Order() <= before { - continue - } - if err := db.WithTx(context.Background(), func(tx *Tx) error { - if _, ok := config.skipMigration[m.Order()]; !ok { - if err := m.Apply(tx); err != nil { - for j := i; j >= 0 && config.migrations[j].Order() > before; j-- { - if e := config.migrations[j].Rollback(); e != nil { - err = errors.Join(err, fmt.Errorf("rollback %s: %w", m.Name(), e)) - break - } - } - - return fmt.Errorf("apply %s: %w", m.Name(), err) - } - } - // version is set intentionally even if actual migration was skipped - if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.Order()), nil, nil); err != nil { - return fmt.Errorf("update user_version to %d: %w", m.Order(), err) - } - return nil - }); err != nil { - err = errors.Join(err, db.Close()) - return nil, err + if err := config.schema.Migrate( + logger, db, before, config.vacuumState, + ); err != nil { + return nil, errors.Join(err, db.Close()) } - - if config.vacuumState != 0 && before <= config.vacuumState { - if err := Vacuum(db); err != nil { - err = errors.Join(err, db.Close()) - return nil, err - } - } - before = m.Order() + case before != after: + logger.Error("database version is too old", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return nil, errors.Join( + fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after), + db.Close()) } + } + if !config.ignoreSchemaDrift { + loaded, err := LoadDBSchemaScript(db) + if err != nil { + return nil, errors.Join( + fmt.Errorf("error loading database schema: %w", err), + db.Close()) + } + diff := config.schema.Diff(loaded) + switch { + case diff == "": // ok + case config.allowSchemaDrift: + logger.Warn("database schema drift detected", + zap.String("uri", uri), + zap.String("diff", diff), + ) + default: + return nil, errors.Join( + fmt.Errorf("schema drift detected (uri %s):\n%s", uri, diff), + db.Close()) + } } + if config.cache { - config.logger.Debug("using query cache", zap.Any("sizes", config.cacheSizes)) + logger.Debug("using query cache", zap.Any("sizes", config.cacheSizes)) db.queryCache = &queryCache{cacheSizesByKind: config.cacheSizes} } db.queryCount.Store(0) @@ -290,13 +297,32 @@ func Version(uri string) (int, error) { if err != nil { return 0, fmt.Errorf("open db %s: %w", uri, err) } - db := &Database{pool: pool} + db := &sqliteDatabase{pool: pool} defer db.Close() return version(db) } -// Database is an instance of sqlite database. -type Database struct { +// Database represents a database. +type Database interface { + Executor + QueryCache + Close() error + QueryCount() int + QueryCache() QueryCache + Tx(ctx context.Context) (Transaction, error) + WithTx(ctx context.Context, exec func(Transaction) error) error + TxImmediate(ctx context.Context) (Transaction, error) + WithTxImmediate(ctx context.Context, exec func(Transaction) error) error +} + +// Transaction represents a transaction. +type Transaction interface { + Executor + Commit() error + Release() error +} + +type sqliteDatabase struct { *queryCache pool *sqlitex.Pool @@ -307,7 +333,9 @@ type Database struct { queryCount atomic.Int64 } -func (db *Database) getConn(ctx context.Context) *sqlite.Conn { +var _ Database = &sqliteDatabase{} + +func (db *sqliteDatabase) getConn(ctx context.Context) *sqlite.Conn { start := time.Now() conn := db.pool.Get(ctx) if conn != nil { @@ -316,19 +344,19 @@ func (db *Database) getConn(ctx context.Context) *sqlite.Conn { return conn } -func (db *Database) getTx(ctx context.Context, initstmt string) (*Tx, error) { +func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx, error) { conn := db.getConn(ctx) if conn == nil { return nil, ErrNoConnection } - tx := &Tx{queryCache: db.queryCache, db: db, conn: conn} + tx := &sqliteTx{queryCache: db.queryCache, db: db, conn: conn} if err := tx.begin(initstmt); err != nil { return nil, err } return tx, nil } -func (db *Database) withTx(ctx context.Context, initstmt string, exec func(*Tx) error) error { +func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func(Transaction) error) error { tx, err := db.getTx(ctx, initstmt) if err != nil { return err @@ -348,13 +376,13 @@ func (db *Database) withTx(ctx context.Context, initstmt string, exec func(*Tx) // after one of the write statements. // // https://www.sqlite.org/lang_transaction.html -func (db *Database) Tx(ctx context.Context) (*Tx, error) { +func (db *sqliteDatabase) Tx(ctx context.Context) (Transaction, error) { return db.getTx(ctx, beginDefault) } // WithTx will pass initialized deferred transaction to exec callback. // Will commit only if error is nil. -func (db *Database) WithTx(ctx context.Context, exec func(*Tx) error) error { +func (db *sqliteDatabase) WithTx(ctx context.Context, exec func(Transaction) error) error { return db.withTx(ctx, beginImmediate, exec) } @@ -363,13 +391,16 @@ func (db *Database) WithTx(ctx context.Context, exec func(*Tx) error) error { // IMMEDIATE cause the database connection to start a new write immediately, without waiting // for a write statement. The BEGIN IMMEDIATE might fail with SQLITE_BUSY if another write // transaction is already active on another database connection. -func (db *Database) TxImmediate(ctx context.Context) (*Tx, error) { +func (db *sqliteDatabase) TxImmediate(ctx context.Context) (Transaction, error) { return db.getTx(ctx, beginImmediate) } // WithTxImmediate will pass initialized immediate transaction to exec callback. // Will commit only if error is nil. -func (db *Database) WithTxImmediate(ctx context.Context, exec func(*Tx) error) error { +func (db *sqliteDatabase) WithTxImmediate( + ctx context.Context, + exec func(Transaction) error, +) error { return db.withTx(ctx, beginImmediate, exec) } @@ -381,7 +412,7 @@ func (db *Database) WithTxImmediate(ctx context.Context, exec func(*Tx) error) e // // Note that Exec will block until database is closed or statement has finished. // If application needs to control statement execution lifetime use one of the transaction. -func (db *Database) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { +func (db *sqliteDatabase) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { db.queryCount.Add(1) conn := db.getConn(context.Background()) if conn == nil { @@ -398,7 +429,7 @@ func (db *Database) Exec(query string, encoder Encoder, decoder Decoder) (int, e } // Close closes all pooled connections. -func (db *Database) Close() error { +func (db *sqliteDatabase) Close() error { db.closeMux.Lock() defer db.closeMux.Unlock() if db.closed { @@ -413,12 +444,12 @@ func (db *Database) Close() error { // QueryCount returns the number of queries executed, including failed // queries, but not counting transaction start / commit / rollback. -func (db *Database) QueryCount() int { +func (db *sqliteDatabase) QueryCount() int { return int(db.queryCount.Load()) } // Return database's QueryCache. -func (db *Database) QueryCache() QueryCache { +func (db *sqliteDatabase) QueryCache() QueryCache { return db.queryCache } @@ -436,11 +467,7 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in for { row, err := stmt.Step() if err != nil { - code := sqlite.ErrCode(err) - if code == sqlite.SQLITE_CONSTRAINT_PRIMARYKEY || code == sqlite.SQLITE_CONSTRAINT_UNIQUE { - return 0, ErrObjectExists - } - return 0, fmt.Errorf("step %d: %w", rows, err) + return 0, fmt.Errorf("step %d: %w", rows, mapSqliteError(err)) } if !row { return rows, nil @@ -459,48 +486,48 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in } } -// Tx is wrapper for database transaction. -type Tx struct { +// sqliteTx is wrapper for database transaction. +type sqliteTx struct { *queryCache - db *Database + db *sqliteDatabase conn *sqlite.Conn committed bool err error } -func (tx *Tx) begin(initstmt string) error { +func (tx *sqliteTx) begin(initstmt string) error { stmt := tx.conn.Prep(initstmt) _, err := stmt.Step() if err != nil { - return fmt.Errorf("begin: %w", err) + return fmt.Errorf("begin: %w", mapSqliteError(err)) } return nil } // Commit transaction. -func (tx *Tx) Commit() error { +func (tx *sqliteTx) Commit() error { stmt := tx.conn.Prep("COMMIT;") _, tx.err = stmt.Step() if tx.err != nil { - return tx.err + return mapSqliteError(tx.err) } tx.committed = true return nil } // Release transaction. Every transaction that was created must be released. -func (tx *Tx) Release() error { +func (tx *sqliteTx) Release() error { defer tx.db.pool.Put(tx.conn) if tx.committed { return nil } stmt := tx.conn.Prep("ROLLBACK") _, tx.err = stmt.Step() - return tx.err + return mapSqliteError(tx.err) } // Exec query. -func (tx *Tx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { +func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { tx.db.queryCount.Add(1) if tx.db.latency != nil { start := time.Now() @@ -511,6 +538,20 @@ func (tx *Tx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) return exec(tx.conn, query, encoder, decoder) } +func mapSqliteError(err error) error { + code := sqlite.ErrCode(err) + if code == sqlite.SQLITE_CONSTRAINT_PRIMARYKEY || code == sqlite.SQLITE_CONSTRAINT_UNIQUE { + return ErrObjectExists + } + if code == sqlite.SQLITE_INTERRUPT { + // TODO: we probably should check if there was indeed a context + // that was canceled. But we're likely to replace crawshaw library + // in future so this part should be rewritten anyway + return context.Canceled + } + return err +} + // Blob represents a binary blob data. It can be reused efficiently // across multiple data retrieval operations, minimizing reallocations // of the underlying byte slice. @@ -595,3 +636,15 @@ func LoadBlob(db Executor, cmd string, id []byte, blob *Blob) error { func IsNull(stmt *Statement, col int) bool { return stmt.ColumnType(col) == sqlite.SQLITE_NULL } + +// StateDatabase is a Database used for Spacemesh state. +type StateDatabase interface { + Database + IsStateDatabase() +} + +// LocalDatabase is a Database used for local node data. +type LocalDatabase interface { + Database + IsLocalDatabase() +} diff --git a/sql/database_test.go b/sql/database_test.go index ab7d71438a7..e4f222143ea 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -9,29 +9,25 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" ) func Test_Transaction_Isolation(t *testing.T) { - ctrl := gomock.NewController(t) - testMigration := NewMockMigration(ctrl) - testMigration.EXPECT().Name().Return("test").AnyTimes() - testMigration.EXPECT().Order().Return(1).AnyTimes() - testMigration.EXPECT().Apply(gomock.Any()).DoAndReturn(func(e Executor) error { - if _, err := e.Exec(`create table testing1 ( - id varchar primary key, - field int - )`, nil, nil); err != nil { - return err - } - return nil - }) - db := InMemory( - WithMigrations([]Migration{testMigration}), + WithLogger(zaptest.NewLogger(t)), WithConnections(10), WithLatencyMetering(true), + WithDatabaseSchema(&Schema{ + Script: `create table testing1 ( + id varchar primary key, + field int + );`, + }), + WithIgnoreSchemaDrift(), ) - tx, err := db.Tx(context.Background()) require.NoError(t, err) @@ -67,28 +63,37 @@ func Test_Migration_Rollback(t *testing.T) { migration2.EXPECT().Name().Return("test").AnyTimes() migration2.EXPECT().Order().Return(2).AnyTimes() - migration1.EXPECT().Apply(gomock.Any()).Return(nil) - migration2.EXPECT().Apply(gomock.Any()).Return(errors.New("migration 2 failed")) + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errors.New("migration 2 failed")) migration2.EXPECT().Rollback().Return(nil) dbFile := filepath.Join(t.TempDir(), "test.sql") _, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), + WithForceMigrations(true), ) require.ErrorContains(t, err, "migration 2 failed") } func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) migration1.EXPECT().Name().Return("test").AnyTimes() migration1.EXPECT().Order().Return(1).AnyTimes() - migration1.EXPECT().Apply(gomock.Any()).Return(nil) + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) dbFile := filepath.Join(t.TempDir(), "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -96,15 +101,48 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { migration2 := NewMockMigration(ctrl) migration2.EXPECT().Name().Return("test").AnyTimes() migration2.EXPECT().Order().Return(2).AnyTimes() - migration2.EXPECT().Apply(gomock.Any()).Return(errors.New("migration 2 failed")) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errors.New("migration 2 failed")) migration2.EXPECT().Rollback().Return(nil) _, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), ) require.ErrorContains(t, err, "migration 2 failed") } +func Test_Migration_Disabled(t *testing.T) { + ctrl := gomock.NewController(t) + migration1 := NewMockMigration(ctrl) + migration1.EXPECT().Name().Return("test").AnyTimes() + migration1.EXPECT().Order().Return(1).AnyTimes() + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) + + dbFile := filepath.Join(t.TempDir(), "test.sql") + db, err := Open("file:"+dbFile, + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + WithIgnoreSchemaDrift(), + ) + require.NoError(t, err) + require.NoError(t, db.Close()) + + migration2 := NewMockMigration(ctrl) + migration2.EXPECT().Order().Return(2).AnyTimes() + + _, err = Open("file:"+dbFile, + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), + WithMigrationsDisabled(), + ) + require.ErrorIs(t, err, ErrOldSchema) +} + func TestDatabaseSkipMigrations(t *testing.T) { ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) @@ -113,12 +151,17 @@ func TestDatabaseSkipMigrations(t *testing.T) { migration2 := NewMockMigration(ctrl) migration2.EXPECT().Name().Return("test").AnyTimes() migration2.EXPECT().Order().Return(2).AnyTimes() - migration2.EXPECT().Apply(gomock.Any()).Return(nil) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) + schema := &Schema{ + Migrations: MigrationList{migration1, migration2}, + } + schema.SkipMigrations(1) dbFile := filepath.Join(t.TempDir(), "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), - WithSkipMigrations(1), + WithDatabaseSchema(schema), + WithForceMigrations(true), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -126,26 +169,36 @@ func TestDatabaseSkipMigrations(t *testing.T) { func TestDatabaseVacuumState(t *testing.T) { dir := t.TempDir() + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) migration1.EXPECT().Order().Return(1).AnyTimes() - migration1.EXPECT().Apply(gomock.Any()).Return(nil).Times(1) + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil).Times(1) migration2 := NewMockMigration(ctrl) migration2.EXPECT().Order().Return(2).AnyTimes() - migration2.EXPECT().Apply(gomock.Any()).Return(nil).Times(1) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil).Times(1) dbFile := filepath.Join(dir, "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) db, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), WithVacuumState(2), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -156,7 +209,7 @@ func TestDatabaseVacuumState(t *testing.T) { } func TestQueryCount(t *testing.T) { - db := InMemory() + db := InMemory(WithLogger(zaptest.NewLogger(t)), WithIgnoreSchemaDrift()) require.Equal(t, 0, db.QueryCount()) n, err := db.Exec("select 1", nil, nil) @@ -171,6 +224,7 @@ func TestQueryCount(t *testing.T) { func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { dir := t.TempDir() + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) @@ -180,14 +234,71 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { migration2.EXPECT().Order().Return(2).AnyTimes() dbFile := filepath.Join(dir, "test.sql") - db, err := Open("file:" + dbFile) + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithForceMigrations(true), + WithIgnoreSchemaDrift()) require.NoError(t, err) _, err = db.Exec("PRAGMA user_version = 3", nil, nil) require.NoError(t, err) require.NoError(t, db.Close()) _, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), + WithIgnoreSchemaDrift(), ) require.ErrorIs(t, err, ErrTooNew) } + +func TestSchemaDrift(t *testing.T) { + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + dbFile := filepath.Join(t.TempDir(), "test.sql") + schema := &Schema{ + // Not using ` here to avoid schema drift warnings due to whitespace + // TODO: ignore whitespace and comments during schema comparison + Script: "PRAGMA user_version = 0;\n" + + "CREATE TABLE testing1 (\n" + + " id varchar primary key,\n" + + " field int\n" + + ");\n", + } + db, err := Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + ) + require.NoError(t, err) + + _, err = db.Exec("create table newtbl (id int)", nil, nil) + require.NoError(t, err) + + require.NoError(t, db.Close()) + require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") + + _, err = Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + ) + require.Error(t, err) + require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, err.Error()) + require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") + + db, err = Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + WithAllowSchemaDrift(true), + ) + require.NoError(t, db.Close()) + require.NoError(t, err) + require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") + require.Equal(t, "database schema drift detected", observedLogs.All()[0].Message) + require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, + observedLogs.All()[0].ContextMap()["diff"]) +} diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 3c00165702a..0d18ef5b79b 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -11,10 +11,11 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMalicious(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nodeID := types.NodeID{1, 1, 1, 1} mal, err := IsMalicious(db, nodeID) @@ -60,7 +61,7 @@ func TestMalicious(t *testing.T) { } func Test_GetMalicious(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() got, err := GetMalicious(db) require.NoError(t, err) require.Nil(t, got) @@ -78,7 +79,7 @@ func Test_GetMalicious(t *testing.T) { } func TestLoadMalfeasanceBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() nid1 := types.RandomNodeID() @@ -123,7 +124,7 @@ func TestMarriageATX(t *testing.T) { t.Parallel() t.Run("not married", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() _, err := MarriageATX(db, id) @@ -131,7 +132,7 @@ func TestMarriageATX(t *testing.T) { }) t.Run("married", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() marriage := MarriageData{ @@ -150,7 +151,7 @@ func TestMarriageATX(t *testing.T) { func TestMarriage(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() marriage := MarriageData{ @@ -169,7 +170,7 @@ func TestEquivocationSet(t *testing.T) { t.Parallel() t.Run("equivocation set of married IDs", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ @@ -196,7 +197,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("equivocation set for unmarried ID contains itself only", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() set, err := EquivocationSet(db, id) require.NoError(t, err) @@ -204,7 +205,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("can't escape the marriage", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ types.RandomNodeID(), @@ -237,7 +238,7 @@ func TestEquivocationSet(t *testing.T) { } }) t.Run("married doesn't become malicious immediately", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() id := types.RandomNodeID() require.NoError(t, SetMarriage(db, id, &MarriageData{ATX: atx})) @@ -256,7 +257,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("all IDs in equivocation set are malicious if one is", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ types.RandomNodeID(), @@ -280,7 +281,7 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { t.Parallel() t.Run("married IDs", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ids := []types.NodeID{ types.RandomNodeID(), types.RandomNodeID(), @@ -296,7 +297,7 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { require.Equal(t, ids, set) }) t.Run("empty set", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() set, err := EquivocationSetByMarriageATX(db, types.RandomATXID()) require.NoError(t, err) require.Empty(t, set) diff --git a/sql/interface.go b/sql/interface.go index 3d3931b60e1..23859f5202b 100644 --- a/sql/interface.go +++ b/sql/interface.go @@ -1,10 +1,12 @@ package sql +import "go.uber.org/zap" + //go:generate mockgen -typed -package=sql -destination=./mocks.go -source=./interface.go // Migration is interface for migrations provider. type Migration interface { - Apply(db Executor) error + Apply(db Executor, logger *zap.Logger) error Rollback() error Name() string Order() int diff --git a/sql/layers/layers_test.go b/sql/layers/layers_test.go index f16a5e40d71..70f1f57bf15 100644 --- a/sql/layers/layers_test.go +++ b/sql/layers/layers_test.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 4 @@ -20,7 +21,7 @@ func TestMain(m *testing.M) { } func TestWeakCoin(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) _, err := GetWeakCoin(db, lid) @@ -38,7 +39,7 @@ func TestWeakCoin(t *testing.T) { } func TestAppliedBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) _, err := GetApplied(db, lid) @@ -67,7 +68,7 @@ func TestAppliedBlock(t *testing.T) { } func TestFirstAppliedInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blks := map[types.LayerID]types.BlockID{ types.EpochID(1).FirstLayer(): {1}, types.EpochID(2).FirstLayer(): types.EmptyBlockID, @@ -107,7 +108,7 @@ func TestFirstAppliedInEpoch(t *testing.T) { } func TestUnsetAppliedFrom(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) last := lid.Add(99) for i := lid; !i.After(last); i = i.Add(1) { @@ -123,7 +124,7 @@ func TestUnsetAppliedFrom(t *testing.T) { } func TestStateHash(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() layers := []uint32{9, 10, 8, 7} hashes := []types.Hash32{{1}, {2}, {3}, {4}} for i := range layers { @@ -147,7 +148,7 @@ func TestStateHash(t *testing.T) { } func TestSetHashes(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := GetAggregatedHash(db, types.LayerID(11)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -178,7 +179,7 @@ func TestSetHashes(t *testing.T) { } func TestProcessed(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid, err := GetProcessed(db) require.NoError(t, err) require.Equal(t, types.LayerID(0), lid) @@ -193,7 +194,7 @@ func TestProcessed(t *testing.T) { } func TestGetAggHashes(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() hashes := make(map[types.LayerID]types.Hash32) diff --git a/sql/localsql/local.go b/sql/localsql/local.go deleted file mode 100644 index cbf1d9f2f4e..00000000000 --- a/sql/localsql/local.go +++ /dev/null @@ -1,38 +0,0 @@ -package localsql - -import "github.com/spacemeshos/go-spacemesh/sql" - -type Database struct { - *sql.Database -} - -func Open(uri string, opts ...sql.Opt) (*Database, error) { - migrations, err := sql.LocalMigrations() - if err != nil { - return nil, err - } - defaultOpts := []sql.Opt{ - sql.WithConnections(16), - sql.WithMigrations(migrations), - } - opts = append(defaultOpts, opts...) - db, err := sql.Open(uri, opts...) - if err != nil { - return nil, err - } - return &Database{Database: db}, nil -} - -func InMemory(opts ...sql.Opt) *Database { - migrations, err := sql.LocalMigrations() - if err != nil { - panic(err) - } - defaultOpts := []sql.Opt{ - sql.WithConnections(1), - sql.WithMigrations(migrations), - } - opts = append(defaultOpts, opts...) - db := sql.InMemory(opts...) - return &Database{Database: db} -} diff --git a/sql/localsql/local_test.go b/sql/localsql/local_test.go deleted file mode 100644 index 1fed142b1af..00000000000 --- a/sql/localsql/local_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package localsql - -import ( - "path/filepath" - "strings" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/spacemeshos/go-spacemesh/sql" -) - -func TestDatabase_MigrateTwice_NoOp(t *testing.T) { - file := filepath.Join(t.TempDir(), "test.db") - db, err := Open("file:" + file) - require.NoError(t, err) - - var sqls1 []string - _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - sqls1 = append(sqls1, sql) - return true - }) - require.NoError(t, err) - require.NoError(t, db.Close()) - - db, err = Open("file:" + file) - require.NoError(t, err) - var sqls2 []string - _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - sqls2 = append(sqls2, sql) - return true - }) - require.NoError(t, err) - require.NoError(t, db.Close()) - - require.Equal(t, sqls1, sqls2) -} diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go new file mode 100644 index 00000000000..2dc32f18adc --- /dev/null +++ b/sql/localsql/localsql.go @@ -0,0 +1,83 @@ +package localsql + +import ( + "embed" + "strings" + "testing" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:generate go run ../schemagen -dbtype local -output schema/schema.sql + +//go:embed schema/schema.sql +var schemaScript string + +//go:embed schema/migrations/*.sql +var migrations embed.FS + +type database struct { + sql.Database +} + +var _ sql.LocalDatabase = &database{} + +func (d *database) IsLocalDatabase() {} + +// Schema returns the schema for the local database. +func Schema() (*sql.Schema, error) { + sqlMigrations, err := sql.LoadSQLMigrations(migrations) + if err != nil { + return nil, err + } + // NOTE: coded state migrations can be added here + // They can be a part of this localsql package + return &sql.Schema{ + Script: strings.ReplaceAll(schemaScript, "\r", ""), + Migrations: sqlMigrations, + }, nil +} + +// Open opens a local database. +func Open(uri string, opts ...sql.Opt) (*database, error) { + schema, err := Schema() + if err != nil { + return nil, err + } + defaultOpts := []sql.Opt{ + sql.WithConnections(16), + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db, err := sql.Open(uri, opts...) + if err != nil { + return nil, err + } + return &database{Database: db}, nil +} + +// Open opens an in-memory local database. +func InMemory(opts ...sql.Opt) *database { + schema, err := Schema() + if err != nil { + panic(err) + } + defaultOpts := []sql.Opt{ + sql.WithConnections(1), + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db := sql.InMemory(opts...) + return &database{Database: db} +} + +// InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. +func InMemoryTest(tb testing.TB, opts ...sql.Opt) sql.LocalDatabase { + opts = append(opts, sql.WithConnections(1)) + db, err := Open("file::memory:?mode=memory", opts...) + if err != nil { + panic(err) + } + tb.Cleanup(func() { db.Close() }) + return db +} diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go new file mode 100644 index 00000000000..320702d7ebd --- /dev/null +++ b/sql/localsql/localsql_test.go @@ -0,0 +1,73 @@ +package localsql + +import ( + "path/filepath" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func TestIdempotentMigration(t *testing.T) { + observer, observedLogs := observer.New(zapcore.InfoLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + file := filepath.Join(t.TempDir(), "test.db") + db, err := Open("file:"+file, sql.WithForceMigrations(true), sql.WithLogger(logger)) + require.NoError(t, err) + + var versionA int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionA = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + require.NoError(t, db.Close()) + + require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") + l := observedLogs.All()[0] + require.Equal(t, "running migrations", l.Message) + require.Equal(t, int64(0), l.ContextMap()["current version"]) + require.Equal(t, int64(versionA), l.ContextMap()["target version"]) + + db, err = Open("file:"+file, sql.WithLogger(logger)) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var versionB int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionB = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + schema, err := Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), versionA) + require.Equal(t, expectedVersion.Order(), versionB) + + require.NoError(t, db.Close()) + // make sure there's no schema drift warnings in the logs + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") +} diff --git a/sql/migrations/local/0001_initial.sql b/sql/localsql/schema/migrations/0001_initial.sql similarity index 100% rename from sql/migrations/local/0001_initial.sql rename to sql/localsql/schema/migrations/0001_initial.sql diff --git a/sql/migrations/local/0002_extend_initial_post.sql b/sql/localsql/schema/migrations/0002_extend_initial_post.sql similarity index 100% rename from sql/migrations/local/0002_extend_initial_post.sql rename to sql/localsql/schema/migrations/0002_extend_initial_post.sql diff --git a/sql/migrations/local/0003_add_nipost_builder_state.sql b/sql/localsql/schema/migrations/0003_add_nipost_builder_state.sql similarity index 100% rename from sql/migrations/local/0003_add_nipost_builder_state.sql rename to sql/localsql/schema/migrations/0003_add_nipost_builder_state.sql diff --git a/sql/migrations/local/0004_atx_sync.sql b/sql/localsql/schema/migrations/0004_atx_sync.sql similarity index 100% rename from sql/migrations/local/0004_atx_sync.sql rename to sql/localsql/schema/migrations/0004_atx_sync.sql diff --git a/sql/migrations/local/0005_fast_startup.sql b/sql/localsql/schema/migrations/0005_fast_startup.sql similarity index 100% rename from sql/migrations/local/0005_fast_startup.sql rename to sql/localsql/schema/migrations/0005_fast_startup.sql diff --git a/sql/migrations/local/0006_prepared_activeset.sql b/sql/localsql/schema/migrations/0006_prepared_activeset.sql similarity index 100% rename from sql/migrations/local/0006_prepared_activeset.sql rename to sql/localsql/schema/migrations/0006_prepared_activeset.sql diff --git a/sql/migrations/local/0007_malfeasance_sync.sql b/sql/localsql/schema/migrations/0007_malfeasance_sync.sql similarity index 100% rename from sql/migrations/local/0007_malfeasance_sync.sql rename to sql/localsql/schema/migrations/0007_malfeasance_sync.sql diff --git a/sql/migrations/local/0008_next.sql b/sql/localsql/schema/migrations/0008_next.sql similarity index 100% rename from sql/migrations/local/0008_next.sql rename to sql/localsql/schema/migrations/0008_next.sql diff --git a/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql b/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql new file mode 100644 index 00000000000..d0ade1f535a --- /dev/null +++ b/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql @@ -0,0 +1,12 @@ +ALTER TABLE malfeasance_sync_state RENAME TO malfeasance_sync_state_old; + +CREATE TABLE malfeasance_sync_state +( + id INT NOT NULL PRIMARY KEY, + timestamp INT NOT NULL +); + +INSERT INTO malfeasance_sync_state (id, timestamp) +SELECT 1, timestamp FROM malfeasance_sync_state_old LIMIT 1; + +DROP TABLE malfeasance_sync_state_old; diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql new file mode 100755 index 00000000000..02c44d3ccbd --- /dev/null +++ b/sql/localsql/schema/schema.sql @@ -0,0 +1,83 @@ +PRAGMA user_version = 9; +CREATE TABLE atx_sync_requests +( + epoch INT NOT NULL, + timestamp INT NOT NULL, total INTEGER, downloaded INTEGER, + PRIMARY KEY (epoch) +) WITHOUT ROWID; +CREATE TABLE atx_sync_state +( + epoch INT NOT NULL, + id CHAR(32) NOT NULL, + requests INT NOT NULL DEFAULT 0, + PRIMARY KEY (epoch, id) +) WITHOUT ROWID; +CREATE TABLE "challenge" +( + id CHAR(32) PRIMARY KEY, + epoch UNSIGNED INT NOT NULL, + sequence UNSIGNED INT NOT NULL, + prev_atx CHAR(32) NOT NULL, + pos_atx CHAR(32) NOT NULL, + commit_atx CHAR(32), + post_nonce UNSIGNED INT, + post_indices VARCHAR, + post_pow UNSIGNED LONG INT +, poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; +CREATE TABLE malfeasance_sync_state +( + id INT NOT NULL PRIMARY KEY, + timestamp INT NOT NULL +); +CREATE TABLE nipost +( + id CHAR(32) PRIMARY KEY, + post_nonce UNSIGNED INT NOT NULL, + post_indices VARCHAR NOT NULL, + post_pow UNSIGNED LONG INT NOT NULL, + + num_units UNSIGNED INT NOT NULL, + vrf_nonce UNSIGNED LONG INT NOT NULL, + + poet_proof_membership VARCHAR NOT NULL, + poet_proof_ref CHAR(32) NOT NULL, + labels_per_unit UNSIGNED INT NOT NULL +) WITHOUT ROWID; +CREATE TABLE poet_certificates +( + node_id BLOB NOT NULL, + certifier_id BLOB NOT NULL, + certificate BLOB NOT NULL, + signature BLOB NOT NULL +); +CREATE UNIQUE INDEX idx_poet_certificates ON poet_certificates (node_id, certifier_id); +CREATE TABLE poet_registration +( + id CHAR(32) NOT NULL, + hash CHAR(32) NOT NULL, + address VARCHAR NOT NULL, + round_id VARCHAR NOT NULL, + round_end INT NOT NULL, + + PRIMARY KEY (id, address) +) WITHOUT ROWID; +CREATE TABLE "post" +( + id CHAR(32) PRIMARY KEY, + post_nonce UNSIGNED INT NOT NULL, + post_indices VARCHAR NOT NULL, + post_pow UNSIGNED LONG INT NOT NULL, + + num_units UNSIGNED INT NOT NULL, + commit_atx CHAR(32) NOT NULL, + vrf_nonce UNSIGNED LONG INT NOT NULL +, challenge BLOB NOT NULL DEFAULT x'0000000000000000000000000000000000000000000000000000000000000000'); +CREATE TABLE prepared_activeset +( + kind UNSIGNED INT NOT NULL, + epoch UNSIGNED INT NOT NULL, + id CHAR(32) NOT NULL, + weight UNSIGNED INT NOT NULL, + data BLOB NOT NULL, + PRIMARY KEY (kind, epoch) +) WITHOUT ROWID; diff --git a/sql/malsync/malsync.go b/sql/malsync/malsync.go index 9ae572b0ade..fac8b50fbaa 100644 --- a/sql/malsync/malsync.go +++ b/sql/malsync/malsync.go @@ -9,7 +9,7 @@ import ( func GetSyncState(db sql.Executor) (time.Time, error) { var timestamp time.Time - rows, err := db.Exec("select timestamp from malfeasance_sync_state", + rows, err := db.Exec("select timestamp from malfeasance_sync_state where id = 1", nil, func(stmt *sql.Statement) bool { v := stmt.ColumnInt64(0) if v > 0 { @@ -17,21 +17,25 @@ func GetSyncState(db sql.Executor) (time.Time, error) { } return true }) - if err != nil { + switch { + case err != nil: return time.Time{}, fmt.Errorf("error getting malfeasance sync state: %w", err) - } else if rows != 1 { + case rows <= 1: + return timestamp, nil + default: return time.Time{}, fmt.Errorf("expected malfeasance_sync_state to have 1 row but got %d rows", rows) } - return timestamp, nil } func updateSyncState(db sql.Executor, ts int64) error { - _, err := db.Exec("update malfeasance_sync_state set timestamp = ?1", + if _, err := db.Exec( + `insert into malfeasance_sync_state (id, timestamp) values(1, ?1) + on conflict (id) do update set timestamp = ?1`, func(stmt *sql.Statement) { stmt.BindInt64(1, ts) - }, nil) - if err != nil { - return fmt.Errorf("error updating malfeasance sync state: %w", err) + }, nil, + ); err != nil { + return fmt.Errorf("error initializing malfeasance sync state: %w", err) } return nil } diff --git a/sql/metrics/prometheus.go b/sql/metrics/prometheus.go index 30991749b1b..fb35490ee08 100644 --- a/sql/metrics/prometheus.go +++ b/sql/metrics/prometheus.go @@ -22,7 +22,7 @@ const ( type DBMetricsCollector struct { logger *zap.Logger checkInterval time.Duration - db *sql.Database + db sql.StateDatabase tablesList map[string]struct{} eg errgroup.Group cancel context.CancelFunc @@ -35,7 +35,7 @@ type DBMetricsCollector struct { // NewDBMetricsCollector creates new DBMetricsCollector. func NewDBMetricsCollector( ctx context.Context, - db *sql.Database, + db sql.StateDatabase, logger *zap.Logger, checkInterval time.Duration, ) *DBMetricsCollector { diff --git a/sql/migrations.go b/sql/migrations.go index 3b9a019e537..3a4d37844dc 100644 --- a/sql/migrations.go +++ b/sql/migrations.go @@ -1,25 +1,53 @@ package sql import ( - "bufio" - "bytes" - "embed" "fmt" "io/fs" + "regexp" + "slices" "strconv" "strings" + + "go.uber.org/zap" ) -//go:embed migrations/**/*.sql -var embedded embed.FS +// MigrationList denotes a list of migrations. +type MigrationList []Migration + +// AddMigration adds a Migration to the MigrationList, overriding the migration with the +// same order number if it already exists. The function returns updated migration list. +// The state of the original migration list is undefined after calling this function. +func (l MigrationList) AddMigration(migration Migration) MigrationList { + for i, m := range l { + if m.Order() == migration.Order() { + l[i] = migration + return l + } + if m.Order() > migration.Order() { + l = slices.Insert(l, i, migration) + return l + } + } + return append(l, migration) +} + +// Version returns database version for the specified migration list. +func (l MigrationList) Version() int { + if len(l) == 0 { + return 0 + } + return l[len(l)-1].Order() +} type sqlMigration struct { order int name string - content *bufio.Scanner + content string } -func (m *sqlMigration) Apply(db Executor) error { +var sqlCommentRx = regexp.MustCompile(`(?m)--.*$`) + +func (m *sqlMigration) Apply(db Executor, logger *zap.Logger) error { current, err := version(db) if err != nil { return err @@ -28,9 +56,14 @@ func (m *sqlMigration) Apply(db Executor) error { if m.order <= current { return nil } - for m.content.Scan() { - if _, err := db.Exec(m.content.Text(), nil, nil); err != nil { - return fmt.Errorf("exec %s: %w", m.content.Text(), err) + // TODO: use more advanced approach to split the SQL script + // into commands + for _, cmd := range strings.Split(m.content, ";") { + cmd = sqlCommentRx.ReplaceAllString(cmd, "") + if strings.TrimSpace(cmd) != "" { + if _, err := db.Exec(cmd, nil, nil); err != nil { + return fmt.Errorf("exec %s: %w", cmd, err) + } } } // binding values in pragma statement is not allowed @@ -65,18 +98,9 @@ func version(db Executor) (int, error) { return current, nil } -func StateMigrations() ([]Migration, error) { - return sqlMigrations("state") -} - -func LocalMigrations() ([]Migration, error) { - return sqlMigrations("local") -} - -func sqlMigrations(dbname string) ([]Migration, error) { - var migrations []Migration - root := fmt.Sprintf("migrations/%s", dbname) - err := fs.WalkDir(embedded, root, func(path string, d fs.DirEntry, err error) error { +func LoadSQLMigrations(fsys fs.FS) (MigrationList, error) { + var migrations MigrationList + err := fs.WalkDir(fsys, "schema/migrations", func(path string, d fs.DirEntry, err error) error { if err != nil { return fmt.Errorf("walkdir %s: %w", path, err) } @@ -91,21 +115,14 @@ func sqlMigrations(dbname string) ([]Migration, error) { if err != nil { return fmt.Errorf("invalid migration %s: %w", d.Name(), err) } - f, err := embedded.Open(path) + script, err := fs.ReadFile(fsys, path) if err != nil { return fmt.Errorf("read file %s: %w", path, err) } - scanner := bufio.NewScanner(f) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if i := bytes.Index(data, []byte(";")); i >= 0 { - return i + 1, data[0 : i+1], nil - } - return 0, nil, nil - }) migrations = append(migrations, &sqlMigration{ order: order, name: d.Name(), - content: scanner, + content: string(script), }) return nil }) diff --git a/sql/migrations_test.go b/sql/migrations_test.go deleted file mode 100644 index 13bfb86b806..00000000000 --- a/sql/migrations_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package sql - -import ( - "slices" - "testing" - - "github.com/stretchr/testify/require" -) - -func Test_MigrationsAppliedOnce(t *testing.T) { - db := InMemory() - - var version int - _, err := db.Exec("PRAGMA user_version;", nil, func(stmt *Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - - migrations, err := StateMigrations() - require.NoError(t, err) - expectedVersion := slices.MaxFunc(migrations, func(a, b Migration) int { - return a.Order() - b.Order() - }) - require.Equal(t, expectedVersion.Order(), version) -} diff --git a/sql/mocks.go b/sql/mocks.go index bbd869f3d49..975de979115 100644 --- a/sql/mocks.go +++ b/sql/mocks.go @@ -13,6 +13,7 @@ import ( reflect "reflect" gomock "go.uber.org/mock/gomock" + zap "go.uber.org/zap" ) // MockMigration is a mock of Migration interface. @@ -39,17 +40,17 @@ func (m *MockMigration) EXPECT() *MockMigrationMockRecorder { } // Apply mocks base method. -func (m *MockMigration) Apply(db Executor) error { +func (m *MockMigration) Apply(db Executor, logger *zap.Logger) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Apply", db) + ret := m.ctrl.Call(m, "Apply", db, logger) ret0, _ := ret[0].(error) return ret0 } // Apply indicates an expected call of Apply. -func (mr *MockMigrationMockRecorder) Apply(db any) *MockMigrationApplyCall { +func (mr *MockMigrationMockRecorder) Apply(db, logger any) *MockMigrationApplyCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockMigration)(nil).Apply), db) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockMigration)(nil).Apply), db, logger) return &MockMigrationApplyCall{Call: call} } @@ -65,13 +66,13 @@ func (c *MockMigrationApplyCall) Return(arg0 error) *MockMigrationApplyCall { } // Do rewrite *gomock.Call.Do -func (c *MockMigrationApplyCall) Do(f func(Executor) error) *MockMigrationApplyCall { +func (c *MockMigrationApplyCall) Do(f func(Executor, *zap.Logger) error) *MockMigrationApplyCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockMigrationApplyCall) DoAndReturn(f func(Executor) error) *MockMigrationApplyCall { +func (c *MockMigrationApplyCall) DoAndReturn(f func(Executor, *zap.Logger) error) *MockMigrationApplyCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/sql/mocks/mocks.go b/sql/mocks/mocks.go index c5133b30640..8d93a4a118b 100644 --- a/sql/mocks/mocks.go +++ b/sql/mocks/mocks.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ./database.go +// Source: github.com/spacemeshos/go-spacemesh/sql (interfaces: Executor) // // Generated by this command: // -// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./database.go +// mockgen -typed -package=mocks -destination=./mocks/mocks.go github.com/spacemeshos/go-spacemesh/sql Executor // // Package mocks is a generated GoMock package. diff --git a/sql/poets/poets_test.go b/sql/poets/poets_test.go index 9545f0075be..be7d88e1926 100644 --- a/sql/poets/poets_test.go +++ b/sql/poets/poets_test.go @@ -8,10 +8,11 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() refs := []types.PoetProofRef{ {0xca, 0xfe}, @@ -51,7 +52,7 @@ func TestHas(t *testing.T) { } func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() refs := []types.PoetProofRef{ {0xca, 0xfe}, @@ -102,7 +103,7 @@ func TestGet(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ref := types.PoetProofRef{0xca, 0xfe} poet := []byte("proof0") @@ -121,7 +122,7 @@ func TestAdd(t *testing.T) { } func TestGetRef(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sids := [][]byte{ []byte("sid1"), diff --git a/sql/querycache.go b/sql/querycache.go index e9840a29599..38db15f8803 100644 --- a/sql/querycache.go +++ b/sql/querycache.go @@ -17,14 +17,14 @@ type ( var NullQueryCache QueryCache = (*queryCache)(nil) -type queryCacheKey struct { +type QueryCacheItemKey struct { Kind QueryCacheKind Key string } // QueryCacheKey creates a key for QueryCache. -func QueryCacheKey(kind QueryCacheKind, key string) queryCacheKey { - return queryCacheKey{Kind: kind, Key: key} +func QueryCacheKey(kind QueryCacheKind, key string) QueryCacheItemKey { + return QueryCacheItemKey{Kind: kind, Key: key} } // QueryCacheSubKey denotes a cache subkey. The empty subkey refers to the main @@ -60,14 +60,14 @@ type QueryCache interface { // called for this cache. GetValue( ctx context.Context, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve UntypedRetrieveFunc, ) (any, error) // UpdateSlice updates the slice stored in the cache by invoking the // specified SliceAppender. If the entry is not cached, the method does // nothing. - UpdateSlice(key queryCacheKey, update SliceAppender) + UpdateSlice(key QueryCacheItemKey, update SliceAppender) // ClearCache empties the cache. ClearCache() } @@ -87,7 +87,7 @@ func IsCached(db any) bool { func WithCachedValue[T any]( ctx context.Context, db any, - key queryCacheKey, + key QueryCacheItemKey, retrieve func(ctx context.Context) (T, error), ) (T, error) { return WithCachedSubKey(ctx, db, key, mainSubKey, retrieve) @@ -100,7 +100,7 @@ func WithCachedValue[T any]( func WithCachedSubKey[T any]( ctx context.Context, db any, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve func(ctx context.Context) (T, error), ) (T, error) { @@ -124,7 +124,7 @@ func WithCachedSubKey[T any]( // AppendToCachedSlice adds a value to the slice stored in the cache by invoking // the specified SliceAppender. If the entry is not cached, the function does // nothing. -func AppendToCachedSlice[T any](db any, key queryCacheKey, v T) { +func AppendToCachedSlice[T any](db any, key QueryCacheItemKey, v T) { if cache, ok := db.(QueryCache); ok { cache.UpdateSlice(key, func(s any) any { if s == nil { @@ -145,7 +145,7 @@ type lru = simplelru.LRU[lruCacheKey, any] type queryCache struct { sync.Mutex updateMtx sync.RWMutex - subKeyMap map[queryCacheKey][]QueryCacheSubKey + subKeyMap map[QueryCacheItemKey][]QueryCacheSubKey cacheSizesByKind map[QueryCacheKind]int caches map[QueryCacheKind]*lru } @@ -162,7 +162,7 @@ func (c *queryCache) ensureLRU(kind QueryCacheKind) *lru { } lruForKind, err := simplelru.NewLRU[lruCacheKey, any](size, func(k lruCacheKey, v any) { if k.subKey == mainSubKey { - c.clearSubKeys(queryCacheKey{Kind: kind, Key: k.key}) + c.clearSubKeys(QueryCacheItemKey{Kind: kind, Key: k.key}) } }) if err != nil { @@ -175,7 +175,7 @@ func (c *queryCache) ensureLRU(kind QueryCacheKind) *lru { return lruForKind } -func (c *queryCache) clearSubKeys(key queryCacheKey) { +func (c *queryCache) clearSubKeys(key QueryCacheItemKey) { lru, found := c.caches[key.Kind] if !found { return @@ -188,7 +188,7 @@ func (c *queryCache) clearSubKeys(key queryCacheKey) { } } -func (c *queryCache) get(key queryCacheKey, subKey QueryCacheSubKey) (any, bool) { +func (c *queryCache) get(key QueryCacheItemKey, subKey QueryCacheSubKey) (any, bool) { c.Lock() defer c.Unlock() lru, found := c.caches[key.Kind] @@ -202,14 +202,14 @@ func (c *queryCache) get(key queryCacheKey, subKey QueryCacheSubKey) (any, bool) }) } -func (c *queryCache) set(key queryCacheKey, subKey QueryCacheSubKey, v any) { +func (c *queryCache) set(key QueryCacheItemKey, subKey QueryCacheSubKey, v any) { c.Lock() defer c.Unlock() if subKey != mainSubKey { sks := c.subKeyMap[key] if slices.Index(sks, subKey) < 0 { if c.subKeyMap == nil { - c.subKeyMap = make(map[queryCacheKey][]QueryCacheSubKey) + c.subKeyMap = make(map[QueryCacheItemKey][]QueryCacheSubKey) } c.subKeyMap[key] = append(sks, subKey) } @@ -224,7 +224,7 @@ func (c *queryCache) IsCached() bool { func (c *queryCache) GetValue( ctx context.Context, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve UntypedRetrieveFunc, ) (any, error) { @@ -251,7 +251,7 @@ func (c *queryCache) GetValue( return v, err } -func (c *queryCache) UpdateSlice(key queryCacheKey, update SliceAppender) { +func (c *queryCache) UpdateSlice(key QueryCacheItemKey, update SliceAppender) { if c == nil { return } diff --git a/sql/recovery/recovery_test.go b/sql/recovery/recovery_test.go index 976091e0777..7f4e396925b 100644 --- a/sql/recovery/recovery_test.go +++ b/sql/recovery/recovery_test.go @@ -6,12 +6,12 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRecoveryInfo(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() restore := types.LayerID(12) got, err := recovery.CheckpointInfo(db) diff --git a/sql/rewards/rewards_test.go b/sql/rewards/rewards_test.go index 5cd5715a530..71fd23e090a 100644 --- a/sql/rewards/rewards_test.go +++ b/sql/rewards/rewards_test.go @@ -6,13 +6,15 @@ import ( "testing" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRewards(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() var part uint64 = math.MaxUint64 / 2 lyrReward := part / 2 @@ -199,13 +201,17 @@ func TestRewards(t *testing.T) { } func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { - migrations, err := sql.StateMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - sort.Slice(migrations, func(i, j int) bool { return migrations[i].Order() < migrations[j].Order() }) + sort.Slice(schema.Migrations, func(i, j int) bool { + return schema.Migrations[i].Order() < schema.Migrations[j].Order() + }) + origMigrations := schema.Migrations + schema.Migrations = schema.Migrations[:7] // apply previous migrations - db := sql.InMemory( - sql.WithMigrations(migrations[:7]), + db := statesql.InMemory( + sql.WithDatabaseSchema(schema), ) // verify that the DB is empty @@ -217,7 +223,7 @@ func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { require.NoError(t, err) // apply the migration - err = migrations[7].Apply(db) + err = origMigrations[7].Apply(db, zaptest.NewLogger(t)) require.NoError(t, err) // verify that db is still empty @@ -230,13 +236,19 @@ func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { } func Test_0008Migration(t *testing.T) { - migrations, err := sql.StateMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - sort.Slice(migrations, func(i, j int) bool { return migrations[i].Order() < migrations[j].Order() }) + sort.Slice(schema.Migrations, func(i, j int) bool { + return schema.Migrations[i].Order() < schema.Migrations[j].Order() + }) + origMigrations := schema.Migrations + schema.Migrations = schema.Migrations[:7] // apply previous migrations - db := sql.InMemory( - sql.WithMigrations(migrations[:7]), + db := statesql.InMemory( + sql.WithDatabaseSchema(schema), + sql.WithForceMigrations(true), + sql.WithAllowSchemaDrift(true), ) // verify that the DB is empty @@ -279,7 +291,7 @@ func Test_0008Migration(t *testing.T) { require.NoError(t, err) // apply the migration - err = migrations[7].Apply(db) + err = origMigrations[7].Apply(db, zaptest.NewLogger(t)) require.NoError(t, err) // verify that one row is still present diff --git a/sql/schema.go b/sql/schema.go new file mode 100644 index 00000000000..cf125f162ea --- /dev/null +++ b/sql/schema.go @@ -0,0 +1,215 @@ +package sql + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" +) + +const ( + SchemaPath = "schema/schema.sql" + UpdatedSchemaPath = "schema/schema.sql.updated" +) + +// LoadDBSchemaScript retrieves the database schema as text. +func LoadDBSchemaScript(db Executor) (string, error) { + var ( + err error + sb strings.Builder + ) + version, err := version(db) + if err != nil { + return "", err + } + fmt.Fprintf(&sb, "PRAGMA user_version = %d;\n", version) + if _, err = db.Exec( + // Type is either 'index' or 'table', we want tables to go first + `select tbl_name, sql || ';' from sqlite_master + where sql is not null + order by tbl_name, type desc, name`, + nil, func(st *Statement) bool { + fmt.Fprintln(&sb, st.ColumnText(1)) + return true + }); err != nil { + return "", fmt.Errorf("error retrieving DB schema: %w", err) + } + // On Windows, the result contains extra carriage returns + return strings.ReplaceAll(sb.String(), "\r", ""), nil +} + +// Schema represents database schema. +type Schema struct { + Script string + Migrations MigrationList + skipMigration map[int]struct{} +} + +// Diff diffs the database schema against the actual schema. +// If there's no differences, it returns an empty string. +func (s *Schema) Diff(actualScript string) string { + return cmp.Diff(s.Script, actualScript) +} + +// WriteToFile writes the schema to the corresponding updated schema file. +func (s *Schema) WriteToFile(basedir string) error { + path := filepath.Join(basedir, UpdatedSchemaPath) + if err := os.WriteFile(path, []byte(s.Script), 0o777); err != nil { + return fmt.Errorf("error writing schema file %s: %w", path, err) + } + return nil +} + +// SkipMigrations skips the specified migrations. +func (s *Schema) SkipMigrations(i ...int) { + if s.skipMigration == nil { + s.skipMigration = make(map[int]struct{}) + } + for _, index := range i { + s.skipMigration[index] = struct{}{} + } +} + +// Apply applies the schema to the database. +func (s *Schema) Apply(db Database) error { + return db.WithTx(context.Background(), func(tx Transaction) error { + scanner := bufio.NewScanner(strings.NewReader(s.Script)) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if i := bytes.Index(data, []byte(";")); i >= 0 { + return i + 1, data[0 : i+1], nil + } + return 0, nil, nil + }) + for scanner.Scan() { + if _, err := tx.Exec(scanner.Text(), nil, nil); err != nil { + return fmt.Errorf("exec %s: %w", scanner.Text(), err) + } + } + return nil + }) +} + +func (s *Schema) CheckDBVersion(logger *zap.Logger, db Database) (before, after int, err error) { + if len(s.Migrations) == 0 { + return 0, 0, nil + } + before, err = version(db) + if err != nil { + return 0, 0, err + } + after = s.Migrations.Version() + if before > after { + logger.Error("database version is newer than expected - downgrade is not supported", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return before, after, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) + } + + return before, after, nil +} + +// Migrate performs database migration. In case if migrations are disabled, the database +// version is checked but no migrations are run, and if the database is too old and +// migrations are disabled, an error is returned. +func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState int) error { + for i, m := range s.Migrations { + if m.Order() <= before { + continue + } + if err := db.WithTx(context.Background(), func(tx Transaction) error { + if _, ok := s.skipMigration[m.Order()]; !ok { + if err := m.Apply(tx, logger); err != nil { + for j := i; j >= 0 && s.Migrations[j].Order() > before; j-- { + if e := s.Migrations[j].Rollback(); e != nil { + err = errors.Join(err, fmt.Errorf("rollback %s: %w", m.Name(), e)) + break + } + } + + return fmt.Errorf("apply %s: %w", m.Name(), err) + } + } + // version is set intentionally even if actual migration was skipped + if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.Order()), nil, nil); err != nil { + return fmt.Errorf("update user_version to %d: %w", m.Order(), err) + } + return nil + }); err != nil { + err = errors.Join(err, db.Close()) + return err + } + + if vacuumState != 0 && before <= vacuumState { + if err := Vacuum(db); err != nil { + err = errors.Join(err, db.Close()) + return err + } + } + before = m.Order() + } + return nil +} + +// SchemaGenOpt represents a schema generator option. +type SchemaGenOpt func(g *SchemaGen) + +func withDefaultOut(w io.Writer) SchemaGenOpt { + return func(g *SchemaGen) { + g.defaultOut = w + } +} + +// SchemaGen generates database schema files. +type SchemaGen struct { + logger *zap.Logger + schema *Schema + defaultOut io.Writer +} + +// NewSchemaGen creates a new SchemaGen instance. +func NewSchemaGen(logger *zap.Logger, schema *Schema, opts ...SchemaGenOpt) *SchemaGen { + g := &SchemaGen{logger: logger, schema: schema, defaultOut: os.Stdout} + for _, opt := range opts { + opt(g) + } + return g +} + +// Generate generates database schema and writes it to the specified file. +// If an empty string is specified as outputFile, os.Stdout is used for output. +func (g *SchemaGen) Generate(outputFile string) error { + db, err := OpenInMemory( + WithLogger(g.logger), + WithDatabaseSchema(g.schema), + WithForceMigrations(true), + WithIgnoreSchemaDrift()) + if err != nil { + return fmt.Errorf("error opening in-memory db: %w", err) + } + defer func() { + if err := db.Close(); err != nil { + g.logger.Error("error closing in-memory db: %w", zap.Error(err)) + } + }() + loadedScript, err := LoadDBSchemaScript(db) + if err != nil { + return fmt.Errorf("error loading DB schema script: %w", err) + } + if outputFile == "" { + if _, err := io.WriteString(g.defaultOut, loadedScript); err != nil { + return fmt.Errorf("error writing schema file: %w", err) + } + } else if err := os.WriteFile(outputFile, []byte(loadedScript), 0o777); err != nil { + return fmt.Errorf("error writing schema file %q: %w", outputFile, err) + } + return nil +} diff --git a/sql/schema_test.go b/sql/schema_test.go new file mode 100644 index 00000000000..f9ccc083fe1 --- /dev/null +++ b/sql/schema_test.go @@ -0,0 +1,54 @@ +package sql + +import ( + "os" + "path/filepath" + "strings" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" +) + +func TestSchemaGen(t *testing.T) { + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + fs := fstest.MapFS{ + "schema/migrations/0001_first.sql": &fstest.MapFile{ + Data: []byte("create table foo(id int);"), + }, + "schema/migrations/0002_second.sql": &fstest.MapFile{ + Data: []byte("create table bar(id int);"), + }, + } + migrations, err := LoadSQLMigrations(fs) + require.NoError(t, err) + require.Len(t, migrations, 2) + schema := &Schema{ + Script: "this should not be run", + Migrations: migrations, + } + var sb strings.Builder + g := NewSchemaGen(logger, schema, withDefaultOut(&sb)) + tempDir := t.TempDir() + schemaPath := filepath.Join(tempDir, "schema.sql") + require.NoError(t, g.Generate(schemaPath)) + contents, err := os.ReadFile(schemaPath) + require.NoError(t, err) + require.Equal(t, + "PRAGMA user_version = 2;\nCREATE TABLE bar(id int);\nCREATE TABLE foo(id int);\n", + string(contents)) + require.NoError(t, g.Generate("")) + require.Equal(t, string(contents), sb.String()) + + require.Equal(t, 0, observedLogs.Len(), + "expected 0 warning messages in the log (schema drift warnings?)") +} diff --git a/sql/schemagen/main.go b/sql/schemagen/main.go new file mode 100644 index 00000000000..c55d2b3ad46 --- /dev/null +++ b/sql/schemagen/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "flag" + "os" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +var ( + level = zap.LevelFlag("level", zapcore.ErrorLevel, "set log verbosity level") + dbType = flag.String("dbtype", "state", "database type (state, local, default state)") + output = flag.String("output", "", "output file (defaults to stdout)") +) + +func main() { + var ( + err error + schema *sql.Schema + ) + flag.Parse() + core := zapcore.NewCore( + zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig()), + os.Stderr, + zap.NewAtomicLevelAt(*level), + ) + logger := zap.New(core).With(zap.String("dbType", *dbType)) + switch *dbType { + case "state": + schema, err = statesql.Schema() + case "local": + schema, err = localsql.Schema() + default: + logger.Fatal("unknown database type, must be state or local") + } + if err != nil { + logger.Fatal("error loading db schema", zap.Error(err)) + } + g := sql.NewSchemaGen(logger, schema) + if err := g.Generate(*output); err != nil { + logger.Fatal("error generating schema", zap.Error(err), zap.String("output", *output)) + } +} diff --git a/sql/migrations/state/0001_initial.sql b/sql/statesql/schema/migrations/0001_initial.sql similarity index 100% rename from sql/migrations/state/0001_initial.sql rename to sql/statesql/schema/migrations/0001_initial.sql diff --git a/sql/migrations/state/0002_v1.0.3.sql b/sql/statesql/schema/migrations/0002_v1.0.3.sql similarity index 100% rename from sql/migrations/state/0002_v1.0.3.sql rename to sql/statesql/schema/migrations/0002_v1.0.3.sql diff --git a/sql/migrations/state/0003_v1.1.5.sql b/sql/statesql/schema/migrations/0003_v1.1.5.sql similarity index 100% rename from sql/migrations/state/0003_v1.1.5.sql rename to sql/statesql/schema/migrations/0003_v1.1.5.sql diff --git a/sql/migrations/state/0004_v1.1.7.sql b/sql/statesql/schema/migrations/0004_v1.1.7.sql similarity index 100% rename from sql/migrations/state/0004_v1.1.7.sql rename to sql/statesql/schema/migrations/0004_v1.1.7.sql diff --git a/sql/migrations/state/0005_v1.2.0.sql b/sql/statesql/schema/migrations/0005_v1.2.0.sql similarity index 100% rename from sql/migrations/state/0005_v1.2.0.sql rename to sql/statesql/schema/migrations/0005_v1.2.0.sql diff --git a/sql/migrations/state/0006_v1.2.2.sql b/sql/statesql/schema/migrations/0006_v1.2.2.sql similarity index 100% rename from sql/migrations/state/0006_v1.2.2.sql rename to sql/statesql/schema/migrations/0006_v1.2.2.sql diff --git a/sql/migrations/state/0007_v1.3.0.sql b/sql/statesql/schema/migrations/0007_v1.3.0.sql similarity index 100% rename from sql/migrations/state/0007_v1.3.0.sql rename to sql/statesql/schema/migrations/0007_v1.3.0.sql diff --git a/sql/migrations/state/0008_rewards.sql b/sql/statesql/schema/migrations/0008_rewards.sql similarity index 100% rename from sql/migrations/state/0008_rewards.sql rename to sql/statesql/schema/migrations/0008_rewards.sql diff --git a/sql/migrations/state/0009_prune_activesets.sql b/sql/statesql/schema/migrations/0009_prune_activesets.sql similarity index 100% rename from sql/migrations/state/0009_prune_activesets.sql rename to sql/statesql/schema/migrations/0009_prune_activesets.sql diff --git a/sql/migrations/state/0010_rowid.sql b/sql/statesql/schema/migrations/0010_rowid.sql similarity index 100% rename from sql/migrations/state/0010_rowid.sql rename to sql/statesql/schema/migrations/0010_rowid.sql diff --git a/sql/migrations/state/0011_atxs_extra_index.sql b/sql/statesql/schema/migrations/0011_atxs_extra_index.sql similarity index 100% rename from sql/migrations/state/0011_atxs_extra_index.sql rename to sql/statesql/schema/migrations/0011_atxs_extra_index.sql diff --git a/sql/migrations/state/0012_atx_validity.sql b/sql/statesql/schema/migrations/0012_atx_validity.sql similarity index 100% rename from sql/migrations/state/0012_atx_validity.sql rename to sql/statesql/schema/migrations/0012_atx_validity.sql diff --git a/sql/migrations/state/0013_atx_coinbase_index.sql b/sql/statesql/schema/migrations/0013_atx_coinbase_index.sql similarity index 100% rename from sql/migrations/state/0013_atx_coinbase_index.sql rename to sql/statesql/schema/migrations/0013_atx_coinbase_index.sql diff --git a/sql/migrations/state/0014_remove_proposals.sql b/sql/statesql/schema/migrations/0014_remove_proposals.sql similarity index 100% rename from sql/migrations/state/0014_remove_proposals.sql rename to sql/statesql/schema/migrations/0014_remove_proposals.sql diff --git a/sql/migrations/state/0015_nonce_index.sql b/sql/statesql/schema/migrations/0015_nonce_index.sql similarity index 100% rename from sql/migrations/state/0015_nonce_index.sql rename to sql/statesql/schema/migrations/0015_nonce_index.sql diff --git a/sql/migrations/state/0016_atx_blob.sql b/sql/statesql/schema/migrations/0016_atx_blob.sql similarity index 100% rename from sql/migrations/state/0016_atx_blob.sql rename to sql/statesql/schema/migrations/0016_atx_blob.sql diff --git a/sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql b/sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql similarity index 100% rename from sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql rename to sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql diff --git a/sql/migrations/state/0018_atx_blob_version.sql b/sql/statesql/schema/migrations/0018_atx_blob_version.sql similarity index 100% rename from sql/migrations/state/0018_atx_blob_version.sql rename to sql/statesql/schema/migrations/0018_atx_blob_version.sql diff --git a/sql/migrations/state/0019_marriages.sql b/sql/statesql/schema/migrations/0019_marriages.sql similarity index 100% rename from sql/migrations/state/0019_marriages.sql rename to sql/statesql/schema/migrations/0019_marriages.sql diff --git a/sql/migrations/state/0020_atx_merge.sql b/sql/statesql/schema/migrations/0020_atx_merge.sql similarity index 100% rename from sql/migrations/state/0020_atx_merge.sql rename to sql/statesql/schema/migrations/0020_atx_merge.sql diff --git a/sql/migrations/state/0021_atx_posts.sql b/sql/statesql/schema/migrations/0021_atx_posts.sql similarity index 100% rename from sql/migrations/state/0021_atx_posts.sql rename to sql/statesql/schema/migrations/0021_atx_posts.sql diff --git a/sql/statesql/schema/migrations/0022_schema_cleanup.sql b/sql/statesql/schema/migrations/0022_schema_cleanup.sql new file mode 100644 index 00000000000..215bca7adcd --- /dev/null +++ b/sql/statesql/schema/migrations/0022_schema_cleanup.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS _litestream_seq; +DROP TABLE IF EXISTS _litestream_lock; diff --git a/sql/statesql/schema/schema.sql b/sql/statesql/schema/schema.sql new file mode 100755 index 00000000000..3eee3106003 --- /dev/null +++ b/sql/statesql/schema/schema.sql @@ -0,0 +1,159 @@ +PRAGMA user_version = 22; +CREATE TABLE accounts +( + address CHAR(24), + balance UNSIGNED LONG INT, + next_nonce UNSIGNED LONG INT, + layer_updated UNSIGNED LONG INT, + template CHAR(24), + state BLOB, + PRIMARY KEY (address, layer_updated DESC) +); +CREATE INDEX accounts_by_layer_updated ON accounts (layer_updated); +CREATE TABLE activesets +( + id CHAR(32) PRIMARY KEY, + active_set BLOB +, epoch INT DEFAULT 0 NOT NULL) WITHOUT ROWID; +CREATE INDEX activesets_by_epoch ON activesets (epoch asc); +CREATE TABLE atx_blobs +( + id CHAR(32), + atx BLOB +, version INTEGER); +CREATE UNIQUE INDEX atx_blobs_id ON atx_blobs (id); +CREATE TABLE atxs +( + id CHAR(32), + epoch INT NOT NULL, + effective_num_units INT NOT NULL, + commitment_atx CHAR(32), + nonce UNSIGNED LONG INT, + base_tick_height UNSIGNED LONG INT, + tick_count UNSIGNED LONG INT, + sequence UNSIGNED LONG INT, + pubkey CHAR(32), + coinbase CHAR(24), + received INT NOT NULL, + validity INTEGER DEFAULT false +, marriage_atx CHAR(32), weight INTEGER); +CREATE INDEX atxs_by_coinbase ON atxs (coinbase); +CREATE INDEX atxs_by_epoch_by_pubkey ON atxs (epoch, pubkey); +CREATE INDEX atxs_by_epoch_by_pubkey_nonce ON atxs (pubkey, epoch desc, nonce) WHERE nonce IS NOT NULL; +CREATE INDEX atxs_by_epoch_id on atxs (epoch, id); +CREATE INDEX atxs_by_pubkey_by_epoch_desc ON atxs (pubkey, epoch desc); +CREATE UNIQUE INDEX atxs_id ON atxs (id); +CREATE TABLE ballots +( + id CHAR(20) PRIMARY KEY, + atx CHAR(32) NOT NULL, + layer INT NOT NULL, + pubkey VARCHAR, + ballot BLOB +); +CREATE INDEX ballots_by_atx_by_layer ON ballots (atx, layer asc); +CREATE INDEX ballots_by_layer_by_pubkey ON ballots (layer asc, pubkey); +CREATE TABLE beacons +( + epoch INT NOT NULL PRIMARY KEY, + beacon CHAR(4) +) WITHOUT ROWID; +CREATE TABLE block_transactions +( + tid CHAR(32), + bid CHAR(20), + layer INT NOT NULL, + PRIMARY KEY (tid, bid) +) WITHOUT ROWID; +CREATE TABLE blocks +( + id CHAR(20) PRIMARY KEY, + layer INT NOT NULL, + validity SMALL INT, + block BLOB +); +CREATE INDEX blocks_by_layer ON blocks (layer, id asc); +CREATE TABLE certificates +( + layer INT NOT NULL, + block VARCHAR NOT NULL, + cert BLOB, + valid bool NOT NULL, + PRIMARY KEY (layer, block) +); +CREATE TABLE identities +( + pubkey VARCHAR PRIMARY KEY, + proof BLOB +, received INT DEFAULT 0 NOT NULL, marriage_atx CHAR(32), marriage_idx INTEGER, marriage_target CHAR(32), marriage_signature CHAR(64)) WITHOUT ROWID; +CREATE TABLE layers +( + id INT PRIMARY KEY DESC, + weak_coin SMALL INT, + processed SMALL INT, + applied_block VARCHAR, + state_hash CHAR(32), + aggregated_hash CHAR(32) +) WITHOUT ROWID; +CREATE INDEX layers_by_processed ON layers (processed); +CREATE TABLE poets +( + ref VARCHAR PRIMARY KEY, + poet BLOB, + service_id VARCHAR, + round_id VARCHAR +); +CREATE INDEX poets_by_service_id_by_round_id ON poets (service_id, round_id); +CREATE TABLE posts ( + atxid CHAR(32) NOT NULL, + pubkey CHAR(32) NOT NULL, + prev_atxid CHAR(32), + prev_atx_index INT, + units INT NOT NULL, + UNIQUE (atxid, pubkey) + ); +CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey, prev_atxid); +CREATE TABLE proposal_transactions +( + tid CHAR(32), + pid CHAR(20), + layer INT NOT NULL, + PRIMARY KEY (tid, pid) +) WITHOUT ROWID; +CREATE TABLE recovery +( + id INTEGER PRIMARY KEY CHECK (id = 1), + restore INT NOT NULL +); +CREATE TABLE rewards +( + pubkey CHAR(32), + coinbase CHAR(24) NOT NULL, + layer INT NOT NULL, + total_reward UNSIGNED LONG INT, + layer_reward UNSIGNED LONG INT, + PRIMARY KEY (pubkey, layer) +); +CREATE INDEX rewards_by_coinbase ON rewards (coinbase, layer); +CREATE INDEX rewards_by_layer ON rewards (layer asc); +CREATE TABLE transactions +( + id CHAR(32) PRIMARY KEY, + tx BLOB, + header BLOB, + result BLOB, + layer INT, + block CHAR(20), + principal CHAR(24), + nonce BLOB, + timestamp INT NOT NULL +) WITHOUT ROWID; +CREATE INDEX transaction_by_layer_principal ON transactions (layer asc, principal); +CREATE INDEX transaction_by_principal_nonce ON transactions (principal, nonce); +CREATE TABLE transactions_results_addresses +( + address CHAR(24), + tid CHAR(32), + PRIMARY KEY (tid, address) +) WITHOUT ROWID; +CREATE INDEX transactions_results_addresses_by_address ON transactions_results_addresses(address); diff --git a/sql/migrations/state_0021_migration.go b/sql/statesql/state_0021_migration.go similarity index 91% rename from sql/migrations/state_0021_migration.go rename to sql/statesql/state_0021_migration.go index 221289788b2..ec5d5ae9fe0 100644 --- a/sql/migrations/state_0021_migration.go +++ b/sql/statesql/state_0021_migration.go @@ -1,4 +1,4 @@ -package migrations +package statesql import ( "errors" @@ -14,14 +14,14 @@ import ( ) type migration0021 struct { - batch int - logger *zap.Logger + batch int } -func New0021Migration(log *zap.Logger, batch int) *migration0021 { +var _ sql.Migration = &migration0021{} + +func New0021Migration(batch int) *migration0021 { return &migration0021{ - logger: log, - batch: batch, + batch: batch, } } @@ -37,7 +37,7 @@ func (*migration0021) Rollback() error { return nil } -func (m *migration0021) Apply(db sql.Executor) error { +func (m *migration0021) Apply(db sql.Executor, logger *zap.Logger) error { if err := m.applySql(db); err != nil { return err } @@ -49,7 +49,7 @@ func (m *migration0021) Apply(db sql.Executor) error { if err != nil { return fmt.Errorf("counting all ATXs %w", err) } - m.logger.Info("applying migration 21", zap.Int("total", total)) + logger.Info("applying migration 21", zap.Int("total", total)) for offset := 0; ; offset += m.batch { n, err := m.processBatch(db, offset, m.batch) @@ -59,7 +59,7 @@ func (m *migration0021) Apply(db sql.Executor) error { processed := offset + n progress := float64(processed) * 100.0 / float64(total) - m.logger.Info("processed ATXs", zap.Float64("progress [%]", progress)) + logger.Info("processed ATXs", zap.Float64("progress [%]", progress)) if processed >= total { return nil } diff --git a/sql/migrations/state_0021_migration_test.go b/sql/statesql/state_0021_migration_test.go similarity index 57% rename from sql/migrations/state_0021_migration_test.go rename to sql/statesql/state_0021_migration_test.go index 76f50d6d1c1..65edcbd3c62 100644 --- a/sql/migrations/state_0021_migration_test.go +++ b/sql/statesql/state_0021_migration_test.go @@ -1,7 +1,7 @@ -package migrations +package statesql import ( - "strings" + "slices" "testing" "github.com/stretchr/testify/require" @@ -15,42 +15,21 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" ) -// Test that in-code migration results in the same schema as the .sql one. -func Test0021Migration_CompatibleSchema(t *testing.T) { - db := sql.InMemory( - sql.WithLogger(zaptest.NewLogger(t)), - sql.WithMigration(New0021Migration(zaptest.NewLogger(t), 1000)), - ) - - var schemasInCode []string - _, err := db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - schemasInCode = append(schemasInCode, sql) - return true - }) +func Test0021Migration(t *testing.T) { + schema, err := Schema() require.NoError(t, err) - require.NoError(t, db.Close()) - - db = sql.InMemory() - - var schemasInFile []string - _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - schemasInFile = append(schemasInFile, sql) - return true + schema.Migrations = slices.DeleteFunc(schema.Migrations, func(m sql.Migration) bool { + if m.Order() != 21 { + t.Logf("QQQQQ: include migration %d -- %s", m.Order(), m.Name()) + } + return m.Order() == 21 }) - require.NoError(t, err) - require.NoError(t, db.Close()) - require.Equal(t, schemasInFile, schemasInCode) -} - -func Test0021Migration(t *testing.T) { db := sql.InMemory( sql.WithLogger(zaptest.NewLogger(t)), - sql.WithSkipMigrations(21), + sql.WithDatabaseSchema(schema), + sql.WithIgnoreSchemaDrift(), + sql.WithForceMigrations(true), ) var signers [177]*signing.EdSigner @@ -82,9 +61,9 @@ func Test0021Migration(t *testing.T) { } } - m := New0021Migration(zaptest.NewLogger(t), 1000) + m := New0021Migration(1000) require.Equal(t, 21, m.Order()) - require.NoError(t, m.Apply(db)) + require.NoError(t, m.Apply(db, zaptest.NewLogger(t))) for _, posts := range allPosts { for atx, post := range posts { diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go new file mode 100644 index 00000000000..b755237fb6d --- /dev/null +++ b/sql/statesql/statesql.go @@ -0,0 +1,79 @@ +package statesql + +import ( + "embed" + "strings" + "testing" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:generate go run ../schemagen -dbtype state -output schema/schema.sql + +//go:embed schema/schema.sql +var schemaScript string + +//go:embed schema/migrations/*.sql +var migrations embed.FS + +type database struct { + sql.Database +} + +var _ sql.StateDatabase = &database{} + +func (db *database) IsStateDatabase() {} + +// Schema returns the schema for the state database. +func Schema() (*sql.Schema, error) { + sqlMigrations, err := sql.LoadSQLMigrations(migrations) + if err != nil { + return nil, err + } + sqlMigrations = sqlMigrations.AddMigration(New0021Migration(1_000_000)) + // NOTE: coded state migrations can be added here + // They can be a part of this localsql package + return &sql.Schema{ + Script: strings.ReplaceAll(schemaScript, "\r", ""), + Migrations: sqlMigrations, + }, nil +} + +// Open opens a state database. +func Open(uri string, opts ...sql.Opt) (sql.StateDatabase, error) { + schema, err := Schema() + if err != nil { + return nil, err + } + opts = append([]sql.Opt{sql.WithDatabaseSchema(schema)}, opts...) + db, err := sql.Open(uri, opts...) + if err != nil { + return nil, err + } + return &database{Database: db}, nil +} + +// Open opens an in-memory state database. +func InMemory(opts ...sql.Opt) sql.StateDatabase { + schema, err := Schema() + if err != nil { + panic(err) + } + defaultOpts := []sql.Opt{ + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db := sql.InMemory(opts...) + return &database{Database: db} +} + +// InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. +func InMemoryTest(tb testing.TB, opts ...sql.Opt) sql.StateDatabase { + opts = append(opts, sql.WithConnections(1)) + db, err := Open("file::memory:?mode=memory", opts...) + if err != nil { + panic(err) + } + tb.Cleanup(func() { db.Close() }) + return db +} diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go new file mode 100644 index 00000000000..207d7f2a5a2 --- /dev/null +++ b/sql/statesql/statesql_test.go @@ -0,0 +1,74 @@ +package statesql + +import ( + "path/filepath" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func TestIdempotentMigration(t *testing.T) { + observer, observedLogs := observer.New(zapcore.InfoLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + file := filepath.Join(t.TempDir(), "test.db") + db, err := Open("file:"+file, sql.WithForceMigrations(true), sql.WithLogger(logger)) + require.NoError(t, err) + + var versionA int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionA = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + require.NoError(t, db.Close()) + + // "running migrations" + "applying migration 21" + "processed ATXs" + require.Equal(t, 3, observedLogs.Len(), "expected count of log messages") + l := observedLogs.All()[0] + require.Equal(t, "running migrations", l.Message) + require.Equal(t, int64(0), l.ContextMap()["current version"]) + require.Equal(t, int64(versionA), l.ContextMap()["target version"]) + + db, err = Open("file:"+file, sql.WithLogger(logger)) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var versionB int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionB = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + schema, err := Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), versionA) + require.Equal(t, expectedVersion.Order(), versionB) + + require.NoError(t, db.Close()) + // make sure there's no schema drift warnings in the logs + require.Equal(t, 3, observedLogs.Len(), "expected 1 log message") +} diff --git a/sql/transactions/iterator_test.go b/sql/transactions/iterator_test.go index fd3cad3a8e0..334cef2a97e 100644 --- a/sql/transactions/iterator_test.go +++ b/sql/transactions/iterator_test.go @@ -15,6 +15,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func matchTx(tx types.TransactionWithResult, filter ResultsFilter) bool { @@ -59,11 +60,11 @@ func filterTxs(txs []types.TransactionWithResult, filter ResultsFilter) []types. } func TestIterateResults(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() gen := fixture.NewTransactionResultGenerator() txs := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(context.TODO(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.TODO(), func(dtx sql.Transaction) error { for i := range txs { tx := gen.Next() @@ -142,12 +143,12 @@ func TestIterateResults(t *testing.T) { } func TestIterateSnapshot(t *testing.T) { - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) t.Cleanup(func() { require.NoError(t, db.Close()) }) require.NoError(t, err) gen := fixture.NewTransactionResultGenerator() expect := 10 - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { for i := 0; i < expect; i++ { tx := gen.Next() @@ -175,7 +176,7 @@ func TestIterateSnapshot(t *testing.T) { }() <-initialized - require.NoError(t, db.WithTx(context.TODO(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.TODO(), func(dtx sql.Transaction) error { for i := 0; i < 10; i++ { tx := gen.Next() diff --git a/sql/transactions/transactions.go b/sql/transactions/transactions.go index f0830e8e3ad..b1e23879708 100644 --- a/sql/transactions/transactions.go +++ b/sql/transactions/transactions.go @@ -28,8 +28,8 @@ func Add(db sql.Executor, tx *types.Transaction, received time.Time) error { if _, err = db.Exec(` insert into transactions (id, tx, header, principal, nonce, timestamp) values (?1, ?2, ?3, ?4, ?5, ?6) - on conflict(id) do update set - header=?3, principal=?4, nonce=?5 + on conflict(id) do update set + header=?3, principal=?4, nonce=?5 where header is null;`, func(stmt *sql.Statement) { stmt.BindBytes(1, tx.ID.Bytes()) @@ -49,7 +49,7 @@ func Add(db sql.Executor, tx *types.Transaction, received time.Time) error { // AddToProposal associates a transaction with a proposal. func AddToProposal(db sql.Executor, tid types.TransactionID, lid types.LayerID, pid types.ProposalID) error { if _, err := db.Exec(` - insert into proposal_transactions (pid, tid, layer) values (?1, ?2, ?3) + insert into proposal_transactions (pid, tid, layer) values (?1, ?2, ?3) on conflict(tid, pid) do nothing;`, func(stmt *sql.Statement) { stmt.BindBytes(1, pid.Bytes()) @@ -136,8 +136,8 @@ func GetAppliedLayer(db sql.Executor, tid types.TransactionID) (types.LayerID, e } // UndoLayers unset all transactions to `statePending` from `from` layer to the max layer with applied transactions. -func UndoLayers(db *sql.Tx, from types.LayerID) error { - _, err := db.Exec(`delete from transactions_results_addresses +func UndoLayers(tx sql.Transaction, from types.LayerID) error { + _, err := tx.Exec(`delete from transactions_results_addresses where tid in (select id from transactions where layer >= ?1);`, func(stmt *sql.Statement) { stmt.BindInt64(1, int64(from)) @@ -145,8 +145,8 @@ func UndoLayers(db *sql.Tx, from types.LayerID) error { if err != nil { return fmt.Errorf("delete addresses mapping %w", err) } - _, err = db.Exec(`update transactions - set layer = null, block = null, result = null + _, err = tx.Exec(`update transactions + set layer = null, block = null, result = null where layer >= ?1`, func(stmt *sql.Statement) { stmt.BindInt64(1, int64(from)) @@ -287,7 +287,7 @@ func AddressesWithPendingTransactions(db sql.Executor) ([]types.AddressNonce, er // GetAcctPendingFromNonce get all pending transactions with nonce after `from` for the given address. func GetAcctPendingFromNonce(db sql.Executor, address types.Address, from uint64) ([]*types.MeshTransaction, error) { return queryPending(db, `select tx, header, layer, block, timestamp, id from transactions - where principal = ?1 and nonce >= ?2 and result is null + where principal = ?1 and nonce >= ?2 and result is null order by nonce asc, timestamp asc`, func(stmt *sql.Statement) { stmt.BindBytes(1, address.Bytes()) @@ -321,14 +321,14 @@ func queryPending( } // AddResult adds result for the transaction. -func AddResult(db *sql.Tx, id types.TransactionID, rst *types.TransactionResult) error { +func AddResult(tx sql.Transaction, id types.TransactionID, rst *types.TransactionResult) error { buf, err := codec.Encode(rst) if err != nil { return fmt.Errorf("encode %w", err) } - if rows, err := db.Exec(`update transactions - set result = ?2, layer = ?3, block = ?4 + if rows, err := tx.Exec(`update transactions + set result = ?2, layer = ?3, block = ?4 where id = ?1 and result is null returning id;`, func(stmt *sql.Statement) { stmt.BindBytes(1, id[:]) @@ -345,7 +345,7 @@ func AddResult(db *sql.Tx, id types.TransactionID, rst *types.TransactionResult) return fmt.Errorf("invalid state for %s", id) } for i := range rst.Addresses { - if _, err := db.Exec(`insert into transactions_results_addresses + if _, err := tx.Exec(`insert into transactions_results_addresses (address, tid) values (?1, ?2);`, func(stmt *sql.Statement) { stmt.BindBytes(1, rst.Addresses[i][:]) @@ -418,7 +418,7 @@ func IterateTransactionsOps( fn func(tx *types.MeshTransaction, result *types.TransactionResult) bool, ) error { var derr error - _, err := db.Exec(`select distinct tx, header, layer, block, timestamp, id, result + _, err := db.Exec(`select distinct tx, header, layer, block, timestamp, id, result from transactions left join transactions_results_addresses on id=tid`+builder.FilterFrom(operations), builder.BindingsFrom(operations), diff --git a/sql/transactions/transactions_test.go b/sql/transactions/transactions_test.go index ed2eef3cc95..4711e9fba2e 100644 --- a/sql/transactions/transactions_test.go +++ b/sql/transactions/transactions_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -70,7 +71,7 @@ func checkMeshTXEqual(t *testing.T, expected, got types.MeshTransaction) { } func TestAddGetHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer1, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -109,7 +110,7 @@ func TestAddGetHas(t *testing.T) { } func TestAddUpdatesHeader(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() txs := []*types.Transaction{ { RawTx: types.NewRawTx([]byte{1, 2, 3}), @@ -142,7 +143,7 @@ func TestAddUpdatesHeader(t *testing.T) { } func TestAddToProposal(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -166,7 +167,7 @@ func TestAddToProposal(t *testing.T) { } func TestDeleteProposalTxs(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() proposals := map[types.LayerID][]types.ProposalID{ types.LayerID(10): {{1, 1}, {1, 2}}, types.LayerID(11): {{2, 1}, {2, 2}}, @@ -197,7 +198,7 @@ func TestDeleteProposalTxs(t *testing.T) { } func TestAddToBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -221,7 +222,7 @@ func TestAddToBlock(t *testing.T) { } func TestApply_AlreadyApplied(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) lid := types.LayerID(10) @@ -231,17 +232,17 @@ func TestApply_AlreadyApplied(t *testing.T) { require.NoError(t, transactions.Add(db, tx, time.Now())) bid := types.RandomBlockID() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) // same block applied again - require.Error(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) // different block applied again - require.Error(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult( dtx, tx.ID, @@ -251,15 +252,15 @@ func TestApply_AlreadyApplied(t *testing.T) { } func TestUndoLayers_Empty(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, types.LayerID(199)) })) } func TestApplyAndUndoLayers(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) firstLayer := types.LayerID(10) @@ -272,7 +273,7 @@ func TestApplyAndUndoLayers(t *testing.T) { require.NoError(t, transactions.Add(db, tx, time.Now())) bid := types.RandomBlockID() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) applied = append(applied, tx.ID) @@ -284,7 +285,7 @@ func TestApplyAndUndoLayers(t *testing.T) { require.Equal(t, types.APPLIED, mtx.State) } // revert to firstLayer - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, firstLayer.Add(1)) })) @@ -300,7 +301,7 @@ func TestApplyAndUndoLayers(t *testing.T) { } func TestGetBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() rng := rand.New(rand.NewSource(1001)) @@ -333,7 +334,7 @@ func TestGetBlob(t *testing.T) { } func TestGetByAddress(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer1, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -348,7 +349,7 @@ func TestGetByAddress(t *testing.T) { createTX(t, signer1, signer2Address, 1, 191, 1), } received := time.Now() - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tx := range txs { require.NoError(t, transactions.Add(dbtx, tx, received)) require.NoError(t, transactions.AddResult(dbtx, tx.ID, &types.TransactionResult{Layer: lid})) @@ -369,7 +370,7 @@ func TestGetByAddress(t *testing.T) { } func TestGetAcctPendingFromNonce(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -404,7 +405,7 @@ func TestGetAcctPendingFromNonce(t *testing.T) { } func TestAppliedLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) require.NoError(t, err) @@ -417,7 +418,7 @@ func TestAppliedLayer(t *testing.T) { for _, tx := range txs { require.NoError(t, transactions.Add(db, tx, time.Now())) } - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, txs[0].ID, &types.TransactionResult{Layer: lid, Block: types.BlockID{1, 1}}) })) @@ -428,7 +429,7 @@ func TestAppliedLayer(t *testing.T) { _, err = transactions.GetAppliedLayer(db, txs[1].ID) require.ErrorIs(t, err, sql.ErrNotFound) - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, lid) })) _, err = transactions.GetAppliedLayer(db, txs[0].ID) @@ -455,7 +456,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { TxHeader: &types.TxHeader{Principal: principals[1], Nonce: 0}, }, } - db := sql.InMemory() + db := statesql.InMemory() for _, tx := range txs { require.NoError(t, transactions.Add(db, &tx, time.Time{})) } @@ -465,7 +466,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { {Address: principals[0], Nonce: txs[0].Nonce}, {Address: principals[1], Nonce: txs[2].Nonce}, }, rst) - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { return transactions.AddResult(dbtx, txs[0].ID, &types.TransactionResult{Message: "hey"}) })) rst, err = transactions.AddressesWithPendingTransactions(db) @@ -474,7 +475,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { {Address: principals[0], Nonce: txs[1].Nonce}, {Address: principals[1], Nonce: txs[2].Nonce}, }, rst) - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { return transactions.AddResult(dbtx, txs[2].ID, &types.TransactionResult{Message: "hey"}) })) rst, err = transactions.AddressesWithPendingTransactions(db) @@ -520,7 +521,7 @@ func TestTransactionInProposal(t *testing.T) { {2}, {3}, } - db := sql.InMemory() + db := statesql.InMemory() for i := range lids { require.NoError(t, transactions.AddToProposal(db, tid, lids[i], pids[i])) } @@ -546,7 +547,7 @@ func TestTransactionInBlock(t *testing.T) { {2}, {3}, } - db := sql.InMemory() + db := statesql.InMemory() for i := range lids { require.NoError(t, transactions.AddToBlock(db, tid, lids[i], bids[i])) } diff --git a/sql/vacuum_test.go b/sql/vacuum_test.go index b994516279e..1a89158c644 100644 --- a/sql/vacuum_test.go +++ b/sql/vacuum_test.go @@ -7,6 +7,6 @@ import ( ) func TestVacuumDB(t *testing.T) { - db := InMemory() + db := InMemory(WithIgnoreSchemaDrift()) require.NoError(t, Vacuum(db)) } diff --git a/syncer/atxsync/atxsync.go b/syncer/atxsync/atxsync.go index a9637ba4eb5..1cee0fa194c 100644 --- a/syncer/atxsync/atxsync.go +++ b/syncer/atxsync/atxsync.go @@ -15,7 +15,7 @@ import ( "github.com/spacemeshos/go-spacemesh/system" ) -func getMissing(db *sql.Database, set []types.ATXID) ([]types.ATXID, error) { +func getMissing(db sql.StateDatabase, set []types.ATXID) ([]types.ATXID, error) { missing := []types.ATXID{} for _, atx := range set { exist, err := atxs.Has(db, atx) @@ -36,7 +36,7 @@ func Download( ctx context.Context, retryInterval time.Duration, logger *zap.Logger, - db *sql.Database, + db sql.StateDatabase, fetcher system.AtxFetcher, set []types.ATXID, ) error { diff --git a/syncer/atxsync/atxsync_test.go b/syncer/atxsync/atxsync_test.go index 5713ee6aba4..126f5ae130a 100644 --- a/syncer/atxsync/atxsync_test.go +++ b/syncer/atxsync/atxsync_test.go @@ -11,8 +11,8 @@ import ( "go.uber.org/zap/zaptest" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -97,7 +97,7 @@ func TestDownload(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { logger := zaptest.NewLogger(t) - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(t) fetcher := mocks.NewMockAtxFetcher(ctrl) for _, atx := range tc.existing { diff --git a/syncer/atxsync/syncer.go b/syncer/atxsync/syncer.go index 3fc77c805e0..8e276d2e437 100644 --- a/syncer/atxsync/syncer.go +++ b/syncer/atxsync/syncer.go @@ -17,7 +17,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/atxsync" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -76,7 +75,7 @@ func WithConfig(cfg Config) Opt { } } -func New(fetcher fetcher, db sql.Executor, localdb *localsql.Database, opts ...Opt) *Syncer { +func New(fetcher fetcher, db sql.Executor, localdb sql.LocalDatabase, opts ...Opt) *Syncer { s := &Syncer{ logger: zap.NewNop(), cfg: DefaultConfig(), @@ -95,7 +94,7 @@ type Syncer struct { cfg Config fetcher fetcher db sql.Executor - localdb *localsql.Database + localdb sql.LocalDatabase } func (s *Syncer) Download(parent context.Context, publish types.EpochID, downloadUntil time.Time) error { @@ -324,7 +323,7 @@ func (s *Syncer) downloadAtxs( } } - if err := s.localdb.WithTx(context.Background(), func(tx *sql.Tx) error { + if err := s.localdb.WithTx(context.Background(), func(tx sql.Transaction) error { err := atxsync.SaveRequest(tx, publish, lastSuccess, int64(len(state)), int64(len(downloaded))) if err != nil { return fmt.Errorf("failed to save request time: %w", err) diff --git a/syncer/atxsync/syncer_test.go b/syncer/atxsync/syncer_test.go index 86009cc8a6b..fb75373b1b9 100644 --- a/syncer/atxsync/syncer_test.go +++ b/syncer/atxsync/syncer_test.go @@ -18,6 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/atxsync" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/atxsync/mocks" "github.com/spacemeshos/go-spacemesh/system" ) @@ -42,7 +43,7 @@ func edata(ids ...string) *fetch.EpochData { func newTester(tb testing.TB, cfg Config) *tester { localdb := localsql.InMemory() - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(tb) fetcher := mocks.NewMockfetcher(ctrl) syncer := New(fetcher, db, localdb, WithConfig(cfg), WithLogger(zaptest.NewLogger(tb))) @@ -60,8 +61,8 @@ func newTester(tb testing.TB, cfg Config) *tester { type tester struct { tb testing.TB syncer *Syncer - localdb *localsql.Database - db *sql.Database + localdb sql.LocalDatabase + db sql.StateDatabase cfg Config ctrl *gomock.Controller fetcher *mocks.Mockfetcher diff --git a/syncer/find_fork_test.go b/syncer/find_fork_test.go index 2759b6145b4..5fe14fb1ee1 100644 --- a/syncer/find_fork_test.go +++ b/syncer/find_fork_test.go @@ -19,19 +19,20 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/mocks" ) type testForkFinder struct { *syncer.ForkFinder - db *sql.Database + db sql.StateDatabase mFetcher *mocks.Mockfetcher } func newTestForkFinderWithDuration(t *testing.T, d time.Duration, lg *zap.Logger) *testForkFinder { mf := mocks.NewMockfetcher(gomock.NewController(t)) - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, layers.SetMeshHash(db, types.GetEffectiveGenesis(), types.RandomHash())) return &testForkFinder{ ForkFinder: syncer.NewForkFinder(lg, db, mf, d), @@ -88,7 +89,7 @@ func layerHash(layer int, good bool) types.Hash32 { return h2 } -func storeNodeHashes(t *testing.T, db *sql.Database, diverge, max int) { +func storeNodeHashes(t *testing.T, db sql.StateDatabase, diverge, max int) { for lid := 0; lid <= max; lid++ { if lid < diverge { require.NoError(t, layers.SetMeshHash(db, types.LayerID(uint32(lid)), layerHash(lid, true))) diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index 160f9bcbfc5..173756002f0 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -18,7 +18,6 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/system" ) @@ -218,12 +217,12 @@ type Syncer struct { cfg Config fetcher fetcher db sql.Executor - localdb *localsql.Database + localdb sql.LocalDatabase clock clockwork.Clock peerErrMetric counter } -func New(fetcher fetcher, db sql.Executor, localdb *localsql.Database, opts ...Opt) *Syncer { +func New(fetcher fetcher, db sql.Executor, localdb sql.LocalDatabase, opts ...Opt) *Syncer { s := &Syncer{ logger: zap.NewNop(), cfg: DefaultConfig(), @@ -341,8 +340,15 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan } } -func (s *Syncer) updateState() error { - if err := malsync.UpdateSyncState(s.localdb, s.clock.Now()); err != nil { +func (s *Syncer) updateState(ctx context.Context) error { + if err := s.localdb.WithTx(ctx, func(tx sql.Transaction) error { + return malsync.UpdateSyncState(tx, s.clock.Now()) + }); err != nil { + if ctx.Err() != nil { + // FIXME: with crawshaw, canceling the context which has been used to get + // a connection from the pool may cause "database: no free connection" errors + err = ctx.Err() + } return fmt.Errorf("error updating malsync state: %w", err) } @@ -360,13 +366,13 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up if nothingToDownload { sst.done() if initial && sst.numSyncedPeers() >= s.cfg.MinSyncPeers { - if err := s.updateState(); err != nil { + if err := s.updateState(ctx); err != nil { return err } s.logger.Info("initial sync of malfeasance proofs completed", log.ZContext(ctx)) return nil } else if !initial && gotUpdate { - if err := s.updateState(); err != nil { + if err := s.updateState(ctx); err != nil { return err } } diff --git a/syncer/malsync/syncer_test.go b/syncer/malsync/syncer_test.go index 692a7d0173f..035b55b0250 100644 --- a/syncer/malsync/syncer_test.go +++ b/syncer/malsync/syncer_test.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/malsync/mocks" ) @@ -137,8 +138,8 @@ func malData(ids ...string) []types.NodeID { type tester struct { tb testing.TB syncer *Syncer - localdb *localsql.Database - db *sql.Database + localdb sql.LocalDatabase + db sql.StateDatabase cfg Config ctrl *gomock.Controller fetcher *mocks.Mockfetcher @@ -151,7 +152,7 @@ type tester struct { func newTester(tb testing.TB, cfg Config) *tester { localdb := localsql.InMemory() - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(tb) fetcher := mocks.NewMockfetcher(ctrl) clock := clockwork.NewFakeClock() diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index 88127db8cb2..f92c4357656 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -20,8 +20,8 @@ import ( "github.com/spacemeshos/go-spacemesh/mesh" mmocks "github.com/spacemeshos/go-spacemesh/mesh/mocks" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/mocks" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -127,7 +127,7 @@ func newTestSyncer(t *testing.T, interval time.Duration) *testSyncer { mCertHdr: mocks.NewMockcertHandler(ctrl), mForkFinder: mocks.NewMockforkFinder(ctrl), } - db := sql.InMemory() + db := statesql.InMemory() ts.cdb = datastore.NewCachedDB(db, lg) var err error atxsdata := atxsdata.New() diff --git a/systest/tests/distributed_post_verification_test.go b/systest/tests/distributed_post_verification_test.go index 7b78ce6d5c5..911b9c9a50d 100644 --- a/systest/tests/distributed_post_verification_test.go +++ b/systest/tests/distributed_post_verification_test.go @@ -30,9 +30,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/handshake" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/systest/cluster" "github.com/spacemeshos/go-spacemesh/systest/testcontext" "github.com/spacemeshos/go-spacemesh/timesync" @@ -110,7 +110,7 @@ func TestPostMalfeasanceProof(t *testing.T) { postSetupMgr, err := activation.NewPostSetupManager( cfg.POST, logger.Named("post"), - datastore.NewCachedDB(sql.InMemory(), zap.NewNop()), + datastore.NewCachedDB(statesql.InMemory(), zap.NewNop()), atxsdata.New(), cl.GoldenATX(), syncer, @@ -156,7 +156,7 @@ func TestPostMalfeasanceProof(t *testing.T) { require.NoError(t, grpcPrivateServer.Start()) t.Cleanup(func() { assert.NoError(t, grpcPrivateServer.Close()) }) - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() certClient := activation.NewCertifierClient(db, localDb, logger.Named("certifier")) certifier := activation.NewCertifier(localDb, logger, certClient) diff --git a/tortoise/model/core.go b/tortoise/model/core.go index 2a352e1eb32..22a1040ffed 100644 --- a/tortoise/model/core.go +++ b/tortoise/model/core.go @@ -20,6 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -29,7 +30,7 @@ const ( ) func newCore(rng *rand.Rand, id string, logger *zap.Logger) *core { - cdb := datastore.NewCachedDB(sql.InMemory(), logger) + cdb := datastore.NewCachedDB(statesql.InMemory(), logger) sig, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) if err != nil { panic(err) @@ -135,7 +136,7 @@ func (c *core) OnMessage(m Messenger, event Message) { if ev.LayerID.After(types.GetEffectiveGenesis()) { tortoise.RecoverLayer(context.Background(), c.tortoise, - c.cdb.Executor, + c.cdb.Database, c.atxdata, ev.LayerID, c.tortoise.OnBallot, diff --git a/tortoise/recover_test.go b/tortoise/recover_test.go index ecc4c6b51b7..be41f02dd64 100644 --- a/tortoise/recover_test.go +++ b/tortoise/recover_test.go @@ -58,7 +58,7 @@ func TestRecoverState(t *testing.T) { tortoise2, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, simState.Atxdata, last, WithLogger(lg), @@ -82,7 +82,7 @@ func TestRecoverEmpty(t *testing.T) { cfg.LayerSize = size tortoise, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), 100, WithLogger(zaptest.NewLogger(t)), @@ -108,18 +108,18 @@ func TestRecoverWithOpinion(t *testing.T) { var last result.Layer for _, rst := range trt.Updates() { if rst.Verified { - require.NoError(t, layers.SetMeshHash(s.GetState(0).DB.Executor, rst.Layer, rst.Opinion)) + require.NoError(t, layers.SetMeshHash(s.GetState(0).DB.Database, rst.Layer, rst.Opinion)) } for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } last = rst } tortoise, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last.Layer, WithLogger(lg), @@ -156,14 +156,14 @@ func TestResetPending(t *testing.T) { require.NoError(t, layers.SetMeshHash(s.GetState(0).DB, rst.Layer, rst.Opinion)) for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } } recovered, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last, WithLogger(lg), @@ -203,14 +203,14 @@ func TestWindowRecovery(t *testing.T) { require.NoError(t, layers.SetMeshHash(s.GetState(0).DB, rst.Layer, rst.Opinion)) for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } } recovered, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last, WithLogger(lg), @@ -239,7 +239,7 @@ func TestRecoverOnlyAtxs(t *testing.T) { trt.TallyVotes(context.Background(), lid) } future := last + 1000 - recovered, err := Recover(context.Background(), s.GetState(0).DB.Executor, s.GetState(0).Atxdata, future, + recovered, err := Recover(context.Background(), s.GetState(0).DB.Database, s.GetState(0).Atxdata, future, WithLogger(zaptest.NewLogger(t)), WithConfig(cfg), ) diff --git a/tortoise/replay/replay_test.go b/tortoise/replay/replay_test.go index 04a5bd7e28b..0393919fd40 100644 --- a/tortoise/replay/replay_test.go +++ b/tortoise/replay/replay_test.go @@ -15,8 +15,8 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/config" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/timesync" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -50,7 +50,7 @@ func TestReplayMainnet(t *testing.T) { ) require.NoError(t, err) - db, err := sql.Open(fmt.Sprintf("file:%s?mode=ro", *dbpath)) + db, err := statesql.Open(fmt.Sprintf("file:%s?mode=ro", *dbpath)) require.NoError(t, err) applied, err := layers.GetLastApplied(db) diff --git a/tortoise/sim/utils.go b/tortoise/sim/utils.go index e7f66f06644..86337c45514 100644 --- a/tortoise/sim/utils.go +++ b/tortoise/sim/utils.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -16,13 +17,13 @@ const ( func newCacheDB(logger *zap.Logger, conf config) *datastore.CachedDB { var ( - db *sql.Database + db sql.StateDatabase err error ) if len(conf.Path) == 0 { - db = sql.InMemory() + db = statesql.InMemory() } else { - db, err = sql.Open(filepath.Join(conf.Path, atxpath)) + db, err = statesql.Open(filepath.Join(conf.Path, atxpath)) if err != nil { panic(err) } diff --git a/tortoise/threshold_test.go b/tortoise/threshold_test.go index 7fd299fa784..bf4f52c367d 100644 --- a/tortoise/threshold_test.go +++ b/tortoise/threshold_test.go @@ -8,8 +8,8 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestComputeThreshold(t *testing.T) { @@ -165,7 +165,7 @@ func TestReferenceHeight(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i, height := range tc.heights { atx := &types.ActivationTx{ PublishEpoch: types.EpochID(tc.epoch) - 1, diff --git a/tortoise/tortoise_test.go b/tortoise/tortoise_test.go index 93187318247..9700cada28c 100644 --- a/tortoise/tortoise_test.go +++ b/tortoise/tortoise_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/tortoise/opinionhash" "github.com/spacemeshos/go-spacemesh/tortoise/sim" ) @@ -336,7 +337,7 @@ func tortoiseFromSimState(tb testing.TB, state sim.State, opts ...Opt) *recovery return &recoveryAdapter{ TB: tb, Tortoise: trtl, - db: state.DB.Executor, + db: state.DB.Database, atxdata: state.Atxdata, } } @@ -467,7 +468,7 @@ func TestComputeExpectedWeight(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { var ( - db = sql.InMemory() + db = statesql.InMemory() epochs = map[types.EpochID]*epochInfo{} first = tc.target.Add(1).GetEpoch() ) diff --git a/tortoise/tracer_test.go b/tortoise/tracer_test.go index b40cd7b26cd..21e217ee5ad 100644 --- a/tortoise/tracer_test.go +++ b/tortoise/tracer_test.go @@ -46,7 +46,7 @@ func TestTracer(t *testing.T) { path := filepath.Join(t.TempDir(), "tortoise.trace") trt, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, s.GetState(0).Atxdata, last, WithTracer(WithOutput(path)), diff --git a/txs/cache.go b/txs/cache.go index 85eedc0c721..404079c6980 100644 --- a/txs/cache.go +++ b/txs/cache.go @@ -295,7 +295,7 @@ func (ac *accountCache) add(logger *zap.Logger, tx *types.Transaction, received func (ac *accountCache) addPendingFromNonce( logger *zap.Logger, - db *sql.Database, + db sql.StateDatabase, nonce uint64, applied types.LayerID, ) error { @@ -382,7 +382,7 @@ func (ac *accountCache) getMempool(logger *zap.Logger) []*NanoTX { // because applying a layer changes the conservative balance in the cache. func (ac *accountCache) resetAfterApply( logger *zap.Logger, - db *sql.Database, + db sql.StateDatabase, nextNonce, newBalance uint64, applied types.LayerID, ) error { @@ -448,7 +448,7 @@ func groupTXsByPrincipal(logger *zap.Logger, mtxs []*types.MeshTransaction) map[ } // buildFromScratch builds the cache from database. -func (c *Cache) buildFromScratch(db *sql.Database) error { +func (c *Cache) buildFromScratch(db sql.StateDatabase) error { applied, err := layers.GetLastApplied(db) if err != nil { return fmt.Errorf("cache: get pending %w", err) @@ -560,7 +560,7 @@ func acceptable(err error) bool { func (c *Cache) Add( ctx context.Context, - db *sql.Database, + db sql.StateDatabase, tx *types.Transaction, received time.Time, mustPersist bool, @@ -607,7 +607,7 @@ func (c *Cache) has(tid types.TransactionID) bool { // LinkTXsWithProposal associates the transactions to a proposal. func (c *Cache) LinkTXsWithProposal( - db *sql.Database, + db sql.StateDatabase, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID, @@ -624,7 +624,7 @@ func (c *Cache) LinkTXsWithProposal( // LinkTXsWithBlock associates the transactions to a block. func (c *Cache) LinkTXsWithBlock( - db *sql.Database, + db sql.StateDatabase, lid types.LayerID, bid types.BlockID, tids []types.TransactionID, @@ -656,7 +656,7 @@ func (c *Cache) updateLayer(lid types.LayerID, bid types.BlockID, tids []types.T } } -func (c *Cache) applyEmptyLayer(db *sql.Database, lid types.LayerID) error { +func (c *Cache) applyEmptyLayer(db sql.StateDatabase, lid types.LayerID) error { c.mu.Lock() defer c.mu.Unlock() @@ -675,7 +675,7 @@ func (c *Cache) applyEmptyLayer(db *sql.Database, lid types.LayerID) error { // ApplyLayer retires the applied transactions from the cache and updates the balances. func (c *Cache) ApplyLayer( ctx context.Context, - db *sql.Database, + db sql.StateDatabase, lid types.LayerID, bid types.BlockID, results []types.TransactionWithResult, @@ -703,7 +703,7 @@ func (c *Cache) ApplyLayer( // commit results before reporting them // TODO(dshulyak) save results in vm - if err := db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + if err := db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, rst := range results { err := transactions.AddResult(dbtx, rst.ID, &rst.TransactionResult) if err != nil { @@ -792,7 +792,7 @@ func (c *Cache) ApplyLayer( return nil } -func (c *Cache) RevertToLayer(db *sql.Database, revertTo types.LayerID) error { +func (c *Cache) RevertToLayer(db sql.StateDatabase, revertTo types.LayerID) error { if err := undoLayers(db, revertTo.Add(1)); err != nil { return err } @@ -832,7 +832,7 @@ func (c *Cache) GetMempool() map[types.Address][]*NanoTX { } // checkApplyOrder returns an error if layers were not applied in order. -func checkApplyOrder(logger *zap.Logger, db *sql.Database, toApply types.LayerID) error { +func checkApplyOrder(logger *zap.Logger, db sql.StateDatabase, toApply types.LayerID) error { lastApplied, err := layers.GetLastApplied(db) if err != nil { return fmt.Errorf("cache get last applied %w", err) @@ -847,8 +847,8 @@ func checkApplyOrder(logger *zap.Logger, db *sql.Database, toApply types.LayerID return nil } -func addToProposal(db *sql.Database, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func addToProposal(db sql.StateDatabase, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tid := range tids { if err := transactions.AddToProposal(dbtx, tid, lid, pid); err != nil { return fmt.Errorf("add2prop %w", err) @@ -858,8 +858,8 @@ func addToProposal(db *sql.Database, lid types.LayerID, pid types.ProposalID, ti }) } -func addToBlock(db *sql.Database, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func addToBlock(db sql.StateDatabase, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tid := range tids { if err := transactions.AddToBlock(dbtx, tid, lid, bid); err != nil { return fmt.Errorf("add2block %w", err) @@ -869,8 +869,8 @@ func addToBlock(db *sql.Database, lid types.LayerID, bid types.BlockID, tids []t }) } -func undoLayers(db *sql.Database, from types.LayerID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func undoLayers(db sql.StateDatabase, from types.LayerID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { err := transactions.UndoLayers(dbtx, from) if err != nil { return fmt.Errorf("undo %w", err) diff --git a/txs/cache_test.go b/txs/cache_test.go index 352ba5d4663..e172e11c5e6 100644 --- a/txs/cache_test.go +++ b/txs/cache_test.go @@ -14,12 +14,13 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) type testCache struct { *Cache - db *sql.Database + db sql.StateDatabase } type testAcct struct { @@ -68,7 +69,7 @@ func newMeshTX( func genAndSaveTXs( t *testing.T, - db *sql.Database, + db sql.StateDatabase, signer *signing.EdSigner, from, to uint64, startTime time.Time, @@ -89,14 +90,14 @@ func genTXs(t *testing.T, signer *signing.EdSigner, from, to uint64, startTime t return mtxs } -func saveTXs(t *testing.T, db *sql.Database, mtxs []*types.MeshTransaction) { +func saveTXs(t *testing.T, db sql.StateDatabase, mtxs []*types.MeshTransaction) { t.Helper() for _, mtx := range mtxs { require.NoError(t, transactions.Add(db, &mtx.Transaction, mtx.Received)) } } -func checkTXStateFromDB(t *testing.T, db *sql.Database, txs []*types.MeshTransaction, state types.TXState) { +func checkTXStateFromDB(t *testing.T, db sql.StateDatabase, txs []*types.MeshTransaction, state types.TXState) { t.Helper() for _, mtx := range txs { got, err := transactions.Get(db, mtx.ID) @@ -105,7 +106,7 @@ func checkTXStateFromDB(t *testing.T, db *sql.Database, txs []*types.MeshTransac } } -func checkTXNotInDB(t *testing.T, db *sql.Database, tid types.TransactionID) { +func checkTXNotInDB(t *testing.T, db sql.StateDatabase, tid types.TransactionID) { t.Helper() _, err := transactions.Get(db, tid) require.ErrorIs(t, err, sql.ErrNotFound) @@ -172,7 +173,7 @@ func createState(tb testing.TB, numAccounts int) map[types.Address]*testAcct { func createCache(tb testing.TB, numAccounts int) (*testCache, map[types.Address]*testAcct) { tb.Helper() accounts := createState(tb, numAccounts) - db := sql.InMemory() + db := statesql.InMemory() return &testCache{ Cache: NewCache(getStateFunc(accounts), zaptest.NewLogger(tb)), db: db, @@ -186,7 +187,7 @@ func createSingleAccountTestCache(tb testing.TB) (*testCache, *testAcct) { principal := types.GenerateAddress(signer.PublicKey().Bytes()) ta := &testAcct{signer: signer, principal: principal, nonce: rand.Uint64()%1000 + 1, balance: defaultBalance} states := map[types.Address]*testAcct{principal: ta} - db := sql.InMemory() + db := statesql.InMemory() return &testCache{ Cache: NewCache(getStateFunc(states), zap.NewNop()), db: db, diff --git a/txs/conservative_state.go b/txs/conservative_state.go index c7e0e0fc9e1..afe2599a4b7 100644 --- a/txs/conservative_state.go +++ b/txs/conservative_state.go @@ -55,12 +55,12 @@ type ConservativeState struct { logger *zap.Logger cfg CSConfig - db *sql.Database + db sql.StateDatabase cache *Cache } // NewConservativeState returns a ConservativeState. -func NewConservativeState(state vmState, db *sql.Database, opts ...ConservativeStateOpt) *ConservativeState { +func NewConservativeState(state vmState, db sql.StateDatabase, opts ...ConservativeStateOpt) *ConservativeState { cs := &ConservativeState{ vmState: state, cfg: defaultCSConfig(), diff --git a/txs/conservative_state_test.go b/txs/conservative_state_test.go index c2e78592e8d..2660a282964 100644 --- a/txs/conservative_state_test.go +++ b/txs/conservative_state_test.go @@ -24,6 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -72,7 +73,7 @@ func newTxWthRecipient( type testConState struct { *ConservativeState logger *zap.Logger - db *sql.Database + db sql.StateDatabase mvm *MockvmState id peer.ID @@ -85,7 +86,7 @@ func (t *testConState) handler() *TxHandler { func createTestState(t *testing.T, gasLimit uint64) *testConState { ctrl := gomock.NewController(t) mvm := NewMockvmState(ctrl) - db := sql.InMemory() + db := statesql.InMemory() cfg := CSConfig{ BlockGasLimit: gasLimit, NumTXsPerProposal: numTXsInProposal,