diff --git a/docstore/drivertest/drivertest.go b/docstore/drivertest/drivertest.go index 940a57ac30..588eed13b8 100644 --- a/docstore/drivertest/drivertest.go +++ b/docstore/drivertest/drivertest.go @@ -1938,11 +1938,15 @@ func testExampleInDoc(t *testing.T, _ Harness, coll *docstore.Collection) { type Name struct { First, Last string } + type Publication struct { + Year int `docstore:"year"` + Publisher string `docstore:"publisher"` + } type Book struct { - Title string `docstore:"name"` - Author Name `docstore:"author"` - PublicationYears []int `docstore:"pub_years,omitempty"` - NumPublications int `docstore:"-"` + Title string `docstore:"name"` + Author Name `docstore:"author"` + Publications []Publication `docstore:"pubs,omitempty"` + Reviews int `docstore:"-"` } must := func(err error) { @@ -1968,8 +1972,11 @@ func testExampleInDoc(t *testing.T, _ Harness, coll *docstore.Collection) { First: "Mikhail", Last: "Bulgakov", }, - PublicationYears: []int{1967, 1973}, - NumPublications: 2, + Publications: []Publication{ + {Year: 1967, Publisher: "Moscow magazine "}, + {Year: 1973, Publisher: "YMCA Press"}, + }, + Reviews: 22950, } doc2 := map[string]interface{}{ @@ -1978,7 +1985,15 @@ func testExampleInDoc(t *testing.T, _ Harness, coll *docstore.Collection) { "First": "Mikhail", "Last": "Bulgakov", }, - "pub_years": []int{1968, 1987}, + "pubs": []map[string]interface{}{ + { + "year": 1968, + "publisher": "Harcourt Brace", + }, + { + "year": 1987, + }, + }, } ctx := context.Background() @@ -1987,16 +2002,22 @@ func testExampleInDoc(t *testing.T, _ Harness, coll *docstore.Collection) { got2 := &Book{Title: doc2[KeyField].(string)} must(coll.Actions().Get(got1).Get(got2).Do(ctx)) - if got1.NumPublications != 0 { + if got1.Reviews != 0 { t.Errorf("docstore:\"-\" tagged field isn't ignored") } checkFieldEqual(got1, doc1, "author") - checkFieldEqual(got2, doc2, "pub_years") + if len(got2.Publications) != len(doc2["pubs"].([]map[string]interface{})) { + t.Errorf("docstore: unexpected amount of pubs: %d/%d", len(got2.Publications), len(doc2["pubs"].([]map[string]interface{}))) + } gots := mustCollect(ctx, t, coll.Query().Where("author.Last", "=", "Bulgakov").Get(ctx)) if len(gots) != 2 { t.Errorf("got %v want all two results", gots) } + gots = mustCollect(ctx, t, coll.Query().Where("pubs.year", "=", 1967).Get(ctx)) + if len(gots) != 1 { + t.Errorf("got %v want The Heart of a Dog", gots) + } must(coll.Actions().Delete(doc1).Delete(doc2).Do(ctx)) } diff --git a/docstore/memdocstore/mem.go b/docstore/memdocstore/mem.go index a5470261b6..802e310b19 100644 --- a/docstore/memdocstore/mem.go +++ b/docstore/memdocstore/mem.go @@ -399,16 +399,31 @@ func (c *collection) checkRevision(arg driver.Document, current storedDoc) error // getAtFieldPath gets the value of m at fp. It returns an error if fp is invalid // (see getParentMap). -func getAtFieldPath(m map[string]interface{}, fp []string) (interface{}, error) { - m2, err := getParentMap(m, fp, false) - if err != nil { - return nil, err +func getAtFieldPath(m map[string]interface{}, fp []string) (result interface{}, err error) { + + var get func(m interface{}, name string) interface{} + get = func(m interface{}, name string) interface{} { + switch concrete := m.(type) { + case map[string]interface{}: + return concrete[name] + case []interface{}: + result := []interface{}{} + for _, e := range concrete { + result = append(result, get(e, name)) + } + return result + } + return nil } - v, ok := m2[fp[len(fp)-1]] - if ok { - return v, nil + result = m + for _, k := range fp { + next := get(result, k) + if next == nil { + return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", strings.Join(fp, ".")) + } + result = next } - return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", fp) + return result, nil } // setAtFieldPath sets m's value at fp to val. It creates intermediate maps as @@ -422,14 +437,6 @@ func setAtFieldPath(m map[string]interface{}, fp []string, val interface{}) erro return nil } -// Delete the value from m at the given field path, if it exists. -func deleteAtFieldPath(m map[string]interface{}, fp []string) { - m2, _ := getParentMap(m, fp, false) // ignore error - if m2 != nil { - delete(m2, fp[len(fp)-1]) - } -} - // getParentMap returns the map that directly contains the given field path; // that is, the value of m at the field path that excludes the last component // of fp. If a non-map is encountered along the way, an InvalidArgument error is diff --git a/docstore/memdocstore/mem_test.go b/docstore/memdocstore/mem_test.go index 2129f62ef7..467ba945f7 100644 --- a/docstore/memdocstore/mem_test.go +++ b/docstore/memdocstore/mem_test.go @@ -16,6 +16,7 @@ package memdocstore import ( "context" + "io" "os" "path/filepath" "testing" @@ -129,6 +130,49 @@ func TestUpdateAtomic(t *testing.T) { } } +func TestQueryNested(t *testing.T) { + ctx := context.Background() + + count := func(iter *docstore.DocumentIterator) (c int) { + doc := docmap{} + for { + if err := iter.Next(ctx, doc); err != nil { + if err == io.EOF { + break + } + t.Fatal(err) + } + c++ + } + return c + } + + dc, err := newCollection(drivertest.KeyField, nil, nil) + if err != nil { + t.Fatal(err) + } + coll := docstore.NewCollection(dc) + defer coll.Close() + + doc := docmap{drivertest.KeyField: "TestQueryNested", + "list": []any{docmap{"a": "A"}}, + "map": docmap{"b": "B"}, + dc.RevisionField(): nil, + } + if err := coll.Put(ctx, doc); err != nil { + t.Fatal(err) + } + + got := count(coll.Query().Where("list.a", "=", "A").Get(ctx)) + if got != 1 { + t.Errorf("got %v docs when filtering by list.a, want 1", got) + } + got = count(coll.Query().Where("map.b", "=", "B").Get(ctx)) + if got != 1 { + t.Errorf("got %v docs when filtering by map.b, want 1", got) + } +} + func TestSortDocs(t *testing.T) { newDocs := func() []storedDoc { return []storedDoc{ diff --git a/docstore/memdocstore/query.go b/docstore/memdocstore/query.go index 419017b993..e1407731a6 100644 --- a/docstore/memdocstore/query.go +++ b/docstore/memdocstore/query.go @@ -138,6 +138,19 @@ func compare(x1, x2 interface{}) (int, bool) { } return -1, true } + if v1.Kind() == reflect.Slice { + for i := 0; i < v1.Len(); i++ { + if c, ok := compare(x2, v1.Index(i).Interface()); ok { + if !ok { + return 0, false + } + if c == 0 { + return 0, true + } + } + } + return -1, true + } if v1.Kind() == reflect.String && v2.Kind() == reflect.String { return strings.Compare(v1.String(), v2.String()), true }