From 0949de123aeeb0e638ad7e65614ebb3807e7baa0 Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Mon, 18 Nov 2024 18:15:18 +0100 Subject: [PATCH] feat: batched and chunked insertion+deletion of relation tuples --- .github/workflows/ci.yaml | 4 +- go.mod | 2 +- go.sum | 4 +- internal/e2e/cli_client_test.go | 58 +++----- internal/e2e/full_suit_test.go | 48 +++---- internal/e2e/grpc_client_test.go | 139 ++++++++---------- internal/e2e/rest_client_test.go | 44 +++--- internal/e2e/sdk_client_test.go | 58 ++++---- internal/e2e/testcases_test.go | 8 +- internal/e2e/transaction_cases_test.go | 88 +++++++++++- internal/persistence/sql/persister.go | 20 --- internal/persistence/sql/query_test.go | 128 +++++++++++++++++ internal/persistence/sql/relationtuples.go | 156 +++++++++++++++------ internal/persistence/sql/uuid_mapping.go | 124 +++++++++------- internal/x/dbx/dsn_testutils.go | 28 ++-- 15 files changed, 582 insertions(+), 327 deletions(-) create mode 100644 internal/persistence/sql/query_test.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e59a41b14..4a0178139 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -49,7 +49,7 @@ jobs: runs-on: ubuntu-latest services: postgres: - image: postgres:11.8 + image: postgres:16 env: POSTGRES_DB: keto POSTGRES_PASSWORD: test @@ -69,7 +69,7 @@ jobs: steps: - run: | docker create --name cockroach -p 26257:26257 \ - cockroachdb/cockroach:latest-v23.2 start-single-node --insecure + cockroachdb/cockroach:latest-v24.2 start-single-node --insecure docker start cockroach name: Start CockroachDB - uses: ory/ci/checkout@master diff --git a/go.mod b/go.mod index 42baac8b5..951d2a5aa 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88 github.com/ory/jsonschema/v3 v3.0.8 github.com/ory/keto/proto v0.13.0-alpha.0 - github.com/ory/x v0.0.675 + github.com/ory/x v0.0.677 github.com/pelletier/go-toml v1.9.5 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index bf4ea25f2..f1f40f477 100644 --- a/go.sum +++ b/go.sum @@ -419,8 +419,8 @@ github.com/ory/jsonschema/v3 v3.0.8 h1:Ssdb3eJ4lDZ/+XnGkvQS/te0p+EkolqwTsDOCxr/F github.com/ory/jsonschema/v3 v3.0.8/go.mod h1:ZPzqjDkwd3QTnb2Z6PAS+OTvBE2x5i6m25wCGx54W/0= github.com/ory/pop/v6 v6.2.1-0.20241121111754-e5dfc0f3344b h1:BIzoOe2/wynZBQak1po0tzgvARseIKsR2bF6b+SZoKE= github.com/ory/pop/v6 v6.2.1-0.20241121111754-e5dfc0f3344b/go.mod h1:okVAYKGtgunD/wbW3NGhZTndJCS+6FqO+cA89rQ4doc= -github.com/ory/x v0.0.675 h1:K6GpVo99BXBFv2UiwMjySNNNqCFKGswynrt7vWQJFU8= -github.com/ory/x v0.0.675/go.mod h1:zJmnDtKje2FCP4EeFvRsKk94XXiqKCSGJMZcirAfhUs= +github.com/ory/x v0.0.677 h1:ZulzE4EBhNBXNotWmGSmGsVNbgbZpIr4snMURRkski0= +github.com/ory/x v0.0.677/go.mod h1:zJmnDtKje2FCP4EeFvRsKk94XXiqKCSGJMZcirAfhUs= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= diff --git a/internal/e2e/cli_client_test.go b/internal/e2e/cli_client_test.go index 28a908413..9abe696b0 100644 --- a/internal/e2e/cli_client_test.go +++ b/internal/e2e/cli_client_test.go @@ -12,47 +12,37 @@ import ( "testing" "time" - "github.com/ory/keto/ketoapi" - "github.com/ory/herodot" - - "github.com/ory/keto/internal/check" - - grpcHealthV1 "google.golang.org/grpc/health/grpc_health_v1" - - "github.com/ory/keto/internal/x" - + "github.com/ory/x/cmdx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + grpcHealthV1 "google.golang.org/grpc/health/grpc_health_v1" gprclient "github.com/ory/keto/cmd/client" cliexpand "github.com/ory/keto/cmd/expand" clirelationtuple "github.com/ory/keto/cmd/relationtuple" - - "github.com/ory/x/cmdx" + "github.com/ory/keto/internal/check" + "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" ) type cliClient struct { c *cmdx.CommandExecuter } -func (g *cliClient) queryNamespaces(t require.TestingT) (res ketoapi.GetNamespacesResponse) { - if t, ok := t.(*testing.T); ok { - t.Skip("not implemented for the CLI") - } +func (g *cliClient) queryNamespaces(t *testing.T) (res ketoapi.GetNamespacesResponse) { + t.Skip("not implemented for the CLI") return } var _ client = (*cliClient)(nil) -func (g *cliClient) oplCheckSyntax(t require.TestingT, _ []byte) []*ketoapi.ParseError { - if t, ok := t.(*testing.T); ok { - t.Skip("not implemented as a command yet") - } +func (g *cliClient) oplCheckSyntax(t *testing.T, _ []byte) []*ketoapi.ParseError { + t.Skip("not implemented as a command yet") return []*ketoapi.ParseError{} } -func (g *cliClient) createTuple(t require.TestingT, r *ketoapi.RelationTuple) { +func (g *cliClient) createTuple(t *testing.T, r *ketoapi.RelationTuple) { tupleEnc, err := json.Marshal(r) require.NoError(t, err) @@ -88,7 +78,7 @@ func (g *cliClient) assembleQueryFlags(q *ketoapi.RelationQuery, opts []x.Pagina return flags } -func (g *cliClient) queryTuple(t require.TestingT, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse { +func (g *cliClient) queryTuple(t *testing.T, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse { out := g.c.ExecNoErr(t, append(g.assembleQueryFlags(q, opts), "relation-tuple", "get")...) var resp ketoapi.GetResponse @@ -97,13 +87,13 @@ func (g *cliClient) queryTuple(t require.TestingT, q *ketoapi.RelationQuery, opt return &resp } -func (g *cliClient) queryTupleErr(t require.TestingT, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) { +func (g *cliClient) queryTupleErr(t *testing.T, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) { stdErr := g.c.ExecExpectedErr(t, append(g.assembleQueryFlags(q, opts), "relation-tuple", "get")...) assert.Contains(t, stdErr, expected.GRPCCodeField.String()) assert.Contains(t, stdErr, expected.Error()) } -func (g *cliClient) check(t require.TestingT, r *ketoapi.RelationTuple) bool { +func (g *cliClient) check(t *testing.T, r *ketoapi.RelationTuple) bool { var sub string if r.SubjectID != nil { sub = *r.SubjectID @@ -116,27 +106,23 @@ func (g *cliClient) check(t require.TestingT, r *ketoapi.RelationTuple) bool { return res.Allowed } -func (g *cliClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple, - expected herodot.DefaultError) { - if t, ok := t.(*testing.T); ok { - t.Skip("not implemented for the CLI") - } +func (g *cliClient) batchCheckErr(t *testing.T, requestTuples []*ketoapi.RelationTuple, expected herodot.DefaultError) { + t.Skip("not implemented for the CLI") } -func (g *cliClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple) []checkResponse { - if t, ok := t.(*testing.T); ok { - t.Skip("not implemented for the CLI") - } + +func (g *cliClient) batchCheck(t *testing.T, requestTuples []*ketoapi.RelationTuple) []checkResponse { + t.Skip("not implemented for the CLI") return nil } -func (g *cliClient) expand(t require.TestingT, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] { +func (g *cliClient) expand(t *testing.T, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] { out := g.c.ExecNoErr(t, "expand", r.Relation, r.Namespace, r.Object, "--"+cliexpand.FlagMaxDepth, fmt.Sprintf("%d", depth), "--"+cmdx.FlagFormat, string(cmdx.FormatJSON)) res := ketoapi.Tree[*ketoapi.RelationTuple]{} require.NoError(t, json.Unmarshal([]byte(out), &res)) return &res } -func (g *cliClient) waitUntilLive(t require.TestingT) { +func (g *cliClient) waitUntilLive(t *testing.T) { flags := make([]string, len(g.c.PersistentArgs)) copy(flags, g.c.PersistentArgs) @@ -154,7 +140,7 @@ func (g *cliClient) waitUntilLive(t require.TestingT) { require.Equal(t, grpcHealthV1.HealthCheckResponse_SERVING.String()+"\n", out) } -func (g *cliClient) deleteTuple(t require.TestingT, r *ketoapi.RelationTuple) { +func (g *cliClient) deleteTuple(t *testing.T, r *ketoapi.RelationTuple) { tupleEnc, err := json.Marshal(r) require.NoError(t, err) @@ -163,6 +149,6 @@ func (g *cliClient) deleteTuple(t require.TestingT, r *ketoapi.RelationTuple) { assert.Len(t, stderr, 0, stdout) } -func (g *cliClient) deleteAllTuples(t require.TestingT, q *ketoapi.RelationQuery) { +func (g *cliClient) deleteAllTuples(t *testing.T, q *ketoapi.RelationQuery) { _ = g.c.ExecNoErr(t, append(g.assembleQueryFlags(q, nil), "relation-tuple", "delete-all", "--force")...) } diff --git a/internal/e2e/full_suit_test.go b/internal/e2e/full_suit_test.go index e00b34387..6afb7cf78 100644 --- a/internal/e2e/full_suit_test.go +++ b/internal/e2e/full_suit_test.go @@ -28,21 +28,21 @@ import ( type ( transactClient interface { client - transactTuples(t require.TestingT, ins []*ketoapi.RelationTuple, del []*ketoapi.RelationTuple) + transactTuples(t *testing.T, ins []*ketoapi.RelationTuple, del []*ketoapi.RelationTuple) } client interface { - createTuple(t require.TestingT, r *ketoapi.RelationTuple) - deleteTuple(t require.TestingT, r *ketoapi.RelationTuple) - deleteAllTuples(t require.TestingT, q *ketoapi.RelationQuery) - queryTuple(t require.TestingT, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse - queryTupleErr(t require.TestingT, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) - check(t require.TestingT, r *ketoapi.RelationTuple) bool - batchCheck(t require.TestingT, r []*ketoapi.RelationTuple) []checkResponse - batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple, expected herodot.DefaultError) - expand(t require.TestingT, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] - oplCheckSyntax(t require.TestingT, content []byte) []*ketoapi.ParseError - waitUntilLive(t require.TestingT) - queryNamespaces(t require.TestingT) ketoapi.GetNamespacesResponse + createTuple(t *testing.T, r *ketoapi.RelationTuple) + deleteTuple(t *testing.T, r *ketoapi.RelationTuple) + deleteAllTuples(t *testing.T, q *ketoapi.RelationQuery) + queryTuple(t *testing.T, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse + queryTupleErr(t *testing.T, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) + check(t *testing.T, r *ketoapi.RelationTuple) bool + batchCheck(t *testing.T, r []*ketoapi.RelationTuple) []checkResponse + batchCheckErr(t *testing.T, requestTuples []*ketoapi.RelationTuple, expected herodot.DefaultError) + expand(t *testing.T, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] + oplCheckSyntax(t *testing.T, content []byte) []*ketoapi.ParseError + waitUntilLive(t *testing.T) + queryNamespaces(t *testing.T) ketoapi.GetNamespacesResponse } ) @@ -65,12 +65,11 @@ func Test(t *testing.T) { // The test cases start here // We execute every test with all clients available for _, cl := range []client{ - &grpcClient{ - readRemote: reg.Config(ctx).ReadAPIListenOn(), - writeRemote: reg.Config(ctx).WriteAPIListenOn(), - oplSyntaxRemote: reg.Config(ctx).OPLSyntaxAPIListenOn(), - ctx: ctx, - }, + newGrpcClient(t, ctx, + reg.Config(ctx).ReadAPIListenOn(), + reg.Config(ctx).WriteAPIListenOn(), + reg.Config(ctx).OPLSyntaxAPIListenOn(), + ), &restClient{ readURL: "http://" + reg.Config(ctx).ReadAPIListenOn(), writeURL: "http://" + reg.Config(ctx).WriteAPIListenOn(), @@ -104,11 +103,11 @@ func Test(t *testing.T) { t.Run("case=metrics are served", func(t *testing.T) { t.Parallel() - (&grpcClient{ - readRemote: reg.Config(ctx).ReadAPIListenOn(), - writeRemote: reg.Config(ctx).WriteAPIListenOn(), - ctx: ctx, - }).waitUntilLive(t) + newGrpcClient(t, ctx, + reg.Config(ctx).ReadAPIListenOn(), + reg.Config(ctx).WriteAPIListenOn(), + reg.Config(ctx).OPLSyntaxAPIListenOn(), + ).waitUntilLive(t) t.Run("case=on "+prometheus.MetricsPrometheusPath, func(t *testing.T) { t.Parallel() @@ -148,6 +147,7 @@ func TestServeConfig(t *testing.T) { t.Log("Waiting for health check to be ready") time.Sleep(10 * time.Millisecond) } + t.Log("Health check is ready") req, err := http.NewRequest(http.MethodOptions, "http://"+reg.Config(ctx).ReadAPIListenOn()+relationtuple.ReadRouteBase, nil) require.NoError(t, err) diff --git a/internal/e2e/grpc_client_test.go b/internal/e2e/grpc_client_test.go index 885d3f262..536881bc6 100644 --- a/internal/e2e/grpc_client_test.go +++ b/internal/e2e/grpc_client_test.go @@ -6,74 +6,65 @@ package e2e import ( "context" "encoding/json" + "errors" + "testing" "time" - "github.com/ory/keto/ketoapi" - opl "github.com/ory/keto/proto/ory/keto/opl/v1alpha1" - - rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" - "github.com/ory/herodot" "github.com/stretchr/testify/assert" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/status" - "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" grpcHealthV1 "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" + opl "github.com/ory/keto/proto/ory/keto/opl/v1alpha1" + rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" ) type grpcClient struct { - readRemote, writeRemote, oplSyntaxRemote string - wc, rc, oc *grpc.ClientConn - ctx context.Context + read, write, oplSyntax *grpc.ClientConn + ctx context.Context } -func (g *grpcClient) queryNamespaces(t require.TestingT) (apiResponse ketoapi.GetNamespacesResponse) { - client := rts.NewNamespacesServiceClient(g.readConn(t)) - res, err := client.ListNamespaces(g.ctx, &rts.ListNamespacesRequest{}) +func newGrpcClient(t *testing.T, ctx context.Context, readRemote, writeRemote, oplSyntaxRemote string) *grpcClient { + read, err := grpc.NewClient(readRemote, grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) - require.NoError(t, convert(res, &apiResponse)) + t.Cleanup(func() { read.Close() }) - return -} - -var _ transactClient = (*grpcClient)(nil) - -func (g *grpcClient) conn(t require.TestingT, remote string) *grpc.ClientConn { - ctx, cancel := context.WithTimeout(g.ctx, 3*time.Second) - defer cancel() + write, err := grpc.NewClient(writeRemote, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + t.Cleanup(func() { write.Close() }) - conn, err := grpc.DialContext(ctx, remote, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), grpc.WithDisableHealthCheck()) + oplSyntax, err := grpc.NewClient(oplSyntaxRemote, grpc.WithTransportCredentials(insecure.NewCredentials())) require.NoError(t, err) + t.Cleanup(func() { oplSyntax.Close() }) - return conn -} + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) -func (g *grpcClient) readConn(t require.TestingT) *grpc.ClientConn { - if g.rc == nil { - g.rc = g.conn(t, g.readRemote) + return &grpcClient{ + read: read, + write: write, + oplSyntax: oplSyntax, + ctx: ctx, } - return g.rc } -func (g *grpcClient) writeConn(t require.TestingT) *grpc.ClientConn { - if g.wc == nil { - g.wc = g.conn(t, g.writeRemote) - } - return g.wc -} +func (g *grpcClient) queryNamespaces(t *testing.T) (apiResponse ketoapi.GetNamespacesResponse) { + client := rts.NewNamespacesServiceClient(g.read) + res, err := client.ListNamespaces(g.ctx, &rts.ListNamespacesRequest{}) + require.NoError(t, err) + require.NoError(t, convert(res, &apiResponse)) -func (g *grpcClient) oplSyntaxConn(t require.TestingT) *grpc.ClientConn { - if g.oc == nil { - g.oc = g.conn(t, g.oplSyntaxRemote) - } - return g.oc + return } -func (g *grpcClient) createTuple(t require.TestingT, r *ketoapi.RelationTuple) { +var _ transactClient = (*grpcClient)(nil) + +func (g *grpcClient) createTuple(t *testing.T, r *ketoapi.RelationTuple) { g.transactTuples(t, []*ketoapi.RelationTuple{r}, nil) } @@ -91,8 +82,8 @@ func (*grpcClient) createQuery(q *ketoapi.RelationQuery) *rts.RelationQuery { return query } -func (g *grpcClient) queryTuple(t require.TestingT, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse { - c := rts.NewReadServiceClient(g.readConn(t)) +func (g *grpcClient) queryTuple(t *testing.T, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse { + c := rts.NewReadServiceClient(g.read) pagination := x.GetPaginationOptions(opts...) resp, err := c.ListRelationTuples(g.ctx, &rts.ListRelationTuplesRequest{ @@ -114,8 +105,8 @@ func (g *grpcClient) queryTuple(t require.TestingT, q *ketoapi.RelationQuery, op } } -func (g *grpcClient) queryTupleErr(t require.TestingT, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) { - c := rts.NewReadServiceClient(g.readConn(t)) +func (g *grpcClient) queryTupleErr(t *testing.T, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) { + c := rts.NewReadServiceClient(g.read) pagination := x.GetPaginationOptions(opts...) _, err := c.ListRelationTuples(g.ctx, &rts.ListRelationTuplesRequest{ @@ -129,8 +120,8 @@ func (g *grpcClient) queryTupleErr(t require.TestingT, expected herodot.DefaultE assert.Equal(t, expected.GRPCCodeField, s.Code(), "%+v", err) } -func (g *grpcClient) check(t require.TestingT, r *ketoapi.RelationTuple) bool { - c := rts.NewCheckServiceClient(g.readConn(t)) +func (g *grpcClient) check(t *testing.T, r *ketoapi.RelationTuple) bool { + c := rts.NewCheckServiceClient(g.read) req := &rts.CheckRequest{ Tuple: &rts.RelationTuple{ @@ -155,7 +146,7 @@ type checkResponse struct { errorMessage string } -func (g *grpcClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple, +func (g *grpcClient) batchCheckErr(t *testing.T, requestTuples []*ketoapi.RelationTuple, expected herodot.DefaultError) { _, err := g.doBatchCheck(t, requestTuples) @@ -166,7 +157,7 @@ func (g *grpcClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi. assert.Contains(t, s.Message(), expected.Reason()) } -func (g *grpcClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple) []checkResponse { +func (g *grpcClient) batchCheck(t *testing.T, requestTuples []*ketoapi.RelationTuple) []checkResponse { resp, err := g.doBatchCheck(t, requestTuples) require.NoError(t, err) @@ -181,9 +172,9 @@ func (g *grpcClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.Rel return checkResponses } -func (g *grpcClient) doBatchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple) (*rts.BatchCheckResponse, error) { +func (g *grpcClient) doBatchCheck(_ *testing.T, requestTuples []*ketoapi.RelationTuple) (*rts.BatchCheckResponse, error) { - c := rts.NewCheckServiceClient(g.readConn(t)) + c := rts.NewCheckServiceClient(g.read) tuples := make([]*rts.RelationTuple, len(requestTuples)) for i, tuple := range requestTuples { @@ -207,8 +198,8 @@ func (g *grpcClient) doBatchCheck(t require.TestingT, requestTuples []*ketoapi.R return c.BatchCheck(g.ctx, req) } -func (g *grpcClient) expand(t require.TestingT, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] { - c := rts.NewExpandServiceClient(g.readConn(t)) +func (g *grpcClient) expand(t *testing.T, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] { + c := rts.NewExpandServiceClient(g.read) resp, err := c.Expand(g.ctx, &rts.ExpandRequest{ Subject: rts.NewSubjectSet(r.Namespace, r.Object, r.Relation), @@ -219,37 +210,29 @@ func (g *grpcClient) expand(t require.TestingT, r *ketoapi.SubjectSet, depth int return ketoapi.TreeFromProto[*ketoapi.RelationTuple](resp.Tree) } -func (g *grpcClient) waitUntilLive(t require.TestingT) { - c := grpcHealthV1.NewHealthClient(g.readConn(t)) - - ctx, cancel := context.WithCancel(g.ctx) - defer cancel() - - cl, err := c.Watch(ctx, &grpcHealthV1.HealthCheckRequest{}) - require.NoError(t, err) - require.NoError(t, cl.CloseSend()) +func (g *grpcClient) waitUntilLive(t *testing.T) { + c := grpcHealthV1.NewHealthClient(g.read) for { - select { - case <-g.ctx.Done(): - return - default: + res, err := c.Check(g.ctx, &grpcHealthV1.HealthCheckRequest{}) + if errors.Is(err, context.Canceled) { + t.Fatalf("timed out waiting for service to be live: %s", err) } - resp, err := cl.Recv() - require.NoError(t, err) - - if resp.Status == grpcHealthV1.HealthCheckResponse_SERVING { + if err == nil { + require.Equal(t, grpcHealthV1.HealthCheckResponse_SERVING, res.Status) return } + t.Logf("waiting for service to be live: %s", err) + time.Sleep(10 * time.Millisecond) } } -func (g *grpcClient) deleteTuple(t require.TestingT, r *ketoapi.RelationTuple) { +func (g *grpcClient) deleteTuple(t *testing.T, r *ketoapi.RelationTuple) { g.transactTuples(t, nil, []*ketoapi.RelationTuple{r}) } -func (g *grpcClient) deleteAllTuples(t require.TestingT, q *ketoapi.RelationQuery) { - c := rts.NewWriteServiceClient(g.writeConn(t)) +func (g *grpcClient) deleteAllTuples(t *testing.T, q *ketoapi.RelationQuery) { + c := rts.NewWriteServiceClient(g.write) query := &rts.RelationQuery{ Namespace: q.Namespace, Object: q.Object, @@ -267,8 +250,8 @@ func (g *grpcClient) deleteAllTuples(t require.TestingT, q *ketoapi.RelationQuer require.NoError(t, err) } -func (g *grpcClient) transactTuples(t require.TestingT, ins []*ketoapi.RelationTuple, del []*ketoapi.RelationTuple) { - c := rts.NewWriteServiceClient(g.writeConn(t)) +func (g *grpcClient) transactTuples(t *testing.T, ins []*ketoapi.RelationTuple, del []*ketoapi.RelationTuple) { + c := rts.NewWriteServiceClient(g.write) deltas := make([]*rts.RelationTupleDelta, len(ins)+len(del)) for i := range ins { @@ -291,8 +274,8 @@ func (g *grpcClient) transactTuples(t require.TestingT, ins []*ketoapi.RelationT require.NoError(t, err) } -func (g *grpcClient) oplCheckSyntax(t require.TestingT, content []byte) (parseErrors []*ketoapi.ParseError) { - c := opl.NewSyntaxServiceClient(g.oplSyntaxConn(t)) +func (g *grpcClient) oplCheckSyntax(t *testing.T, content []byte) (parseErrors []*ketoapi.ParseError) { + c := opl.NewSyntaxServiceClient(g.oplSyntax) res, err := c.Check(g.ctx, &opl.CheckRequest{Content: content}) require.NoError(t, err) diff --git a/internal/e2e/rest_client_test.go b/internal/e2e/rest_client_test.go index 7e3ff92e1..d92f47cd9 100644 --- a/internal/e2e/rest_client_test.go +++ b/internal/e2e/rest_client_test.go @@ -10,25 +10,22 @@ import ( "io" "net/http" "strconv" + "testing" "time" - client2 "github.com/ory/keto/internal/httpclient" - "github.com/ory/keto/internal/schema" - "github.com/ory/keto/ketoapi" - "github.com/ory/herodot" - "github.com/tidwall/gjson" - "github.com/ory/x/healthx" - - "github.com/ory/keto/internal/x" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "github.com/ory/keto/internal/check" "github.com/ory/keto/internal/expand" + client2 "github.com/ory/keto/internal/httpclient" "github.com/ory/keto/internal/relationtuple" + "github.com/ory/keto/internal/schema" + "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" ) var _ client = &restClient{} @@ -37,7 +34,7 @@ type restClient struct { readURL, writeURL, oplSyntaxURL string } -func (rc *restClient) queryNamespaces(t require.TestingT) (res ketoapi.GetNamespacesResponse) { +func (rc *restClient) queryNamespaces(t *testing.T) (res ketoapi.GetNamespacesResponse) { body, code := rc.makeRequest(t, http.MethodGet, "/namespaces", "", rc.readURL) assert.Equal(t, http.StatusOK, code, body) require.NoError(t, json.Unmarshal([]byte(body), &res)) @@ -45,7 +42,7 @@ func (rc *restClient) queryNamespaces(t require.TestingT) (res ketoapi.GetNamesp return } -func (rc *restClient) oplCheckSyntax(t require.TestingT, content []byte) []*ketoapi.ParseError { +func (rc *restClient) oplCheckSyntax(t *testing.T, content []byte) []*ketoapi.ParseError { body, code := rc.makeRequest(t, http.MethodPost, schema.RouteBase, string(content), rc.oplSyntaxURL) assert.Equal(t, http.StatusOK, code, body) var response ketoapi.CheckOPLSyntaxResponse @@ -54,7 +51,7 @@ func (rc *restClient) oplCheckSyntax(t require.TestingT, content []byte) []*keto return response.Errors } -func (rc *restClient) makeRequest(t require.TestingT, method, path, body string, baseURL string) (string, int) { +func (rc *restClient) makeRequest(t *testing.T, method, path, body string, baseURL string) (string, int) { var b io.Reader if body != "" { b = bytes.NewBufferString(body) @@ -72,7 +69,7 @@ func (rc *restClient) makeRequest(t require.TestingT, method, path, body string, return string(respBody), resp.StatusCode } -func (rc *restClient) createTuple(t require.TestingT, r *ketoapi.RelationTuple) { +func (rc *restClient) createTuple(t *testing.T, r *ketoapi.RelationTuple) { tEnc, err := json.Marshal(r) require.NoError(t, err) @@ -80,17 +77,17 @@ func (rc *restClient) createTuple(t require.TestingT, r *ketoapi.RelationTuple) assert.Equal(t, http.StatusCreated, code, body) } -func (rc *restClient) deleteTuple(t require.TestingT, r *ketoapi.RelationTuple) { +func (rc *restClient) deleteTuple(t *testing.T, r *ketoapi.RelationTuple) { body, code := rc.makeRequest(t, http.MethodDelete, relationtuple.WriteRouteBase+"?"+r.ToURLQuery().Encode(), "", rc.writeURL) require.Equal(t, http.StatusNoContent, code, body) } -func (rc *restClient) deleteAllTuples(t require.TestingT, q *ketoapi.RelationQuery) { +func (rc *restClient) deleteAllTuples(t *testing.T, q *ketoapi.RelationQuery) { body, code := rc.makeRequest(t, http.MethodDelete, relationtuple.WriteRouteBase+"?"+q.ToURLQuery().Encode(), "", rc.writeURL) require.Equal(t, http.StatusNoContent, code, body) } -func (rc *restClient) queryTuple(t require.TestingT, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse { +func (rc *restClient) queryTuple(t *testing.T, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse { urlQuery := q.ToURLQuery() pagination := x.GetPaginationOptions(opts...) @@ -110,7 +107,7 @@ func (rc *restClient) queryTuple(t require.TestingT, q *ketoapi.RelationQuery, o return &dec } -func (rc *restClient) queryTupleErr(t require.TestingT, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) { +func (rc *restClient) queryTupleErr(t *testing.T, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) { urlQuery := q.ToURLQuery() pagination := x.GetPaginationOptions(opts...) @@ -129,7 +126,7 @@ func (rc *restClient) queryTupleErr(t require.TestingT, expected herodot.Default assert.Equal(t, expected.Error(), gjson.Get(body, "error.message").String(), body) } -func (rc *restClient) check(t require.TestingT, r *ketoapi.RelationTuple) bool { +func (rc *restClient) check(t *testing.T, r *ketoapi.RelationTuple) bool { q := r.ToURLQuery() bodyGet, codeGet := rc.makeRequest(t, http.MethodGet, fmt.Sprintf("%s?%s", check.RouteBase, q.Encode()), "", rc.readURL) @@ -156,7 +153,7 @@ func (rc *restClient) check(t require.TestingT, r *ketoapi.RelationTuple) bool { return false } -func (rc *restClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple, +func (rc *restClient) batchCheckErr(t *testing.T, requestTuples []*ketoapi.RelationTuple, expected herodot.DefaultError) { req := client2.BatchCheckPermissionBody{ @@ -169,7 +166,7 @@ func (rc *restClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi assert.Contains(t, body, expected.Reason()) } -func (rc *restClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple) []checkResponse { +func (rc *restClient) batchCheck(t *testing.T, requestTuples []*ketoapi.RelationTuple) []checkResponse { req := client2.BatchCheckPermissionBody{ Tuples: tuplesToRelationships(requestTuples), } @@ -191,7 +188,7 @@ func (rc *restClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.Re return responseChecks } -func (rc *restClient) expand(t require.TestingT, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] { +func (rc *restClient) expand(t *testing.T, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] { query := r.ToURLQuery() query.Set("max-depth", fmt.Sprintf("%d", depth)) @@ -204,17 +201,18 @@ func (rc *restClient) expand(t require.TestingT, r *ketoapi.SubjectSet, depth in return tree } -func healthReady(t require.TestingT, readURL string) bool { +func healthReady(t *testing.T, readURL string) bool { req, err := http.NewRequest("GET", readURL+healthx.ReadyCheckPath, nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) if err != nil { return false } + defer resp.Body.Close() return resp.StatusCode == http.StatusOK } -func (rc *restClient) waitUntilLive(t require.TestingT) { +func (rc *restClient) waitUntilLive(t *testing.T) { // wait for /health/ready for !healthReady(t, rc.readURL) { time.Sleep(10 * time.Millisecond) diff --git a/internal/e2e/sdk_client_test.go b/internal/e2e/sdk_client_test.go index 21da87cb6..4274ad67b 100644 --- a/internal/e2e/sdk_client_test.go +++ b/internal/e2e/sdk_client_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "net/http" + "testing" "time" "github.com/ory/herodot" @@ -30,15 +31,16 @@ var _ client = (*sdkClient)(nil) var requestTimeout = 5 * time.Second -func (c *sdkClient) requestCtx() context.Context { - ctx, _ := context.WithTimeout(context.Background(), requestTimeout) +func (c *sdkClient) requestCtx(t *testing.T) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + t.Cleanup(cancel) return ctx } -func (c *sdkClient) oplCheckSyntax(t require.TestingT, content []byte) (parseErrors []*ketoapi.ParseError) { +func (c *sdkClient) oplCheckSyntax(t *testing.T, content []byte) (parseErrors []*ketoapi.ParseError) { res, _, err := c.getOPLSyntaxClient(). RelationshipApi. - CheckOplSyntax(c.requestCtx()). + CheckOplSyntax(c.requestCtx(t)). Body(string(content)). Execute() require.NoError(t, err) @@ -81,7 +83,7 @@ func (c *sdkClient) getOPLSyntaxClient() *httpclient.APIClient { return c.sc } -func (c *sdkClient) createTuple(t require.TestingT, r *ketoapi.RelationTuple) { +func (c *sdkClient) createTuple(t *testing.T, r *ketoapi.RelationTuple) { payload := httpclient.CreateRelationshipBody{ Namespace: pointerx.Ptr(r.Namespace), Object: pointerx.Ptr(r.Object), @@ -97,7 +99,7 @@ func (c *sdkClient) createTuple(t require.TestingT, r *ketoapi.RelationTuple) { } _, _, err := c.getWriteClient().RelationshipApi. - CreateRelationship(c.requestCtx()). + CreateRelationship(c.requestCtx(t)). CreateRelationshipBody(payload). Execute() require.NoError(t, err) @@ -121,9 +123,9 @@ func withSubject[P interface { return params } -func (c *sdkClient) deleteTuple(t require.TestingT, r *ketoapi.RelationTuple) { +func (c *sdkClient) deleteTuple(t *testing.T, r *ketoapi.RelationTuple) { request := c.getWriteClient().RelationshipApi. - DeleteRelationships(c.requestCtx()). + DeleteRelationships(c.requestCtx(t)). Namespace(r.Namespace). Object(r.Object). Relation(r.Relation) @@ -133,8 +135,8 @@ func (c *sdkClient) deleteTuple(t require.TestingT, r *ketoapi.RelationTuple) { require.NoError(t, err) } -func (c *sdkClient) deleteAllTuples(t require.TestingT, q *ketoapi.RelationQuery) { - request := c.getWriteClient().RelationshipApi.DeleteRelationships(c.requestCtx()) +func (c *sdkClient) deleteAllTuples(t *testing.T, q *ketoapi.RelationQuery) { + request := c.getWriteClient().RelationshipApi.DeleteRelationships(c.requestCtx(t)) if q.Namespace != nil { request = request.Namespace(*q.Namespace) } @@ -180,8 +182,8 @@ func compileParams(req httpclient.RelationshipApiApiGetRelationshipsRequest, q * return req } -func (c *sdkClient) queryTuple(t require.TestingT, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse { - request := c.getReadClient().RelationshipApi.GetRelationships(c.requestCtx()) +func (c *sdkClient) queryTuple(t *testing.T, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse { + request := c.getReadClient().RelationshipApi.GetRelationships(c.requestCtx(t)) request = compileParams(request, q, opts) resp, _, err := request.Execute() @@ -212,8 +214,8 @@ func (c *sdkClient) queryTuple(t require.TestingT, q *ketoapi.RelationQuery, opt return getResp } -func (c *sdkClient) queryTupleErr(t require.TestingT, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) { - request := c.getReadClient().RelationshipApi.GetRelationships(c.requestCtx()) +func (c *sdkClient) queryTupleErr(t *testing.T, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) { + request := c.getReadClient().RelationshipApi.GetRelationships(c.requestCtx(t)) request = compileParams(request, q, opts) _, _, err := request.Execute() @@ -227,8 +229,8 @@ func (c *sdkClient) queryTupleErr(t require.TestingT, expected herodot.DefaultEr } } -func (c *sdkClient) check(t require.TestingT, r *ketoapi.RelationTuple) bool { - request := c.getReadClient().PermissionApi.CheckPermission(c.requestCtx()). +func (c *sdkClient) check(t *testing.T, r *ketoapi.RelationTuple) bool { + request := c.getReadClient().PermissionApi.CheckPermission(c.requestCtx(t)). Namespace(r.Namespace). Object(r.Object). Relation(r.Relation) @@ -240,9 +242,9 @@ func (c *sdkClient) check(t require.TestingT, r *ketoapi.RelationTuple) bool { return resp.GetAllowed() } -func (c *sdkClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple, expected herodot.DefaultError) { +func (c *sdkClient) batchCheckErr(t *testing.T, requestTuples []*ketoapi.RelationTuple, expected herodot.DefaultError) { - request := c.getReadClient().PermissionApi.BatchCheckPermission(c.requestCtx()). + request := c.getReadClient().PermissionApi.BatchCheckPermission(c.requestCtx(t)). BatchCheckPermissionBody(httpclient.BatchCheckPermissionBody{ Tuples: tuplesToRelationships(requestTuples), }) @@ -258,8 +260,8 @@ func (c *sdkClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi.R } } -func (c *sdkClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple) []checkResponse { - request := c.getReadClient().PermissionApi.BatchCheckPermission(c.requestCtx()). +func (c *sdkClient) batchCheck(t *testing.T, requestTuples []*ketoapi.RelationTuple) []checkResponse { + request := c.getReadClient().PermissionApi.BatchCheckPermission(c.requestCtx(t)). BatchCheckPermissionBody(httpclient.BatchCheckPermissionBody{ Tuples: tuplesToRelationships(requestTuples), }) @@ -302,7 +304,7 @@ func tuplesToRelationships(tuples []*ketoapi.RelationTuple) []httpclient.Relatio return relationships } -func buildTree(t require.TestingT, mt *httpclient.ExpandedPermissionTree) *ketoapi.Tree[*ketoapi.RelationTuple] { +func buildTree(t *testing.T, mt *httpclient.ExpandedPermissionTree) *ketoapi.Tree[*ketoapi.RelationTuple] { result := &ketoapi.Tree[*ketoapi.RelationTuple]{ Type: ketoapi.TreeNodeType(mt.Type), } @@ -330,8 +332,8 @@ func buildTree(t require.TestingT, mt *httpclient.ExpandedPermissionTree) *ketoa return result } -func (c *sdkClient) expand(t require.TestingT, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] { - request := c.getReadClient().PermissionApi.ExpandPermissions(c.requestCtx()). +func (c *sdkClient) expand(t *testing.T, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple] { + request := c.getReadClient().PermissionApi.ExpandPermissions(c.requestCtx(t)). Namespace(r.Namespace). Object(r.Object). Relation(r.Relation). @@ -343,16 +345,16 @@ func (c *sdkClient) expand(t require.TestingT, r *ketoapi.SubjectSet, depth int) return buildTree(t, resp) } -func (c *sdkClient) waitUntilLive(t require.TestingT) { - resp, _, err := c.getReadClient().MetadataApi.IsReady(c.requestCtx()).Execute() +func (c *sdkClient) waitUntilLive(t *testing.T) { + resp, _, err := c.getReadClient().MetadataApi.IsReady(c.requestCtx(t)).Execute() for err != nil { - resp, _, err = c.getReadClient().MetadataApi.IsReady(c.requestCtx()).Execute() + resp, _, err = c.getReadClient().MetadataApi.IsReady(c.requestCtx(t)).Execute() } require.Equal(t, "ok", resp.Status) } -func (c *sdkClient) queryNamespaces(t require.TestingT) (response ketoapi.GetNamespacesResponse) { - res, _, err := c.getReadClient().RelationshipApi.ListRelationshipNamespaces(c.requestCtx()).Execute() +func (c *sdkClient) queryNamespaces(t *testing.T) (response ketoapi.GetNamespacesResponse) { + res, _, err := c.getReadClient().RelationshipApi.ListRelationshipNamespaces(c.requestCtx(t)).Execute() require.NoError(t, err) require.NoError(t, convert(res, &response)) diff --git a/internal/e2e/testcases_test.go b/internal/e2e/testcases_test.go index a75d1c842..0ee52859f 100644 --- a/internal/e2e/testcases_test.go +++ b/internal/e2e/testcases_test.go @@ -8,17 +8,15 @@ import ( "strconv" "testing" - "github.com/ory/x/pointerx" - - "github.com/ory/keto/internal/expand" - "github.com/ory/keto/ketoapi" - "github.com/ory/herodot" + "github.com/ory/x/pointerx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" ) func runCases(c client, m *namespaceTestManager) func(*testing.T) { diff --git a/internal/e2e/transaction_cases_test.go b/internal/e2e/transaction_cases_test.go index 4e4e7e5b0..9f925a902 100644 --- a/internal/e2e/transaction_cases_test.go +++ b/internal/e2e/transaction_cases_test.go @@ -4,15 +4,19 @@ package e2e import ( + "cmp" + "slices" + "strconv" "testing" + "time" "github.com/ory/x/pointerx" - - "github.com/ory/keto/ketoapi" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ory/keto/internal/namespace" + "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" ) func runTransactionCases(c transactClient, m *namespaceTestManager) func(*testing.T) { @@ -58,6 +62,84 @@ func runTransactionCases(c transactClient, m *namespaceTestManager) func(*testin assert.Len(t, resp.RelationTuples, 0) }) + t.Run("case=large inserts and deletes", func(t *testing.T) { + if !testing.Short() { + t.Skip("This test is fairly expensive, especially the deletion.") + } + + ns := []*namespace.Namespace{ + {Name: t.Name() + "1"}, + {Name: t.Name() + "2"}, + } + m.add(t, ns...) + + var tuples []*ketoapi.RelationTuple + for i := range 12001 { + tuples = append(tuples, + &ketoapi.RelationTuple{ + Namespace: ns[0].Name, + Object: "o" + strconv.Itoa(i), + Relation: "rela", + SubjectSet: &ketoapi.SubjectSet{ + Namespace: ns[1].Name, + Object: "o" + strconv.Itoa(i), + Relation: "relx", + }, + }, + &ketoapi.RelationTuple{ + Namespace: ns[0].Name, + Object: "o" + strconv.Itoa(i), + Relation: "relb", + SubjectID: pointerx.Ptr("sid"), + }, + ) + } + + t0 := time.Now() + c.transactTuples(t, tuples, nil) + t.Log("insert", time.Since(t0)) + + t0 = time.Now() + var resp []*ketoapi.RelationTuple + var pt string + for { + r := c.queryTuple(t, &ketoapi.RelationQuery{ + Namespace: &ns[0].Name, + }, x.WithSize(1000), x.WithToken(pt)) + resp = append(resp, r.RelationTuples...) + pt = r.NextPageToken + if pt == "" { + break + } + } + t.Log("query", time.Since(t0)) + + sort := func(a, b *ketoapi.RelationTuple) int { + return cmp.Or( + cmp.Compare(a.Namespace, b.Namespace), + cmp.Compare(a.Object, b.Object), + cmp.Compare(a.Relation, b.Relation), + ) + } + t0 = time.Now() + slices.SortFunc(resp, sort) + slices.SortFunc(tuples, sort) + t.Log("sort", time.Since(t0)) + + t0 = time.Now() + require.Equal(t, tuples, resp) + t.Log("equal", time.Since(t0)) + + t0 = time.Now() + c.transactTuples(t, nil, tuples) + t.Log(t.Name(), "delete", time.Since(t0)) + + resp = c.queryTuple(t, &ketoapi.RelationQuery{ + Namespace: &ns[0].Name, + }).RelationTuples + assert.Len(t, resp, 0) + }) + t.Run("case=expand-api-display-access docs code sample", func(t *testing.T) { files := &namespace.Namespace{Name: t.Name() + "files"} directories := &namespace.Namespace{Name: t.Name() + "directories"} diff --git a/internal/persistence/sql/persister.go b/internal/persistence/sql/persister.go index 7fd5d94f6..b7fc83098 100644 --- a/internal/persistence/sql/persister.go +++ b/internal/persistence/sql/persister.go @@ -6,11 +6,9 @@ package sql import ( "context" "embed" - "reflect" "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" - "github.com/ory/x/otelx" "github.com/ory/x/popx" "github.com/pkg/errors" @@ -70,24 +68,6 @@ func (p *Persister) Connection(ctx context.Context) *pop.Connection { return popx.GetConnection(ctx, p.conn.WithContext(ctx)) } -func (p *Persister) createWithNetwork(ctx context.Context, v interface{}) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createWithNetwork") - defer otelx.End(span, &err) - - rv := reflect.ValueOf(v) - - if rv.Kind() != reflect.Ptr && rv.Elem().Kind() != reflect.Struct { - panic("expected to get *struct in create") - } - nID := rv.Elem().FieldByName("NetworkID") - if !nID.IsValid() || !nID.CanSet() { - panic("expected struct to have a 'NetworkID uuid.UUID' field") - } - nID.Set(reflect.ValueOf(p.NetworkID(ctx))) - - return p.Connection(ctx).Create(v) -} - func (p *Persister) queryWithNetwork(ctx context.Context) *pop.Query { return p.Connection(ctx).Where("nid = ?", p.NetworkID(ctx)) } diff --git a/internal/persistence/sql/query_test.go b/internal/persistence/sql/query_test.go new file mode 100644 index 000000000..cad24e6ab --- /dev/null +++ b/internal/persistence/sql/query_test.go @@ -0,0 +1,128 @@ +package sql + +import ( + "database/sql" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/ory/x/uuidx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/keto/internal/relationtuple" +) + +func TestBuildDelete(t *testing.T) { + t.Parallel() + nid := uuidx.NewV4() + + q, args, err := buildDelete(nid, nil) + assert.Error(t, err) + assert.Empty(t, q) + assert.Empty(t, args) + + obj1, obj2, sub1, obj3 := uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4() + + q, args, err = buildDelete(nid, []*relationtuple.RelationTuple{ + { + Namespace: "ns1", + Object: obj1, + Relation: "rel1", + Subject: &relationtuple.SubjectID{ + ID: sub1, + }, + }, + { + Namespace: "ns2", + Object: obj2, + Relation: "rel2", + Subject: &relationtuple.SubjectSet{ + Namespace: "ns3", + Object: obj3, + Relation: "rel3", + }, + }, + }) + require.NoError(t, err) + + // parentheses are important here + assert.Equal(t, q, "DELETE FROM keto_relation_tuples WHERE ((namespace = ? AND object = ? AND relation = ? AND subject_id = ? AND subject_set_namespace IS NULL AND subject_set_object IS NULL AND subject_set_relation IS NULL) OR (namespace = ? AND object = ? AND relation = ? AND subject_id IS NULL AND subject_set_namespace = ? AND subject_set_object = ? AND subject_set_relation = ?)) AND nid = ?") + assert.Equal(t, []any{"ns1", obj1, "rel1", sub1, "ns2", obj2, "rel2", "ns3", obj3, "rel3", nid}, args) +} + +func TestBuildInsert(t *testing.T) { + t.Parallel() + nid := uuidx.NewV4() + + q, args, err := buildInsert(time.Now(), nid, nil) + assert.Error(t, err) + assert.Empty(t, q) + assert.Empty(t, args) + + obj1, obj2, sub1, obj3 := uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4() + + now := time.Now() + + q, args, err = buildInsert(now, nid, []*relationtuple.RelationTuple{ + { + Namespace: "ns1", + Object: obj1, + Relation: "rel1", + Subject: &relationtuple.SubjectID{ + ID: sub1, + }, + }, + { + Namespace: "ns2", + Object: obj2, + Relation: "rel2", + Subject: &relationtuple.SubjectSet{ + Namespace: "ns3", + Object: obj3, + Relation: "rel3", + }, + }, + }) + require.NoError(t, err) + + assert.Equal(t, q, "INSERT INTO keto_relation_tuples (shard_id, nid, namespace, object, relation, subject_id, subject_set_namespace, subject_set_object, subject_set_relation, commit_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?), (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + assert.Equal(t, []any{ + args[0], // this is kind of cheating but we generate the shard id in the buildInsert function + nid, + "ns1", + obj1, + "rel1", + uuid.NullUUID{sub1, true}, + sql.NullString{}, uuid.NullUUID{}, sql.NullString{}, + now, + + args[10], // again, cheating + nid, + "ns2", + obj2, + "rel2", + uuid.NullUUID{}, + sql.NullString{"ns3", true}, uuid.NullUUID{obj3, true}, sql.NullString{"rel3", true}, + now, + }, args) +} + +func TestBuildInsertUUIDs(t *testing.T) { + t.Parallel() + + foo, bar, baz := uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4() + uuids := []UUIDMapping{ + {foo, "foo"}, + {bar, "bar"}, + {baz, "baz"}, + } + + q, args := buildInsertUUIDs(uuids, "mysql") + assert.Equal(t, "INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES (?,?),(?,?),(?,?)", q) + assert.Equal(t, []any{foo, "foo", bar, "bar", baz, "baz"}, args) + + q, args = buildInsertUUIDs(uuids, "anything else") + assert.Equal(t, "INSERT INTO keto_uuid_mappings (id, string_representation) VALUES (?,?),(?,?),(?,?) ON CONFLICT (id) DO NOTHING", q) + assert.Equal(t, []any{foo, "foo", bar, "bar", baz, "baz"}, args) +} diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index 9d82bee31..b6f40f781 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -6,18 +6,29 @@ package sql import ( "context" "database/sql" + "fmt" + "slices" + "strings" "time" - "github.com/ory/keto/ketoapi" - "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" "github.com/pkg/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "github.com/ory/keto/internal/relationtuple" "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" +) + +// Typical database limits for placeholders/bind vars are 1<<15 (32k, MySQL, SQLite) and 1<<16 (64k, PostgreSQL, CockroachDB). +const ( + chunkSizeInsertUUIDMappings = 15000 // two placeholders per mapping + chunkSizeInsertTuple = 3000 // ten placeholders per tuple + chunkSizeDeleteTuple = 100 // the database must build an expression tree for each chunk, so we must limit more aggressively ) type ( @@ -71,7 +82,7 @@ func (r *RelationTuple) ToInternal() (*relationtuple.RelationTuple, error) { return rt, nil } -func (r *RelationTuple) insertSubject(_ context.Context, s relationtuple.Subject) error { +func (r *RelationTuple) insertSubject(s relationtuple.Subject) error { switch st := s.(type) { case *relationtuple.SubjectID: r.SubjectID = uuid.NullUUID{ @@ -90,39 +101,12 @@ func (r *RelationTuple) insertSubject(_ context.Context, s relationtuple.Subject return nil } -func (r *RelationTuple) FromInternal(ctx context.Context, p *Persister, rt *relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FromInternal") - defer otelx.End(span, &err) - +func (r *RelationTuple) FromInternal(rt *relationtuple.RelationTuple) (err error) { r.Namespace = rt.Namespace r.Object = rt.Object r.Relation = rt.Relation - return r.insertSubject(ctx, rt.Subject) -} - -func (p *Persister) InsertRelationTuple(ctx context.Context, rel *relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InsertRelationTuple") - defer otelx.End(span, &err) - - if rel.Subject == nil { - return errors.WithStack(ketoapi.ErrNilSubject) - } - - rt := &RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - CommitTime: time.Now(), - } - if err := rt.FromInternal(ctx, p, rel); err != nil { - return err - } - - if err := sqlcon.HandleError( - p.createWithNetwork(ctx, rt), - ); err != nil { - return err - } - return nil + return r.insertSubject(rt.Subject) } func (p *Persister) whereSubject(_ context.Context, q *pop.Query, sub relationtuple.Subject) error { @@ -165,25 +149,53 @@ func (p *Persister) whereQuery(ctx context.Context, q *pop.Query, rq *relationtu return nil } +func buildDelete(nid uuid.UUID, rs []*relationtuple.RelationTuple) (query string, args []any, err error) { + if len(rs) == 0 { + return "", nil, errors.WithStack(ketoapi.ErrMalformedInput) + } + + args = make([]any, 0, 6*len(rs)+1) + ors := make([]string, 0, len(rs)) + for _, rt := range rs { + switch s := rt.Subject.(type) { + case *relationtuple.SubjectID: + ors = append(ors, "(namespace = ? AND object = ? AND relation = ? AND subject_id = ? AND subject_set_namespace IS NULL AND subject_set_object IS NULL AND subject_set_relation IS NULL)") + args = append(args, rt.Namespace, rt.Object, rt.Relation, s.ID) + case *relationtuple.SubjectSet: + ors = append(ors, "(namespace = ? AND object = ? AND relation = ? AND subject_id IS NULL AND subject_set_namespace = ? AND subject_set_object = ? AND subject_set_relation = ?)") + args = append(args, rt.Namespace, rt.Object, rt.Relation, s.Namespace, s.Object, s.Relation) + case nil: + return "", nil, errors.WithStack(ketoapi.ErrNilSubject) + } + } + + query = fmt.Sprintf("DELETE FROM %s WHERE (%s) AND nid = ?", (&RelationTuple{}).TableName(), strings.Join(ors, " OR ")) + args = append(args, nid) + return query, args, nil +} + func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRelationTuples") + ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRelationTuples", + trace.WithAttributes(attribute.Int("count", len(rs)))) defer otelx.End(span, &err) + if len(rs) == 0 { + return nil + } + return p.Transaction(ctx, func(ctx context.Context) error { - for _, r := range rs { - q := p.queryWithNetwork(ctx). - Where("namespace = ?", r.Namespace). - Where("object = ?", r.Object). - Where("relation = ?", r.Relation) - if err := p.whereSubject(ctx, q, r.Subject); err != nil { + for chunk := range slices.Chunk(rs, chunkSizeDeleteTuple) { + q, args, err := buildDelete(p.NetworkID(ctx), chunk) + if err != nil { return err } - - if err := q.Delete(&RelationTuple{}); err != nil { - return err + if q == "" { + continue + } + if err := p.Connection(ctx).RawQuery(q, args...).Exec(); err != nil { + return sqlcon.HandleError(err) } } - return nil }) } @@ -260,15 +272,63 @@ func (p *Persister) ExistsRelationTuples(ctx context.Context, query *relationtup return exists, sqlcon.HandleError(err) } +func buildInsert(commitTime time.Time, nid uuid.UUID, rs []*relationtuple.RelationTuple) (query string, args []any, err error) { + if len(rs) == 0 { + return "", nil, errors.WithStack(ketoapi.ErrMalformedInput) + } + + var q strings.Builder + fmt.Fprintf(&q, "INSERT INTO %s (shard_id, nid, namespace, object, relation, subject_id, subject_set_namespace, subject_set_object, subject_set_relation, commit_time) VALUES ", (&RelationTuple{}).TableName()) + const placeholders = "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + const separator = ", " + q.Grow(len(rs) * (len(placeholders) + len(separator))) + args = make([]any, 0, 10*len(rs)) + + for i, r := range rs { + if r.Subject == nil { + return "", nil, errors.WithStack(ketoapi.ErrNilSubject) + } + + rt := &RelationTuple{ + ID: uuid.Must(uuid.NewV4()), + NetworkID: nid, + CommitTime: commitTime, + } + if err := rt.FromInternal(r); err != nil { + return "", nil, err + } + + if i > 0 { + q.WriteString(separator) + } + q.WriteString(placeholders) + args = append(args, rt.ID, rt.NetworkID, rt.Namespace, rt.Object, rt.Relation, rt.SubjectID, rt.SubjectSetNamespace, rt.SubjectSetObject, rt.SubjectSetRelation, rt.CommitTime) + } + + query = q.String() + return query, args, nil +} + func (p *Persister) WriteRelationTuples(ctx context.Context, rs ...*relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.WriteRelationTuples") + ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.WriteRelationTuples", + trace.WithAttributes(attribute.Int("count", len(rs)))) defer otelx.End(span, &err) + if len(rs) == 0 { + return nil + } + + commitTime := time.Now() + return p.Transaction(ctx, func(ctx context.Context) error { - for _, r := range rs { - if err := p.InsertRelationTuple(ctx, r); err != nil { + for chunk := range slices.Chunk(rs, chunkSizeInsertTuple) { + q, args, err := buildInsert(commitTime, p.NetworkID(ctx), chunk) + if err != nil { return err } + if err := p.Connection(ctx).RawQuery(q, args...).Exec(); err != nil { + return sqlcon.HandleError(err) + } } return nil }) @@ -278,6 +338,10 @@ func (p *Persister) TransactRelationTuples(ctx context.Context, ins []*relationt ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.TransactRelationTuples") defer otelx.End(span, &err) + if len(ins)+len(del) == 0 { + return nil + } + return p.Transaction(ctx, func(ctx context.Context) error { if err := p.WriteRelationTuples(ctx, ins...); err != nil { return err diff --git a/internal/persistence/sql/uuid_mapping.go b/internal/persistence/sql/uuid_mapping.go index 9b48b1259..4cd429a41 100644 --- a/internal/persistence/sql/uuid_mapping.go +++ b/internal/persistence/sql/uuid_mapping.go @@ -7,6 +7,7 @@ import ( "context" "iter" "maps" + "slices" "strings" "github.com/gofrs/uuid" @@ -32,47 +33,6 @@ func (UUIDMapping) TableName() string { return "keto_uuid_mappings" } -func (p *Persister) batchToUUIDs(ctx context.Context, values []string, readOnly bool) (uuids []uuid.UUID, err error) { - if len(values) == 0 { - return - } - - uuids = make([]uuid.UUID, len(values)) - placeholderArray := make([]string, len(values)) - args := make([]interface{}, 0, len(values)*2) - for i, val := range values { - uuids[i] = uuid.NewV5(p.NetworkID(ctx), val) - placeholderArray[i] = "(?, ?)" - args = append(args, uuids[i], val) - } - placeholders := strings.Join(placeholderArray, ", ") - - p.d.Logger().WithField("values", values).WithField("UUIDs", uuids).Trace("adding UUID mappings") - - if !readOnly { - // We need to write manual SQL here because the INSERT should not fail if - // the UUID already exists, but we still want to return an error if anything - // else goes wrong. - var query string - switch d := p.Connection(ctx).Dialect.Name(); d { - case "mysql": - query = ` - INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES ` + placeholders - default: - query = ` - INSERT INTO keto_uuid_mappings (id, string_representation) - VALUES ` + placeholders + ` - ON CONFLICT (id) DO NOTHING` - } - - return uuids, sqlcon.HandleError( - p.Connection(ctx).RawQuery(query, args...).Exec(), - ) - } else { - return uuids, nil - } -} - func (p *Persister) batchFromUUIDs(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (res []string, err error) { if len(ids) == 0 { return @@ -128,18 +88,52 @@ func (p *Persister) batchFromUUIDs(ctx context.Context, ids []uuid.UUID, opts .. return } -func (p *Persister) MapStringsToUUIDs(ctx context.Context, s ...string) (_ []uuid.UUID, err error) { +func (p *Persister) MapStringsToUUIDs(ctx context.Context, values ...string) (uuids []uuid.UUID, err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDs") defer otelx.End(span, &err) - return p.batchToUUIDs(ctx, s, false) + if len(values) == 0 { + return + } + + uuids, err = p.MapStringsToUUIDsReadOnly(ctx, values...) + if err != nil { + return nil, err + } + + p.d.Logger().WithField("values", values).WithField("UUIDs", uuids).Trace("adding UUID mappings") + + mappings := make([]UUIDMapping, len(values)) + for i := range values { + mappings[i] = UUIDMapping{ + ID: uuids[i], + StringRepresentation: values[i], + } + } + + err = p.Transaction(ctx, func(ctx context.Context) error { + for chunk := range slices.Chunk(mappings, chunkSizeInsertUUIDMappings) { + query, args := buildInsertUUIDs(chunk, p.conn.Dialect.Name()) + if err := p.Connection(ctx).RawQuery(query, args...).Exec(); err != nil { + return sqlcon.HandleError(err) + } + } + return nil + }) + + return uuids, err } -func (p *Persister) MapStringsToUUIDsReadOnly(ctx context.Context, s ...string) (_ []uuid.UUID, err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDsReadOnly") - defer otelx.End(span, &err) +func (p *Persister) MapStringsToUUIDsReadOnly(ctx context.Context, ss ...string) (uuids []uuid.UUID, err error) { + // This function doesn't talk to the database or do anything interesting, so we don't need to trace it. + // ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDsReadOnly") + // defer otelx.End(span, &err) - return p.batchToUUIDs(ctx, s, true) + uuids = make([]uuid.UUID, len(ss)) + for i := range ss { + uuids[i] = uuid.NewV5(p.NetworkID(ctx), ss[i]) + } + return uuids, nil } func (p *Persister) MapUUIDsToStrings(ctx context.Context, u ...uuid.UUID) (_ []string, err error) { @@ -148,3 +142,39 @@ func (p *Persister) MapUUIDsToStrings(ctx context.Context, u ...uuid.UUID) (_ [] return p.batchFromUUIDs(ctx, u) } + +func buildInsertUUIDs(values []UUIDMapping, dialect string) (query string, args []any) { + if len(values) == 0 { + return "", nil + } + + const placeholder = "(?,?)" + const separator = "," + + var q strings.Builder + args = make([]any, 0, len(values)*2) + + if dialect == "mysql" { + q.WriteString("INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES ") + } else { + q.WriteString("INSERT INTO keto_uuid_mappings (id, string_representation) VALUES ") + } + + q.Grow(len(values)*(len(placeholder)+len(separator)) + 100) + + for i, val := range values { + if i > 0 { + q.WriteString(separator) + } + q.WriteString(placeholder) + args = append(args, val.ID, val.StringRepresentation) + } + + if dialect == "mysql" { + // nothing + } else { + q.WriteString(" ON CONFLICT (id) DO NOTHING") + } + + return q.String(), args +} diff --git a/internal/x/dbx/dsn_testutils.go b/internal/x/dbx/dsn_testutils.go index f13ba8dd1..af74f5a73 100644 --- a/internal/x/dbx/dsn_testutils.go +++ b/internal/x/dbx/dsn_testutils.go @@ -10,12 +10,12 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" "time" "github.com/go-sql-driver/mysql" "github.com/gobuffalo/pop/v6" - "github.com/ory/x/sqlcon/dockertest" "github.com/stretchr/testify/require" "github.com/tidwall/sjson" ) @@ -110,17 +110,21 @@ func GetDSNs(t testing.TB, debugSqliteOnDisk bool) []*DsnT { var mysql, postgres, cockroach string testDB := dbName(t.Name()) - dockertest.Parallel([]func(){ - func() { - mysql = RunMySQL(t, testDB) - }, - func() { - postgres = RunPostgres(t, testDB) - }, - func() { - cockroach = RunCockroach(t, testDB) - }, - }) + var wg sync.WaitGroup + wg.Add(3) + go func() { + defer wg.Done() + postgres = RunPostgres(t, testDB) + }() + go func() { + defer wg.Done() + mysql = RunMySQL(t, testDB) + }() + go func() { + defer wg.Done() + cockroach = RunCockroach(t, testDB) + }() + wg.Wait() if mysql != "" { dsns = append(dsns, &DsnT{