Skip to content

Commit

Permalink
Refactor the websocket client and add fixes
Browse files Browse the repository at this point in the history
The websocket client and hub interaction has been simplified a bit.
The hub now acts only as a tee writer to the various clients that
register. Clients must register and unregister explicitly. The hub
is no longer passed in to the client.

Websocket clients now watch for password changes or jwt token expiration
times. Clients are disconnected if auth token expires or if the password
is changed.

Various aditional safety checks have been added.

Signed-off-by: Gabriel Adrian Samfira <gsamfira@cloudbasesolutions.com>
  • Loading branch information
gabriel-samfira committed Jul 2, 2024
1 parent dcee092 commit cf35997
Show file tree
Hide file tree
Showing 17 changed files with 423 additions and 140 deletions.
18 changes: 10 additions & 8 deletions apiserver/controllers/controllers.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,9 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets")
return
}
defer conn.Close()

// nolint:golangci-lint,godox
// TODO (gsamfira): Handle ExpiresAt. Right now, if a client uses
// a valid token to authenticate, and keeps the websocket connection
// open, it will allow that client to stream logs via websockets
// until the connection is broken. We need to forcefully disconnect
// the client once the token expires.
client, err := wsWriter.NewClient(conn, a.hub)
client, err := wsWriter.NewClient(ctx, conn)
if err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client")
return
Expand All @@ -199,7 +194,14 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to register new client")
return
}
client.Go()
defer a.hub.Unregister(client)

if err := client.Start(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start client")
return
}
<-client.Done()
slog.Info("client disconnected", "client_id", client.ID())
}

// NotFoundHandler is returned when an invalid URL is acccessed
Expand Down
12 changes: 7 additions & 5 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,19 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) {
expires := &jwt.NumericDate{
Time: expireToken,
}
generation := PasswordGeneration(ctx)
claims := JWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: expires,
// nolint:golangci-lint,godox
// TODO: make this configurable
Issuer: "garm",
},
UserID: UserID(ctx),
TokenID: tokenID,
IsAdmin: IsAdmin(ctx),
FullName: FullName(ctx),
UserID: UserID(ctx),
TokenID: tokenID,
IsAdmin: IsAdmin(ctx),
FullName: FullName(ctx),
Generation: generation,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.cfg.Secret))
Expand Down Expand Up @@ -182,5 +184,5 @@ func (a *Authenticator) AuthenticateUser(ctx context.Context, info params.Passwo
return ctx, runnerErrors.ErrUnauthorized
}

return PopulateContext(ctx, user), nil
return PopulateContext(ctx, user, nil), nil
}
40 changes: 36 additions & 4 deletions auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package auth

