Skip to content

Commit

Permalink
Add "clearomitted" directive (#373)
Browse files Browse the repository at this point in the history
* Add "clearomitted" directive

Adds `//msgp:clearomitted` directive.

This will cause all `omitempty` and `omitzero` fields to be set to the zero value on Unmarshal and Decode, and the field wasn't present in the marshalled data.

This can be useful when de-serializing into reused objects, that can have values set for these, and we want to avoid clearing all fields, but also don't want existing fields to leak through.

Fields are tracked through a bit mask, and zeroed after the unmarshal loop, if no value has been written.

This does not affect marshaling.
  • Loading branch information
klauspost authored Oct 22, 2024
1 parent 9279415 commit 6da1c88
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 4 deletions.
97 changes: 97 additions & 0 deletions _generated/clearomitted.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package _generated

import (
"encoding/json"
"time"
)

//go:generate msgp

//msgp:clearomitted

// check some specific cases for omitzero

type ClearOmitted0 struct {
AStruct ClearOmittedA `msg:"astruct,omitempty"` // leave this one omitempty
BStruct ClearOmittedA `msg:"bstruct,omitzero"` // and compare to this
AStructPtr *ClearOmittedA `msg:"astructptr,omitempty"` // a pointer case omitempty
BStructPtr *ClearOmittedA `msg:"bstructptr,omitzero"` // a pointer case omitzero
AExt OmitZeroExt `msg:"aext,omitzero"` // external type case

// more
APtrNamedStr *NamedStringCO `msg:"aptrnamedstr,omitzero"`
ANamedStruct NamedStructCO `msg:"anamedstruct,omitzero"`
APtrNamedStruct *NamedStructCO `msg:"aptrnamedstruct,omitzero"`
EmbeddableStructCO `msg:",flatten,omitzero"` // embed flat
EmbeddableStructCO2 `msg:"embeddablestruct2,omitzero"` // embed non-flat
ATime time.Time `msg:"atime,omitzero"`
ASlice []int `msg:"aslice,omitempty"`
AMap map[string]int `msg:"amap,omitempty"`
ABin []byte `msg:"abin,omitempty"`
AInt int `msg:"aint,omitempty"`
AString string `msg:"atring,omitempty"`
Adur time.Duration `msg:"adur,omitempty"`
AJSON json.Number `msg:"ajson,omitempty"`

ClearOmittedTuple ClearOmittedTuple `msg:"ozt"` // the inside of a tuple should ignore both omitempty and omitzero
}

type ClearOmittedA struct {
A string `msg:"a,omitempty"`
B NamedStringCO `msg:"b,omitzero"`
C NamedStringCO `msg:"c,omitzero"`
}

func (o *ClearOmittedA) IsZero() bool {
if o == nil {
return true
}
return *o == (ClearOmittedA{})
}

type NamedStructCO struct {
A string `msg:"a,omitempty"`
B string `msg:"b,omitempty"`
}

func (ns *NamedStructCO) IsZero() bool {
if ns == nil {
return true
}
return *ns == (NamedStructCO{})
}

type NamedStringCO string

func (ns *NamedStringCO) IsZero() bool {
if ns == nil {
return true
}
return *ns == ""
}

type EmbeddableStructCO struct {
SomeEmbed string `msg:"someembed2,omitempty"`
}

func (es EmbeddableStructCO) IsZero() bool { return es == (EmbeddableStructCO{}) }

type EmbeddableStructCO2 struct {
SomeEmbed2 string `msg:"someembed2,omitempty"`
}

func (es EmbeddableStructCO2) IsZero() bool { return es == (EmbeddableStructCO2{}) }

//msgp:tuple ClearOmittedTuple

// ClearOmittedTuple is flagged for tuple output, it should ignore all omitempty and omitzero functionality
// since it's fundamentally incompatible.
type ClearOmittedTuple struct {
FieldA string `msg:"fielda,omitempty"`
FieldB NamedStringCO `msg:"fieldb,omitzero"`
FieldC NamedStringCO `msg:"fieldc,omitzero"`
}

type ClearOmitted1 struct {
T1 ClearOmittedTuple `msg:"t1"`
}
93 changes: 93 additions & 0 deletions _generated/clearomitted_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package _generated

import (
"bytes"
"encoding/json"
"reflect"
"testing"
"time"

"github.com/tinylib/msgp/msgp"
)

func TestClearOmitted(t *testing.T) {
cleared := ClearOmitted0{}
encoded, err := cleared.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
vPtr := NamedStringCO("value")
filled := ClearOmitted0{
AStruct: ClearOmittedA{A: "something"},
BStruct: ClearOmittedA{A: "somthing"},
AStructPtr: &ClearOmittedA{A: "something"},
AExt: OmitZeroExt{25},
APtrNamedStr: &vPtr,
ANamedStruct: NamedStructCO{A: "value"},
APtrNamedStruct: &NamedStructCO{A: "sdf"},
EmbeddableStructCO: EmbeddableStructCO{"value"},
EmbeddableStructCO2: EmbeddableStructCO2{"value"},
ATime: time.Now(),
ASlice: []int{1, 2, 3},
AMap: map[string]int{"1": 1},
ABin: []byte{1, 2, 3},
ClearOmittedTuple: ClearOmittedTuple{FieldA: "value"},
AInt: 42,
AString: "value",
Adur: time.Second,
AJSON: json.Number(`43.0000000000002`),
}
dst := filled
_, err = dst.UnmarshalMsg(encoded)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(dst, cleared) {
t.Errorf("\n got=%#v\nwant=%#v", dst, cleared)
}
// Reset
dst = filled
err = dst.DecodeMsg(msgp.NewReader(bytes.NewReader(encoded)))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(dst, cleared) {
t.Errorf("\n got=%#v\nwant=%#v", dst, cleared)
}

// Check that fields aren't accidentally zeroing fields.
wantJson, err := json.Marshal(filled)
if err != nil {
t.Fatal(err)
}
encoded, err = filled.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
dst = ClearOmitted0{}
_, err = dst.UnmarshalMsg(encoded)
if err != nil {
t.Fatal(err)
}
got, err := json.Marshal(dst)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, wantJson) {
t.Errorf("\n got=%#v\nwant=%#v", string(got), string(wantJson))
}
// Reset
dst = ClearOmitted0{}
err = dst.DecodeMsg(msgp.NewReader(bytes.NewReader(encoded)))
if err != nil {
t.Fatal(err)
}
got, err = json.Marshal(dst)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, wantJson) {
t.Errorf("\n got=%#v\nwant=%#v", string(got), string(wantJson))
}
t.Log("OK - got", string(got))
}
38 changes: 38 additions & 0 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ func (d *decodeGen) structAsMap(s *Struct) {
d.p.declare(sz, u32)
d.assignAndCheck(sz, mapHeader)

oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero")
if !d.ctx.clearOmitted {
oeCount = 0
}
bm := bmask{
bitlen: oeCount,
varname: sz + "Mask",
}
if oeCount > 0 {
// Declare mask
d.p.printf("\n%s", bm.typeDecl())
d.p.printf("\n_ = %s", bm.varname)
}
// Index to field idx of each emitted
oeEmittedIdx := []int{}

d.p.printf("\nfor %s > 0 {\n%s--", sz, sz)
d.assignAndCheck("field", mapKey)
d.p.print("\nswitch msgp.UnsafeString(field) {")
Expand All @@ -123,6 +139,10 @@ func (d *decodeGen) structAsMap(s *Struct) {
}
SetIsAllowNil(fieldElem, anField)
next(d, fieldElem)
if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) {
d.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx)))
oeEmittedIdx = append(oeEmittedIdx, i)
}
d.ctx.Pop()
if !d.p.ok() {
return
Expand All @@ -136,6 +156,24 @@ func (d *decodeGen) structAsMap(s *Struct) {

d.p.closeblock() // close switch
d.p.closeblock() // close for loop

if oeCount > 0 {
d.p.printf("\n// Clear omitted fields.\n")
d.p.printf("if %s {\n", bm.notAllSet())
for bitIdx, fieldIdx := range oeEmittedIdx {
fieldElem := s.Fields[fieldIdx].FieldElem

d.p.printf("if %s == 0 {\n", bm.readExpr(bitIdx))
fze := fieldElem.ZeroExpr()
if fze != "" {
d.p.printf("%s = %s\n", fieldElem.Varname(), fze)
} else {
d.p.printf("%s = %s{}\n", fieldElem.Varname(), fieldElem.TypeName())
}
d.p.printf("}\n")
}
d.p.printf("}")
}
}

func (d *decodeGen) gBase(b *BaseElem) {
Expand Down
11 changes: 11 additions & 0 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,17 @@ func (s *Struct) AnyHasTagPart(pname string) bool {
return false
}

// CountFieldTagPart the count of HasTagPart(p) is true for any field.
func (s *Struct) CountFieldTagPart(pname string) int {
var n int
for _, sf := range s.Fields {
if sf.HasTagPart(pname) {
n++
}
}
return n
}

type StructField struct {
FieldTag string // the string inside the `msg:""` tag up to the first comma
FieldTagParts []string // the string inside the `msg:""` tag split by commas
Expand Down
30 changes: 27 additions & 3 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"io"
"strings"
)

const (
Expand Down Expand Up @@ -77,6 +78,7 @@ const (
type Printer struct {
gens []generator
CompactFloats bool
ClearOmitted bool
}

func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
Expand Down Expand Up @@ -145,7 +147,7 @@ func (p *Printer) Print(e Elem) error {
// collisions between idents created during SetVarname and idents created during Print,
// hence the separate prefixes.
resetIdent("zb")
err := g.Execute(e, Context{compFloats: p.CompactFloats})
err := g.Execute(e, Context{compFloats: p.CompactFloats, clearOmitted: p.ClearOmitted})
resetIdent("za")

if err != nil {
Expand All @@ -172,8 +174,9 @@ func (c contextVar) Arg() string {
}

type Context struct {
path []contextItem
compFloats bool
path []contextItem
compFloats bool
clearOmitted bool
}

func (c *Context) PushString(s string) {
Expand Down Expand Up @@ -501,3 +504,24 @@ func (b *bmask) setStmt(bitoffset int) string {

return buf.String()
}

// notAllSet returns a check against all fields having been set in set.
func (b *bmask) notAllSet() string {
var buf bytes.Buffer
buf.Grow(len(b.varname) + 16)
buf.WriteString(b.varname)
if b.bitlen > 64 {
var bytes []string
remain := b.bitlen
for remain >= 8 {
bytes = append(bytes, "0xff")
}
if remain > 0 {
bytes = append(bytes, fmt.Sprintf("0x%X", remain))
}
fmt.Fprintf(&buf, " != [%d]byte{%s}\n", (b.bitlen+63)/64, strings.Join(bytes, ","))
}
fmt.Fprintf(&buf, " != 0x%x", uint64(1<<b.bitlen)-1)

return buf.String()
}
37 changes: 37 additions & 0 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
u.p.declare(sz, u32)
u.assignAndCheck(sz, mapHeader)

oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero")
if !u.ctx.clearOmitted {
oeCount = 0
}
bm := bmask{
bitlen: oeCount,
varname: sz + "Mask",
}
if oeCount > 0 {
// Declare mask
u.p.printf("\n%s", bm.typeDecl())
u.p.printf("\n_ = %s", bm.varname)
}
// Index to field idx of each emitted
oeEmittedIdx := []int{}

u.p.printf("\nfor %s > 0 {", sz)
u.p.printf("\n%s--; field, bts, err = msgp.ReadMapKeyZC(bts)", sz)
u.p.wrapErrCheck(u.ctx.ArgsStr())
Expand All @@ -122,13 +138,34 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
SetIsAllowNil(fieldElem, anField)
next(u, fieldElem)
u.ctx.Pop()
if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) {
u.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx)))
oeEmittedIdx = append(oeEmittedIdx, i)
}
if anField {
u.p.printf("\n}")
}
}
u.p.print("\ndefault:\nbts, err = msgp.Skip(bts)")
u.p.wrapErrCheck(u.ctx.ArgsStr())
u.p.print("\n}\n}") // close switch and for loop
if oeCount > 0 {
u.p.printf("\n// Clear omitted fields.\n")
u.p.printf("if %s {\n", bm.notAllSet())
for bitIdx, fieldIdx := range oeEmittedIdx {
fieldElem := s.Fields[fieldIdx].FieldElem

u.p.printf("if %s == 0 {\n", bm.readExpr(bitIdx))
fze := fieldElem.ZeroExpr()
if fze != "" {
u.p.printf("%s = %s\n", fieldElem.Varname(), fze)
} else {
u.p.printf("%s = %s{}\n", fieldElem.Varname(), fieldElem.TypeName())
}
u.p.printf("}\n")
}
u.p.printf("}")
}
}

func (u *unmarshalGen) gBase(b *BaseElem) {
Expand Down
Loading

0 comments on commit 6da1c88

Please sign in to comment.