Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serde options refactor #506

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/sr/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (cl *Client) SubjectsByID(ctx context.Context, id int) ([]string, error) {
return subjects, err
}

// SchemaVersion is a subject version pair.
// SubjectVersion is a subject version pair.
type SubjectVersion struct {
Subject string `json:"subject"`
Version int `json:"version"`
Expand Down Expand Up @@ -602,7 +602,7 @@ type SetCompatibility struct {
OverrideRuleSet *SchemaRuleSet `json:"overrideRuleSet,omitempty"` // Override rule set used for schema registration.
}

// SetCompatibilitysets the compatibility for each requested subject. The
// SetCompatibility sets the compatibility for each requested subject. The
// global compatibility can be set by either using an empty subject or by
// specifying no subjects. If specifying no subjects, this returns one element.
func (cl *Client) SetCompatibility(ctx context.Context, compat SetCompatibility, subjects ...string) []CompatibilityResult {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sr/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Client struct {
}

// NewClient returns a new schema registry client.
func NewClient(opts ...Opt) (*Client, error) {
func NewClient(opts ...ClientOpt) (*Client, error) {
cl := &Client{
urls: []string{"http://localhost:8081"},
httpcl: &http.Client{Timeout: 5 * time.Second},
Expand Down
32 changes: 16 additions & 16 deletions pkg/sr/config.go → pkg/sr/clientopt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,30 @@ import (
)

type (
// Opt is an option to configure a client.
Opt interface{ apply(*Client) }
opt struct{ fn func(*Client) }
// ClientOpt is an option to configure a client.
ClientOpt interface{ apply(*Client) }
clientOpt struct{ fn func(*Client) }
)

func (o opt) apply(cl *Client) { o.fn(cl) }
func (o clientOpt) apply(cl *Client) { o.fn(cl) }

// HTTPClient sets the http client that the schema registry client uses,
// overriding the default client that speaks plaintext with a timeout of 5s.
func HTTPClient(httpcl *http.Client) Opt {
return opt{func(cl *Client) { cl.httpcl = httpcl }}
func HTTPClient(httpcl *http.Client) ClientOpt {
return clientOpt{func(cl *Client) { cl.httpcl = httpcl }}
}

// UserAgent sets the User-Agent to use in requests, overriding the default
// "franz-go".
func UserAgent(ua string) Opt {
return opt{func(cl *Client) { cl.ua = ua }}
func UserAgent(ua string) ClientOpt {
return clientOpt{func(cl *Client) { cl.ua = ua }}
}

// URLs sets the URLs that the client speaks to, overriding the default
// http://localhost:8081. This option automatically prefixes any URL that is
// missing an http:// or https:// prefix with http://.
func URLs(urls ...string) Opt {
return opt{func(cl *Client) {
func URLs(urls ...string) ClientOpt {
return clientOpt{func(cl *Client) {
for i, u := range urls {
if strings.HasPrefix(u, "http://") || strings.HasPrefix(u, "https://") {
continue
Expand All @@ -45,8 +45,8 @@ func URLs(urls ...string) Opt {
}

// DialTLSConfig sets a tls.Config to use in the default http client.
func DialTLSConfig(c *tls.Config) Opt {
return opt{func(cl *Client) {
func DialTLSConfig(c *tls.Config) ClientOpt {
return clientOpt{func(cl *Client) {
cl.httpcl = &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
Expand All @@ -68,8 +68,8 @@ func DialTLSConfig(c *tls.Config) Opt {
}

// BasicAuth sets basic authorization to use for every request.
func BasicAuth(user, pass string) Opt {
return opt{func(cl *Client) {
func BasicAuth(user, pass string) ClientOpt {
return clientOpt{func(cl *Client) {
cl.basicAuth = &struct {
user string
pass string
Expand All @@ -78,8 +78,8 @@ func BasicAuth(user, pass string) Opt {
}

// DefaultParams sets default parameters to apply to every request.
func DefaultParams(ps ...Param) Opt {
return opt{func(cl *Client) {
func DefaultParams(ps ...Param) ClientOpt {
return clientOpt{func(cl *Client) {
cl.defParams = mergeParams(ps...)
}}
}
2 changes: 1 addition & 1 deletion pkg/sr/enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (k *SchemaRuleKind) UnmarshalText(text []byte) error {
return nil
}

// Mode specifies a schema rule's mode.
// SchemaRuleMode specifies a schema rule's mode.
//
// Migration rules can be specified for an UPGRADE, DOWNGRADE, or both
// (UPDOWN). Migration rules are used during complex schema evolution.
Expand Down
148 changes: 70 additions & 78 deletions pkg/sr/serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,52 +20,6 @@ var (
ErrBadHeader = errors.New("5 byte header for value is missing or does not have 0 magic byte")
)

type (
// SerdeOpt is an option to configure a Serde.
SerdeOpt interface{ apply(*tserde) }
serdeOpt struct{ fn func(*tserde) }
)

func (o serdeOpt) apply(t *tserde) { o.fn(t) }

// EncodeFn allows Serde to encode a value.
func EncodeFn(fn func(any) ([]byte, error)) SerdeOpt {
return serdeOpt{func(t *tserde) { t.encode = fn }}
}

// AppendEncodeFn allows Serde to encode a value to an existing slice. This
// can be more efficient than EncodeFn; this function is used if it exists.
func AppendEncodeFn(fn func([]byte, any) ([]byte, error)) SerdeOpt {
return serdeOpt{func(t *tserde) { t.appendEncode = fn }}
}

// DecodeFn allows Serde to decode into a value.
func DecodeFn(fn func([]byte, any) error) SerdeOpt {
return serdeOpt{func(t *tserde) { t.decode = fn }}
}

// GenerateFn returns a new(Value) that can be decoded into. This function can
// be used to control the instantiation of a new type for DecodeNew.
func GenerateFn(fn func() any) SerdeOpt {
return serdeOpt{func(t *tserde) { t.gen = fn }}
}

// Index attaches a message index to a value. A single schema ID can be
// registered multiple times with different indices.
//
// This option supports schemas that encode many different values from the same
// schema (namely, protobuf). The index into the schema to encode a
// particular message is specified with `index`.
//
// NOTE: this option must be used for protobuf schemas.
//
// For more information, see where `message-indexes` are described in:
//
// https://docs.confluent.io/platform/current/schema-registry/serdes-develop/index.html#wire-format
func Index(index ...int) SerdeOpt {
return serdeOpt{func(t *tserde) { t.index = index }}
}

type tserde struct {
id uint32
exists bool
Expand Down Expand Up @@ -96,7 +50,7 @@ type Serde struct {
types atomic.Value // map[reflect.Type]tserde
mu sync.Mutex

defaults []SerdeOpt
defaults []EncodingOpt
h SerdeHeader
}

Expand All @@ -105,6 +59,25 @@ var (
noTypes = make(map[reflect.Type]tserde)
)

// NewSerde returns a new Serde using the supplied default options, which are
// applied to every registered type. These options are always applied first, so
// you can override them as necessary when registering.
//
// This can be useful if you always want to use the same encoding or decoding
// functions.
func NewSerde(opts ...SerdeOrEncodingOpt) *Serde {
var s Serde
for _, opt := range opts {
switch opt := opt.(type) {
case SerdeOpt:
opt.apply(&s)
case EncodingOpt:
s.defaults = append(s.defaults, opt)
}
}
return &s
}

func (s *Serde) loadIDs() map[int]tserde {
ids := s.ids.Load()
if ids == nil {
Expand All @@ -121,16 +94,6 @@ func (s *Serde) loadTypes() map[reflect.Type]tserde {
return types.(map[reflect.Type]tserde)
}

// SetDefaults sets default options to apply to every registered type. These
// options are always applied first, so you can override them as necessary when
// registering.
//
// This can be useful if you always want to use the same encoding or decoding
// functions.
func (s *Serde) SetDefaults(opts ...SerdeOpt) {
s.defaults = opts
}

// DecodeID decodes an ID from b, returning the ID and the remaining bytes,
// or an error.
func (s *Serde) DecodeID(b []byte) (id int, out []byte, err error) {
Expand All @@ -154,7 +117,7 @@ func (s *Serde) header() SerdeHeader {
// Register registers a schema ID and the value it corresponds to, as well as
// the encoding or decoding functions. You need to register functions depending
// on whether you are only encoding, only decoding, or both.
func (s *Serde) Register(id int, v any, opts ...SerdeOpt) {
func (s *Serde) Register(id int, v any, opts ...EncodingOpt) {
var t tserde
for _, opt := range s.defaults {
opt.apply(&t)
Expand Down Expand Up @@ -258,28 +221,31 @@ func (s *Serde) Encode(v any) ([]byte, error) {
return s.AppendEncode(nil, v)
}

// AppendEncode appends an encoded value to b according to the schema registry
// wire format and returns it. If EncodeFn was not used, this returns
// ErrNotRegistered.
// AppendEncode encodes a value and prepends the header according to the
// configured SerdeHeader, appends it to b and returns b. If EncodeFn was not
// registered, this returns ErrNotRegistered.
func (s *Serde) AppendEncode(b []byte, v any) ([]byte, error) {
t, ok := s.loadTypes()[reflect.TypeOf(v)]
if !ok || (t.encode == nil && t.appendEncode == nil) {
return b, ErrNotRegistered
}
// Load tserde based on the registered type.
t := s.loadTypes()[reflect.TypeOf(v)]

b, err := s.header().AppendEncode(b, int(t.id), t.index)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this was lost (and not added in one of the new top level functions)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, it's used in the new call to AppendEncode below.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might rather have a little bit of duplication, not sure yet.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep leaning duplication; I rebased on master and squashed all of your commits and did a few touchups in #742.

if err != nil {
return nil, err
// Check if we loaded a valid tserde.
if !t.exists || (t.encode == nil && t.appendEncode == nil) {
return nil, ErrNotRegistered
}

if t.appendEncode != nil {
return t.appendEncode(b, v)
}
encoded, err := t.encode(v)
if err != nil {
return nil, err
appendEncode := t.appendEncode
if appendEncode == nil {
// Fallback to t.encode.
appendEncode = func(b []byte, v any) ([]byte, error) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this causes an extra alloc in the encode path.

encoded, err := t.encode(v)
if err != nil {
return nil, err
}
return append(b, encoded...), nil
}
}
return append(b, encoded...), nil

return AppendEncode(b, v, int(t.id), t.index, s.header(), appendEncode)
}

// MustEncode returns the value of Encode, panicking on error. This is a
Expand Down Expand Up @@ -328,8 +294,10 @@ func (s *Serde) DecodeNew(b []byte) (any, error) {
var v any
if t.gen != nil {
v = t.gen()
} else {
} else if t.typeof != nil {
v = reflect.New(t.typeof).Interface()
} else {
return nil, ErrNotRegistered
}
return v, t.decode(b, v)
}
Expand All @@ -339,7 +307,6 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) {
if err != nil {
return nil, tserde{}, err
}

t := s.loadIDs()[id]
if len(t.subindex) > 0 {
var index []int
Expand All @@ -354,12 +321,37 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) {
t = t.subindex[idx]
}
}
if !t.exists {

// Check if we loaded a valid tserde.
if !t.exists || t.decode == nil {
return nil, tserde{}, ErrNotRegistered
}

return b, t, nil
}

// Encode encodes a value and prepends the header. If the encoding function
// fails, this returns an error.
func Encode(v any, id int, index []int, h SerdeHeader, enc func(any) ([]byte, error)) ([]byte, error) {
return AppendEncode(nil, v, id, index, h, func(b []byte, val any) ([]byte, error) {
encoded, err := enc(val)
if err != nil {
return nil, err
}
return append(b, encoded...), nil
})
}

// AppendEncode encodes a value and prepends the header, appends it to b and
// returns b. If the encoding function fails, this returns an error.
func AppendEncode(b []byte, v any, id int, index []int, h SerdeHeader, enc func([]byte, any) ([]byte, error)) ([]byte, error) {
b, err := h.AppendEncode(b, id, index)
if err != nil {
return nil, err
}
return enc(b, v)
}

// SerdeHeader encodes and decodes a message header.
type SerdeHeader interface {
// AppendEncode encodes a schema ID and optional index to b, returning the
Expand Down
31 changes: 29 additions & 2 deletions pkg/sr/serde_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ func TestSerde(t *testing.T) {
}
)

var serde Serde
serde.SetDefaults(
serde := NewSerde(
EncodeFn(json.Marshal),
DecodeFn(json.Unmarshal),
)
Expand Down Expand Up @@ -113,6 +112,19 @@ func TestSerde(t *testing.T) {
t.Errorf("#%d got MustAppendEncode(%v) != Encode(foo%v)", i, b2, b)
}

bIndented, err := Encode(test.enc, 100, []int{0}, serde.header(), func(v any) ([]byte, error) {
return json.MarshalIndent(v, "", " ")
})
if err != nil {
t.Errorf("#%d Encode[ID=100]: got err? %v, exp err? %v", i, gotErr, test.expErr)
continue
}
if i := bytes.IndexByte(bIndented, '{'); !bytes.Equal(bIndented[:i], []byte{0, 0, 0, 0, 100, 0}) {
t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[:i], []byte{0, 0, 0, 0, 100, 0})
} else if expIndented := extractIndentedJSON(b); !bytes.Equal(bIndented[i:], expIndented) {
t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[i:], expIndented)
}

v, err := serde.DecodeNew(b)
if err != nil {
t.Errorf("#%d DecodeNew: got unexpected err %v", i, err)
Expand All @@ -126,6 +138,7 @@ func TestSerde(t *testing.T) {
}
if !reflect.DeepEqual(v, exp) {
t.Errorf("#%d round trip: got %v != exp %v", i, v, exp)
continue
}
}

Expand All @@ -141,6 +154,20 @@ func TestSerde(t *testing.T) {
if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 99}); err != ErrNotRegistered {
t.Errorf("got %v != exp ErrNotRegistered", err)
}
if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 100, 0}); err != ErrNotRegistered {
// schema is registered but type is unknown
t.Errorf("got %v != exp ErrNotRegistered", err)
}
}

func extractIndentedJSON(in []byte) []byte {
i := bytes.IndexByte(in, '{') // skip header
var out bytes.Buffer
err := json.Indent(&out, in[i:], "", " ")
if err != nil {
panic(err)
}
return out.Bytes()
}

func TestConfluentHeader(t *testing.T) {
Expand Down
Loading
Loading