Skip to content

Commit

Permalink
GODRIVER-3326 Add ObjectIDAsHexString to BSONOptions. (#1791)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu authored Sep 9, 2024
1 parent e80251d commit ec59e09
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 20 deletions.
9 changes: 3 additions & 6 deletions internal/codecutil/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (e MarshalError) Error() string {
}

// EncoderFn is used to functionally construct an encoder for marshaling values.
type EncoderFn func(io.Writer) (*bson.Encoder, error)
type EncoderFn func(io.Writer) *bson.Encoder

// MarshalValue will attempt to encode the value with the encoder returned by
// the encoder function.
Expand All @@ -49,14 +49,11 @@ func MarshalValue(val interface{}, encFn EncoderFn) (bsoncore.Value, error) {

buf := new(bytes.Buffer)

enc, err := encFn(buf)
if err != nil {
return bsoncore.Value{}, err
}
enc := encFn(buf)

// Encode the value in a single-element document with an empty key. Use
// bsoncore to extract the first element and return the BSON value.
err = enc.Encode(bson.D{{Key: "", Value: val}})
err := enc.Encode(bson.D{{Key: "", Value: val}})
if err != nil {
return bsoncore.Value{}, MarshalError{Value: val, Err: err}
}
Expand Down
6 changes: 2 additions & 4 deletions internal/codecutil/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ import (
func testEncFn(t *testing.T) EncoderFn {
t.Helper()

return func(w io.Writer) (*bson.Encoder, error) {
return func(w io.Writer) *bson.Encoder {
rw := bson.NewDocumentWriter(w)
enc := bson.NewEncoder(rw)

return enc, nil
return bson.NewEncoder(rw)
}
}

Expand Down
175 changes: 175 additions & 0 deletions internal/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ func (e *negateCodec) DecodeValue(_ bson.DecodeContext, vr bson.ValueReader, val
return nil
}

type intKey int

func (i intKey) MarshalKey() (string, error) {
return fmt.Sprintf("key_%d", i), nil
}

var _ options.ContextDialer = &slowConnDialer{}

// A slowConnDialer dials connections that delay network round trips by the given delay duration.
Expand Down Expand Up @@ -724,6 +730,20 @@ func TestClient_BSONOptions(t *testing.T) {
C string `json:"y" bson:"3"`
}

type omitemptyTest struct {
X jsonTagsTest `bson:"x,omitempty"`
}

type truncatingDoublesTest struct {
X int
}

type timeZoneTest struct {
X time.Time
}

timestamp, _ := time.Parse(time.RFC3339, "2006-01-02T15:04:05+07:00")

testCases := []struct {
name string
bsonOpts *options.BSONOptions
Expand Down Expand Up @@ -766,6 +786,88 @@ func TestClient_BSONOptions(t *testing.T) {
AppendInt32("x", 1).
Build()),
},
{
name: "NilMapAsEmpty",
bsonOpts: &options.BSONOptions{
NilMapAsEmpty: true,
},
doc: bson.D{{Key: "x", Value: map[string]string(nil)}},
decodeInto: func() interface{} { return &bson.D{} },
want: &bson.D{{Key: "x", Value: bson.D{}}},
wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().
AppendDocument("x", bsoncore.NewDocumentBuilder().Build()).
Build()),
},
{
name: "NilSliceAsEmpty",
bsonOpts: &options.BSONOptions{
NilSliceAsEmpty: true,
},
doc: bson.D{{Key: "x", Value: []int(nil)}},
decodeInto: func() interface{} { return &bson.D{} },
want: &bson.D{{Key: "x", Value: bson.A{}}},
wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().
AppendArray("x", bsoncore.NewDocumentBuilder().Build()).
Build()),
},
{
name: "NilByteSliceAsEmpty",
bsonOpts: &options.BSONOptions{
NilByteSliceAsEmpty: true,
},
doc: bson.D{{Key: "x", Value: []byte(nil)}},
decodeInto: func() interface{} { return &bson.D{} },
want: &bson.D{{Key: "x", Value: bson.Binary{Data: []byte{}}}},
wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().
AppendBinary("x", 0, nil).
Build()),
},
{
name: "OmitZeroStruct",
bsonOpts: &options.BSONOptions{
OmitZeroStruct: true,
},
doc: omitemptyTest{},
decodeInto: func() interface{} { return &bson.D{} },
want: &bson.D{},
wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().Build()),
},
{
name: "StringifyMapKeysWithFmt",
bsonOpts: &options.BSONOptions{
StringifyMapKeysWithFmt: true,
},
doc: map[intKey]string{intKey(42): "foo"},
decodeInto: func() interface{} { return &bson.D{} },
want: &bson.D{{"42", "foo"}},
wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().
AppendString("42", "foo").
Build()),
},
{
name: "AllowTruncatingDoubles",
bsonOpts: &options.BSONOptions{
AllowTruncatingDoubles: true,
},
doc: bson.D{{Key: "x", Value: 3.14}},
decodeInto: func() interface{} { return &truncatingDoublesTest{} },
want: &truncatingDoublesTest{3},
wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().
AppendDouble("x", 3.14).
Build()),
},
{
name: "BinaryAsSlice",
bsonOpts: &options.BSONOptions{
BinaryAsSlice: true,
},
doc: bson.D{{Key: "x", Value: []byte{42}}},
decodeInto: func() interface{} { return &bson.D{} },
want: &bson.D{{Key: "x", Value: []byte{42}}},
wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().
AppendBinary("x", 0, []byte{42}).
Build()),
},
{
name: "DefaultDocumentM",
bsonOpts: &options.BSONOptions{
Expand All @@ -775,6 +877,50 @@ func TestClient_BSONOptions(t *testing.T) {
decodeInto: func() interface{} { return &bson.D{} },
want: &bson.D{{Key: "doc", Value: bson.M{"a": int64(1)}}},
},
{
name: "UseLocalTimeZone",
bsonOpts: &options.BSONOptions{
UseLocalTimeZone: true,
},
doc: bson.D{{Key: "x", Value: timestamp}},
decodeInto: func() interface{} { return &timeZoneTest{} },
want: &timeZoneTest{timestamp.In(time.Local)},
},
{
name: "ZeroMaps",
bsonOpts: &options.BSONOptions{
ZeroMaps: true,
},
doc: bson.D{{"a", "apple"}, {"b", "banana"}},
decodeInto: func() interface{} {
return &map[string]string{
"b": "berry",
"c": "carrot",
}
},
want: &map[string]string{
"a": "apple",
"b": "banana",
},
},
{
name: "ZeroStructs",
bsonOpts: &options.BSONOptions{
ZeroStructs: true,
},
doc: bson.D{{"a", "apple"}, {"x", "broccoli"}},
decodeInto: func() interface{} {
return &jsonTagsTest{
B: "banana",
C: "carrot",
}
},
want: &jsonTagsTest{
A: "apple",
B: "broccoli",
C: "",
},
},
}

for _, tc := range testCases {
Expand Down Expand Up @@ -809,6 +955,35 @@ func TestClient_BSONOptions(t *testing.T) {
}

opts := mtest.NewOptions().ClientOptions(
options.Client().SetBSONOptions(&options.BSONOptions{
ObjectIDAsHexString: true,
}))
mt.RunOpts("ObjectIDAsHexString", opts, func(mt *mtest.T) {
res, err := mt.Coll.InsertOne(context.Background(), bson.D{{"x", 42}})
require.NoError(mt, err, "InsertOne error")

sr := mt.Coll.FindOne(
context.Background(),
bson.D{{Key: "_id", Value: res.InsertedID}},
)

type data struct {
ID string `bson:"_id"`
X int `bson:"x"`
}
var got data

err = sr.Decode(&got)
require.NoError(mt, err, "Decode error")

want := data{
ID: res.InsertedID.(bson.ObjectID).Hex(),
X: 42,
}
assert.Equal(mt, want, got, "expected and actual decoded result are different")
})

opts = mtest.NewOptions().ClientOptions(
options.Client().SetBSONOptions(&options.BSONOptions{
ErrorOnInlineDuplicates: true,
}))
Expand Down
3 changes: 3 additions & 0 deletions mongo/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ func getDecoder(
if opts.DefaultDocumentM {
dec.DefaultDocumentM()
}
if opts.ObjectIDAsHexString {
dec.ObjectIDAsHexString()
}
if opts.UseJSONStructTags {
dec.UseJSONStructTags()
}
Expand Down
33 changes: 33 additions & 0 deletions mongo/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package mongo
import (
"context"
"fmt"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -264,6 +265,38 @@ func TestNewCursorFromDocuments(t *testing.T) {
})
}

func TestGetDecoder(t *testing.T) {
t.Parallel()

decT := reflect.TypeOf((*bson.Decoder)(nil))
ctxT := reflect.TypeOf(bson.DecodeContext{})
for i := 0; i < decT.NumMethod(); i++ {
m := decT.Method(i)
// Test methods with no input/output parameter.
if m.Type.NumIn() != 1 || m.Type.NumOut() != 0 {
continue
}
t.Run(m.Name, func(t *testing.T) {
var opts options.BSONOptions
optsV := reflect.ValueOf(&opts).Elem()
f, ok := optsV.Type().FieldByName(m.Name)
require.True(t, ok, "expected %s field in %s", m.Name, optsV.Type())

wantDec := reflect.ValueOf(bson.NewDecoder(nil))
_ = wantDec.Method(i).Call(nil)
wantCtx := wantDec.Elem().Field(0)
require.Equal(t, ctxT, wantCtx.Type())

optsV.FieldByIndex(f.Index).SetBool(true)
gotDec := getDecoder(nil, &opts, nil)
gotCtx := reflect.ValueOf(gotDec).Elem().Field(0)
require.Equal(t, ctxT, gotCtx.Type())

assert.True(t, gotCtx.Equal(wantCtx), "expected %v: %v, got: %v", ctxT, wantCtx, gotCtx)
})
}
}

func BenchmarkNewCursorFromDocuments(b *testing.B) {
// Prepare sample data
documents := []interface{}{
Expand Down
14 changes: 5 additions & 9 deletions mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func getEncoder(
w io.Writer,
opts *options.BSONOptions,
reg *bson.Registry,
) (*bson.Encoder, error) {
) *bson.Encoder {
vw := bson.NewDocumentWriter(w)
enc := bson.NewEncoder(vw)

Expand Down Expand Up @@ -95,13 +95,13 @@ func getEncoder(
enc.SetRegistry(reg)
}

return enc, nil
return enc
}

// newEncoderFn will return a function for constructing an encoder based on the
// provided codec options.
func newEncoderFn(opts *options.BSONOptions, registry *bson.Registry) codecutil.EncoderFn {
return func(w io.Writer) (*bson.Encoder, error) {
return func(w io.Writer) *bson.Encoder {
return getEncoder(w, opts, registry)
}
}
Expand All @@ -128,12 +128,8 @@ func marshal(
}

buf := new(bytes.Buffer)
enc, err := getEncoder(buf, bsonOpts, registry)
if err != nil {
return nil, fmt.Errorf("error configuring BSON encoder: %w", err)
}

err = enc.Encode(val)
enc := getEncoder(buf, bsonOpts, registry)
err := enc.Encode(val)
if err != nil {
return nil, MarshalError{Value: val, Err: err}
}
Expand Down
33 changes: 33 additions & 0 deletions mongo/mongo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package mongo
import (
"errors"
"fmt"
"reflect"
"testing"

"go.mongodb.org/mongo-driver/v2/bson"
Expand Down Expand Up @@ -608,6 +609,38 @@ func TestMarshalValue(t *testing.T) {
}
}

func TestGetEncoder(t *testing.T) {
t.Parallel()

encT := reflect.TypeOf((*bson.Encoder)(nil))
ctxT := reflect.TypeOf(bson.EncodeContext{})
for i := 0; i < encT.NumMethod(); i++ {
m := encT.Method(i)
// Test methods with no input/output parameter.
if m.Type.NumIn() != 1 || m.Type.NumOut() != 0 {
continue
}
t.Run(m.Name, func(t *testing.T) {
var opts options.BSONOptions
optsV := reflect.ValueOf(&opts).Elem()
f, ok := optsV.Type().FieldByName(m.Name)
require.True(t, ok, "expected %s field in %s", m.Name, optsV.Type())

wantEnc := reflect.ValueOf(bson.NewEncoder(nil))
_ = wantEnc.Method(i).Call(nil)
wantCtx := wantEnc.Elem().Field(0)
require.Equal(t, ctxT, wantCtx.Type())

optsV.FieldByIndex(f.Index).SetBool(true)
gotEnc := getEncoder(nil, &opts, nil)
gotCtx := reflect.ValueOf(gotEnc).Elem().Field(0)
require.Equal(t, ctxT, gotCtx.Type())

assert.True(t, gotCtx.Equal(wantCtx), "expected %v: %v, got: %v", ctxT, wantCtx, gotCtx)
})
}
}

var _ bson.ValueMarshaler = bvMarsh{}

type bvMarsh struct {
Expand Down
Loading

0 comments on commit ec59e09

Please sign in to comment.