Skip to content

Commit

Permalink
feat: refactor topo locking and use it for routing rules locking
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
  • Loading branch information
GuptaManan100 committed May 24, 2024
1 parent e978473 commit 95ac32b
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 176 deletions.
89 changes: 36 additions & 53 deletions go/vt/topo/topo_lock.go → go/vt/topo/locking.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,61 +18,44 @@ package topo

import (
"context"
"fmt"

"vitess.io/vitess/go/trace"
"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

// ITopoLock is the interface for a lock that can be used to lock a key in the topology server.
// The lock is associated with a context and can be unlocked by calling the returned function.
// Note that we don't need an Unlock method on the interface, as the Lock() function
// returns a function that can be used to unlock the lock.
type ITopoLock interface {
Lock(ctx context.Context) (context.Context, func(*error), error)
}

type TopoLock struct {
Path string // topo path to lock
Name string // name, for logging purposes

ts *Server
}

var _ ITopoLock = (*TopoLock)(nil)

func (ts *Server) NewTopoLock(path, name string) *TopoLock {
return &TopoLock{
ts: ts,
Path: path,
Name: name,
}
}

func (tl *TopoLock) String() string {
return fmt.Sprintf("TopoLock{Path: %v, Name: %v}", tl.Path, tl.Name)
// lockType is the interface for knowing the resource that is being locked.
// It allows for better controlling nuances for different lock types and log messages.
type lockType interface {
Type() string
ResourceName() string
Path() string
}

// perform the topo lock operation
func (l *Lock) lock(ctx context.Context, ts *Server, path string) (LockDescriptor, error) {
func (l *Lock) lock(ctx context.Context, ts *Server, lt lockType, isBlocking bool) (LockDescriptor, error) {
log.Infof("Locking %v %v for action %v", lt.Type(), lt.ResourceName(), l.Action)

ctx, cancel := context.WithTimeout(ctx, LockTimeout)
defer cancel()
span, ctx := trace.NewSpan(ctx, "TopoServer.Lock")
span.Annotate("action", l.Action)
span.Annotate("path", path)
span.Annotate("path", lt.Path())
defer span.Finish()

j, err := l.ToJSON()
if err != nil {
return nil, err
}
return ts.globalCell.Lock(ctx, path, j)
if isBlocking {
return ts.globalCell.Lock(ctx, lt.Path(), j)
}
return ts.globalCell.TryLock(ctx, lt.Path(), j)
}

// unlock unlocks a previously locked key.
func (l *Lock) unlock(ctx context.Context, path string, lockDescriptor LockDescriptor, actionError error) error {
func (l *Lock) unlock(ctx context.Context, lt lockType, lockDescriptor LockDescriptor, actionError error) error {
// Detach from the parent timeout, but copy the trace span.
// We need to still release the lock even if the parent
// context timed out.
Expand All @@ -82,21 +65,21 @@ func (l *Lock) unlock(ctx context.Context, path string, lockDescriptor LockDescr

span, ctx := trace.NewSpan(ctx, "TopoServer.Unlock")
span.Annotate("action", l.Action)
span.Annotate("path", path)
span.Annotate("path", lt.Path())
defer span.Finish()

// first update the actionNode
if actionError != nil {
log.Infof("Unlocking %v %v for action %v with error %v", lt.Type(), lt.ResourceName(), l.Action, actionError)
l.Status = "Error: " + actionError.Error()
} else {
log.Infof("Unlocking %v %v for successful action %v", lt.Type(), lt.ResourceName(), l.Action)
l.Status = "Done"
}
return lockDescriptor.Unlock(ctx)
}

// Lock adds lock information to the context, checks that the lock is not already held, and locks it.
// It returns a new context with the lock information and a function to unlock the lock.
func (tl TopoLock) Lock(ctx context.Context) (context.Context, func(*error), error) {
func (ts *Server) internalLock(ctx context.Context, lt lockType, action string, isBlocking bool) (context.Context, func(*error), error) {
i, ok := ctx.Value(locksKey).(*locksInfo)
if !ok {
i = &locksInfo{
Expand All @@ -107,63 +90,63 @@ func (tl TopoLock) Lock(ctx context.Context) (context.Context, func(*error), err
i.mu.Lock()
defer i.mu.Unlock()
// check that we are not already locked
if _, ok := i.info[tl.Path]; ok {
return nil, nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "lock for %v is already held", tl.Path)
if _, ok := i.info[lt.ResourceName()]; ok {
return nil, nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "lock for %v %v is already held", lt.Type(), lt.ResourceName())
}

// lock it
l := newLock(fmt.Sprintf("lock for %s", tl.Name))
lockDescriptor, err := l.lock(ctx, tl.ts, tl.Path)
l := newLock(action)
lockDescriptor, err := l.lock(ctx, ts, lt, isBlocking)
if err != nil {
return nil, nil, err
}
// and update our structure
i.info[tl.Path] = &lockInfo{
i.info[lt.ResourceName()] = &lockInfo{
lockDescriptor: lockDescriptor,
actionNode: l,
}
return ctx, func(finalErr *error) {
i.mu.Lock()
defer i.mu.Unlock()

if _, ok := i.info[tl.Path]; !ok {
if _, ok := i.info[lt.ResourceName()]; !ok {
if *finalErr != nil {
log.Errorf("trying to unlock %v multiple times", tl.Path)
log.Errorf("trying to unlock %v %v multiple times", lt.Type(), lt.ResourceName())
} else {
*finalErr = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "trying to unlock %v multiple times", tl.Path)
*finalErr = vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "trying to unlock %v %v multiple times", lt.Type(), lt.ResourceName())
}
return
}

err := l.unlock(ctx, tl.Path, lockDescriptor, *finalErr)
err := l.unlock(ctx, lt, lockDescriptor, *finalErr)
// if we have an error, we log it, but we still want to delete the lock
if *finalErr != nil {
if err != nil {
// both error are set, just log the unlock error
log.Errorf("unlock(%v) failed: %v", tl.Path, err)
log.Errorf("unlock %v %v failed: %v", lt.Type(), lt.ResourceName(), err)
}
} else {
*finalErr = err
}
delete(i.info, tl.Path)
delete(i.info, lt.ResourceName())
}, nil
}

func CheckLocked(ctx context.Context, keyPath string) error {
func checkLocked(ctx context.Context, lt lockType) error {
// extract the locksInfo pointer
i, ok := ctx.Value(locksKey).(*locksInfo)
if !ok {
return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "%s is not locked (no locksInfo)", keyPath)
return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "%v %v is not locked (no locksInfo)", lt.Type(), lt.ResourceName())
}
i.mu.Lock()
defer i.mu.Unlock()

// find the individual entry
_, ok = i.info[keyPath]
li, ok := i.info[lt.ResourceName()]
if !ok {
return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "%s is not locked (no lockInfo in map)", keyPath)
return vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "%v %v is not locked (no lockInfo in map)", lt.Type(), lt.ResourceName())
}

// and we're good for now.
return nil
// Check the lock server implementation still holds the lock.
return li.lockDescriptor.Check(ctx)
}
18 changes: 1 addition & 17 deletions go/vt/topo/locks.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,7 @@ func (ts *Server) internalLockShard(ctx context.Context, keyspace, shard, action
l := newLock(action)
var lockDescriptor LockDescriptor
var err error
if isBlocking {
lockDescriptor, err = l.lockShard(ctx, ts, keyspace, shard)
} else {
lockDescriptor, err = l.tryLockShard(ctx, ts, keyspace, shard)
}
lockDescriptor, err = l.internalLockShard(ctx, ts, keyspace, shard, isBlocking)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -401,18 +397,6 @@ func CheckShardLocked(ctx context.Context, keyspace, shard string) error {
return li.lockDescriptor.Check(ctx)
}

// lockShard will lock the shard in the topology server.
// UnlockShard should be called if this returns no error.
func (l *Lock) lockShard(ctx context.Context, ts *Server, keyspace, shard string) (LockDescriptor, error) {
return l.internalLockShard(ctx, ts, keyspace, shard, true)
}

// tryLockShard will lock the shard in the topology server but unlike `lockShard` it fail-fast if not able to get lock
// UnlockShard should be called if this returns no error.
func (l *Lock) tryLockShard(ctx context.Context, ts *Server, keyspace, shard string) (LockDescriptor, error) {
return l.internalLockShard(ctx, ts, keyspace, shard, false)
}

func (l *Lock) internalLockShard(ctx context.Context, ts *Server, keyspace, shard string, isBlocking bool) (LockDescriptor, error) {
log.Infof("Locking shard %v/%v for action %v", keyspace, shard, l.Action)

Expand Down
35 changes: 23 additions & 12 deletions go/vt/topo/routing_rules_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,31 @@ limitations under the License.
package topo

import (
"fmt"
"context"
)

// RoutingRulesLock is a wrapper over TopoLock, to serialize updates to routing rules.
type RoutingRulesLock struct {
*TopoLock
type routingRulesType struct{}

var _ lockType = (*routingRulesType)(nil)

func (s *routingRulesType) Type() string {
return "routing_rules"
}

func (s *routingRulesType) ResourceName() string {
return RoutingRulesPath
}

func (s *routingRulesType) Path() string {
return RoutingRulesPath
}

// LockRoutingRules acquires a lock for routing rules.
func (ts *Server) LockRoutingRules(ctx context.Context, action string) (context.Context, func(*error), error) {
return ts.internalLock(ctx, &routingRulesType{}, action, true)
}

func NewRoutingRulesLock(ts *Server, name string) *RoutingRulesLock {
return &RoutingRulesLock{
TopoLock: &TopoLock{
Path: RoutingRulesPath,
Name: fmt.Sprintf("RoutingRules::%s", name),
ts: ts,
},
}
// CheckRoutingRulesLocked checks if a lock for routing rules is still possessed.
func CheckRoutingRulesLocked(ctx context.Context) error {
return checkLocked(ctx, &routingRulesType{})
}
63 changes: 59 additions & 4 deletions go/vt/topo/routing_rules_lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package topo_test
import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand All @@ -28,6 +29,61 @@ import (
vschemapb "vitess.io/vitess/go/vt/proto/vschema"
)

// lower the lock timeout for testing
const testLockTimeout = 3 * time.Second

// TestTopoLockTimeout tests that the lock times out after the specified duration.
func TestTopoLockTimeout(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := memorytopo.NewServer(ctx, "zone1")
defer ts.Close()

err := ts.CreateKeyspaceRoutingRules(ctx, &vschemapb.KeyspaceRoutingRules{})
require.NoError(t, err)

currentTopoLockTimeout := topo.LockTimeout
topo.LockTimeout = testLockTimeout
defer func() {
topo.LockTimeout = currentTopoLockTimeout
}()

// acquire the lock
origCtx := ctx
_, unlock, err := ts.LockRoutingRules(origCtx, "ks1")
require.NoError(t, err)
defer unlock(&err)

// re-acquiring the lock should fail
_, _, err2 := ts.LockRoutingRules(origCtx, "ks1")
require.Errorf(t, err2, "deadline exceeded")
}

// TestTopoLockBasic tests basic lock operations.
func TestTopoLockBasic(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := memorytopo.NewServer(ctx, "zone1")
defer ts.Close()

err := ts.CreateKeyspaceRoutingRules(ctx, &vschemapb.KeyspaceRoutingRules{})
require.NoError(t, err)

origCtx := ctx
ctx, unlock, err := ts.LockRoutingRules(origCtx, "ks1")
require.NoError(t, err)

// locking the same key again, without unlocking, should return an error
_, _, err2 := ts.LockRoutingRules(ctx, "ks1")
require.ErrorContains(t, err2, "already held")

// confirm that the lock can be re-acquired after unlocking
unlock(&err)
_, unlock, err = ts.LockRoutingRules(origCtx, "ks1")
require.NoError(t, err)
defer unlock(&err)
}

// TestKeyspaceRoutingRulesLock tests that the lock is acquired and released correctly.
func TestKeyspaceRoutingRulesLock(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -44,17 +100,16 @@ func TestKeyspaceRoutingRulesLock(t *testing.T) {
err := ts.CreateKeyspaceRoutingRules(ctx, &vschemapb.KeyspaceRoutingRules{})
require.NoError(t, err)

lock := topo.NewRoutingRulesLock(ts, "ks1")
_, unlock, err := lock.Lock(ctx)
_, unlock, err := ts.LockRoutingRules(ctx, "ks1")
require.NoError(t, err)

// re-acquiring the lock should fail
_, _, err = lock.Lock(ctx)
_, _, err = ts.LockRoutingRules(ctx, "ks1")
require.Error(t, err)

unlock(&err)

// re-acquiring the lock should succeed
_, _, err = lock.Lock(ctx)
_, _, err = ts.LockRoutingRules(ctx, "ks1")
require.NoError(t, err)
}
Loading

0 comments on commit 95ac32b

Please sign in to comment.