Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add event stream #275

Merged
merged 8 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions apiserver/controllers/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

gErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm-provider-common/util"
"github.com/cloudbase/garm/apiserver/events"
"github.com/cloudbase/garm/apiserver/params"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/metrics"
Expand Down Expand Up @@ -163,6 +164,43 @@ func (a *APIController) WebhookHandler(w http.ResponseWriter, r *http.Request) {
}
}

func (a *APIController) EventsHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !auth.IsAdmin(ctx) {
w.WriteHeader(http.StatusForbidden)
if _, err := w.Write([]byte("events are available to admin users")); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to encode response")
}
return
}

conn, err := a.upgrader.Upgrade(w, r, nil)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets")
return
}
defer conn.Close()

wsClient, err := wsWriter.NewClient(ctx, conn)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client")
return
}
defer wsClient.Stop()

eventHandler, err := events.NewHandler(ctx, wsClient)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new event handler")
return
}

if err := eventHandler.Start(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start event handler")
return
}
<-eventHandler.Done()
}

func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request) {
ctx := req.Context()
if !auth.IsAdmin(ctx) {
Expand All @@ -183,14 +221,9 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets")
return
}
defer conn.Close()

// nolint:golangci-lint,godox
// TODO (gsamfira): Handle ExpiresAt. Right now, if a client uses
// a valid token to authenticate, and keeps the websocket connection
// open, it will allow that client to stream logs via websockets
// until the connection is broken. We need to forcefully disconnect
// the client once the token expires.
client, err := wsWriter.NewClient(conn, a.hub)
client, err := wsWriter.NewClient(ctx, conn)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client")
return
Expand All @@ -199,7 +232,14 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to register new client")
return
}
client.Go()
defer a.hub.Unregister(client)

if err := client.Start(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start client")
return
}
<-client.Done()
slog.Info("client disconnected", "client_id", client.ID())
}

// NotFoundHandler is returned when an invalid URL is acccessed
Expand Down
181 changes: 181 additions & 0 deletions apiserver/events/events.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package events

import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"sync"

runnerErrors "github.com/cloudbase/garm-provider-common/errors"
commonUtil "github.com/cloudbase/garm-provider-common/util"
"github.com/cloudbase/garm/auth"
"github.com/cloudbase/garm/database/common"
"github.com/cloudbase/garm/database/watcher"
"github.com/cloudbase/garm/websocket"
)

func NewHandler(ctx context.Context, client *websocket.Client) (*EventHandler, error) {
if client == nil {
return nil, runnerErrors.ErrUnauthorized
}

newID := commonUtil.NewID()
userID := auth.UserID(ctx)
if userID == "" {
return nil, runnerErrors.ErrUnauthorized
}
consumerID := fmt.Sprintf("ws-event-watcher-%s-%s", userID, newID)
consumer, err := watcher.RegisterConsumer(
// Filter everything by default. Users should set up filters
// after registration.
ctx, consumerID, watcher.WithNone())
if err != nil {
return nil, err
}

handler := &EventHandler{
client: client,
ctx: ctx,
consumer: consumer,
done: make(chan struct{}),
}
client.SetMessageHandler(handler.HandleClientMessages)

return handler, nil
}

type EventHandler struct {
client *websocket.Client
consumer common.Consumer

ctx context.Context
done chan struct{}
running bool

mux sync.Mutex
}

func (e *EventHandler) loop() {
defer e.Stop()

for {
select {
case <-e.ctx.Done():
slog.DebugContext(e.ctx, "context done, stopping event handler")
return
case <-e.client.Done():
slog.DebugContext(e.ctx, "client done, stopping event handler")
return
case <-e.Done():
slog.DebugContext(e.ctx, "done channel closed, stopping event handler")
case event, ok := <-e.consumer.Watch():
if !ok {
slog.DebugContext(e.ctx, "watcher closed, stopping event handler")
return
}
asJs, err := json.Marshal(event)
if err != nil {
slog.ErrorContext(e.ctx, "failed to marshal event", "error", err)
continue
}
if _, err := e.client.Write(asJs); err != nil {
slog.ErrorContext(e.ctx, "failed to write event", "error", err)
}
}
}
}

func (e *EventHandler) Start() error {
e.mux.Lock()
defer e.mux.Unlock()

if e.running {
return nil
}

if err := e.client.Start(); err != nil {
return err
}
e.running = true
go e.loop()
return nil
}

func (e *EventHandler) Stop() {
e.mux.Lock()
defer e.mux.Unlock()

if !e.running {
return
}
e.running = false
e.consumer.Close()
e.client.Stop()
close(e.done)
}

func (e *EventHandler) Done() <-chan struct{} {
return e.done
}

