Skip to content

Commit

Permalink
Merge pull request #742 from twmb/lovro
Browse files Browse the repository at this point in the history
pkg/sr: touchups
  • Loading branch information
twmb authored May 26, 2024
2 parents d1514f9 + f8e5acf commit d70b761
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 55 deletions.
4 changes: 2 additions & 2 deletions pkg/sr/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (cl *Client) SubjectsByID(ctx context.Context, id int) ([]string, error) {
return subjects, err
}

// SchemaVersion is a subject version pair.
// SubjectVersion is a subject version pair.
type SubjectVersion struct {
Subject string `json:"subject"`
Version int `json:"version"`
Expand Down Expand Up @@ -602,7 +602,7 @@ type SetCompatibility struct {
OverrideRuleSet *SchemaRuleSet `json:"overrideRuleSet,omitempty"` // Override rule set used for schema registration.
}

// SetCompatibilitysets the compatibility for each requested subject. The
// SetCompatibility sets the compatibility for each requested subject. The
// global compatibility can be set by either using an empty subject or by
// specifying no subjects. If specifying no subjects, this returns one element.
func (cl *Client) SetCompatibility(ctx context.Context, compat SetCompatibility, subjects ...string) []CompatibilityResult {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sr/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Client struct {
}

// NewClient returns a new schema registry client.
func NewClient(opts ...Opt) (*Client, error) {
func NewClient(opts ...ClientOpt) (*Client, error) {
cl := &Client{
urls: []string{"http://localhost:8081"},
httpcl: &http.Client{Timeout: 5 * time.Second},
Expand Down
32 changes: 16 additions & 16 deletions pkg/sr/config.go → pkg/sr/clientopt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,30 @@ import (
)

type (
// Opt is an option to configure a client.
Opt interface{ apply(*Client) }
opt struct{ fn func(*Client) }
// ClientOpt is an option to configure a client.
ClientOpt interface{ apply(*Client) }
clientOpt struct{ fn func(*Client) }
)

func (o opt) apply(cl *Client) { o.fn(cl) }
func (o clientOpt) apply(cl *Client) { o.fn(cl) }

// HTTPClient sets the http client that the schema registry client uses,
// overriding the default client that speaks plaintext with a timeout of 5s.
func HTTPClient(httpcl *http.Client) Opt {
return opt{func(cl *Client) { cl.httpcl = httpcl }}
func HTTPClient(httpcl *http.Client) ClientOpt {
return clientOpt{func(cl *Client) { cl.httpcl = httpcl }}
}

// UserAgent sets the User-Agent to use in requests, overriding the default
// "franz-go".
func UserAgent(ua string) Opt {
return opt{func(cl *Client) { cl.ua = ua }}
func UserAgent(ua string) ClientOpt {
return clientOpt{func(cl *Client) { cl.ua = ua }}
}

// URLs sets the URLs that the client speaks to, overriding the default
// http://localhost:8081. This option automatically prefixes any URL that is
// missing an http:// or https:// prefix with http://.
func URLs(urls ...string) Opt {
return opt{func(cl *Client) {
func URLs(urls ...string) ClientOpt {
return clientOpt{func(cl *Client) {
for i, u := range urls {
if strings.HasPrefix(u, "http://") || strings.HasPrefix(u, "https://") {
continue
Expand All @@ -45,8 +45,8 @@ func URLs(urls ...string) Opt {
}

// DialTLSConfig sets a tls.Config to use in the default http client.
func DialTLSConfig(c *tls.Config) Opt {
return opt{func(cl *Client) {
func DialTLSConfig(c *tls.Config) ClientOpt {
return clientOpt{func(cl *Client) {
cl.httpcl = &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
Expand All @@ -68,8 +68,8 @@ func DialTLSConfig(c *tls.Config) Opt {
}

// BasicAuth sets basic authorization to use for every request.
func BasicAuth(user, pass string) Opt {
return opt{func(cl *Client) {
func BasicAuth(user, pass string) ClientOpt {
return clientOpt{func(cl *Client) {
cl.basicAuth = &struct {
user string
pass string
Expand All @@ -78,8 +78,8 @@ func BasicAuth(user, pass string) Opt {
}

// DefaultParams sets default parameters to apply to every request.
func DefaultParams(ps ...Param) Opt {
return opt{func(cl *Client) {
func DefaultParams(ps ...Param) ClientOpt {
return clientOpt{func(cl *Client) {
cl.defParams = mergeParams(ps...)
}}
}
2 changes: 1 addition & 1 deletion pkg/sr/enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (k *SchemaRuleKind) UnmarshalText(text []byte) error {
return nil
}

// Mode specifies a schema rule's mode.
// SchemaRuleMode specifies a schema rule's mode.
//
// Migration rules can be specified for an UPGRADE, DOWNGRADE, or both
// (UPDOWN). Migration rules are used during complex schema evolution.
Expand Down
121 changes: 88 additions & 33 deletions pkg/sr/serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,53 @@ var (
)

type (
// SerdeOpt is an option to configure a Serde.
SerdeOpt interface{ apply(*tserde) }
serdeOpt struct{ fn func(*tserde) }
// EncodingOpt is an option to configure the behavior of Serde.Encode and
// Serde.Decode.
EncodingOpt interface {
serdeOrEncodingOpt()
apply(*tserde)
}
encodingOpt struct{ fn func(*tserde) }

// SerdeOpt is an option to configure Serde.
SerdeOpt interface {
serdeOrEncodingOpt()
apply(*Serde)
}
serdeOpt struct{ fn func(serde *Serde) }

// SerdeOrEncodingOpt is either a SerdeOpt or EncodingOpt.
SerdeOrEncodingOpt interface {
serdeOrEncodingOpt()
}
)

func (o serdeOpt) apply(t *tserde) { o.fn(t) }
func (o serdeOpt) serdeOrEncodingOpt() { /* satisfy interface */ }
func (o serdeOpt) apply(s *Serde) { o.fn(s) }

func (o encodingOpt) serdeOrEncodingOpt() { /* satisfy interface */ }
func (o encodingOpt) apply(t *tserde) { o.fn(t) }

// EncodeFn allows Serde to encode a value.
func EncodeFn(fn func(any) ([]byte, error)) SerdeOpt {
return serdeOpt{func(t *tserde) { t.encode = fn }}
func EncodeFn(fn func(any) ([]byte, error)) EncodingOpt {
return encodingOpt{func(t *tserde) { t.encode = fn }}
}

// AppendEncodeFn allows Serde to encode a value to an existing slice. This
// can be more efficient than EncodeFn; this function is used if it exists.
func AppendEncodeFn(fn func([]byte, any) ([]byte, error)) SerdeOpt {
return serdeOpt{func(t *tserde) { t.appendEncode = fn }}
func AppendEncodeFn(fn func([]byte, any) ([]byte, error)) EncodingOpt {
return encodingOpt{func(t *tserde) { t.appendEncode = fn }}
}

// DecodeFn allows Serde to decode into a value.
func DecodeFn(fn func([]byte, any) error) SerdeOpt {
return serdeOpt{func(t *tserde) { t.decode = fn }}
func DecodeFn(fn func([]byte, any) error) EncodingOpt {
return encodingOpt{func(t *tserde) { t.decode = fn }}
}

// GenerateFn returns a new(Value) that can be decoded into. This function can
// be used to control the instantiation of a new type for DecodeNew.
func GenerateFn(fn func() any) SerdeOpt {
return serdeOpt{func(t *tserde) { t.gen = fn }}
func GenerateFn(fn func() any) EncodingOpt {
return encodingOpt{func(t *tserde) { t.gen = fn }}
}

// Index attaches a message index to a value. A single schema ID can be
Expand All @@ -62,8 +82,13 @@ func GenerateFn(fn func() any) SerdeOpt {
// For more information, see where `message-indexes` are described in:
//
// https://docs.confluent.io/platform/current/schema-registry/serdes-develop/index.html#wire-format
func Index(index ...int) SerdeOpt {
return serdeOpt{func(t *tserde) { t.index = index }}
func Index(index ...int) EncodingOpt {
return encodingOpt{func(t *tserde) { t.index = index }}
}

// Header defines the SerdeHeader used to encode and decode the message header.
func Header(header SerdeHeader) SerdeOpt {
return serdeOpt{func(s *Serde) { s.h = header }}
}

type tserde struct {
Expand Down Expand Up @@ -96,7 +121,7 @@ type Serde struct {
types atomic.Value // map[reflect.Type]tserde
mu sync.Mutex

defaults []SerdeOpt
defaults []EncodingOpt
h SerdeHeader
}

Expand All @@ -105,6 +130,25 @@ var (
noTypes = make(map[reflect.Type]tserde)
)

// NewSerde returns a new Serde using the supplied default options, which are
// applied to every registered type. These options are always applied first, so
// you can override them as necessary when registering.
//
// This can be useful if you always want to use the same encoding or decoding
// functions.
func NewSerde(opts ...SerdeOrEncodingOpt) *Serde {
var s Serde
for _, opt := range opts {
switch opt := opt.(type) {
case SerdeOpt:
opt.apply(&s)
case EncodingOpt:
s.defaults = append(s.defaults, opt)
}
}
return &s
}

func (s *Serde) loadIDs() map[int]tserde {
ids := s.ids.Load()
if ids == nil {
Expand All @@ -121,16 +165,6 @@ func (s *Serde) loadTypes() map[reflect.Type]tserde {
return types.(map[reflect.Type]tserde)
}

// SetDefaults sets default options to apply to every registered type. These
// options are always applied first, so you can override them as necessary when
// registering.
//
// This can be useful if you always want to use the same encoding or decoding
// functions.
func (s *Serde) SetDefaults(opts ...SerdeOpt) {
s.defaults = opts
}

// DecodeID decodes an ID from b, returning the ID and the remaining bytes,
// or an error.
func (s *Serde) DecodeID(b []byte) (id int, out []byte, err error) {
Expand All @@ -154,7 +188,7 @@ func (s *Serde) header() SerdeHeader {
// Register registers a schema ID and the value it corresponds to, as well as
// the encoding or decoding functions. You need to register functions depending
// on whether you are only encoding, only decoding, or both.
func (s *Serde) Register(id int, v any, opts ...SerdeOpt) {
func (s *Serde) Register(id int, v any, opts ...EncodingOpt) {
var t tserde
for _, opt := range s.defaults {
opt.apply(&t)
Expand Down Expand Up @@ -258,20 +292,18 @@ func (s *Serde) Encode(v any) ([]byte, error) {
return s.AppendEncode(nil, v)
}

// AppendEncode appends an encoded value to b according to the schema registry
// wire format and returns it. If EncodeFn was not used, this returns
// ErrNotRegistered.
// AppendEncode encodes a value and prepends the header according to the
// configured SerdeHeader, appends it to b and returns b. If EncodeFn was not
// registered, this returns ErrNotRegistered.
func (s *Serde) AppendEncode(b []byte, v any) ([]byte, error) {
t, ok := s.loadTypes()[reflect.TypeOf(v)]
if !ok || (t.encode == nil && t.appendEncode == nil) {
return b, ErrNotRegistered
}

b, err := s.header().AppendEncode(b, int(t.id), t.index)
if err != nil {
return nil, err
}

if t.appendEncode != nil {
return t.appendEncode(b, v)
}
Expand Down Expand Up @@ -328,8 +360,10 @@ func (s *Serde) DecodeNew(b []byte) (any, error) {
var v any
if t.gen != nil {
v = t.gen()
} else {
} else if t.typeof != nil {
v = reflect.New(t.typeof).Interface()
} else {
return nil, ErrNotRegistered
}
return v, t.decode(b, v)
}
Expand All @@ -339,7 +373,6 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) {
if err != nil {
return nil, tserde{}, err
}

t := s.loadIDs()[id]
if len(t.subindex) > 0 {
var index []int
Expand All @@ -360,6 +393,28 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) {
return b, t, nil
}

// Encode encodes a value and prepends the header. If the encoding function
// fails, this returns an error.
func Encode(v any, id int, index []int, h SerdeHeader, enc func(any) ([]byte, error)) ([]byte, error) {
return AppendEncode(nil, v, id, index, h, func(b []byte, val any) ([]byte, error) {
encoded, err := enc(val)
if err != nil {
return nil, err
}
return append(b, encoded...), nil
})
}

// AppendEncode encodes a value and prepends the header, appends it to b and
// returns b. If the encoding function fails, this returns an error.
func AppendEncode(b []byte, v any, id int, index []int, h SerdeHeader, enc func([]byte, any) ([]byte, error)) ([]byte, error) {
b, err := h.AppendEncode(b, id, index)
if err != nil {
return nil, err
}
return enc(b, v)
}

// SerdeHeader encodes and decodes a message header.
type SerdeHeader interface {
// AppendEncode encodes a schema ID and optional index to b, returning the
Expand Down
31 changes: 29 additions & 2 deletions pkg/sr/serde_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ func TestSerde(t *testing.T) {
}
)

var serde Serde
serde.SetDefaults(
serde := NewSerde(
EncodeFn(json.Marshal),
DecodeFn(json.Unmarshal),
)
Expand Down Expand Up @@ -113,6 +112,19 @@ func TestSerde(t *testing.T) {
t.Errorf("#%d got MustAppendEncode(%v) != Encode(foo%v)", i, b2, b)
}

bIndented, err := Encode(test.enc, 100, []int{0}, serde.header(), func(v any) ([]byte, error) {
return json.MarshalIndent(v, "", " ")
})
if err != nil {
t.Errorf("#%d Encode[ID=100]: got err? %v, exp err? %v", i, gotErr, test.expErr)
continue
}
if i := bytes.IndexByte(bIndented, '{'); !bytes.Equal(bIndented[:i], []byte{0, 0, 0, 0, 100, 0}) {
t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[:i], []byte{0, 0, 0, 0, 100, 0})
} else if expIndented := extractIndentedJSON(b); !bytes.Equal(bIndented[i:], expIndented) {
t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[i:], expIndented)
}

v, err := serde.DecodeNew(b)
if err != nil {
t.Errorf("#%d DecodeNew: got unexpected err %v", i, err)
Expand All @@ -126,6 +138,7 @@ func TestSerde(t *testing.T) {
}
if !reflect.DeepEqual(v, exp) {
t.Errorf("#%d round trip: got %v != exp %v", i, v, exp)
continue
}
}

Expand All @@ -141,6 +154,20 @@ func TestSerde(t *testing.T) {
if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 99}); err != ErrNotRegistered {
t.Errorf("got %v != exp ErrNotRegistered", err)
}
if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 100, 0}); err != ErrNotRegistered {
// schema is registered but type is unknown
t.Errorf("got %v != exp ErrNotRegistered", err)
}
}

func extractIndentedJSON(in []byte) []byte {
i := bytes.IndexByte(in, '{') // skip header
var out bytes.Buffer
err := json.Indent(&out, in[i:], "", " ")
if err != nil {
panic(err)
}
return out.Bytes()
}

func TestConfluentHeader(t *testing.T) {
Expand Down

0 comments on commit d70b761

Please sign in to comment.