import (
"context"
"time"

runnerErrors "github.com/cloudbase/garm-provider-common/errors"
"github.com/cloudbase/garm/params"
Expand All @@ -28,9 +29,11 @@ const (
fullNameKey contextFlags = "full_name"
readMetricsKey contextFlags = "read_metrics"
// UserIDFlag is the User ID flag we set in the context
UserIDFlag contextFlags = "user_id"
isEnabledFlag contextFlags = "is_enabled"
jwtTokenFlag contextFlags = "jwt_token"
UserIDFlag contextFlags = "user_id"
isEnabledFlag contextFlags = "is_enabled"
jwtTokenFlag contextFlags = "jwt_token"
authExpiresFlag contextFlags = "auth_expires"
passwordGenerationFlag contextFlags = "password_generation"

instanceIDKey contextFlags = "id"
instanceNameKey contextFlags = "name"
Expand Down Expand Up @@ -169,14 +172,43 @@ func PopulateInstanceContext(ctx context.Context, instance params.Instance) cont

// PopulateContext sets the appropriate fields in the context, based on
// the user object
func PopulateContext(ctx context.Context, user params.User) context.Context {
func PopulateContext(ctx context.Context, user params.User, authExpires *time.Time) context.Context {
ctx = SetUserID(ctx, user.ID)
ctx = SetAdmin(ctx, user.IsAdmin)
ctx = SetIsEnabled(ctx, user.Enabled)
ctx = SetFullName(ctx, user.FullName)
ctx = SetExpires(ctx, authExpires)
ctx = SetPasswordGeneration(ctx, user.Generation)
return ctx
}

func SetExpires(ctx context.Context, expires *time.Time) context.Context {
if expires == nil {
return ctx
}
return context.WithValue(ctx, authExpiresFlag, expires)
}

func Expires(ctx context.Context) *time.Time {
elem := ctx.Value(authExpiresFlag)
if elem == nil {
return nil
}
return elem.(*time.Time)
}

func SetPasswordGeneration(ctx context.Context, val uint) context.Context {
return context.WithValue(ctx, passwordGenerationFlag, val)
}

func PasswordGeneration(ctx context.Context) uint {
elem := ctx.Value(passwordGenerationFlag)
if elem == nil {
return 0
}
return elem.(uint)
}

// SetFullName sets the user full name in the context
func SetFullName(ctx context.Context, fullName string) context.Context {
return context.WithValue(ctx, fullNameKey, fullName)
Expand Down
15 changes: 14 additions & 1 deletion auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"log/slog"
"net/http"
"strings"
"time"

jwt "github.com/golang-jwt/jwt/v5"

Expand All @@ -37,6 +38,7 @@ type JWTClaims struct {
FullName string `json:"full_name"`
IsAdmin bool `json:"is_admin"`
ReadMetrics bool `json:"read_metrics"`
Generation uint `json:"generation"`
jwt.RegisteredClaims
}

Expand Down Expand Up @@ -69,7 +71,18 @@ func (amw *jwtMiddleware) claimsToContext(ctx context.Context, claims *JWTClaims
return ctx, runnerErrors.ErrUnauthorized
}

ctx = PopulateContext(ctx, userInfo)
var expiresAt *time.Time
if claims.ExpiresAt != nil {
expires := claims.ExpiresAt.Time.UTC()
expiresAt = &expires
}

if userInfo.Generation != claims.Generation {
// Password was reset since token was issued. Invalidate.
return ctx, runnerErrors.ErrUnauthorized
}

ctx = PopulateContext(ctx, userInfo, expiresAt)
return ctx, nil
}

Expand Down
5 changes: 4 additions & 1 deletion cmd/garm-cli/cmd/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/cloudbase/garm-provider-common/util"
apiParams "github.com/cloudbase/garm/apiserver/params"
garmWs "github.com/cloudbase/garm/websocket"
)

var logCmd = &cobra.Command{
Expand Down Expand Up @@ -66,7 +67,9 @@ var logCmd = &cobra.Command{
for {
_, message, err := c.ReadMessage()
if err != nil {
slog.With(slog.Any("error", err)).Error("reading log message")
if garmWs.IsErrorOfInterest(err) {
slog.With(slog.Any("error", err)).Error("reading log message")
}
return
}
fmt.Println(util.SanitizeLogEntry(string(message)))
Expand Down
2 changes: 1 addition & 1 deletion cmd/garm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func main() {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed")
}

slog.With(slog.Any("error", err)).ErrorContext(ctx, "waiting for runner to stop")
slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop")
if err := runner.Wait(); err != nil {
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers")
os.Exit(1)
Expand Down
14 changes: 7 additions & 7 deletions database/sql/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ func (s *GithubTestSuite) TestCreateCredentials() {
func (s *GithubTestSuite) TestCreateCredentialsFailsOnDuplicateCredentials() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "testuser", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser)
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)

credParams := params.CreateGithubCredentialsParams{
Name: testCredsName,
Expand Down Expand Up @@ -313,8 +313,8 @@ func (s *GithubTestSuite) TestNormalUsersCanOnlySeeTheirOwnCredentialsAdminCanSe
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "testuser1", s.db, s.T())
testUser2 := garmTesting.CreateGARMTestUser(ctx, "testuser2", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser)
testUser2Ctx := auth.PopulateContext(context.Background(), testUser2)
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
testUser2Ctx := auth.PopulateContext(context.Background(), testUser2, nil)

credParams := params.CreateGithubCredentialsParams{
Name: testCredsName,
Expand Down Expand Up @@ -370,7 +370,7 @@ func (s *GithubTestSuite) TestGetGithubCredentialsFailsWhenCredentialsDontExist(
func (s *GithubTestSuite) TestGetGithubCredentialsByNameReturnsOnlyCurrentUserCredentials() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user1", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser)
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)

credParams := params.CreateGithubCredentialsParams{
Name: testCredsName,
Expand Down Expand Up @@ -472,7 +472,7 @@ func (s *GithubTestSuite) TestDeleteGithubCredentials() {
func (s *GithubTestSuite) TestDeleteGithubCredentialsByNonAdminUser() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user4", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser)
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)

credParams := params.CreateGithubCredentialsParams{
Name: testCredsName,
Expand Down Expand Up @@ -682,7 +682,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsForNonExistingCredentials()
func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAdminUser() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser)
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)

credParams := params.CreateGithubCredentialsParams{
Name: testCredsName,
Expand Down Expand Up @@ -711,7 +711,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAd
func (s *GithubTestSuite) TestAdminUserCanUpdateAnyGithubCredentials() {
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T())
testUserCtx := auth.PopulateContext(context.Background(), testUser)
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)

credParams := params.CreateGithubCredentialsParams{
Name: testCredsName,
Expand Down
13 changes: 7 additions & 6 deletions database/sql/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,13 @@ type Instance struct {
type User struct {
Base

Username string `gorm:"uniqueIndex;varchar(64)"`
FullName string `gorm:"type:varchar(254)"`
Email string `gorm:"type:varchar(254);unique;index:idx_email"`
Password string `gorm:"type:varchar(60)"`
IsAdmin bool
Enabled bool
Username string `gorm:"uniqueIndex;varchar(64)"`
FullName string `gorm:"type:varchar(254)"`
Email string `gorm:"type:varchar(254);unique;index:idx_email"`
Password string `gorm:"type:varchar(60)"`
Generation uint
IsAdmin bool
Enabled bool
}

type ControllerInfo struct {
Expand Down
2 changes: 1 addition & 1 deletion database/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (s *sqlDatabase) migrateCredentialsToDB() (err error) {
// user. GARM is not yet multi-user, so it's safe to assume we only have this
// one user.
adminCtx := context.Background()
adminCtx = auth.PopulateContext(adminCtx, adminUser)
adminCtx = auth.PopulateContext(adminCtx, adminUser, nil)

slog.Info("migrating credentials to DB")
slog.Info("creating github endpoints table")
Expand Down
Loading

0 comments on commit cf35997

Please sign in to comment.