diff --git a/_generated/clearomitted.go b/_generated/clearomitted.go new file mode 100644 index 00000000..a6a5e10b --- /dev/null +++ b/_generated/clearomitted.go @@ -0,0 +1,97 @@ +package _generated + +import ( + "encoding/json" + "time" +) + +//go:generate msgp + +//msgp:clearomitted + +// check some specific cases for omitzero + +type ClearOmitted0 struct { + AStruct ClearOmittedA `msg:"astruct,omitempty"` // leave this one omitempty + BStruct ClearOmittedA `msg:"bstruct,omitzero"` // and compare to this + AStructPtr *ClearOmittedA `msg:"astructptr,omitempty"` // a pointer case omitempty + BStructPtr *ClearOmittedA `msg:"bstructptr,omitzero"` // a pointer case omitzero + AExt OmitZeroExt `msg:"aext,omitzero"` // external type case + + // more + APtrNamedStr *NamedStringCO `msg:"aptrnamedstr,omitzero"` + ANamedStruct NamedStructCO `msg:"anamedstruct,omitzero"` + APtrNamedStruct *NamedStructCO `msg:"aptrnamedstruct,omitzero"` + EmbeddableStructCO `msg:",flatten,omitzero"` // embed flat + EmbeddableStructCO2 `msg:"embeddablestruct2,omitzero"` // embed non-flat + ATime time.Time `msg:"atime,omitzero"` + ASlice []int `msg:"aslice,omitempty"` + AMap map[string]int `msg:"amap,omitempty"` + ABin []byte `msg:"abin,omitempty"` + AInt int `msg:"aint,omitempty"` + AString string `msg:"atring,omitempty"` + Adur time.Duration `msg:"adur,omitempty"` + AJSON json.Number `msg:"ajson,omitempty"` + + ClearOmittedTuple ClearOmittedTuple `msg:"ozt"` // the inside of a tuple should ignore both omitempty and omitzero +} + +type ClearOmittedA struct { + A string `msg:"a,omitempty"` + B NamedStringCO `msg:"b,omitzero"` + C NamedStringCO `msg:"c,omitzero"` +} + +func (o *ClearOmittedA) IsZero() bool { + if o == nil { + return true + } + return *o == (ClearOmittedA{}) +} + +type NamedStructCO struct { + A string `msg:"a,omitempty"` + B string `msg:"b,omitempty"` +} + +func (ns *NamedStructCO) IsZero() bool { + if ns == nil { + return true + } + return *ns == (NamedStructCO{}) +} + +type NamedStringCO string + +func (ns *NamedStringCO) IsZero() bool { + if ns == nil { + return true + } + return *ns == "" +} + +type EmbeddableStructCO struct { + SomeEmbed string `msg:"someembed2,omitempty"` +} + +func (es EmbeddableStructCO) IsZero() bool { return es == (EmbeddableStructCO{}) } + +type EmbeddableStructCO2 struct { + SomeEmbed2 string `msg:"someembed2,omitempty"` +} + +func (es EmbeddableStructCO2) IsZero() bool { return es == (EmbeddableStructCO2{}) } + +//msgp:tuple ClearOmittedTuple + +// ClearOmittedTuple is flagged for tuple output, it should ignore all omitempty and omitzero functionality +// since it's fundamentally incompatible. +type ClearOmittedTuple struct { + FieldA string `msg:"fielda,omitempty"` + FieldB NamedStringCO `msg:"fieldb,omitzero"` + FieldC NamedStringCO `msg:"fieldc,omitzero"` +} + +type ClearOmitted1 struct { + T1 ClearOmittedTuple `msg:"t1"` +} diff --git a/_generated/clearomitted_test.go b/_generated/clearomitted_test.go new file mode 100644 index 00000000..e40d98ca --- /dev/null +++ b/_generated/clearomitted_test.go @@ -0,0 +1,93 @@ +package _generated + +import ( + "bytes" + "encoding/json" + "reflect" + "testing" + "time" + + "github.com/tinylib/msgp/msgp" +) + +func TestClearOmitted(t *testing.T) { + cleared := ClearOmitted0{} + encoded, err := cleared.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + vPtr := NamedStringCO("value") + filled := ClearOmitted0{ + AStruct: ClearOmittedA{A: "something"}, + BStruct: ClearOmittedA{A: "somthing"}, + AStructPtr: &ClearOmittedA{A: "something"}, + AExt: OmitZeroExt{25}, + APtrNamedStr: &vPtr, + ANamedStruct: NamedStructCO{A: "value"}, + APtrNamedStruct: &NamedStructCO{A: "sdf"}, + EmbeddableStructCO: EmbeddableStructCO{"value"}, + EmbeddableStructCO2: EmbeddableStructCO2{"value"}, + ATime: time.Now(), + ASlice: []int{1, 2, 3}, + AMap: map[string]int{"1": 1}, + ABin: []byte{1, 2, 3}, + ClearOmittedTuple: ClearOmittedTuple{FieldA: "value"}, + AInt: 42, + AString: "value", + Adur: time.Second, + AJSON: json.Number(`43.0000000000002`), + } + dst := filled + _, err = dst.UnmarshalMsg(encoded) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(dst, cleared) { + t.Errorf("\n got=%#v\nwant=%#v", dst, cleared) + } + // Reset + dst = filled + err = dst.DecodeMsg(msgp.NewReader(bytes.NewReader(encoded))) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(dst, cleared) { + t.Errorf("\n got=%#v\nwant=%#v", dst, cleared) + } + + // Check that fields aren't accidentally zeroing fields. + wantJson, err := json.Marshal(filled) + if err != nil { + t.Fatal(err) + } + encoded, err = filled.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + dst = ClearOmitted0{} + _, err = dst.UnmarshalMsg(encoded) + if err != nil { + t.Fatal(err) + } + got, err := json.Marshal(dst) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, wantJson) { + t.Errorf("\n got=%#v\nwant=%#v", string(got), string(wantJson)) + } + // Reset + dst = ClearOmitted0{} + err = dst.DecodeMsg(msgp.NewReader(bytes.NewReader(encoded))) + if err != nil { + t.Fatal(err) + } + got, err = json.Marshal(dst) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, wantJson) { + t.Errorf("\n got=%#v\nwant=%#v", string(got), string(wantJson)) + } + t.Log("OK - got", string(got)) +} diff --git a/gen/decode.go b/gen/decode.go index cb78c5e3..daca62be 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -107,6 +107,22 @@ func (d *decodeGen) structAsMap(s *Struct) { d.p.declare(sz, u32) d.assignAndCheck(sz, mapHeader) + oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero") + if !d.ctx.clearOmitted { + oeCount = 0 + } + bm := bmask{ + bitlen: oeCount, + varname: sz + "Mask", + } + if oeCount > 0 { + // Declare mask + d.p.printf("\n%s", bm.typeDecl()) + d.p.printf("\n_ = %s", bm.varname) + } + // Index to field idx of each emitted + oeEmittedIdx := []int{} + d.p.printf("\nfor %s > 0 {\n%s--", sz, sz) d.assignAndCheck("field", mapKey) d.p.print("\nswitch msgp.UnsafeString(field) {") @@ -123,6 +139,10 @@ func (d *decodeGen) structAsMap(s *Struct) { } SetIsAllowNil(fieldElem, anField) next(d, fieldElem) + if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) { + d.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx))) + oeEmittedIdx = append(oeEmittedIdx, i) + } d.ctx.Pop() if !d.p.ok() { return @@ -136,6 +156,24 @@ func (d *decodeGen) structAsMap(s *Struct) { d.p.closeblock() // close switch d.p.closeblock() // close for loop + + if oeCount > 0 { + d.p.printf("\n// Clear omitted fields.\n") + d.p.printf("if %s {\n", bm.notAllSet()) + for bitIdx, fieldIdx := range oeEmittedIdx { + fieldElem := s.Fields[fieldIdx].FieldElem + + d.p.printf("if %s == 0 {\n", bm.readExpr(bitIdx)) + fze := fieldElem.ZeroExpr() + if fze != "" { + d.p.printf("%s = %s\n", fieldElem.Varname(), fze) + } else { + d.p.printf("%s = %s{}\n", fieldElem.Varname(), fieldElem.TypeName()) + } + d.p.printf("}\n") + } + d.p.printf("}") + } } func (d *decodeGen) gBase(b *BaseElem) { diff --git a/gen/elem.go b/gen/elem.go index 1fd26ff7..1455170f 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -516,6 +516,17 @@ func (s *Struct) AnyHasTagPart(pname string) bool { return false } +// CountFieldTagPart the count of HasTagPart(p) is true for any field. +func (s *Struct) CountFieldTagPart(pname string) int { + var n int + for _, sf := range s.Fields { + if sf.HasTagPart(pname) { + n++ + } + } + return n +} + type StructField struct { FieldTag string // the string inside the `msg:""` tag up to the first comma FieldTagParts []string // the string inside the `msg:""` tag split by commas diff --git a/gen/spec.go b/gen/spec.go index c0bccfd2..18cbec39 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "strings" ) const ( @@ -77,6 +78,7 @@ const ( type Printer struct { gens []generator CompactFloats bool + ClearOmitted bool } func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer { @@ -145,7 +147,7 @@ func (p *Printer) Print(e Elem) error { // collisions between idents created during SetVarname and idents created during Print, // hence the separate prefixes. resetIdent("zb") - err := g.Execute(e, Context{compFloats: p.CompactFloats}) + err := g.Execute(e, Context{compFloats: p.CompactFloats, clearOmitted: p.ClearOmitted}) resetIdent("za") if err != nil { @@ -172,8 +174,9 @@ func (c contextVar) Arg() string { } type Context struct { - path []contextItem - compFloats bool + path []contextItem + compFloats bool + clearOmitted bool } func (c *Context) PushString(s string) { @@ -501,3 +504,24 @@ func (b *bmask) setStmt(bitoffset int) string { return buf.String() } + +// notAllSet returns a check against all fields having been set in set. +func (b *bmask) notAllSet() string { + var buf bytes.Buffer + buf.Grow(len(b.varname) + 16) + buf.WriteString(b.varname) + if b.bitlen > 64 { + var bytes []string + remain := b.bitlen + for remain >= 8 { + bytes = append(bytes, "0xff") + } + if remain > 0 { + bytes = append(bytes, fmt.Sprintf("0x%X", remain)) + } + fmt.Fprintf(&buf, " != [%d]byte{%s}\n", (b.bitlen+63)/64, strings.Join(bytes, ",")) + } + fmt.Fprintf(&buf, " != 0x%x", uint64(1< 0 { + // Declare mask + u.p.printf("\n%s", bm.typeDecl()) + u.p.printf("\n_ = %s", bm.varname) + } + // Index to field idx of each emitted + oeEmittedIdx := []int{} + u.p.printf("\nfor %s > 0 {", sz) u.p.printf("\n%s--; field, bts, err = msgp.ReadMapKeyZC(bts)", sz) u.p.wrapErrCheck(u.ctx.ArgsStr()) @@ -122,6 +138,10 @@ func (u *unmarshalGen) mapstruct(s *Struct) { SetIsAllowNil(fieldElem, anField) next(u, fieldElem) u.ctx.Pop() + if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) { + u.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx))) + oeEmittedIdx = append(oeEmittedIdx, i) + } if anField { u.p.printf("\n}") } @@ -129,6 +149,23 @@ func (u *unmarshalGen) mapstruct(s *Struct) { u.p.print("\ndefault:\nbts, err = msgp.Skip(bts)") u.p.wrapErrCheck(u.ctx.ArgsStr()) u.p.print("\n}\n}") // close switch and for loop + if oeCount > 0 { + u.p.printf("\n// Clear omitted fields.\n") + u.p.printf("if %s {\n", bm.notAllSet()) + for bitIdx, fieldIdx := range oeEmittedIdx { + fieldElem := s.Fields[fieldIdx].FieldElem + + u.p.printf("if %s == 0 {\n", bm.readExpr(bitIdx)) + fze := fieldElem.ZeroExpr() + if fze != "" { + u.p.printf("%s = %s\n", fieldElem.Varname(), fze) + } else { + u.p.printf("%s = %s{}\n", fieldElem.Varname(), fieldElem.TypeName()) + } + u.p.printf("}\n") + } + u.p.printf("}") + } } func (u *unmarshalGen) gBase(b *BaseElem) { diff --git a/parse/directives.go b/parse/directives.go index 2ea1be09..1a50a98a 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -27,6 +27,7 @@ var directives = map[string]directive{ "ignore": ignore, "tuple": astuple, "compactfloats": compactfloats, + "clearomitted": clearomitted, } // map of all recognized directives which will be applied @@ -193,3 +194,9 @@ func compactfloats(text []string, f *FileSet) error { f.CompactFloats = true return nil } + +//msgp:clearomitted +func clearomitted(text []string, f *FileSet) error { + f.ClearOmitted = true + return nil +} diff --git a/parse/getast.go b/parse/getast.go index 7d6cebe9..35c2bd8f 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -21,7 +21,8 @@ type FileSet struct { Identities map[string]gen.Elem // processed from specs Directives []string // raw preprocessor directives Imports []*ast.ImportSpec // imports - CompactFloats bool // Use smaller floats when feasible. + CompactFloats bool // Use smaller floats when feasible + ClearOmitted bool // Set omitted fields to zero value tagName string // tag to read field names from pointerRcv bool // generate with pointer receivers. } @@ -271,6 +272,7 @@ loop: } } p.CompactFloats = f.CompactFloats + p.ClearOmitted = f.ClearOmitted } func (f *FileSet) PrintTo(p *gen.Printer) error {