Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2914 bsoncodec/bsonrw: eliminate encoding allocations #1323

Merged
merged 5 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bson/bsoncodec/slice_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re
}

// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().ConvertibleTo(tD) {
if val.Type() == tD || val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(primitive.D)

dw, err := vw.WriteDocument()
Expand Down
73 changes: 24 additions & 49 deletions bson/bsoncodec/struct_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,14 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val
encoder := desc.encoder

var zero bool
rvInterface := rv.Interface()
if cz, ok := encoder.(CodecZeroer); ok {
zero = cz.IsTypeZero(rvInterface)
zero = cz.IsTypeZero(rv.Interface())
} else if rv.Kind() == reflect.Interface {
// isZero will not treat an interface rv as an interface, so we need to check for the
// zero interface separately.
zero = rv.IsNil()
} else {
zero = isZero(rvInterface, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
zero = isZero(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
}
if desc.omitEmpty && zero {
continue
Expand Down Expand Up @@ -392,56 +391,32 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val
return nil
}

func isZero(i interface{}, omitZeroStruct bool) bool {
v := reflect.ValueOf(i)

// check the value validity
if !v.IsValid() {
return true
func isZero(v reflect.Value, omitZeroStruct bool) bool {
kind := v.Kind()
if (kind != reflect.Pointer || !v.IsNil()) && v.Type().Implements(tZeroer) {
return v.Interface().(Zeroer).IsZero()
}

if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
return z.IsZero()
}

switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
case reflect.Struct:
if kind == reflect.Struct {
if !omitZeroStruct {
return false
}

// TODO(GODRIVER-2820): Update the logic to be able to handle private struct fields.
// TODO Use condition "reflect.Zero(v.Type()).Equal(v)" instead.

vt := v.Type()
if vt == tTime {
return v.Interface().(time.Time).IsZero()
}
for i := 0; i < v.NumField(); i++ {
if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
numField := vt.NumField()
for i := 0; i < numField; i++ {
ff := vt.Field(i)
if ff.PkgPath != "" && !ff.Anonymous {
continue // Private field
}
fld := v.Field(i)
if !isZero(fld.Interface(), omitZeroStruct) {
if !isZero(v.Field(i), omitZeroStruct) {
return false
}
}
return true
}

return false
return !v.IsValid() || v.IsZero()
}

type structDescription struct {
Expand Down Expand Up @@ -708,21 +683,21 @@ func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {

// DeepZero returns recursive zero object
func deepZero(st reflect.Type) (result reflect.Value) {
result = reflect.Indirect(reflect.New(st))

if result.Kind() == reflect.Struct {
for i := 0; i < result.NumField(); i++ {
if f := result.Field(i); f.Kind() == reflect.Ptr {
if f.CanInterface() {
if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
}
if st.Kind() == reflect.Struct {
numField := st.NumField()
for i := 0; i < numField; i++ {
if result == emptyValue {
result = reflect.Indirect(reflect.New(st))
}
f := result.Field(i)
if f.CanInterface() {
if f.Type().Kind() == reflect.Struct {
result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem())))
}
}
}
}

return
return result
}

// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
Expand Down
3 changes: 2 additions & 1 deletion bson/bsoncodec/struct_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsoncodec

import (
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -147,7 +148,7 @@ func TestIsZero(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
t.Parallel()

got := isZero(tc.value, tc.omitZeroStruct)
got := isZero(reflect.ValueOf(tc.value), tc.omitZeroStruct)
assert.Equal(t, tc.want, got, "expected and actual isZero return are different")
})
}
Expand Down
1 change: 1 addition & 0 deletions bson/bsoncodec/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem()
var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem()
var tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem()

var tBinary = reflect.TypeOf(primitive.Binary{})
var tUndefined = reflect.TypeOf(primitive.Undefined{})
Expand Down
6 changes: 3 additions & 3 deletions bson/bsonrw/copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error)
}

vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
defer putValueWriter(vw)

vw.reset(dst)

Expand All @@ -213,7 +213,7 @@ func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
}

vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
defer putValueWriter(vw)

vw.reset(dst)

Expand Down Expand Up @@ -258,7 +258,7 @@ func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []
}

vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
defer putValueWriter(vw)

start := len(dst)

Expand Down
12 changes: 10 additions & 2 deletions bson/bsonrw/value_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,7 @@ func (vr *valueReader) ReadValue() (ValueReader, error) {
return nil, ErrEOA
}

_, err = vr.readCString()
if err != nil {
if err := vr.consumeCString(); err != nil {
return nil, err
}

Expand Down Expand Up @@ -794,6 +793,15 @@ func (vr *valueReader) readByte() (byte, error) {
return vr.d[vr.offset-1], nil
}

func (vr *valueReader) consumeCString() error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Consider the slightly more descriptive name skipCString.

idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
return io.EOF
}
vr.offset += int64(idx) + 1
return nil
}

func (vr *valueReader) readCString() (string, error) {
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
Expand Down
63 changes: 39 additions & 24 deletions bson/bsonrw/value_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ var vwPool = sync.Pool{
},
}

func putValueWriter(vw *valueWriter) {
if vw != nil {
vw.w = nil // don't leak the writer
vwPool.Put(vw)
}
}

// BSONValueWriterPool is a pool for BSON ValueWriters.
//
// Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0.
Expand Down Expand Up @@ -149,32 +156,21 @@ type valueWriter struct {
}

func (vw *valueWriter) advanceFrame() {
if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
length := len(vw.stack)
if length+1 >= cap(vw.stack) {
// double it
buf := make([]vwState, 2*cap(vw.stack)+1)
copy(buf, vw.stack)
vw.stack = buf
}
vw.stack = vw.stack[:length+1]
}
vw.frame++
if vw.frame >= int64(len(vw.stack)) {
vw.stack = append(vw.stack, vwState{})
}
}

