From 4d75ba0dfb99fb7cde83ec5fa71439bd16f2626a Mon Sep 17 00:00:00 2001 From: Islam Aleiv Date: Thu, 19 Sep 2024 16:55:13 +0200 Subject: [PATCH] Fix query with filter on 2 relation fields of 1 composite index --- internal/planner/multi.go | 14 ++- internal/planner/scan.go | 29 +++-- internal/planner/type_join.go | 58 +++++----- ...th_unique_index_on_relation_filter_test.go | 103 ++++++++++++++++++ 4 files changed, 162 insertions(+), 42 deletions(-) diff --git a/internal/planner/multi.go b/internal/planner/multi.go index ac564c4ed1..170239a833 100644 --- a/internal/planner/multi.go +++ b/internal/planner/multi.go @@ -136,12 +136,22 @@ func (p *parallelNode) nextMerge(_ int, plan planNode) (bool, error) { return false, err } - doc := plan.Value() - copy(p.currentValue.Fields, doc.Fields) + p.currentValue = p.mergeDoc(p.currentValue, plan.Value().Fields) return true, nil } +func (p *parallelNode) mergeDoc(doc core.Doc, newFields core.DocFields) core.Doc { + for i := range newFields { + if doc.Fields[i] == nil { + doc.Fields[i] = newFields[i] + } else if newSubDoc, ok := newFields[i].(core.Doc); ok { + doc.Fields[i] = p.mergeDoc(doc.Fields[i].(core.Doc), newSubDoc.Fields) + } + } + return doc +} + func (p *parallelNode) nextAppend(index int, plan planNode) (bool, error) { key := p.currentValue.GetID() if key == "" { diff --git a/internal/planner/scan.go b/internal/planner/scan.go index 151705a698..019cd1dee2 100644 --- a/internal/planner/scan.go +++ b/internal/planner/scan.go @@ -92,10 +92,10 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error { switch requestable := r.(type) { // field is simple as its just a base level field case *mapper.Field: - n.tryAddField(requestable.GetName()) + n.tryAddFieldWithName(requestable.GetName()) // select might have its own select fields and filters fields case *mapper.Select: - n.tryAddField(requestable.Field.Name + request.RelatedObjectID) // foreign key for type joins + n.tryAddFieldWithName(requestable.Field.Name + request.RelatedObjectID) // foreign key for type joins err := n.initFields(requestable.Fields) if err != nil { return err @@ -112,13 +112,13 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error { return err } for _, fd := range fieldDescs { - n.tryAddField(fd.Name) + n.tryAddFieldWithName(fd.Name) } } if target.ChildTarget.HasValue { - n.tryAddField(target.ChildTarget.Name) + n.tryAddFieldWithName(target.ChildTarget.Name) } else { - n.tryAddField(target.Field.Name) + n.tryAddFieldWithName(target.Field.Name) } } } @@ -126,7 +126,7 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error { return nil } -func (n *scanNode) tryAddField(fieldName string) bool { +func (n *scanNode) tryAddFieldWithName(fieldName string) bool { fd, ok := n.col.Definition().GetFieldByName(fieldName) if !ok { // skip fields that are not part of the @@ -134,10 +134,25 @@ func (n *scanNode) tryAddField(fieldName string) bool { // is only responsible for basic fields return false } - n.fields = append(n.fields, fd) + n.addField(fd) return true } +// addField adds a field to the list of fields to be fetched. +// It will not add the field if it is already in the list. +func (n *scanNode) addField(field client.FieldDefinition) { + found := false + for i := range n.fields { + if n.fields[i].Name == field.Name { + found = true + break + } + } + if !found { + n.fields = append(n.fields, field) + } +} + func (scan *scanNode) initFetcher( cid immutable.Option[string], index immutable.Option[client.IndexDescription], diff --git a/internal/planner/type_join.go b/internal/planner/type_join.go index 2102c74479..fc5eb9bbaf 100644 --- a/internal/planner/type_join.go +++ b/internal/planner/type_join.go @@ -531,43 +531,35 @@ func newPrimaryObjectsRetriever( return j } -func (j *primaryObjectsRetriever) retrievePrimaryDocsReferencingSecondaryDoc() error { - relIDFieldDef, ok := j.primarySide.col.Definition().GetFieldByName( - j.primarySide.relFieldDef.Value().Name + request.RelatedObjectID) +func (r *primaryObjectsRetriever) retrievePrimaryDocsReferencingSecondaryDoc() error { + relIDFieldDef, ok := r.primarySide.col.Definition().GetFieldByName( + r.primarySide.relFieldDef.Value().Name + request.RelatedObjectID) if !ok { - return client.NewErrFieldNotExist(j.primarySide.relFieldDef.Value().Name + request.RelatedObjectID) + return client.NewErrFieldNotExist(r.primarySide.relFieldDef.Value().Name + request.RelatedObjectID) } - j.primaryScan = getScanNode(j.primarySide.plan) + r.primaryScan = getScanNode(r.primarySide.plan) - j.relIDFieldDef = relIDFieldDef + r.relIDFieldDef = relIDFieldDef - primaryDocs, err := j.retrievePrimaryDocs() + primaryDocs, err := r.retrievePrimaryDocs() if err != nil { return err } - j.resultPrimaryDocs, j.resultSecondaryDoc = joinPrimaryDocs(primaryDocs, j.secondarySide, j.primarySide) + r.resultPrimaryDocs, r.resultSecondaryDoc = joinPrimaryDocs(primaryDocs, r.secondarySide, r.primarySide) return nil } -func (j *primaryObjectsRetriever) addIDFieldToScanner() { - found := false - for i := range j.primaryScan.fields { - if j.primaryScan.fields[i].Name == j.relIDFieldDef.Name { - found = true - break - } - } - if !found { - j.primaryScan.fields = append(j.primaryScan.fields, j.relIDFieldDef) +func (r *primaryObjectsRetriever) collectDocs(numDocs int) ([]core.Doc, error) { + p := r.primarySide.plan + // If the primary side is a multiScanNode, we need to get the source node, as we are the only + // consumer (one, not multiple) of it. + if multiScan, ok := p.(*multiScanNode); ok { + p = multiScan.Source() } -} - -func (j *primaryObjectsRetriever) collectDocs(numDocs int) ([]core.Doc, error) { - p := j.primarySide.plan if err := p.Init(); err != nil { return nil, NewErrSubTypeInit(err) } @@ -591,28 +583,28 @@ func (j *primaryObjectsRetriever) collectDocs(numDocs int) ([]core.Doc, error) { return docs, nil } -func (j *primaryObjectsRetriever) retrievePrimaryDocs() ([]core.Doc, error) { - j.addIDFieldToScanner() +func (r *primaryObjectsRetriever) retrievePrimaryDocs() ([]core.Doc, error) { + r.primaryScan.addField(r.relIDFieldDef) - secondaryDoc := j.secondarySide.plan.Value() - addFilterOnIDField(j.primaryScan, j.primarySide.relIDFieldMapIndex.Value(), secondaryDoc.GetID()) + secondaryDoc := r.secondarySide.plan.Value() + addFilterOnIDField(r.primaryScan, r.primarySide.relIDFieldMapIndex.Value(), secondaryDoc.GetID()) - oldFetcher := j.primaryScan.fetcher + oldFetcher := r.primaryScan.fetcher - indexOnRelation := findIndexByFieldName(j.primaryScan.col, j.relIDFieldDef.Name) - j.primaryScan.initFetcher(immutable.None[string](), indexOnRelation) + indexOnRelation := findIndexByFieldName(r.primaryScan.col, r.relIDFieldDef.Name) + r.primaryScan.initFetcher(immutable.None[string](), indexOnRelation) - docs, err := j.collectDocs(0) + docs, err := r.collectDocs(0) if err != nil { return nil, err } - err = j.primaryScan.fetcher.Close() + err = r.primaryScan.fetcher.Close() if err != nil { return nil, err } - j.primaryScan.fetcher = oldFetcher + r.primaryScan.fetcher = oldFetcher return docs, nil } @@ -780,7 +772,7 @@ func (join *invertibleTypeJoin) invertJoinDirectionWithIndex( ) error { p := join.childSide.plan s := getScanNode(p) - s.tryAddField(join.childSide.relFieldDef.Value().Name + request.RelatedObjectID) + s.tryAddFieldWithName(join.childSide.relFieldDef.Value().Name + request.RelatedObjectID) s.filter = fieldFilter s.initFetcher(immutable.Option[string]{}, immutable.Some(index)) diff --git a/tests/integration/index/query_with_unique_index_on_relation_filter_test.go b/tests/integration/index/query_with_unique_index_on_relation_filter_test.go index 05c4b05395..f9bc98577a 100644 --- a/tests/integration/index/query_with_unique_index_on_relation_filter_test.go +++ b/tests/integration/index/query_with_unique_index_on_relation_filter_test.go @@ -70,3 +70,106 @@ func TestQueryWithUniqueCompositeIndex_WithFilterOnIndexedRelation_ShouldFilter( testUtils.ExecuteTestCase(t, test) } + +func TestQueryWithUniqueCompositeIndex_WithIndexComprising2RelationsAndFilterOnIt_ShouldFilter(t *testing.T) { + test := testUtils.TestCase{ + Description: "Test if we can filter on a unique composite index comprising at least 2 relations and filter on them", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type User { + name: String @index(unique: true) + devices: [Device] + } + + type Manufacturer { + name: String + devices: [Device] + } + + type Device @index(unique: true, includes: [{name: "owner_id"}, {name: "manufacturer_id"}, {name: "model"}]) { + owner: User + manufacturer: Manufacturer + model: String + } + `, + }, + testUtils.CreateDoc{ + CollectionID: 0, + DocMap: map[string]any{ + "name": "John", + }, + }, + testUtils.CreateDoc{ + CollectionID: 1, + DocMap: map[string]any{ + "name": "Apple", + }, + }, + testUtils.CreateDoc{ + CollectionID: 2, + DocMap: map[string]any{ + "model": "iPhone", + "owner_id": testUtils.NewDocIndex(0, 0), + "manufacturer_id": testUtils.NewDocIndex(1, 0), + }, + }, + testUtils.CreateDoc{ + CollectionID: 2, + DocMap: map[string]any{ + "model": "MacBook", + "owner_id": testUtils.NewDocIndex(0, 0), + "manufacturer_id": testUtils.NewDocIndex(1, 0), + }, + }, + testUtils.Request{ + Request: `query { + byUserId: Device (filter: { + manufacturer_id: {_eq: "bae-18c7d707-c44d-552f-b6d6-9e3d05bbf9c1"}, + owner_id: {_eq: "bae-1ef746f8-821e-586f-99b2-4cb1fb9b782f"} + }) { + owner { + name + } + } + byUserName: Device (filter: { + manufacturer_id: {_eq: "bae-18c7d707-c44d-552f-b6d6-9e3d05bbf9c1"}, + owner: {name: {_eq: "John"}} + }) { + owner { + name + } + } + }`, + Results: map[string]any{ + "byUserId": []map[string]any{ + { + "owner": map[string]any{ + "name": "John", + }, + }, + { + "owner": map[string]any{ + "name": "John", + }, + }, + }, + "byUserName": []map[string]any{ + { + "owner": map[string]any{ + "name": "John", + }, + }, + { + "owner": map[string]any{ + "name": "John", + }, + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +}