Skip to content

Commit

Permalink
GODRIVER-2914 bsoncodec/bsonrw: eliminate encoding allocations (#1323)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com>
Co-authored-by: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 8, 2023
1 parent 7f8b1d0 commit cbe8aa4
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 84 deletions.
2 changes: 1 addition & 1 deletion bson/bsoncodec/slice_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re
}

// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().ConvertibleTo(tD) {
if val.Type() == tD || val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(primitive.D)

dw, err := vw.WriteDocument()
Expand Down
73 changes: 24 additions & 49 deletions bson/bsoncodec/struct_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,14 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val
encoder := desc.encoder

var zero bool
rvInterface := rv.Interface()
if cz, ok := encoder.(CodecZeroer); ok {
zero = cz.IsTypeZero(rvInterface)
zero = cz.IsTypeZero(rv.Interface())
} else if rv.Kind() == reflect.Interface {
// isZero will not treat an interface rv as an interface, so we need to check for the
// zero interface separately.
zero = rv.IsNil()
} else {
zero = isZero(rvInterface, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
zero = isZero(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
}
if desc.omitEmpty && zero {
continue
Expand Down Expand Up @@ -392,56 +391,32 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val
return nil
}

func isZero(i interface{}, omitZeroStruct bool) bool {
v := reflect.ValueOf(i)

// check the value validity
if !v.IsValid() {
return true
func isZero(v reflect.Value, omitZeroStruct bool) bool {
kind := v.Kind()
if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) {
return v.Interface().(Zeroer).IsZero()
}

if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
return z.IsZero()
}

switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
case reflect.Struct:
if kind == reflect.Struct {
if !omitZeroStruct {
return false
}

// TODO(GODRIVER-2820): Update the logic to be able to handle private struct fields.
// TODO Use condition "reflect.Zero(v.Type()).Equal(v)" instead.

vt := v.Type()
if vt == tTime {
return v.Interface().(time.Time).IsZero()
}
for i := 0; i < v.NumField(); i++ {
if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
numField := vt.NumField()
for i := 0; i < numField; i++ {
ff := vt.Field(i)
if ff.PkgPath != "" && !ff.Anonymous {
continue // Private field
}
fld := v.Field(i)
if !isZero(fld.Interface(), omitZeroStruct) {
if !isZero(v.Field(i), omitZeroStruct) {
return false
}
}
return true
}

return false
return !v.IsValid() || v.IsZero()
}

type structDescription struct {
Expand Down Expand Up @@ -708,21 +683,21 @@ func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {

// DeepZero returns recursive zero object
func deepZero(st reflect.Type) (result reflect.Value) {
result = reflect.Indirect(reflect.New(st))

if result.Kind() == reflect.Struct {
for i := 0; i < result.NumField(); i++ {
if f := result.Field(i); f.Kind() == reflect.Ptr {
if f.CanInterface() {
if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
}
if st.Kind() == reflect.Struct {
numField := st.NumField()
for i := 0; i < numField; i++ {
if result == emptyValue {
result = reflect.Indirect(reflect.New(st))
}
f := result.Field(i)
if f.CanInterface() {
if f.Type().Kind() == reflect.Struct {
result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem())))
}
}
}
}

return
return result
}

// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
Expand Down
3 changes: 2 additions & 1 deletion bson/bsoncodec/struct_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsoncodec

import (
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -147,7 +148,7 @@ func TestIsZero(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
t.Parallel()

got := isZero(tc.value, tc.omitZeroStruct)
got := isZero(reflect.ValueOf(tc.value), tc.omitZeroStruct)
assert.Equal(t, tc.want, got, "expected and actual isZero return are different")
})
}
Expand Down
1 change: 1 addition & 0 deletions bson/bsoncodec/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem()
var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem()
var tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem()

var tBinary = reflect.TypeOf(primitive.Binary{})
var tUndefined = reflect.TypeOf(primitive.Undefined{})
Expand Down
6 changes: 3 additions & 3 deletions bson/bsonrw/copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error)
}

vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
defer putValueWriter(vw)

vw.reset(dst)

Expand All @@ -213,7 +213,7 @@ func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
}

vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
defer putValueWriter(vw)

vw.reset(dst)

Expand Down Expand Up @@ -258,7 +258,7 @@ func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []
}

vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
defer putValueWriter(vw)

start := len(dst)

Expand Down
12 changes: 10 additions & 2 deletions bson/bsonrw/value_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,7 @@ func (vr *valueReader) ReadValue() (ValueReader, error) {
return nil, ErrEOA
}

_, err = vr.readCString()
if err != nil {
if err := vr.skipCString(); err != nil {
return nil, err
}

Expand Down Expand Up @@ -794,6 +793,15 @@ func (vr *valueReader) readByte() (byte, error) {
return vr.d[vr.offset-1], nil
}

func (vr *valueReader) skipCString() error {
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
return io.EOF
}
vr.offset += int64(idx) + 1
return nil
}

func (vr *valueReader) readCString() (string, error) {
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
Expand Down
74 changes: 49 additions & 25 deletions bson/bsonrw/value_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ var vwPool = sync.Pool{
},
}

func putValueWriter(vw *valueWriter) {
if vw != nil {
vw.w = nil // don't leak the writer
vwPool.Put(vw)
}
}

// BSONValueWriterPool is a pool for BSON ValueWriters.
//
// Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0.
Expand Down Expand Up @@ -149,32 +156,21 @@ type valueWriter struct {
}

func (vw *valueWriter) advanceFrame() {
if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
length := len(vw.stack)
if length+1 >= cap(vw.stack) {
// double it
buf := make([]vwState, 2*cap(vw.stack)+1)
copy(buf, vw.stack)
vw.stack = buf
}
vw.stack = vw.stack[:length+1]
}
vw.frame++
if vw.frame >= int64(len(vw.stack)) {
vw.stack = append(vw.stack, vwState{})
}
}

