diff --git a/bson/bson_test.go b/bson/bson_test.go index dcfc1037d9..78fd4986c5 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -297,6 +297,178 @@ func TestD(t *testing.T) { }) } +func TestD_MarshalJSON(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + test D + expected interface{} + }{ + { + "nil", + nil, + nil, + }, + { + "empty", + D{}, + struct{}{}, + }, + { + "non-empty", + D{ + {"a", 42}, + {"b", true}, + {"c", "answer"}, + {"d", nil}, + {"e", 2.71828}, + {"f", A{42, true, "answer", nil, 2.71828}}, + {"g", D{{"foo", "bar"}}}, + }, + struct { + A int `json:"a"` + B bool `json:"b"` + C string `json:"c"` + D interface{} `json:"d"` + E float32 `json:"e"` + F []interface{} `json:"f"` + G map[string]interface{} `json:"g"` + }{ + A: 42, + B: true, + C: "answer", + D: nil, + E: 2.71828, + F: []interface{}{42, true, "answer", nil, 2.71828}, + G: map[string]interface{}{"foo": "bar"}, + }, + }, + } + for _, tc := range testcases { + tc := tc + t.Run("json.Marshal "+tc.name, func(t *testing.T) { + t.Parallel() + + got, err := json.Marshal(tc.test) + assert.NoError(t, err) + want, _ := json.Marshal(tc.expected) + assert.Equal(t, want, got) + }) + } + for _, tc := range testcases { + tc := tc + t.Run("json.MarshalIndent "+tc.name, func(t *testing.T) { + t.Parallel() + + got, err := json.MarshalIndent(tc.test, "", "") + assert.NoError(t, err) + want, _ := json.MarshalIndent(tc.expected, "", "") + assert.Equal(t, want, got) + }) + } +} + +func TestD_UnmarshalJSON(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + test []byte + expected D + }{ + { + "nil", + []byte(`null`), + nil, + }, + { + "empty", + []byte(`{}`), + D{}, + }, + { + "non-empty", + []byte(`{"hello":"world","pi":3.142,"boolean":true,"nothing":null,"list":["hello world",3.142,false,null,{"Lorem":"ipsum"}],"document":{"foo":"bar"}}`), + D{ + {"hello", "world"}, + {"pi", 3.142}, + {"boolean", true}, + {"nothing", nil}, + {"list", []interface{}{"hello world", 3.142, false, nil, D{{"Lorem", "ipsum"}}}}, + {"document", D{{"foo", "bar"}}}, + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var got D + err := json.Unmarshal(tc.test, &got) + assert.NoError(t, err) + assert.Equal(t, tc.expected, got) + }) + } + }) + + t.Run("failure", func(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + test string + }{ + { + "illegal", + `nil`, + }, + { + "invalid", + `{"pi": 3.142ipsum}`, + }, + { + "malformatted", + `{"pi", 3.142}`, + }, + { + "truncated", + `{"pi": 3.142`, + }, + { + "array type", + `["pi", 3.142]`, + }, + { + "boolean type", + `true`, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var a map[string]interface{} + want := json.Unmarshal([]byte(tc.test), &a) + var b D + got := json.Unmarshal([]byte(tc.test), &b) + switch w := want.(type) { + case *json.UnmarshalTypeError: + w.Type = reflect.TypeOf(b) + require.IsType(t, want, got) + g := got.(*json.UnmarshalTypeError) + assert.Equal(t, w, g) + default: + assert.Equal(t, want, got) + } + }) + } + }) +} + type stringerString string func (ss stringerString) String() string { diff --git a/bson/primitive.go b/bson/primitive.go index ef57fc26d1..281d233553 100644 --- a/bson/primitive.go +++ b/bson/primitive.go @@ -13,6 +13,7 @@ import ( "bytes" "encoding/json" "fmt" + "reflect" "time" ) @@ -216,6 +217,55 @@ func (d D) Map() M { return m } +// MarshalJSON encodes D into JSON. +func (d D) MarshalJSON() ([]byte, error) { + if d == nil { + return json.Marshal(nil) + } + var err error + var buf bytes.Buffer + buf.Write([]byte("{")) + enc := json.NewEncoder(&buf) + for i, e := range d { + err = enc.Encode(e.Key) + if err != nil { + return nil, err + } + buf.Write([]byte(":")) + err = enc.Encode(e.Value) + if err != nil { + return nil, err + } + if i < len(d)-1 { + buf.Write([]byte(",")) + } + } + buf.Write([]byte("}")) + return json.RawMessage(buf.Bytes()).MarshalJSON() +} + +// UnmarshalJSON decodes D from JSON. +func (d *D) UnmarshalJSON(b []byte) error { + dec := json.NewDecoder(bytes.NewReader(b)) + t, err := dec.Token() + if err != nil { + return err + } + if t == nil { + *d = nil + return nil + } + if v, ok := t.(json.Delim); !ok || v != '{' { + return &json.UnmarshalTypeError{ + Value: tokenString(t), + Type: reflect.TypeOf(D(nil)), + Offset: dec.InputOffset(), + } + } + *d, err = jsonDecodeD(dec) + return err +} + // E represents a BSON element for a D. It is usually used inside a D. type E struct { Key string @@ -237,3 +287,97 @@ type M map[string]interface{} // // bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}} type A []interface{} + +func jsonDecodeD(dec *json.Decoder) (D, error) { + res := D{} + for { + var e E + + t, err := dec.Token() + if err != nil { + return nil, err + } + key, ok := t.(string) + if !ok { + break + } + e.Key = key + + t, err = dec.Token() + if err != nil { + return nil, err + } + switch v := t.(type) { + case json.Delim: + switch v { + case '[': + e.Value, err = jsonDecodeSlice(dec) + if err != nil { + return nil, err + } + case '{': + e.Value, err = jsonDecodeD(dec) + if err != nil { + return nil, err + } + } + default: + e.Value = t + } + + res = append(res, e) + } + return res, nil +} + +func jsonDecodeSlice(dec *json.Decoder) ([]interface{}, error) { + var res []interface{} + done := false + for !done { + t, err := dec.Token() + if err != nil { + return nil, err + } + switch v := t.(type) { + case json.Delim: + switch v { + case '[': + a, err := jsonDecodeSlice(dec) + if err != nil { + return nil, err + } + res = append(res, a) + case '{': + d, err := jsonDecodeD(dec) + if err != nil { + return nil, err + } + res = append(res, d) + default: + done = true + } + default: + res = append(res, t) + } + } + return res, nil +} + +func tokenString(t json.Token) string { + switch v := t.(type) { + case json.Delim: + switch v { + case '{': + return "object" + case '[': + return "array" + } + case bool: + return "bool" + case float64: + return "number" + case json.Number, string: + return "string" + } + return "unknown" +}