Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new consistency middleware for full-consistency-only callers #2109

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions internal/middleware/consistency/consistency.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import (
"github.com/authzed/spicedb/pkg/zedtoken"
)

var ConsistentyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
var ConsistencyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "spicedb",
Subsystem: "middleware",
Name: "consistency_assigned_total",
Help: "Count of the consistencies used per request",
}, []string{"method", "source"})
}, []string{"method", "source", "service"})

type hasConsistency interface{ GetConsistency() *v1.Consistency }

Expand Down Expand Up @@ -64,18 +64,18 @@ func RevisionFromContext(ctx context.Context) (datastore.Revision, *v1.ZedToken,

// AddRevisionToContext adds a revision to the given context, based on the consistency block found
// in the given request (if applicable).
func AddRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore) error {
func AddRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error {
switch req := req.(type) {
case hasConsistency:
return addRevisionToContextFromConsistency(ctx, req, ds)
return addRevisionToContextFromConsistency(ctx, req, ds, serviceLabel)
default:
return nil
}
}

// addRevisionToContextFromConsistency adds a revision to the given context, based on the consistency block found
// in the given request (if applicable).
func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency, ds datastore.Datastore) error {
func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency, ds datastore.Datastore, serviceLabel string) error {
handle := ctx.Value(revisionKey)
if handle == nil {
return nil
Expand All @@ -89,7 +89,9 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency
switch {
case hasOptionalCursor && withOptionalCursor.GetOptionalCursor() != nil:
// Always use the revision encoded in the cursor.
ConsistentyCounter.WithLabelValues("snapshot", "cursor").Inc()
if serviceLabel != "" {
ConsistencyCounter.WithLabelValues("snapshot", "cursor", serviceLabel).Inc()
}

requestedRev, err := cursor.DecodeToDispatchRevision(withOptionalCursor.GetOptionalCursor(), ds)
if err != nil {
Expand All @@ -109,7 +111,10 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency
if consistency == nil {
source = "server"
}
ConsistentyCounter.WithLabelValues("minlatency", source).Inc()

if serviceLabel != "" {
ConsistencyCounter.WithLabelValues("minlatency", source, serviceLabel).Inc()
}

databaseRev, err := ds.OptimizedRevision(ctx)
if err != nil {
Expand All @@ -119,7 +124,9 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency

case consistency.GetFullyConsistent():
// Fully Consistent: Use the datastore's synchronized revision.
ConsistentyCounter.WithLabelValues("full", "request").Inc()
if serviceLabel != "" {
ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc()
}

databaseRev, err := ds.HeadRevision(ctx)
if err != nil {
Expand All @@ -139,13 +146,15 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency
if pickedRequest {
source = "request"
}
ConsistentyCounter.WithLabelValues("atleast", source).Inc()
ConsistencyCounter.WithLabelValues("atleast", source, serviceLabel).Inc()

revision = picked

case consistency.GetAtExactSnapshot() != nil:
// Exact snapshot: Use the revision as encoded in the zed token.
ConsistentyCounter.WithLabelValues("snapshot", "request").Inc()
if serviceLabel != "" {
ConsistencyCounter.WithLabelValues("snapshot", "request", serviceLabel).Inc()
}

requestedRev, err := zedtoken.DecodeRevision(consistency.GetAtExactSnapshot(), ds)
if err != nil {
Expand Down Expand Up @@ -175,7 +184,7 @@ var bypassServiceWhitelist = map[string]struct{}{

// UnaryServerInterceptor returns a new unary server interceptor that performs per-request exchange of
// the specified consistency configuration for the revision at which to perform the request.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
func UnaryServerInterceptor(serviceLabel string) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
for bypass := range bypassServiceWhitelist {
if strings.HasPrefix(info.FullMethod, bypass) {
Expand All @@ -184,7 +193,7 @@ func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
}
ds := datastoremw.MustFromContext(ctx)
newCtx := ContextWithHandle(ctx)
if err := AddRevisionToContext(newCtx, req, ds); err != nil {
if err := AddRevisionToContext(newCtx, req, ds, serviceLabel); err != nil {
return nil, err
}

Expand All @@ -194,21 +203,23 @@ func UnaryServerInterceptor() grpc.UnaryServerInterceptor {

// StreamServerInterceptor returns a new stream server interceptor that performs per-request exchange of
// the specified consistency configuration for the revision at which to perform the request.
func StreamServerInterceptor() grpc.StreamServerInterceptor {
func StreamServerInterceptor(serviceLabel string) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
for bypass := range bypassServiceWhitelist {
if strings.HasPrefix(info.FullMethod, bypass) {
return handler(srv, stream)
}
}
wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context())}
wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context()), serviceLabel, AddRevisionToContext}
return handler(srv, wrapper)
}
}

type recvWrapper struct {
grpc.ServerStream
ctx context.Context
ctx context.Context
serviceLabel string
handler func(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error
}

func (s *recvWrapper) Context() context.Context { return s.ctx }
Expand All @@ -218,8 +229,7 @@ func (s *recvWrapper) RecvMsg(m interface{}) error {
return err
}
ds := datastoremw.MustFromContext(s.ctx)

return AddRevisionToContext(s.ctx, m, ds)
return s.handler(s.ctx, m, ds, s.serviceLabel)
}

// pickBestRevision compares the provided ZedToken with the optimized revision of the datastore, and returns the most
Expand Down
14 changes: 7 additions & 7 deletions internal/middleware/consistency/consistency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestAddRevisionToContextNoneSupplied(t *testing.T) {
ds.On("OptimizedRevision").Return(optimized, nil).Once()

updated := ContextWithHandle(context.Background())
err := AddRevisionToContext(updated, &v1.ReadRelationshipsRequest{}, ds)
err := AddRevisionToContext(updated, &v1.ReadRelationshipsRequest{}, ds, "somelabel")
require.NoError(err)

rev, _, err := RevisionFromContext(updated)
Expand All @@ -52,7 +52,7 @@ func TestAddRevisionToContextMinimizeLatency(t *testing.T) {
MinimizeLatency: true,
},
},
}, ds)
}, ds, "somelabel")
require.NoError(err)

rev, _, err := RevisionFromContext(updated)
Expand All @@ -75,7 +75,7 @@ func TestAddRevisionToContextFullyConsistent(t *testing.T) {
FullyConsistent: true,
},
},
}, ds)
}, ds, "somelabel")
require.NoError(err)