func (vw *valueWriter) push(m mode) {
vw.advanceFrame()

// Clean the stack
vw.stack[vw.frame].mode = m
vw.stack[vw.frame].key = ""
vw.stack[vw.frame].arrkey = 0
vw.stack[vw.frame].start = 0
vw.stack[vw.frame] = vwState{mode: m}

vw.stack[vw.frame].mode = m
switch m {
case mDocument, mArray, mCodeWithScope:
vw.reserveLength()
vw.reserveLength() // WARN: this is not needed
}
}

Expand Down Expand Up @@ -213,6 +209,7 @@ func newValueWriter(w io.Writer) *valueWriter {
return vw
}

// TODO: only used in tests
func newValueWriterFromSlice(buf []byte) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
Expand Down Expand Up @@ -249,17 +246,16 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod
}

func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
switch vw.stack[vw.frame].mode {
frame := &vw.stack[vw.frame]
switch frame.mode {
case mElement:
key := vw.stack[vw.frame].key
key := frame.key
if !isValidCString(key) {
return errors.New("BSON element key cannot contain null bytes")
}

vw.buf = bsoncore.AppendHeader(vw.buf, t, key)
vw.appendHeader(t, key)
case mValue:
// TODO: Do this with a cache of the first 1000 or so array keys.
vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
vw.appendIntHeader(t, frame.arrkey)
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
Expand Down Expand Up @@ -601,9 +597,11 @@ func (vw *valueWriter) writeLength() error {
if length > maxSize {
return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
}
length = length - int(vw.stack[vw.frame].start)
start := vw.stack[vw.frame].start
frame := &vw.stack[vw.frame]
length = length - int(frame.start)
start := frame.start

_ = vw.buf[start+3] // BCE
vw.buf[start+0] = byte(length)
vw.buf[start+1] = byte(length >> 8)
vw.buf[start+2] = byte(length >> 16)
Expand All @@ -612,5 +610,31 @@ func (vw *valueWriter) writeLength() error {
}

func isValidCString(cs string) bool {
return !strings.ContainsRune(cs, '\x00')
// Disallow the zero byte in a cstring because the zero byte is used as the
// terminating character.
//
// It's safe to check bytes instead of runes because all multibyte UTF-8
// code points start with (binary) 11xxxxxx or 10xxxxxx, so 00000000 (i.e.
// 0) will never be part of a multibyte UTF-8 code point. This logic is the
// same as the "r < utf8.RuneSelf" case in strings.IndexRune but can be
// inlined.
//
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/strings/strings.go;l=127
return strings.IndexByte(cs, 0) == -1
}

// appendHeader is the same as bsoncore.AppendHeader but does not check if the
// key is a valid C string since the caller has already checked for that.
//
// The caller of this function must check if key is a valid C string.
func (vw *valueWriter) appendHeader(t bsontype.Type, key string) {
vw.buf = bsoncore.AppendType(vw.buf, t)
vw.buf = append(vw.buf, key...)
vw.buf = append(vw.buf, 0x00)
}

func (vw *valueWriter) appendIntHeader(t bsontype.Type, key int) {
vw.buf = bsoncore.AppendType(vw.buf, t)
vw.buf = strconv.AppendInt(vw.buf, int64(key), 10)
vw.buf = append(vw.buf, 0x00)
}
32 changes: 29 additions & 3 deletions bson/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package bson
import (
"bytes"
"encoding/json"
"sync"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
Expand Down Expand Up @@ -141,6 +142,13 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{
return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val)
}

// Pool of buffers for marshalling BSON.
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}

// MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the
// bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be
// transformed into a document, MarshalValueAppendWithContext should be used instead.
Expand All @@ -162,8 +170,26 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{
//
// See [Encoder] for more examples.
func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) {
sw := new(bsonrw.SliceWriter)
*sw = dst
sw := bufPool.Get().(*bytes.Buffer)
defer func() {
// Proper usage of a sync.Pool requires each entry to have approximately
// the same memory cost. To obtain this property when the stored type
// contains a variably-sized buffer, we add a hard limit on the maximum
// buffer to place back in the pool. We limit the size to 16MiB because
// that's the maximum wire message size supported by any current MongoDB
// server.
//
// Comment based on
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
//
// Recycle byte slices that are smaller than 16MiB and at least half
// occupied.
if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() {
bufPool.Put(sw)
}
}()

sw.Reset()
vw := bvwPool.Get(sw)
defer bvwPool.Put(vw)

Expand All @@ -184,7 +210,7 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf
return nil, err
}

return *sw, nil
return append(dst, sw.Bytes()...), nil
}

// MarshalValue returns the BSON encoding of val.
Expand Down

0 comments on commit cbe8aa4

Please sign in to comment.