From 9402957674f1b61863f8ae488ee1caeedf1f649a Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Sun, 26 May 2024 17:36:16 -0400 Subject: [PATCH] Implement support for metadata associated with read-write transactions This will allow callers of APIs such as WriteRelationships and DeleteRelationships to assign metadata to the transaction that will be mirrored back out in the Watch API, to provide a means for correlating updates --- go.mod | 1 + go.sum | 2 + internal/datastore/common/changes.go | 51 +++++++++-- internal/datastore/common/changes_test.go | 40 +++++++-- internal/datastore/crdb/crdb.go | 22 +++++ ...ion.0008_add_transaction_metadata_table.go | 88 ++++++++++++++++++ internal/datastore/crdb/watch.go | 31 ++++++- internal/datastore/memdb/memdb.go | 12 ++- internal/datastore/mysql/datastore.go | 17 ++-- internal/datastore/mysql/datastore_test.go | 2 +- ....0009_add_metadata_to_transaction_table.go | 18 ++++ internal/datastore/mysql/query_builder.go | 10 +-- internal/datastore/mysql/readwrite.go | 14 +-- internal/datastore/mysql/revisions.go | 18 ++-- internal/datastore/mysql/watch.go | 55 ++++++++++-- ....0019_add_metadata_to_transaction_table.go | 23 +++++ internal/datastore/postgres/postgres.go | 15 ++-- .../postgres/postgres_shared_test.go | 2 +- internal/datastore/postgres/revisions.go | 16 +++- internal/datastore/postgres/watch.go | 21 +++-- ...ion.0010_add_transaction_metadata_table.go | 36 ++++++++ internal/datastore/spanner/schema.go | 4 + internal/datastore/spanner/spanner.go | 39 +++++++- internal/datastore/spanner/watch.go | 30 ++++++- pkg/datastore/datastore.go | 4 + pkg/datastore/options/options.go | 5 +- .../options/zz_generated.query_options.go | 10 +++ pkg/datastore/test/datastore.go | 1 + pkg/datastore/test/watch.go | 90 +++++++++++++++++++ 29 files changed, 606 insertions(+), 71 deletions(-) create mode 100644 internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go create mode 100644 internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go create mode 100644 internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go create mode 100644 internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go diff --git a/go.mod b/go.mod index a442172273..eae1720168 100644 --- a/go.mod +++ b/go.mod @@ -110,6 +110,7 @@ require ( ) require ( + github.com/Masterminds/semver v1.5.0 github.com/Yiling-J/theine-go v0.4.1 github.com/ccoveille/go-safecast v1.1.0 github.com/gosimple/slug v1.14.0 diff --git a/go.sum b/go.sum index d542545727..aa54e412ef 100644 --- a/go.sum +++ b/go.sum @@ -654,6 +654,8 @@ github.com/IBM/pgxpoolprometheus v1.1.1/go.mod h1:GFJDkHbidFfB2APbhBTSy2X4PKH3bL github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26yLj/V+ulKp8= github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0= github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM= diff --git a/internal/datastore/common/changes.go b/internal/datastore/common/changes.go index 78d7636580..1cc18663fb 100644 --- a/internal/datastore/common/changes.go +++ b/internal/datastore/common/changes.go @@ -5,6 +5,7 @@ import ( "sort" "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/structpb" "github.com/ccoveille/go-safecast" @@ -37,6 +38,7 @@ type changeRecord[R datastore.Revision] struct { definitionsChanged map[string]datastore.SchemaDefinition namespacesDeleted map[string]struct{} caveatsDeleted map[string]struct{} + metadata map[string]any } // NewChanges creates a new Changes object for change tracking and de-duplication. @@ -132,6 +134,25 @@ func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error { return nil } +// SetRevisionMetadata sets the metadata for the given revision. +func (ch *Changes[R, K]) SetRevisionMetadata(ctx context.Context, rev R, metadata map[string]any) error { + if len(metadata) == 0 { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + if len(record.metadata) > 0 { + return spiceerrors.MustBugf("metadata already set for revision") + } + + maps.Copy(record.metadata, metadata) + return nil +} + func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { k := ch.keyFunc(rev) revisionChanges, ok := ch.records[k] @@ -143,6 +164,7 @@ func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { make(map[string]datastore.SchemaDefinition), make(map[string]struct{}), make(map[string]struct{}), + make(map[string]any), } ch.records[k] = revisionChanges } @@ -248,21 +270,25 @@ func (ch *Changes[R, K]) AddChangedDefinition( // AsRevisionChanges returns the list of changes processed so far as a datastore watch // compatible, ordered, changelist. -func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) []datastore.RevisionChanges { +func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) ([]datastore.RevisionChanges, error) { return ch.revisionChanges(lessThanFunc, *new(R), false) } // FilterAndRemoveRevisionChanges filters a list of changes processed up to the bound revision from the changes list, removing them // and returning the filtered changes. -func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) []datastore.RevisionChanges { - changes := ch.revisionChanges(lessThanFunc, boundRev, true) +func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) ([]datastore.RevisionChanges, error) { + changes, err := ch.revisionChanges(lessThanFunc, boundRev, true) + if err != nil { + return nil, err + } + ch.removeAllChangesBefore(boundRev) - return changes + return changes, nil } -func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) []datastore.RevisionChanges { +func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) ([]datastore.RevisionChanges, error) { if ch.IsEmpty() { - return nil + return nil, nil } revisionsWithChanges := make([]K, 0, len(ch.records)) @@ -273,7 +299,7 @@ func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, bou } if len(revisionsWithChanges) == 0 { - return nil + return nil, nil } sort.Slice(revisionsWithChanges, func(i int, j int) bool { @@ -299,9 +325,18 @@ func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, bou changes[i].ChangedDefinitions = maps.Values(revisionChangeRecord.definitionsChanged) changes[i].DeletedNamespaces = maps.Keys(revisionChangeRecord.namespacesDeleted) changes[i].DeletedCaveats = maps.Keys(revisionChangeRecord.caveatsDeleted) + + if len(revisionChangeRecord.metadata) > 0 { + metadata, err := structpb.NewStruct(revisionChangeRecord.metadata) + if err != nil { + return nil, spiceerrors.MustBugf("failed to convert metadata to structpb: %v", err) + } + + changes[i].Metadata = metadata + } } - return changes + return changes, nil } func (ch *Changes[R, K]) removeAllChangesBefore(boundRev R) { diff --git a/internal/datastore/common/changes_test.go b/internal/datastore/common/changes_test.go index db496a4d77..7b16ab3e46 100644 --- a/internal/datastore/common/changes_test.go +++ b/internal/datastore/common/changes_test.go @@ -330,9 +330,12 @@ func TestChanges(t *testing.T) { } } + actual, err := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + require.NoError(err) + require.Equal( canonicalize(tc.expected), - canonicalize(ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc)), + canonicalize(actual), ) }) } @@ -347,6 +350,23 @@ func TestFilteredSchemaChanges(t *testing.T) { require.True(t, ch.IsEmpty()) } +func TestSetMetadata(t *testing.T) { + ctx := context.Background() + ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) + require.True(t, ch.IsEmpty()) + + err := ch.SetRevisionMetadata(ctx, rev1, map[string]any{"foo": "bar"}) + require.NoError(t, err) + require.False(t, ch.IsEmpty()) + + results, err := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev2) + require.NoError(t, err) + require.Equal(t, 1, len(results)) + require.True(t, ch.IsEmpty()) + + require.Equal(t, map[string]any{"foo": "bar"}, results[0].Metadata.AsMap()) +} + func TestFilteredRelationshipChanges(t *testing.T) { ctx := context.Background() ch := NewChanges(revisions.TransactionIDKeyFunc, datastore.WatchRelationships, 0) @@ -374,7 +394,8 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { require.False(t, ch.IsEmpty()) - results := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev3) + results, err := ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, rev3) + require.NoError(t, err) require.Equal(t, 2, len(results)) require.False(t, ch.IsEmpty()) @@ -393,8 +414,9 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { }, }, results) - remaining := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) require.Equal(t, 1, len(remaining)) + require.NoError(t, err) require.Equal(t, []datastore.RevisionChanges{ { @@ -405,11 +427,13 @@ func TestFilterAndRemoveRevisionChanges(t *testing.T) { }, }, remaining) - results = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillion) + results, err = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillion) + require.NoError(t, err) require.Equal(t, 1, len(results)) require.True(t, ch.IsEmpty()) - results = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillionOne) + results, err = ch.FilterAndRemoveRevisionChanges(revisions.TransactionIDKeyLessThanFunc, revOneMillionOne) + require.NoError(t, err) require.Equal(t, 0, len(results)) require.True(t, ch.IsEmpty()) } @@ -432,7 +456,8 @@ func TestHLCOrdering(t *testing.T) { err = ch.AddRelationshipChange(ctx, rev0, tuple.MustParse("document:foo#viewer@user:tom"), core.RelationTupleUpdate_TOUCH) require.NoError(t, err) - remaining := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + require.NoError(t, err) require.Equal(t, 2, len(remaining)) require.Equal(t, []datastore.RevisionChanges{ @@ -475,7 +500,8 @@ func TestHLCSameRevision(t *testing.T) { err = ch.AddRelationshipChange(ctx, rev0again, tuple.MustParse("document:foo#viewer@user:sarah"), core.RelationTupleUpdate_TOUCH) require.NoError(t, err) - remaining := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + remaining, err := ch.AsRevisionChanges(revisions.HLCKeyLessThanFunc) + require.NoError(t, err) require.Equal(t, 1, len(remaining)) expected := []*core.RelationTupleUpdate{ diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 6d58c8b102..16f67030f1 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -54,6 +54,7 @@ const ( tableTransactions = "transactions" tableCaveat = "caveat" tableRelationshipCounter = "relationship_counter" + tableTransactionMetadata = "transaction_metadata" colNamespace = "namespace" colConfig = "serialized_config" @@ -79,6 +80,8 @@ const ( colCounterSerializedFilter = "serialized_filter" colCounterCurrentCount = "current_count" colCounterUpdatedAt = "updated_at_timestamp" + colExpiresAt = "expires_at" + colMetadata = "metadata" errUnableToInstantiate = "unable to instantiate datastore" errRevision = "unable to find revision: %w" @@ -207,6 +210,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas analyzeBeforeStatistics: config.analyzeBeforeStatistics, filterMaximumIDCount: config.filterMaximumIDCount, supportsIntegrity: config.withIntegrity, + gcWindow: config.gcWindow, } ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal) @@ -289,6 +293,7 @@ type crdbDatastore struct { writeOverlapKeyer overlapKeyer overlapKeyInit func(ctx context.Context) keySet analyzeBeforeStatistics bool + gcWindow time.Duration beginChangefeedQuery string transactionNowQuery string @@ -332,6 +337,23 @@ func (cds *crdbDatastore) ReadWriteTx( Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(querier, cds.supportsIntegrity), } + // If metadata is to be attached, write that row now. + if config.Metadata != nil { + expiresAt := time.Now().Add(cds.gcWindow).Add(1 * time.Minute) + insertTransactionMetadata := psql.Insert(tableTransactionMetadata). + Columns(colExpiresAt, colMetadata). + Values(expiresAt, config.Metadata.AsMap()) + + sql, args, err := insertTransactionMetadata.ToSql() + if err != nil { + return fmt.Errorf("error building metadata insert: %w", err) + } + + if _, err := tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("error writing metadata: %w", err) + } + } + rwt := &crdbReadWriteTXN{ &crdbReader{ querier, diff --git a/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go b/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go new file mode 100644 index 0000000000..9eda5e34fe --- /dev/null +++ b/internal/datastore/crdb/migrations/zz_migration.0008_add_transaction_metadata_table.go @@ -0,0 +1,88 @@ +package migrations + +import ( + "context" + "fmt" + "regexp" + + "github.com/Masterminds/semver" + "github.com/jackc/pgx/v5" +) + +const ( + // ttl_expiration_expression support was added in CRDB v22.2, but the E2E tests + // use v21.2. + addTransactionMetadataTableQueryWithBasicTTL = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expire_after = '1d'); + ` + + addTransactionMetadataTableQuery = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expiration_expression = 'expires_at', ttl_job_cron = '@daily'); + ` + + // See: https://www.cockroachlabs.com/docs/stable/changefeed-messages#prevent-changefeeds-from-emitting-row-level-ttl-deletes + // for why we set ttl_disable_changefeed_replication = 'true'. This isn't stricly necessary as the Watch API will ignore the + // deletions of these metadata rows, but no reason to even have it in the changefeed. + // NOTE: This only applies on CRDB v24 and later. + addTransactionMetadataTableQueryWithTTLIgnore = ` + CREATE TABLE transaction_metadata ( + key UUID PRIMARY KEY DEFAULT gen_random_uuid(), + expires_at TIMESTAMPTZ, + metadata JSONB + ) WITH (ttl_expiration_expression = 'expires_at', ttl_job_cron = '@daily', ttl_disable_changefeed_replication = 'true'); + ` +) + +func init() { + err := CRDBMigrations.Register("add-transaction-metadata-table", "add-integrity-relationtuple-table", addTransactionMetadataTable, noAtomicMigration) + if err != nil { + panic("failed to register migration: " + err.Error()) + } +} + +func addTransactionMetadataTable(ctx context.Context, conn *pgx.Conn) error { + row := conn.QueryRow(ctx, "select version()") + var fullVersionString string + if err := row.Scan(&fullVersionString); err != nil { + return err + } + + re, err := regexp.Compile(semver.SemVerRegex) + if err != nil { + return fmt.Errorf("failed to compile regex: %w", err) + } + + version := re.FindString(fullVersionString) + v, err := semver.NewVersion(version) + if err != nil { + return fmt.Errorf("failed to parse version %q: %w", version, err) + } + + if v.Major() < 22 { + return fmt.Errorf("unsupported version %q", version) + } + + // v22.1 doesn't support `ttl_expiration_expression`; it was added in v22.2. + if v.Major() == 22 && v.Minor() == 1 { + _, err := conn.Exec(ctx, addTransactionMetadataTableQueryWithBasicTTL) + return err + } + + // `ttl_disable_changefeed_replication` was added in v24. + if v.Major() < 24 { + _, err := conn.Exec(ctx, addTransactionMetadataTableQuery) + return err + } + + // v24 and later + _, err = conn.Exec(ctx, addTransactionMetadataTableQueryWithTTLIgnore) + return err +} diff --git a/internal/datastore/crdb/watch.go b/internal/datastore/crdb/watch.go index 2241ae33f7..d03989c6fe 100644 --- a/internal/datastore/crdb/watch.go +++ b/internal/datastore/crdb/watch.go @@ -55,6 +55,8 @@ type changeDetails struct { IntegrityKeyID *string `json:"integrity_key_id"` IntegrityHashAsHex *string `json:"integrity_hash"` TimestampAsString *string `json:"timestamp"` + + Metadata map[string]any `json:"metadata"` } } @@ -110,7 +112,8 @@ func (cds *crdbDatastore) watch( } defer func() { _ = conn.Close(ctx) }() - tableNames := make([]string, 0, 3) + tableNames := make([]string, 0, 4) + tableNames = append(tableNames, tableTransactionMetadata) if opts.Content&datastore.WatchRelationships == datastore.WatchRelationships { tableNames = append(tableNames, cds.tableTupleName()) } @@ -217,7 +220,13 @@ func (cds *crdbDatastore) watch( return } - for _, revChange := range tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev) { + filtered, err := tracked.FilterAndRemoveRevisionChanges(revisions.HLCKeyLessThanFunc, rev) + if err != nil { + sendError(err) + return + } + + for _, revChange := range filtered { revChange := revChange if !sendChange(&revChange) { return @@ -393,6 +402,24 @@ func (cds *crdbDatastore) watch( return } } + + case tableTransactionMetadata: + if details.After != nil { + rev, err := revisions.HLCRevisionFromString(details.Updated) + if err != nil { + sendError(fmt.Errorf("malformed update timestamp: %w", err)) + return + } + + if err := tracked.SetRevisionMetadata(ctx, rev, details.After.Metadata); err != nil { + sendError(err) + return + } + } + + default: + sendError(spiceerrors.MustBugf("unexpected table name in changefeed: %s", tableName)) + return } } diff --git a/internal/datastore/memdb/memdb.go b/internal/datastore/memdb/memdb.go index 04b16a80bf..40354e80d7 100644 --- a/internal/datastore/memdb/memdb.go +++ b/internal/datastore/memdb/memdb.go @@ -204,6 +204,12 @@ func (mdb *memdbDatastore) ReadWriteTx( tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) if tx != nil { + if config.Metadata != nil { + if err := tracked.SetRevisionMetadata(ctx, newRevision, config.Metadata.AsMap()); err != nil { + return datastore.NoRevision, err + } + } + for _, change := range tx.Changes() { switch change.Table { case tableRelationship: @@ -270,7 +276,11 @@ func (mdb *memdbDatastore) ReadWriteTx( } var rc datastore.RevisionChanges - changes := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + if err != nil { + return datastore.NoRevision, err + } + if len(changes) > 1 { return datastore.NoRevision, spiceerrors.MustBugf("unexpected MemDB transaction with multiple revision changes") } else if len(changes) == 1 { diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index 40e15ef2de..3309c1b468 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -37,6 +37,7 @@ const ( colID = "id" colTimestamp = "timestamp" + colMetadata = "metadata" colNamespace = "namespace" colConfig = "serialized_config" colCreatedTxn = "created_transaction" @@ -207,10 +208,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option driver := migrations.NewMySQLDriverFromDB(db, config.tablePrefix) queryBuilder := NewQueryBuilder(driver) - createTxn, _, err := sb.Insert(driver.RelationTupleTransaction()).Values().ToSql() - if err != nil { - return nil, fmt.Errorf("NewMySQLDatastore: %w", err) - } + createTxn := sb.Insert(driver.RelationTupleTransaction()).Columns(colMetadata) // used for seeding the initial relation_tuple_transaction. using INSERT IGNORE on a known // ID value makes this idempotent (i.e. safe to execute concurrently). @@ -339,7 +337,12 @@ func (mds *Datastore) ReadWriteTx( for i := uint8(0); i <= mds.maxRetries; i++ { var newTxnID uint64 if err = migrations.BeginTxFunc(ctx, mds.db, &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) error { - newTxnID, err = mds.createNewTransaction(ctx, tx) + var metadata map[string]any + if config.Metadata != nil { + metadata = config.Metadata.AsMap() + } + + newTxnID, err = mds.createNewTransaction(ctx, tx, metadata) if err != nil { return fmt.Errorf("unable to create new txn ID: %w", err) } @@ -442,7 +445,7 @@ func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper err := rows.Scan( &nextTuple.ResourceAndRelation.Namespace, &nextTuple.ResourceAndRelation.ObjectId, @@ -497,7 +500,7 @@ type Datastore struct { cancelGc context.CancelFunc gcHasRun atomic.Bool - createTxn string + createTxn sq.InsertBuilder createBaseTxn string *QueryBuilder diff --git a/internal/datastore/mysql/datastore_test.go b/internal/datastore/mysql/datastore_test.go index 23390ba1eb..c9eeeec066 100644 --- a/internal/datastore/mysql/datastore_test.go +++ b/internal/datastore/mysql/datastore_test.go @@ -640,7 +640,7 @@ func TransactionTimestampsTest(t *testing.T, ds datastore.Datastore) { // Transaction timestamp should not be stored in system time zone tx, err := db.BeginTx(ctx, nil) req.NoError(err) - txID, err := ds.(*Datastore).createNewTransaction(ctx, tx) + txID, err := ds.(*Datastore).createNewTransaction(ctx, tx, nil) req.NoError(err) err = tx.Commit() req.NoError(err) diff --git a/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go b/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go new file mode 100644 index 0000000000..64e66a83e8 --- /dev/null +++ b/internal/datastore/mysql/migrations/zz_migration.0009_add_metadata_to_transaction_table.go @@ -0,0 +1,18 @@ +package migrations + +import "fmt" + +func addMetadataToTransactionTable(t *tables) string { + return fmt.Sprintf(`ALTER TABLE %s + ADD COLUMN metadata BLOB NULL DEFAULT NULL;`, + t.RelationTupleTransaction(), + ) +} + +func init() { + mustRegisterMigration("add_metadata_to_transaction_table", "add_relationship_counters_table", noNonatomicMigration, + newStatementBatch( + addMetadataToTransactionTable, + ).execute, + ) +} diff --git a/internal/datastore/mysql/query_builder.go b/internal/datastore/mysql/query_builder.go index 3e1b700811..2356dc605a 100644 --- a/internal/datastore/mysql/query_builder.go +++ b/internal/datastore/mysql/query_builder.go @@ -9,8 +9,8 @@ import ( // QueryBuilder captures all parameterizable queries used // by the MySQL datastore implementation type QueryBuilder struct { - GetLastRevision sq.SelectBuilder - GetRevisionRange sq.SelectBuilder + GetLastRevision sq.SelectBuilder + LoadRevisionRange sq.SelectBuilder WriteNamespaceQuery sq.InsertBuilder ReadNamespaceQuery sq.SelectBuilder @@ -43,7 +43,7 @@ func NewQueryBuilder(driver *migrations.MySQLDriver) *QueryBuilder { // transaction builders builder.GetLastRevision = getLastRevision(driver.RelationTupleTransaction()) - builder.GetRevisionRange = getRevisionRange(driver.RelationTupleTransaction()) + builder.LoadRevisionRange = loadRevisionRange(driver.RelationTupleTransaction()) // namespace builders builder.WriteNamespaceQuery = writeNamespace(driver.Namespace()) @@ -99,8 +99,8 @@ func getLastRevision(tableTransaction string) sq.SelectBuilder { return sb.Select("MAX(id)").From(tableTransaction).Limit(1) } -func getRevisionRange(tableTransaction string) sq.SelectBuilder { - return sb.Select("MIN(id)", "MAX(id)").From(tableTransaction) +func loadRevisionRange(tableTransaction string) sq.SelectBuilder { + return sb.Select(colID, colMetadata).From(tableTransaction) } func readCounter(tableRelationshipCounters string) sq.SelectBuilder { diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index 7a283d8d5a..6f536a7754 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -50,10 +50,10 @@ type mysqlReadWriteTXN struct { newTxnID uint64 } -// caveatContextWrapper is used to marshall maps into MySQLs JSON data type -type caveatContextWrapper map[string]any +// structpbWrapper is used to marshall maps into MySQLs JSON data type +type structpbWrapper map[string]any -func (cc *caveatContextWrapper) Scan(val any) error { +func (cc *structpbWrapper) Scan(val any) error { v, ok := val.([]byte) if !ok { return fmt.Errorf("unsupported type: %T", v) @@ -61,7 +61,7 @@ func (cc *caveatContextWrapper) Scan(val any) error { return json.Unmarshal(v, &cc) } -func (cc *caveatContextWrapper) Value() (driver.Value, error) { +func (cc *structpbWrapper) Value() (driver.Value, error) { return json.Marshal(&cc) } @@ -221,7 +221,7 @@ func (rwt *mysqlReadWriteTXN) WriteRelationships(ctx context.Context, mutations } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper tupleIdsToDelete := make([]int64, 0, len(clauses)) for rows.Next() { @@ -282,7 +282,7 @@ func (rwt *mysqlReadWriteTXN) WriteRelationships(ctx context.Context, mutations tpl := mut.Tuple var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper if tpl.Caveat != nil { caveatName = tpl.Caveat.CaveatName caveatContext = tpl.Caveat.Context.AsMap() @@ -503,7 +503,7 @@ func (rwt *mysqlReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkW } var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper if tpl.Caveat != nil { caveatName = tpl.Caveat.CaveatName caveatContext = tpl.Caveat.Context.AsMap() diff --git a/internal/datastore/mysql/revisions.go b/internal/datastore/mysql/revisions.go index 1128622322..1302f8883a 100644 --- a/internal/datastore/mysql/revisions.go +++ b/internal/datastore/mysql/revisions.go @@ -77,8 +77,6 @@ func (mds *Datastore) optimizedRevisionFunc(ctx context.Context) (datastore.Revi } func (mds *Datastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { - // implementation deviates slightly from PSQL implementation in order to support - // database seeding in runtime, instead of through migrate command revision, err := mds.loadRevision(ctx) if err != nil { return datastore.NoRevision, err @@ -159,16 +157,26 @@ func (mds *Datastore) checkValidTransaction(ctx context.Context, revisionTx uint return freshEnough.Bool, unknown.Bool, nil } -func (mds *Datastore) createNewTransaction(ctx context.Context, tx *sql.Tx) (newTxnID uint64, err error) { +func (mds *Datastore) createNewTransaction(ctx context.Context, tx *sql.Tx, metadata map[string]any) (newTxnID uint64, err error) { ctx, span := tracer.Start(ctx, "createNewTransaction") defer span.End() - createQuery := mds.createTxn + var wrappedMetadata structpbWrapper + if len(metadata) > 0 { + wrappedMetadata = metadata + } + + createQuery := mds.createTxn.Values(&wrappedMetadata) + if err != nil { + return 0, fmt.Errorf("createNewTransaction: %w", err) + } + + sql, args, err := createQuery.ToSql() if err != nil { return 0, fmt.Errorf("createNewTransaction: %w", err) } - result, err := tx.ExecContext(ctx, createQuery) + result, err := tx.ExecContext(ctx, sql, args...) if err != nil { return 0, fmt.Errorf("createNewTransaction: %w", err) } diff --git a/internal/datastore/mysql/watch.go b/internal/datastore/mysql/watch.go index f3c2dffc84..14e356796f 100644 --- a/internal/datastore/mysql/watch.go +++ b/internal/datastore/mysql/watch.go @@ -125,7 +125,52 @@ func (mds *Datastore) loadChanges( return } - sql, args, err := mds.QueryChangedQuery.Where(sq.Or{ + stagedChanges := common.NewChanges(revisions.TransactionIDKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + + // Load any metadata for the revision range. + sql, args, err := mds.LoadRevisionRange.Where(sq.Or{ + sq.And{ + sq.Gt{colID: afterRevision}, + sq.LtOrEq{colID: newRevision}, + }, + }).ToSql() + if err != nil { + return + } + + rows, err := mds.db.QueryContext(ctx, sql, args...) + if err != nil { + if errors.Is(err, context.Canceled) { + err = datastore.NewWatchCanceledErr() + } + return + } + defer common.LogOnError(ctx, rows.Close) + + for rows.Next() { + var txnID uint64 + var metadata structpbWrapper + err = rows.Scan( + &txnID, + &metadata, + ) + if err != nil { + return nil, 0, err + } + + if len(metadata) > 0 { + if err := stagedChanges.SetRevisionMetadata(ctx, revisions.NewForTransactionID(txnID), metadata); err != nil { + return nil, 0, err + } + } + } + rows.Close() + if err = rows.Err(); err != nil { + return + } + + // Load the changes relationships for the revision range. + sql, args, err = mds.QueryChangedQuery.Where(sq.Or{ sq.And{ sq.Gt{colCreatedTxn: afterRevision}, sq.LtOrEq{colCreatedTxn: newRevision}, @@ -139,7 +184,7 @@ func (mds *Datastore) loadChanges( return } - rows, err := mds.db.QueryContext(ctx, sql, args...) + rows, err = mds.db.QueryContext(ctx, sql, args...) if err != nil { if errors.Is(err, context.Canceled) { err = datastore.NewWatchCanceledErr() @@ -148,8 +193,6 @@ func (mds *Datastore) loadChanges( } defer common.LogOnError(ctx, rows.Close) - stagedChanges := common.NewChanges(revisions.TransactionIDKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) - for rows.Next() { nextTuple := &core.RelationTuple{ ResourceAndRelation: &core.ObjectAndRelation{}, @@ -159,7 +202,7 @@ func (mds *Datastore) loadChanges( var createdTxn uint64 var deletedTxn uint64 var caveatName string - var caveatContext caveatContextWrapper + var caveatContext structpbWrapper err = rows.Scan( &nextTuple.ResourceAndRelation.Namespace, &nextTuple.ResourceAndRelation.ObjectId, @@ -196,6 +239,6 @@ func (mds *Datastore) loadChanges( return } - changes = stagedChanges.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) + changes, err = stagedChanges.AsRevisionChanges(revisions.TransactionIDKeyLessThanFunc) return } diff --git a/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go b/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go new file mode 100644 index 0000000000..162ba71f18 --- /dev/null +++ b/internal/datastore/postgres/migrations/zz_migration.0019_add_metadata_to_transaction_table.go @@ -0,0 +1,23 @@ +package migrations + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +const addMetadataToTransactionTable = `ALTER TABLE relation_tuple_transaction ADD COLUMN IF NOT EXISTS metadata JSONB NOT NULL DEFAULT '{}'` + +func init() { + if err := DatabaseMigrations.Register("add-metadata-to-transaction-table", "create-relationships-counters-table", + func(ctx context.Context, conn *pgx.Conn) error { + if _, err := conn.Exec(ctx, addMetadataToTransactionTable); err != nil { + return fmt.Errorf("failed to add metadata to transaction table: %w", err) + } + return nil + }, + noTxMigration); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 00e0cefb4e..81648075fc 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -49,6 +49,7 @@ const ( colXID = "xid" colTimestamp = "timestamp" + colMetadata = "metadata" colNamespace = "namespace" colConfig = "serialized_config" colCreatedXid = "created_xid" @@ -102,12 +103,7 @@ var ( OrderByClause(fmt.Sprintf("%s DESC", colXID)). Limit(1) - createTxn = fmt.Sprintf( - "INSERT INTO %s DEFAULT VALUES RETURNING %s, %s", - tableTransaction, - colXID, - colSnapshot, - ) + createTxn = psql.Insert(tableTransaction).Columns(colMetadata) getNow = psql.Select("NOW()") @@ -441,7 +437,12 @@ func (pgd *pgDatastore) ReadWriteTx( var newSnapshot pgSnapshot err = wrapError(pgx.BeginTxFunc(ctx, pgd.writePool, pgx.TxOptions{IsoLevel: pgx.Serializable}, func(tx pgx.Tx) error { var err error - newXID, newSnapshot, err = createNewTransaction(ctx, tx) + var metadata map[string]any + if config.Metadata != nil { + metadata = config.Metadata.AsMap() + } + + newXID, newSnapshot, err = createNewTransaction(ctx, tx, metadata) if err != nil { return err } diff --git a/internal/datastore/postgres/postgres_shared_test.go b/internal/datastore/postgres/postgres_shared_test.go index d89b379506..6bd77b8271 100644 --- a/internal/datastore/postgres/postgres_shared_test.go +++ b/internal/datastore/postgres/postgres_shared_test.go @@ -440,7 +440,7 @@ func TransactionTimestampsTest(t *testing.T, ds datastore.Datastore) { tx, err := pgd.writePool.Begin(ctx) require.NoError(err) - txXID, _, err := createNewTransaction(ctx, tx) + txXID, _, err := createNewTransaction(ctx, tx, nil) require.NoError(err) err = tx.Commit(ctx) diff --git a/internal/datastore/postgres/revisions.go b/internal/datastore/postgres/revisions.go index e35a268d1a..e811971afb 100644 --- a/internal/datastore/postgres/revisions.go +++ b/internal/datastore/postgres/revisions.go @@ -250,11 +250,22 @@ func parseRevisionDecimal(revisionStr string) (datastore.Revision, error) { }}, nil } -func createNewTransaction(ctx context.Context, tx pgx.Tx) (newXID xid8, newSnapshot pgSnapshot, err error) { +var emptyMetadata = map[string]any{} + +func createNewTransaction(ctx context.Context, tx pgx.Tx, metadata map[string]any) (newXID xid8, newSnapshot pgSnapshot, err error) { ctx, span := tracer.Start(ctx, "createNewTransaction") defer span.End() - cterr := tx.QueryRow(ctx, createTxn).Scan(&newXID, &newSnapshot) + if metadata == nil { + metadata = emptyMetadata + } + + sql, args, err := createTxn.Values(metadata).Suffix("RETURNING " + colXID + ", " + colSnapshot).ToSql() + if err != nil { + return + } + + cterr := tx.QueryRow(ctx, sql, args...).Scan(&newXID, &newSnapshot) if cterr != nil { err = fmt.Errorf("error when trying to create a new transaction: %w", cterr) } @@ -265,6 +276,7 @@ type postgresRevision struct { snapshot pgSnapshot optionalTxID xid8 optionalNanosTimestamp uint64 + optionalMetadata map[string]any } func (pr postgresRevision) Equal(rhsRaw datastore.Revision) bool { diff --git a/internal/datastore/postgres/watch.go b/internal/datastore/postgres/watch.go index ab1a353c2c..d28b0c0c9d 100644 --- a/internal/datastore/postgres/watch.go +++ b/internal/datastore/postgres/watch.go @@ -27,10 +27,10 @@ var ( // xid8 is one of the last ~2 billion transaction IDs generated. We should be garbage // collecting these transactions long before we get to that point. newRevisionsQuery = fmt.Sprintf(` - SELECT %[1]s, %[2]s, %[3]s FROM %[4]s + SELECT %[1]s, %[2]s, %[3]s, %[4]s FROM %[5]s WHERE %[1]s >= pg_snapshot_xmax($1) OR ( %[1]s >= pg_snapshot_xmin($1) AND NOT pg_visible_in_snapshot(%[1]s, $1) - ) ORDER BY pg_xact_commit_timestamp(%[1]s::xid), %[1]s;`, colXID, colSnapshot, colTimestamp, tableTransaction) + ) ORDER BY pg_xact_commit_timestamp(%[1]s::xid), %[1]s;`, colXID, colSnapshot, colMetadata, colTimestamp, tableTransaction) queryChangedTuples = psql.Select( colNamespace, @@ -201,8 +201,9 @@ func (pgd *pgDatastore) getNewRevisions(ctx context.Context, afterTX postgresRev for rows.Next() { var nextXID xid8 var nextSnapshot pgSnapshot + var metadata map[string]any var timestamp time.Time - if err := rows.Scan(&nextXID, &nextSnapshot, ×tamp); err != nil { + if err := rows.Scan(&nextXID, &nextSnapshot, &metadata, ×tamp); err != nil { return fmt.Errorf("unable to decode new revision: %w", err) } @@ -215,6 +216,7 @@ func (pgd *pgDatastore) getNewRevisions(ctx context.Context, afterTX postgresRev snapshot: nextSnapshot.markComplete(nextXID.Uint64), optionalTxID: nextXID, optionalNanosTimestamp: nanosTimestamp, + optionalMetadata: metadata, }) } if rows.Err() != nil { @@ -234,6 +236,8 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev filter := make(map[uint64]int, len(revisions)) txidToRevision := make(map[uint64]postgresRevision, len(revisions)) + tracked := common.NewChanges(revisionKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + for i, rev := range revisions { if rev.optionalTxID.Uint64 < xmin { xmin = rev.optionalTxID.Uint64 @@ -243,9 +247,13 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev } filter[rev.optionalTxID.Uint64] = i txidToRevision[rev.optionalTxID.Uint64] = rev - } - tracked := common.NewChanges(revisionKeyFunc, options.Content, options.MaximumBufferedChangesByteSize) + if len(rev.optionalMetadata) > 0 { + if err := tracked.SetRevisionMetadata(ctx, rev, rev.optionalMetadata); err != nil { + return nil, err + } + } + } // Load relationship changes. if options.Content&datastore.WatchRelationships == datastore.WatchRelationships { @@ -272,10 +280,9 @@ func (pgd *pgDatastore) loadChanges(ctx context.Context, revisions []postgresRev } // Reconcile the changes. - reconciledChanges := tracked.AsRevisionChanges(func(lhs, rhs uint64) bool { + return tracked.AsRevisionChanges(func(lhs, rhs uint64) bool { return filter[lhs] < filter[rhs] }) - return reconciledChanges, nil } func (pgd *pgDatastore) loadRelationshipChanges(ctx context.Context, xmin uint64, xmax uint64, txidToRevision map[uint64]postgresRevision, filter map[uint64]int, tracked *common.Changes[postgresRevision, uint64]) error { diff --git a/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go b/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go new file mode 100644 index 0000000000..e23cfaf412 --- /dev/null +++ b/internal/datastore/spanner/migrations/zz_migration.0010_add_transaction_metadata_table.go @@ -0,0 +1,36 @@ +package migrations + +import ( + "context" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" +) + +const ( + // NOTE: We use 2 days here because Spanner only supports deletion policies at intervals of days. + // See: https://cloud.google.com/spanner/docs/ttl/working-with-ttl#create + addTransactionMetadataTable = `CREATE TABLE transaction_metadata ( + transaction_tag STRING(36) NOT NULL, + created_at TIMESTAMP DEFAULT (CURRENT_TIMESTAMP()), + metadata JSON + ) PRIMARY KEY (transaction_tag), + ROW DELETION POLICY (OLDER_THAN(created_at, INTERVAL 2 DAY)) + ` +) + +func init() { + if err := SpannerMigrations.Register("add-transaction-metadata-table", "add-relationship-counter-table", func(ctx context.Context, w Wrapper) error { + updateOp, err := w.adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ + Database: w.client.DatabaseName(), + Statements: []string{ + addTransactionMetadataTable, + }, + }) + if err != nil { + return err + } + return updateOp.Wait(ctx) + }, nil); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/spanner/schema.go b/internal/datastore/spanner/schema.go index 364fc549fc..3cde345a2a 100644 --- a/internal/datastore/spanner/schema.go +++ b/internal/datastore/spanner/schema.go @@ -30,6 +30,10 @@ const ( colCounterSerializedFilter = "serialized_filter" colCounterCurrentCount = "current_count" colCounterUpdatedAtTimestamp = "updated_at_timestamp" + + tableTransactionMetadata = "transaction_metadata" + colTransactionTag = "transaction_tag" + colMetadata = "metadata" ) var allRelationshipCols = []string{ diff --git a/internal/datastore/spanner/spanner.go b/internal/datastore/spanner/spanner.go index 19cc8537b4..76ec854b2a 100644 --- a/internal/datastore/spanner/spanner.go +++ b/internal/datastore/spanner/spanner.go @@ -12,6 +12,7 @@ import ( "cloud.google.com/go/spanner" ocprom "contrib.go.opencensus.io/exporter/prometheus" sq "github.com/Masterminds/squirrel" + "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opencensus.io/plugin/ocgrpc" "go.opencensus.io/stats/view" @@ -239,18 +240,50 @@ func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision) datas return spannerReader{executor, txSource, sd.filterMaximumIDCount} } +func (sd *spannerDatastore) readTransactionMetadata(ctx context.Context, transactionTag string) (map[string]any, error) { + row, err := sd.client.Single().ReadRow(ctx, tableTransactionMetadata, spanner.Key{transactionTag}, []string{colMetadata}) + if err != nil { + if spanner.ErrCode(err) == codes.NotFound { + return map[string]any{}, nil + } + + return nil, err + } + + var metadata map[string]any + if err := row.Columns(&metadata); err != nil { + return nil, err + } + + return metadata, nil +} + func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUserFunc, opts ...options.RWTOptionsOption) (datastore.Revision, error) { config := options.NewRWTOptionsWithOptions(opts...) ctx, span := tracer.Start(ctx, "ReadWriteTx") defer span.End() + transactionTag := "sdb-rwt-" + uuid.NewString() + ctx, cancel := context.WithCancel(ctx) - ts, err := sd.client.ReadWriteTransaction(ctx, func(ctx context.Context, spannerRWT *spanner.ReadWriteTransaction) error { + rs, err := sd.client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, spannerRWT *spanner.ReadWriteTransaction) error { txSource := func() readTX { return &traceableRTX{delegate: spannerRWT} } + if config.Metadata != nil { + // Insert the metadata into the transaction metadata table. + mutation := spanner.Insert(tableTransactionMetadata, + []string{colTransactionTag, colMetadata}, + []any{transactionTag, config.Metadata.AsMap()}, + ) + + if err := spannerRWT.BufferWrite([]*spanner.Mutation{mutation}); err != nil { + return fmt.Errorf("unable to write metadata: %w", err) + } + } + executor := common.QueryExecutor{Executor: queryExecutor(txSource)} rwt := spannerReadWriteTXN{ spannerReader{executor, txSource, sd.filterMaximumIDCount}, @@ -270,7 +303,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser } return nil - }) + }, spanner.TransactionOptions{TransactionTag: transactionTag}) if err != nil { if cerr := convertToWriteConstraintError(err); cerr != nil { return datastore.NoRevision, cerr @@ -278,7 +311,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser return datastore.NoRevision, err } - return revisions.NewForTime(ts), nil + return revisions.NewForTime(rs.CommitTs), nil } func (sd *spannerDatastore) ReadyState(ctx context.Context) (datastore.ReadyState, error) { diff --git a/internal/datastore/spanner/watch.go b/internal/datastore/spanner/watch.go index edd4eab37e..5a533e4b8a 100644 --- a/internal/datastore/spanner/watch.go +++ b/internal/datastore/spanner/watch.go @@ -156,6 +156,23 @@ func (sd *spannerDatastore) watch( } defer reader.Close() + metadataForTransactionTag := map[string]map[string]any{} + + addMetadataForTransactionTag := func(ctx context.Context, tracked *common.Changes[revisions.TimestampRevision, int64], revision revisions.TimestampRevision, transactionTag string) error { + if metadata, ok := metadataForTransactionTag[transactionTag]; ok { + return tracked.SetRevisionMetadata(ctx, revision, metadata) + } + + // Otherwise, load the metadata from the transactions metadata table. + transactionMetadata, err := sd.readTransactionMetadata(ctx, transactionTag) + if err != nil { + return err + } + + metadataForTransactionTag[transactionTag] = transactionMetadata + return tracked.SetRevisionMetadata(ctx, revision, transactionMetadata) + } + err = reader.Read(ctx, func(result *changestreams.ReadResult) error { // See: https://cloud.google.com/spanner/docs/change-streams/details for _, record := range result.ChangeRecords { @@ -165,6 +182,12 @@ func (sd *spannerDatastore) watch( changeRevision := revisions.NewForTime(dcr.CommitTimestamp) modType := dcr.ModType // options are INSERT, UPDATE, DELETE + if len(dcr.TransactionTag) > 0 { + if err := addMetadataForTransactionTag(ctx, tracked, changeRevision, dcr.TransactionTag); err != nil { + return err + } + } + for _, mod := range dcr.Mods { primaryKeyColumnValues, ok := mod.Keys.Value.(map[string]any) if !ok { @@ -312,7 +335,12 @@ func (sd *spannerDatastore) watch( } if !tracked.IsEmpty() { - for _, revChange := range tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) { + changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + if err != nil { + return err + } + + for _, revChange := range changes { revChange := revChange if !sendChange(&revChange) { return datastore.NewWatchDisconnectedErr() diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index e372a94944..dbca36ee68 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -9,6 +9,7 @@ import ( "time" "github.com/rs/zerolog" + "google.golang.org/protobuf/types/known/structpb" "github.com/authzed/spicedb/pkg/tuple" @@ -61,6 +62,9 @@ type RevisionChanges struct { // up until and including the Revision and that no additional schema updates can // have occurred before this point. IsCheckpoint bool + + // Metadata is the metadata associated with the revision, if any. + Metadata *structpb.Struct } func (rc *RevisionChanges) MarshalZerologObject(e *zerolog.Event) { diff --git a/pkg/datastore/options/options.go b/pkg/datastore/options/options.go index 34daf2e58a..1d13f2ca11 100644 --- a/pkg/datastore/options/options.go +++ b/pkg/datastore/options/options.go @@ -1,6 +1,8 @@ package options import ( + "google.golang.org/protobuf/types/known/structpb" + core "github.com/authzed/spicedb/pkg/proto/core/v1" ) @@ -51,7 +53,8 @@ type ResourceRelation struct { // RWTOptions are options that can affect the way a read-write transaction is // executed. type RWTOptions struct { - DisableRetries bool `debugmap:"visible"` + DisableRetries bool `debugmap:"visible"` + Metadata *structpb.Struct `debugmap:"visible"` } // DeleteOptions are the options that can affect the results of a delete relationships diff --git a/pkg/datastore/options/zz_generated.query_options.go b/pkg/datastore/options/zz_generated.query_options.go index f87f310b7c..f761b06b66 100644 --- a/pkg/datastore/options/zz_generated.query_options.go +++ b/pkg/datastore/options/zz_generated.query_options.go @@ -4,6 +4,7 @@ package options import ( defaults "github.com/creasty/defaults" helpers "github.com/ecordell/optgen/helpers" + structpb "google.golang.org/protobuf/types/known/structpb" ) type QueryOptionsOption func(q *QueryOptions) @@ -192,6 +193,7 @@ func NewRWTOptionsWithOptionsAndDefaults(opts ...RWTOptionsOption) *RWTOptions { func (r *RWTOptions) ToOption() RWTOptionsOption { return func(to *RWTOptions) { to.DisableRetries = r.DisableRetries + to.Metadata = r.Metadata } } @@ -199,6 +201,7 @@ func (r *RWTOptions) ToOption() RWTOptionsOption { func (r RWTOptions) DebugMap() map[string]any { debugMap := map[string]any{} debugMap["DisableRetries"] = helpers.DebugValue(r.DisableRetries, false) + debugMap["Metadata"] = helpers.DebugValue(r.Metadata, false) return debugMap } @@ -224,3 +227,10 @@ func WithDisableRetries(disableRetries bool) RWTOptionsOption { r.DisableRetries = disableRetries } } + +// WithMetadata returns an option that can set Metadata on a RWTOptions +func WithMetadata(metadata *structpb.Struct) RWTOptionsOption { + return func(r *RWTOptions) { + r.Metadata = metadata + } +} diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index 05f034c6ba..cc8a0999f5 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -162,6 +162,7 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories) t.Run("TestCaveatedRelationshipWatch", func(t *testing.T) { CaveatedRelationshipWatchTest(t, tester) }) t.Run("TestWatchWithTouch", func(t *testing.T) { WatchWithTouchTest(t, tester) }) t.Run("TestWatchWithDelete", func(t *testing.T) { WatchWithDeleteTest(t, tester) }) + t.Run("TestWatchWithMetadata", func(t *testing.T) { WatchWithMetadataTest(t, tester) }) } if !except.Watch() && !except.WatchSchema() { diff --git a/pkg/datastore/test/watch.go b/pkg/datastore/test/watch.go index 0191796bbc..a91ec8eefc 100644 --- a/pkg/datastore/test/watch.go +++ b/pkg/datastore/test/watch.go @@ -13,9 +13,11 @@ import ( "github.com/scylladb/go-set/strset" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/structpb" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" @@ -176,6 +178,50 @@ func VerifyUpdates( require.False(expectDisconnect, "all changes verified without expected disconnect") } +func VerifyUpdatesWithMetadata( + require *require.Assertions, + testUpdates []updateWithMetadata, + changes <-chan *datastore.RevisionChanges, + errchan <-chan error, + expectDisconnect bool, +) { + for _, expected := range testUpdates { + changeWait := time.NewTimer(waitForChangesTimeout) + select { + case change, ok := <-changes: + if !ok { + require.True(expectDisconnect, "unexpected disconnect") + errWait := time.NewTimer(waitForChangesTimeout) + select { + case err := <-errchan: + require.True(errors.As(err, &datastore.ErrWatchDisconnected{})) + return + case <-errWait.C: + require.Fail("Timed out waiting for ErrWatchDisconnected") + } + return + } + + expectedChangeSet := setOfChanges(expected.updates) + actualChangeSet := setOfChanges(change.RelationshipChanges) + + missingExpected := strset.Difference(expectedChangeSet, actualChangeSet) + unexpected := strset.Difference(actualChangeSet, expectedChangeSet) + + require.True(missingExpected.IsEmpty(), "expected changes missing: %s", missingExpected) + require.True(unexpected.IsEmpty(), "unexpected changes: %s", unexpected) + + require.Equal(expected.metadata, change.Metadata.AsMap(), "metadata mismatch") + + time.Sleep(1 * time.Millisecond) + case <-changeWait.C: + require.Fail("Timed out", "waited for changes: %s", expected) + } + } + + require.False(expectDisconnect, "all changes verified without expected disconnect") +} + func setOfChanges(changes []*core.RelationTupleUpdate) *strset.Set { changeSet := strset.NewWithSize(len(changes)) for _, change := range changes { @@ -337,6 +383,50 @@ func WatchWithTouchTest(t *testing.T, tester DatastoreTester) { ) } +type updateWithMetadata struct { + updates []*core.RelationTupleUpdate + metadata map[string]any +} + +func WatchWithMetadataTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 16) + require.NoError(err) + + setupDatastore(ds, require) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lowestRevision, err := ds.HeadRevision(ctx) + require.NoError(err) + + changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchJustRelationships()) + require.Zero(len(errchan)) + + metadata, err := structpb.NewStruct(map[string]any{"somekey": "somevalue"}) + require.NoError(err) + + _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteRelationships(ctx, []*core.RelationTupleUpdate{ + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:tom")), + }) + }, options.WithMetadata(metadata)) + require.NoError(err) + + VerifyUpdatesWithMetadata(require, []updateWithMetadata{ + { + updates: []*core.RelationTupleUpdate{tuple.Touch(tuple.Parse("document:firstdoc#viewer@user:tom"))}, + metadata: map[string]any{"somekey": "somevalue"}, + }, + }, + changes, + errchan, + false, + ) +} + func WatchWithDeleteTest(t *testing.T, tester DatastoreTester) { require := require.New(t)