Skip to content

Commit

Permalink
- add auth token support for grpc and rest
Browse files Browse the repository at this point in the history
  • Loading branch information
kubemq committed Dec 25, 2019
1 parent 2726f03 commit bf812fc
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 31 deletions.
8 changes: 4 additions & 4 deletions grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,17 @@ func newGRPCTransport(ctx context.Context, opts *Options) (Transport, *ServerInf

func (g *gRPCTransport) SetUnaryInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if g.opts.token != "" {
ctx = metadata.AppendToOutgoingContext(ctx, kubeMQTokenHeader, g.opts.token)
if g.opts.authToken != "" {
ctx = metadata.AppendToOutgoingContext(ctx, kubeMQAuthTokenHeader, g.opts.authToken)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

func (g *gRPCTransport) SetStreamInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if g.opts.token != "" {
ctx = metadata.AppendToOutgoingContext(ctx, kubeMQTokenHeader, g.opts.token)
if g.opts.authToken != "" {
ctx = metadata.AppendToOutgoingContext(ctx, kubeMQAuthTokenHeader, g.opts.authToken)
}
return streamer(ctx, desc, cc, method, opts...)
}
Expand Down
12 changes: 6 additions & 6 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

const kubeMQTokenHeader = "X-Kubemq-Server-Token"
const kubeMQAuthTokenHeader = "authorization"

type Option interface {
apply(*Options)
Expand All @@ -25,7 +25,7 @@ type Options struct {
certFile string
certData string
serverOverrideDomain string
token string
authToken string
clientId string
receiveBufferSize int
defaultChannel string
Expand Down Expand Up @@ -88,10 +88,10 @@ func WithCertificate(certData, serverOverrideDomain string) Option {
})
}

// WithToken - set KubeMQ token to be used for KubeMQ connection - not mandatory, only if enforced by the KubeMQ server
func WithToken(token string) Option {
// WithAuthToken - set KubeMQ JWT Auth token to be used for KubeMQ connection
func WithAuthToken(token string) Option {
return newFuncOption(func(o *Options) {
o.token = token
o.authToken = token
})
}

Expand Down Expand Up @@ -137,7 +137,7 @@ func GetDefaultOptions() *Options {
isSecured: false,
certFile: "",
serverOverrideDomain: "",
token: "",
authToken: "",
clientId: "ClientId",
receiveBufferSize: 10,
defaultChannel: "",
Expand Down
58 changes: 37 additions & 21 deletions rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"

"github.com/google/uuid"
pb "github.com/kubemq-io/protobuf/go"
Expand Down Expand Up @@ -32,9 +33,13 @@ func (res *restResponse) unmarshal(v interface{}) error {
return json.Unmarshal(res.Data, v)
}

func newWebsocketConn(ctx context.Context, uri string, readCh chan string, ready chan struct{}, errCh chan error) (*websocket.Conn, error) {
func newWebsocketConn(ctx context.Context, uri string, readCh chan string, ready chan struct{}, errCh chan error, authToken string) (*websocket.Conn, error) {
var c *websocket.Conn
conn, res, err := websocket.DefaultDialer.Dial(uri, nil)
header := http.Header{}
if authToken != "" {
header.Add("Authorization", fmt.Sprintf("Bearer %s", authToken))
}
conn, res, err := websocket.DefaultDialer.Dial(uri, header)
if err != nil {
buf := make([]byte, 1024)
if res != nil {
Expand Down Expand Up @@ -65,9 +70,13 @@ func newWebsocketConn(ctx context.Context, uri string, readCh chan string, ready
return c, nil
}

func newBiDirectionalWebsocketConn(ctx context.Context, uri string, readCh chan string, writeCh chan []byte, ready chan struct{}, errCh chan error) (*websocket.Conn, error) {
func newBiDirectionalWebsocketConn(ctx context.Context, uri string, readCh chan string, writeCh chan []byte, ready chan struct{}, errCh chan error, authToken string) (*websocket.Conn, error) {
var c *websocket.Conn
conn, res, err := websocket.DefaultDialer.Dial(uri, nil)
header := http.Header{}
if authToken != "" {
header.Add("Authorization", fmt.Sprintf("Bearer %s", authToken))
}
conn, res, err := websocket.DefaultDialer.Dial(uri, header)
if err != nil {
buf := make([]byte, 1024)
if res != nil {
Expand Down Expand Up @@ -136,11 +145,18 @@ func newRestTransport(ctx context.Context, opts *Options) (Transport, *ServerInf
}
return rt, si, nil
}
func (rt *restTransport) newRequest() *resty.Request {
r := resty.New().R()
if rt.opts.authToken != "" {
r.SetAuthToken(rt.opts.authToken)
}
return r

}
func (rt *restTransport) Ping(ctx context.Context) (*ServerInfo, error) {
resp := &restResponse{}
uri := fmt.Sprintf("%s/ping", rt.restAddress)
_, err := resty.New().R().SetResult(resp).SetError(resp).Get(uri)
_, err := rt.newRequest().SetResult(resp).SetError(resp).Get(uri)
if err != nil {
return nil, err
}
Expand All @@ -163,7 +179,7 @@ func (rt *restTransport) SendEvent(ctx context.Context, event *Event) error {
Tags: event.Tags,
}
uri := fmt.Sprintf("%s/send/event", rt.restAddress)
_, err := resty.New().R().SetBody(eventPb).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(eventPb).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return err
}
Expand All @@ -181,7 +197,7 @@ func (rt *restTransport) StreamEvents(ctx context.Context, eventsCh chan *Event,
wsErrCh := make(chan error, 1)
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh)
conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh, rt.opts.authToken)
if err != nil {
errCh <- err
return
Expand Down Expand Up @@ -223,7 +239,7 @@ func (rt *restTransport) SubscribeToEvents(ctx context.Context, channel, group s
rxChan := make(chan string)
ready := make(chan struct{}, 1)
wsErrCh := make(chan error, 1)
conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh)
conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh, rt.opts.authToken)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -271,7 +287,7 @@ func (rt *restTransport) SendEventStore(ctx context.Context, eventStore *EventSt
Tags: eventStore.Tags,
}
uri := fmt.Sprintf("%s/send/event", rt.restAddress)
_, err := resty.New().R().SetBody(eventPb).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(eventPb).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return nil, err
}
Expand All @@ -290,7 +306,7 @@ func (rt *restTransport) StreamEventsStore(ctx context.Context, eventsCh chan *E
wsErrCh := make(chan error, 1)
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh)
conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh, rt.opts.authToken)
if err != nil {
errCh <- err
return
Expand Down Expand Up @@ -339,7 +355,7 @@ func (rt *restTransport) SubscribeToEventsStore(ctx context.Context, channel, gr
rxChan := make(chan string)
ready := make(chan struct{}, 1)
wsErrCh := make(chan error, 1)
conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh)
conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh, rt.opts.authToken)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -395,7 +411,7 @@ func (rt *restTransport) SendCommand(ctx context.Context, command *Command) (*Co
XXX_sizecache: 0,
}
uri := fmt.Sprintf("%s/send/request", rt.restAddress)
_, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return nil, err
}
Expand All @@ -412,7 +428,7 @@ func (rt *restTransport) SubscribeToCommands(ctx context.Context, channel, group
rxChan := make(chan string)
ready := make(chan struct{}, 1)
wsErrCh := make(chan error, 1)
conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh)
conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh, rt.opts.authToken)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -465,7 +481,7 @@ func (rt *restTransport) SendQuery(ctx context.Context, query *Query) (*QueryRes
Tags: query.Tags,
}
uri := fmt.Sprintf("%s/send/request", rt.restAddress)
_, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return nil, err
}
Expand All @@ -482,7 +498,7 @@ func (rt *restTransport) SubscribeToQueries(ctx context.Context, channel, group
rxChan := make(chan string)
ready := make(chan struct{}, 1)
wsErrCh := make(chan error, 1)
conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh)
conn, err := newWebsocketConn(ctx, uri, rxChan, ready, wsErrCh, rt.opts.authToken)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -533,7 +549,7 @@ func (rt *restTransport) SendResponse(ctx context.Context, response *Response) e
Tags: response.Tags,
}
uri := fmt.Sprintf("%s/send/response", rt.restAddress)
_, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return err
}
Expand Down Expand Up @@ -561,7 +577,7 @@ func (rt *restTransport) SendQueueMessage(ctx context.Context, msg *QueueMessage
},
}
uri := fmt.Sprintf("%s/queue/send", rt.restAddress)
_, err := resty.New().R().SetBody(msgSend).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(msgSend).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -596,7 +612,7 @@ func (rt *restTransport) SendQueueMessages(ctx context.Context, msgs []*QueueMes
})
}
uri := fmt.Sprintf("%s/queue/send_batch", rt.restAddress)
_, err := resty.New().R().SetBody(br).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(br).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -630,7 +646,7 @@ func (rt *restTransport) ReceiveQueueMessages(ctx context.Context, req *ReceiveQ
IsPeak: req.IsPeak,
}
uri := fmt.Sprintf("%s/queue/receive", rt.restAddress)
_, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return nil, err
}
Expand All @@ -650,7 +666,7 @@ func (rt *restTransport) AckAllQueueMessages(ctx context.Context, req *AckAllQue
WaitTimeSeconds: req.WaitTimeSeconds,
}
uri := fmt.Sprintf("%s/queue/ack_all", rt.restAddress)
_, err := resty.New().R().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
_, err := rt.newRequest().SetBody(request).SetResult(resp).SetError(resp).Post(uri)
if err != nil {
return nil, err
}
Expand All @@ -669,7 +685,7 @@ func (rt *restTransport) StreamQueueMessage(ctx context.Context, reqCh chan *pb.
ready := make(chan struct{}, 1)
wsErrCh := make(chan error, 1)
newCtx, cancel := context.WithCancel(ctx)
conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh)
conn, err := newBiDirectionalWebsocketConn(newCtx, uri, readCh, writeCh, ready, wsErrCh, rt.opts.authToken)
if err != nil {
errCh <- err
return
Expand Down

0 comments on commit bf812fc

Please sign in to comment.