Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pkg/sr: touchups #742

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/sr/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (cl *Client) SubjectsByID(ctx context.Context, id int) ([]string, error) {
return subjects, err
}

// SchemaVersion is a subject version pair.
// SubjectVersion is a subject version pair.
type SubjectVersion struct {
Subject string `json:"subject"`
Version int `json:"version"`
Expand Down Expand Up @@ -602,7 +602,7 @@ type SetCompatibility struct {
OverrideRuleSet *SchemaRuleSet `json:"overrideRuleSet,omitempty"` // Override rule set used for schema registration.
}

// SetCompatibilitysets the compatibility for each requested subject. The
// SetCompatibility sets the compatibility for each requested subject. The
// global compatibility can be set by either using an empty subject or by
// specifying no subjects. If specifying no subjects, this returns one element.
func (cl *Client) SetCompatibility(ctx context.Context, compat SetCompatibility, subjects ...string) []CompatibilityResult {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sr/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Client struct {
}

// NewClient returns a new schema registry client.
func NewClient(opts ...Opt) (*Client, error) {
func NewClient(opts ...ClientOpt) (*Client, error) {
cl := &Client{
urls: []string{"http://localhost:8081"},
httpcl: &http.Client{Timeout: 5 * time.Second},
Expand Down
32 changes: 16 additions & 16 deletions pkg/sr/config.go → pkg/sr/clientopt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,30 @@ import (
)

type (
// Opt is an option to configure a client.
Opt interface{ apply(*Client) }
opt struct{ fn func(*Client) }
// ClientOpt is an option to configure a client.
ClientOpt interface{ apply(*Client) }
clientOpt struct{ fn func(*Client) }
)

func (o opt) apply(cl *Client) { o.fn(cl) }
func (o clientOpt) apply(cl *Client) { o.fn(cl) }

// HTTPClient sets the http client that the schema registry client uses,
// overriding the default client that speaks plaintext with a timeout of 5s.
func HTTPClient(httpcl *http.Client) Opt {
return opt{func(cl *Client) { cl.httpcl = httpcl }}
func HTTPClient(httpcl *http.Client) ClientOpt {
return clientOpt{func(cl *Client) { cl.httpcl = httpcl }}
}

// UserAgent sets the User-Agent to use in requests, overriding the default
// "franz-go".
func UserAgent(ua string) Opt {
return opt{func(cl *Client) { cl.ua = ua }}
func UserAgent(ua string) ClientOpt {
return clientOpt{func(cl *Client) { cl.ua = ua }}
}

// URLs sets the URLs that the client speaks to, overriding the default
// http://localhost:8081. This option automatically prefixes any URL that is
// missing an http:// or https:// prefix with http://.
func URLs(urls ...string) Opt {
return opt{func(cl *Client) {
func URLs(urls ...string) ClientOpt {
return clientOpt{func(cl *Client) {
for i, u := range urls {
if strings.HasPrefix(u, "http://") || strings.HasPrefix(u, "https://") {
continue
Expand All @@ -45,8 +45,8 @@ func URLs(urls ...string) Opt {
}

// DialTLSConfig sets a tls.Config to use in the default http client.
func DialTLSConfig(c *tls.Config) Opt {
return opt{func(cl *Client) {
func DialTLSConfig(c *tls.Config) ClientOpt {
return clientOpt{func(cl *Client) {
cl.httpcl = &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
Expand All @@ -68,8 +68,8 @@ func DialTLSConfig(c *tls.Config) Opt {
}

// BasicAuth sets basic authorization to use for every request.
func BasicAuth(user, pass string) Opt {
return opt{func(cl *Client) {
func BasicAuth(user, pass string) ClientOpt {
return clientOpt{func(cl *Client) {
cl.basicAuth = &struct {
user string
pass string
Expand All @@ -78,8 +78,8 @@ func BasicAuth(user, pass string) Opt {
}

// DefaultParams sets default parameters to apply to every request.
func DefaultParams(ps ...Param) Opt {
return opt{func(cl *Client) {
func DefaultParams(ps ...Param) ClientOpt {
return clientOpt{func(cl *Client) {
cl.defParams = mergeParams(ps...)
}}
}
2 changes: 1 addition & 1 deletion pkg/sr/enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (k *SchemaRuleKind) UnmarshalText(text []byte) error {
return nil
}

// Mode specifies a schema rule's mode.
// SchemaRuleMode specifies a schema rule's mode.
//
// Migration rules can be specified for an UPGRADE, DOWNGRADE, or both
// (UPDOWN). Migration rules are used during complex schema evolution.
Expand Down
121 changes: 88 additions & 33 deletions pkg/sr/serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,53 @@ var (
)

type (
// SerdeOpt is an option to configure a Serde.
SerdeOpt interface{ apply(*tserde) }
serdeOpt struct{ fn func(*tserde) }
// EncodingOpt is an option to configure the behavior of Serde.Encode and
// Serde.Decode.
EncodingOpt interface {
serdeOrEncodingOpt()
apply(*tserde)
}
encodingOpt struct{ fn func(*tserde) }

// SerdeOpt is an option to configure Serde.
SerdeOpt interface {
serdeOrEncodingOpt()
apply(*Serde)
}
serdeOpt struct{ fn func(serde *Serde) }

// SerdeOrEncodingOpt is either a SerdeOpt or EncodingOpt.
SerdeOrEncodingOpt interface {
serdeOrEncodingOpt()
}
)

func (o serdeOpt) apply(t *tserde) { o.fn(t) }
func (o serdeOpt) serdeOrEncodingOpt() { /* satisfy interface */ }
func (o serdeOpt) apply(s *Serde) { o.fn(s) }

func (o encodingOpt) serdeOrEncodingOpt() { /* satisfy interface */ }
func (o encodingOpt) apply(t *tserde) { o.fn(t) }

// EncodeFn allows Serde to encode a value.
func EncodeFn(fn func(any) ([]byte, error)) SerdeOpt {
return serdeOpt{func(t *tserde) { t.encode = fn }}
func EncodeFn(fn func(any) ([]byte, error)) EncodingOpt {
return encodingOpt{func(t *tserde) { t.encode = fn }}
}

// AppendEncodeFn allows Serde to encode a value to an existing slice. This
// can be more efficient than EncodeFn; this function is used if it exists.
func AppendEncodeFn(fn func([]byte, any) ([]byte, error)) SerdeOpt {
return serdeOpt{func(t *tserde) { t.appendEncode = fn }}
func AppendEncodeFn(fn func([]byte, any) ([]byte, error)) EncodingOpt {
return encodingOpt{func(t *tserde) { t.appendEncode = fn }}
}

// DecodeFn allows Serde to decode into a value.
func DecodeFn(fn func([]byte, any) error) SerdeOpt {
return serdeOpt{func(t *tserde) { t.decode = fn }}
func DecodeFn(fn func([]byte, any) error) EncodingOpt {
return encodingOpt{func(t *tserde) { t.decode = fn }}
}

// GenerateFn returns a new(Value) that can be decoded into. This function can
// be used to control the instantiation of a new type for DecodeNew.
func GenerateFn(fn func() any) SerdeOpt {
return serdeOpt{func(t *tserde) { t.gen = fn }}
func GenerateFn(fn func() any) EncodingOpt {
return encodingOpt{func(t *tserde) { t.gen = fn }}
}

// Index attaches a message index to a value. A single schema ID can be
Expand All @@ -62,8 +82,13 @@ func GenerateFn(fn func() any) SerdeOpt {
// For more information, see where `message-indexes` are described in:
//
// https://docs.confluent.io/platform/current/schema-registry/serdes-develop/index.html#wire-format
func Index(index ...int) SerdeOpt {
return serdeOpt{func(t *tserde) { t.index = index }}
func Index(index ...int) EncodingOpt {
return encodingOpt{func(t *tserde) { t.index = index }}
}

// Header defines the SerdeHeader used to encode and decode the message header.
func Header(header SerdeHeader) SerdeOpt {
return serdeOpt{func(s *Serde) { s.h = header }}
}

type tserde struct {
Expand Down Expand Up @@ -96,7 +121,7 @@ type Serde struct {
types atomic.Value // map[reflect.Type]tserde
mu sync.Mutex

defaults []SerdeOpt
defaults []EncodingOpt
h SerdeHeader
}

Expand All @@ -105,6 +130,25 @@ var (
noTypes = make(map[reflect.Type]tserde)
)

// NewSerde returns a new Serde using the supplied default options, which are
// applied to every registered type. These options are always applied first, so
// you can override them as necessary when registering.
//
// This can be useful if you always want to use the same encoding or decoding
// functions.
func NewSerde(opts ...SerdeOrEncodingOpt) *Serde {
var s Serde
for _, opt := range opts {
switch opt := opt.(type) {
case SerdeOpt:
opt.apply(&s)
case EncodingOpt:
s.defaults = append(s.defaults, opt)
}
}
return &s
}

func (s *Serde) loadIDs() map[int]tserde {
ids := s.ids.Load()
if ids == nil {
Expand All @@ -121,16 +165,6 @@ func (s *Serde) loadTypes() map[reflect.Type]tserde {
return types.(map[reflect.Type]tserde)
}

// SetDefaults sets default options to apply to every registered type. These
// options are always applied first, so you can override them as necessary when
// registering.
//
// This can be useful if you always want to use the same encoding or decoding
// functions.
func (s *Serde) SetDefaults(opts ...SerdeOpt) {
s.defaults = opts
}

// DecodeID decodes an ID from b, returning the ID and the remaining bytes,
// or an error.
func (s *Serde) DecodeID(b []byte) (id int, out []byte, err error) {
Expand All @@ -154,7 +188,7 @@ func (s *Serde) header() SerdeHeader {
// Register registers a schema ID and the value it corresponds to, as well as
// the encoding or decoding functions. You need to register functions depending
// on whether you are only encoding, only decoding, or both.
func (s *Serde) Register(id int, v any, opts ...SerdeOpt) {
func (s *Serde) Register(id int, v any, opts ...EncodingOpt) {
var t tserde
for _, opt := range s.defaults {
opt.apply(&t)
Expand Down Expand Up @@ -258,20 +292,18 @@ func (s *Serde) Encode(v any) ([]byte, error) {
return s.AppendEncode(nil, v)
}

// AppendEncode appends an encoded value to b according to the schema registry
// wire format and returns it. If EncodeFn was not used, this returns
// ErrNotRegistered.
// AppendEncode encodes a value and prepends the header according to the
// configured SerdeHeader, appends it to b and returns b. If EncodeFn was not
// registered, this returns ErrNotRegistered.
func (s *Serde) AppendEncode(b []byte, v any) ([]byte, error) {
t, ok := s.loadTypes()[reflect.TypeOf(v)]
if !ok || (t.encode == nil && t.appendEncode == nil) {
return b, ErrNotRegistered
}

b, err := s.header().AppendEncode(b, int(t.id), t.index)
if err != nil {
return nil, err
}

if t.appendEncode != nil {
return t.appendEncode(b, v)
}
Expand Down Expand Up @@ -328,8 +360,10 @@ func (s *Serde) DecodeNew(b []byte) (any, error) {
var v any
if t.gen != nil {
v = t.gen()
} else {
} else if t.typeof != nil {
v = reflect.New(t.typeof).Interface()
} else {
return nil, ErrNotRegistered
}
return v, t.decode(b, v)
}
Expand All @@ -339,7 +373,6 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) {
if err != nil {
return nil, tserde{}, err
}

t := s.loadIDs()[id]
if len(t.subindex) > 0 {
var index []int
Expand All @@ -360,6 +393,28 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) {
return b, t, nil
}

// Encode encodes a value and prepends the header. If the encoding function
// fails, this returns an error.
func Encode(v any, id int, index []int, h SerdeHeader, enc func(any) ([]byte, error)) ([]byte, error) {
return AppendEncode(nil, v, id, index, h, func(b []byte, val any) ([]byte, error) {
encoded, err := enc(val)
if err != nil {
return nil, err
}
return append(b, encoded...), nil
})
}

// AppendEncode encodes a value and prepends the header, appends it to b and
// returns b. If the encoding function fails, this returns an error.
func AppendEncode(b []byte, v any, id int, index []int, h SerdeHeader, enc func([]byte, any) ([]byte, error)) ([]byte, error) {
b, err := h.AppendEncode(b, id, index)
if err != nil {
return nil, err
}
return enc(b, v)
}

// SerdeHeader encodes and decodes a message header.
type SerdeHeader interface {
// AppendEncode encodes a schema ID and optional index to b, returning the
Expand Down
31 changes: 29 additions & 2 deletions pkg/sr/serde_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ func TestSerde(t *testing.T) {
}
)

var serde Serde
serde.SetDefaults(
serde := NewSerde(
EncodeFn(json.Marshal),
DecodeFn(json.Unmarshal),
)
Expand Down Expand Up @@ -113,6 +112,19 @@ func TestSerde(t *testing.T) {
t.Errorf("#%d got MustAppendEncode(%v) != Encode(foo%v)", i, b2, b)
}

bIndented, err := Encode(test.enc, 100, []int{0}, serde.header(), func(v any) ([]byte, error) {
return json.MarshalIndent(v, "", " ")
})
if err != nil {
t.Errorf("#%d Encode[ID=100]: got err? %v, exp err? %v", i, gotErr, test.expErr)
continue
}
if i := bytes.IndexByte(bIndented, '{'); !bytes.Equal(bIndented[:i], []byte{0, 0, 0, 0, 100, 0}) {
t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[:i], []byte{0, 0, 0, 0, 100, 0})
} else if expIndented := extractIndentedJSON(b); !bytes.Equal(bIndented[i:], expIndented) {
t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[i:], expIndented)
}

v, err := serde.DecodeNew(b)
if err != nil {
t.Errorf("#%d DecodeNew: got unexpected err %v", i, err)
Expand All @@ -126,6 +138,7 @@ func TestSerde(t *testing.T) {
}
if !reflect.DeepEqual(v, exp) {
t.Errorf("#%d round trip: got %v != exp %v", i, v, exp)
continue
}
}

Expand All @@ -141,6 +154,20 @@ func TestSerde(t *testing.T) {
if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 99}); err != ErrNotRegistered {
t.Errorf("got %v != exp ErrNotRegistered", err)
}
if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 100, 0}); err != ErrNotRegistered {
// schema is registered but type is unknown
t.Errorf("got %v != exp ErrNotRegistered", err)
}
}

func extractIndentedJSON(in []byte) []byte {
i := bytes.IndexByte(in, '{') // skip header
var out bytes.Buffer
err := json.Indent(&out, in[i:], "", " ")
if err != nil {
panic(err)
}
return out.Bytes()
}

func TestConfluentHeader(t *testing.T) {
Expand Down
Loading