diff --git a/mergevalues.go b/mergevalues.go new file mode 100644 index 0000000..d3d0f53 --- /dev/null +++ b/mergevalues.go @@ -0,0 +1,123 @@ +package astjson + +import ( + "bytes" + "errors" +) + +var ( + ErrMergeDifferentTypes = errors.New("cannot merge different types") + ErrMergeDifferingArrayLengths = errors.New("cannot merge arrays of differing lengths") + ErrMergeUnknownType = errors.New("cannot merge unknown type") +) + +func MergeValues(a, b *Value) (v *Value, changed bool, err error) { + if a == nil { + return b, true, nil + } + if b == nil { + return a, false, nil + } + aBool, bBool := a.Type() == TypeTrue || a.Type() == TypeFalse, b.Type() == TypeTrue || b.Type() == TypeFalse + booleans := aBool && bBool + oneIsNull := a.Type() == TypeNull || b.Type() == TypeNull + if a.Type() != b.Type() && !booleans && !oneIsNull { + return nil, false, ErrMergeDifferentTypes + } + if b.Type() == TypeNull && a.Type() != TypeNull { + return b, true, nil + } + switch a.Type() { + case TypeObject: + ao, _ := a.Object() + bo, _ := b.Object() + ao.unescapeKeys() + bo.unescapeKeys() + for i := range bo.kvs { + k := bo.kvs[i].k + r := bo.kvs[i].v + l := ao.Get(k) + if l == nil { + ao.Set(k, r) + continue + } + n, changed, err := MergeValues(l, r) + if err != nil { + return nil, false, err + } + if changed { + ao.Set(k, n) + } + } + return a, false, nil + case TypeArray: + aa, _ := a.Array() + ba, _ := b.Array() + if len(aa) == 0 { + return b, true, nil + } + if len(ba) == 0 { + return a, false, nil + } + if len(aa) != len(ba) { + return nil, false, ErrMergeDifferingArrayLengths + } + for i := range aa { + n, changed, err := MergeValues(aa[i], ba[i]) + if err != nil { + return nil, false, err + } + if changed { + aa[i] = n + } + } + return a, false, nil + case TypeFalse: + if b.Type() == TypeTrue { + return b, true, nil + } + return a, false, nil + case TypeTrue: + if b.Type() == TypeFalse { + return b, true, nil + } + return a, false, nil + case TypeNull: + if b.Type() != TypeNull { + return b, true, nil + } + return a, false, nil + case TypeNumber: + af, _ := a.Float64() + bf, _ := b.Float64() + if af != bf { + return b, true, nil + } + return a, false, nil + case TypeString: + as, _ := a.StringBytes() + bs, _ := b.StringBytes() + if !bytes.Equal(as, bs) { + return b, true, nil + } + return a, false, nil + default: + return nil, false, ErrMergeUnknownType + } +} + +func MergeValuesWithPath(a, b *Value, path ...string) (v *Value, changed bool, err error) { + if len(path) == 0 { + return MergeValues(a, b) + } + root := &Value{ + t: TypeObject, + } + current := root + for i := 0; i < len(path)-1; i++ { + current.Set(path[i], &Value{t: TypeObject}) + current = current.Get(path[i]) + } + current.Set(path[len(path)-1], b) + return MergeValues(a, root) +} diff --git a/mergevalues_test.go b/mergevalues_test.go new file mode 100644 index 0000000..445bbc2 --- /dev/null +++ b/mergevalues_test.go @@ -0,0 +1,303 @@ +package astjson + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMergeValues(t *testing.T) { + t.Parallel() + t.Run("left nil", func(t *testing.T) { + t.Parallel() + b := MustParse(`{"b":2}`) + merged, changed, err := MergeValues(nil, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `{"b":2}`, string(out)) + out = merged.Get("b").MarshalTo(out[:0]) + require.Equal(t, `2`, string(out)) + }) + t.Run("right nil", func(t *testing.T) { + t.Parallel() + a := MustParse(`{"a":1}`) + merged, changed, err := MergeValues(a, nil) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `{"a":1}`, string(out)) + out = merged.Get("a").MarshalTo(out[:0]) + require.Equal(t, `1`, string(out)) + }) + t.Run("type mismatch err", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`{"a":1}`), MustParse(`{"a":true}`) + merged, changed, err := MergeValues(a, b) + require.Equal(t, ErrMergeDifferentTypes, err) + require.Nil(t, merged) + require.Equal(t, false, changed) + }) + t.Run("bool type mismatch ok", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`true`), MustParse(`false`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `false`, string(out)) + }) + t.Run("bool type mismatch ok reverse", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`false`), MustParse(`true`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `true`, string(out)) + }) + t.Run("integers", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`1`), MustParse(`2`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `2`, string(out)) + }) + t.Run("integers reverse", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`2`), MustParse(`1`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `1`, string(out)) + }) + t.Run("integers equal", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`1`), MustParse(`1`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `1`, string(out)) + }) + t.Run("floats", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`1.1`), MustParse(`2.2`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `2.2`, string(out)) + }) + t.Run("floats reverse", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`2.2`), MustParse(`1.1`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `1.1`, string(out)) + }) + t.Run("floats equal", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`1.1`), MustParse(`1.1`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `1.1`, string(out)) + }) + t.Run("arrays", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`[1,2]`), MustParse(`[3,4]`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `[3,4]`, string(out)) + }) + t.Run("left array empty", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`[]`), MustParse(`[1,2]`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `[1,2]`, string(out)) + }) + t.Run("right array empty", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`[1,2]`), MustParse(`[]`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `[1,2]`, string(out)) + }) + t.Run("err differing array lengths", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`[1,2]`), MustParse(`[3]`) + merged, changed, err := MergeValues(a, b) + require.Equal(t, ErrMergeDifferingArrayLengths, err) + require.Nil(t, merged) + require.Equal(t, false, changed) + }) + t.Run("err merging array item", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`[1,2]`), MustParse(`[3,"a"]`) + merged, changed, err := MergeValues(a, b) + require.Error(t, err) + require.Nil(t, merged) + require.Equal(t, false, changed) + }) + t.Run("false false", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`false`), MustParse(`false`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `false`, string(out)) + }) + t.Run("true true", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`true`), MustParse(`true`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `true`, string(out)) + }) + t.Run("null null", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`null`), MustParse(`null`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `null`, string(out)) + }) + t.Run("null not null", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`null`), MustParse(`1`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `1`, string(out)) + }) + t.Run("null not null reverse", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`1`), MustParse(`null`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `null`, string(out)) + }) + t.Run("array objects", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`[{"a":1,"b":2},{"x":1}]`), MustParse(`[{"a":2,"b":3,"c":4},{"y":1}]`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `[{"a":2,"b":3,"c":4},{"x":1,"y":1}]`, string(out)) + }) + t.Run("objects", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`{"a":{"b":1}}`), MustParse(`{"a":{"c":2}}`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `{"a":{"b":1,"c":2}}`, string(out)) + }) + t.Run("strings", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`"a"`), MustParse(`"b"`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, true, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `"b"`, string(out)) + }) + t.Run("strings equal", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`"a"`), MustParse(`"a"`) + merged, changed, err := MergeValues(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `"a"`, string(out)) + }) + t.Run("with path", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`{"a":{"b":1}}`), MustParse(`{"c":2}`) + merged, changed, err := MergeValuesWithPath(a, b, "a") + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `{"a":{"b":1,"c":2}}`, string(out)) + e := MustParse(`{"e":true}`) + merged, changed, err = MergeValuesWithPath(merged, e, "a", "d") + require.NoError(t, err) + require.Equal(t, false, changed) + out = merged.MarshalTo(out[:0]) + require.Equal(t, `{"a":{"b":1,"c":2,"d":{"e":true}}}`, string(out)) + }) + t.Run("with empty path", func(t *testing.T) { + t.Parallel() + a, b := MustParse(`{"a":1}`), MustParse(`{"b":2}`) + merged, changed, err := MergeValuesWithPath(a, b) + require.NoError(t, err) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `{"a":1,"b":2}`, string(out)) + out = merged.Get("b").MarshalTo(out[:0]) + require.Equal(t, `2`, string(out)) + }) + t.Run("merge with swap", func(t *testing.T) { + t.Parallel() + left := MustParse(`{"a":{"b":1,"c":2,"e":[],"f":[1],"h":[1,2,3]}}`) + right := MustParse(`{"a":{"b":2,"d":3,"e":[1,2,3],"g":[1],"h":[4,5,6]}}`) + out, _, err := MergeValues(left, right) + require.NoError(t, err) + require.Equal(t, `{"a":{"b":2,"c":2,"e":[1,2,3],"f":[1],"h":[4,5,6],"d":3,"g":[1]}}`, out.String()) + }) + t.Run("null true", func(t *testing.T) { + t.Parallel() + left := MustParse(`null`) + right := MustParse(`true`) + out, _, err := MergeValues(left, right) + require.NoError(t, err) + require.Equal(t, `true`, out.String()) + }) + t.Run("true null", func(t *testing.T) { + t.Parallel() + left := MustParse(`true`) + right := MustParse(`null`) + out, _, err := MergeValues(left, right) + require.NoError(t, err) + require.Equal(t, `null`, out.String()) + }) + t.Run("nested null true", func(t *testing.T) { + t.Parallel() + left := MustParse(`{"a":null}`) + right := MustParse(`{"a":true}`) + out, _, err := MergeValues(left, right) + require.NoError(t, err) + require.Equal(t, `{"a":true}`, out.String()) + }) + t.Run("nested true null", func(t *testing.T) { + t.Parallel() + left := MustParse(`{"a":true}`) + right := MustParse(`{"a":null}`) + out, _, err := MergeValues(left, right) + require.NoError(t, err) + require.Equal(t, `{"a":null}`, out.String()) + }) +} diff --git a/update_test.go b/update_test.go index ec5ee84..94be4ae 100644 --- a/update_test.go +++ b/update_test.go @@ -1,12 +1,7 @@ package astjson import ( - "encoding/json" - "fmt" - "strings" "testing" - - "github.com/stretchr/testify/assert" ) func TestObjectDelSet(t *testing.T) { @@ -119,96 +114,3 @@ func TestValue_AppendArrayItems(t *testing.T) { t.Fatalf("unexpected output; got %q; want %q", out, `[1,2,3,4,5,6]`) } } - -func TestMergeWithSwap(t *testing.T) { - left := MustParse(`{"a":{"b":1,"c":2,"e":[],"f":[1],"h":[1,2,3]}}`) - right := MustParse(`{"a":{"b":2,"d":3,"e":[1,2,3],"g":[1],"h":[4,5,6]}}`) - out, _ := MergeValues(left, right) - assert.Equal(t, `{"a":{"b":2,"c":2,"e":[1,2,3],"f":[1],"h":[4,5,6],"d":3,"g":[1]}}`, out.String()) -} - -type RootObject struct { - Child *ChildObject `json:"child"` -} - -type ChildObject struct { - GrandChild *GrandChildObject `json:"grand_child"` -} - -type GrandChildObject struct { - Items []string `json:"items"` -} - -func BenchmarkValue_SetArrayItem(b *testing.B) { - - root := &RootObject{ - Child: &ChildObject{ - GrandChild: &GrandChildObject{ - Items: make([]string, 0), - }, - }, - } - - l, err := json.Marshal(root) - assert.NoError(b, err) - - root.Child.GrandChild.Items = make([]string, 1024*1024) - - for i := 0; i < 1024*1024; i++ { - root.Child.GrandChild.Items[i] = strings.Repeat("a", 1024) - } - - r, err := json.Marshal(root) - assert.NoError(b, err) - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - l, _ := ParseBytesWithoutCache(l) - r, _ := ParseBytesWithoutCache(r) - out, _ := MergeValues(l, r) - arr := out.GetArray("child", "grand_child", "items") - assert.Len(b, arr, 1024*1024) - } -} - -func BenchmarkMergeValuesWithATonOfRecursion(b *testing.B) { - b.ReportAllocs() - - left := MustParse(`{"a":{}}`) - str := fmt.Sprintf( - `{"ba":{"bb":{"bc":{"bd":[%s, %s, %s], "be": {"bf": %s}}}}}`, - objectWithRecursion(10), - objectWithRecursion(20), - objectWithRecursion(2), - objectWithRecursion(3)) - - expected := fmt.Sprintf( - `{"a":{},"ba":{"bb":{"bc":{"bd":[%s,%s,%s],"be":{"bf":%s}}}}}`, - objectWithRecursion(10), - objectWithRecursion(20), - objectWithRecursion(2), - objectWithRecursion(3)) - - _ = expected - - right := MustParse(str) - - for i := 0; i < b.N; i++ { - _, _ = MergeValues(left, right) - /*v, _ := MergeValues(left, right) - if v.String() != expected { - assert.Equal(b, expected, v.String()) - }*/ - } - - fmt.Printf("left: %s\n", left.String()) -} - -func objectWithRecursion(depth int) string { - if depth == 0 { - return `{}` - } - return `{"a":` + objectWithRecursion(depth-1) + `}` -} diff --git a/util.go b/util.go index b1ea06e..a30ddef 100644 --- a/util.go +++ b/util.go @@ -1,7 +1,6 @@ package astjson import ( - "bytes" "unsafe" ) @@ -28,103 +27,6 @@ var ( NullValue = MustParse(`null`) ) -func MergeValues(a, b *Value) (*Value, bool) { - if a == nil { - return b, true - } - if b == nil { - return a, false - } - if a.Type() != b.Type() { - return a, false - } - switch a.Type() { - case TypeObject: - ao, _ := a.Object() - bo, _ := b.Object() - ao.unescapeKeys() - bo.unescapeKeys() - for i := range bo.kvs { - k := bo.kvs[i].k - r := bo.kvs[i].v - l := ao.Get(k) - if l == nil { - ao.Set(k, r) - continue - } - n, changed := MergeValues(l, r) - if changed { - ao.Set(k, n) - } - } - return a, false - case TypeArray: - aa, _ := a.Array() - ba, _ := b.Array() - if len(aa) == 0 { - return b, true - } - if len(ba) == 0 { - return a, false - } - if len(aa) != len(ba) { - return b, true - } - for i := range aa { - n, changed := MergeValues(aa[i], ba[i]) - if changed { - aa[i] = n - } - } - return a, false - case TypeFalse: - if b.Type() == TypeTrue { - return b, true - } - return a, false - case TypeTrue: - if b.Type() == TypeFalse { - return b, true - } - return a, false - case TypeNull: - if b.Type() != TypeNull { - return b, true - } - return a, false - case TypeNumber: - af, _ := a.Float64() - bf, _ := b.Float64() - if af != bf { - return b, true - } - return a, false - case TypeString: - as, _ := a.StringBytes() - bs, _ := b.StringBytes() - if !bytes.Equal(as, bs) { - return b, true - } - return a, false - default: - return b, true - } -} - -func MergeValuesWithPath(a, b *Value, path ...string) (*Value, bool) { - if len(path) == 0 { - return MergeValues(a, b) - } - root := MustParseBytes([]byte(`{}`)) - current := root - for i := 0; i < len(path)-1; i++ { - current.Set(path[i], MustParseBytes([]byte(`{}`))) - current = current.Get(path[i]) - } - current.Set(path[len(path)-1], b) - return MergeValues(a, root) -} - func AppendToArray(array, value *Value) { if array.Type() != TypeArray { return diff --git a/util_test.go b/util_test.go index 04ef947..a7fdec9 100644 --- a/util_test.go +++ b/util_test.go @@ -31,53 +31,6 @@ func TestStartEndString(t *testing.T) { f(getString(100*maxStartEndStringLen), "abcdefghijklmnopqrstuvwxyzabcdefghijklmn...efghijklmnopqrstuvwxyzabcdefghijklmnopqr") } -func TestMergeValues(t *testing.T) { - a, b := MustParse(`{"a":1}`), MustParse(`{"b":2}`) - merged, changed := MergeValues(a, b) - require.Equal(t, false, changed) - out := merged.MarshalTo(nil) - require.Equal(t, `{"a":1,"b":2}`, string(out)) - out = merged.Get("b").MarshalTo(out[:0]) - require.Equal(t, `2`, string(out)) -} - -func TestMergeValuesArray(t *testing.T) { - a, b := MustParse(`[1,2]`), MustParse(`[3,4]`) - merged, changed := MergeValues(a, b) - require.Equal(t, false, changed) - out := merged.MarshalTo(nil) - require.Equal(t, `[3,4]`, string(out)) -} - -func TestMergeObjectValuesArray(t *testing.T) { - a, b := MustParse(`[{"a":1,"b":2},{"x":1}]`), MustParse(`[{"a":2,"b":3,"c":4},{"y":1}]`) - merged, changed := MergeValues(a, b) - require.Equal(t, false, changed) - out := merged.MarshalTo(nil) - require.Equal(t, `[{"a":2,"b":3,"c":4},{"x":1,"y":1}]`, string(out)) -} - -func TestMergeValuesNestedObjects(t *testing.T) { - a, b := MustParse(`{"a":{"b":1}}`), MustParse(`{"a":{"c":2}}`) - merged, changed := MergeValues(a, b) - require.Equal(t, false, changed) - out := merged.MarshalTo(nil) - require.Equal(t, `{"a":{"b":1,"c":2}}`, string(out)) -} - -func TestMergeValuesWithPath(t *testing.T) { - a, b := MustParse(`{"a":{"b":1}}`), MustParse(`{"c":2}`) - merged, changed := MergeValuesWithPath(a, b, "a") - require.Equal(t, false, changed) - out := merged.MarshalTo(nil) - require.Equal(t, `{"a":{"b":1,"c":2}}`, string(out)) - e := MustParse(`{"e":true}`) - merged, changed = MergeValuesWithPath(merged, e, "a", "d") - require.Equal(t, false, changed) - out = merged.MarshalTo(out[:0]) - require.Equal(t, `{"a":{"b":1,"c":2,"d":{"e":true}}}`, string(out)) -} - func TestGetArray(t *testing.T) { a := MustParse(`[{"name":"Jens"},{"name":"Jannik"}]`) arr, err := a.Array()