From 7ee2919bb330629c5c4dd423976d17ff139a4ce9 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Fri, 14 Jul 2023 22:04:48 -0400 Subject: [PATCH] bson: improve marshal/unmarshal performance by ~58% and ~29% MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit improves BSON marshaling performance by ~58% and unmarshaling performance by ~29% by replacing the mutex based decoder/encoder caches with sync.Map, which can often avoid locking and is ideally suited for caches that only grow. The commit also adds the BenchmarkCodeMarshal and BenchmarkCodeUnmarshal benchmarks from the Go standard library's encoding/json package since they do an excellent job of stress testing parallel encoding/decoding (a common use case in a database driver) and are how the lock contention that led to this commit were discovered. ``` goos: darwin goarch: arm64 pkg: go.mongodb.org/mongo-driver/bson │ base.20.txt │ new.20.txt │ │ sec/op │ sec/op vs base │ CodeUnmarshal/BSON-10 3.192m ± 1% 2.246m ± 1% -29.64% (p=0.000 n=20) CodeUnmarshal/JSON-10 2.735m ± 1% 2.737m ± 0% ~ (p=0.640 n=20) CodeMarshal/BSON-10 2.972m ± 0% 1.221m ± 3% -58.93% (p=0.000 n=20) CodeMarshal/JSON-10 471.0µ ± 1% 464.6µ ± 0% -1.36% (p=0.000 n=20) geomean 1.870m 1.366m -26.92% │ base.20.txt │ new.20.txt │ │ B/s │ B/s vs base │ CodeUnmarshal/BSON-10 579.7Mi ± 1% 823.9Mi ± 1% +42.13% (p=0.000 n=20) CodeUnmarshal/JSON-10 676.6Mi ± 1% 676.2Mi ± 0% ~ (p=0.640 n=20) CodeMarshal/BSON-10 622.7Mi ± 0% 1516.2Mi ± 3% +143.46% (p=0.000 n=20) CodeMarshal/JSON-10 3.837Gi ± 1% 3.890Gi ± 0% +1.38% (p=0.000 n=20) geomean 989.8Mi 1.323Gi +36.84% │ base.20.txt │ new.20.txt │ │ B/op │ B/op vs base │ CodeUnmarshal/BSON-10 4.219Mi ± 0% 4.219Mi ± 0% ~ (p=0.077 n=20) CodeUnmarshal/JSON-10 2.904Mi ± 0% 2.904Mi ± 0% ~ (p=0.672 n=20) CodeMarshal/BSON-10 2.821Mi ± 1% 2.776Mi ± 2% -1.59% (p=0.023 n=20) CodeMarshal/JSON-10 1.857Mi ± 0% 1.859Mi ± 0% ~ (p=0.331 n=20) geomean 2.830Mi 2.820Mi -0.37% │ base.20.txt │ new.20.txt │ │ allocs/op │ allocs/op vs base │ CodeUnmarshal/BSON-10 230.4k ± 0% 230.4k ± 0% ~ (p=1.000 n=20) CodeUnmarshal/JSON-10 92.67k ± 0% 92.67k ± 0% ~ (p=1.000 n=20) ¹ CodeMarshal/BSON-10 94.07k ± 0% 94.07k ± 0% ~ (p=0.112 n=20) CodeMarshal/JSON-10 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=20) ¹ geomean 6.694k 6.694k +0.00% ¹ all samples are equal ``` --- bson/bsoncodec/codec_cache.go | 157 ++++++++++++++++++++++++++ bson/bsoncodec/codec_cache_test.go | 174 +++++++++++++++++++++++++++++ bson/bsoncodec/pointer_codec.go | 59 ++++------ bson/bsoncodec/registry.go | 160 +++++++++----------------- bson/bsoncodec/registry_test.go | 118 +++++++++++++------ bson/bsoncodec/struct_codec.go | 32 ++++-- 6 files changed, 509 insertions(+), 191 deletions(-) create mode 100644 bson/bsoncodec/codec_cache.go create mode 100644 bson/bsoncodec/codec_cache_test.go diff --git a/bson/bsoncodec/codec_cache.go b/bson/bsoncodec/codec_cache.go new file mode 100644 index 0000000000..eb05e0b40d --- /dev/null +++ b/bson/bsoncodec/codec_cache.go @@ -0,0 +1,157 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncodec + +import ( + "reflect" + "sync" + "sync/atomic" +) + +// statically assert array size +var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer] +var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer] + +type typeEncoderCache struct { + cache sync.Map // map[reflect.Type]ValueEncoder +} + +func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) { + c.cache.Store(rt, enc) +} + +func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(ValueEncoder), true + } + return nil, false +} + +func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder { + if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { + enc = v.(ValueEncoder) + } + return enc +} + +func (c *typeEncoderCache) Clone() *typeEncoderCache { + cc := new(typeEncoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + +type typeDecoderCache struct { + cache sync.Map // map[reflect.Type]ValueDecoder +} + +func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) { + c.cache.Store(rt, dec) +} + +func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(ValueDecoder), true + } + return nil, false +} + +func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder { + if v, loaded := c.cache.LoadOrStore(rt, dec); loaded { + dec = v.(ValueDecoder) + } + return dec +} + +func (c *typeDecoderCache) Clone() *typeDecoderCache { + cc := new(typeDecoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + +// atomic.Value requires that all calls to Store() have the same concrete type +// so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type +// is always the same (since different concrete types may implement the +// ValueEncoder interface). +type kindEncoderCacheEntry struct { + enc ValueEncoder +} + +type kindEncoderCache struct { + entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry +} + +func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) { + if enc != nil && rt < reflect.Kind(len(c.entries)) { + c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc}) + } +} + +func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) { + if rt < reflect.Kind(len(c.entries)) { + if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok { + return ent.enc, ent.enc != nil + } + } + return nil, false +} + +func (c *kindEncoderCache) Clone() *kindEncoderCache { + cc := new(kindEncoderCache) + for i, v := range c.entries { + if val := v.Load(); val != nil { + cc.entries[i].Store(val) + } + } + return cc +} + +// atomic.Value requires that all calls to Store() have the same concrete type +// so we wrap the ValueDecoder with a kindDecoderCacheEntry to ensure the type +// is always the same (since different concrete types may implement the +// ValueDecoder interface). +type kindDecoderCacheEntry struct { + dec ValueDecoder +} + +type kindDecoderCache struct { + entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry +} + +func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) { + if rt < reflect.Kind(len(c.entries)) { + c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec}) + } +} + +func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) { + if rt < reflect.Kind(len(c.entries)) { + if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok { + return ent.dec, ent.dec != nil + } + } + return nil, false +} + +func (c *kindDecoderCache) Clone() *kindDecoderCache { + cc := new(kindDecoderCache) + for i, v := range c.entries { + if val := v.Load(); val != nil { + cc.entries[i].Store(val) + } + } + return cc +} diff --git a/bson/bsoncodec/codec_cache_test.go b/bson/bsoncodec/codec_cache_test.go new file mode 100644 index 0000000000..89d52bb961 --- /dev/null +++ b/bson/bsoncodec/codec_cache_test.go @@ -0,0 +1,174 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncodec + +import ( + "reflect" + "strconv" + "strings" + "testing" +) + +var codecCacheTestTypes = [16]reflect.Type{ + reflect.TypeOf(uint8(0)), + reflect.TypeOf(uint16(0)), + reflect.TypeOf(uint32(0)), + reflect.TypeOf(uint64(0)), + reflect.TypeOf(uint(0)), + reflect.TypeOf(uintptr(0)), + reflect.TypeOf(int8(0)), + reflect.TypeOf(int16(0)), + reflect.TypeOf(int32(0)), + reflect.TypeOf(int64(0)), + reflect.TypeOf(int(0)), + reflect.TypeOf(float32(0)), + reflect.TypeOf(float64(0)), + reflect.TypeOf(true), + reflect.TypeOf(struct{ A int }{}), + reflect.TypeOf(map[int]int{}), +} + +func TestTypeCache(t *testing.T) { + rt := reflect.TypeOf(int(0)) + ec := new(typeEncoderCache) + dc := new(typeDecoderCache) + + codec := new(fakeCodec) + ec.Store(rt, codec) + dc.Store(rt, codec) + if v, ok := ec.Load(rt); !ok || !reflect.DeepEqual(v, codec) { + t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true) + } + if v, ok := dc.Load(rt); !ok || !reflect.DeepEqual(v, codec) { + t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true) + } + + // Make sure we overwrite the stored value with nil + ec.Store(rt, nil) + dc.Store(rt, nil) + if v, ok := ec.Load(rt); ok || v != nil { + t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false) + } + if v, ok := dc.Load(rt); ok || v != nil { + t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false) + } +} + +func TestTypeCacheClone(t *testing.T) { + codec := new(fakeCodec) + ec1 := new(typeEncoderCache) + dc1 := new(typeDecoderCache) + for _, rt := range codecCacheTestTypes { + ec1.Store(rt, codec) + dc1.Store(rt, codec) + } + ec2 := ec1.Clone() + dc2 := dc1.Clone() + for _, rt := range codecCacheTestTypes { + if v, _ := ec2.Load(rt); !reflect.DeepEqual(v, codec) { + t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec) + } + if v, _ := dc2.Load(rt); !reflect.DeepEqual(v, codec) { + t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec) + } + } +} + +func TestKindCacheArray(t *testing.T) { + // Check array bounds + var c kindEncoderCache + codec := new(fakeCodec) + c.Store(reflect.UnsafePointer, codec) // valid + c.Store(reflect.UnsafePointer+1, codec) // ignored + if v, ok := c.Load(reflect.UnsafePointer); !ok || v != codec { + t.Errorf("Load(reflect.UnsafePointer) = %v, %t; want: %v, %t", v, ok, codec, true) + } + if v, ok := c.Load(reflect.UnsafePointer + 1); ok || v != nil { + t.Errorf("Load(reflect.UnsafePointer + 1) = %v, %t; want: %v, %t", v, ok, nil, false) + } + + // Make sure that reflect.UnsafePointer is the last/largest reflect.Type. + // + // The String() method of invalid reflect.Type types are of the format + // "kind{NUMBER}". + for rt := reflect.UnsafePointer + 1; rt < reflect.UnsafePointer+16; rt++ { + s := rt.String() + if !strings.Contains(s, strconv.Itoa(int(rt))) { + t.Errorf("reflect.Type(%d) appears to be valid: %q", rt, s) + } + } +} + +func TestKindCacheClone(t *testing.T) { + e1 := new(kindEncoderCache) + d1 := new(kindDecoderCache) + codec := new(fakeCodec) + for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ { + e1.Store(k, codec) + d1.Store(k, codec) + } + e2 := e1.Clone() + for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ { + v1, ok1 := e1.Load(k) + v2, ok2 := e2.Load(k) + if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil { + t.Errorf("Encoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2) + } + } + d2 := d1.Clone() + for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ { + v1, ok1 := d1.Load(k) + v2, ok2 := d2.Load(k) + if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil { + t.Errorf("Decoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2) + } + } +} + +func TestKindCacheEncoderNilEncoder(t *testing.T) { + t.Run("Encoder", func(t *testing.T) { + c := new(kindEncoderCache) + c.Store(reflect.Invalid, ValueEncoder(nil)) + v, ok := c.Load(reflect.Invalid) + if v != nil || ok { + t.Errorf("Load of nil ValueEncoder should return: nil, false; got: %v, %t", v, ok) + } + }) + t.Run("Decoder", func(t *testing.T) { + c := new(kindDecoderCache) + c.Store(reflect.Invalid, ValueDecoder(nil)) + v, ok := c.Load(reflect.Invalid) + if v != nil || ok { + t.Errorf("Load of nil ValueDecoder should return: nil, false; got: %v, %t", v, ok) + } + }) +} + +func BenchmarkEncoderCacheLoad(b *testing.B) { + c := new(typeEncoderCache) + codec := new(fakeCodec) + typs := codecCacheTestTypes + for _, t := range typs { + c.Store(t, codec) + } + b.RunParallel(func(pb *testing.PB) { + for i := 0; pb.Next(); i++ { + c.Load(typs[i%len(typs)]) + } + }) +} + +func BenchmarkEncoderCacheStore(b *testing.B) { + c := new(typeEncoderCache) + codec := new(fakeCodec) + b.RunParallel(func(pb *testing.PB) { + typs := codecCacheTestTypes + for i := 0; pb.Next(); i++ { + c.Store(typs[i%len(typs)], codec) + } + }) +} diff --git a/bson/bsoncodec/pointer_codec.go b/bson/bsoncodec/pointer_codec.go index a1bf9c3e2b..e5923230b0 100644 --- a/bson/bsoncodec/pointer_codec.go +++ b/bson/bsoncodec/pointer_codec.go @@ -8,7 +8,6 @@ package bsoncodec import ( "reflect" - "sync" "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" @@ -22,9 +21,8 @@ var _ ValueDecoder = &PointerCodec{} // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the // PointerCodec registered. type PointerCodec struct { - ecache map[reflect.Type]ValueEncoder - dcache map[reflect.Type]ValueDecoder - l sync.RWMutex + ecache typeEncoderCache + dcache typeDecoderCache } // NewPointerCodec returns a PointerCodec that has been initialized. @@ -32,10 +30,7 @@ type PointerCodec struct { // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the // PointerCodec registered. func NewPointerCodec() *PointerCodec { - return &PointerCodec{ - ecache: make(map[reflect.Type]ValueEncoder), - dcache: make(map[reflect.Type]ValueDecoder), - } + return &PointerCodec{} } // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil @@ -52,24 +47,19 @@ func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val return vw.WriteNull() } - pc.l.RLock() - enc, ok := pc.ecache[val.Type()] - pc.l.RUnlock() - if ok { - if enc == nil { - return ErrNoEncoder{Type: val.Type()} + typ := val.Type() + if v, ok := pc.ecache.Load(typ); ok { + if v == nil { + return ErrNoEncoder{Type: typ} } - return enc.EncodeValue(ec, vw, val.Elem()) + return v.EncodeValue(ec, vw, val.Elem()) } - - enc, err := ec.LookupEncoder(val.Type().Elem()) - pc.l.Lock() - pc.ecache[val.Type()] = enc - pc.l.Unlock() + // TODO(charlie): handle concurrent requests for the same type + enc, err := ec.LookupEncoder(typ.Elem()) + enc = pc.ecache.LoadOrStore(typ, enc) if err != nil { return err } - return enc.EncodeValue(ec, vw, val.Elem()) } @@ -80,36 +70,31 @@ func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } + typ := val.Type() if vr.Type() == bsontype.Null { - val.Set(reflect.Zero(val.Type())) + val.Set(reflect.Zero(typ)) return vr.ReadNull() } if vr.Type() == bsontype.Undefined { - val.Set(reflect.Zero(val.Type())) + val.Set(reflect.Zero(typ)) return vr.ReadUndefined() } if val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) + val.Set(reflect.New(typ.Elem())) } - pc.l.RLock() - dec, ok := pc.dcache[val.Type()] - pc.l.RUnlock() - if ok { - if dec == nil { - return ErrNoDecoder{Type: val.Type()} + if v, ok := pc.dcache.Load(typ); ok { + if v == nil { + return ErrNoDecoder{Type: typ} } - return dec.DecodeValue(dc, vr, val.Elem()) + return v.DecodeValue(dc, vr, val.Elem()) } - - dec, err := dc.LookupDecoder(val.Type().Elem()) - pc.l.Lock() - pc.dcache[val.Type()] = dec - pc.l.Unlock() + // TODO(charlie): handle concurrent requests for the same type + dec, err := dc.LookupDecoder(typ.Elem()) + dec = pc.dcache.LoadOrStore(typ, dec) if err != nil { return err } - return dec.DecodeValue(dc, vr, val.Elem()) } diff --git a/bson/bsoncodec/registry.go b/bson/bsoncodec/registry.go index 930de28490..f309ee2b39 100644 --- a/bson/bsoncodec/registry.go +++ b/bson/bsoncodec/registry.go @@ -216,72 +216,42 @@ func (rb *RegistryBuilder) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Typ // // Deprecated: Use NewRegistry instead. func (rb *RegistryBuilder) Build() *Registry { - registry := new(Registry) - - registry.typeEncoders = make(map[reflect.Type]ValueEncoder, len(rb.registry.typeEncoders)) - for t, enc := range rb.registry.typeEncoders { - registry.typeEncoders[t] = enc - } - - registry.typeDecoders = make(map[reflect.Type]ValueDecoder, len(rb.registry.typeDecoders)) - for t, dec := range rb.registry.typeDecoders { - registry.typeDecoders[t] = dec - } - - registry.interfaceEncoders = make([]interfaceValueEncoder, len(rb.registry.interfaceEncoders)) - copy(registry.interfaceEncoders, rb.registry.interfaceEncoders) - - registry.interfaceDecoders = make([]interfaceValueDecoder, len(rb.registry.interfaceDecoders)) - copy(registry.interfaceDecoders, rb.registry.interfaceDecoders) - - registry.kindEncoders = make(map[reflect.Kind]ValueEncoder) - for kind, enc := range rb.registry.kindEncoders { - registry.kindEncoders[kind] = enc + r := &Registry{ + interfaceEncoders: append([]interfaceValueEncoder(nil), rb.registry.interfaceEncoders...), + interfaceDecoders: append([]interfaceValueDecoder(nil), rb.registry.interfaceDecoders...), + typeEncoders: rb.registry.typeEncoders.Clone(), + typeDecoders: rb.registry.typeDecoders.Clone(), + kindEncoders: rb.registry.kindEncoders.Clone(), + kindDecoders: rb.registry.kindDecoders.Clone(), } - - registry.kindDecoders = make(map[reflect.Kind]ValueDecoder) - for kind, dec := range rb.registry.kindDecoders { - registry.kindDecoders[kind] = dec - } - - registry.typeMap = make(map[bsontype.Type]reflect.Type) - for bt, rt := range rb.registry.typeMap { - registry.typeMap[bt] = rt - } - - return registry + rb.registry.typeMap.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + r.typeMap.Store(k, v) + } + return true + }) + return r } // A Registry is used to store and retrieve codecs for types and interfaces. This type is the main // typed passed around and Encoders and Decoders are constructed from it. type Registry struct { - typeEncoders map[reflect.Type]ValueEncoder - typeDecoders map[reflect.Type]ValueDecoder - interfaceEncoders []interfaceValueEncoder interfaceDecoders []interfaceValueDecoder - - kindEncoders map[reflect.Kind]ValueEncoder - kindDecoders map[reflect.Kind]ValueDecoder - - typeMap map[bsontype.Type]reflect.Type - - mu sync.RWMutex + typeEncoders *typeEncoderCache + typeDecoders *typeDecoderCache + kindEncoders *kindEncoderCache + kindDecoders *kindDecoderCache + typeMap sync.Map // map[bsontype.Type]reflect.Type } // NewRegistry creates a new empty Registry. func NewRegistry() *Registry { return &Registry{ - typeEncoders: make(map[reflect.Type]ValueEncoder), - typeDecoders: make(map[reflect.Type]ValueDecoder), - - interfaceEncoders: make([]interfaceValueEncoder, 0), - interfaceDecoders: make([]interfaceValueDecoder, 0), - - kindEncoders: make(map[reflect.Kind]ValueEncoder), - kindDecoders: make(map[reflect.Kind]ValueDecoder), - - typeMap: make(map[bsontype.Type]reflect.Type), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), } } @@ -296,7 +266,7 @@ func NewRegistry() *Registry { // // RegisterTypeEncoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { - r.typeEncoders[valueType] = enc + r.typeEncoders.Store(valueType, enc) } // RegisterTypeDecoder registers the provided ValueDecoder for the provided type. @@ -310,7 +280,7 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) // // RegisterTypeDecoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) { - r.typeDecoders[valueType] = dec + r.typeDecoders.Store(valueType, dec) } // RegisterKindEncoder registers the provided ValueEncoder for the provided kind. @@ -326,7 +296,7 @@ func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) // // RegisterKindEncoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { - r.kindEncoders[kind] = enc + r.kindEncoders.Store(kind, enc) } // RegisterKindDecoder registers the provided ValueDecoder for the provided kind. @@ -342,7 +312,7 @@ func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { // // RegisterKindDecoder should not be called concurrently with any other Registry method. func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { - r.kindDecoders[kind] = dec + r.kindDecoders.Store(kind, dec) } // RegisterInterfaceEncoder registers an encoder for the provided interface type iface. This encoder will @@ -401,7 +371,7 @@ func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder // // reg.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.Raw{})) func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) { - r.typeMap[bt] = rt + r.typeMap.Store(bt, rt) } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup @@ -418,9 +388,7 @@ func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) { // If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for // concurrent use by multiple goroutines after all codecs and encoders are registered. func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { - r.mu.RLock() enc, found := r.lookupTypeEncoder(valueType) - r.mu.RUnlock() if found { if enc == nil { return nil, ErrNoEncoder{Type: valueType} @@ -430,36 +398,26 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { enc, found = r.lookupInterfaceEncoder(valueType, true) if found { - r.mu.Lock() - r.typeEncoders[valueType] = enc - r.mu.Unlock() - return enc, nil + return r.typeEncoders.LoadOrStore(valueType, enc), nil } - if valueType == nil { - r.mu.Lock() - r.typeEncoders[valueType] = nil - r.mu.Unlock() + r.storeTypeEncoder(valueType, nil) return nil, ErrNoEncoder{Type: valueType} } - enc, found = r.kindEncoders[valueType.Kind()] - if !found { - r.mu.Lock() - r.typeEncoders[valueType] = nil - r.mu.Unlock() - return nil, ErrNoEncoder{Type: valueType} + if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { + return r.storeTypeEncoder(valueType, v), nil } + r.storeTypeEncoder(valueType, nil) + return nil, ErrNoEncoder{Type: valueType} +} - r.mu.Lock() - r.typeEncoders[valueType] = enc - r.mu.Unlock() - return enc, nil +func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder { + return r.typeEncoders.LoadOrStore(rt, enc) } -func (r *Registry) lookupTypeEncoder(valueType reflect.Type) (ValueEncoder, bool) { - enc, found := r.typeEncoders[valueType] - return enc, found +func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { + return r.typeEncoders.Load(rt) } func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) { @@ -475,7 +433,7 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // ahead in interfaceEncoders defaultEnc, found := r.lookupInterfaceEncoder(valueType, false) if !found { - defaultEnc = r.kindEncoders[valueType.Kind()] + defaultEnc, _ = r.kindEncoders.Load(valueType.Kind()) } return newCondAddrEncoder(ienc.ve, defaultEnc), true } @@ -500,10 +458,7 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if valueType == nil { return nil, ErrNilType } - decodererr := ErrNoDecoder{Type: valueType} - r.mu.RLock() dec, found := r.lookupTypeDecoder(valueType) - r.mu.RUnlock() if found { if dec == nil { return nil, ErrNoDecoder{Type: valueType} @@ -513,29 +468,22 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { dec, found = r.lookupInterfaceDecoder(valueType, true) if found { - r.mu.Lock() - r.typeDecoders[valueType] = dec - r.mu.Unlock() - return dec, nil + return r.storeTypeDecoder(valueType, dec), nil } - dec, found = r.kindDecoders[valueType.Kind()] - if !found { - r.mu.Lock() - r.typeDecoders[valueType] = nil - r.mu.Unlock() - return nil, decodererr + if v, ok := r.kindDecoders.Load(valueType.Kind()); ok { + return r.storeTypeDecoder(valueType, v), nil } - - r.mu.Lock() - r.typeDecoders[valueType] = dec - r.mu.Unlock() - return dec, nil + r.storeTypeDecoder(valueType, nil) + return nil, ErrNoDecoder{Type: valueType} } func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) { - dec, found := r.typeDecoders[valueType] - return dec, found + return r.typeDecoders.Load(valueType) +} + +func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder { + return r.typeDecoders.LoadOrStore(typ, dec) } func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { @@ -548,7 +496,7 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // ahead in interfaceDecoders defaultDec, found := r.lookupInterfaceDecoder(valueType, false) if !found { - defaultDec = r.kindDecoders[valueType.Kind()] + defaultDec, _ = r.kindDecoders.Load(valueType.Kind()) } return newCondAddrDecoder(idec.vd, defaultDec), true } @@ -561,11 +509,11 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // // LookupTypeMapEntry should not be called concurrently with any other Registry method. func (r *Registry) LookupTypeMapEntry(bt bsontype.Type) (reflect.Type, error) { - t, ok := r.typeMap[bt] - if !ok || t == nil { + v, ok := r.typeMap.Load(bt) + if v == nil || !ok { return nil, ErrNoTypeMapEntry{Type: bt} } - return t, nil + return v.(reflect.Type), nil } type interfaceValueEncoder struct { diff --git a/bson/bsoncodec/registry_test.go b/bson/bsoncodec/registry_test.go index 9ed68ce566..d09f32be5e 100644 --- a/bson/bsoncodec/registry_test.go +++ b/bson/bsoncodec/registry_test.go @@ -65,7 +65,7 @@ func TestRegistryBuilder(t *testing.T) { got := reg.typeEncoders for _, s := range want { wantT, wantC := s.t, s.c - gotC, exists := got[wantT] + gotC, exists := got.Load(wantT) if !exists { t.Errorf("Did not find type in the type registry: %v", wantT) } @@ -94,7 +94,7 @@ func TestRegistryBuilder(t *testing.T) { got := reg.kindEncoders for _, s := range want { wantK, wantC := s.k, s.c - gotC, exists := got[wantK] + gotC, exists := got.Load(wantK) if !exists { t.Errorf("Did not find kind in the kind registry: %v", wantK) } @@ -111,14 +111,14 @@ func TestRegistryBuilder(t *testing.T) { rb.RegisterDefaultEncoder(reflect.Map, codec) reg := rb.Build() - if reg.kindEncoders[reflect.Map] != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Map], codec) + if reg.kindEncoders.get(reflect.Map) != codec { + t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) } rb.RegisterDefaultEncoder(reflect.Map, codec2) reg = rb.Build() - if reg.kindEncoders[reflect.Map] != codec2 { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Map], codec2) + if reg.kindEncoders.get(reflect.Map) != codec2 { + t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) } }) t.Run("StructCodec", func(t *testing.T) { @@ -128,14 +128,14 @@ func TestRegistryBuilder(t *testing.T) { rb.RegisterDefaultEncoder(reflect.Struct, codec) reg := rb.Build() - if reg.kindEncoders[reflect.Struct] != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Struct], codec) + if reg.kindEncoders.get(reflect.Struct) != codec { + t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) } rb.RegisterDefaultEncoder(reflect.Struct, codec2) reg = rb.Build() - if reg.kindEncoders[reflect.Struct] != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Struct], codec2) + if reg.kindEncoders.get(reflect.Struct) != codec2 { + t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) } }) t.Run("SliceCodec", func(t *testing.T) { @@ -145,14 +145,14 @@ func TestRegistryBuilder(t *testing.T) { rb.RegisterDefaultEncoder(reflect.Slice, codec) reg := rb.Build() - if reg.kindEncoders[reflect.Slice] != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Slice], codec) + if reg.kindEncoders.get(reflect.Slice) != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) } rb.RegisterDefaultEncoder(reflect.Slice, codec2) reg = rb.Build() - if reg.kindEncoders[reflect.Slice] != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Slice], codec2) + if reg.kindEncoders.get(reflect.Slice) != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) } }) t.Run("ArrayCodec", func(t *testing.T) { @@ -162,14 +162,14 @@ func TestRegistryBuilder(t *testing.T) { rb.RegisterDefaultEncoder(reflect.Array, codec) reg := rb.Build() - if reg.kindEncoders[reflect.Array] != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Array], codec) + if reg.kindEncoders.get(reflect.Array) != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) } rb.RegisterDefaultEncoder(reflect.Array, codec2) reg = rb.Build() - if reg.kindEncoders[reflect.Array] != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Array], codec2) + if reg.kindEncoders.get(reflect.Array) != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) } }) }) @@ -485,7 +485,7 @@ func TestRegistry(t *testing.T) { got := reg.typeEncoders for _, s := range want { wantT, wantC := s.t, s.c - gotC, exists := got[wantT] + gotC, exists := got.Load(wantT) if !exists { t.Errorf("type missing in registry: %v", wantT) } @@ -515,7 +515,7 @@ func TestRegistry(t *testing.T) { got := reg.kindEncoders for _, s := range want { wantK, wantC := s.k, s.c - gotC, exists := got[wantK] + gotC, exists := got.Load(wantK) if !exists { t.Errorf("type missing in registry: %v", wantK) } @@ -534,12 +534,12 @@ func TestRegistry(t *testing.T) { codec2 := &fakeCodec{num: 2} reg := NewRegistry() reg.RegisterKindEncoder(reflect.Map, codec) - if reg.kindEncoders[reflect.Map] != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Map], codec) + if reg.kindEncoders.get(reflect.Map) != codec { + t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) } reg.RegisterKindEncoder(reflect.Map, codec2) - if reg.kindEncoders[reflect.Map] != codec2 { - t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders[reflect.Map], codec2) + if reg.kindEncoders.get(reflect.Map) != codec2 { + t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) } }) t.Run("StructCodec", func(t *testing.T) { @@ -549,12 +549,12 @@ func TestRegistry(t *testing.T) { codec2 := &fakeCodec{num: 2} reg := NewRegistry() reg.RegisterKindEncoder(reflect.Struct, codec) - if reg.kindEncoders[reflect.Struct] != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Struct], codec) + if reg.kindEncoders.get(reflect.Struct) != codec { + t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) } reg.RegisterKindEncoder(reflect.Struct, codec2) - if reg.kindEncoders[reflect.Struct] != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Struct], codec2) + if reg.kindEncoders.get(reflect.Struct) != codec2 { + t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) } }) t.Run("SliceCodec", func(t *testing.T) { @@ -564,12 +564,12 @@ func TestRegistry(t *testing.T) { codec2 := &fakeCodec{num: 2} reg := NewRegistry() reg.RegisterKindEncoder(reflect.Slice, codec) - if reg.kindEncoders[reflect.Slice] != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Slice], codec) + if reg.kindEncoders.get(reflect.Slice) != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) } reg.RegisterKindEncoder(reflect.Slice, codec2) - if reg.kindEncoders[reflect.Slice] != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Slice], codec2) + if reg.kindEncoders.get(reflect.Slice) != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) } }) t.Run("ArrayCodec", func(t *testing.T) { @@ -579,12 +579,12 @@ func TestRegistry(t *testing.T) { codec2 := &fakeCodec{num: 2} reg := NewRegistry() reg.RegisterKindEncoder(reflect.Array, codec) - if reg.kindEncoders[reflect.Array] != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Array], codec) + if reg.kindEncoders.get(reflect.Array) != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) } reg.RegisterKindEncoder(reflect.Array, codec2) - if reg.kindEncoders[reflect.Array] != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders[reflect.Array], codec2) + if reg.kindEncoders.get(reflect.Array) != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) } }) }) @@ -860,6 +860,52 @@ func TestRegistry(t *testing.T) { }) } +// get is only for testing as it does return if the value was found +func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder { + e, _ := c.Load(rt) + return e +} + +func BenchmarkLookupEncoder(b *testing.B) { + type childStruct struct { + V1, V2, V3, V4 int + } + type nestedStruct struct { + childStruct + A struct{ C1, C2, C3, C4 childStruct } + B struct{ C1, C2, C3, C4 childStruct } + C struct{ M1, M2, M3, M4 map[int]int } + } + types := [...]reflect.Type{ + reflect.TypeOf(int64(1)), + reflect.TypeOf(&fakeCodec{}), + reflect.TypeOf(&testInterface1Impl{}), + reflect.TypeOf(&nestedStruct{}), + } + r := NewRegistry() + for _, typ := range types { + r.RegisterTypeEncoder(typ, &fakeCodec{}) + } + b.Run("Serial", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := r.LookupEncoder(types[i%len(types)]) + if err != nil { + b.Fatal(err) + } + } + }) + b.Run("Parallel", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for i := 0; pb.Next(); i++ { + _, err := r.LookupEncoder(types[i%len(types)]) + if err != nil { + b.Fatal(err) + } + } + }) + }) +} + type fakeType1 struct{} type fakeType2 struct{} type fakeType4 struct{} diff --git a/bson/bsoncodec/struct_codec.go b/bson/bsoncodec/struct_codec.go index 1dfdd98865..29ea76d19c 100644 --- a/bson/bsoncodec/struct_codec.go +++ b/bson/bsoncodec/struct_codec.go @@ -63,8 +63,7 @@ type Zeroer interface { // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the // StructCodec registered. type StructCodec struct { - cache map[reflect.Type]*structDescription - l sync.RWMutex + cache sync.Map // map[reflect.Type]*structDescription parser StructTagParser // DecodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the @@ -115,7 +114,6 @@ func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) structOpt := bsonoptions.MergeStructCodecOptions(opts...) codec := &StructCodec{ - cache: make(map[reflect.Type]*structDescription), parser: p, } @@ -502,13 +500,27 @@ func (sc *StructCodec) describeStruct( ) (*structDescription, error) { // We need to analyze the struct, including getting the tags, collecting // information about inlining, and create a map of the field name to the field. - sc.l.RLock() - ds, exists := sc.cache[t] - sc.l.RUnlock() - if exists { - return ds, nil + if v, ok := sc.cache.Load(t); ok { + return v.(*structDescription), nil } + // TODO(charlie): Only describe the struct once when called + // concurrently with the same type. + ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates) + if err != nil { + return nil, err + } + if v, loaded := sc.cache.LoadOrStore(t, ds); loaded { + ds = v.(*structDescription) + } + return ds, nil +} +func (sc *StructCodec) describeStructSlow( + r *Registry, + t reflect.Type, + useJSONStructTags bool, + errorOnDuplicates bool, +) (*structDescription, error) { numFields := t.NumField() sd := &structDescription{ fm: make(map[string]fieldDescription, numFields), @@ -639,10 +651,6 @@ func (sc *StructCodec) describeStruct( sort.Sort(byIndex(sd.fl)) - sc.l.Lock() - sc.cache[t] = sd - sc.l.Unlock() - return sd, nil }