From 6da1c88b62d9e30eebac9d553cb5e81e8c2b3a1a Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 22 Oct 2024 01:43:35 -0700 Subject: [PATCH] Add "clearomitted" directive (#373) * Add "clearomitted" directive Adds `//msgp:clearomitted` directive. This will cause all `omitempty` and `omitzero` fields to be set to the zero value on Unmarshal and Decode, and the field wasn't present in the marshalled data. This can be useful when de-serializing into reused objects, that can have values set for these, and we want to avoid clearing all fields, but also don't want existing fields to leak through. Fields are tracked through a bit mask, and zeroed after the unmarshal loop, if no value has been written. This does not affect marshaling. --- _generated/clearomitted.go | 97 +++++++++++++++++++++++++++++++++ _generated/clearomitted_test.go | 93 +++++++++++++++++++++++++++++++ gen/decode.go | 38 +++++++++++++ gen/elem.go | 11 ++++ gen/spec.go | 30 +++++++++- gen/unmarshal.go | 37 +++++++++++++ parse/directives.go | 7 +++ parse/getast.go | 4 +- 8 files changed, 313 insertions(+), 4 deletions(-) create mode 100644 _generated/clearomitted.go create mode 100644 _generated/clearomitted_test.go diff --git a/_generated/clearomitted.go b/_generated/clearomitted.go new file mode 100644 index 00000000..a6a5e10b --- /dev/null +++ b/_generated/clearomitted.go @@ -0,0 +1,97 @@ +package _generated + +import ( + "encoding/json" + "time" +) + +//go:generate msgp + +//msgp:clearomitted + +// check some specific cases for omitzero + +type ClearOmitted0 struct { + AStruct ClearOmittedA `msg:"astruct,omitempty"` // leave this one omitempty + BStruct ClearOmittedA `msg:"bstruct,omitzero"` // and compare to this + AStructPtr *ClearOmittedA `msg:"astructptr,omitempty"` // a pointer case omitempty + BStructPtr *ClearOmittedA `msg:"bstructptr,omitzero"` // a pointer case omitzero + AExt OmitZeroExt `msg:"aext,omitzero"` // external type case + + // more + APtrNamedStr *NamedStringCO `msg:"aptrnamedstr,omitzero"` + ANamedStruct NamedStructCO `msg:"anamedstruct,omitzero"` + APtrNamedStruct *NamedStructCO `msg:"aptrnamedstruct,omitzero"` + EmbeddableStructCO `msg:",flatten,omitzero"` // embed flat + EmbeddableStructCO2 `msg:"embeddablestruct2,omitzero"` // embed non-flat + ATime time.Time `msg:"atime,omitzero"` + ASlice []int `msg:"aslice,omitempty"` + AMap map[string]int `msg:"amap,omitempty"` + ABin []byte `msg:"abin,omitempty"` + AInt int `msg:"aint,omitempty"` + AString string `msg:"atring,omitempty"` + Adur time.Duration `msg:"adur,omitempty"` + AJSON json.Number `msg:"ajson,omitempty"` + + ClearOmittedTuple ClearOmittedTuple `msg:"ozt"` // the inside of a tuple should ignore both omitempty and omitzero +} + +type ClearOmittedA struct { + A string `msg:"a,omitempty"` + B NamedStringCO `msg:"b,omitzero"` + C NamedStringCO `msg:"c,omitzero"` +} + +func (o *ClearOmittedA) IsZero() bool { + if o == nil { + return true + } + return *o == (ClearOmittedA{}) +} + +type NamedStructCO struct { + A string `msg:"a,omitempty"` + B string `msg:"b,omitempty"` +} + +func (ns *NamedStructCO) IsZero() bool { + if ns == nil { + return true + } + return *ns == (NamedStructCO{}) +} + +type NamedStringCO string + +func (ns *NamedStringCO) IsZero() bool { + if ns == nil { + return true + } + return *ns == "" +} + +type EmbeddableStructCO struct { + SomeEmbed string `msg:"someembed2,omitempty"` +} + +func (es EmbeddableStructCO) IsZero() bool { return es == (EmbeddableStructCO{}) } + +type EmbeddableStructCO2 struct { + SomeEmbed2 string `msg:"someembed2,omitempty"` +} + +func (es EmbeddableStructCO2) IsZero() bool { return es == (EmbeddableStructCO2{}) } + +//msgp:tuple ClearOmittedTuple + +// ClearOmittedTuple is flagged for tuple output, it should ignore all omitempty and omitzero functionality +// since it's fundamentally incompatible. +type ClearOmittedTuple struct { + FieldA string `msg:"fielda,omitempty"` + FieldB NamedStringCO `msg:"fieldb,omitzero"` + FieldC NamedStringCO `msg:"fieldc,omitzero"` +} + +type ClearOmitted1 struct { + T1 ClearOmittedTuple `msg:"t1"` +} diff --git a/_generated/clearomitted_test.go b/_generated/clearomitted_test.go new file mode 100644 index 00000000..e40d98ca --- /dev/null +++ b/_generated/clearomitted_test.go @@ -0,0 +1,93 @@ +package _generated + +import ( + "bytes" + "encoding/json" + "reflect" + "testing" + "time" + + "github.com/tinylib/msgp/msgp" +) + +func TestClearOmitted(t *testing.T) { + cleared := ClearOmitted0{} + encoded, err := cleared.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + vPtr := NamedStringCO("value") + filled := ClearOmitted0{ + AStruct: ClearOmittedA{A: "something"}, + BStruct: ClearOmittedA{A: "somthing"}, + AStructPtr: &ClearOmittedA{A: "something"}, + AExt: OmitZeroExt{25}, + APtrNamedStr: &vPtr, + ANamedStruct: NamedStructCO{A: "value"}, + APtrNamedStruct: &NamedStructCO{A: "sdf"}, + EmbeddableStructCO: EmbeddableStructCO{"value"}, + EmbeddableStructCO2: EmbeddableStructCO2{"value"}, + ATime: time.Now(), + ASlice: []int{1, 2, 3}, + AMap: map[string]int{"1": 1}, + ABin: []byte{1, 2, 3}, + ClearOmittedTuple: ClearOmittedTuple{FieldA: "value"}, + AInt: 42, + AString: "value", + Adur: time.Second, + AJSON: json.Number(`43.0000000000002`), + } + dst := filled + _, err = dst.UnmarshalMsg(encoded) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(dst, cleared) { + t.Errorf("\n got=%#v\nwant=%#v", dst, cleared) + } + // Reset + dst = filled + err = dst.DecodeMsg(msgp.NewReader(bytes.NewReader(encoded))) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(dst, cleared) { + t.Errorf("\n got=%#v\nwant=%#v", dst, cleared) + } + + // Check that fields aren't accidentally zeroing fields. + wantJson, err := json.Marshal(filled) + if err != nil { + t.Fatal(err) + } + encoded, err = filled.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + dst = ClearOmitted0{} + _, err = dst.UnmarshalMsg(encoded) + if err != nil { + t.Fatal(err) + } + got, err := json.Marshal(dst) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, wantJson) { + t.Errorf("\n got=%#v\nwant=%#v", string(got), string(wantJson)) + } + // Reset + dst = ClearOmitted0{} + err = dst.DecodeMsg(msgp.NewReader(bytes.NewReader(encoded))) + if err != nil { + t.Fatal(err) + } + got, err = json.Marshal(dst) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, wantJson) { + t.Errorf("\n got=%#v\nwant=%#v", string(got), string(wantJson)) + } + t.Log("OK - got", string(got)) +} diff --git a/gen/decode.go b/gen/decode.go index cb78c5e3..daca62be 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -107,6 +107,22 @@ func (d *decodeGen) structAsMap(s *Struct) { d.p.declare(sz, u32) d.assignAndCheck(sz, mapHeader) + oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero") + if !d.ctx.clearOmitted { + oeCount = 0 + } + bm := bmask{ + bitlen: oeCount, + varname: sz + "Mask", + } + if oeCount > 0 { + // Declare mask + d.p.printf("\n%s", bm.typeDecl()) + d.p.printf("\n_ = %s", bm.varname) + } + // Index to field idx of each emitted + oeEmittedIdx := []int{} + d.p.printf("\nfor %s > 0 {\n%s--", sz, sz) d.assignAndCheck("field", mapKey) d.p.print("\nswitch msgp.UnsafeString(field) {") @@ -123,6 +139,10 @@ func (d *decodeGen) structAsMap(s *Struct) { } SetIsAllowNil(fieldElem, anField) next(d, fieldElem) + if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) { + d.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx))) + oeEmittedIdx = append(oeEmittedIdx, i) + } d.ctx.Pop() if !d.p.ok() { return @@ -136,6 +156,24 @@ func (d *decodeGen) structAsMap(s *Struct) { d.p.closeblock() // close switch d.p.closeblock() // close for loop + + if oeCount > 0 { + d.p.printf("\n// Clear omitted fields.\n") + d.p.printf("if %s {\n", bm.notAllSet()) + for bitIdx, fieldIdx := range oeEmittedIdx { + fieldElem := s.Fields[fieldIdx].FieldElem + + d.p.printf("if %s == 0 {\n", bm.readExpr(bitIdx)) + fze := fieldElem.ZeroExpr() + if fze != "" { + d.p.printf("%s = %s\n", fieldElem.Varname(), fze) + } else { + d.p.printf("%s = %s{}\n", fieldElem.Varname(), fieldElem.TypeName()) + } + d.p.printf("}\n") + } + d.p.printf("}") + } } func (d *decodeGen) gBase(b *BaseElem) { diff --git a/gen/elem.go b/gen/elem.go index 1fd26ff7..1455170f 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -516,6 +516,17 @@ func (s *Struct) AnyHasTagPart(pname string) bool { return false } +// CountFieldTagPart the count of HasTagPart(p) is true for any field. +func (s *Struct) CountFieldTagPart(pname string) int { + var n int + for _, sf := range s.Fields { + if sf.HasTagPart(pname) { + n++ + } + } + return n +} + type StructField struct { FieldTag string // the string inside the `msg:""` tag up to the first comma FieldTagParts []string // the string inside the `msg:""` tag split by commas diff --git a/gen/spec.go b/gen/spec.go index c0bccfd2..18cbec39 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "strings" ) const ( @@ -77,6 +78,7 @@ const ( type Printer struct { gens []generator CompactFloats bool + ClearOmitted bool } func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer { @@ -145,7 +147,7 @@ func (p *Printer) Print(e Elem) error { // collisions between idents created during SetVarname and idents created during Print, // hence the separate prefixes. resetIdent("zb") - err := g.Execute(e, Context{compFloats: p.CompactFloats}) + err := g.Execute(e, Context{compFloats: p.CompactFloats, clearOmitted: p.ClearOmitted}) resetIdent("za") if err != nil { @@ -172,8 +174,9 @@ func (c contextVar) Arg() string { } type Context struct { - path []contextItem - compFloats bool + path []contextItem + compFloats bool + clearOmitted bool } func (c *Context) PushString(s string) { @@ -501,3 +504,24 @@ func (b *bmask) setStmt(bitoffset int) string { return buf.String() } + +// notAllSet returns a check against all fields having been set in set. +func (b *bmask) notAllSet() string { + var buf bytes.Buffer + buf.Grow(len(b.varname) + 16) + buf.WriteString(b.varname) + if b.bitlen > 64 { + var bytes []string + remain := b.bitlen + for remain >= 8 { + bytes = append(bytes, "0xff") + } + if remain > 0 { + bytes = append(bytes, fmt.Sprintf("0x%X", remain)) + } + fmt.Fprintf(&buf, " != [%d]byte{%s}\n", (b.bitlen+63)/64, strings.Join(bytes, ",")) + } + fmt.Fprintf(&buf, " != 0x%x", uint64(1< 0 { + // Declare mask + u.p.printf("\n%s", bm.typeDecl()) + u.p.printf("\n_ = %s", bm.varname) + } + // Index to field idx of each emitted + oeEmittedIdx := []int{} + u.p.printf("\nfor %s > 0 {", sz) u.p.printf("\n%s--; field, bts, err = msgp.ReadMapKeyZC(bts)", sz) u.p.wrapErrCheck(u.ctx.ArgsStr()) @@ -122,6 +138,10 @@ func (u *unmarshalGen) mapstruct(s *Struct) { SetIsAllowNil(fieldElem, anField) next(u, fieldElem) u.ctx.Pop() + if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) { + u.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx))) + oeEmittedIdx = append(oeEmittedIdx, i) + } if anField { u.p.printf("\n}") } @@ -129,6 +149,23 @@ func (u *unmarshalGen) mapstruct(s *Struct) { u.p.print("\ndefault:\nbts, err = msgp.Skip(bts)") u.p.wrapErrCheck(u.ctx.ArgsStr()) u.p.print("\n}\n}") // close switch and for loop + if oeCount > 0 { + u.p.printf("\n// Clear omitted fields.\n") + u.p.printf("if %s {\n", bm.notAllSet()) + for bitIdx, fieldIdx := range oeEmittedIdx { + fieldElem := s.Fields[fieldIdx].FieldElem + + u.p.printf("if %s == 0 {\n", bm.readExpr(bitIdx)) + fze := fieldElem.ZeroExpr() + if fze != "" { + u.p.printf("%s = %s\n", fieldElem.Varname(), fze) + } else { + u.p.printf("%s = %s{}\n", fieldElem.Varname(), fieldElem.TypeName()) + } + u.p.printf("}\n") + } + u.p.printf("}") + } } func (u *unmarshalGen) gBase(b *BaseElem) { diff --git a/parse/directives.go b/parse/directives.go index 2ea1be09..1a50a98a 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -27,6 +27,7 @@ var directives = map[string]directive{ "ignore": ignore, "tuple": astuple, "compactfloats": compactfloats, + "clearomitted": clearomitted, } // map of all recognized directives which will be applied @@ -193,3 +194,9 @@ func compactfloats(text []string, f *FileSet) error { f.CompactFloats = true return nil } + +//msgp:clearomitted +func clearomitted(text []string, f *FileSet) error { + f.ClearOmitted = true + return nil +} diff --git a/parse/getast.go b/parse/getast.go index 7d6cebe9..35c2bd8f 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -21,7 +21,8 @@ type FileSet struct { Identities map[string]gen.Elem // processed from specs Directives []string // raw preprocessor directives Imports []*ast.ImportSpec // imports - CompactFloats bool // Use smaller floats when feasible. + CompactFloats bool // Use smaller floats when feasible + ClearOmitted bool // Set omitted fields to zero value tagName string // tag to read field names from pointerRcv bool // generate with pointer receivers. } @@ -271,6 +272,7 @@ loop: } } p.CompactFloats = f.CompactFloats + p.ClearOmitted = f.ClearOmitted } func (f *FileSet) PrintTo(p *gen.Printer) error {