// optionsToWatcherFilters converts the Options struct to a PayloadFilterFunc.
// The client will send an array of filters that indicates which entities and which
// operations the client is interested in. The behavior is that of "any" filter.
// Which means that if any of the elements in the array match an event, it will be
// sent to the websocket.
// Alternatively, clients can choose to get everything.
func (e *EventHandler) optionsToWatcherFilters(opt Options) common.PayloadFilterFunc {
if opt.SendEverything {
return watcher.WithEverything()
}

var funcs []common.PayloadFilterFunc
for _, filter := range opt.Filters {
var filterFunc []common.PayloadFilterFunc
if filter.EntityType == "" {
return watcher.WithNone()
}
filterFunc = append(filterFunc, watcher.WithEntityTypeFilter(filter.EntityType))
if len(filter.Operations) > 0 {
var opFunc []common.PayloadFilterFunc
for _, op := range filter.Operations {
opFunc = append(opFunc, watcher.WithOperationTypeFilter(op))
}
filterFunc = append(filterFunc, watcher.WithAny(opFunc...))
}
funcs = append(funcs, watcher.WithAll(filterFunc...))
}
return watcher.WithAny(funcs...)
}

func (e *EventHandler) HandleClientMessages(message []byte) error {
if e.consumer == nil {
return fmt.Errorf("consumer not initialized")
}

var opt Options
if err := json.Unmarshal(message, &opt); err != nil {
slog.ErrorContext(e.ctx, "failed to unmarshal message from client", "error", err, "message", string(message))
// Client is in error. Disconnect.
e.client.Write([]byte("failed to unmarshal filter"))
e.Stop()
return nil
}

if err := opt.Validate(); err != nil {
if errors.Is(err, common.ErrNoFiltersProvided) {
slog.DebugContext(e.ctx, "no filters provided; ignoring")
return nil
}
slog.ErrorContext(e.ctx, "invalid filter", "error", err)
e.client.Write([]byte("invalid filter"))
e.Stop()
return nil
}

watcherFilters := e.optionsToWatcherFilters(opt)
e.consumer.SetFilters(watcherFilters)
return nil
}
50 changes: 50 additions & 0 deletions apiserver/events/params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package events

import (
"github.com/cloudbase/garm/database/common"
)

type Filter struct {
Operations []common.OperationType `json:"operations"`
EntityType common.DatabaseEntityType `json:"entity_type"`
}

func (f Filter) Validate() error {
switch f.EntityType {
case common.RepositoryEntityType, common.OrganizationEntityType, common.EnterpriseEntityType,
common.PoolEntityType, common.UserEntityType, common.InstanceEntityType,
common.JobEntityType, common.ControllerEntityType, common.GithubCredentialsEntityType,
common.GithubEndpointEntityType:
default:
return common.ErrInvalidEntityType
}

for _, op := range f.Operations {
switch op {
case common.CreateOperation, common.UpdateOperation, common.DeleteOperation:
default:
return common.ErrInvalidOperationType
}
}
return nil
}

type Options struct {
SendEverything bool `json:"send_everything"`
Filters []Filter `json:"filters"`
}

func (o Options) Validate() error {
if o.SendEverything {
return nil
}
if len(o.Filters) == 0 {
return common.ErrNoFiltersProvided
}
for _, f := range o.Filters {
if err := f.Validate(); err != nil {
return err
}
}
return nil
}
1 change: 1 addition & 0 deletions apiserver/routers/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ func NewAPIRouter(han *controllers.APIController, authMiddleware, initMiddleware

// Websocket log writer
apiRouter.Handle("/{ws:ws\\/?}", http.HandlerFunc(han.WSHandler)).Methods("GET")
apiRouter.Handle("/{events:events\\/?}", http.HandlerFunc(han.EventsHandler)).Methods("GET")

// NotFound handler
apiRouter.PathPrefix("/").HandlerFunc(han.NotFoundHandler).Methods("GET", "POST", "PUT", "DELETE", "OPTIONS")
Expand Down
12 changes: 7 additions & 5 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,19 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) {
expires := &jwt.NumericDate{
Time: expireToken,
}
generation := PasswordGeneration(ctx)
claims := JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: expires,
// nolint:golangci-lint,godox
// TODO: make this configurable
Issuer: "garm",
},
UserID: UserID(ctx),
TokenID: tokenID,
IsAdmin: IsAdmin(ctx),
FullName: FullName(ctx),
UserID: UserID(ctx),
TokenID: tokenID,
IsAdmin: IsAdmin(ctx),
FullName: FullName(ctx),
Generation: generation,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.cfg.Secret))
Expand Down Expand Up @@ -182,5 +184,5 @@ func (a *Authenticator) AuthenticateUser(ctx context.Context, info params.Passwo
return ctx, runnerErrors.ErrUnauthorized
}

return PopulateContext(ctx, user), nil
return PopulateContext(ctx, user, nil), nil
}
Loading
Loading