func (vw *valueWriter) push(m mode) {
vw.advanceFrame()

// Clean the stack
vw.stack[vw.frame].mode = m
vw.stack[vw.frame].key = ""
vw.stack[vw.frame].arrkey = 0
vw.stack[vw.frame].start = 0
vw.stack[vw.frame] = vwState{mode: m}

vw.stack[vw.frame].mode = m
switch m {
case mDocument, mArray, mCodeWithScope:
vw.reserveLength()
vw.reserveLength() // WARN: this is not needed
}
}

Expand Down Expand Up @@ -213,6 +209,7 @@ func newValueWriter(w io.Writer) *valueWriter {
return vw
}

// TODO: only used in tests
func newValueWriterFromSlice(buf []byte) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
Expand Down Expand Up @@ -249,17 +246,17 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod
}

func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
switch vw.stack[vw.frame].mode {
frame := &vw.stack[vw.frame]
switch frame.mode {
case mElement:
key := vw.stack[vw.frame].key
key := frame.key
if !isValidCString(key) {
return errors.New("BSON element key cannot contain null bytes")
}

vw.buf = bsoncore.AppendHeader(vw.buf, t, key)
vw.appendHeader(t, key)
case mValue:
// TODO: Do this with a cache of the first 1000 or so array keys.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: This TODO is unnecessary because strconv already does this for us. Consider removing it.

Suggested change
// TODO: Do this with a cache of the first 1000 or so array keys.

vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
vw.appendIntHeader(t, frame.arrkey)
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
Expand Down Expand Up @@ -601,9 +598,11 @@ func (vw *valueWriter) writeLength() error {
if length > maxSize {
return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
}
length = length - int(vw.stack[vw.frame].start)
start := vw.stack[vw.frame].start
frame := &vw.stack[vw.frame]
length = length - int(frame.start)
start := frame.start

_ = vw.buf[start+3] // BCE
vw.buf[start+0] = byte(length)
vw.buf[start+1] = byte(length >> 8)
vw.buf[start+2] = byte(length >> 16)
Expand All @@ -612,5 +611,21 @@ func (vw *valueWriter) writeLength() error {
}

func isValidCString(cs string) bool {
return !strings.ContainsRune(cs, '\x00')
return strings.IndexByte(cs, 0) == -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Consider adding a comment about why we can check individual bytes rather than runes.

Suggested change
return strings.IndexByte(cs, 0) == -1
// Disallow the zero byte in a cstring because the zero byte is used as the
// terminating character. It's safe to check bytes instead of runes because
// all multibyte UTF-8 code points start with "11xxxxxx" or "10xxxxxx", so
// "00000000" will never be part of a multibyte UTF-8 code point.
return strings.IndexByte(cs, 0) == -1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Real reason is that '\x00' is a byte (it equals 0 not 00000000) and since it is less that MaxRune (0x80) it just calls IndexByte - this change just omits an unnecessary function call (and one that can't currently be inlined because its complexity score exceeds what is allowed for inlining).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're both saying the same thing. 0x80 (hex) == 10000000 (binary). I can update the comment with a reference to the corresponding code in strings.ContainsRune for additional clarity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I thought you meant 10000000 as hex or a 32bit rune here since that's what IndexRune takes.

}

// appendHeader is the same as bsoncore.AppendHeader but does not check if the
// key is a valid C string since the caller has already checked for that.
//
// The caller of this function must check if key is a valid C string.
func (vw *valueWriter) appendHeader(t bsontype.Type, key string) {
vw.buf = bsoncore.AppendType(vw.buf, t)
vw.buf = append(vw.buf, key...)
vw.buf = append(vw.buf, 0x00)
}

func (vw *valueWriter) appendIntHeader(t bsontype.Type, key int) {
vw.buf = bsoncore.AppendType(vw.buf, t)
vw.buf = strconv.AppendInt(vw.buf, int64(key), 10)
vw.buf = append(vw.buf, 0x00)
}
15 changes: 12 additions & 3 deletions bson/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package bson
import (
"bytes"
"encoding/json"
"sync"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
Expand Down Expand Up @@ -141,6 +142,13 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{
return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val)
}

// Pool of buffers for marshalling BSON.
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}

// MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the
// bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be
// transformed into a document, MarshalValueAppendWithContext should be used instead.
Expand All @@ -162,8 +170,9 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{
//
// See [Encoder] for more examples.
func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) {
sw := new(bsonrw.SliceWriter)
*sw = dst
sw := bufPool.Get().(*bytes.Buffer)
defer bufPool.Put(sw)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check the size and utilization of of sw and only recycle it if it's not too large and not underutilized.

E.g.

defer func() {
	if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() {
		bufPool.Put(sw)
	}
}()

See the similar code in Operation.Execute for details:

defer func() {
// Proper usage of a sync.Pool requires each entry to have approximately the same memory
// cost. To obtain this property when the stored type contains a variably-sized buffer,
// we add a hard limit on the maximum buffer to place back in the pool. We limit the
// size to 16MiB because that's the maximum wire message size supported by MongoDB.
//
// Comment copied from https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
//
// Recycle byte slices that are smaller than 16MiB and at least half occupied.
if c := cap(*wm); c < 16*1024*1024 && c/2 < len(*wm) {
memoryPool.Put(wm)
}
}()

sw.Reset()
vw := bvwPool.Get(sw)
defer bvwPool.Put(vw)

Expand All @@ -184,7 +193,7 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf
return nil, err
}

return *sw, nil
return append(dst, sw.Bytes()...), nil
}

// MarshalValue returns the BSON encoding of val.
Expand Down
Loading