diff --git a/CHANGELOG.md b/CHANGELOG.md index 87f1876ac28..de0dd6a230f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,6 +86,9 @@ operating your own PoET and want to use certificate authentication please refer ATXs. This vulnerability allows an attacker to claim rewards for a full tick amount although they should not be eligible for them. +* [#6003](https://github.com/spacemeshos/go-spacemesh/pull/6003) Improve database schema handling. + This includes schema drift detection which may happen after running unreleased versions. + * [#6031](https://github.com/spacemeshos/go-spacemesh/pull/6031) Fixed an edge case where the storage units might have changed after the initial PoST was generated but before the first ATX has been emitted, invalidating the initial PoST. The node will now try to verify the initial PoST and regenerate it if necessary. diff --git a/CODING.md b/CODING.md index fbb458d693d..0e373b3746a 100644 --- a/CODING.md +++ b/CODING.md @@ -105,3 +105,79 @@ Some useful logging recommendations for cleaner output: ## Commit Messages For commit messages, follow [this guideline](https://www.conventionalcommits.org/en/v1.0.0/). Use reasonable length for the subject and body, ideally no longer than 72 characters. Use the imperative mood for subject lines. + +## Handling database schema changes + +go-spacemesh currently maintains 2 SQLite databases in the data folder: `state.sql` (state database) and `local.sql` (local database). It employs schema versioning for both databases, with a possibility to upgrade older versions of each database to the current schema version by means of running a series of migrations. Also, go-spacemesh tracks any schema drift (unexpected schema changes) in the databases. + +When a database is first created, the corresponding schema file embedded in go-spacemesh executable is used to initialize it: +* `sql/statesql/schema/schema.sql` for `state.sql` +* `sql/localsql/schema/schema.sql` for `local.sql` +The schema file includes `PRAGMA user_version = ...` which sets the version of the database schema. The version of the schema is equal to the number of migrations defined for the corresponding database (`state.sql` or `local.sql`). + +For an existing database, the `PRAGMA user_version` is checked against the expected version number. If the database's schema version is too new, go-spacemesh fails right away as an older go-spacemesh version cannot be expected to work with a database from a newer version. If the database version number is older than the expected version, go-spacemesh runs the necessary migration scripts embedded in go-spacemesh executable and updates `PRAGMA user_version = ...`. The migration scripts are located in the following folders: +* `sql/statesql/schema/migrations` for `state.sql` +* `sql/localsql/schema/migrations` for `local.sql` + +Additionally, some migrations ("coded migrations") can be implemented in Go code, in which case they reside in `.go` files located in `sql/statesql` and `sql/localsql` packages, respectively. It is worth noting that old coded migrations can be removed at some point, rendering database versions that are *too* old unusable with newer go-spacemesh versions. + +After all the migrations are run, go-spacemesh compares the schema of each database to the embedded schema scripts and if they differ, fails with an error message: +``` +Error: open sqlite db schema drift detected (uri file:data/state.sql): + ( + """ + ... // 82 identical lines + PRIMARY KEY (layer, block) + ); ++ CREATE TABLE foo(id int); + CREATE TABLE identities + ( + ... // 66 identical lines + """ + ) +``` + +In this case, a table named `foo` has somehow appeared in the database, causing go-spacemesh to fail due to the schema drift. The possible reasons for schema drift can be the following: +* running an unreleased version of go-spacemesh using your data folder. The unreleased version may contain migrations that may be changed before the release happens +* manual changes in the database +* external SQLite tooling used on the database that adds some tables, indices etc. + +In case if you want to run go-spacemesh with schema drift anyway, you can set `main.db-allow-schema-drift` to true. In this case, a warning with schema diff will be logged instead of failing. + +The schema changes in go-spacemesh code should be always done by means of adding migrations. Let's for example create a new migration (use zero-padded N+1 instead of 0010 with N being the number of the last migration for the local db): + +```console +$ echo 'CREATE TABLE foo(id int);' >sql/localsql/schema/migrations/0010_foo.sql +``` + +After that, we update the schema files +```console +$ make generate +$ # alternative: cd sql/localsql && go generate +$ git diff sql/localsql/schema/schema.sql +diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql +index 02c44d3cc..ebcdf4278 100755 +--- a/sql/localsql/schema/schema.sql ++++ b/sql/localsql/schema/schema.sql +@@ -1,4 +1,4 @@ +-PRAGMA user_version = 9; ++PRAGMA user_version = 10; + CREATE TABLE atx_sync_requests + ( + epoch INT NOT NULL, +@@ -24,6 +24,7 @@ CREATE TABLE "challenge" + post_indices VARCHAR, + post_pow UNSIGNED LONG INT + , poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; ++CREATE TABLE foo(id int); + CREATE TABLE malfeasance_sync_state + ( + id INT NOT NULL PRIMARY KEY, +``` + +Note that the changes include both the new table and an updated `PRAGMA user_version` line. +The changes in the schema file must be committed along with the migration we added. +```console +$ git add sql/localsql/schema/migrations/0010_foo.sql sql/localsql/schema.sql +$ git commit -m "sql: add a test migration" +``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 13c79edce51..d5e4d4f1346 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ Thank you for considering to contribute to the go-spacemesh open source project. We welcome contributions large and small and we actively accept contributions. - go-spacemesh is part of [The Spacemesh open source project](https://spacemesh.io), and is MIT licensed free open source software. -- Please make sure to scan the [open issues](https://github.com/spacemeshos/go-spacemesh/issues). +- Please make sure to scan the [open issues](https://github.com/spacemeshos/go-spacemesh/issues). - Search the closed ones before reporting things, and help us with the open ones. - Make sure that you are able to contribute to MIT licensed free open software (no legal issues please). - Introduce yourself, ask questions about issues or talk about things on our [discord server](https://chat.spacemesh.io/). @@ -39,7 +39,7 @@ Thank you for considering to contribute to the go-spacemesh open source project. # Code Guidelines Please follow these guidelines for your PR to be reviewed and be considered for merging into the project. -1. Document all methods and functions using [go commentary](https://golang.org/doc/effective_go.html#commentary). +1. Document all methods and functions using [go commentary](https://golang.org/doc/effective_go.html#commentary). 2. Provide at least one unit test for each function and method. 3. Provide at least one integration test for each feature with a flow which involves more than one function call. Your tests should reflect the main ways that your code should be used. 4. Run `go mod tidy`, `go fmt ./...` and `make lint` to format and lint your code before submitting your PR. @@ -49,7 +49,7 @@ Please follow these guidelines for your PR to be reviewed and be considered for - Check for existing 3rd-party packages in the vendor folder `./vendor` before adding a new dependency. - Use [govendor](https://github.com/kardianos/govendor) to add a new dependency. -# Working on a funded issue +# Working on a funded issue ## Step 1 - Discover :sunrise_over_mountains: - Browse the [open funded issues](https://github.com/spacemeshos/go-spacemesh/labels/funded%20%3Amoneybag%3A) in our github repo, or on our [gitcoin.io funded issues page](https://gitcoin.co/profile/spacemeshos). @@ -68,6 +68,6 @@ Please follow these guidelines for your PR to be reviewed and be considered for ## Step 3 - Get paid :moneybag: - When ready, submit your PR for review and go through the code review process with one of our maintainers. - Expect a review process that ensures that you have followed our code guidelines at that your design and implementation are solid. You are expected to revision your code based on reviewers comments. -- You should receive your bounty as soon as your PR is approved and merged by one of our maintainers. +- You should receive your bounty as soon as your PR is approved and merged by one of our maintainers. Please review our funded issues program [legal notes](https://github.com/spacemeshos/go-spacemesh/blob/master/legal.md). diff --git a/activation/activation.go b/activation/activation.go index dfd58f455f9..ed8cf60afdd 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -28,7 +28,6 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -76,7 +75,7 @@ type Builder struct { conf Config db sql.Executor atxsdata *atxsdata.Data - localDB *localsql.Database + localDB sql.LocalDatabase publisher pubsub.Publisher nipostBuilder nipostBuilder validator nipostValidator @@ -172,7 +171,7 @@ func NewBuilder( conf Config, db sql.Executor, atxsdata *atxsdata.Data, - localDB *localsql.Database, + localDB sql.LocalDatabase, publisher pubsub.Publisher, nipostBuilder nipostBuilder, layerClock layerClock, diff --git a/activation/activation_test.go b/activation/activation_test.go index a3c696cb24f..039218d06cf 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -33,6 +33,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" sqlmocks "github.com/spacemeshos/go-spacemesh/sql/mocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) // ========== Vars / Consts ========== @@ -54,7 +55,7 @@ func TestMain(m *testing.M) { type testAtxBuilder struct { *Builder db sql.Executor - localDb *localsql.Database + localDb sql.LocalDatabase goldenATXID types.ATXID observedLogs *observer.ObservedLogs @@ -77,7 +78,7 @@ func newTestBuilder(tb testing.TB, numSigners int, opts ...BuilderOption) *testA ctrl := gomock.NewController(tb) tab := &testAtxBuilder{ - db: sql.InMemory(), + db: statesql.InMemory(), localDb: localsql.InMemory(sql.WithConnections(numSigners)), goldenATXID: types.ATXID(types.HexToHash32("77777")), diff --git a/activation/certifier.go b/activation/certifier.go index 236df2827d8..af63f9c49ac 100644 --- a/activation/certifier.go +++ b/activation/certifier.go @@ -21,7 +21,6 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" - "github.com/spacemeshos/go-spacemesh/sql/localsql" certifierdb "github.com/spacemeshos/go-spacemesh/sql/localsql/certifier" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -80,14 +79,14 @@ type CertifyResponse struct { type Certifier struct { logger *zap.Logger - db *localsql.Database + db sql.LocalDatabase client certifierClient certifications singleflight.Group } func NewCertifier( - db *localsql.Database, + db sql.LocalDatabase, logger *zap.Logger, client certifierClient, ) *Certifier { @@ -147,7 +146,7 @@ type CertifierClient struct { client *retryablehttp.Client logger *zap.Logger db sql.Executor - localDb *localsql.Database + localDb sql.LocalDatabase } type certifierClientOpts func(*CertifierClient) @@ -162,7 +161,7 @@ func WithCertifierClientConfig(cfg CertifierClientConfig) certifierClientOpts { func NewCertifierClient( db sql.Executor, - localDb *localsql.Database, + localDb sql.LocalDatabase, logger *zap.Logger, opts ...certifierClientOpts, ) *CertifierClient { diff --git a/activation/certifier_test.go b/activation/certifier_test.go index fa0db876c01..cb2ad5d3067 100644 --- a/activation/certifier_test.go +++ b/activation/certifier_test.go @@ -16,6 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql" certdb "github.com/spacemeshos/go-spacemesh/sql/localsql/certifier" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestPersistsCerts(t *testing.T) { @@ -113,7 +114,7 @@ func TestObtainingPost(t *testing.T) { id := types.RandomNodeID() t.Run("no POST or ATX", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() certifier := NewCertifierClient(db, localDb, zaptest.NewLogger(t)) @@ -121,7 +122,7 @@ func TestObtainingPost(t *testing.T) { require.ErrorContains(t, err, "PoST not found") }) t.Run("initial POST available", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() post := nipost.Post{ @@ -142,7 +143,7 @@ func TestObtainingPost(t *testing.T) { require.Equal(t, post, *got) }) t.Run("initial POST unavailable but ATX exists", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() atx := newInitialATXv1(t, types.RandomATXID()) diff --git a/activation/e2e/activation_test.go b/activation/e2e/activation_test.go index c8a447aef44..be8aeb249a5 100644 --- a/activation/e2e/activation_test.go +++ b/activation/e2e/activation_test.go @@ -27,9 +27,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -61,7 +61,7 @@ func Test_BuilderWithMultipleClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() svc := grpcserver.NewPostService(logger, grpcserver.PostServiceQueryInterval(100*time.Millisecond)) diff --git a/activation/e2e/atx_merge_test.go b/activation/e2e/atx_merge_test.go index 9818edee6f8..e9099dd4cc7 100644 --- a/activation/e2e/atx_merge_test.go +++ b/activation/e2e/atx_merge_test.go @@ -27,11 +27,11 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" @@ -205,7 +205,7 @@ func Test_MarryAndMerge(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) localDB := localsql.InMemory() diff --git a/activation/e2e/builds_atx_v2_test.go b/activation/e2e/builds_atx_v2_test.go index 2a37a1bafba..5a7fd4fa0ca 100644 --- a/activation/e2e/builds_atx_v2_test.go +++ b/activation/e2e/builds_atx_v2_test.go @@ -24,9 +24,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -52,7 +52,7 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { require.NoError(t, err) cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) opts := testPostSetupOpts(t) diff --git a/activation/e2e/certifier_client_test.go b/activation/e2e/certifier_client_test.go index 9ff1a358870..2f09e1cd958 100644 --- a/activation/e2e/certifier_client_test.go +++ b/activation/e2e/certifier_client_test.go @@ -23,9 +23,9 @@ import ( "github.com/spacemeshos/go-spacemesh/api/grpcserver" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestCertification(t *testing.T) { @@ -34,7 +34,7 @@ func TestCertification(t *testing.T) { require.NoError(t, err) cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() opts := testPostSetupOpts(t) diff --git a/activation/e2e/checkpoint_merged_test.go b/activation/e2e/checkpoint_merged_test.go index 545082f2684..91f40dd7b70 100644 --- a/activation/e2e/checkpoint_merged_test.go +++ b/activation/e2e/checkpoint_merged_test.go @@ -25,11 +25,11 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -43,7 +43,7 @@ func Test_CheckpointAfterMerge(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) localDB := localsql.InMemory() @@ -261,7 +261,7 @@ func Test_CheckpointAfterMerge(t *testing.T) { require.NoError(t, err) require.Nil(t, data) - newDB, err := sql.Open("file:" + recoveryCfg.DbPath()) + newDB, err := statesql.Open("file:" + recoveryCfg.DbPath()) require.NoError(t, err) defer newDB.Close() diff --git a/activation/e2e/checkpoint_test.go b/activation/e2e/checkpoint_test.go index 048469b2ff2..c7788733e1b 100644 --- a/activation/e2e/checkpoint_test.go +++ b/activation/e2e/checkpoint_test.go @@ -25,9 +25,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -46,7 +46,7 @@ func TestCheckpoint_PublishingSoloATXs(t *testing.T) { require.NoError(t, err) cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) opts := testPostSetupOpts(t) @@ -181,7 +181,7 @@ func TestCheckpoint_PublishingSoloATXs(t *testing.T) { require.NoError(t, err) require.Nil(t, data) - newDB, err := sql.Open("file:" + recoveryCfg.DbPath()) + newDB, err := statesql.Open("file:" + recoveryCfg.DbPath()) require.NoError(t, err) defer newDB.Close() diff --git a/activation/e2e/nipost_test.go b/activation/e2e/nipost_test.go index 2a135e3256c..5ecb004af1a 100644 --- a/activation/e2e/nipost_test.go +++ b/activation/e2e/nipost_test.go @@ -23,9 +23,9 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -136,7 +136,7 @@ func initPost( logger := zaptest.NewLogger(tb) syncer := syncedSyncer(tb) - db := sql.InMemory() + db := statesql.InMemory() mgr, err := activation.NewPostSetupManager(cfg, logger, db, atxsdata.New(), golden, syncer, nil) require.NoError(tb, err) @@ -157,7 +157,7 @@ func TestNIPostBuilderWithClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() opts := testPostSetupOpts(t) @@ -243,7 +243,7 @@ func Test_NIPostBuilderWithMultipleClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() opts := testPostSetupOpts(t) svc := grpcserver.NewPostService(logger, grpcserver.PostServiceQueryInterval(100*time.Millisecond)) diff --git a/activation/e2e/validation_test.go b/activation/e2e/validation_test.go index 2ea75217770..4b63dd2174b 100644 --- a/activation/e2e/validation_test.go +++ b/activation/e2e/validation_test.go @@ -16,8 +16,8 @@ import ( "github.com/spacemeshos/go-spacemesh/api/grpcserver" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestValidator_Validate(t *testing.T) { @@ -29,7 +29,7 @@ func TestValidator_Validate(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := testPostConfig() - db := sql.InMemory() + db := statesql.InMemory() validator := activation.NewMocknipostValidator(gomock.NewController(t)) @@ -50,7 +50,7 @@ func TestValidator_Validate(t *testing.T) { GracePeriod: epoch / 4, } - poetDb := activation.NewPoetDb(sql.InMemory(), logger.Named("poetDb")) + poetDb := activation.NewPoetDb(statesql.InMemory(), logger.Named("poetDb")) client := ae2e.NewTestPoetClient(1) poetService := activation.NewPoetServiceWithClient(poetDb, client, poetCfg, logger) diff --git a/activation/handler_test.go b/activation/handler_test.go index 9b71a9f81f9..7b364f37840 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -31,6 +31,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -193,7 +194,7 @@ func newTestHandlerMocks(tb testing.TB, golden types.ATXID) handlerMocks { func newTestHandler(tb testing.TB, goldenATXID types.ATXID, opts ...HandlerOption) *testHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) edVerifier := signing.NewEdVerifier() mocks := newTestHandlerMocks(tb, goldenATXID) diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 1ba7c90d2de..d1f28ea0f17 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -519,7 +519,7 @@ func (h *HandlerV1) checkWrongPrevAtx( func (h *HandlerV1) checkMalicious( ctx context.Context, - tx *sql.Tx, + tx sql.Transaction, watx *wire.ActivationTxV1, ) (*mwire.MalfeasanceProof, error) { malicious, err := identities.IsMalicious(tx, watx.SmesherID) @@ -543,7 +543,7 @@ func (h *HandlerV1) storeAtx( watx *wire.ActivationTxV1, ) (*mwire.MalfeasanceProof, error) { var proof *mwire.MalfeasanceProof - if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { var err error proof, err = h.checkMalicious(ctx, tx, watx) if err != nil { diff --git a/activation/handler_v1_test.go b/activation/handler_v1_test.go index 97ce849cebc..1cf8ada3355 100644 --- a/activation/handler_v1_test.go +++ b/activation/handler_v1_test.go @@ -21,6 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type v1TestHandler struct { @@ -31,7 +32,7 @@ type v1TestHandler struct { func newV1TestHandler(tb testing.TB, goldenATXID types.ATXID) *v1TestHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) mocks := newTestHandlerMocks(tb, goldenATXID) return &v1TestHandler{ HandlerV1: &HandlerV1{ diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 73aca060632..1aed2810487 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -654,7 +654,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( } func (h *HandlerV2) checkMalicious( - tx *sql.Tx, + tx sql.Transaction, watx *wire.ActivationTxV2, marrying []marriage, ) (bool, *mwire.MalfeasanceProof, error) { @@ -684,7 +684,7 @@ func (h *HandlerV2) checkMalicious( return false, nil, nil } -func (h *HandlerV2) checkDoubleMarry(tx *sql.Tx, marrying []marriage) (*mwire.MalfeasanceProof, error) { +func (h *HandlerV2) checkDoubleMarry(tx sql.Transaction, marrying []marriage) (*mwire.MalfeasanceProof, error) { for _, m := range marrying { married, err := identities.Married(tx, m.id) if err != nil { @@ -716,7 +716,7 @@ func (h *HandlerV2) storeAtx( malicious bool proof *mwire.MalfeasanceProof ) - if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { var err error malicious, proof, err = h.checkMalicious(tx, watx, marrying) if err != nil { diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 4f80f9e9fbb..4d3c844810c 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -24,6 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type v2TestHandler struct { @@ -44,7 +45,7 @@ const ( func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) mocks := newTestHandlerMocks(tb, golden) return &v2TestHandler{ HandlerV2: &HandlerV2{ diff --git a/activation/malfeasance_test.go b/activation/malfeasance_test.go index a90ac6c5dee..cc2dc07b893 100644 --- a/activation/malfeasance_test.go +++ b/activation/malfeasance_test.go @@ -21,6 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func createIdentity(tb testing.TB, db sql.Executor, sig *signing.EdSigner) { @@ -40,11 +41,11 @@ type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { @@ -237,12 +238,12 @@ type testInvalidPostIndexHandler struct { *InvalidPostIndexHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase mockPostVerifier *MockPostVerifier } func newTestInvalidPostIndexHandler(tb testing.TB) *testInvalidPostIndexHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { @@ -428,11 +429,11 @@ type testInvalidPrevATXHandler struct { *InvalidPrevATXHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestInvalidPrevATXHandler(tb testing.TB) *testInvalidPrevATXHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/activation/nipost.go b/activation/nipost.go index 3ee0b17b655..1a7b13d8b68 100644 --- a/activation/nipost.go +++ b/activation/nipost.go @@ -21,7 +21,6 @@ import ( "github.com/spacemeshos/go-spacemesh/metrics/public" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -47,7 +46,7 @@ var ErrInvalidInitialPost = errors.New("invalid initial post") // NIPostBuilder holds the required state and dependencies to create Non-Interactive Proofs of Space-Time (NIPost). type NIPostBuilder struct { - localDB *localsql.Database + localDB sql.LocalDatabase poetProvers map[string]PoetService postService postService @@ -77,7 +76,7 @@ func NipostbuilderWithPostStates(ps PostStates) NIPostBuilderOption { // NewNIPostBuilder returns a NIPostBuilder. func NewNIPostBuilder( - db *localsql.Database, + db sql.LocalDatabase, postService postService, lg *zap.Logger, poetCfg PoetConfig, diff --git a/activation/nipost_test.go b/activation/nipost_test.go index a351345d9f6..a50a2119880 100644 --- a/activation/nipost_test.go +++ b/activation/nipost_test.go @@ -52,7 +52,7 @@ type testNIPostBuilder struct { observedLogs *observer.ObservedLogs eventSub <-chan events.UserEvent - mDb *localsql.Database + mDb sql.LocalDatabase mLogger *zap.Logger mPoetDb *MockpoetDbAPI mClock *MocklayerClock diff --git a/activation/poet_client_test.go b/activation/poet_client_test.go index 6b8e5126f61..9b9d5910538 100644 --- a/activation/poet_client_test.go +++ b/activation/poet_client_test.go @@ -20,8 +20,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql/certifier" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_HTTPPoetClient_ParsesURL(t *testing.T) { @@ -397,7 +397,7 @@ func TestPoetService_CachesCertifierInfo(t *testing.T) { cfg := DefaultPoetConfig() cfg.CertifierInfoCacheTTL = tc.ttl client := NewMockPoetClient(gomock.NewController(t)) - db := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + db := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) poet := NewPoetServiceWithClient(db, client, cfg, zaptest.NewLogger(t)) url := &url.URL{Host: "certifier.hello"} pubkey := []byte("pubkey") diff --git a/activation/poetdb.go b/activation/poetdb.go index a82cda5e282..6ce0f691b54 100644 --- a/activation/poetdb.go +++ b/activation/poetdb.go @@ -23,12 +23,12 @@ var ErrObjectExists = sql.ErrObjectExists // PoetDb is a database for PoET proofs. type PoetDb struct { - sqlDB *sql.Database + sqlDB sql.StateDatabase logger *zap.Logger } // NewPoetDb returns a new PoET handler. -func NewPoetDb(db *sql.Database, log *zap.Logger) *PoetDb { +func NewPoetDb(db sql.StateDatabase, log *zap.Logger) *PoetDb { return &PoetDb{sqlDB: db, logger: log} } diff --git a/activation/poetdb_test.go b/activation/poetdb_test.go index 6e86452507e..10fd77283e3 100644 --- a/activation/poetdb_test.go +++ b/activation/poetdb_test.go @@ -16,7 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) var ( @@ -63,7 +63,7 @@ func getPoetProof(t *testing.T) types.PoetProofMessage { func TestPoetDbHappyFlow(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) r.NoError(poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature)) ref, err := msg.Ref() @@ -83,7 +83,7 @@ func TestPoetDbHappyFlow(t *testing.T) { func TestPoetDbInvalidPoetProof(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) msg.PoetProof.Root = []byte("some other root") err := poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature) @@ -99,7 +99,7 @@ func TestPoetDbInvalidPoetProof(t *testing.T) { func TestPoetDbInvalidPoetStatement(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) msg.Statement = types.CalcHash32([]byte("some other statement")) err := poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature) @@ -115,7 +115,7 @@ func TestPoetDbInvalidPoetStatement(t *testing.T) { func TestPoetDbNonExistingKeys(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) _, err := poetDb.GetProofRef(msg.PoetServiceID, "0") r.EqualError( diff --git a/activation/post_supervisor_test.go b/activation/post_supervisor_test.go index f26021a3645..7c21cbf4c37 100644 --- a/activation/post_supervisor_test.go +++ b/activation/post_supervisor_test.go @@ -24,7 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func closedChan() <-chan struct{} { @@ -56,7 +56,7 @@ func newPostManager(t *testing.T, cfg PostConfig, opts PostSetupOpts) *PostSetup close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() mgr, err := NewPostSetupManager(cfg, zaptest.NewLogger(t), db, atxsdata, types.RandomATXID(), syncer, validator) require.NoError(t, err) diff --git a/activation/post_test.go b/activation/post_test.go index b3e24eabda3..58b111fc291 100644 --- a/activation/post_test.go +++ b/activation/post_test.go @@ -17,8 +17,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestPostSetupManager(t *testing.T) { @@ -369,7 +369,7 @@ func newTestPostManager(tb testing.TB) *testPostManager { syncer.EXPECT().RegisterForATXSynced().AnyTimes().Return(synced) logger := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), logger) + cdb := datastore.NewCachedDB(statesql.InMemory(), logger) mgr, err := NewPostSetupManager(DefaultPostConfig(), logger, cdb, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) diff --git a/activation/validation_test.go b/activation/validation_test.go index a96c977f6ed..3392594c40f 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -16,8 +16,8 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_Validation_VRFNonce(t *testing.T) { @@ -476,7 +476,7 @@ func TestValidateMerkleProof(t *testing.T) { } func TestVerifyChainDeps(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() goldenATXID := types.ATXID{2, 3, 4} signer, err := signing.NewEdSigner() @@ -619,7 +619,7 @@ func TestVerifyChainDeps(t *testing.T) { } func TestVerifyChainDepsAfterCheckpoint(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() goldenATXID := types.ATXID{2, 3, 4} signer, err := signing.NewEdSigner() diff --git a/activation/verify_state_test.go b/activation/verify_state_test.go index 26446d59265..de119a54159 100644 --- a/activation/verify_state_test.go +++ b/activation/verify_state_test.go @@ -11,13 +11,13 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_CheckPrevATXs(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() logger := zaptest.NewLogger(t) // Arrange diff --git a/api/grpcserver/activation_service_test.go b/api/grpcserver/activation_service_test.go index 2cf5ad0a3c1..f79d12576fb 100644 --- a/api/grpcserver/activation_service_test.go +++ b/api/grpcserver/activation_service_test.go @@ -17,6 +17,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_Highest_ReturnsGoldenAtxOnError(t *testing.T) { @@ -137,7 +138,7 @@ func TestGet_IdentityCanceled(t *testing.T) { atxProvider := grpcserver.NewMockatxProvider(ctrl) activationService := grpcserver.NewActivationService(atxProvider, types.ATXID{1}) - smesher, proof := grpcserver.BallotMalfeasance(t, sql.InMemory()) + smesher, proof := grpcserver.BallotMalfeasance(t, statesql.InMemory()) id := types.RandomATXID() atx := types.ActivationTx{ Sequence: rand.Uint64(), diff --git a/api/grpcserver/admin_service.go b/api/grpcserver/admin_service.go index 3da2e55d4b5..f6e438ae356 100644 --- a/api/grpcserver/admin_service.go +++ b/api/grpcserver/admin_service.go @@ -34,14 +34,14 @@ const ( // AdminService exposes endpoints for node administration. type AdminService struct { - db *sql.Database + db sql.StateDatabase dataDir string recover func() p peers } // NewAdminService creates a new admin grpc service. -func NewAdminService(db *sql.Database, dataDir string, p peers) *AdminService { +func NewAdminService(db sql.StateDatabase, dataDir string, p peers) *AdminService { return &AdminService{ db: db, dataDir: dataDir, diff --git a/api/grpcserver/admin_service_test.go b/api/grpcserver/admin_service_test.go index e68222adbcb..be442e5396a 100644 --- a/api/grpcserver/admin_service_test.go +++ b/api/grpcserver/admin_service_test.go @@ -19,11 +19,12 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const snapshot uint32 = 15 -func newAtx(tb testing.TB, db *sql.Database) { +func newAtx(tb testing.TB, db sql.StateDatabase) { atx := &types.ActivationTx{ PublishEpoch: types.EpochID(2), Sequence: 0, @@ -41,7 +42,7 @@ func newAtx(tb testing.TB, db *sql.Database) { require.NoError(tb, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits)) } -func createMesh(tb testing.TB, db *sql.Database) { +func createMesh(tb testing.TB, db sql.StateDatabase) { for range 10 { newAtx(tb, db) } @@ -57,7 +58,7 @@ func createMesh(tb testing.TB, db *sql.Database) { } func TestAdminService_Checkpoint(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() createMesh(t, db) svc := NewAdminService(db, t.TempDir(), nil) cfg, cleanup := launchServer(t, svc) @@ -94,7 +95,7 @@ func TestAdminService_Checkpoint(t *testing.T) { } func TestAdminService_CheckpointError(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() svc := NewAdminService(db, t.TempDir(), nil) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -111,7 +112,7 @@ func TestAdminService_CheckpointError(t *testing.T) { } func TestAdminService_Recovery(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() recoveryCalled := atomic.Bool{} svc := NewAdminService(db, t.TempDir(), nil) svc.recover = func() { recoveryCalled.Store(true) } @@ -133,7 +134,7 @@ func TestAdminService_PeerInfo(t *testing.T) { ctrl := gomock.NewController(t) p := NewMockpeers(ctrl) - db := sql.InMemory() + db := statesql.InMemory() svc := NewAdminService(db, t.TempDir(), p) cfg, cleanup := launchServer(t, svc) diff --git a/api/grpcserver/debug_service.go b/api/grpcserver/debug_service.go index f0e0b6671ec..672886bf27f 100644 --- a/api/grpcserver/debug_service.go +++ b/api/grpcserver/debug_service.go @@ -24,7 +24,7 @@ import ( // DebugService exposes global state data, output from the STF. type DebugService struct { - db *sql.Database + db sql.StateDatabase conState conservativeState netInfo networkInfo oracle oracle @@ -46,7 +46,7 @@ func (d DebugService) String() string { } // NewDebugService creates a new grpc service using config data. -func NewDebugService(db *sql.Database, conState conservativeState, host networkInfo, oracle oracle, +func NewDebugService(db sql.StateDatabase, conState conservativeState, host networkInfo, oracle oracle, loggers map[string]*zap.AtomicLevel, ) *DebugService { return &DebugService{ diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index bf9b7c37696..126ddbf3f47 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -48,11 +48,11 @@ import ( peerinfomocks "github.com/spacemeshos/go-spacemesh/p2p/peerinfo/mocks" pubsubmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/txs" ) @@ -725,7 +725,7 @@ func TestMeshService(t *testing.T) { genesis := time.Unix(genTimeUnix, 0) genTime.EXPECT().GenesisTime().Return(genesis) genTime.EXPECT().CurrentLayer().Return(layerCurrent).AnyTimes() - db := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + db := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) svc := NewMeshService( db, meshAPIMock, @@ -1256,7 +1256,7 @@ func TestTransactionServiceSubmitUnsync(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil) - svc := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + svc := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -1295,7 +1295,7 @@ func TestTransactionServiceSubmitInvalidTx(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(errors.New("failed validation")) - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1328,7 +1328,7 @@ func TestTransactionService_SubmitNoConcurrency(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).Times(numTxs) - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1356,7 +1356,7 @@ func TestTransactionService(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1666,7 +1666,7 @@ func TestAccountMeshDataStream_comprehensive(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) grpcService := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, @@ -1848,7 +1848,7 @@ func TestLayerStream_comprehensive(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + db := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) grpcService := NewMeshService( db, @@ -1994,7 +1994,7 @@ func TestMultiService(t *testing.T) { genTime.EXPECT().GenesisTime().Return(genesis) svc1 := NewNodeService(peerCounter, meshAPIMock, genTime, syncer, "v0.0.0", "cafebabe") svc2 := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, @@ -2039,7 +2039,7 @@ func TestDebugService(t *testing.T) { ctrl := gomock.NewController(t) netInfo := NewMocknetworkInfo(ctrl) mOracle := NewMockoracle(ctrl) - db := sql.InMemory() + db := statesql.InMemory() testLog := zap.NewAtomicLevel() loggers := map[string]*zap.AtomicLevel{ @@ -2233,7 +2233,7 @@ func TestEventsReceived(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) - txService := NewTransactionService(sql.InMemory(), nil, meshAPIMock, conStateAPI, nil, nil) + txService := NewTransactionService(statesql.InMemory(), nil, meshAPIMock, conStateAPI, nil, nil) gsService := NewGlobalStateService(meshAPIMock, conStateAPI) cfg, cleanup := launchServer(t, txService, gsService) t.Cleanup(cleanup) @@ -2283,8 +2283,8 @@ func TestEventsReceived(t *testing.T) { time.Sleep(time.Millisecond * 50) lg := logtest.New(t) - svm := vm.New(sql.InMemory(), vm.WithLogger(lg)) - conState := txs.NewConservativeState(svm, sql.InMemory(), txs.WithLogger(lg.Zap().Named("conState"))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(lg)) + conState := txs.NewConservativeState(svm, statesql.InMemory(), txs.WithLogger(lg.Zap().Named("conState"))) conState.AddToCache(context.Background(), globalTx, time.Now()) weight := new(big.Rat).SetFloat64(18.7) @@ -2347,7 +2347,7 @@ func TestTransactionsRewards(t *testing.T) { req.NoError(err, "stream request returned unexpected error") time.Sleep(50 * time.Millisecond) - svm := vm.New(sql.InMemory(), vm.WithLogger(logtest.New(t))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(logtest.New(t))) _, _, err = svm.Apply(vm.ApplyContext{Layer: types.LayerID(17)}, []types.Transaction{*globalTx}, rewards) req.NoError(err) @@ -2368,7 +2368,7 @@ func TestTransactionsRewards(t *testing.T) { req.NoError(err, "stream request returned unexpected error") time.Sleep(50 * time.Millisecond) - svm := vm.New(sql.InMemory(), vm.WithLogger(logtest.New(t))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(logtest.New(t))) _, _, err = svm.Apply(vm.ApplyContext{Layer: types.LayerID(17)}, []types.Transaction{*globalTx}, rewards) req.NoError(err) @@ -2387,7 +2387,7 @@ func TestVMAccountUpdates(t *testing.T) { events.InitializeReporter() // in memory database doesn't allow reads while writer locked db - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) require.NoError(t, err) t.Cleanup(func() { db.Close() }) svm := vm.New(db, vm.WithLogger(logtest.New(t))) @@ -2483,7 +2483,7 @@ func createAtxs(tb testing.TB, epoch types.EpochID, atxids []types.ATXID) []*typ func TestMeshService_EpochStream(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, diff --git a/api/grpcserver/http_server_test.go b/api/grpcserver/http_server_test.go index 46f75a5bd2d..4c5034871b2 100644 --- a/api/grpcserver/http_server_test.go +++ b/api/grpcserver/http_server_test.go @@ -18,7 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func launchJsonServer(tb testing.TB, services ...ServiceAPI) (Config, func()) { @@ -66,7 +66,7 @@ func TestJsonApi(t *testing.T) { conStateAPI := NewMockconservativeState(ctrl) svc1 := NewNodeService(peerCounter, meshAPIMock, genTime, syncer, version, build) svc2 := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, diff --git a/api/grpcserver/mesh_service_test.go b/api/grpcserver/mesh_service_test.go index 6acb9465aab..9ce543173f8 100644 --- a/api/grpcserver/mesh_service_test.go +++ b/api/grpcserver/mesh_service_test.go @@ -22,6 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -142,7 +143,7 @@ func HareMalfeasance(tb testing.TB, db sql.Executor) (types.NodeID, *wire.Malfea func TestMeshService_MalfeasanceQuery(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, @@ -195,7 +196,7 @@ func TestMeshService_MalfeasanceStream(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, @@ -301,7 +302,7 @@ func (t *ConStateAPIMockInstrumented) GetLayerStateRoot(types.LayerID) (types.Ha func TestReadLayer(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), &MeshAPIMockInstrumented{}, diff --git a/api/grpcserver/post_service_test.go b/api/grpcserver/post_service_test.go index f3fddd506b0..464ef984a16 100644 --- a/api/grpcserver/post_service_test.go +++ b/api/grpcserver/post_service_test.go @@ -23,7 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func launchPostSupervisor( @@ -58,7 +58,7 @@ func launchPostSupervisor( close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() logger := log.Named("post manager") mgr, err := activation.NewPostSetupManager(postCfg, logger, db, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) @@ -102,7 +102,7 @@ func launchPostSupervisorTLS( close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() logger := log.Named("post supervisor") mgr, err := activation.NewPostSetupManager(postCfg, logger, db, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) diff --git a/api/grpcserver/transaction_service.go b/api/grpcserver/transaction_service.go index a0713b3b8c8..80f5155c9a6 100644 --- a/api/grpcserver/transaction_service.go +++ b/api/grpcserver/transaction_service.go @@ -28,7 +28,7 @@ import ( // TransactionService exposes transaction data, and a submit tx endpoint. type TransactionService struct { - db *sql.Database + db sql.StateDatabase publisher pubsub.Publisher // P2P Swarm mesh meshAPI // Mesh conState conservativeState @@ -52,7 +52,7 @@ func (s TransactionService) String() string { // NewTransactionService creates a new grpc service using config data. func NewTransactionService( - db *sql.Database, + db sql.StateDatabase, publisher pubsub.Publisher, msh meshAPI, conState conservativeState, diff --git a/api/grpcserver/transaction_service_test.go b/api/grpcserver/transaction_service_test.go index 343e30ed58e..9e97c0d5388 100644 --- a/api/grpcserver/transaction_service_test.go +++ b/api/grpcserver/transaction_service_test.go @@ -23,12 +23,13 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/txs" ) func TestTransactionService_StreamResults(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -36,7 +37,7 @@ func TestTransactionService_StreamResults(t *testing.T) { gen := fixture.NewTransactionResultGenerator(). WithAddresses(2) txs := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(ctx, func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(ctx, func(dtx sql.Transaction) error { for i := range txs { tx := gen.Next() @@ -134,7 +135,7 @@ func TestTransactionService_StreamResults(t *testing.T) { } func BenchmarkStreamResults(b *testing.B) { - db := sql.InMemory() + db := statesql.InMemory() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -216,7 +217,7 @@ func parseOk() parseExpectation { } func TestParseTransactions(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) t.Cleanup(cancel) vminst := vm.New(db) diff --git a/api/grpcserver/v2alpha1/account_test.go b/api/grpcserver/v2alpha1/account_test.go index 3d8684d5bc3..fb876851e2c 100644 --- a/api/grpcserver/v2alpha1/account_test.go +++ b/api/grpcserver/v2alpha1/account_test.go @@ -14,8 +14,8 @@ import ( "google.golang.org/grpc/status" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testAccount struct { @@ -27,7 +27,7 @@ type testAccount struct { } func TestAccountService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctrl, ctx := gomock.WithContext(context.Background(), t) conState := NewMockaccountConState(ctrl) diff --git a/api/grpcserver/v2alpha1/activation_test.go b/api/grpcserver/v2alpha1/activation_test.go index 70b97330fe5..a482f05b30d 100644 --- a/api/grpcserver/v2alpha1/activation_test.go +++ b/api/grpcserver/v2alpha1/activation_test.go @@ -16,12 +16,12 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestActivationService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewAtxsGenerator() @@ -104,7 +104,7 @@ func TestActivationService_List(t *testing.T) { } func TestActivationStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewAtxsGenerator() @@ -214,7 +214,7 @@ func TestActivationStreamService_Stream(t *testing.T) { } func TestActivationService_ActivationsCount(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() genEpoch3 := fixture.NewAtxsGenerator().WithEpochs(3, 1) diff --git a/api/grpcserver/v2alpha1/layer_test.go b/api/grpcserver/v2alpha1/layer_test.go index bfc082b0689..2c21967c0d7 100644 --- a/api/grpcserver/v2alpha1/layer_test.go +++ b/api/grpcserver/v2alpha1/layer_test.go @@ -19,10 +19,11 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestLayerService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lrs := make([]layers.Layer, 100) @@ -98,7 +99,7 @@ func TestLayerConvertEventStatus(t *testing.T) { } func TestLayerStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lrs := make([]layers.Layer, 100) @@ -225,7 +226,7 @@ func layerGenWithBlock(withBlock bool) layerGenOpt { } } -func generateLayer(db *sql.Database, id types.LayerID, opts ...layerGenOpt) (*layers.Layer, error) { +func generateLayer(db sql.StateDatabase, id types.LayerID, opts ...layerGenOpt) (*layers.Layer, error) { g := &layerGenOpts{} for _, opt := range opts { opt(g) diff --git a/api/grpcserver/v2alpha1/reward_test.go b/api/grpcserver/v2alpha1/reward_test.go index 50952097d14..2e16cab291a 100644 --- a/api/grpcserver/v2alpha1/reward_test.go +++ b/api/grpcserver/v2alpha1/reward_test.go @@ -15,12 +15,12 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/rewards" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRewardService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewRewardsGenerator().WithAddresses(100).WithUniqueCoinbase() @@ -103,7 +103,7 @@ func TestRewardService_List(t *testing.T) { } func TestRewardStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewRewardsGenerator() diff --git a/api/grpcserver/v2alpha1/transaction_test.go b/api/grpcserver/v2alpha1/transaction_test.go index 3128b470d5a..0effec5797b 100644 --- a/api/grpcserver/v2alpha1/transaction_test.go +++ b/api/grpcserver/v2alpha1/transaction_test.go @@ -31,18 +31,19 @@ import ( pubsubmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/txs" ) func TestTransactionService_List(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewTransactionResultGenerator().WithAddresses(2) txsList := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(ctx, func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(ctx, func(dtx sql.Transaction) error { for i := range txsList { tx := gen.Next() @@ -178,7 +179,7 @@ func TestTransactionService_List(t *testing.T) { func TestTransactionService_EstimateGas(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() vminst := vm.New(db) ctx := context.Background() @@ -241,7 +242,7 @@ func TestTransactionService_EstimateGas(t *testing.T) { func TestTransactionService_ParseTransaction(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() vminst := vm.New(db) ctx := context.Background() @@ -357,7 +358,7 @@ func TestTransactionServiceSubmitUnsync(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -400,7 +401,7 @@ func TestTransactionServiceSubmitInvalidTx(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(errors.New("failed validation")) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -437,7 +438,7 @@ func TestTransactionService_SubmitNoConcurrency(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).Times(numTxs) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) diff --git a/atxsdata/warmup.go b/atxsdata/warmup.go index 84c68507363..557618dcef9 100644 --- a/atxsdata/warmup.go +++ b/atxsdata/warmup.go @@ -10,7 +10,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/layers" ) -func Warm(db *sql.Database, keep types.EpochID) (*Data, error) { +func Warm(db sql.StateDatabase, keep types.EpochID) (*Data, error) { cache := New() tx, err := db.Tx(context.Background()) if err != nil { diff --git a/atxsdata/warmup_test.go b/atxsdata/warmup_test.go index 67fa5981408..c052e87b96a 100644 --- a/atxsdata/warmup_test.go +++ b/atxsdata/warmup_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/mocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func gatx( @@ -37,7 +38,7 @@ func gatx( func TestWarmup(t *testing.T) { types.SetLayersPerEpoch(3) t.Run("sanity", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() applied := types.LayerID(10) nonce := types.VRFPostIndex(1) data := []types.ActivationTx{ @@ -60,19 +61,19 @@ func TestWarmup(t *testing.T) { } }) t.Run("no data", func(t *testing.T) { - c, err := Warm(sql.InMemory(), 1) + c, err := Warm(statesql.InMemory(), 1) require.NoError(t, err) require.NotNil(t, c) }) t.Run("closed db", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, db.Close()) c, err := Warm(db, 1) require.Error(t, err) require.Nil(t, c) }) t.Run("db failures", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nonce := types.VRFPostIndex(1) data := gatx(types.ATXID{1, 1}, 1, types.NodeID{1}, nonce) require.NoError(t, atxs.Add(db, &data, types.AtxBlob{})) diff --git a/beacon/beacon_test.go b/beacon/beacon_test.go index 5e3943d32df..2854269ffc9 100644 --- a/beacon/beacon_test.go +++ b/beacon/beacon_test.go @@ -30,6 +30,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -91,7 +92,7 @@ func newTestDriver(tb testing.TB, cfg Config, p pubsub.Publisher, miners int, id tpd.mVerifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(true) - tpd.cdb = datastore.NewCachedDB(sql.InMemory(), lg) + tpd.cdb = datastore.NewCachedDB(statesql.InMemory(), lg) tpd.ProtocolDriver = New(p, signing.NewEdVerifier(), tpd.mVerifier, tpd.cdb, tpd.mClock, WithConfig(cfg), WithLogger(lg), @@ -493,7 +494,7 @@ func TestBeacon_NoRaceOnClose(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, closed: make(chan struct{}), results: make(chan result.Beacon, 100), @@ -528,7 +529,7 @@ func TestBeacon_BeaconsWithDatabase(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, } epoch3 := types.EpochID(3) @@ -581,7 +582,7 @@ func TestBeacon_BeaconsWithDatabaseFailure(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, } epoch := types.EpochID(3) @@ -599,7 +600,7 @@ func TestBeacon_BeaconsCleanupOldEpoch(t *testing.T) { mclock := NewMocklayerClock(gomock.NewController(t)) pd := &ProtocolDriver{ logger: lg.Named("Beacon"), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, @@ -704,7 +705,7 @@ func TestBeacon_ReportBeaconFromBallot(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), config: UnitTestConfig(), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, @@ -740,7 +741,7 @@ func TestBeacon_ReportBeaconFromBallot_SameBallot(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), config: UnitTestConfig(), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, diff --git a/blocks/certifier.go b/blocks/certifier.go index d5045149e93..b3461a98fc0 100644 --- a/blocks/certifier.go +++ b/blocks/certifier.go @@ -81,7 +81,7 @@ type Certifier struct { stop func() stopped atomic.Bool - db *sql.Database + db sql.StateDatabase oracle eligibility.Rolacle signers map[types.NodeID]*signing.EdSigner edVerifier *signing.EdVerifier @@ -99,7 +99,7 @@ type Certifier struct { // NewCertifier creates new block certifier. func NewCertifier( - db *sql.Database, + db sql.StateDatabase, o eligibility.Rolacle, v *signing.EdVerifier, @@ -567,7 +567,7 @@ func (c *Certifier) save( if len(valid)+len(invalid) == 0 { return certificates.Add(c.db, lid, cert) } - return c.db.WithTx(ctx, func(dbtx *sql.Tx) error { + return c.db.WithTx(ctx, func(dbtx sql.Transaction) error { if err := certificates.Add(dbtx, lid, cert); err != nil { return err } diff --git a/blocks/certifier_test.go b/blocks/certifier_test.go index 8a3cc721206..d4a0c7faef0 100644 --- a/blocks/certifier_test.go +++ b/blocks/certifier_test.go @@ -20,6 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -27,7 +28,7 @@ const defaultCnt = uint16(2) type testCertifier struct { *Certifier - db *sql.Database + db sql.StateDatabase mOracle *eligibility.MockRolacle mPub *pubsubmock.MockPublisher mClk *mocks.MocklayerClock @@ -38,7 +39,7 @@ type testCertifier struct { func newTestCertifier(t *testing.T, signers int) *testCertifier { t.Helper() types.SetLayersPerEpoch(3) - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(t) mo := eligibility.NewMockRolacle(ctrl) mp := pubsubmock.NewMockPublisher(ctrl) diff --git a/blocks/generator.go b/blocks/generator.go index 0d30d7f40fb..6c63db247f5 100644 --- a/blocks/generator.go +++ b/blocks/generator.go @@ -30,7 +30,7 @@ type Generator struct { eg errgroup.Group stop func() - db *sql.Database + db sql.StateDatabase atxs *atxsdata.Data proposals *store.Store msh meshProvider @@ -84,7 +84,7 @@ func WithHareOutputChan(ch <-chan hare3.ConsensusOutput) GeneratorOpt { // NewGenerator creates new block generator. func NewGenerator( - db *sql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, proposals *store.Store, exec executor, diff --git a/blocks/generator_test.go b/blocks/generator_test.go index 4145f3ff29b..211a9855121 100644 --- a/blocks/generator_test.go +++ b/blocks/generator_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -73,7 +74,7 @@ func createTestGenerator(t *testing.T) *testGenerator { } tg.mockMesh.EXPECT().ProcessedLayer().Return(types.LayerID(1)).AnyTimes() lg := zaptest.NewLogger(t) - db := sql.InMemory() + db := statesql.InMemory() data := atxsdata.New() proposals := store.New() tg.Generator = NewGenerator( @@ -266,7 +267,7 @@ func Test_StopBeforeStart(t *testing.T) { func genData( t *testing.T, - db *sql.Database, + db sql.StateDatabase, data *atxsdata.Data, store *store.Store, lid types.LayerID, diff --git a/blocks/handler.go b/blocks/handler.go index 1d6d7b12011..4ff6f9b30ec 100644 --- a/blocks/handler.go +++ b/blocks/handler.go @@ -28,7 +28,7 @@ type Handler struct { logger *zap.Logger fetcher system.Fetcher - db *sql.Database + db sql.StateDatabase tortoise tortoiseProvider mesh meshProvider } @@ -44,7 +44,13 @@ func WithLogger(logger *zap.Logger) Opt { } // NewHandler creates new Handler. -func NewHandler(f system.Fetcher, db *sql.Database, tortoise tortoiseProvider, m meshProvider, opts ...Opt) *Handler { +func NewHandler( + f system.Fetcher, + db sql.StateDatabase, + tortoise tortoiseProvider, + m meshProvider, + opts ...Opt, +) *Handler { h := &Handler{ logger: zap.NewNop(), fetcher: f, diff --git a/blocks/handler_test.go b/blocks/handler_test.go index 8c44db02726..7ccb2078406 100644 --- a/blocks/handler_test.go +++ b/blocks/handler_test.go @@ -13,8 +13,8 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -34,7 +34,7 @@ func createTestHandler(t *testing.T) *testHandler { } th.Handler = NewHandler( th.mockFetcher, - sql.InMemory(), + statesql.InMemory(), th.mockTortoise, th.mockMesh, WithLogger(zaptest.NewLogger(t)), diff --git a/blocks/utils.go b/blocks/utils.go index e6ec1faf507..4dfb25d1a17 100644 --- a/blocks/utils.go +++ b/blocks/utils.go @@ -50,7 +50,7 @@ type proposalMetadata struct { func getProposalMetadata( ctx context.Context, logger *zap.Logger, - db *sql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, cfg Config, lid types.LayerID, @@ -232,7 +232,7 @@ func toUint64Slice(b []byte) []uint64 { func rewardInfoAndHeight( cfg Config, - db *sql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, props []*types.Proposal, ) (uint64, []types.AnyReward, error) { diff --git a/blocks/utils_test.go b/blocks/utils_test.go index 11776145b5b..d7f9b1f3ca4 100644 --- a/blocks/utils_test.go +++ b/blocks/utils_test.go @@ -14,9 +14,9 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -159,7 +159,7 @@ func Test_getBlockTXs_expected_order(t *testing.T) { func Test_getProposalMetadata(t *testing.T) { lg := zaptest.NewLogger(t) - db := sql.InMemory() + db := statesql.InMemory() data := atxsdata.New() cfg := Config{OptFilterThreshold: 70} lid := types.LayerID(111) diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 97a77d54680..8348289c8d8 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -29,6 +29,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const recoveryDir = "recovery" @@ -120,7 +121,7 @@ func Recover( return nil, errors.New("restore layer not set") } logger.Info("recovering from checkpoint", zap.String("url", cfg.Uri), zap.Stringer("restore", cfg.Restore)) - db, err := sql.Open("file:" + cfg.DbPath()) + db, err := statesql.Open("file:" + cfg.DbPath()) if err != nil { return nil, fmt.Errorf("open old database: %w", err) } @@ -131,7 +132,7 @@ func Recover( } defer localDB.Close() logger.Info("clearing atx and malfeasance sync metadata from local database") - if err := localDB.WithTx(ctx, func(tx *sql.Tx) error { + if err := localDB.WithTx(ctx, func(tx sql.Transaction) error { if err := atxsync.Clear(tx); err != nil { return err } @@ -153,8 +154,8 @@ func Recover( func RecoverWithDb( ctx context.Context, logger *zap.Logger, - db *sql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, fs afero.Fs, cfg *RecoverConfig, ) (*PreservedData, error) { @@ -187,8 +188,8 @@ type recoveryData struct { func RecoverFromLocalFile( ctx context.Context, logger *zap.Logger, - db *sql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, fs afero.Fs, cfg *RecoverConfig, file string, @@ -257,8 +258,7 @@ func RecoverFromLocalFile( } logger.Info("backed up old database", log.ZContext(ctx), zap.String("backup dir", backupDir)) - var newDB *sql.Database - newDB, err = sql.Open("file:" + cfg.DbPath()) + newDB, err := statesql.Open("file:" + cfg.DbPath()) if err != nil { return nil, fmt.Errorf("creating new DB: %w", err) } @@ -268,7 +268,7 @@ func RecoverFromLocalFile( zap.Int("num accounts", len(data.accounts)), zap.Int("num atxs", len(data.atxs)), ) - if err = newDB.WithTx(ctx, func(tx *sql.Tx) error { + if err = newDB.WithTx(ctx, func(tx sql.Transaction) error { for _, acct := range data.accounts { if err = accounts.Update(tx, acct); err != nil { return fmt.Errorf("restore account snapshot: %w", err) @@ -388,8 +388,8 @@ func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recove func collectOwnAtxDeps( logger *zap.Logger, - db *sql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, nodeID types.NodeID, goldenATX types.ATXID, data *recoveryData, @@ -450,7 +450,7 @@ func collectOwnAtxDeps( } func collectDeps( - db *sql.Database, + db sql.StateDatabase, ref types.ATXID, all map[types.ATXID]struct{}, ) (map[types.ATXID]*AtxDep, map[types.PoetProofRef]*types.PoetProofMessage, error) { @@ -466,7 +466,7 @@ func collectDeps( } func collect( - db *sql.Database, + db sql.StateDatabase, ref types.ATXID, all map[types.ATXID]struct{}, deps map[types.ATXID]*AtxDep, @@ -531,7 +531,7 @@ func collect( } func poetProofs( - db *sql.Database, + db sql.StateDatabase, atxIds map[types.ATXID]*AtxDep, ) (map[types.PoetProofRef]*types.PoetProofMessage, error) { proofs := make(map[types.PoetProofRef]*types.PoetProofMessage, len(atxIds)) diff --git a/checkpoint/recovery_collecting_deps_test.go b/checkpoint/recovery_collecting_deps_test.go index 1df6ca0a9b7..4bbc892c105 100644 --- a/checkpoint/recovery_collecting_deps_test.go +++ b/checkpoint/recovery_collecting_deps_test.go @@ -10,15 +10,15 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestCollectingDeps(t *testing.T) { golden := types.RandomATXID() t.Run("collect marriage ATXs", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() marriageATX := &wire.ActivationTxV1{ InnerActivationTxV1: wire.InnerActivationTxV1{ diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index b21f9e9d065..0fa8e530d35 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -35,6 +35,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -77,7 +78,7 @@ func accountEqual(tb testing.TB, cacct types.AccountSnapshot, acct *types.Accoun } } -func verifyDbContent(tb testing.TB, db *sql.Database) { +func verifyDbContent(tb testing.TB, db sql.StateDatabase) { tb.Helper() var expected types.Checkpoint require.NoError(tb, json.Unmarshal([]byte(checkpointData), &expected)) @@ -168,7 +169,7 @@ func TestRecover(t *testing.T) { } bsdir := filepath.Join(cfg.DataDir, bootstrap.DirName) require.NoError(t, fs.MkdirAll(bsdir, 0o700)) - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() data, err := checkpoint.RecoverWithDb(context.Background(), zaptest.NewLogger(t), db, localDB, fs, cfg) if tc.expErr != nil { @@ -177,7 +178,7 @@ func TestRecover(t *testing.T) { } require.NoError(t, err) require.Nil(t, data) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) defer newDB.Close() @@ -209,7 +210,7 @@ func TestRecover_SameRecoveryInfo(t *testing.T) { } bsdir := filepath.Join(cfg.DataDir, bootstrap.DirName) require.NoError(t, fs.MkdirAll(bsdir, 0o700)) - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() types.SetEffectiveGenesis(0) require.NoError(t, recovery.SetCheckpoint(db, types.LayerID(recoverLayer))) @@ -241,7 +242,7 @@ func TestRecover_RestoreLayerCannotBeZero(t *testing.T) { func validateAndPreserveData( tb testing.TB, - db *sql.Database, + db sql.StateDatabase, deps []*checkpoint.AtxDep, ) { tb.Helper() @@ -496,7 +497,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -541,7 +542,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) { require.ElementsMatch(t, atxRef, atxIDs(preserve.Deps)) require.ElementsMatch(t, proofRef, proofRefs(preserve.Proofs)) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -581,7 +582,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -653,7 +654,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { require.ElementsMatch(t, atxRef, atxIDs(preserve.Deps)) require.ElementsMatch(t, proofRef, proofRefs(preserve.Proofs)) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -691,7 +692,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -748,7 +749,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -781,7 +782,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) vAtxs, proofs := createAtxChain(t, sig) @@ -824,7 +825,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -857,7 +858,7 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) vAtxs, proofs := createAtxChain(t, sig) @@ -883,7 +884,7 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -920,7 +921,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) require.NoError(t, atxs.Add(oldDB, atx, types.AtxBlob{})) @@ -930,7 +931,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) diff --git a/checkpoint/runner.go b/checkpoint/runner.go index 2b72e236b03..29b32ec156e 100644 --- a/checkpoint/runner.go +++ b/checkpoint/runner.go @@ -28,7 +28,7 @@ const ( func checkpointDB( ctx context.Context, - db *sql.Database, + db sql.StateDatabase, snapshot types.LayerID, numAtxs int, ) (*types.Checkpoint, error) { @@ -166,7 +166,7 @@ func checkpointDB( func Generate( ctx context.Context, fs afero.Fs, - db *sql.Database, + db sql.StateDatabase, dataDir string, snapshot types.LayerID, numAtxs int, diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index f7009c24ec2..e01d9efe782 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -21,6 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -253,7 +254,7 @@ func asAtxSnapshot(v *types.ActivationTx, cmt *types.ATXID) types.AtxSnapshot { } } -func createMesh(t testing.TB, db *sql.Database, miners []miner, accts []*types.Account) { +func createMesh(t testing.TB, db sql.StateDatabase, miners []miner, accts []*types.Account) { t.Helper() for _, miner := range miners { for _, atx := range miner.atxs { @@ -300,7 +301,7 @@ func TestRunner_Generate(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) createMesh(t, db, tc.miners, tc.accts) @@ -336,7 +337,7 @@ func TestRunner_Generate_Error(t *testing.T) { t.Parallel() t.Run("no commitment atx", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) atx := newAtx(types.ATXID{13}, types.EmptyATXID, nil, 2, 1, 11, types.RandomNodeID().Bytes()) @@ -350,7 +351,7 @@ func TestRunner_Generate_Error(t *testing.T) { }) t.Run("no atxs", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) createMesh(t, db, nil, allAccounts) @@ -363,7 +364,7 @@ func TestRunner_Generate_Error(t *testing.T) { }) t.Run("no accounts", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) createMesh(t, db, allMiners, nil) diff --git a/cmd/activeset/activeset.go b/cmd/activeset/activeset.go index 1046916b024..1415972e6b5 100644 --- a/cmd/activeset/activeset.go +++ b/cmd/activeset/activeset.go @@ -9,8 +9,8 @@ import ( "strconv" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func main() { @@ -30,7 +30,7 @@ Example: if len(dbpath) == 0 { must(errors.New("dbpath is empty"), "dbpath is empty\n") } - db, err := sql.Open("file:" + dbpath) + db, err := statesql.Open("file:" + dbpath) must(err, "can't open db at dbpath=%v. err=%s\n", dbpath, err) ids, err := atxs.GetIDsByEpoch(context.Background(), db, types.EpochID(publish)) diff --git a/cmd/bootstrapper/generator_test.go b/cmd/bootstrapper/generator_test.go index d84cf2a2888..e76451d50f9 100644 --- a/cmd/bootstrapper/generator_test.go +++ b/cmd/bootstrapper/generator_test.go @@ -25,6 +25,7 @@ import ( "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -103,7 +104,7 @@ func verifyUpdate(tb testing.TB, data []byte, epoch types.EpochID, expBeacon str func TestGenerator_Generate(t *testing.T) { t.Parallel() targetEpoch := types.EpochID(3) - db := sql.InMemory() + db := statesql.InMemory() createAtxs(t, db, targetEpoch-1, types.RandomActiveSet(activeSetSize)) cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, zaptest.NewLogger(t))) t.Cleanup(cleanup) @@ -169,7 +170,7 @@ func TestGenerator_Generate(t *testing.T) { func TestGenerator_CheckAPI(t *testing.T) { t.Parallel() targetEpoch := types.EpochID(3) - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) createAtxs(t, db, targetEpoch-1, types.RandomActiveSet(activeSetSize)) cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, lg.Zap())) diff --git a/cmd/bootstrapper/server_test.go b/cmd/bootstrapper/server_test.go index 910d448c79e..14230486029 100644 --- a/cmd/bootstrapper/server_test.go +++ b/cmd/bootstrapper/server_test.go @@ -20,7 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/log/logtest" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) //go:embed checkpointdata.json @@ -57,7 +57,7 @@ func updateCheckpoint(t *testing.T, ctx context.Context, data string) { } func TestServer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, zaptest.NewLogger(t))) t.Cleanup(cleanup) diff --git a/cmd/merge-nodes/internal/errors.go b/cmd/merge-nodes/internal/errors.go index 53af669de08..b3e3449b0b5 100644 --- a/cmd/merge-nodes/internal/errors.go +++ b/cmd/merge-nodes/internal/errors.go @@ -4,7 +4,4 @@ import ( "errors" ) -var ( - ErrSupervisedNode = errors.New("merging of supervised smeshing nodes is not supported") - ErrInvalidSchema = errors.New("database has an invalid schema version") -) +var ErrSupervisedNode = errors.New("merging of supervised smeshing nodes is not supported") diff --git a/cmd/merge-nodes/internal/merge_action.go b/cmd/merge-nodes/internal/merge_action.go index 43b9929fdc6..1ebc90b00ef 100644 --- a/cmd/merge-nodes/internal/merge_action.go +++ b/cmd/merge-nodes/internal/merge_action.go @@ -26,7 +26,7 @@ const ( func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { // Open the target database - var dstDB *localsql.Database + var dstDB sql.LocalDatabase var err error dstDB, err = openDB(dbLog, to) switch { @@ -150,7 +150,7 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { } dbLog.Info("merging databases", zap.String("from", from), zap.String("to", to)) - err = dstDB.WithTx(ctx, func(tx *sql.Tx) error { + err = dstDB.WithTx(ctx, func(tx sql.Transaction) error { enc := func(stmt *sql.Statement) { stmt.BindText(1, filepath.Join(from, localDbFile)) } @@ -183,38 +183,20 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { return nil } -func openDB(dbLog *zap.Logger, path string) (*localsql.Database, error) { +func openDB(dbLog *zap.Logger, path string) (sql.LocalDatabase, error) { dbPath := filepath.Join(path, localDbFile) if _, err := os.Stat(dbPath); err != nil { - return nil, fmt.Errorf("open database %s: %w", dbPath, err) - } - - migrations, err := sql.LocalMigrations() - if err != nil { - return nil, fmt.Errorf("get local migrations: %w", err) + return nil, fmt.Errorf("stat source database %s: %w", dbPath, err) } db, err := localsql.Open("file:"+dbPath, sql.WithLogger(dbLog), - sql.WithMigrations(nil), // do not migrate database when opening + sql.WithMigrationsDisabled(), ) if err != nil { return nil, fmt.Errorf("open source database %s: %w", dbPath, err) } - // check if the source database has the right schema - var version int - _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - if err != nil { - return nil, fmt.Errorf("get source database schema for %s: %w", dbPath, err) - } - if version != len(migrations) { - db.Close() - return nil, ErrInvalidSchema - } return db, nil } diff --git a/cmd/merge-nodes/internal/merge_action_test.go b/cmd/merge-nodes/internal/merge_action_test.go index f2fdc111359..a323a16b9e9 100644 --- a/cmd/merge-nodes/internal/merge_action_test.go +++ b/cmd/merge-nodes/internal/merge_action_test.go @@ -24,20 +24,26 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) -func Test_MergeDBs_InvalidTargetScheme(t *testing.T) { - tmpDst := t.TempDir() - - migrations, err := sql.LocalMigrations() +func oldSchema(t *testing.T) *sql.Schema { + schema, err := localsql.Schema() require.NoError(t, err) + schema.Migrations = schema.Migrations[:2] + return schema +} + +func Test_MergeDBs_InvalidTargetSchema(t *testing.T) { + tmpDst := t.TempDir() db, err := localsql.Open("file:"+filepath.Join(tmpDst, localDbFile), - sql.WithMigrations(migrations[:2]), // old schema + sql.WithDatabaseSchema(oldSchema(t)), + sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) err = MergeDBs(context.Background(), zaptest.NewLogger(t), "", tmpDst) - require.ErrorIs(t, err, ErrInvalidSchema) + require.ErrorIs(t, err, sql.ErrOldSchema) require.ErrorContains(t, err, "target database") } @@ -79,25 +85,24 @@ func Test_MergeDBs_InvalidSourcePath(t *testing.T) { require.ErrorIs(t, err, fs.ErrNotExist) } -func Test_MergeDBs_InvalidSourceScheme(t *testing.T) { +func Test_MergeDBs_InvalidSourceSchema(t *testing.T) { tmpDst := t.TempDir() - migrations, err := sql.LocalMigrations() - require.NoError(t, err) - db, err := localsql.Open("file:" + filepath.Join(tmpDst, localDbFile)) require.NoError(t, err) require.NoError(t, db.Close()) tmpSrc := t.TempDir() db, err = localsql.Open("file:"+filepath.Join(tmpSrc, localDbFile), - sql.WithMigrations(migrations[:2]), // old schema + sql.WithDatabaseSchema(oldSchema(t)), + sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) err = MergeDBs(context.Background(), zaptest.NewLogger(t), tmpSrc, tmpDst) - require.ErrorIs(t, err, ErrInvalidSchema) + require.ErrorIs(t, err, sql.ErrOldSchema) require.ErrorContains(t, err, "source database") } diff --git a/cmd/node/main.go b/cmd/node/main.go index 99ef2265e49..0fb230794db 100644 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -3,7 +3,6 @@ package main import ( - "fmt" _ "net/http/pprof" "os" @@ -24,7 +23,8 @@ func main() { // run the app cmd.Branch = branch cmd.NoMainNet = noMainNet == "true" if err := node.GetCommand().Execute(); err != nil { - fmt.Fprintln(os.Stderr, err) + // Do not print error as cmd.SilenceErrors is false + // and the error was already printed os.Exit(1) } } diff --git a/config/config.go b/config/config.go index 230d9a9a8a5..78ef09c66a2 100644 --- a/config/config.go +++ b/config/config.go @@ -117,6 +117,7 @@ type BaseConfig struct { DatabaseSkipMigrations []int `mapstructure:"db-skip-migrations"` DatabaseQueryCache bool `mapstructure:"db-query-cache"` DatabaseQueryCacheSizes DatabaseQueryCacheSizes `mapstructure:"db-query-cache-sizes"` + DatabaseSchemaAllowDrift bool `mapstructure:"db-allow-schema-drift"` PruneActivesetsFrom types.EpochID `mapstructure:"prune-activesets-from"` diff --git a/datastore/mocks/mocks.go b/datastore/mocks/mocks.go deleted file mode 100644 index bbb150d389f..00000000000 --- a/datastore/mocks/mocks.go +++ /dev/null @@ -1,156 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./store.go -// -// Generated by this command: -// -// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./store.go -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - sql "github.com/spacemeshos/go-spacemesh/sql" - gomock "go.uber.org/mock/gomock" -) - -// MockExecutor is a mock of Executor interface. -type MockExecutor struct { - ctrl *gomock.Controller - recorder *MockExecutorMockRecorder -} - -// MockExecutorMockRecorder is the mock recorder for MockExecutor. -type MockExecutorMockRecorder struct { - mock *MockExecutor -} - -// NewMockExecutor creates a new mock instance. -func NewMockExecutor(ctrl *gomock.Controller) *MockExecutor { - mock := &MockExecutor{ctrl: ctrl} - mock.recorder = &MockExecutorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockExecutor) EXPECT() *MockExecutorMockRecorder { - return m.recorder -} - -// Exec mocks base method. -func (m *MockExecutor) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockExecutorMockRecorder) Exec(arg0, arg1, arg2 any) *MockExecutorExecCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockExecutor)(nil).Exec), arg0, arg1, arg2) - return &MockExecutorExecCall{Call: call} -} - -// MockExecutorExecCall wrap *gomock.Call -type MockExecutorExecCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorExecCall) Return(arg0 int, arg1 error) *MockExecutorExecCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockExecutorExecCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockExecutorExecCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCache mocks base method. -func (m *MockExecutor) QueryCache() sql.QueryCache { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCache") - ret0, _ := ret[0].(sql.QueryCache) - return ret0 -} - -// QueryCache indicates an expected call of QueryCache. -func (mr *MockExecutorMockRecorder) QueryCache() *MockExecutorQueryCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockExecutor)(nil).QueryCache)) - return &MockExecutorQueryCacheCall{Call: call} -} - -// MockExecutorQueryCacheCall wrap *gomock.Call -type MockExecutorQueryCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorQueryCacheCall) Return(arg0 sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorQueryCacheCall) Do(f func() sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTx mocks base method. -func (m *MockExecutor) WithTx(arg0 context.Context, arg1 func(*sql.Tx) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTx", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTx indicates an expected call of WithTx. -func (mr *MockExecutorMockRecorder) WithTx(arg0, arg1 any) *MockExecutorWithTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockExecutor)(nil).WithTx), arg0, arg1) - return &MockExecutorWithTxCall{Call: call} -} - -// MockExecutorWithTxCall wrap *gomock.Call -type MockExecutorWithTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorWithTxCall) Return(arg0 error) *MockExecutorWithTxCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorWithTxCall) Do(f func(context.Context, func(*sql.Tx) error) error) *MockExecutorWithTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorWithTxCall) DoAndReturn(f func(context.Context, func(*sql.Tx) error) error) *MockExecutorWithTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/datastore/store.go b/datastore/store.go index 1d5180bfc5a..014f23a7e70 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -30,18 +30,9 @@ type VrfNonceKey struct { Epoch types.EpochID } -//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./store.go - -type Executor interface { - sql.Executor - WithTx(context.Context, func(*sql.Tx) error) error - QueryCache() sql.QueryCache -} - // CachedDB is simply a database injected with cache. type CachedDB struct { - Executor - sql.QueryCache + sql.Database logger *zap.Logger // cache is optional in tests. It MUST be set for the 'App' @@ -91,7 +82,7 @@ func WithConsensusCache(c *atxsdata.Data) Opt { } // NewCachedDB create an instance of a CachedDB. -func NewCachedDB(db Executor, lg *zap.Logger, opts ...Opt) *CachedDB { +func NewCachedDB(db sql.StateDatabase, lg *zap.Logger, opts ...Opt) *CachedDB { o := cacheOpts{cfg: DefaultConfig()} for _, opt := range opts { opt(&o) @@ -114,8 +105,7 @@ func NewCachedDB(db Executor, lg *zap.Logger, opts ...Opt) *CachedDB { } return &CachedDB{ - Executor: db, - QueryCache: db.QueryCache(), + Database: db, logger: lg, atxsdata: o.atxsdata, atxCache: atxHdrCache, @@ -169,7 +159,7 @@ func (db *CachedDB) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof return proof, nil } - proof, err := identities.GetMalfeasanceProof(db.Executor, id) + proof, err := identities.GetMalfeasanceProof(db.Database, id) if err != nil && err != sql.ErrNotFound { return nil, err } diff --git a/datastore/store_test.go b/datastore/store_test.go index c89f01dd097..80c73abafb2 100644 --- a/datastore/store_test.go +++ b/datastore/store_test.go @@ -26,6 +26,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/poets" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -54,7 +55,7 @@ func getBytes( } func TestMalfeasanceProof_Honest(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t)) require.Equal(t, 0, cdb.MalfeasanceCacheSize()) @@ -115,7 +116,7 @@ func TestMalfeasanceProof_Honest(t *testing.T) { } func TestMalfeasanceProof_Dishonest(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t)) require.Equal(t, 0, cdb.MalfeasanceCacheSize()) @@ -143,7 +144,7 @@ func TestMalfeasanceProof_Dishonest(t *testing.T) { } func TestBlobStore_GetATXBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -186,7 +187,7 @@ func TestBlobStore_GetATXBlob(t *testing.T) { } func TestBlobStore_GetBallotBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -221,7 +222,7 @@ func TestBlobStore_GetBallotBlob(t *testing.T) { } func TestBlobStore_GetBlockBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -256,7 +257,7 @@ func TestBlobStore_GetBlockBlob(t *testing.T) { } func TestBlobStore_GetPoetBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -285,7 +286,7 @@ func TestBlobStore_GetPoetBlob(t *testing.T) { } func TestBlobStore_GetProposalBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() proposals := store.New() bs := datastore.NewBlobStore(db, proposals) ctx := context.Background() @@ -323,7 +324,7 @@ func TestBlobStore_GetProposalBlob(t *testing.T) { } func TestBlobStore_GetTXBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -351,7 +352,7 @@ func TestBlobStore_GetTXBlob(t *testing.T) { } func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -385,7 +386,7 @@ func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { } func TestBlobStore_GetActiveSet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -409,7 +410,7 @@ func TestBlobStore_GetActiveSet(t *testing.T) { } func Test_MarkingMalicious(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() store := atxsdata.New() id := types.RandomNodeID() cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t), datastore.WithConsensusCache(store)) diff --git a/fetch/fetch_test.go b/fetch/fetch_test.go index 00599bd5d4d..b93327e7da2 100644 --- a/fetch/fetch_test.go +++ b/fetch/fetch_test.go @@ -21,7 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/proposals/store" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testFetch struct { @@ -80,7 +80,7 @@ func createFetch(tb testing.TB) *testFetch { } lg := zaptest.NewLogger(tb) - tf.Fetch = NewFetch(datastore.NewCachedDB(sql.InMemory(), lg), store.New(), nil, + tf.Fetch = NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg), store.New(), nil, WithContext(context.TODO()), WithConfig(cfg), WithLogger(lg), @@ -117,7 +117,7 @@ func badReceiver(context.Context, types.Hash32, p2p.Peer, []byte) error { func TestFetch_Start(t *testing.T) { lg := zaptest.NewLogger(t) - f := NewFetch(datastore.NewCachedDB(sql.InMemory(), lg), store.New(), nil, + f := NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg), store.New(), nil, WithContext(context.TODO()), WithConfig(DefaultConfig()), WithLogger(lg), @@ -384,7 +384,7 @@ func TestFetch_PeerDroppedWhenMessageResultsInValidationReject(t *testing.T) { }) defer eg.Wait() - fetcher := NewFetch(datastore.NewCachedDB(sql.InMemory(), lg), store.New(), h, + fetcher := NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg), store.New(), h, WithContext(ctx), WithConfig(cfg), WithLogger(lg), diff --git a/fetch/handler_test.go b/fetch/handler_test.go index d42900a28a5..45905c1dc23 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -24,17 +24,18 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testHandler struct { *handler - db *sql.Database + db sql.StateDatabase cdb *datastore.CachedDB } func createTestHandler(t testing.TB, opts ...sql.Opt) *testHandler { lg := zaptest.NewLogger(t) - db := sql.InMemory(opts...) + db := statesql.InMemory(opts...) cdb := datastore.NewCachedDB(db, lg) return &testHandler{ handler: newHandler(cdb, datastore.NewBlobStore(cdb, store.New()), lg), @@ -349,6 +350,8 @@ func testHandleEpochInfoReqWithQueryCache( getInfo func(th *testHandler, req []byte, ed *EpochData), ) { th := createTestHandler(t, sql.WithQueryCache(true)) + require.True(t, th.cdb.Database.IsCached()) + require.True(t, sql.IsCached(th.cdb)) epoch := types.EpochID(11) var expected EpochData @@ -359,8 +362,7 @@ func testHandleEpochInfoReqWithQueryCache( expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) } - qc := th.cdb.Executor.(interface{ QueryCount() int }) - require.Equal(t, 20, qc.QueryCount()) + require.Equal(t, 20, th.cdb.Database.QueryCount()) epochBytes, err := codec.Encode(epoch) require.NoError(t, err) @@ -368,7 +370,7 @@ func testHandleEpochInfoReqWithQueryCache( for i := 0; i < 3; i++ { getInfo(th, epochBytes, &got) require.ElementsMatch(t, expected.AtxIDs, got.AtxIDs) - require.Equal(t, 21, qc.QueryCount()) + require.Equal(t, 21, th.cdb.Database.QueryCount(), "query count @ i = %d", i) } // Add another ATX which should be appended to the cached slice @@ -376,14 +378,14 @@ func testHandleEpochInfoReqWithQueryCache( require.NoError(t, atxs.Add(th.cdb, vatx, blob)) atxs.AtxAdded(th.cdb, vatx) expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) - require.Equal(t, 23, qc.QueryCount()) + require.Equal(t, 23, th.cdb.Database.QueryCount()) getInfo(th, epochBytes, &got) require.ElementsMatch(t, expected.AtxIDs, got.AtxIDs) // The query count is not incremented as the slice is still // cached and the new atx is just appended to it, even though // the response is re-serialized. - require.Equal(t, 23, qc.QueryCount()) + require.Equal(t, 23, th.cdb.Database.QueryCount()) } func TestHandleEpochInfoReqWithQueryCache(t *testing.T) { diff --git a/fetch/mesh_data_test.go b/fetch/mesh_data_test.go index a4084653b19..31205448458 100644 --- a/fetch/mesh_data_test.go +++ b/fetch/mesh_data_test.go @@ -26,7 +26,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/proposals/store" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -999,7 +999,7 @@ func Test_GetAtxsLimiting(t *testing.T) { cfg.QueueSize = 1000 cfg.GetAtxsConcurrency = getAtxConcurrency - cdb := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + cdb := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) client := server.New(wrapHost(mesh.Hosts()[0]), hashProtocol, nil) host, err := p2p.Upgrade(mesh.Hosts()[0]) require.NoError(t, err) diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index a2de466a6d4..25fdad3cdac 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -27,6 +27,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/poets" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -37,13 +38,13 @@ type blobKey struct { type testP2PFetch struct { t *testing.T - clientDB *sql.Database + clientDB sql.StateDatabase // client proposals clientPDB *store.Store clientCDB *datastore.CachedDB clientFetch *Fetch serverID peer.ID - serverDB *sql.Database + serverDB sql.StateDatabase // server proposals serverPDB *store.Store serverCDB *datastore.CachedDB @@ -104,8 +105,8 @@ func createP2PFetch( if sqlCache { sqlOpts = []sql.Opt{sql.WithQueryCache(true)} } - clientDB := sql.InMemory(sqlOpts...) - serverDB := sql.InMemory(sqlOpts...) + clientDB := statesql.InMemory(sqlOpts...) + serverDB := statesql.InMemory(sqlOpts...) tpf := &testP2PFetch{ t: t, clientDB: clientDB, diff --git a/genvm/core/context_test.go b/genvm/core/context_test.go index e13a8287ff5..92d8cb627db 100644 --- a/genvm/core/context_test.go +++ b/genvm/core/context_test.go @@ -10,23 +10,23 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/core" "github.com/spacemeshos/go-spacemesh/genvm/core/mocks" "github.com/spacemeshos/go-spacemesh/genvm/registry" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestTransfer(t *testing.T) { t.Run("NoBalance", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} require.ErrorIs(t, ctx.Transfer(core.Address{}, 100), core.ErrNoBalance) }) t.Run("MaxSpend", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 100 require.NoError(t, ctx.Transfer(core.Address{1}, 50)) require.ErrorIs(t, ctx.Transfer(core.Address{2}, 100), core.ErrMaxSpend) }) t.Run("ReducesBalance", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 1000 for _, amount := range []uint64{50, 100, 200, 255} { @@ -67,7 +67,7 @@ func TestConsume(t *testing.T) { func TestApply(t *testing.T) { t.Run("UpdatesNonce", func(t *testing.T) { - ss := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + ss := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: ss} ctx.PrincipalAccount.Address = core.Address{1} ctx.Header.Nonce = 10 @@ -80,7 +80,7 @@ func TestApply(t *testing.T) { require.Equal(t, ctx.PrincipalAccount.NextNonce, account.NextNonce) }) t.Run("ConsumeMaxGas", func(t *testing.T) { - ss := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + ss := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: ss} ctx.PrincipalAccount.Balance = 1000 @@ -97,7 +97,7 @@ func TestApply(t *testing.T) { require.Equal(t, ctx.Fee(), ctx.Header.MaxGas*ctx.Header.GasPrice) }) t.Run("PreserveTransferOrder", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Address = core.Address{1} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 1000 @@ -129,7 +129,7 @@ func TestRelay(t *testing.T) { remote = core.Address{'r', 'e', 'm'} ) t.Run("not spawned", func(t *testing.T) { - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: cache} call := func(remote core.Host) error { require.Fail(t, "not expected to be called") @@ -138,7 +138,7 @@ func TestRelay(t *testing.T) { require.ErrorIs(t, ctx.Relay(template, remote, call), core.ErrNotSpawned) }) t.Run("mismatched template", func(t *testing.T) { - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) require.NoError(t, cache.Update(core.Account{ Address: remote, TemplateAddress: &core.Address{'m', 'i', 's'}, @@ -166,7 +166,7 @@ func TestRelay(t *testing.T) { reg := registry.New() reg.Register(template, handler) - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) receiver2 := core.Address{'f'} const ( total = 1000 diff --git a/genvm/core/staged_cache_test.go b/genvm/core/staged_cache_test.go index 8a6b29519ea..a0189309555 100644 --- a/genvm/core/staged_cache_test.go +++ b/genvm/core/staged_cache_test.go @@ -6,11 +6,11 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/genvm/core" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestCacheGetCopies(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ss := core.NewStagedCache(core.DBLoader{db}) address := core.Address{1} account, err := ss.Get(address) @@ -23,7 +23,7 @@ func TestCacheGetCopies(t *testing.T) { } func TestCacheUpdatePreserveOrder(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ss := core.NewStagedCache(core.DBLoader{db}) order := []core.Address{{3}, {1}, {2}} for _, address := range order { diff --git a/genvm/templates/vault/vault_test.go b/genvm/templates/vault/vault_test.go index 28fef09d14d..e1f461fd9c8 100644 --- a/genvm/templates/vault/vault_test.go +++ b/genvm/templates/vault/vault_test.go @@ -9,7 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/genvm/core" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestVested(t *testing.T) { @@ -392,7 +392,7 @@ func TestSpend(t *testing.T) { } ctx := core.Context{ LayerID: types.LayerID(tc.lid), - Loader: core.NewStagedCache(core.DBLoader{Executor: sql.InMemory()}), + Loader: core.NewStagedCache(core.DBLoader{Executor: statesql.InMemory()}), Header: types.TxHeader{MaxSpend: math.MaxUint64}, PrincipalAccount: types.Account{ Address: owner, @@ -419,7 +419,7 @@ func TestSpend(t *testing.T) { } ctx := core.Context{ LayerID: types.LayerID(2), - Loader: core.NewStagedCache(core.DBLoader{Executor: sql.InMemory()}), + Loader: core.NewStagedCache(core.DBLoader{Executor: statesql.InMemory()}), Header: types.TxHeader{MaxSpend: math.MaxUint64}, PrincipalAccount: types.Account{ Address: owner, diff --git a/genvm/vm.go b/genvm/vm.go index aaf18917cdc..61d82c182de 100644 --- a/genvm/vm.go +++ b/genvm/vm.go @@ -58,7 +58,7 @@ func WithConfig(cfg Config) Opt { } // New returns VM instance. -func New(db *sql.Database, opts ...Opt) *VM { +func New(db sql.StateDatabase, opts ...Opt) *VM { vm := &VM{ logger: log.NewNop(), db: db, @@ -78,7 +78,7 @@ func New(db *sql.Database, opts ...Opt) *VM { // VM handles modifications to the account state. type VM struct { logger log.Log - db *sql.Database + db sql.StateDatabase cfg Config registry *registry.Registry } diff --git a/genvm/vm_test.go b/genvm/vm_test.go index 24943b2e0aa..2a81553df78 100644 --- a/genvm/vm_test.go +++ b/genvm/vm_test.go @@ -34,9 +34,9 @@ import ( "github.com/spacemeshos/go-spacemesh/hash" "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func testContext(lid types.LayerID) ApplyContext { @@ -48,7 +48,7 @@ func testContext(lid types.LayerID) ApplyContext { func newTester(tb testing.TB) *tester { return &tester{ TB: tb, - VM: New(sql.InMemory(), + VM: New(statesql.InMemory(), WithLogger(logtest.New(tb)), WithConfig(Config{GasLimit: math.MaxUint64}), ), @@ -279,7 +279,7 @@ type tester struct { } func (t *tester) persistent() *tester { - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) t.Cleanup(func() { require.NoError(t, db.Close()) }) require.NoError(t, err) t.VM = New(db, WithLogger(logtest.New(t)), @@ -2572,7 +2572,7 @@ func TestVestingData(t *testing.T) { spendAccountNonce := t2.nonces[0] spendAmount := uint64(1_000_000) - vm := New(sql.InMemory(), WithLogger(logtest.New(t))) + vm := New(statesql.InMemory(), WithLogger(logtest.New(t))) require.NoError(t, vm.ApplyGenesis( []core.Account{ { diff --git a/hare3/eligibility/oracle_test.go b/hare3/eligibility/oracle_test.go index ff1e13bf22e..94addaf5098 100644 --- a/hare3/eligibility/oracle_test.go +++ b/hare3/eligibility/oracle_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -46,14 +47,14 @@ func TestMain(m *testing.M) { type testOracle struct { *Oracle tb testing.TB - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data mBeacon *mocks.MockBeaconGetter mVerifier *MockvrfVerifier } func defaultOracle(tb testing.TB) *testOracle { - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ctrl := gomock.NewController(tb) diff --git a/hare3/hare.go b/hare3/hare.go index fcd049a926c..4e8a952da12 100644 --- a/hare3/hare.go +++ b/hare3/hare.go @@ -163,7 +163,7 @@ type nodeclock interface { func New( nodeclock nodeclock, pubsub pubsub.PublishSubsciber, - db *sql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, proposals *store.Store, verifier *signing.EdVerifier, @@ -225,7 +225,7 @@ type Hare struct { // dependencies nodeclock nodeclock pubsub pubsub.PublishSubsciber - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data proposals *store.Store verifier *signing.EdVerifier diff --git a/hare3/hare_test.go b/hare3/hare_test.go index dd8139ab12c..c9ea3634c26 100644 --- a/hare3/hare_test.go +++ b/hare3/hare_test.go @@ -32,6 +32,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/beacons" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -117,7 +118,7 @@ type node struct { vrfsigner *signing.VRFSigner atx *types.ActivationTx oracle *eligibility.Oracle - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data proposals *store.Store @@ -149,7 +150,7 @@ func (n *node) reuseSigner(signer *signing.EdSigner) *node { } func (n *node) withDb() *node { - n.db = sql.InMemory() + n.db = statesql.InMemory() n.atxsdata = atxsdata.New() n.proposals = store.New() return n @@ -892,7 +893,7 @@ func TestProposals(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() proposals := store.New() hare := New( diff --git a/hare3/malfeasance_test.go b/hare3/malfeasance_test.go index 0f2a0f14910..c8b1ee37b3e 100644 --- a/hare3/malfeasance_test.go +++ b/hare3/malfeasance_test.go @@ -16,17 +16,18 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/malfeasance/handler.go b/malfeasance/handler.go index 0476dc14644..a2c8a61c2c5 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -168,7 +168,7 @@ func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceGossip h.countInvalidProof(&p.MalfeasanceProof) return types.EmptyNodeID, err } - if err := h.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { malicious, err := identities.IsMalicious(dbtx, nodeID) if err != nil { return fmt.Errorf("check known malicious: %w", err) diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index 2384a4e8413..e902712687d 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -21,18 +21,19 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *Handler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase mockTrt *Mocktortoise } func newHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/mesh/executor_test.go b/mesh/executor_test.go index a4937f19a2d..2e5732ed081 100644 --- a/mesh/executor_test.go +++ b/mesh/executor_test.go @@ -22,6 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -33,7 +34,7 @@ func TestMain(m *testing.M) { type testExecutor struct { tb testing.TB exec *mesh.Executor - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data mcs *mocks.MockconservativeState mvm *mocks.MockvmState @@ -43,7 +44,7 @@ func newTestExecutor(t *testing.T) *testExecutor { ctrl := gomock.NewController(t) te := &testExecutor{ tb: t, - db: sql.InMemory(), + db: statesql.InMemory(), atxsdata: atxsdata.New(), mvm: mocks.NewMockvmState(ctrl), mcs: mocks.NewMockconservativeState(ctrl), diff --git a/mesh/malfeasance_test.go b/mesh/malfeasance_test.go index 8e4c607bf89..6d473c517a5 100644 --- a/mesh/malfeasance_test.go +++ b/mesh/malfeasance_test.go @@ -16,17 +16,18 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/mesh/mesh.go b/mesh/mesh.go index eee3391bb12..94b8b2dc8c9 100644 --- a/mesh/mesh.go +++ b/mesh/mesh.go @@ -35,7 +35,7 @@ import ( // Mesh is the logic layer above our mesh.DB database. type Mesh struct { logger log.Log - cdb *sql.Database + cdb sql.StateDatabase atxsdata *atxsdata.Data clock layerClock @@ -58,7 +58,7 @@ type Mesh struct { // NewMesh creates a new instant of a mesh. func NewMesh( - db *sql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, c layerClock, trtl system.Tortoise, @@ -92,7 +92,7 @@ func NewMesh( } genesis := types.GetEffectiveGenesis() - if err = db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + if err = db.WithTx(context.Background(), func(dbtx sql.Transaction) error { if err = layers.SetProcessed(dbtx, genesis); err != nil { return fmt.Errorf("mesh init: %w", err) } @@ -370,7 +370,7 @@ func (msh *Mesh) applyResults(ctx context.Context, results []result.Layer) error return fmt.Errorf("execute block %v/%v: %w", layer.Layer, target, err) } } - if err := msh.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := msh.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { if err := layers.SetApplied(dbtx, layer.Layer, target); err != nil { return fmt.Errorf("set applied for %v/%v: %w", layer.Layer, target, err) } @@ -420,7 +420,7 @@ func (msh *Mesh) saveHareOutput(ctx context.Context, lid types.LayerID, bid type certs []certificates.CertValidity err error ) - if err = msh.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err = msh.cdb.WithTx(ctx, func(tx sql.Transaction) error { // check if a certificate has been generated or sync'ed. // - node generated the certificate when it collected enough certify messages // - hare outputs are processed in layer order. i.e. when hare fails for a previous layer N, @@ -542,7 +542,7 @@ func (msh *Mesh) AddBallot( var proof *wire.MalfeasanceProof // ballots.LayerBallotByNodeID and ballots.Add should be atomic // otherwise concurrent ballots.Add from the same smesher may not be noticed - if err := msh.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := msh.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { if !malicious { prev, err := ballots.LayerBallotByNodeID(dbtx, ballot.Layer, ballot.SmesherID) if err != nil && !errors.Is(err, sql.ErrNotFound) { diff --git a/mesh/mesh_test.go b/mesh/mesh_test.go index c41bc3c52bc..795564a4aaa 100644 --- a/mesh/mesh_test.go +++ b/mesh/mesh_test.go @@ -24,6 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -36,7 +37,7 @@ const ( type testMesh struct { *Mesh - db *sql.Database + db sql.StateDatabase // it is used in malfeasence.Validate, which is called in the tests cdb *datastore.CachedDB atxsdata *atxsdata.Data @@ -50,7 +51,7 @@ func createTestMesh(t *testing.T) *testMesh { t.Helper() types.SetLayersPerEpoch(3) lg := logtest.New(t) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ctrl := gomock.NewController(t) tm := &testMesh{ diff --git a/miner/active_set_generator_test.go b/miner/active_set_generator_test.go index 660952d03c0..c1e381c2ce9 100644 --- a/miner/active_set_generator_test.go +++ b/miner/active_set_generator_test.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/activeset" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type expect struct { @@ -65,7 +66,7 @@ func unixPtr(sec, nsec int64) *time.Time { func newTesterActiveSetGenerator(tb testing.TB, cfg config) *testerActiveSetGenerator { var ( - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ctrl = gomock.NewController(tb) @@ -97,8 +98,8 @@ type testerActiveSetGenerator struct { tb testing.TB gen *activeSetGenerator - db *sql.Database - localdb *localsql.Database + db sql.StateDatabase + localdb sql.LocalDatabase atxsdata *atxsdata.Data ctrl *gomock.Controller clock *mocks.MocklayerClock diff --git a/miner/proposal_builder_test.go b/miner/proposal_builder_test.go index 6df1c625924..1e9925bdf71 100644 --- a/miner/proposal_builder_test.go +++ b/miner/proposal_builder_test.go @@ -25,7 +25,6 @@ import ( pmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/proposals" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" @@ -35,6 +34,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -749,7 +749,7 @@ func TestBuild(t *testing.T) { publisher = pmocks.NewMockPublisher(ctrl) tortoise = mocks.NewMockvotesEncoder(ctrl) syncer = smocks.NewMockSyncStateProvider(ctrl) - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ) @@ -905,7 +905,7 @@ func TestStartStop(t *testing.T) { publisher = pmocks.NewMockPublisher(ctrl) tortoise = mocks.NewMockvotesEncoder(ctrl) syncer = smocks.NewMockSyncStateProvider(ctrl) - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ) diff --git a/node/node.go b/node/node.go index 9d7e8ee9b19..ca68c12af4a 100644 --- a/node/node.go +++ b/node/node.go @@ -76,7 +76,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" dbmetrics "github.com/spacemeshos/go-spacemesh/sql/metrics" - "github.com/spacemeshos/go-spacemesh/sql/migrations" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/atxsync" "github.com/spacemeshos/go-spacemesh/syncer/blockssync" @@ -379,10 +379,10 @@ type App struct { fileLock *flock.Flock signers []*signing.EdSigner Config *config.Config - db *sql.Database + db sql.StateDatabase cachedDB *datastore.CachedDB dbMetrics *dbmetrics.DBMetricsCollector - localDB *localsql.Database + localDB sql.LocalDatabase grpcPublicServer *grpcserver.Server grpcPrivateServer *grpcserver.Server grpcPostServer *grpcserver.Server @@ -1897,18 +1897,20 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { return fmt.Errorf("failed to create %s: %w", dbPath, err) } dbLog := app.addLogger(StateDbLogger, lg) - m21 := migrations.New0021Migration(dbLog.Zap(), 1_000_000) - migrations, err := sql.StateMigrations() + schema, err := statesql.Schema() if err != nil { - return fmt.Errorf("failed to load migrations: %w", err) + return fmt.Errorf("error loading db schema: %w", err) + } + if len(app.Config.DatabaseSkipMigrations) > 0 { + schema.SkipMigrations(app.Config.DatabaseSkipMigrations...) } dbopts := []sql.Opt{ sql.WithLogger(dbLog.Zap()), - sql.WithMigrations(migrations), - sql.WithMigration(m21), + sql.WithDatabaseSchema(schema), sql.WithConnections(app.Config.DatabaseConnections), sql.WithLatencyMetering(app.Config.DatabaseLatencyMetering), sql.WithVacuumState(app.Config.DatabaseVacuumState), + sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), sql.WithQueryCache(app.Config.DatabaseQueryCache), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ atxs.CacheKindEpochATXs: app.Config.DatabaseQueryCacheSizes.EpochATXs, @@ -1916,10 +1918,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { activesets.CacheKindActiveSetBlob: app.Config.DatabaseQueryCacheSizes.ActiveSetBlob, }), } - if len(app.Config.DatabaseSkipMigrations) > 0 { - dbopts = append(dbopts, sql.WithSkipMigrations(app.Config.DatabaseSkipMigrations...)) - } - sqlDB, err := sql.Open("file:"+filepath.Join(dbPath, dbFile), dbopts...) + sqlDB, err := statesql.Open("file:"+filepath.Join(dbPath, dbFile), dbopts...) if err != nil { return fmt.Errorf("open sqlite db %w", err) } @@ -1961,14 +1960,10 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { app.log.With().Info("malicious ATX check completed", log.Duration("duration", time.Since(start))) } - migrations, err = sql.LocalMigrations() - if err != nil { - return fmt.Errorf("load local migrations: %w", err) - } localDB, err := localsql.Open("file:"+filepath.Join(dbPath, localDbFile), sql.WithLogger(dbLog.Zap()), - sql.WithMigrations(migrations), sql.WithConnections(app.Config.DatabaseConnections), + sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), ) if err != nil { return fmt.Errorf("open sqlite db %w", err) diff --git a/node/node_version_check_test.go b/node/node_version_check_test.go index 8e8d8a6d766..affc503f86e 100644 --- a/node/node_version_check_test.go +++ b/node/node_version_check_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/config" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestUpgradeToV15(t *testing.T) { @@ -37,10 +38,15 @@ func TestUpgradeToV15(t *testing.T) { uri := path.Join(cfg.DataDir(), localDbFile) - migrations, err := sql.LocalMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - db, err := sql.Open(uri, sql.WithMigrations(migrations[:2])) + schema.Migrations = schema.Migrations[:2] + + db, err := statesql.Open(uri, + sql.WithDatabaseSchema(schema), + sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift()) require.NoError(t, err) require.NoError(t, db.Close()) diff --git a/proposals/handler.go b/proposals/handler.go index 8109d8d9ebe..adcb14f27b0 100644 --- a/proposals/handler.go +++ b/proposals/handler.go @@ -49,7 +49,7 @@ type Handler struct { logger log.Log cfg Config - db *sql.Database + db sql.StateDatabase atxsdata *atxsdata.Data activeSets *lru.Cache[types.Hash32, uint64] edVerifier *signing.EdVerifier @@ -108,7 +108,7 @@ func WithConfig(cfg Config) Opt { // NewHandler creates new Handler. func NewHandler( - db *sql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, proposals proposalsConsumer, edVerifier *signing.EdVerifier, diff --git a/proposals/handler_test.go b/proposals/handler_test.go index d2d52e62930..60ed498a5ae 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -31,6 +31,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/blocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -93,7 +94,7 @@ func fullMockSet(tb testing.TB) *mockSet { func createTestHandler(t *testing.T) *testHandler { types.SetLayersPerEpoch(layersPerEpoch) ms := fullMockSet(t) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ms.md.EXPECT().GetBallot(gomock.Any()).AnyTimes().DoAndReturn(func(id types.BallotID) *tortoise.BallotData { ballot, err := ballots.Get(db, id) @@ -236,7 +237,7 @@ func createProposal(t *testing.T, opts ...any) *types.Proposal { return p } -func createAtx(t *testing.T, db *sql.Database, epoch types.EpochID, atxID types.ATXID, nodeID types.NodeID) { +func createAtx(t *testing.T, db sql.StateDatabase, epoch types.EpochID, atxID types.ATXID, nodeID types.NodeID) { atx := &types.ActivationTx{ PublishEpoch: epoch, NumUnits: 1, diff --git a/prune/prune.go b/prune/prune.go index 0f7ddb4218c..2fe9ef20c41 100644 --- a/prune/prune.go +++ b/prune/prune.go @@ -22,7 +22,7 @@ func WithLogger(logger *zap.Logger) Opt { } } -func New(db *sql.Database, safeDist uint32, activesetEpoch types.EpochID, opts ...Opt) *Pruner { +func New(db sql.StateDatabase, safeDist uint32, activesetEpoch types.EpochID, opts ...Opt) *Pruner { p := &Pruner{ logger: zap.NewNop(), db: db, @@ -37,7 +37,7 @@ func New(db *sql.Database, safeDist uint32, activesetEpoch types.EpochID, opts . type Pruner struct { logger *zap.Logger - db *sql.Database + db sql.StateDatabase safeDist uint32 activesetEpoch types.EpochID } diff --git a/prune/prune_test.go b/prune/prune_test.go index 91011f98f9b..0eb4ee5e7ee 100644 --- a/prune/prune_test.go +++ b/prune/prune_test.go @@ -11,13 +11,14 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) func TestPrune(t *testing.T) { types.SetLayersPerEpoch(3) - db := sql.InMemory() + db := statesql.InMemory() current := types.LayerID(10) lyrProps := make([]*types.Proposal, 0, current) diff --git a/sql/accounts/accounts_test.go b/sql/accounts/accounts_test.go index 21d34dbea2c..4556ad332a6 100644 --- a/sql/accounts/accounts_test.go +++ b/sql/accounts/accounts_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/builder" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func genSeq(address types.Address, n int) []*types.Account { @@ -21,7 +22,7 @@ func genSeq(address types.Address, n int) []*types.Account { func TestUpdate(t *testing.T) { address := types.Address{1, 2, 3} - db := sql.InMemory() + db := statesql.InMemory() seq := genSeq(address, 2) for _, update := range seq { require.NoError(t, Update(db, update)) @@ -34,7 +35,7 @@ func TestUpdate(t *testing.T) { func TestHas(t *testing.T) { address := types.Address{1, 2, 3} - db := sql.InMemory() + db := statesql.InMemory() has, err := Has(db, address) require.NoError(t, err) require.False(t, has) @@ -50,7 +51,7 @@ func TestHas(t *testing.T) { func TestRevert(t *testing.T) { address := types.Address{1, 1} seq := genSeq(address, 10) - db := sql.InMemory() + db := statesql.InMemory() for _, update := range seq { require.NoError(t, Update(db, update)) } @@ -62,7 +63,7 @@ func TestRevert(t *testing.T) { } func TestAll(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() addresses := []types.Address{{1, 1}, {2, 2}, {3, 3}} n := []int{10, 7, 20} for i, address := range addresses { @@ -81,7 +82,7 @@ func TestAll(t *testing.T) { } func TestSnapshot(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Snapshot(db, types.LayerID(1)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -108,7 +109,7 @@ func TestSnapshot(t *testing.T) { } func TestIterateAccountsOps(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i := 0; i < 100; i++ { addr := types.Address{} diff --git a/sql/activesets/activesets_test.go b/sql/activesets/activesets_test.go index 8828fa417e6..acb8fe53ce9 100644 --- a/sql/activesets/activesets_test.go +++ b/sql/activesets/activesets_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestActiveSet(t *testing.T) { @@ -19,7 +20,7 @@ func TestActiveSet(t *testing.T) { Epoch: 2, Set: []types.ATXID{{1}, {2}}, } - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, Add(db, ids[0], set)) require.ErrorIs(t, Add(db, ids[0], set), sql.ErrObjectExists) @@ -68,7 +69,7 @@ func TestCachedActiveSet(t *testing.T) { Epoch: 2, Set: []types.ATXID{{3}, {4}}, } - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) require.NoError(t, Add(db, ids[0], set0)) require.NoError(t, Add(db, ids[1], set1)) @@ -78,12 +79,12 @@ func TestCachedActiveSet(t *testing.T) { for i := 0; i < 3; i++ { require.NoError(t, LoadBlob(ctx, db, ids[0].Bytes(), &b)) require.Equal(t, codec.MustEncode(set0), b.Bytes) - require.Equal(t, 3, db.QueryCount()) + require.Equal(t, 3, db.QueryCount(), "ids[0]: QueryCount at i=%d", i) } for i := 0; i < 3; i++ { require.NoError(t, LoadBlob(ctx, db, ids[1].Bytes(), &b)) require.Equal(t, codec.MustEncode(set1), b.Bytes) - require.Equal(t, 4, db.QueryCount()) + require.Equal(t, 4, db.QueryCount(), "ids[1]: QueryCount at i=%d", i) } } diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index d38e7836014..99d2c6cc212 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -16,6 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 5 @@ -28,7 +29,7 @@ func TestMain(m *testing.M) { } func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -50,7 +51,7 @@ func TestGet(t *testing.T) { } func TestAll(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() var expected []types.ATXID for i := 0; i < 3; i++ { @@ -67,7 +68,7 @@ func TestAll(t *testing.T) { } func TestHasID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -90,7 +91,7 @@ func TestHasID(t *testing.T) { } func Test_IdentityExists(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -108,7 +109,7 @@ func Test_IdentityExists(t *testing.T) { } func TestGetFirstIDByNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -141,7 +142,7 @@ func TestGetFirstIDByNodeID(t *testing.T) { } func TestLatestN(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -233,7 +234,7 @@ func TestLatestN(t *testing.T) { } func TestGetByEpochAndNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -265,7 +266,7 @@ func TestGetByEpochAndNodeID(t *testing.T) { } func TestGetLastIDByNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -299,7 +300,7 @@ func TestGetLastIDByNodeID(t *testing.T) { } func TestGetIDByEpochAndNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -343,7 +344,7 @@ func TestGetIDByEpochAndNodeID(t *testing.T) { } func TestGetIDsByEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig1, err := signing.NewEdSigner() @@ -379,7 +380,7 @@ func TestGetIDsByEpoch(t *testing.T) { } func TestGetIDsByEpochCached(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) ctx := context.Background() sig1, err := signing.NewEdSigner() @@ -428,7 +429,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.Equal(t, 11, db.QueryCount()) } - require.NoError(t, db.WithTx(context.Background(), func(tx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { atxs.Add(tx, atx5, types.AtxBlob{}) return nil })) @@ -440,7 +441,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.ElementsMatch(t, []types.ATXID{atx4.ID(), atx5.ID()}, ids3) require.Equal(t, 13, db.QueryCount()) // not incremented after Add - require.Error(t, db.WithTx(context.Background(), func(tx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { atxs.Add(tx, atx6, types.AtxBlob{}) return errors.New("fail") // rollback })) @@ -453,7 +454,7 @@ func TestGetIDsByEpochCached(t *testing.T) { } func Test_IterateAtxsWithMalfeasance(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() e1 := types.EpochID(1) m := make(map[types.ATXID]bool) @@ -483,7 +484,7 @@ func Test_IterateAtxsWithMalfeasance(t *testing.T) { } func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() e1 := types.EpochID(1) m := make(map[types.ATXID]bool) @@ -514,7 +515,7 @@ func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { func TestVRFNonce(t *testing.T) { // Arrange - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -547,7 +548,7 @@ func TestVRFNonce(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig, err := signing.NewEdSigner() @@ -599,7 +600,7 @@ func TestLoadBlob(t *testing.T) { } func TestLoadBlob_DefaultsToV1(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -616,7 +617,7 @@ func TestLoadBlob_DefaultsToV1(t *testing.T) { } func TestGetBlobCached(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) ctx := context.Background() sig, err := signing.NewEdSigner() @@ -638,7 +639,7 @@ func TestGetBlobCached(t *testing.T) { // Test that we don't put in the cache a reference to the blob that was passed to LoadBlob. // Each cache entry must use a unique slice for the blob. func TestGetBlobCached_CacheEntriesAreDistinct(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) atx := types.ActivationTx{} atx.SetID(types.RandomATXID()) @@ -670,7 +671,7 @@ func TestGetBlobCached_CacheEntriesAreDistinct(t *testing.T) { // Test that the cached blob is not shared with the caller // but copied into the provided blob. func TestGetBlobCached_OverwriteSafety(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) atx := types.ActivationTx{} atx.SetID(types.RandomATXID()) blob := types.AtxBlob{Blob: []byte("original blob")} @@ -688,7 +689,7 @@ func TestGetBlobCached_OverwriteSafety(t *testing.T) { } func TestCachedBlobEviction(t *testing.T) { - db := sql.InMemory( + db := statesql.InMemory( sql.WithQueryCache(true), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ atxs.CacheKindATXBlob: 10, @@ -729,7 +730,7 @@ func TestCachedBlobEviction(t *testing.T) { } func TestCheckpointATX(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig, err := signing.NewEdSigner() @@ -776,7 +777,7 @@ func TestCheckpointATX(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nonExistingATXID := types.ATXID(types.CalcHash32([]byte("0"))) _, err := atxs.Get(db, nonExistingATXID) @@ -854,7 +855,7 @@ type header struct { filteredOut bool } -func createAtx(tb testing.TB, db *sql.Database, hdr header) (types.ATXID, *signing.EdSigner) { +func createAtx(tb testing.TB, db sql.StateDatabase, hdr header) (types.ATXID, *signing.EdSigner) { sig, err := signing.NewEdSigner() require.NoError(tb, err) @@ -968,7 +969,7 @@ func TestGetIDWithMaxHeight(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() var sigs []*signing.EdSigner var ids []types.ATXID filtered := make(map[types.ATXID]struct{}) @@ -1009,7 +1010,7 @@ func TestLatest(t *testing.T) { {"out of order", []uint32{3, 4, 1, 2}, 4}, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i, epoch := range tc.epochs { full := &types.ActivationTx{ PublishEpoch: types.EpochID(epoch), @@ -1028,7 +1029,7 @@ func TestLatest(t *testing.T) { } func Test_PrevATXCollisions(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -1079,13 +1080,13 @@ func TestCoinbase(t *testing.T) { t.Parallel() t.Run("not found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.Coinbase(db, types.NodeID{}) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) atx, blob := newAtx(t, sig, withCoinbase(types.Address{1, 2, 3})) @@ -1096,7 +1097,7 @@ func TestCoinbase(t *testing.T) { }) t.Run("picks last", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) atx1, blob1 := newAtx(t, sig, withPublishEpoch(1), withCoinbase(types.Address{1, 2, 3})) @@ -1113,13 +1114,13 @@ func TestUnits(t *testing.T) { t.Parallel() t.Run("ATX not found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.Units(db, types.RandomATXID(), types.RandomNodeID()) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("smesher has no units in ATX", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atxID := types.RandomATXID() require.NoError(t, atxs.SetUnits(db, atxID, types.RandomNodeID(), 10)) _, err := atxs.Units(db, atxID, types.RandomNodeID()) @@ -1127,7 +1128,7 @@ func TestUnits(t *testing.T) { }) t.Run("returns units for given smesher in given ATX", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atxID := types.RandomATXID() units := map[types.NodeID]uint32{ {1, 2, 3}: 10, diff --git a/sql/ballots/ballots_test.go b/sql/ballots/ballots_test.go index 901161d0f47..2f39a01fb38 100644 --- a/sql/ballots/ballots_test.go +++ b/sql/ballots/ballots_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 3 @@ -26,7 +27,7 @@ func TestMain(m *testing.M) { } func TestLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) pub := types.BytesToNodeID([]byte{1, 1, 1}) ballots := []types.Ballot{ @@ -65,7 +66,7 @@ func TestLayer(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nodeID := types.RandomNodeID() ballot := types.NewExistingBallot(types.BallotID{1}, types.RandomEdSignature(), nodeID, types.LayerID(0)) _, err := Get(db, ballot.ID()) @@ -85,7 +86,7 @@ func TestAdd(t *testing.T) { } func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ballot := types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, types.EmptyNodeID, types.LayerID(0)) exists, err := Has(db, ballot.ID()) @@ -99,7 +100,7 @@ func TestHas(t *testing.T) { } func TestLatest(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() latest, err := LatestLayer(db) require.NoError(t, err) require.Equal(t, types.LayerID(0), latest) @@ -123,7 +124,7 @@ func TestLatest(t *testing.T) { } func TestLayerBallotBySmesher(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) nodeID1 := types.RandomNodeID() nodeID2 := types.RandomNodeID() @@ -158,7 +159,7 @@ func newAtx(signer *signing.EdSigner, layerID types.LayerID) *types.ActivationTx } func TestFirstInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(layersPerEpoch * 2) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -285,7 +286,7 @@ func TestAllFirstInEpoch(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() for _, ballot := range tc.ballots { require.NoError(t, Add(db, &ballot)) } @@ -301,7 +302,7 @@ func TestAllFirstInEpoch(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() ballot1 := types.NewExistingBallot( diff --git a/sql/beacons/beacons_test.go b/sql/beacons/beacons_test.go index 1d5648d30d6..8a012e54ba9 100644 --- a/sql/beacons/beacons_test.go +++ b/sql/beacons/beacons_test.go @@ -7,12 +7,13 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const baseEpoch = 3 func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() beacons := []types.Beacon{ types.HexToBeacon("0x1"), @@ -35,7 +36,7 @@ func TestGet(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Get(db, types.EpochID(baseEpoch)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -50,7 +51,7 @@ func TestAdd(t *testing.T) { } func TestSet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Get(db, types.EpochID(baseEpoch)) require.ErrorIs(t, err, sql.ErrNotFound) diff --git a/sql/blocks/blocks_test.go b/sql/blocks/blocks_test.go index 8b64d2ab65c..038b542caab 100644 --- a/sql/blocks/blocks_test.go +++ b/sql/blocks/blocks_test.go @@ -10,10 +10,11 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestAddGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1, 1}, types.InnerBlock{LayerIndex: types.LayerID(1)}, @@ -26,7 +27,7 @@ func TestAddGet(t *testing.T) { } func TestAlreadyExists(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1}, types.InnerBlock{}, @@ -36,7 +37,7 @@ func TestAlreadyExists(t *testing.T) { } func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1}, types.InnerBlock{}, @@ -52,7 +53,7 @@ func TestHas(t *testing.T) { } func TestValidity(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -86,7 +87,7 @@ func TestValidity(t *testing.T) { } func TestLayerFilter(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -122,7 +123,7 @@ func TestLayerFilter(t *testing.T) { } func TestLayerOrdered(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -153,7 +154,7 @@ func TestLayerOrdered(t *testing.T) { } func TestContextualValidity(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -197,7 +198,7 @@ func TestContextualValidity(t *testing.T) { } func TestGetLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid1 := types.LayerID(11) block1 := types.NewExistingBlock( types.BlockID{1, 1}, @@ -222,12 +223,12 @@ func TestGetLayer(t *testing.T) { func TestLastValid(t *testing.T) { t.Run("empty", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := LastValid(db) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("all valid", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blocks := map[types.BlockID]struct { lid types.LayerID }{ @@ -248,7 +249,7 @@ func TestLastValid(t *testing.T) { require.Equal(t, 33, int(last)) }) t.Run("last is invalid", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blocks := map[types.BlockID]struct { invalid bool lid types.LayerID @@ -274,7 +275,7 @@ func TestLastValid(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lid1 := types.LayerID(11) @@ -315,7 +316,7 @@ func TestLoadBlob(t *testing.T) { } func TestLayerForMangledBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := db.Exec("insert into blocks (id, layer, block) values (?1, ?2, ?3);", func(stmt *sql.Statement) { stmt.BindBytes(1, []byte(`mangled-block-id`)) diff --git a/sql/certificates/certs_test.go b/sql/certificates/certs_test.go index cfe4ff9104e..2056f6b0919 100644 --- a/sql/certificates/certs_test.go +++ b/sql/certificates/certs_test.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 5 @@ -44,7 +45,7 @@ func makeCert(lid types.LayerID, bid types.BlockID) *types.Certificate { } func TestCertificates(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) got, err := Get(db, lid) @@ -94,7 +95,7 @@ func TestCertificates(t *testing.T) { } func TestHareOutput(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) ho, err := GetHareOutput(db, lid) @@ -132,7 +133,7 @@ func TestHareOutput(t *testing.T) { } func TestCertifiedBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lyrBlocks := map[types.LayerID]types.BlockID{ types.LayerID(layersPerEpoch - 1): {1}, // epoch 0 @@ -161,7 +162,7 @@ func TestCertifiedBlock(t *testing.T) { } func TestDeleteCert(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, Add(db, types.LayerID(2), &types.Certificate{BlockID: types.BlockID{2}})) require.NoError(t, Add(db, types.LayerID(3), &types.Certificate{BlockID: types.BlockID{3}})) require.NoError(t, Add(db, types.LayerID(4), &types.Certificate{BlockID: types.BlockID{4}})) @@ -177,7 +178,7 @@ func TestDeleteCert(t *testing.T) { } func TestFirstInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lyrBlocks := map[types.LayerID]types.BlockID{ types.LayerID(layersPerEpoch - 1): {1}, // epoch 0 diff --git a/sql/database.go b/sql/database.go index 907f8d1dde6..1a75ab86b83 100644 --- a/sql/database.go +++ b/sql/database.go @@ -5,8 +5,6 @@ import ( "errors" "fmt" "maps" - "slices" - "sort" "strings" "sync" "sync/atomic" @@ -30,6 +28,9 @@ var ( ErrObjectExists = errors.New("database: object exists") // ErrTooNew is returned if database version is newer than expected. ErrTooNew = errors.New("database version is too new") + // ErrOldSchema is returned when the database version differs from the expected one + // and migrations are disabled. + ErrOldSchema = errors.New("old database version") ) const ( @@ -37,7 +38,7 @@ const ( beginImmediate = "BEGIN IMMEDIATE;" ) -//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./database.go +//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go github.com/spacemeshos/go-spacemesh/sql Executor // Executor is an interface for executing raw statement. type Executor interface { @@ -61,29 +62,27 @@ type Encoder func(*Statement) type Decoder func(*Statement) bool func defaultConf() *conf { - migrations, err := StateMigrations() - if err != nil { - panic(err) - } - return &conf{ - connections: 16, - migrations: migrations, - skipMigration: map[int]struct{}{}, - logger: zap.NewNop(), + enableMigrations: true, + connections: 16, + logger: zap.NewNop(), + schema: &Schema{}, } } type conf struct { - flags sqlite.OpenFlags - connections int - skipMigration map[int]struct{} - vacuumState int - migrations []Migration - enableLatency bool - cache bool - cacheSizes map[QueryCacheKind]int - logger *zap.Logger + enableMigrations bool + forceFresh bool + forceMigrations bool + connections int + vacuumState int + enableLatency bool + cache bool + cacheSizes map[QueryCacheKind]int + logger *zap.Logger + schema *Schema + allowSchemaDrift bool + ignoreSchemaDrift bool } // WithConnections overwrites number of pooled connections. @@ -93,48 +92,18 @@ func WithConnections(n int) Opt { } } +// WithLogger specifies logger for the database. func WithLogger(logger *zap.Logger) Opt { return func(c *conf) { c.logger = logger } } -// WithMigrations overwrites embedded migrations. -// Migrations are sorted by order before applying. -func WithMigrations(migrations []Migration) Opt { +// WithMigrationsDisabled disables migrations for the database. +// The migrations are enabled by default. +func WithMigrationsDisabled() Opt { return func(c *conf) { - sort.Slice(migrations, func(i, j int) bool { - return migrations[i].Order() < migrations[j].Order() - }) - c.migrations = migrations - } -} - -// WithMigration adds migration to the list of migrations. -// It will overwrite an existing migration with the same order. -func WithMigration(migration Migration) Opt { - return func(c *conf) { - for i, m := range c.migrations { - if m.Order() == migration.Order() { - c.migrations[i] = migration - return - } - if m.Order() > migration.Order() { - c.migrations = slices.Insert(c.migrations, i, migration) - return - } - } - c.migrations = append(c.migrations, migration) - } -} - -// WithSkipMigrations will update database version with executing associated migrations. -// It should be used at your own risk. -func WithSkipMigrations(i ...int) Opt { - return func(c *conf) { - for _, index := range i { - c.skipMigration[index] = struct{}{} - } + c.enableMigrations = false } } @@ -172,13 +141,54 @@ func WithQueryCacheSizes(sizes map[QueryCacheKind]int) Opt { } } +// WithForceMigrations forces database to run all the migrations instead +// of using a schema snapshot in case of a fresh database. +func WithForceMigrations(force bool) Opt { + return func(c *conf) { + c.forceMigrations = true + } +} + +// WithSchema specifies database schema script. +func WithDatabaseSchema(schema *Schema) Opt { + return func(c *conf) { + c.schema = schema + } +} + +// WithAllowSchemaDrift prevents Open from failing upon schema +// drift. A warning is printed instead. +func WithAllowSchemaDrift(allow bool) Opt { + return func(c *conf) { + c.allowSchemaDrift = allow + } +} + +func WithIgnoreSchemaDrift() Opt { + return func(c *conf) { + c.ignoreSchemaDrift = true + } +} + +func withForceFresh() Opt { + return func(c *conf) { + c.forceFresh = true + } +} + // Opt for configuring database. type Opt func(c *conf) -// InMemory database for testing. -func InMemory(opts ...Opt) *Database { - opts = append(opts, WithConnections(1)) - db, err := Open("file::memory:?mode=memory", opts...) +// OpenInMemory creates an in-memory database. +func OpenInMemory(opts ...Opt) (*sqliteDatabase, error) { + opts = append(opts, WithConnections(1), withForceFresh()) + return Open("file::memory:?mode=memory", opts...) +} + +// InMemory creates an in-memory database for testing and panics if +// there's an error. +func InMemory(opts ...Opt) *sqliteDatabase { + db, err := OpenInMemory(opts...) if err != nil { panic(err) } @@ -190,82 +200,92 @@ func InMemory(opts ...Opt) *Database { // Database is opened in WAL mode and pragma synchronous=normal. // https://sqlite.org/wal.html // https://www.sqlite.org/pragma.html#pragma_synchronous -func Open(uri string, opts ...Opt) (*Database, error) { +func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { config := defaultConf() for _, opt := range opts { opt(config) } - pool, err := sqlitex.Open(uri, config.flags, config.connections) + logger := config.logger.With(zap.String("uri", uri)) + var flags sqlite.OpenFlags + if !config.forceFresh { + flags = sqlite.SQLITE_OPEN_READWRITE | + sqlite.SQLITE_OPEN_WAL | + sqlite.SQLITE_OPEN_URI | + sqlite.SQLITE_OPEN_NOMUTEX + } + freshDB := config.forceFresh + pool, err := sqlitex.Open(uri, flags, config.connections) if err != nil { - return nil, fmt.Errorf("open db %s: %w", uri, err) + if config.forceFresh || sqlite.ErrCode(err) != sqlite.SQLITE_CANTOPEN { + return nil, fmt.Errorf("open db %s: %w", uri, err) + } + flags |= sqlite.SQLITE_OPEN_CREATE + freshDB = true + pool, err = sqlitex.Open(uri, flags, config.connections) + if err != nil { + return nil, fmt.Errorf("create db %s: %w", uri, err) + } } - db := &Database{pool: pool} + db := &sqliteDatabase{pool: pool} if config.enableLatency { db.latency = newQueryLatency() } - //nolint:nestif - if config.migrations != nil { - before, err := version(db) - if err != nil { - return nil, err - } - after := 0 - if len(config.migrations) > 0 { - after = config.migrations[len(config.migrations)-1].Order() + if freshDB && !config.forceMigrations { + if err := config.schema.Apply(db); err != nil { + return nil, errors.Join( + fmt.Errorf("error running schema script: %w", err), + db.Close()) } - if before > after { - pool.Close() - config.logger.Error("database version is newer than expected - downgrade is not supported", - zap.String("uri", uri), + } else { + before, after, err := config.schema.CheckDBVersion(logger, db) + switch { + case err != nil: + return nil, errors.Join(err, db.Close()) + case before != after && config.enableMigrations: + logger.Info("running migrations", zap.Int("current version", before), zap.Int("target version", after), ) - return nil, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) - } - config.logger.Info("running migrations", - zap.String("uri", uri), - zap.Int("current version", before), - zap.Int("target version", after), - ) - for i, m := range config.migrations { - if m.Order() <= before { - continue + if err := config.schema.Migrate( + logger, db, before, config.vacuumState, + ); err != nil { + return nil, errors.Join(err, db.Close()) } - if err := db.WithTx(context.Background(), func(tx *Tx) error { - if _, ok := config.skipMigration[m.Order()]; !ok { - if err := m.Apply(tx); err != nil { - for j := i; j >= 0 && config.migrations[j].Order() > before; j-- { - if e := config.migrations[j].Rollback(); e != nil { - err = errors.Join(err, fmt.Errorf("rollback %s: %w", m.Name(), e)) - break - } - } - - return fmt.Errorf("apply %s: %w", m.Name(), err) - } - } - // version is set intentionally even if actual migration was skipped - if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.Order()), nil, nil); err != nil { - return fmt.Errorf("update user_version to %d: %w", m.Order(), err) - } - return nil - }); err != nil { - err = errors.Join(err, db.Close()) - return nil, err - } - - if config.vacuumState != 0 && before <= config.vacuumState { - if err := Vacuum(db); err != nil { - err = errors.Join(err, db.Close()) - return nil, err - } - } - before = m.Order() + case before != after: + logger.Error("database version is too old", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return nil, errors.Join( + fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after), + db.Close()) } + } + if !config.ignoreSchemaDrift { + loaded, err := LoadDBSchemaScript(db) + if err != nil { + return nil, errors.Join( + fmt.Errorf("error loading database schema: %w", err), + db.Close()) + } + diff := config.schema.Diff(loaded) + switch { + case diff == "": // ok + case config.allowSchemaDrift: + logger.Warn("database schema drift detected", + zap.String("uri", uri), + zap.String("diff", diff), + ) + default: + return nil, errors.Join( + fmt.Errorf("schema drift detected (uri %s):\n%s", uri, diff), + db.Close()) + } } + if config.cache { - config.logger.Debug("using query cache", zap.Any("sizes", config.cacheSizes)) + logger.Debug("using query cache", zap.Any("sizes", config.cacheSizes)) db.queryCache = &queryCache{cacheSizesByKind: config.cacheSizes} } db.queryCount.Store(0) @@ -277,13 +297,32 @@ func Version(uri string) (int, error) { if err != nil { return 0, fmt.Errorf("open db %s: %w", uri, err) } - db := &Database{pool: pool} + db := &sqliteDatabase{pool: pool} defer db.Close() return version(db) } -// Database is an instance of sqlite database. -type Database struct { +// Database represents a database. +type Database interface { + Executor + QueryCache + Close() error + QueryCount() int + QueryCache() QueryCache + Tx(ctx context.Context) (Transaction, error) + WithTx(ctx context.Context, exec func(Transaction) error) error + TxImmediate(ctx context.Context) (Transaction, error) + WithTxImmediate(ctx context.Context, exec func(Transaction) error) error +} + +// Transaction represents a transaction. +type Transaction interface { + Executor + Commit() error + Release() error +} + +type sqliteDatabase struct { *queryCache pool *sqlitex.Pool @@ -294,7 +333,9 @@ type Database struct { queryCount atomic.Int64 } -func (db *Database) getConn(ctx context.Context) *sqlite.Conn { +var _ Database = &sqliteDatabase{} + +func (db *sqliteDatabase) getConn(ctx context.Context) *sqlite.Conn { start := time.Now() conn := db.pool.Get(ctx) if conn != nil { @@ -303,19 +344,19 @@ func (db *Database) getConn(ctx context.Context) *sqlite.Conn { return conn } -func (db *Database) getTx(ctx context.Context, initstmt string) (*Tx, error) { +func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx, error) { conn := db.getConn(ctx) if conn == nil { return nil, ErrNoConnection } - tx := &Tx{queryCache: db.queryCache, db: db, conn: conn} + tx := &sqliteTx{queryCache: db.queryCache, db: db, conn: conn} if err := tx.begin(initstmt); err != nil { return nil, err } return tx, nil } -func (db *Database) withTx(ctx context.Context, initstmt string, exec func(*Tx) error) error { +func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func(Transaction) error) error { tx, err := db.getTx(ctx, initstmt) if err != nil { return err @@ -335,13 +376,13 @@ func (db *Database) withTx(ctx context.Context, initstmt string, exec func(*Tx) // after one of the write statements. // // https://www.sqlite.org/lang_transaction.html -func (db *Database) Tx(ctx context.Context) (*Tx, error) { +func (db *sqliteDatabase) Tx(ctx context.Context) (Transaction, error) { return db.getTx(ctx, beginDefault) } // WithTx will pass initialized deferred transaction to exec callback. // Will commit only if error is nil. -func (db *Database) WithTx(ctx context.Context, exec func(*Tx) error) error { +func (db *sqliteDatabase) WithTx(ctx context.Context, exec func(Transaction) error) error { return db.withTx(ctx, beginImmediate, exec) } @@ -350,13 +391,16 @@ func (db *Database) WithTx(ctx context.Context, exec func(*Tx) error) error { // IMMEDIATE cause the database connection to start a new write immediately, without waiting // for a write statement. The BEGIN IMMEDIATE might fail with SQLITE_BUSY if another write // transaction is already active on another database connection. -func (db *Database) TxImmediate(ctx context.Context) (*Tx, error) { +func (db *sqliteDatabase) TxImmediate(ctx context.Context) (Transaction, error) { return db.getTx(ctx, beginImmediate) } // WithTxImmediate will pass initialized immediate transaction to exec callback. // Will commit only if error is nil. -func (db *Database) WithTxImmediate(ctx context.Context, exec func(*Tx) error) error { +func (db *sqliteDatabase) WithTxImmediate( + ctx context.Context, + exec func(Transaction) error, +) error { return db.withTx(ctx, beginImmediate, exec) } @@ -368,7 +412,7 @@ func (db *Database) WithTxImmediate(ctx context.Context, exec func(*Tx) error) e // // Note that Exec will block until database is closed or statement has finished. // If application needs to control statement execution lifetime use one of the transaction. -func (db *Database) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { +func (db *sqliteDatabase) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { db.queryCount.Add(1) conn := db.getConn(context.Background()) if conn == nil { @@ -385,7 +429,7 @@ func (db *Database) Exec(query string, encoder Encoder, decoder Decoder) (int, e } // Close closes all pooled connections. -func (db *Database) Close() error { +func (db *sqliteDatabase) Close() error { db.closeMux.Lock() defer db.closeMux.Unlock() if db.closed { @@ -400,12 +444,12 @@ func (db *Database) Close() error { // QueryCount returns the number of queries executed, including failed // queries, but not counting transaction start / commit / rollback. -func (db *Database) QueryCount() int { +func (db *sqliteDatabase) QueryCount() int { return int(db.queryCount.Load()) } // Return database's QueryCache. -func (db *Database) QueryCache() QueryCache { +func (db *sqliteDatabase) QueryCache() QueryCache { return db.queryCache } @@ -423,11 +467,7 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in for { row, err := stmt.Step() if err != nil { - code := sqlite.ErrCode(err) - if code == sqlite.SQLITE_CONSTRAINT_PRIMARYKEY || code == sqlite.SQLITE_CONSTRAINT_UNIQUE { - return 0, ErrObjectExists - } - return 0, fmt.Errorf("step %d: %w", rows, err) + return 0, fmt.Errorf("step %d: %w", rows, mapSqliteError(err)) } if !row { return rows, nil @@ -446,48 +486,48 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in } } -// Tx is wrapper for database transaction. -type Tx struct { +// sqliteTx is wrapper for database transaction. +type sqliteTx struct { *queryCache - db *Database + db *sqliteDatabase conn *sqlite.Conn committed bool err error } -func (tx *Tx) begin(initstmt string) error { +func (tx *sqliteTx) begin(initstmt string) error { stmt := tx.conn.Prep(initstmt) _, err := stmt.Step() if err != nil { - return fmt.Errorf("begin: %w", err) + return fmt.Errorf("begin: %w", mapSqliteError(err)) } return nil } // Commit transaction. -func (tx *Tx) Commit() error { +func (tx *sqliteTx) Commit() error { stmt := tx.conn.Prep("COMMIT;") _, tx.err = stmt.Step() if tx.err != nil { - return tx.err + return mapSqliteError(tx.err) } tx.committed = true return nil } // Release transaction. Every transaction that was created must be released. -func (tx *Tx) Release() error { +func (tx *sqliteTx) Release() error { defer tx.db.pool.Put(tx.conn) if tx.committed { return nil } stmt := tx.conn.Prep("ROLLBACK") _, tx.err = stmt.Step() - return tx.err + return mapSqliteError(tx.err) } // Exec query. -func (tx *Tx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { +func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { tx.db.queryCount.Add(1) if tx.db.latency != nil { start := time.Now() @@ -498,6 +538,20 @@ func (tx *Tx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) return exec(tx.conn, query, encoder, decoder) } +func mapSqliteError(err error) error { + code := sqlite.ErrCode(err) + if code == sqlite.SQLITE_CONSTRAINT_PRIMARYKEY || code == sqlite.SQLITE_CONSTRAINT_UNIQUE { + return ErrObjectExists + } + if code == sqlite.SQLITE_INTERRUPT { + // TODO: we probably should check if there was indeed a context + // that was canceled. But we're likely to replace crawshaw library + // in future so this part should be rewritten anyway + return context.Canceled + } + return err +} + // Blob represents a binary blob data. It can be reused efficiently // across multiple data retrieval operations, minimizing reallocations // of the underlying byte slice. @@ -582,3 +636,15 @@ func LoadBlob(db Executor, cmd string, id []byte, blob *Blob) error { func IsNull(stmt *Statement, col int) bool { return stmt.ColumnType(col) == sqlite.SQLITE_NULL } + +// StateDatabase is a Database used for Spacemesh state. +type StateDatabase interface { + Database + IsStateDatabase() +} + +// LocalDatabase is a Database used for local node data. +type LocalDatabase interface { + Database + IsLocalDatabase() +} diff --git a/sql/database_test.go b/sql/database_test.go index ab7d71438a7..e4f222143ea 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -9,29 +9,25 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" ) func Test_Transaction_Isolation(t *testing.T) { - ctrl := gomock.NewController(t) - testMigration := NewMockMigration(ctrl) - testMigration.EXPECT().Name().Return("test").AnyTimes() - testMigration.EXPECT().Order().Return(1).AnyTimes() - testMigration.EXPECT().Apply(gomock.Any()).DoAndReturn(func(e Executor) error { - if _, err := e.Exec(`create table testing1 ( - id varchar primary key, - field int - )`, nil, nil); err != nil { - return err - } - return nil - }) - db := InMemory( - WithMigrations([]Migration{testMigration}), + WithLogger(zaptest.NewLogger(t)), WithConnections(10), WithLatencyMetering(true), + WithDatabaseSchema(&Schema{ + Script: `create table testing1 ( + id varchar primary key, + field int + );`, + }), + WithIgnoreSchemaDrift(), ) - tx, err := db.Tx(context.Background()) require.NoError(t, err) @@ -67,28 +63,37 @@ func Test_Migration_Rollback(t *testing.T) { migration2.EXPECT().Name().Return("test").AnyTimes() migration2.EXPECT().Order().Return(2).AnyTimes() - migration1.EXPECT().Apply(gomock.Any()).Return(nil) - migration2.EXPECT().Apply(gomock.Any()).Return(errors.New("migration 2 failed")) + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errors.New("migration 2 failed")) migration2.EXPECT().Rollback().Return(nil) dbFile := filepath.Join(t.TempDir(), "test.sql") _, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), + WithForceMigrations(true), ) require.ErrorContains(t, err, "migration 2 failed") } func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) migration1.EXPECT().Name().Return("test").AnyTimes() migration1.EXPECT().Order().Return(1).AnyTimes() - migration1.EXPECT().Apply(gomock.Any()).Return(nil) + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) dbFile := filepath.Join(t.TempDir(), "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -96,15 +101,48 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { migration2 := NewMockMigration(ctrl) migration2.EXPECT().Name().Return("test").AnyTimes() migration2.EXPECT().Order().Return(2).AnyTimes() - migration2.EXPECT().Apply(gomock.Any()).Return(errors.New("migration 2 failed")) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errors.New("migration 2 failed")) migration2.EXPECT().Rollback().Return(nil) _, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), ) require.ErrorContains(t, err, "migration 2 failed") } +func Test_Migration_Disabled(t *testing.T) { + ctrl := gomock.NewController(t) + migration1 := NewMockMigration(ctrl) + migration1.EXPECT().Name().Return("test").AnyTimes() + migration1.EXPECT().Order().Return(1).AnyTimes() + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) + + dbFile := filepath.Join(t.TempDir(), "test.sql") + db, err := Open("file:"+dbFile, + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + WithIgnoreSchemaDrift(), + ) + require.NoError(t, err) + require.NoError(t, db.Close()) + + migration2 := NewMockMigration(ctrl) + migration2.EXPECT().Order().Return(2).AnyTimes() + + _, err = Open("file:"+dbFile, + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), + WithMigrationsDisabled(), + ) + require.ErrorIs(t, err, ErrOldSchema) +} + func TestDatabaseSkipMigrations(t *testing.T) { ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) @@ -113,12 +151,17 @@ func TestDatabaseSkipMigrations(t *testing.T) { migration2 := NewMockMigration(ctrl) migration2.EXPECT().Name().Return("test").AnyTimes() migration2.EXPECT().Order().Return(2).AnyTimes() - migration2.EXPECT().Apply(gomock.Any()).Return(nil) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) + schema := &Schema{ + Migrations: MigrationList{migration1, migration2}, + } + schema.SkipMigrations(1) dbFile := filepath.Join(t.TempDir(), "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), - WithSkipMigrations(1), + WithDatabaseSchema(schema), + WithForceMigrations(true), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -126,26 +169,36 @@ func TestDatabaseSkipMigrations(t *testing.T) { func TestDatabaseVacuumState(t *testing.T) { dir := t.TempDir() + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) migration1.EXPECT().Order().Return(1).AnyTimes() - migration1.EXPECT().Apply(gomock.Any()).Return(nil).Times(1) + migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil).Times(1) migration2 := NewMockMigration(ctrl) migration2.EXPECT().Order().Return(2).AnyTimes() - migration2.EXPECT().Apply(gomock.Any()).Return(nil).Times(1) + migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil).Times(1) dbFile := filepath.Join(dir, "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) db, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), WithVacuumState(2), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -156,7 +209,7 @@ func TestDatabaseVacuumState(t *testing.T) { } func TestQueryCount(t *testing.T) { - db := InMemory() + db := InMemory(WithLogger(zaptest.NewLogger(t)), WithIgnoreSchemaDrift()) require.Equal(t, 0, db.QueryCount()) n, err := db.Exec("select 1", nil, nil) @@ -171,6 +224,7 @@ func TestQueryCount(t *testing.T) { func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { dir := t.TempDir() + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) @@ -180,14 +234,71 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { migration2.EXPECT().Order().Return(2).AnyTimes() dbFile := filepath.Join(dir, "test.sql") - db, err := Open("file:" + dbFile) + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithForceMigrations(true), + WithIgnoreSchemaDrift()) require.NoError(t, err) _, err = db.Exec("PRAGMA user_version = 3", nil, nil) require.NoError(t, err) require.NoError(t, db.Close()) _, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), + WithIgnoreSchemaDrift(), ) require.ErrorIs(t, err, ErrTooNew) } + +func TestSchemaDrift(t *testing.T) { + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + dbFile := filepath.Join(t.TempDir(), "test.sql") + schema := &Schema{ + // Not using ` here to avoid schema drift warnings due to whitespace + // TODO: ignore whitespace and comments during schema comparison + Script: "PRAGMA user_version = 0;\n" + + "CREATE TABLE testing1 (\n" + + " id varchar primary key,\n" + + " field int\n" + + ");\n", + } + db, err := Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + ) + require.NoError(t, err) + + _, err = db.Exec("create table newtbl (id int)", nil, nil) + require.NoError(t, err) + + require.NoError(t, db.Close()) + require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") + + _, err = Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + ) + require.Error(t, err) + require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, err.Error()) + require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") + + db, err = Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + WithAllowSchemaDrift(true), + ) + require.NoError(t, db.Close()) + require.NoError(t, err) + require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") + require.Equal(t, "database schema drift detected", observedLogs.All()[0].Message) + require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, + observedLogs.All()[0].ContextMap()["diff"]) +} diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index ad87b4838de..b1b6b30ee94 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -11,10 +11,11 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMalicious(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nodeID := types.NodeID{1, 1, 1, 1} mal, err := IsMalicious(db, nodeID) @@ -60,7 +61,7 @@ func TestMalicious(t *testing.T) { } func Test_GetMalicious(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() got, err := GetMalicious(db) require.NoError(t, err) require.Nil(t, got) @@ -78,7 +79,7 @@ func Test_GetMalicious(t *testing.T) { } func TestLoadMalfeasanceBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() nid1 := types.RandomNodeID() @@ -123,7 +124,7 @@ func TestMarried(t *testing.T) { t.Parallel() t.Run("identity not in DB", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() married, err := Married(db, id) @@ -138,7 +139,7 @@ func TestMarried(t *testing.T) { }) t.Run("identity in DB", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() // add ID in the DB @@ -160,7 +161,7 @@ func TestMarriageATX(t *testing.T) { t.Parallel() t.Run("not married", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() _, err := MarriageATX(db, id) @@ -168,7 +169,7 @@ func TestMarriageATX(t *testing.T) { }) t.Run("married", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() marriage := MarriageData{ @@ -187,7 +188,7 @@ func TestMarriageATX(t *testing.T) { func TestMarriage(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() marriage := MarriageData{ @@ -206,7 +207,7 @@ func TestEquivocationSet(t *testing.T) { t.Parallel() t.Run("equivocation set of married IDs", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ @@ -233,7 +234,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("equivocation set for unmarried ID contains itself only", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() set, err := EquivocationSet(db, id) require.NoError(t, err) @@ -241,7 +242,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("can't escape the marriage", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ types.RandomNodeID(), @@ -274,7 +275,7 @@ func TestEquivocationSet(t *testing.T) { } }) t.Run("married doesn't become malicious immediately", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() id := types.RandomNodeID() require.NoError(t, SetMarriage(db, id, &MarriageData{ATX: atx})) @@ -293,7 +294,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("all IDs in equivocation set are malicious if one is", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ types.RandomNodeID(), @@ -317,7 +318,7 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { t.Parallel() t.Run("married IDs", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ids := []types.NodeID{ types.RandomNodeID(), types.RandomNodeID(), @@ -333,7 +334,7 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { require.Equal(t, ids, set) }) t.Run("empty set", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() set, err := EquivocationSetByMarriageATX(db, types.RandomATXID()) require.NoError(t, err) require.Empty(t, set) diff --git a/sql/interface.go b/sql/interface.go index 3d3931b60e1..23859f5202b 100644 --- a/sql/interface.go +++ b/sql/interface.go @@ -1,10 +1,12 @@ package sql +import "go.uber.org/zap" + //go:generate mockgen -typed -package=sql -destination=./mocks.go -source=./interface.go // Migration is interface for migrations provider. type Migration interface { - Apply(db Executor) error + Apply(db Executor, logger *zap.Logger) error Rollback() error Name() string Order() int diff --git a/sql/layers/layers_test.go b/sql/layers/layers_test.go index f16a5e40d71..70f1f57bf15 100644 --- a/sql/layers/layers_test.go +++ b/sql/layers/layers_test.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 4 @@ -20,7 +21,7 @@ func TestMain(m *testing.M) { } func TestWeakCoin(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) _, err := GetWeakCoin(db, lid) @@ -38,7 +39,7 @@ func TestWeakCoin(t *testing.T) { } func TestAppliedBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) _, err := GetApplied(db, lid) @@ -67,7 +68,7 @@ func TestAppliedBlock(t *testing.T) { } func TestFirstAppliedInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blks := map[types.LayerID]types.BlockID{ types.EpochID(1).FirstLayer(): {1}, types.EpochID(2).FirstLayer(): types.EmptyBlockID, @@ -107,7 +108,7 @@ func TestFirstAppliedInEpoch(t *testing.T) { } func TestUnsetAppliedFrom(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) last := lid.Add(99) for i := lid; !i.After(last); i = i.Add(1) { @@ -123,7 +124,7 @@ func TestUnsetAppliedFrom(t *testing.T) { } func TestStateHash(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() layers := []uint32{9, 10, 8, 7} hashes := []types.Hash32{{1}, {2}, {3}, {4}} for i := range layers { @@ -147,7 +148,7 @@ func TestStateHash(t *testing.T) { } func TestSetHashes(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := GetAggregatedHash(db, types.LayerID(11)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -178,7 +179,7 @@ func TestSetHashes(t *testing.T) { } func TestProcessed(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid, err := GetProcessed(db) require.NoError(t, err) require.Equal(t, types.LayerID(0), lid) @@ -193,7 +194,7 @@ func TestProcessed(t *testing.T) { } func TestGetAggHashes(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() hashes := make(map[types.LayerID]types.Hash32) diff --git a/sql/localsql/local.go b/sql/localsql/local.go deleted file mode 100644 index cbf1d9f2f4e..00000000000 --- a/sql/localsql/local.go +++ /dev/null @@ -1,38 +0,0 @@ -package localsql - -import "github.com/spacemeshos/go-spacemesh/sql" - -type Database struct { - *sql.Database -} - -func Open(uri string, opts ...sql.Opt) (*Database, error) { - migrations, err := sql.LocalMigrations() - if err != nil { - return nil, err - } - defaultOpts := []sql.Opt{ - sql.WithConnections(16), - sql.WithMigrations(migrations), - } - opts = append(defaultOpts, opts...) - db, err := sql.Open(uri, opts...) - if err != nil { - return nil, err - } - return &Database{Database: db}, nil -} - -func InMemory(opts ...sql.Opt) *Database { - migrations, err := sql.LocalMigrations() - if err != nil { - panic(err) - } - defaultOpts := []sql.Opt{ - sql.WithConnections(1), - sql.WithMigrations(migrations), - } - opts = append(defaultOpts, opts...) - db := sql.InMemory(opts...) - return &Database{Database: db} -} diff --git a/sql/localsql/local_test.go b/sql/localsql/local_test.go deleted file mode 100644 index 1fed142b1af..00000000000 --- a/sql/localsql/local_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package localsql - -import ( - "path/filepath" - "strings" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/spacemeshos/go-spacemesh/sql" -) - -func TestDatabase_MigrateTwice_NoOp(t *testing.T) { - file := filepath.Join(t.TempDir(), "test.db") - db, err := Open("file:" + file) - require.NoError(t, err) - - var sqls1 []string - _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - sqls1 = append(sqls1, sql) - return true - }) - require.NoError(t, err) - require.NoError(t, db.Close()) - - db, err = Open("file:" + file) - require.NoError(t, err) - var sqls2 []string - _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - sqls2 = append(sqls2, sql) - return true - }) - require.NoError(t, err) - require.NoError(t, db.Close()) - - require.Equal(t, sqls1, sqls2) -} diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go new file mode 100644 index 00000000000..0b79bc546f5 --- /dev/null +++ b/sql/localsql/localsql.go @@ -0,0 +1,71 @@ +package localsql + +import ( + "embed" + "strings" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:generate go run ../schemagen -dbtype local -output schema/schema.sql + +//go:embed schema/schema.sql +var schemaScript string + +//go:embed schema/migrations/*.sql +var migrations embed.FS + +type database struct { + sql.Database +} + +var _ sql.LocalDatabase = &database{} + +func (d *database) IsLocalDatabase() {} + +// Schema returns the schema for the local database. +func Schema() (*sql.Schema, error) { + sqlMigrations, err := sql.LoadSQLMigrations(migrations) + if err != nil { + return nil, err + } + // NOTE: coded state migrations can be added here + // They can be a part of this localsql package + return &sql.Schema{ + Script: strings.ReplaceAll(schemaScript, "\r", ""), + Migrations: sqlMigrations, + }, nil +} + +// Open opens a local database. +func Open(uri string, opts ...sql.Opt) (*database, error) { + schema, err := Schema() + if err != nil { + return nil, err + } + defaultOpts := []sql.Opt{ + sql.WithConnections(16), + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db, err := sql.Open(uri, opts...) + if err != nil { + return nil, err + } + return &database{Database: db}, nil +} + +// Open opens an in-memory local database. +func InMemory(opts ...sql.Opt) *database { + schema, err := Schema() + if err != nil { + panic(err) + } + defaultOpts := []sql.Opt{ + sql.WithConnections(1), + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db := sql.InMemory(opts...) + return &database{Database: db} +} diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go new file mode 100644 index 00000000000..320702d7ebd --- /dev/null +++ b/sql/localsql/localsql_test.go @@ -0,0 +1,73 @@ +package localsql + +import ( + "path/filepath" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func TestIdempotentMigration(t *testing.T) { + observer, observedLogs := observer.New(zapcore.InfoLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + file := filepath.Join(t.TempDir(), "test.db") + db, err := Open("file:"+file, sql.WithForceMigrations(true), sql.WithLogger(logger)) + require.NoError(t, err) + + var versionA int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionA = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + require.NoError(t, db.Close()) + + require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") + l := observedLogs.All()[0] + require.Equal(t, "running migrations", l.Message) + require.Equal(t, int64(0), l.ContextMap()["current version"]) + require.Equal(t, int64(versionA), l.ContextMap()["target version"]) + + db, err = Open("file:"+file, sql.WithLogger(logger)) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var versionB int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionB = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + schema, err := Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), versionA) + require.Equal(t, expectedVersion.Order(), versionB) + + require.NoError(t, db.Close()) + // make sure there's no schema drift warnings in the logs + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") +} diff --git a/sql/migrations/local/0001_initial.sql b/sql/localsql/schema/migrations/0001_initial.sql similarity index 100% rename from sql/migrations/local/0001_initial.sql rename to sql/localsql/schema/migrations/0001_initial.sql diff --git a/sql/migrations/local/0002_extend_initial_post.sql b/sql/localsql/schema/migrations/0002_extend_initial_post.sql similarity index 100% rename from sql/migrations/local/0002_extend_initial_post.sql rename to sql/localsql/schema/migrations/0002_extend_initial_post.sql diff --git a/sql/migrations/local/0003_add_nipost_builder_state.sql b/sql/localsql/schema/migrations/0003_add_nipost_builder_state.sql similarity index 100% rename from sql/migrations/local/0003_add_nipost_builder_state.sql rename to sql/localsql/schema/migrations/0003_add_nipost_builder_state.sql diff --git a/sql/migrations/local/0004_atx_sync.sql b/sql/localsql/schema/migrations/0004_atx_sync.sql similarity index 100% rename from sql/migrations/local/0004_atx_sync.sql rename to sql/localsql/schema/migrations/0004_atx_sync.sql diff --git a/sql/migrations/local/0005_fast_startup.sql b/sql/localsql/schema/migrations/0005_fast_startup.sql similarity index 100% rename from sql/migrations/local/0005_fast_startup.sql rename to sql/localsql/schema/migrations/0005_fast_startup.sql diff --git a/sql/migrations/local/0006_prepared_activeset.sql b/sql/localsql/schema/migrations/0006_prepared_activeset.sql similarity index 100% rename from sql/migrations/local/0006_prepared_activeset.sql rename to sql/localsql/schema/migrations/0006_prepared_activeset.sql diff --git a/sql/migrations/local/0007_malfeasance_sync.sql b/sql/localsql/schema/migrations/0007_malfeasance_sync.sql similarity index 100% rename from sql/migrations/local/0007_malfeasance_sync.sql rename to sql/localsql/schema/migrations/0007_malfeasance_sync.sql diff --git a/sql/migrations/local/0008_next.sql b/sql/localsql/schema/migrations/0008_next.sql similarity index 100% rename from sql/migrations/local/0008_next.sql rename to sql/localsql/schema/migrations/0008_next.sql diff --git a/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql b/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql new file mode 100644 index 00000000000..d0ade1f535a --- /dev/null +++ b/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql @@ -0,0 +1,12 @@ +ALTER TABLE malfeasance_sync_state RENAME TO malfeasance_sync_state_old; + +CREATE TABLE malfeasance_sync_state +( + id INT NOT NULL PRIMARY KEY, + timestamp INT NOT NULL +); + +INSERT INTO malfeasance_sync_state (id, timestamp) +SELECT 1, timestamp FROM malfeasance_sync_state_old LIMIT 1; + +DROP TABLE malfeasance_sync_state_old; diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql new file mode 100755 index 00000000000..02c44d3ccbd --- /dev/null +++ b/sql/localsql/schema/schema.sql @@ -0,0 +1,83 @@ +PRAGMA user_version = 9; +CREATE TABLE atx_sync_requests +( + epoch INT NOT NULL, + timestamp INT NOT NULL, total INTEGER, downloaded INTEGER, + PRIMARY KEY (epoch) +) WITHOUT ROWID; +CREATE TABLE atx_sync_state +( + epoch INT NOT NULL, + id CHAR(32) NOT NULL, + requests INT NOT NULL DEFAULT 0, + PRIMARY KEY (epoch, id) +) WITHOUT ROWID; +CREATE TABLE "challenge" +( + id CHAR(32) PRIMARY KEY, + epoch UNSIGNED INT NOT NULL, + sequence UNSIGNED INT NOT NULL, + prev_atx CHAR(32) NOT NULL, + pos_atx CHAR(32) NOT NULL, + commit_atx CHAR(32), + post_nonce UNSIGNED INT, + post_indices VARCHAR, + post_pow UNSIGNED LONG INT +, poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; +CREATE TABLE malfeasance_sync_state +( + id INT NOT NULL PRIMARY KEY, + timestamp INT NOT NULL +); +CREATE TABLE nipost +( + id CHAR(32) PRIMARY KEY, + post_nonce UNSIGNED INT NOT NULL, + post_indices VARCHAR NOT NULL, + post_pow UNSIGNED LONG INT NOT NULL, + + num_units UNSIGNED INT NOT NULL, + vrf_nonce UNSIGNED LONG INT NOT NULL, + + poet_proof_membership VARCHAR NOT NULL, + poet_proof_ref CHAR(32) NOT NULL, + labels_per_unit UNSIGNED INT NOT NULL +) WITHOUT ROWID; +CREATE TABLE poet_certificates +( + node_id BLOB NOT NULL, + certifier_id BLOB NOT NULL, + certificate BLOB NOT NULL, + signature BLOB NOT NULL +); +CREATE UNIQUE INDEX idx_poet_certificates ON poet_certificates (node_id, certifier_id); +CREATE TABLE poet_registration +( + id CHAR(32) NOT NULL, + hash CHAR(32) NOT NULL, + address VARCHAR NOT NULL, + round_id VARCHAR NOT NULL, + round_end INT NOT NULL, + + PRIMARY KEY (id, address) +) WITHOUT ROWID; +CREATE TABLE "post" +( + id CHAR(32) PRIMARY KEY, + post_nonce UNSIGNED INT NOT NULL, + post_indices VARCHAR NOT NULL, + post_pow UNSIGNED LONG INT NOT NULL, + + num_units UNSIGNED INT NOT NULL, + commit_atx CHAR(32) NOT NULL, + vrf_nonce UNSIGNED LONG INT NOT NULL +, challenge BLOB NOT NULL DEFAULT x'0000000000000000000000000000000000000000000000000000000000000000'); +CREATE TABLE prepared_activeset +( + kind UNSIGNED INT NOT NULL, + epoch UNSIGNED INT NOT NULL, + id CHAR(32) NOT NULL, + weight UNSIGNED INT NOT NULL, + data BLOB NOT NULL, + PRIMARY KEY (kind, epoch) +) WITHOUT ROWID; diff --git a/sql/malsync/malsync.go b/sql/malsync/malsync.go index 9ae572b0ade..fac8b50fbaa 100644 --- a/sql/malsync/malsync.go +++ b/sql/malsync/malsync.go @@ -9,7 +9,7 @@ import ( func GetSyncState(db sql.Executor) (time.Time, error) { var timestamp time.Time - rows, err := db.Exec("select timestamp from malfeasance_sync_state", + rows, err := db.Exec("select timestamp from malfeasance_sync_state where id = 1", nil, func(stmt *sql.Statement) bool { v := stmt.ColumnInt64(0) if v > 0 { @@ -17,21 +17,25 @@ func GetSyncState(db sql.Executor) (time.Time, error) { } return true }) - if err != nil { + switch { + case err != nil: return time.Time{}, fmt.Errorf("error getting malfeasance sync state: %w", err) - } else if rows != 1 { + case rows <= 1: + return timestamp, nil + default: return time.Time{}, fmt.Errorf("expected malfeasance_sync_state to have 1 row but got %d rows", rows) } - return timestamp, nil } func updateSyncState(db sql.Executor, ts int64) error { - _, err := db.Exec("update malfeasance_sync_state set timestamp = ?1", + if _, err := db.Exec( + `insert into malfeasance_sync_state (id, timestamp) values(1, ?1) + on conflict (id) do update set timestamp = ?1`, func(stmt *sql.Statement) { stmt.BindInt64(1, ts) - }, nil) - if err != nil { - return fmt.Errorf("error updating malfeasance sync state: %w", err) + }, nil, + ); err != nil { + return fmt.Errorf("error initializing malfeasance sync state: %w", err) } return nil } diff --git a/sql/metrics/prometheus.go b/sql/metrics/prometheus.go index 6a7ba2ef5f5..dcaa3206ffb 100644 --- a/sql/metrics/prometheus.go +++ b/sql/metrics/prometheus.go @@ -22,7 +22,7 @@ const ( type DBMetricsCollector struct { logger log.Logger checkInterval time.Duration - db *sql.Database + db sql.StateDatabase tablesList map[string]struct{} eg errgroup.Group cancel context.CancelFunc @@ -35,7 +35,7 @@ type DBMetricsCollector struct { // NewDBMetricsCollector creates new DBMetricsCollector. func NewDBMetricsCollector( ctx context.Context, - db *sql.Database, + db sql.StateDatabase, logger log.Logger, checkInterval time.Duration, ) *DBMetricsCollector { diff --git a/sql/migrations.go b/sql/migrations.go index 3b9a019e537..3a4d37844dc 100644 --- a/sql/migrations.go +++ b/sql/migrations.go @@ -1,25 +1,53 @@ package sql import ( - "bufio" - "bytes" - "embed" "fmt" "io/fs" + "regexp" + "slices" "strconv" "strings" + + "go.uber.org/zap" ) -//go:embed migrations/**/*.sql -var embedded embed.FS +// MigrationList denotes a list of migrations. +type MigrationList []Migration + +// AddMigration adds a Migration to the MigrationList, overriding the migration with the +// same order number if it already exists. The function returns updated migration list. +// The state of the original migration list is undefined after calling this function. +func (l MigrationList) AddMigration(migration Migration) MigrationList { + for i, m := range l { + if m.Order() == migration.Order() { + l[i] = migration + return l + } + if m.Order() > migration.Order() { + l = slices.Insert(l, i, migration) + return l + } + } + return append(l, migration) +} + +// Version returns database version for the specified migration list. +func (l MigrationList) Version() int { + if len(l) == 0 { + return 0 + } + return l[len(l)-1].Order() +} type sqlMigration struct { order int name string - content *bufio.Scanner + content string } -func (m *sqlMigration) Apply(db Executor) error { +var sqlCommentRx = regexp.MustCompile(`(?m)--.*$`) + +func (m *sqlMigration) Apply(db Executor, logger *zap.Logger) error { current, err := version(db) if err != nil { return err @@ -28,9 +56,14 @@ func (m *sqlMigration) Apply(db Executor) error { if m.order <= current { return nil } - for m.content.Scan() { - if _, err := db.Exec(m.content.Text(), nil, nil); err != nil { - return fmt.Errorf("exec %s: %w", m.content.Text(), err) + // TODO: use more advanced approach to split the SQL script + // into commands + for _, cmd := range strings.Split(m.content, ";") { + cmd = sqlCommentRx.ReplaceAllString(cmd, "") + if strings.TrimSpace(cmd) != "" { + if _, err := db.Exec(cmd, nil, nil); err != nil { + return fmt.Errorf("exec %s: %w", cmd, err) + } } } // binding values in pragma statement is not allowed @@ -65,18 +98,9 @@ func version(db Executor) (int, error) { return current, nil } -func StateMigrations() ([]Migration, error) { - return sqlMigrations("state") -} - -func LocalMigrations() ([]Migration, error) { - return sqlMigrations("local") -} - -func sqlMigrations(dbname string) ([]Migration, error) { - var migrations []Migration - root := fmt.Sprintf("migrations/%s", dbname) - err := fs.WalkDir(embedded, root, func(path string, d fs.DirEntry, err error) error { +func LoadSQLMigrations(fsys fs.FS) (MigrationList, error) { + var migrations MigrationList + err := fs.WalkDir(fsys, "schema/migrations", func(path string, d fs.DirEntry, err error) error { if err != nil { return fmt.Errorf("walkdir %s: %w", path, err) } @@ -91,21 +115,14 @@ func sqlMigrations(dbname string) ([]Migration, error) { if err != nil { return fmt.Errorf("invalid migration %s: %w", d.Name(), err) } - f, err := embedded.Open(path) + script, err := fs.ReadFile(fsys, path) if err != nil { return fmt.Errorf("read file %s: %w", path, err) } - scanner := bufio.NewScanner(f) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if i := bytes.Index(data, []byte(";")); i >= 0 { - return i + 1, data[0 : i+1], nil - } - return 0, nil, nil - }) migrations = append(migrations, &sqlMigration{ order: order, name: d.Name(), - content: scanner, + content: string(script), }) return nil }) diff --git a/sql/migrations_test.go b/sql/migrations_test.go deleted file mode 100644 index 13bfb86b806..00000000000 --- a/sql/migrations_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package sql - -import ( - "slices" - "testing" - - "github.com/stretchr/testify/require" -) - -func Test_MigrationsAppliedOnce(t *testing.T) { - db := InMemory() - - var version int - _, err := db.Exec("PRAGMA user_version;", nil, func(stmt *Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - - migrations, err := StateMigrations() - require.NoError(t, err) - expectedVersion := slices.MaxFunc(migrations, func(a, b Migration) int { - return a.Order() - b.Order() - }) - require.Equal(t, expectedVersion.Order(), version) -} diff --git a/sql/mocks.go b/sql/mocks.go index bbd869f3d49..975de979115 100644 --- a/sql/mocks.go +++ b/sql/mocks.go @@ -13,6 +13,7 @@ import ( reflect "reflect" gomock "go.uber.org/mock/gomock" + zap "go.uber.org/zap" ) // MockMigration is a mock of Migration interface. @@ -39,17 +40,17 @@ func (m *MockMigration) EXPECT() *MockMigrationMockRecorder { } // Apply mocks base method. -func (m *MockMigration) Apply(db Executor) error { +func (m *MockMigration) Apply(db Executor, logger *zap.Logger) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Apply", db) + ret := m.ctrl.Call(m, "Apply", db, logger) ret0, _ := ret[0].(error) return ret0 } // Apply indicates an expected call of Apply. -func (mr *MockMigrationMockRecorder) Apply(db any) *MockMigrationApplyCall { +func (mr *MockMigrationMockRecorder) Apply(db, logger any) *MockMigrationApplyCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockMigration)(nil).Apply), db) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockMigration)(nil).Apply), db, logger) return &MockMigrationApplyCall{Call: call} } @@ -65,13 +66,13 @@ func (c *MockMigrationApplyCall) Return(arg0 error) *MockMigrationApplyCall { } // Do rewrite *gomock.Call.Do -func (c *MockMigrationApplyCall) Do(f func(Executor) error) *MockMigrationApplyCall { +func (c *MockMigrationApplyCall) Do(f func(Executor, *zap.Logger) error) *MockMigrationApplyCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockMigrationApplyCall) DoAndReturn(f func(Executor) error) *MockMigrationApplyCall { +func (c *MockMigrationApplyCall) DoAndReturn(f func(Executor, *zap.Logger) error) *MockMigrationApplyCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/sql/mocks/mocks.go b/sql/mocks/mocks.go index c5133b30640..8d93a4a118b 100644 --- a/sql/mocks/mocks.go +++ b/sql/mocks/mocks.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ./database.go +// Source: github.com/spacemeshos/go-spacemesh/sql (interfaces: Executor) // // Generated by this command: // -// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./database.go +// mockgen -typed -package=mocks -destination=./mocks/mocks.go github.com/spacemeshos/go-spacemesh/sql Executor // // Package mocks is a generated GoMock package. diff --git a/sql/poets/poets_test.go b/sql/poets/poets_test.go index 9545f0075be..be7d88e1926 100644 --- a/sql/poets/poets_test.go +++ b/sql/poets/poets_test.go @@ -8,10 +8,11 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() refs := []types.PoetProofRef{ {0xca, 0xfe}, @@ -51,7 +52,7 @@ func TestHas(t *testing.T) { } func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() refs := []types.PoetProofRef{ {0xca, 0xfe}, @@ -102,7 +103,7 @@ func TestGet(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ref := types.PoetProofRef{0xca, 0xfe} poet := []byte("proof0") @@ -121,7 +122,7 @@ func TestAdd(t *testing.T) { } func TestGetRef(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sids := [][]byte{ []byte("sid1"), diff --git a/sql/querycache.go b/sql/querycache.go index e9840a29599..38db15f8803 100644 --- a/sql/querycache.go +++ b/sql/querycache.go @@ -17,14 +17,14 @@ type ( var NullQueryCache QueryCache = (*queryCache)(nil) -type queryCacheKey struct { +type QueryCacheItemKey struct { Kind QueryCacheKind Key string } // QueryCacheKey creates a key for QueryCache. -func QueryCacheKey(kind QueryCacheKind, key string) queryCacheKey { - return queryCacheKey{Kind: kind, Key: key} +func QueryCacheKey(kind QueryCacheKind, key string) QueryCacheItemKey { + return QueryCacheItemKey{Kind: kind, Key: key} } // QueryCacheSubKey denotes a cache subkey. The empty subkey refers to the main @@ -60,14 +60,14 @@ type QueryCache interface { // called for this cache. GetValue( ctx context.Context, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve UntypedRetrieveFunc, ) (any, error) // UpdateSlice updates the slice stored in the cache by invoking the // specified SliceAppender. If the entry is not cached, the method does // nothing. - UpdateSlice(key queryCacheKey, update SliceAppender) + UpdateSlice(key QueryCacheItemKey, update SliceAppender) // ClearCache empties the cache. ClearCache() } @@ -87,7 +87,7 @@ func IsCached(db any) bool { func WithCachedValue[T any]( ctx context.Context, db any, - key queryCacheKey, + key QueryCacheItemKey, retrieve func(ctx context.Context) (T, error), ) (T, error) { return WithCachedSubKey(ctx, db, key, mainSubKey, retrieve) @@ -100,7 +100,7 @@ func WithCachedValue[T any]( func WithCachedSubKey[T any]( ctx context.Context, db any, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve func(ctx context.Context) (T, error), ) (T, error) { @@ -124,7 +124,7 @@ func WithCachedSubKey[T any]( // AppendToCachedSlice adds a value to the slice stored in the cache by invoking // the specified SliceAppender. If the entry is not cached, the function does // nothing. -func AppendToCachedSlice[T any](db any, key queryCacheKey, v T) { +func AppendToCachedSlice[T any](db any, key QueryCacheItemKey, v T) { if cache, ok := db.(QueryCache); ok { cache.UpdateSlice(key, func(s any) any { if s == nil { @@ -145,7 +145,7 @@ type lru = simplelru.LRU[lruCacheKey, any] type queryCache struct { sync.Mutex updateMtx sync.RWMutex - subKeyMap map[queryCacheKey][]QueryCacheSubKey + subKeyMap map[QueryCacheItemKey][]QueryCacheSubKey cacheSizesByKind map[QueryCacheKind]int caches map[QueryCacheKind]*lru } @@ -162,7 +162,7 @@ func (c *queryCache) ensureLRU(kind QueryCacheKind) *lru { } lruForKind, err := simplelru.NewLRU[lruCacheKey, any](size, func(k lruCacheKey, v any) { if k.subKey == mainSubKey { - c.clearSubKeys(queryCacheKey{Kind: kind, Key: k.key}) + c.clearSubKeys(QueryCacheItemKey{Kind: kind, Key: k.key}) } }) if err != nil { @@ -175,7 +175,7 @@ func (c *queryCache) ensureLRU(kind QueryCacheKind) *lru { return lruForKind } -func (c *queryCache) clearSubKeys(key queryCacheKey) { +func (c *queryCache) clearSubKeys(key QueryCacheItemKey) { lru, found := c.caches[key.Kind] if !found { return @@ -188,7 +188,7 @@ func (c *queryCache) clearSubKeys(key queryCacheKey) { } } -func (c *queryCache) get(key queryCacheKey, subKey QueryCacheSubKey) (any, bool) { +func (c *queryCache) get(key QueryCacheItemKey, subKey QueryCacheSubKey) (any, bool) { c.Lock() defer c.Unlock() lru, found := c.caches[key.Kind] @@ -202,14 +202,14 @@ func (c *queryCache) get(key queryCacheKey, subKey QueryCacheSubKey) (any, bool) }) } -func (c *queryCache) set(key queryCacheKey, subKey QueryCacheSubKey, v any) { +func (c *queryCache) set(key QueryCacheItemKey, subKey QueryCacheSubKey, v any) { c.Lock() defer c.Unlock() if subKey != mainSubKey { sks := c.subKeyMap[key] if slices.Index(sks, subKey) < 0 { if c.subKeyMap == nil { - c.subKeyMap = make(map[queryCacheKey][]QueryCacheSubKey) + c.subKeyMap = make(map[QueryCacheItemKey][]QueryCacheSubKey) } c.subKeyMap[key] = append(sks, subKey) } @@ -224,7 +224,7 @@ func (c *queryCache) IsCached() bool { func (c *queryCache) GetValue( ctx context.Context, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve UntypedRetrieveFunc, ) (any, error) { @@ -251,7 +251,7 @@ func (c *queryCache) GetValue( return v, err } -func (c *queryCache) UpdateSlice(key queryCacheKey, update SliceAppender) { +func (c *queryCache) UpdateSlice(key QueryCacheItemKey, update SliceAppender) { if c == nil { return } diff --git a/sql/recovery/recovery_test.go b/sql/recovery/recovery_test.go index 976091e0777..7f4e396925b 100644 --- a/sql/recovery/recovery_test.go +++ b/sql/recovery/recovery_test.go @@ -6,12 +6,12 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRecoveryInfo(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() restore := types.LayerID(12) got, err := recovery.CheckpointInfo(db) diff --git a/sql/rewards/rewards_test.go b/sql/rewards/rewards_test.go index 5cd5715a530..71fd23e090a 100644 --- a/sql/rewards/rewards_test.go +++ b/sql/rewards/rewards_test.go @@ -6,13 +6,15 @@ import ( "testing" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRewards(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() var part uint64 = math.MaxUint64 / 2 lyrReward := part / 2 @@ -199,13 +201,17 @@ func TestRewards(t *testing.T) { } func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { - migrations, err := sql.StateMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - sort.Slice(migrations, func(i, j int) bool { return migrations[i].Order() < migrations[j].Order() }) + sort.Slice(schema.Migrations, func(i, j int) bool { + return schema.Migrations[i].Order() < schema.Migrations[j].Order() + }) + origMigrations := schema.Migrations + schema.Migrations = schema.Migrations[:7] // apply previous migrations - db := sql.InMemory( - sql.WithMigrations(migrations[:7]), + db := statesql.InMemory( + sql.WithDatabaseSchema(schema), ) // verify that the DB is empty @@ -217,7 +223,7 @@ func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { require.NoError(t, err) // apply the migration - err = migrations[7].Apply(db) + err = origMigrations[7].Apply(db, zaptest.NewLogger(t)) require.NoError(t, err) // verify that db is still empty @@ -230,13 +236,19 @@ func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { } func Test_0008Migration(t *testing.T) { - migrations, err := sql.StateMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - sort.Slice(migrations, func(i, j int) bool { return migrations[i].Order() < migrations[j].Order() }) + sort.Slice(schema.Migrations, func(i, j int) bool { + return schema.Migrations[i].Order() < schema.Migrations[j].Order() + }) + origMigrations := schema.Migrations + schema.Migrations = schema.Migrations[:7] // apply previous migrations - db := sql.InMemory( - sql.WithMigrations(migrations[:7]), + db := statesql.InMemory( + sql.WithDatabaseSchema(schema), + sql.WithForceMigrations(true), + sql.WithAllowSchemaDrift(true), ) // verify that the DB is empty @@ -279,7 +291,7 @@ func Test_0008Migration(t *testing.T) { require.NoError(t, err) // apply the migration - err = migrations[7].Apply(db) + err = origMigrations[7].Apply(db, zaptest.NewLogger(t)) require.NoError(t, err) // verify that one row is still present diff --git a/sql/schema.go b/sql/schema.go new file mode 100644 index 00000000000..cf125f162ea --- /dev/null +++ b/sql/schema.go @@ -0,0 +1,215 @@ +package sql + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" +) + +const ( + SchemaPath = "schema/schema.sql" + UpdatedSchemaPath = "schema/schema.sql.updated" +) + +// LoadDBSchemaScript retrieves the database schema as text. +func LoadDBSchemaScript(db Executor) (string, error) { + var ( + err error + sb strings.Builder + ) + version, err := version(db) + if err != nil { + return "", err + } + fmt.Fprintf(&sb, "PRAGMA user_version = %d;\n", version) + if _, err = db.Exec( + // Type is either 'index' or 'table', we want tables to go first + `select tbl_name, sql || ';' from sqlite_master + where sql is not null + order by tbl_name, type desc, name`, + nil, func(st *Statement) bool { + fmt.Fprintln(&sb, st.ColumnText(1)) + return true + }); err != nil { + return "", fmt.Errorf("error retrieving DB schema: %w", err) + } + // On Windows, the result contains extra carriage returns + return strings.ReplaceAll(sb.String(), "\r", ""), nil +} + +// Schema represents database schema. +type Schema struct { + Script string + Migrations MigrationList + skipMigration map[int]struct{} +} + +// Diff diffs the database schema against the actual schema. +// If there's no differences, it returns an empty string. +func (s *Schema) Diff(actualScript string) string { + return cmp.Diff(s.Script, actualScript) +} + +// WriteToFile writes the schema to the corresponding updated schema file. +func (s *Schema) WriteToFile(basedir string) error { + path := filepath.Join(basedir, UpdatedSchemaPath) + if err := os.WriteFile(path, []byte(s.Script), 0o777); err != nil { + return fmt.Errorf("error writing schema file %s: %w", path, err) + } + return nil +} + +// SkipMigrations skips the specified migrations. +func (s *Schema) SkipMigrations(i ...int) { + if s.skipMigration == nil { + s.skipMigration = make(map[int]struct{}) + } + for _, index := range i { + s.skipMigration[index] = struct{}{} + } +} + +// Apply applies the schema to the database. +func (s *Schema) Apply(db Database) error { + return db.WithTx(context.Background(), func(tx Transaction) error { + scanner := bufio.NewScanner(strings.NewReader(s.Script)) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if i := bytes.Index(data, []byte(";")); i >= 0 { + return i + 1, data[0 : i+1], nil + } + return 0, nil, nil + }) + for scanner.Scan() { + if _, err := tx.Exec(scanner.Text(), nil, nil); err != nil { + return fmt.Errorf("exec %s: %w", scanner.Text(), err) + } + } + return nil + }) +} + +func (s *Schema) CheckDBVersion(logger *zap.Logger, db Database) (before, after int, err error) { + if len(s.Migrations) == 0 { + return 0, 0, nil + } + before, err = version(db) + if err != nil { + return 0, 0, err + } + after = s.Migrations.Version() + if before > after { + logger.Error("database version is newer than expected - downgrade is not supported", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return before, after, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) + } + + return before, after, nil +} + +// Migrate performs database migration. In case if migrations are disabled, the database +// version is checked but no migrations are run, and if the database is too old and +// migrations are disabled, an error is returned. +func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState int) error { + for i, m := range s.Migrations { + if m.Order() <= before { + continue + } + if err := db.WithTx(context.Background(), func(tx Transaction) error { + if _, ok := s.skipMigration[m.Order()]; !ok { + if err := m.Apply(tx, logger); err != nil { + for j := i; j >= 0 && s.Migrations[j].Order() > before; j-- { + if e := s.Migrations[j].Rollback(); e != nil { + err = errors.Join(err, fmt.Errorf("rollback %s: %w", m.Name(), e)) + break + } + } + + return fmt.Errorf("apply %s: %w", m.Name(), err) + } + } + // version is set intentionally even if actual migration was skipped + if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.Order()), nil, nil); err != nil { + return fmt.Errorf("update user_version to %d: %w", m.Order(), err) + } + return nil + }); err != nil { + err = errors.Join(err, db.Close()) + return err + } + + if vacuumState != 0 && before <= vacuumState { + if err := Vacuum(db); err != nil { + err = errors.Join(err, db.Close()) + return err + } + } + before = m.Order() + } + return nil +} + +// SchemaGenOpt represents a schema generator option. +type SchemaGenOpt func(g *SchemaGen) + +func withDefaultOut(w io.Writer) SchemaGenOpt { + return func(g *SchemaGen) { + g.defaultOut = w + } +} + +// SchemaGen generates database schema files. +type SchemaGen struct { + logger *zap.Logger + schema *Schema + defaultOut io.Writer +} + +// NewSchemaGen creates a new SchemaGen instance. +func NewSchemaGen(logger *zap.Logger, schema *Schema, opts ...SchemaGenOpt) *SchemaGen { + g := &SchemaGen{logger: logger, schema: schema, defaultOut: os.Stdout} + for _, opt := range opts { + opt(g) + } + return g +} + +// Generate generates database schema and writes it to the specified file. +// If an empty string is specified as outputFile, os.Stdout is used for output. +func (g *SchemaGen) Generate(outputFile string) error { + db, err := OpenInMemory( + WithLogger(g.logger), + WithDatabaseSchema(g.schema), + WithForceMigrations(true), + WithIgnoreSchemaDrift()) + if err != nil { + return fmt.Errorf("error opening in-memory db: %w", err) + } + defer func() { + if err := db.Close(); err != nil { + g.logger.Error("error closing in-memory db: %w", zap.Error(err)) + } + }() + loadedScript, err := LoadDBSchemaScript(db) + if err != nil { + return fmt.Errorf("error loading DB schema script: %w", err) + } + if outputFile == "" { + if _, err := io.WriteString(g.defaultOut, loadedScript); err != nil { + return fmt.Errorf("error writing schema file: %w", err) + } + } else if err := os.WriteFile(outputFile, []byte(loadedScript), 0o777); err != nil { + return fmt.Errorf("error writing schema file %q: %w", outputFile, err) + } + return nil +} diff --git a/sql/schema_test.go b/sql/schema_test.go new file mode 100644 index 00000000000..f9ccc083fe1 --- /dev/null +++ b/sql/schema_test.go @@ -0,0 +1,54 @@ +package sql + +import ( + "os" + "path/filepath" + "strings" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" +) + +func TestSchemaGen(t *testing.T) { + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + fs := fstest.MapFS{ + "schema/migrations/0001_first.sql": &fstest.MapFile{ + Data: []byte("create table foo(id int);"), + }, + "schema/migrations/0002_second.sql": &fstest.MapFile{ + Data: []byte("create table bar(id int);"), + }, + } + migrations, err := LoadSQLMigrations(fs) + require.NoError(t, err) + require.Len(t, migrations, 2) + schema := &Schema{ + Script: "this should not be run", + Migrations: migrations, + } + var sb strings.Builder + g := NewSchemaGen(logger, schema, withDefaultOut(&sb)) + tempDir := t.TempDir() + schemaPath := filepath.Join(tempDir, "schema.sql") + require.NoError(t, g.Generate(schemaPath)) + contents, err := os.ReadFile(schemaPath) + require.NoError(t, err) + require.Equal(t, + "PRAGMA user_version = 2;\nCREATE TABLE bar(id int);\nCREATE TABLE foo(id int);\n", + string(contents)) + require.NoError(t, g.Generate("")) + require.Equal(t, string(contents), sb.String()) + + require.Equal(t, 0, observedLogs.Len(), + "expected 0 warning messages in the log (schema drift warnings?)") +} diff --git a/sql/schemagen/main.go b/sql/schemagen/main.go new file mode 100644 index 00000000000..c55d2b3ad46 --- /dev/null +++ b/sql/schemagen/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "flag" + "os" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +var ( + level = zap.LevelFlag("level", zapcore.ErrorLevel, "set log verbosity level") + dbType = flag.String("dbtype", "state", "database type (state, local, default state)") + output = flag.String("output", "", "output file (defaults to stdout)") +) + +func main() { + var ( + err error + schema *sql.Schema + ) + flag.Parse() + core := zapcore.NewCore( + zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig()), + os.Stderr, + zap.NewAtomicLevelAt(*level), + ) + logger := zap.New(core).With(zap.String("dbType", *dbType)) + switch *dbType { + case "state": + schema, err = statesql.Schema() + case "local": + schema, err = localsql.Schema() + default: + logger.Fatal("unknown database type, must be state or local") + } + if err != nil { + logger.Fatal("error loading db schema", zap.Error(err)) + } + g := sql.NewSchemaGen(logger, schema) + if err := g.Generate(*output); err != nil { + logger.Fatal("error generating schema", zap.Error(err), zap.String("output", *output)) + } +} diff --git a/sql/migrations/state/0001_initial.sql b/sql/statesql/schema/migrations/0001_initial.sql similarity index 100% rename from sql/migrations/state/0001_initial.sql rename to sql/statesql/schema/migrations/0001_initial.sql diff --git a/sql/migrations/state/0002_v1.0.3.sql b/sql/statesql/schema/migrations/0002_v1.0.3.sql similarity index 100% rename from sql/migrations/state/0002_v1.0.3.sql rename to sql/statesql/schema/migrations/0002_v1.0.3.sql diff --git a/sql/migrations/state/0003_v1.1.5.sql b/sql/statesql/schema/migrations/0003_v1.1.5.sql similarity index 100% rename from sql/migrations/state/0003_v1.1.5.sql rename to sql/statesql/schema/migrations/0003_v1.1.5.sql diff --git a/sql/migrations/state/0004_v1.1.7.sql b/sql/statesql/schema/migrations/0004_v1.1.7.sql similarity index 100% rename from sql/migrations/state/0004_v1.1.7.sql rename to sql/statesql/schema/migrations/0004_v1.1.7.sql diff --git a/sql/migrations/state/0005_v1.2.0.sql b/sql/statesql/schema/migrations/0005_v1.2.0.sql similarity index 100% rename from sql/migrations/state/0005_v1.2.0.sql rename to sql/statesql/schema/migrations/0005_v1.2.0.sql diff --git a/sql/migrations/state/0006_v1.2.2.sql b/sql/statesql/schema/migrations/0006_v1.2.2.sql similarity index 100% rename from sql/migrations/state/0006_v1.2.2.sql rename to sql/statesql/schema/migrations/0006_v1.2.2.sql diff --git a/sql/migrations/state/0007_v1.3.0.sql b/sql/statesql/schema/migrations/0007_v1.3.0.sql similarity index 100% rename from sql/migrations/state/0007_v1.3.0.sql rename to sql/statesql/schema/migrations/0007_v1.3.0.sql diff --git a/sql/migrations/state/0008_rewards.sql b/sql/statesql/schema/migrations/0008_rewards.sql similarity index 100% rename from sql/migrations/state/0008_rewards.sql rename to sql/statesql/schema/migrations/0008_rewards.sql diff --git a/sql/migrations/state/0009_prune_activesets.sql b/sql/statesql/schema/migrations/0009_prune_activesets.sql similarity index 100% rename from sql/migrations/state/0009_prune_activesets.sql rename to sql/statesql/schema/migrations/0009_prune_activesets.sql diff --git a/sql/migrations/state/0010_rowid.sql b/sql/statesql/schema/migrations/0010_rowid.sql similarity index 100% rename from sql/migrations/state/0010_rowid.sql rename to sql/statesql/schema/migrations/0010_rowid.sql diff --git a/sql/migrations/state/0011_atxs_extra_index.sql b/sql/statesql/schema/migrations/0011_atxs_extra_index.sql similarity index 100% rename from sql/migrations/state/0011_atxs_extra_index.sql rename to sql/statesql/schema/migrations/0011_atxs_extra_index.sql diff --git a/sql/migrations/state/0012_atx_validity.sql b/sql/statesql/schema/migrations/0012_atx_validity.sql similarity index 100% rename from sql/migrations/state/0012_atx_validity.sql rename to sql/statesql/schema/migrations/0012_atx_validity.sql diff --git a/sql/migrations/state/0013_atx_coinbase_index.sql b/sql/statesql/schema/migrations/0013_atx_coinbase_index.sql similarity index 100% rename from sql/migrations/state/0013_atx_coinbase_index.sql rename to sql/statesql/schema/migrations/0013_atx_coinbase_index.sql diff --git a/sql/migrations/state/0014_remove_proposals.sql b/sql/statesql/schema/migrations/0014_remove_proposals.sql similarity index 100% rename from sql/migrations/state/0014_remove_proposals.sql rename to sql/statesql/schema/migrations/0014_remove_proposals.sql diff --git a/sql/migrations/state/0015_nonce_index.sql b/sql/statesql/schema/migrations/0015_nonce_index.sql similarity index 100% rename from sql/migrations/state/0015_nonce_index.sql rename to sql/statesql/schema/migrations/0015_nonce_index.sql diff --git a/sql/migrations/state/0016_atx_blob.sql b/sql/statesql/schema/migrations/0016_atx_blob.sql similarity index 100% rename from sql/migrations/state/0016_atx_blob.sql rename to sql/statesql/schema/migrations/0016_atx_blob.sql diff --git a/sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql b/sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql similarity index 100% rename from sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql rename to sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql diff --git a/sql/migrations/state/0018_atx_blob_version.sql b/sql/statesql/schema/migrations/0018_atx_blob_version.sql similarity index 100% rename from sql/migrations/state/0018_atx_blob_version.sql rename to sql/statesql/schema/migrations/0018_atx_blob_version.sql diff --git a/sql/migrations/state/0019_marriages.sql b/sql/statesql/schema/migrations/0019_marriages.sql similarity index 100% rename from sql/migrations/state/0019_marriages.sql rename to sql/statesql/schema/migrations/0019_marriages.sql diff --git a/sql/migrations/state/0020_atx_merge.sql b/sql/statesql/schema/migrations/0020_atx_merge.sql similarity index 100% rename from sql/migrations/state/0020_atx_merge.sql rename to sql/statesql/schema/migrations/0020_atx_merge.sql diff --git a/sql/migrations/state/0021_atx_posts.sql b/sql/statesql/schema/migrations/0021_atx_posts.sql similarity index 100% rename from sql/migrations/state/0021_atx_posts.sql rename to sql/statesql/schema/migrations/0021_atx_posts.sql diff --git a/sql/statesql/schema/migrations/0022_schema_cleanup.sql b/sql/statesql/schema/migrations/0022_schema_cleanup.sql new file mode 100644 index 00000000000..215bca7adcd --- /dev/null +++ b/sql/statesql/schema/migrations/0022_schema_cleanup.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS _litestream_seq; +DROP TABLE IF EXISTS _litestream_lock; diff --git a/sql/statesql/schema/schema.sql b/sql/statesql/schema/schema.sql new file mode 100755 index 00000000000..b20bc7b9b0e --- /dev/null +++ b/sql/statesql/schema/schema.sql @@ -0,0 +1,158 @@ +PRAGMA user_version = 22; +CREATE TABLE accounts +( + address CHAR(24), + balance UNSIGNED LONG INT, + next_nonce UNSIGNED LONG INT, + layer_updated UNSIGNED LONG INT, + template CHAR(24), + state BLOB, + PRIMARY KEY (address, layer_updated DESC) +); +CREATE INDEX accounts_by_layer_updated ON accounts (layer_updated); +CREATE TABLE activesets +( + id CHAR(32) PRIMARY KEY, + active_set BLOB +, epoch INT DEFAULT 0 NOT NULL) WITHOUT ROWID; +CREATE INDEX activesets_by_epoch ON activesets (epoch asc); +CREATE TABLE atx_blobs +( + id CHAR(32), + atx BLOB +, version INTEGER); +CREATE UNIQUE INDEX atx_blobs_id ON atx_blobs (id); +CREATE TABLE atxs +( + id CHAR(32), + prev_id CHAR(32), + epoch INT NOT NULL, + effective_num_units INT NOT NULL, + commitment_atx CHAR(32), + nonce UNSIGNED LONG INT, + base_tick_height UNSIGNED LONG INT, + tick_count UNSIGNED LONG INT, + sequence UNSIGNED LONG INT, + pubkey CHAR(32), + coinbase CHAR(24), + received INT NOT NULL, + validity INTEGER DEFAULT false +, weight INTEGER); +CREATE INDEX atxs_by_coinbase ON atxs (coinbase); +CREATE INDEX atxs_by_epoch_by_pubkey ON atxs (epoch, pubkey); +CREATE INDEX atxs_by_epoch_by_pubkey_nonce ON atxs (pubkey, epoch desc, nonce) WHERE nonce IS NOT NULL; +CREATE INDEX atxs_by_epoch_id on atxs (epoch, id); +CREATE INDEX atxs_by_pubkey_by_epoch_desc ON atxs (pubkey, epoch desc); +CREATE UNIQUE INDEX atxs_id ON atxs (id); +CREATE TABLE ballots +( + id CHAR(20) PRIMARY KEY, + atx CHAR(32) NOT NULL, + layer INT NOT NULL, + pubkey VARCHAR, + ballot BLOB +); +CREATE INDEX ballots_by_atx_by_layer ON ballots (atx, layer asc); +CREATE INDEX ballots_by_layer_by_pubkey ON ballots (layer asc, pubkey); +CREATE TABLE beacons +( + epoch INT NOT NULL PRIMARY KEY, + beacon CHAR(4) +) WITHOUT ROWID; +CREATE TABLE block_transactions +( + tid CHAR(32), + bid CHAR(20), + layer INT NOT NULL, + PRIMARY KEY (tid, bid) +) WITHOUT ROWID; +CREATE TABLE blocks +( + id CHAR(20) PRIMARY KEY, + layer INT NOT NULL, + validity SMALL INT, + block BLOB +); +CREATE INDEX blocks_by_layer ON blocks (layer, id asc); +CREATE TABLE certificates +( + layer INT NOT NULL, + block VARCHAR NOT NULL, + cert BLOB, + valid bool NOT NULL, + PRIMARY KEY (layer, block) +); +CREATE TABLE identities +( + pubkey VARCHAR PRIMARY KEY, + proof BLOB +, received INT DEFAULT 0 NOT NULL, marriage_atx CHAR(32), marriage_idx INTEGER, marriage_target CHAR(32), marriage_signature CHAR(64)) WITHOUT ROWID; +CREATE TABLE layers +( + id INT PRIMARY KEY DESC, + weak_coin SMALL INT, + processed SMALL INT, + applied_block VARCHAR, + state_hash CHAR(32), + aggregated_hash CHAR(32) +) WITHOUT ROWID; +CREATE INDEX layers_by_processed ON layers (processed); +CREATE TABLE poets +( + ref VARCHAR PRIMARY KEY, + poet BLOB, + service_id VARCHAR, + round_id VARCHAR +); +CREATE INDEX poets_by_service_id_by_round_id ON poets (service_id, round_id); +CREATE TABLE posts ( + atxid CHAR(32) NOT NULL, + pubkey CHAR(32) NOT NULL, + units INT NOT NULL, + UNIQUE (atxid, pubkey) + ); +CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey); +CREATE TABLE proposal_transactions +( + tid CHAR(32), + pid CHAR(20), + layer INT NOT NULL, + PRIMARY KEY (tid, pid) +) WITHOUT ROWID; +CREATE TABLE recovery +( + id INTEGER PRIMARY KEY CHECK (id = 1), + restore INT NOT NULL +); +CREATE TABLE rewards +( + pubkey CHAR(32), + coinbase CHAR(24) NOT NULL, + layer INT NOT NULL, + total_reward UNSIGNED LONG INT, + layer_reward UNSIGNED LONG INT, + PRIMARY KEY (pubkey, layer) +); +CREATE INDEX rewards_by_coinbase ON rewards (coinbase, layer); +CREATE INDEX rewards_by_layer ON rewards (layer asc); +CREATE TABLE transactions +( + id CHAR(32) PRIMARY KEY, + tx BLOB, + header BLOB, + result BLOB, + layer INT, + block CHAR(20), + principal CHAR(24), + nonce BLOB, + timestamp INT NOT NULL +) WITHOUT ROWID; +CREATE INDEX transaction_by_layer_principal ON transactions (layer asc, principal); +CREATE INDEX transaction_by_principal_nonce ON transactions (principal, nonce); +CREATE TABLE transactions_results_addresses +( + address CHAR(24), + tid CHAR(32), + PRIMARY KEY (tid, address) +) WITHOUT ROWID; +CREATE INDEX transactions_results_addresses_by_address ON transactions_results_addresses(address); diff --git a/sql/migrations/state_0021_migration.go b/sql/statesql/state_0021_migration.go similarity index 91% rename from sql/migrations/state_0021_migration.go rename to sql/statesql/state_0021_migration.go index 8a98145e571..ffa423a0128 100644 --- a/sql/migrations/state_0021_migration.go +++ b/sql/statesql/state_0021_migration.go @@ -1,4 +1,4 @@ -package migrations +package statesql import ( "errors" @@ -14,14 +14,12 @@ import ( ) type migration0021 struct { - batch int - logger *zap.Logger + batch int } -func New0021Migration(log *zap.Logger, batch int) *migration0021 { +func New0021Migration(batch int) *migration0021 { return &migration0021{ - logger: log, - batch: batch, + batch: batch, } } @@ -37,7 +35,7 @@ func (*migration0021) Rollback() error { return nil } -func (m *migration0021) Apply(db sql.Executor) error { +func (m *migration0021) Apply(db sql.Executor, logger *zap.Logger) error { if err := m.createTable(db); err != nil { return err } @@ -49,7 +47,7 @@ func (m *migration0021) Apply(db sql.Executor) error { if err != nil { return fmt.Errorf("counting all ATXs %w", err) } - m.logger.Info("applying migration 21", zap.Int("total", total)) + logger.Info("applying migration 21", zap.Int("total", total)) for offset := 0; ; offset += m.batch { n, err := m.processBatch(db, offset, m.batch) @@ -59,7 +57,7 @@ func (m *migration0021) Apply(db sql.Executor) error { processed := offset + n progress := float64(processed) * 100.0 / float64(total) - m.logger.Info("processed ATXs", zap.Float64("progress [%]", progress)) + logger.Info("processed ATXs", zap.Float64("progress [%]", progress)) if processed >= total { return nil } diff --git a/sql/migrations/state_0021_migration_test.go b/sql/statesql/state_0021_migration_test.go similarity index 57% rename from sql/migrations/state_0021_migration_test.go rename to sql/statesql/state_0021_migration_test.go index 76f50d6d1c1..65edcbd3c62 100644 --- a/sql/migrations/state_0021_migration_test.go +++ b/sql/statesql/state_0021_migration_test.go @@ -1,7 +1,7 @@ -package migrations +package statesql import ( - "strings" + "slices" "testing" "github.com/stretchr/testify/require" @@ -15,42 +15,21 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" ) -// Test that in-code migration results in the same schema as the .sql one. -func Test0021Migration_CompatibleSchema(t *testing.T) { - db := sql.InMemory( - sql.WithLogger(zaptest.NewLogger(t)), - sql.WithMigration(New0021Migration(zaptest.NewLogger(t), 1000)), - ) - - var schemasInCode []string - _, err := db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - schemasInCode = append(schemasInCode, sql) - return true - }) +func Test0021Migration(t *testing.T) { + schema, err := Schema() require.NoError(t, err) - require.NoError(t, db.Close()) - - db = sql.InMemory() - - var schemasInFile []string - _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - schemasInFile = append(schemasInFile, sql) - return true + schema.Migrations = slices.DeleteFunc(schema.Migrations, func(m sql.Migration) bool { + if m.Order() != 21 { + t.Logf("QQQQQ: include migration %d -- %s", m.Order(), m.Name()) + } + return m.Order() == 21 }) - require.NoError(t, err) - require.NoError(t, db.Close()) - require.Equal(t, schemasInFile, schemasInCode) -} - -func Test0021Migration(t *testing.T) { db := sql.InMemory( sql.WithLogger(zaptest.NewLogger(t)), - sql.WithSkipMigrations(21), + sql.WithDatabaseSchema(schema), + sql.WithIgnoreSchemaDrift(), + sql.WithForceMigrations(true), ) var signers [177]*signing.EdSigner @@ -82,9 +61,9 @@ func Test0021Migration(t *testing.T) { } } - m := New0021Migration(zaptest.NewLogger(t), 1000) + m := New0021Migration(1000) require.Equal(t, 21, m.Order()) - require.NoError(t, m.Apply(db)) + require.NoError(t, m.Apply(db, zaptest.NewLogger(t))) for _, posts := range allPosts { for atx, post := range posts { diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go new file mode 100644 index 00000000000..252eb5099cc --- /dev/null +++ b/sql/statesql/statesql.go @@ -0,0 +1,67 @@ +package statesql + +import ( + "embed" + "strings" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:generate go run ../schemagen -dbtype state -output schema/schema.sql + +//go:embed schema/schema.sql +var schemaScript string + +//go:embed schema/migrations/*.sql +var migrations embed.FS + +type database struct { + sql.Database +} + +var _ sql.StateDatabase = &database{} + +func (db *database) IsStateDatabase() {} + +// Schema returns the schema for the state database. +func Schema() (*sql.Schema, error) { + sqlMigrations, err := sql.LoadSQLMigrations(migrations) + if err != nil { + return nil, err + } + sqlMigrations = sqlMigrations.AddMigration(New0021Migration(1_000_000)) + // NOTE: coded state migrations can be added here + // They can be a part of this localsql package + return &sql.Schema{ + Script: strings.ReplaceAll(schemaScript, "\r", ""), + Migrations: sqlMigrations, + }, nil +} + +// Open opens a state database. +func Open(uri string, opts ...sql.Opt) (sql.StateDatabase, error) { + schema, err := Schema() + if err != nil { + return nil, err + } + opts = append([]sql.Opt{sql.WithDatabaseSchema(schema)}, opts...) + db, err := sql.Open(uri, opts...) + if err != nil { + return nil, err + } + return &database{Database: db}, nil +} + +// Open opens an in-memory state database. +func InMemory(opts ...sql.Opt) sql.StateDatabase { + schema, err := Schema() + if err != nil { + panic(err) + } + defaultOpts := []sql.Opt{ + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db := sql.InMemory(opts...) + return &database{Database: db} +} diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go new file mode 100644 index 00000000000..207d7f2a5a2 --- /dev/null +++ b/sql/statesql/statesql_test.go @@ -0,0 +1,74 @@ +package statesql + +import ( + "path/filepath" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func TestIdempotentMigration(t *testing.T) { + observer, observedLogs := observer.New(zapcore.InfoLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + file := filepath.Join(t.TempDir(), "test.db") + db, err := Open("file:"+file, sql.WithForceMigrations(true), sql.WithLogger(logger)) + require.NoError(t, err) + + var versionA int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionA = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + require.NoError(t, db.Close()) + + // "running migrations" + "applying migration 21" + "processed ATXs" + require.Equal(t, 3, observedLogs.Len(), "expected count of log messages") + l := observedLogs.All()[0] + require.Equal(t, "running migrations", l.Message) + require.Equal(t, int64(0), l.ContextMap()["current version"]) + require.Equal(t, int64(versionA), l.ContextMap()["target version"]) + + db, err = Open("file:"+file, sql.WithLogger(logger)) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var versionB int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionB = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + schema, err := Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), versionA) + require.Equal(t, expectedVersion.Order(), versionB) + + require.NoError(t, db.Close()) + // make sure there's no schema drift warnings in the logs + require.Equal(t, 3, observedLogs.Len(), "expected 1 log message") +} diff --git a/sql/transactions/iterator_test.go b/sql/transactions/iterator_test.go index fd3cad3a8e0..334cef2a97e 100644 --- a/sql/transactions/iterator_test.go +++ b/sql/transactions/iterator_test.go @@ -15,6 +15,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func matchTx(tx types.TransactionWithResult, filter ResultsFilter) bool { @@ -59,11 +60,11 @@ func filterTxs(txs []types.TransactionWithResult, filter ResultsFilter) []types. } func TestIterateResults(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() gen := fixture.NewTransactionResultGenerator() txs := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(context.TODO(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.TODO(), func(dtx sql.Transaction) error { for i := range txs { tx := gen.Next() @@ -142,12 +143,12 @@ func TestIterateResults(t *testing.T) { } func TestIterateSnapshot(t *testing.T) { - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) t.Cleanup(func() { require.NoError(t, db.Close()) }) require.NoError(t, err) gen := fixture.NewTransactionResultGenerator() expect := 10 - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { for i := 0; i < expect; i++ { tx := gen.Next() @@ -175,7 +176,7 @@ func TestIterateSnapshot(t *testing.T) { }() <-initialized - require.NoError(t, db.WithTx(context.TODO(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.TODO(), func(dtx sql.Transaction) error { for i := 0; i < 10; i++ { tx := gen.Next() diff --git a/sql/transactions/transactions.go b/sql/transactions/transactions.go index f0830e8e3ad..b1e23879708 100644 --- a/sql/transactions/transactions.go +++ b/sql/transactions/transactions.go @@ -28,8 +28,8 @@ func Add(db sql.Executor, tx *types.Transaction, received time.Time) error { if _, err = db.Exec(` insert into transactions (id, tx, header, principal, nonce, timestamp) values (?1, ?2, ?3, ?4, ?5, ?6) - on conflict(id) do update set - header=?3, principal=?4, nonce=?5 + on conflict(id) do update set + header=?3, principal=?4, nonce=?5 where header is null;`, func(stmt *sql.Statement) { stmt.BindBytes(1, tx.ID.Bytes()) @@ -49,7 +49,7 @@ func Add(db sql.Executor, tx *types.Transaction, received time.Time) error { // AddToProposal associates a transaction with a proposal. func AddToProposal(db sql.Executor, tid types.TransactionID, lid types.LayerID, pid types.ProposalID) error { if _, err := db.Exec(` - insert into proposal_transactions (pid, tid, layer) values (?1, ?2, ?3) + insert into proposal_transactions (pid, tid, layer) values (?1, ?2, ?3) on conflict(tid, pid) do nothing;`, func(stmt *sql.Statement) { stmt.BindBytes(1, pid.Bytes()) @@ -136,8 +136,8 @@ func GetAppliedLayer(db sql.Executor, tid types.TransactionID) (types.LayerID, e } // UndoLayers unset all transactions to `statePending` from `from` layer to the max layer with applied transactions. -func UndoLayers(db *sql.Tx, from types.LayerID) error { - _, err := db.Exec(`delete from transactions_results_addresses +func UndoLayers(tx sql.Transaction, from types.LayerID) error { + _, err := tx.Exec(`delete from transactions_results_addresses where tid in (select id from transactions where layer >= ?1);`, func(stmt *sql.Statement) { stmt.BindInt64(1, int64(from)) @@ -145,8 +145,8 @@ func UndoLayers(db *sql.Tx, from types.LayerID) error { if err != nil { return fmt.Errorf("delete addresses mapping %w", err) } - _, err = db.Exec(`update transactions - set layer = null, block = null, result = null + _, err = tx.Exec(`update transactions + set layer = null, block = null, result = null where layer >= ?1`, func(stmt *sql.Statement) { stmt.BindInt64(1, int64(from)) @@ -287,7 +287,7 @@ func AddressesWithPendingTransactions(db sql.Executor) ([]types.AddressNonce, er // GetAcctPendingFromNonce get all pending transactions with nonce after `from` for the given address. func GetAcctPendingFromNonce(db sql.Executor, address types.Address, from uint64) ([]*types.MeshTransaction, error) { return queryPending(db, `select tx, header, layer, block, timestamp, id from transactions - where principal = ?1 and nonce >= ?2 and result is null + where principal = ?1 and nonce >= ?2 and result is null order by nonce asc, timestamp asc`, func(stmt *sql.Statement) { stmt.BindBytes(1, address.Bytes()) @@ -321,14 +321,14 @@ func queryPending( } // AddResult adds result for the transaction. -func AddResult(db *sql.Tx, id types.TransactionID, rst *types.TransactionResult) error { +func AddResult(tx sql.Transaction, id types.TransactionID, rst *types.TransactionResult) error { buf, err := codec.Encode(rst) if err != nil { return fmt.Errorf("encode %w", err) } - if rows, err := db.Exec(`update transactions - set result = ?2, layer = ?3, block = ?4 + if rows, err := tx.Exec(`update transactions + set result = ?2, layer = ?3, block = ?4 where id = ?1 and result is null returning id;`, func(stmt *sql.Statement) { stmt.BindBytes(1, id[:]) @@ -345,7 +345,7 @@ func AddResult(db *sql.Tx, id types.TransactionID, rst *types.TransactionResult) return fmt.Errorf("invalid state for %s", id) } for i := range rst.Addresses { - if _, err := db.Exec(`insert into transactions_results_addresses + if _, err := tx.Exec(`insert into transactions_results_addresses (address, tid) values (?1, ?2);`, func(stmt *sql.Statement) { stmt.BindBytes(1, rst.Addresses[i][:]) @@ -418,7 +418,7 @@ func IterateTransactionsOps( fn func(tx *types.MeshTransaction, result *types.TransactionResult) bool, ) error { var derr error - _, err := db.Exec(`select distinct tx, header, layer, block, timestamp, id, result + _, err := db.Exec(`select distinct tx, header, layer, block, timestamp, id, result from transactions left join transactions_results_addresses on id=tid`+builder.FilterFrom(operations), builder.BindingsFrom(operations), diff --git a/sql/transactions/transactions_test.go b/sql/transactions/transactions_test.go index ed2eef3cc95..4711e9fba2e 100644 --- a/sql/transactions/transactions_test.go +++ b/sql/transactions/transactions_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -70,7 +71,7 @@ func checkMeshTXEqual(t *testing.T, expected, got types.MeshTransaction) { } func TestAddGetHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer1, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -109,7 +110,7 @@ func TestAddGetHas(t *testing.T) { } func TestAddUpdatesHeader(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() txs := []*types.Transaction{ { RawTx: types.NewRawTx([]byte{1, 2, 3}), @@ -142,7 +143,7 @@ func TestAddUpdatesHeader(t *testing.T) { } func TestAddToProposal(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -166,7 +167,7 @@ func TestAddToProposal(t *testing.T) { } func TestDeleteProposalTxs(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() proposals := map[types.LayerID][]types.ProposalID{ types.LayerID(10): {{1, 1}, {1, 2}}, types.LayerID(11): {{2, 1}, {2, 2}}, @@ -197,7 +198,7 @@ func TestDeleteProposalTxs(t *testing.T) { } func TestAddToBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -221,7 +222,7 @@ func TestAddToBlock(t *testing.T) { } func TestApply_AlreadyApplied(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) lid := types.LayerID(10) @@ -231,17 +232,17 @@ func TestApply_AlreadyApplied(t *testing.T) { require.NoError(t, transactions.Add(db, tx, time.Now())) bid := types.RandomBlockID() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) // same block applied again - require.Error(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) // different block applied again - require.Error(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult( dtx, tx.ID, @@ -251,15 +252,15 @@ func TestApply_AlreadyApplied(t *testing.T) { } func TestUndoLayers_Empty(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, types.LayerID(199)) })) } func TestApplyAndUndoLayers(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) firstLayer := types.LayerID(10) @@ -272,7 +273,7 @@ func TestApplyAndUndoLayers(t *testing.T) { require.NoError(t, transactions.Add(db, tx, time.Now())) bid := types.RandomBlockID() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) applied = append(applied, tx.ID) @@ -284,7 +285,7 @@ func TestApplyAndUndoLayers(t *testing.T) { require.Equal(t, types.APPLIED, mtx.State) } // revert to firstLayer - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, firstLayer.Add(1)) })) @@ -300,7 +301,7 @@ func TestApplyAndUndoLayers(t *testing.T) { } func TestGetBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() rng := rand.New(rand.NewSource(1001)) @@ -333,7 +334,7 @@ func TestGetBlob(t *testing.T) { } func TestGetByAddress(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer1, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -348,7 +349,7 @@ func TestGetByAddress(t *testing.T) { createTX(t, signer1, signer2Address, 1, 191, 1), } received := time.Now() - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tx := range txs { require.NoError(t, transactions.Add(dbtx, tx, received)) require.NoError(t, transactions.AddResult(dbtx, tx.ID, &types.TransactionResult{Layer: lid})) @@ -369,7 +370,7 @@ func TestGetByAddress(t *testing.T) { } func TestGetAcctPendingFromNonce(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -404,7 +405,7 @@ func TestGetAcctPendingFromNonce(t *testing.T) { } func TestAppliedLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) require.NoError(t, err) @@ -417,7 +418,7 @@ func TestAppliedLayer(t *testing.T) { for _, tx := range txs { require.NoError(t, transactions.Add(db, tx, time.Now())) } - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, txs[0].ID, &types.TransactionResult{Layer: lid, Block: types.BlockID{1, 1}}) })) @@ -428,7 +429,7 @@ func TestAppliedLayer(t *testing.T) { _, err = transactions.GetAppliedLayer(db, txs[1].ID) require.ErrorIs(t, err, sql.ErrNotFound) - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, lid) })) _, err = transactions.GetAppliedLayer(db, txs[0].ID) @@ -455,7 +456,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { TxHeader: &types.TxHeader{Principal: principals[1], Nonce: 0}, }, } - db := sql.InMemory() + db := statesql.InMemory() for _, tx := range txs { require.NoError(t, transactions.Add(db, &tx, time.Time{})) } @@ -465,7 +466,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { {Address: principals[0], Nonce: txs[0].Nonce}, {Address: principals[1], Nonce: txs[2].Nonce}, }, rst) - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { return transactions.AddResult(dbtx, txs[0].ID, &types.TransactionResult{Message: "hey"}) })) rst, err = transactions.AddressesWithPendingTransactions(db) @@ -474,7 +475,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { {Address: principals[0], Nonce: txs[1].Nonce}, {Address: principals[1], Nonce: txs[2].Nonce}, }, rst) - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { return transactions.AddResult(dbtx, txs[2].ID, &types.TransactionResult{Message: "hey"}) })) rst, err = transactions.AddressesWithPendingTransactions(db) @@ -520,7 +521,7 @@ func TestTransactionInProposal(t *testing.T) { {2}, {3}, } - db := sql.InMemory() + db := statesql.InMemory() for i := range lids { require.NoError(t, transactions.AddToProposal(db, tid, lids[i], pids[i])) } @@ -546,7 +547,7 @@ func TestTransactionInBlock(t *testing.T) { {2}, {3}, } - db := sql.InMemory() + db := statesql.InMemory() for i := range lids { require.NoError(t, transactions.AddToBlock(db, tid, lids[i], bids[i])) } diff --git a/sql/vacuum_test.go b/sql/vacuum_test.go index b994516279e..1a89158c644 100644 --- a/sql/vacuum_test.go +++ b/sql/vacuum_test.go @@ -7,6 +7,6 @@ import ( ) func TestVacuumDB(t *testing.T) { - db := InMemory() + db := InMemory(WithIgnoreSchemaDrift()) require.NoError(t, Vacuum(db)) } diff --git a/syncer/atxsync/atxsync.go b/syncer/atxsync/atxsync.go index 41ac7ec7c0a..ab93cb62bf6 100644 --- a/syncer/atxsync/atxsync.go +++ b/syncer/atxsync/atxsync.go @@ -14,7 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/system" ) -func getMissing(db *sql.Database, set []types.ATXID) ([]types.ATXID, error) { +func getMissing(db sql.StateDatabase, set []types.ATXID) ([]types.ATXID, error) { missing := []types.ATXID{} for _, atx := range set { exist, err := atxs.Has(db, atx) @@ -35,7 +35,7 @@ func Download( ctx context.Context, retryInterval time.Duration, logger *zap.Logger, - db *sql.Database, + db sql.StateDatabase, fetcher system.AtxFetcher, set []types.ATXID, ) error { diff --git a/syncer/atxsync/atxsync_test.go b/syncer/atxsync/atxsync_test.go index 632721910e7..ffc05564581 100644 --- a/syncer/atxsync/atxsync_test.go +++ b/syncer/atxsync/atxsync_test.go @@ -11,8 +11,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/log/logtest" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -97,7 +97,7 @@ func TestDownload(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { logger := logtest.New(t) - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(t) fetcher := mocks.NewMockAtxFetcher(ctrl) for _, atx := range tc.existing { diff --git a/syncer/atxsync/syncer.go b/syncer/atxsync/syncer.go index cea9d825f7e..4d14043a4e5 100644 --- a/syncer/atxsync/syncer.go +++ b/syncer/atxsync/syncer.go @@ -17,7 +17,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/atxsync" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -76,7 +75,7 @@ func WithConfig(cfg Config) Opt { } } -func New(fetcher fetcher, db sql.Executor, localdb *localsql.Database, opts ...Opt) *Syncer { +func New(fetcher fetcher, db sql.Executor, localdb sql.LocalDatabase, opts ...Opt) *Syncer { s := &Syncer{ logger: zap.NewNop(), cfg: DefaultConfig(), @@ -95,7 +94,7 @@ type Syncer struct { cfg Config fetcher fetcher db sql.Executor - localdb *localsql.Database + localdb sql.LocalDatabase } func (s *Syncer) Download(parent context.Context, publish types.EpochID, downloadUntil time.Time) error { @@ -324,7 +323,7 @@ func (s *Syncer) downloadAtxs( } } - if err := s.localdb.WithTx(context.Background(), func(tx *sql.Tx) error { + if err := s.localdb.WithTx(context.Background(), func(tx sql.Transaction) error { err := atxsync.SaveRequest(tx, publish, lastSuccess, int64(len(state)), int64(len(downloaded))) if err != nil { return fmt.Errorf("failed to save request time: %w", err) diff --git a/syncer/atxsync/syncer_test.go b/syncer/atxsync/syncer_test.go index 8d9e3dc99b6..caa751eba58 100644 --- a/syncer/atxsync/syncer_test.go +++ b/syncer/atxsync/syncer_test.go @@ -18,6 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/atxsync" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/atxsync/mocks" "github.com/spacemeshos/go-spacemesh/system" ) @@ -42,7 +43,7 @@ func edata(ids ...string) *fetch.EpochData { func newTester(tb testing.TB, cfg Config) *tester { localdb := localsql.InMemory() - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(tb) fetcher := mocks.NewMockfetcher(ctrl) syncer := New(fetcher, db, localdb, WithConfig(cfg), WithLogger(logtest.New(tb).Zap())) @@ -60,8 +61,8 @@ func newTester(tb testing.TB, cfg Config) *tester { type tester struct { tb testing.TB syncer *Syncer - localdb *localsql.Database - db *sql.Database + localdb sql.LocalDatabase + db sql.StateDatabase cfg Config ctrl *gomock.Controller fetcher *mocks.Mockfetcher diff --git a/syncer/find_fork_test.go b/syncer/find_fork_test.go index 2759b6145b4..5fe14fb1ee1 100644 --- a/syncer/find_fork_test.go +++ b/syncer/find_fork_test.go @@ -19,19 +19,20 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/mocks" ) type testForkFinder struct { *syncer.ForkFinder - db *sql.Database + db sql.StateDatabase mFetcher *mocks.Mockfetcher } func newTestForkFinderWithDuration(t *testing.T, d time.Duration, lg *zap.Logger) *testForkFinder { mf := mocks.NewMockfetcher(gomock.NewController(t)) - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, layers.SetMeshHash(db, types.GetEffectiveGenesis(), types.RandomHash())) return &testForkFinder{ ForkFinder: syncer.NewForkFinder(lg, db, mf, d), @@ -88,7 +89,7 @@ func layerHash(layer int, good bool) types.Hash32 { return h2 } -func storeNodeHashes(t *testing.T, db *sql.Database, diverge, max int) { +func storeNodeHashes(t *testing.T, db sql.StateDatabase, diverge, max int) { for lid := 0; lid <= max; lid++ { if lid < diverge { require.NoError(t, layers.SetMeshHash(db, types.LayerID(uint32(lid)), layerHash(lid, true))) diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index 67a96ddc50c..3a143888525 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -18,7 +18,6 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/system" ) @@ -218,12 +217,12 @@ type Syncer struct { cfg Config fetcher fetcher db sql.Executor - localdb *localsql.Database + localdb sql.LocalDatabase clock clockwork.Clock peerErrMetric counter } -func New(fetcher fetcher, db sql.Executor, localdb *localsql.Database, opts ...Opt) *Syncer { +func New(fetcher fetcher, db sql.Executor, localdb sql.LocalDatabase, opts ...Opt) *Syncer { s := &Syncer{ logger: zap.NewNop(), cfg: DefaultConfig(), @@ -341,8 +340,15 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan } } -func (s *Syncer) updateState() error { - if err := malsync.UpdateSyncState(s.localdb, s.clock.Now()); err != nil { +func (s *Syncer) updateState(ctx context.Context) error { + if err := s.localdb.WithTx(ctx, func(tx sql.Transaction) error { + return malsync.UpdateSyncState(tx, s.clock.Now()) + }); err != nil { + if ctx.Err() != nil { + // FIXME: with crawshaw, canceling the context which has been used to get + // a connection from the pool may cause "database: no free connection" errors + err = ctx.Err() + } return fmt.Errorf("error updating malsync state: %w", err) } @@ -360,13 +366,13 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up if nothingToDownload { sst.done() if initial && sst.numSyncedPeers() >= s.cfg.MinSyncPeers { - if err := s.updateState(); err != nil { + if err := s.updateState(ctx); err != nil { return err } s.logger.Info("initial sync of malfeasance proofs completed", log.ZContext(ctx)) return nil } else if !initial && gotUpdate { - if err := s.updateState(); err != nil { + if err := s.updateState(ctx); err != nil { return err } } diff --git a/syncer/malsync/syncer_test.go b/syncer/malsync/syncer_test.go index f4297a7e878..3bca66cc170 100644 --- a/syncer/malsync/syncer_test.go +++ b/syncer/malsync/syncer_test.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/malsync/mocks" ) @@ -137,8 +138,8 @@ func malData(ids ...string) []types.NodeID { type tester struct { tb testing.TB syncer *Syncer - localdb *localsql.Database - db *sql.Database + localdb sql.LocalDatabase + db sql.StateDatabase cfg Config ctrl *gomock.Controller fetcher *mocks.Mockfetcher @@ -151,7 +152,7 @@ type tester struct { func newTester(tb testing.TB, cfg Config) *tester { localdb := localsql.InMemory() - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(tb) fetcher := mocks.NewMockfetcher(ctrl) clock := clockwork.NewFakeClock() diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index 9a3b1c41488..0327524807a 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -20,8 +20,8 @@ import ( "github.com/spacemeshos/go-spacemesh/mesh" mmocks "github.com/spacemeshos/go-spacemesh/mesh/mocks" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/mocks" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -127,7 +127,7 @@ func newTestSyncer(t *testing.T, interval time.Duration) *testSyncer { mCertHdr: mocks.NewMockcertHandler(ctrl), mForkFinder: mocks.NewMockforkFinder(ctrl), } - db := sql.InMemory() + db := statesql.InMemory() ts.cdb = datastore.NewCachedDB(db, lg.Zap()) var err error atxsdata := atxsdata.New() diff --git a/systest/tests/distributed_post_verification_test.go b/systest/tests/distributed_post_verification_test.go index a719fc61456..ef402bbd6f8 100644 --- a/systest/tests/distributed_post_verification_test.go +++ b/systest/tests/distributed_post_verification_test.go @@ -30,9 +30,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/handshake" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/systest/cluster" "github.com/spacemeshos/go-spacemesh/systest/testcontext" "github.com/spacemeshos/go-spacemesh/timesync" @@ -110,7 +110,7 @@ func TestPostMalfeasanceProof(t *testing.T) { postSetupMgr, err := activation.NewPostSetupManager( cfg.POST, logger.Named("post"), - datastore.NewCachedDB(sql.InMemory(), zap.NewNop()), + datastore.NewCachedDB(statesql.InMemory(), zap.NewNop()), atxsdata.New(), cl.GoldenATX(), syncer, @@ -156,7 +156,7 @@ func TestPostMalfeasanceProof(t *testing.T) { require.NoError(t, grpcPrivateServer.Start()) t.Cleanup(func() { assert.NoError(t, grpcPrivateServer.Close()) }) - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() certClient := activation.NewCertifierClient(db, localDb, logger.Named("certifier")) certifier := activation.NewCertifier(localDb, logger, certClient) diff --git a/tortoise/model/core.go b/tortoise/model/core.go index 88aa0cfa9dd..0dc78d43da5 100644 --- a/tortoise/model/core.go +++ b/tortoise/model/core.go @@ -19,6 +19,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -28,7 +29,7 @@ const ( ) func newCore(rng *rand.Rand, id string, logger *zap.Logger) *core { - cdb := datastore.NewCachedDB(sql.InMemory(), logger) + cdb := datastore.NewCachedDB(statesql.InMemory(), logger) sig, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) if err != nil { panic(err) @@ -134,7 +135,7 @@ func (c *core) OnMessage(m Messenger, event Message) { if ev.LayerID.After(types.GetEffectiveGenesis()) { tortoise.RecoverLayer(context.Background(), c.tortoise, - c.cdb.Executor, + c.cdb.Database, c.atxdata, ev.LayerID, c.tortoise.OnBallot, diff --git a/tortoise/recover_test.go b/tortoise/recover_test.go index ecc4c6b51b7..be41f02dd64 100644 --- a/tortoise/recover_test.go +++ b/tortoise/recover_test.go @@ -58,7 +58,7 @@ func TestRecoverState(t *testing.T) { tortoise2, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, simState.Atxdata, last, WithLogger(lg), @@ -82,7 +82,7 @@ func TestRecoverEmpty(t *testing.T) { cfg.LayerSize = size tortoise, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), 100, WithLogger(zaptest.NewLogger(t)), @@ -108,18 +108,18 @@ func TestRecoverWithOpinion(t *testing.T) { var last result.Layer for _, rst := range trt.Updates() { if rst.Verified { - require.NoError(t, layers.SetMeshHash(s.GetState(0).DB.Executor, rst.Layer, rst.Opinion)) + require.NoError(t, layers.SetMeshHash(s.GetState(0).DB.Database, rst.Layer, rst.Opinion)) } for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } last = rst } tortoise, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last.Layer, WithLogger(lg), @@ -156,14 +156,14 @@ func TestResetPending(t *testing.T) { require.NoError(t, layers.SetMeshHash(s.GetState(0).DB, rst.Layer, rst.Opinion)) for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } } recovered, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last, WithLogger(lg), @@ -203,14 +203,14 @@ func TestWindowRecovery(t *testing.T) { require.NoError(t, layers.SetMeshHash(s.GetState(0).DB, rst.Layer, rst.Opinion)) for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } } recovered, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last, WithLogger(lg), @@ -239,7 +239,7 @@ func TestRecoverOnlyAtxs(t *testing.T) { trt.TallyVotes(context.Background(), lid) } future := last + 1000 - recovered, err := Recover(context.Background(), s.GetState(0).DB.Executor, s.GetState(0).Atxdata, future, + recovered, err := Recover(context.Background(), s.GetState(0).DB.Database, s.GetState(0).Atxdata, future, WithLogger(zaptest.NewLogger(t)), WithConfig(cfg), ) diff --git a/tortoise/replay/replay_test.go b/tortoise/replay/replay_test.go index 77e71079ee0..9f6dd399fa3 100644 --- a/tortoise/replay/replay_test.go +++ b/tortoise/replay/replay_test.go @@ -15,8 +15,8 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/config" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/timesync" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -50,7 +50,7 @@ func TestReplayMainnet(t *testing.T) { ) require.NoError(t, err) - db, err := sql.Open(fmt.Sprintf("file:%s?mode=ro", *dbpath)) + db, err := statesql.Open(fmt.Sprintf("file:%s?mode=ro", *dbpath)) require.NoError(t, err) applied, err := layers.GetLastApplied(db) diff --git a/tortoise/sim/utils.go b/tortoise/sim/utils.go index e7f66f06644..86337c45514 100644 --- a/tortoise/sim/utils.go +++ b/tortoise/sim/utils.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -16,13 +17,13 @@ const ( func newCacheDB(logger *zap.Logger, conf config) *datastore.CachedDB { var ( - db *sql.Database + db sql.StateDatabase err error ) if len(conf.Path) == 0 { - db = sql.InMemory() + db = statesql.InMemory() } else { - db, err = sql.Open(filepath.Join(conf.Path, atxpath)) + db, err = statesql.Open(filepath.Join(conf.Path, atxpath)) if err != nil { panic(err) } diff --git a/tortoise/threshold_test.go b/tortoise/threshold_test.go index 7fd299fa784..bf4f52c367d 100644 --- a/tortoise/threshold_test.go +++ b/tortoise/threshold_test.go @@ -8,8 +8,8 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestComputeThreshold(t *testing.T) { @@ -165,7 +165,7 @@ func TestReferenceHeight(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i, height := range tc.heights { atx := &types.ActivationTx{ PublishEpoch: types.EpochID(tc.epoch) - 1, diff --git a/tortoise/tortoise_test.go b/tortoise/tortoise_test.go index 93187318247..9700cada28c 100644 --- a/tortoise/tortoise_test.go +++ b/tortoise/tortoise_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/tortoise/opinionhash" "github.com/spacemeshos/go-spacemesh/tortoise/sim" ) @@ -336,7 +337,7 @@ func tortoiseFromSimState(tb testing.TB, state sim.State, opts ...Opt) *recovery return &recoveryAdapter{ TB: tb, Tortoise: trtl, - db: state.DB.Executor, + db: state.DB.Database, atxdata: state.Atxdata, } } @@ -467,7 +468,7 @@ func TestComputeExpectedWeight(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { var ( - db = sql.InMemory() + db = statesql.InMemory() epochs = map[types.EpochID]*epochInfo{} first = tc.target.Add(1).GetEpoch() ) diff --git a/tortoise/tracer_test.go b/tortoise/tracer_test.go index b40cd7b26cd..21e217ee5ad 100644 --- a/tortoise/tracer_test.go +++ b/tortoise/tracer_test.go @@ -46,7 +46,7 @@ func TestTracer(t *testing.T) { path := filepath.Join(t.TempDir(), "tortoise.trace") trt, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, s.GetState(0).Atxdata, last, WithTracer(WithOutput(path)), diff --git a/txs/cache.go b/txs/cache.go index f15c10d3bb6..d3d35a605c7 100644 --- a/txs/cache.go +++ b/txs/cache.go @@ -332,7 +332,7 @@ func (ac *accountCache) add(logger *zap.Logger, tx *types.Transaction, received func (ac *accountCache) addPendingFromNonce( logger *zap.Logger, - db *sql.Database, + db sql.StateDatabase, nonce uint64, applied types.LayerID, ) error { @@ -423,7 +423,7 @@ func (ac *accountCache) getMempool(logger *zap.Logger) []*NanoTX { // because applying a layer changes the conservative balance in the cache. func (ac *accountCache) resetAfterApply( logger *zap.Logger, - db *sql.Database, + db sql.StateDatabase, nextNonce, newBalance uint64, applied types.LayerID, ) error { @@ -489,7 +489,7 @@ func groupTXsByPrincipal(logger *zap.Logger, mtxs []*types.MeshTransaction) map[ } // buildFromScratch builds the cache from database. -func (c *Cache) buildFromScratch(db *sql.Database) error { +func (c *Cache) buildFromScratch(db sql.StateDatabase) error { applied, err := layers.GetLastApplied(db) if err != nil { return fmt.Errorf("cache: get pending %w", err) @@ -606,7 +606,7 @@ func acceptable(err error) bool { func (c *Cache) Add( ctx context.Context, - db *sql.Database, + db sql.StateDatabase, tx *types.Transaction, received time.Time, mustPersist bool, @@ -653,7 +653,7 @@ func (c *Cache) has(tid types.TransactionID) bool { // LinkTXsWithProposal associates the transactions to a proposal. func (c *Cache) LinkTXsWithProposal( - db *sql.Database, + db sql.StateDatabase, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID, @@ -670,7 +670,7 @@ func (c *Cache) LinkTXsWithProposal( // LinkTXsWithBlock associates the transactions to a block. func (c *Cache) LinkTXsWithBlock( - db *sql.Database, + db sql.StateDatabase, lid types.LayerID, bid types.BlockID, tids []types.TransactionID, @@ -702,7 +702,7 @@ func (c *Cache) updateLayer(lid types.LayerID, bid types.BlockID, tids []types.T return nil } -func (c *Cache) applyEmptyLayer(db *sql.Database, lid types.LayerID) error { +func (c *Cache) applyEmptyLayer(db sql.StateDatabase, lid types.LayerID) error { c.mu.Lock() defer c.mu.Unlock() @@ -721,7 +721,7 @@ func (c *Cache) applyEmptyLayer(db *sql.Database, lid types.LayerID) error { // ApplyLayer retires the applied transactions from the cache and updates the balances. func (c *Cache) ApplyLayer( ctx context.Context, - db *sql.Database, + db sql.StateDatabase, lid types.LayerID, bid types.BlockID, results []types.TransactionWithResult, @@ -749,7 +749,7 @@ func (c *Cache) ApplyLayer( // commit results before reporting them // TODO(dshulyak) save results in vm - if err := db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + if err := db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, rst := range results { err := transactions.AddResult(dbtx, rst.ID, &rst.TransactionResult) if err != nil { @@ -838,7 +838,7 @@ func (c *Cache) ApplyLayer( return nil } -func (c *Cache) RevertToLayer(db *sql.Database, revertTo types.LayerID) error { +func (c *Cache) RevertToLayer(db sql.StateDatabase, revertTo types.LayerID) error { if err := undoLayers(db, revertTo.Add(1)); err != nil { return err } @@ -879,7 +879,7 @@ func (c *Cache) GetMempool(logger *zap.Logger) map[types.Address][]*NanoTX { } // checkApplyOrder returns an error if layers were not applied in order. -func checkApplyOrder(logger *zap.Logger, db *sql.Database, toApply types.LayerID) error { +func checkApplyOrder(logger *zap.Logger, db sql.StateDatabase, toApply types.LayerID) error { lastApplied, err := layers.GetLastApplied(db) if err != nil { logger.Error("failed to get last applied layer", zap.Error(err)) @@ -895,8 +895,8 @@ func checkApplyOrder(logger *zap.Logger, db *sql.Database, toApply types.LayerID return nil } -func addToProposal(db *sql.Database, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func addToProposal(db sql.StateDatabase, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tid := range tids { if err := transactions.AddToProposal(dbtx, tid, lid, pid); err != nil { return fmt.Errorf("add2prop %w", err) @@ -906,8 +906,8 @@ func addToProposal(db *sql.Database, lid types.LayerID, pid types.ProposalID, ti }) } -func addToBlock(db *sql.Database, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func addToBlock(db sql.StateDatabase, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tid := range tids { if err := transactions.AddToBlock(dbtx, tid, lid, bid); err != nil { return fmt.Errorf("add2block %w", err) @@ -917,8 +917,8 @@ func addToBlock(db *sql.Database, lid types.LayerID, bid types.BlockID, tids []t }) } -func undoLayers(db *sql.Database, from types.LayerID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func undoLayers(db sql.StateDatabase, from types.LayerID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { err := transactions.UndoLayers(dbtx, from) if err != nil { return fmt.Errorf("undo %w", err) diff --git a/txs/cache_test.go b/txs/cache_test.go index 96ab805074d..2658cb1dfc1 100644 --- a/txs/cache_test.go +++ b/txs/cache_test.go @@ -13,12 +13,13 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) type testCache struct { *Cache - db *sql.Database + db sql.StateDatabase } type testAcct struct { @@ -67,7 +68,7 @@ func newMeshTX( func genAndSaveTXs( t *testing.T, - db *sql.Database, + db sql.StateDatabase, signer *signing.EdSigner, from, to uint64, startTime time.Time, @@ -88,14 +89,14 @@ func genTXs(t *testing.T, signer *signing.EdSigner, from, to uint64, startTime t return mtxs } -func saveTXs(t *testing.T, db *sql.Database, mtxs []*types.MeshTransaction) { +func saveTXs(t *testing.T, db sql.StateDatabase, mtxs []*types.MeshTransaction) { t.Helper() for _, mtx := range mtxs { require.NoError(t, transactions.Add(db, &mtx.Transaction, mtx.Received)) } } -func checkTXStateFromDB(t *testing.T, db *sql.Database, txs []*types.MeshTransaction, state types.TXState) { +func checkTXStateFromDB(t *testing.T, db sql.StateDatabase, txs []*types.MeshTransaction, state types.TXState) { for _, mtx := range txs { got, err := transactions.Get(db, mtx.ID) require.NoError(t, err) @@ -103,7 +104,7 @@ func checkTXStateFromDB(t *testing.T, db *sql.Database, txs []*types.MeshTransac } } -func checkTXNotInDB(t *testing.T, db *sql.Database, tid types.TransactionID) { +func checkTXNotInDB(t *testing.T, db sql.StateDatabase, tid types.TransactionID) { _, err := transactions.Get(db, tid) require.ErrorIs(t, err, sql.ErrNotFound) } @@ -169,7 +170,7 @@ func createState(tb testing.TB, numAccounts int) map[types.Address]*testAcct { func createCache(tb testing.TB, numAccounts int) (*testCache, map[types.Address]*testAcct) { tb.Helper() accounts := createState(tb, numAccounts) - db := sql.InMemory() + db := statesql.InMemory() return &testCache{ Cache: NewCache(getStateFunc(accounts), zaptest.NewLogger(tb)), db: db, @@ -183,7 +184,7 @@ func createSingleAccountTestCache(tb testing.TB) (*testCache, *testAcct) { principal := types.GenerateAddress(signer.PublicKey().Bytes()) ta := &testAcct{signer: signer, principal: principal, nonce: rand.Uint64()%1000 + 1, balance: defaultBalance} states := map[types.Address]*testAcct{principal: ta} - db := sql.InMemory() + db := statesql.InMemory() return &testCache{ Cache: NewCache(getStateFunc(states), zaptest.NewLogger(tb)), db: db, diff --git a/txs/conservative_state.go b/txs/conservative_state.go index 88912319b39..def31171bf4 100644 --- a/txs/conservative_state.go +++ b/txs/conservative_state.go @@ -54,12 +54,12 @@ type ConservativeState struct { logger *zap.Logger cfg CSConfig - db *sql.Database + db sql.StateDatabase cache *Cache } // NewConservativeState returns a ConservativeState. -func NewConservativeState(state vmState, db *sql.Database, opts ...ConservativeStateOpt) *ConservativeState { +func NewConservativeState(state vmState, db sql.StateDatabase, opts ...ConservativeStateOpt) *ConservativeState { cs := &ConservativeState{ vmState: state, cfg: defaultCSConfig(), diff --git a/txs/conservative_state_test.go b/txs/conservative_state_test.go index 46bad3a5a28..f9dd30478be 100644 --- a/txs/conservative_state_test.go +++ b/txs/conservative_state_test.go @@ -25,6 +25,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -73,7 +74,7 @@ func newTxWthRecipient( type testConState struct { *ConservativeState logger *zap.Logger - db *sql.Database + db sql.StateDatabase mvm *MockvmState id peer.ID @@ -86,7 +87,7 @@ func (t *testConState) handler() *TxHandler { func createTestState(t *testing.T, gasLimit uint64) *testConState { ctrl := gomock.NewController(t) mvm := NewMockvmState(ctrl) - db := sql.InMemory() + db := statesql.InMemory() cfg := CSConfig{ BlockGasLimit: gasLimit, NumTXsPerProposal: numTXsInProposal,