diff --git a/internal/persistence/sql/query_test.go b/internal/persistence/sql/query_test.go new file mode 100644 index 000000000..f27acba18 --- /dev/null +++ b/internal/persistence/sql/query_test.go @@ -0,0 +1,49 @@ +package sql + +import ( + "testing" + + "github.com/ory/x/uuidx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/keto/internal/relationtuple" +) + +func TestBuildDelete(t *testing.T) { + t.Parallel() + nid := uuidx.NewV4() + + q, args, err := buildDelete(nid, nil) + assert.Error(t, err) + assert.Empty(t, q) + assert.Empty(t, args) + + obj1, obj2, sub1, obj3 := uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4() + + q, args, err = buildDelete(nid, []*relationtuple.RelationTuple{ + { + Namespace: "ns1", + Object: obj1, + Relation: "rel1", + Subject: &relationtuple.SubjectID{ + ID: sub1, + }, + }, + { + Namespace: "ns2", + Object: obj2, + Relation: "rel2", + Subject: &relationtuple.SubjectSet{ + Namespace: "ns3", + Object: obj3, + Relation: "rel3", + }, + }, + }) + require.NoError(t, err) + + // parentheses are important here + assert.Equal(t, q, "DELETE FROM keto_relation_tuples WHERE ((namespace = ? AND object = ? AND relation = ? AND subject_id = ? AND subject_set_namespace IS NULL AND subject_set_object IS NULL AND subject_set_relation IS NULL) OR (namespace = ? AND object = ? AND relation = ? AND subject_id IS NULL AND subject_set_namespace = ? AND subject_set_object = ? AND subject_set_relation = ?)) AND nid = ?") + assert.Equal(t, []any{"ns1", obj1, "rel1", sub1, "ns2", obj2, "rel2", "ns3", obj3, "rel3", nid}, args) +} diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index 9d82bee31..4b92bb22c 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -6,18 +6,22 @@ package sql import ( "context" "database/sql" + "fmt" + "slices" + "strings" "time" - "github.com/ory/keto/ketoapi" - "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" "github.com/pkg/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "github.com/ory/keto/internal/relationtuple" "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" ) type ( @@ -165,25 +169,53 @@ func (p *Persister) whereQuery(ctx context.Context, q *pop.Query, rq *relationtu return nil } +func buildDelete(nid uuid.UUID, rs []*relationtuple.RelationTuple) (query string, args []any, err error) { + if len(rs) == 0 { + return "", nil, errors.WithStack(ketoapi.ErrMalformedInput) + } + + args = make([]any, 0, 4*len(rs)+1) + ors := make([]string, 0, len(rs)) + for _, rt := range rs { + switch s := rt.Subject.(type) { + case *relationtuple.SubjectID: + ors = append(ors, "(namespace = ? AND object = ? AND relation = ? AND subject_id = ? AND subject_set_namespace IS NULL AND subject_set_object IS NULL AND subject_set_relation IS NULL)") + args = append(args, rt.Namespace, rt.Object, rt.Relation, s.ID) + case *relationtuple.SubjectSet: + ors = append(ors, "(namespace = ? AND object = ? AND relation = ? AND subject_id IS NULL AND subject_set_namespace = ? AND subject_set_object = ? AND subject_set_relation = ?)") + args = append(args, rt.Namespace, rt.Object, rt.Relation, s.Namespace, s.Object, s.Relation) + case nil: + return "", nil, errors.WithStack(ketoapi.ErrNilSubject) + } + } + + query = fmt.Sprintf("DELETE FROM %s WHERE (%s) AND nid = ?", (&RelationTuple{}).TableName(), strings.Join(ors, " OR ")) + args = append(args, nid) + return query, args, nil +} + func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRelationTuples") + ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRelationTuples", + trace.WithAttributes(attribute.Int("count", len(rs)))) defer otelx.End(span, &err) + if len(rs) == 0 { + return nil + } + return p.Transaction(ctx, func(ctx context.Context) error { - for _, r := range rs { - q := p.queryWithNetwork(ctx). - Where("namespace = ?", r.Namespace). - Where("object = ?", r.Object). - Where("relation = ?", r.Relation) - if err := p.whereSubject(ctx, q, r.Subject); err != nil { + for chunk := range slices.Chunk(rs, 500) { + q, args, err := buildDelete(p.NetworkID(ctx), chunk) + if err != nil { return err } - - if err := q.Delete(&RelationTuple{}); err != nil { - return err + if q == "" { + continue + } + if err := p.Connection(ctx).RawQuery(q, args...).Exec(); err != nil { + return sqlcon.HandleError(err) } } - return nil }) }