diff --git a/grpc.go b/grpc.go index 3018a00..127152b 100644 --- a/grpc.go +++ b/grpc.go @@ -79,8 +79,8 @@ func newGRPCTransport(ctx context.Context, opts *Options) (Transport, *ServerInf func (g *gRPCTransport) SetUnaryInterceptor() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if g.opts.token != "" { - ctx = metadata.AppendToOutgoingContext(ctx, kubeMQTokenHeader, g.opts.token) + if g.opts.authToken != "" { + ctx = metadata.AppendToOutgoingContext(ctx, kubeMQAuthTokenHeader, g.opts.authToken) } return invoker(ctx, method, req, reply, cc, opts...) } @@ -88,8 +88,8 @@ func (g *gRPCTransport) SetUnaryInterceptor() grpc.UnaryClientInterceptor { func (g *gRPCTransport) SetStreamInterceptor() grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - if g.opts.token != "" { - ctx = metadata.AppendToOutgoingContext(ctx, kubeMQTokenHeader, g.opts.token) + if g.opts.authToken != "" { + ctx = metadata.AppendToOutgoingContext(ctx, kubeMQAuthTokenHeader, g.opts.authToken) } return streamer(ctx, desc, cc, method, opts...) } diff --git a/options.go b/options.go index 3545bef..9d6b875 100644 --- a/options.go +++ b/options.go @@ -6,7 +6,7 @@ import ( "time" ) -const kubeMQTokenHeader = "X-Kubemq-Server-Token" +const kubeMQAuthTokenHeader = "authorization" type Option interface { apply(*Options) @@ -25,7 +25,7 @@ type Options struct { certFile string certData string serverOverrideDomain string - token string + authToken string clientId string receiveBufferSize int defaultChannel string @@ -88,10 +88,10 @@ func WithCertificate(certData, serverOverrideDomain string) Option { }) } -// WithToken - set KubeMQ token to be used for KubeMQ connection - not mandatory, only if enforced by the KubeMQ server -func WithToken(token string) Option { +// WithAuthToken - set KubeMQ JWT Auth token to be used for KubeMQ connection +func WithAuthToken(token string) Option { return newFuncOption(func(o *Options) { - o.token = token + o.authToken = token }) } @@ -137,7 +137,7 @@ func GetDefaultOptions() *Options { isSecured: false, certFile: "", serverOverrideDomain: "", - token: "", + authToken: "", clientId: "ClientId", receiveBufferSize: 10, defaultChannel: "", diff --git a/rest.go b/rest.go index 9d3aaf2..257e714 100644 --- a/rest.go +++ b/rest.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "github.com/google/uuid" pb "github.com/kubemq-io/protobuf/go" @@ -32,9 +33,13 @@ func (res *restResponse) unmarshal(v interface{}) error { return json.Unmarshal(res.Data, v) } -func newWebsocketConn(ctx context.Context, uri string, readCh chan string, ready chan struct{}, errCh chan error) (*websocket.Conn, error) { +func newWebsocketConn(ctx context.Context, uri string, readCh chan string, ready chan struct{}, errCh chan error, authToken string) (*websocket.Conn, error) { var c *websocket.Conn - conn, res, err := websocket.DefaultDialer.Dial(uri, nil) + header := http.Header{} + if authToken != "" { + header.Add("Authorization", fmt.Sprintf("Bearer %s", authToken)) + } + conn, res, err := websocket.DefaultDialer.Dial(uri, header) if err != nil { buf := make([]byte, 1024) if res != nil { @@ -65,9 +70,13 @@ func newWebsocketConn(ctx context.Context, uri string, readCh chan string, ready return c, nil } -func newBiDirectionalWebsocketConn(ctx context.Context, uri string, readCh chan string, writeCh chan []byte, ready chan struct{}, errCh chan error) (*websocket.Conn, error) { +func newBiDirectionalWebsocketConn(ctx context.Context, uri string, readCh chan string, writeCh chan []byte, ready chan struct{}, errCh chan error, authToken string) (*websocket.Conn, error) { var c *websocket.Conn - conn, res, err := websocket.DefaultDialer.Dial(uri, nil) + header := http.Header{} + if authToken != "" { + header.Add("Authorization", fmt.Sprintf("Bearer %s", authToken)) + } + conn, res, err := websocket.DefaultDialer.Dial(uri, header) if err != nil { buf := make([]byte, 1024) if res != nil { @@ -136,11 +145,18 @@ func newRestTransport(ctx context.Context, opts *Options) (Transport, *ServerInf } return rt, si, nil } +func (rt *restTransport) newRequest() *resty.Request { + r := resty.New().R() + if rt.opts.authToken != "" { + r.SetAuthToken(rt.opts.authToken) + } + return r +} func (rt *restTransport) Ping(ctx context.Context) (*ServerInfo, error) { resp := &restResponse{} uri := fmt.Sprintf("%s/ping", rt.restAddress) - _, err := resty.New().R().SetResult(resp).SetError(resp).Get(uri) + _, err := rt.newRequest().SetResult(resp).SetError(resp).Get(uri) if err != nil { return nil, err } @@ -163,7 +179,7 @@ func (rt *restTransport) SendEvent(ctx context.Context, event *Event) error { Tags: event.Tags, } uri := fmt.Sprintf("%s/send/event", rt.restAddress) - _, err := resty.New().R().SetBody(eventPb).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(eventPb).SetResult(resp).SetError(resp).Post(uri) if err != nil { return err } @@ -181,7 +197,7 @@ func (rt *restTransport) StreamEvents(ctx context.Context, eventsCh chan *Event, wsErrCh := make(chan error, 1) newCtx, cancel := context.WithCancel(ctx) defer cancel() - conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh) + conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh, rt.opts.authToken) if err != nil { errCh <- err return @@ -223,7 +239,7 @@ func (rt *restTransport) SubscribeToEvents(ctx context.Context, channel, group s rxChan := make(chan string) ready := make(chan struct{}, 1) wsErrCh := make(chan error, 1) - conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh) + conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh, rt.opts.authToken) if err != nil { return nil, err } @@ -271,7 +287,7 @@ func (rt *restTransport) SendEventStore(ctx context.Context, eventStore *EventSt Tags: eventStore.Tags, } uri := fmt.Sprintf("%s/send/event", rt.restAddress) - _, err := resty.New().R().SetBody(eventPb).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(eventPb).SetResult(resp).SetError(resp).Post(uri) if err != nil { return nil, err } @@ -290,7 +306,7 @@ func (rt *restTransport) StreamEventsStore(ctx context.Context, eventsCh chan *E wsErrCh := make(chan error, 1) newCtx, cancel := context.WithCancel(ctx) defer cancel() - conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh) + conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh, rt.opts.authToken) if err != nil { errCh <- err return @@ -339,7 +355,7 @@ func (rt *restTransport) SubscribeToEventsStore(ctx context.Context, channel, gr rxChan := make(chan string) ready := make(chan struct{}, 1) wsErrCh := make(chan error, 1) - conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh) + conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh, rt.opts.authToken) if err != nil { return nil, err } @@ -395,7 +411,7 @@ func (rt *restTransport) SendCommand(ctx context.Context, command *Command) (*Co XXX_sizecache: 0, } uri := fmt.Sprintf("%s/send/request", rt.restAddress) - _, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri) if err != nil { return nil, err } @@ -412,7 +428,7 @@ func (rt *restTransport) SubscribeToCommands(ctx context.Context, channel, group rxChan := make(chan string) ready := make(chan struct{}, 1) wsErrCh := make(chan error, 1) - conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh) + conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh, rt.opts.authToken) if err != nil { return nil, err } @@ -465,7 +481,7 @@ func (rt *restTransport) SendQuery(ctx context.Context, query *Query) (*QueryRes Tags: query.Tags, } uri := fmt.Sprintf("%s/send/request", rt.restAddress) - _, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri) if err != nil { return nil, err } @@ -482,7 +498,7 @@ func (rt *restTransport) SubscribeToQueries(ctx context.Context, channel, group rxChan := make(chan string) ready := make(chan struct{}, 1) wsErrCh := make(chan error, 1) - conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh) + conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh, rt.opts.authToken) if err != nil { return nil, err } @@ -533,7 +549,7 @@ func (rt *restTransport) SendResponse(ctx context.Context, response *Response) e Tags: response.Tags, } uri := fmt.Sprintf("%s/send/response", rt.restAddress) - _, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri) if err != nil { return err } @@ -561,7 +577,7 @@ func (rt *restTransport) SendQueueMessage(ctx context.Context, msg *QueueMessage }, } uri := fmt.Sprintf("%s/queue/send", rt.restAddress) - _, err := resty.New().R().SetBody(msgSend).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(msgSend).SetResult(resp).SetError(resp).Post(uri) if err != nil { return nil, err } @@ -596,7 +612,7 @@ func (rt *restTransport) SendQueueMessages(ctx context.Context, msgs []*QueueMes }) } uri := fmt.Sprintf("%s/queue/send_batch", rt.restAddress) - _, err := resty.New().R().SetBody(br).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(br).SetResult(resp).SetError(resp).Post(uri) if err != nil { return nil, err } @@ -630,7 +646,7 @@ func (rt *restTransport) ReceiveQueueMessages(ctx context.Context, req *ReceiveQ IsPeak: req.IsPeak, } uri := fmt.Sprintf("%s/queue/receive", rt.restAddress) - _, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri) if err != nil { return nil, err } @@ -650,7 +666,7 @@ func (rt *restTransport) AckAllQueueMessages(ctx context.Context, req *AckAllQue WaitTimeSeconds: req.WaitTimeSeconds, } uri := fmt.Sprintf("%s/queue/ack_all", rt.restAddress) - _, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri) + _, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri) if err != nil { return nil, err } @@ -669,7 +685,7 @@ func (rt *restTransport) StreamQueueMessage(ctx context.Context, reqCh chan *pb. ready := make(chan struct{}, 1) wsErrCh := make(chan error, 1) newCtx, cancel := context.WithCancel(ctx) - conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh) + conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh, rt.opts.authToken) if err != nil { errCh <- err return