rev, _, err := RevisionFromContext(updated)
Expand All @@ -99,7 +99,7 @@ func TestAddRevisionToContextAtLeastAsFresh(t *testing.T) {
AtLeastAsFresh: zedtoken.MustNewFromRevision(exact),
},
},
}, ds)
}, ds, "somelabel")
require.NoError(err)

rev, _, err := RevisionFromContext(updated)
Expand All @@ -123,7 +123,7 @@ func TestAddRevisionToContextAtValidExactSnapshot(t *testing.T) {
AtExactSnapshot: zedtoken.MustNewFromRevision(exact),
},
},
}, ds)
}, ds, "somelabel")
require.NoError(err)

rev, _, err := RevisionFromContext(updated)
Expand All @@ -147,7 +147,7 @@ func TestAddRevisionToContextAtInvalidExactSnapshot(t *testing.T) {
AtExactSnapshot: zedtoken.MustNewFromRevision(zero),
},
},
}, ds)
}, ds, "somelabel")
require.Error(err)
ds.AssertExpectations(t)
}
Expand Down Expand Up @@ -181,7 +181,7 @@ func TestAddRevisionToContextWithCursor(t *testing.T) {
},
},
OptionalCursor: cursor,
}, ds)
}, ds, "somelabel")
require.NoError(err)

// ensure we get back `optimized` from the cursor
Expand Down
66 changes: 66 additions & 0 deletions internal/middleware/consistency/forcefull.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package consistency

import (
"context"
"strings"

"google.golang.org/grpc"

datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
"github.com/authzed/spicedb/pkg/datastore"
)

// ForceFullConsistencyUnaryServerInterceptor returns a new unary server interceptor that enforces full consistency
// for all requests, except for those in the bypassServiceWhitelist.
func ForceFullConsistencyUnaryServerInterceptor(serviceLabel string) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use any instead of interface{}

for bypass := range bypassServiceWhitelist {
if strings.HasPrefix(info.FullMethod, bypass) {
return handler(ctx, req)
}
}
ds := datastoremw.MustFromContext(ctx)
newCtx := ContextWithHandle(ctx)
if err := setFullConsistencyRevisionToContext(newCtx, req, ds, serviceLabel); err != nil {
return nil, err
}

return handler(newCtx, req)
}
}

// ForceFullConsistencyStreamServerInterceptor returns a new stream server interceptor that enforces full consistency
// for all requests, except for those in the bypassServiceWhitelist.
func ForceFullConsistencyStreamServerInterceptor(serviceLabel string) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
for bypass := range bypassServiceWhitelist {
if strings.HasPrefix(info.FullMethod, bypass) {
return handler(srv, stream)
}
}
wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context()), serviceLabel, setFullConsistencyRevisionToContext}
return handler(srv, wrapper)
}
}

func setFullConsistencyRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore, serviceLabel string) error {
handle := ctx.Value(revisionKey)
if handle == nil {
return nil
}

switch req.(type) {
case hasConsistency:
if serviceLabel != "" {
ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc()
}

databaseRev, err := ds.HeadRevision(ctx)
if err != nil {
return rewriteDatastoreError(ctx, err)
}
handle.(*revisionHandle).revision = databaseRev
}

return nil
}
4 changes: 2 additions & 2 deletions internal/services/integrationtesting/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func TestCertRotation(t *testing.T) {
},
{
Name: "consistency",
Middleware: consistency.UnaryServerInterceptor(),
Middleware: consistency.UnaryServerInterceptor("testing"),
},
{
Name: "servicespecific",
Expand All @@ -167,7 +167,7 @@ func TestCertRotation(t *testing.T) {
},
{
Name: "consistency",
Middleware: consistency.StreamServerInterceptor(),
Middleware: consistency.StreamServerInterceptor("testing"),
},
{
Name: "servicespecific",
Expand Down
4 changes: 2 additions & 2 deletions internal/testserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func NewTestServerWithConfigAndDatastore(require *require.Assertions,
},
{
Name: "consistency",
Middleware: consistency.UnaryServerInterceptor(),
Middleware: consistency.UnaryServerInterceptor("testserver"),
},
{
Name: "servicespecific",
Expand All @@ -127,7 +127,7 @@ func NewTestServerWithConfigAndDatastore(require *require.Assertions,
},
{
Name: "consistency",
Middleware: consistency.StreamServerInterceptor(),
Middleware: consistency.StreamServerInterceptor("testserver"),
},
{
Name: "servicespecific",
Expand Down
5 changes: 3 additions & 2 deletions pkg/cmd/server/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ type MiddlewareOption struct {
EnableRequestLog bool `debugmap:"visible"`
EnableResponseLog bool `debugmap:"visible"`
DisableGRPCHistogram bool `debugmap:"visible"`
MiddlewareServiceLabel string `debugmap:"visible"`

unaryDatastoreMiddleware *ReferenceableMiddleware[grpc.UnaryServerInterceptor] `debugmap:"hidden"`
streamDatastoreMiddleware *ReferenceableMiddleware[grpc.StreamServerInterceptor] `debugmap:"hidden"`
Expand Down Expand Up @@ -341,7 +342,7 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS
NewUnaryMiddleware().
WithName(DefaultInternalMiddlewareConsistency).
WithInternal(true).
WithInterceptor(consistencymw.UnaryServerInterceptor()).
WithInterceptor(consistencymw.UnaryServerInterceptor(opts.MiddlewareServiceLabel)).
Done(),

NewUnaryMiddleware().
Expand Down Expand Up @@ -415,7 +416,7 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St
NewStreamMiddleware().
WithName(DefaultInternalMiddlewareConsistency).
WithInternal(true).
WithInterceptor(consistencymw.StreamServerInterceptor()).
WithInterceptor(consistencymw.StreamServerInterceptor(opts.MiddlewareServiceLabel)).
Done(),

NewStreamMiddleware().
Expand Down
7 changes: 7 additions & 0 deletions pkg/cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type Config struct {
PresharedSecureKey []string `debugmap:"sensitive"`
ShutdownGracePeriod time.Duration `debugmap:"visible"`
DisableVersionResponse bool `debugmap:"visible"`
ServerName string `debugmap:"visible"`

// GRPC Gateway config
HTTPGateway util.HTTPServerConfig `debugmap:"visible"`
Expand Down Expand Up @@ -376,6 +377,11 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) {
watchServiceOption = services.WatchServiceDisabled
}

serverName := c.ServerName
if serverName == "" {
serverName = "spicedb"
}

opts := MiddlewareOption{
log.Logger,
c.GRPCAuthFunc,
Expand All @@ -384,6 +390,7 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) {
c.EnableRequestLogs,
c.EnableResponseLogs,
c.DisableGRPCLatencyHistogram,
serverName,
nil,
nil,
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/cmd/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func TestModifyUnaryMiddleware(t *testing.T) {
},
}}

opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, nil, nil}
opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", nil, nil}
opt = opt.WithDatastore(nil)

defaultMw, err := DefaultUnaryMiddleware(opt)
Expand Down Expand Up @@ -259,7 +259,7 @@ func TestModifyStreamingMiddleware(t *testing.T) {
},
}}

opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, nil, nil}
opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", nil, nil}
opt = opt.WithDatastore(nil)

defaultMw, err := DefaultStreamingMiddleware(opt)
Expand Down
Loading
Loading