diff --git a/apiserver/controllers/controllers.go b/apiserver/controllers/controllers.go index 556892f3..1d72d9b7 100644 --- a/apiserver/controllers/controllers.go +++ b/apiserver/controllers/controllers.go @@ -183,14 +183,9 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request) slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets") return } + defer conn.Close() - // nolint:golangci-lint,godox - // TODO (gsamfira): Handle ExpiresAt. Right now, if a client uses - // a valid token to authenticate, and keeps the websocket connection - // open, it will allow that client to stream logs via websockets - // until the connection is broken. We need to forcefully disconnect - // the client once the token expires. - client, err := wsWriter.NewClient(conn, a.hub) + client, err := wsWriter.NewClient(ctx, conn) if err != nil { slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client") return @@ -199,7 +194,14 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request) slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to register new client") return } - client.Go() + defer a.hub.Unregister(client) + + if err := client.Start(); err != nil { + slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start client") + return + } + <-client.Done() + slog.Info("client disconnected", "client_id", client.ID()) } // NotFoundHandler is returned when an invalid URL is acccessed diff --git a/auth/auth.go b/auth/auth.go index 4a4f957a..7dfabcf0 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -55,6 +55,7 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) { expires := &jwt.NumericDate{ Time: expireToken, } + generation := PasswordGeneration(ctx) claims := JWTClaims{ RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: expires, @@ -62,10 +63,11 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) { // TODO: make this configurable Issuer: "garm", }, - UserID: UserID(ctx), - TokenID: tokenID, - IsAdmin: IsAdmin(ctx), - FullName: FullName(ctx), + UserID: UserID(ctx), + TokenID: tokenID, + IsAdmin: IsAdmin(ctx), + FullName: FullName(ctx), + Generation: generation, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(a.cfg.Secret)) @@ -182,5 +184,5 @@ func (a *Authenticator) AuthenticateUser(ctx context.Context, info params.Passwo return ctx, runnerErrors.ErrUnauthorized } - return PopulateContext(ctx, user), nil + return PopulateContext(ctx, user, nil), nil } diff --git a/auth/context.go b/auth/context.go index 71d29d27..0d95be56 100644 --- a/auth/context.go +++ b/auth/context.go @@ -16,6 +16,7 @@ package auth import ( "context" + "time" runnerErrors "github.com/cloudbase/garm-provider-common/errors" "github.com/cloudbase/garm/params" @@ -28,9 +29,11 @@ const ( fullNameKey contextFlags = "full_name" readMetricsKey contextFlags = "read_metrics" // UserIDFlag is the User ID flag we set in the context - UserIDFlag contextFlags = "user_id" - isEnabledFlag contextFlags = "is_enabled" - jwtTokenFlag contextFlags = "jwt_token" + UserIDFlag contextFlags = "user_id" + isEnabledFlag contextFlags = "is_enabled" + jwtTokenFlag contextFlags = "jwt_token" + authExpiresFlag contextFlags = "auth_expires" + passwordGenerationFlag contextFlags = "password_generation" instanceIDKey contextFlags = "id" instanceNameKey contextFlags = "name" @@ -169,14 +172,43 @@ func PopulateInstanceContext(ctx context.Context, instance params.Instance) cont // PopulateContext sets the appropriate fields in the context, based on // the user object -func PopulateContext(ctx context.Context, user params.User) context.Context { +func PopulateContext(ctx context.Context, user params.User, authExpires *time.Time) context.Context { ctx = SetUserID(ctx, user.ID) ctx = SetAdmin(ctx, user.IsAdmin) ctx = SetIsEnabled(ctx, user.Enabled) ctx = SetFullName(ctx, user.FullName) + ctx = SetExpires(ctx, authExpires) + ctx = SetPasswordGeneration(ctx, user.Generation) return ctx } +func SetExpires(ctx context.Context, expires *time.Time) context.Context { + if expires == nil { + return ctx + } + return context.WithValue(ctx, authExpiresFlag, expires) +} + +func Expires(ctx context.Context) *time.Time { + elem := ctx.Value(authExpiresFlag) + if elem == nil { + return nil + } + return elem.(*time.Time) +} + +func SetPasswordGeneration(ctx context.Context, val uint) context.Context { + return context.WithValue(ctx, passwordGenerationFlag, val) +} + +func PasswordGeneration(ctx context.Context) uint { + elem := ctx.Value(passwordGenerationFlag) + if elem == nil { + return 0 + } + return elem.(uint) +} + // SetFullName sets the user full name in the context func SetFullName(ctx context.Context, fullName string) context.Context { return context.WithValue(ctx, fullNameKey, fullName) diff --git a/auth/jwt.go b/auth/jwt.go index d463df2c..e9b5745f 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -21,6 +21,7 @@ import ( "log/slog" "net/http" "strings" + "time" jwt "github.com/golang-jwt/jwt/v5" @@ -37,6 +38,7 @@ type JWTClaims struct { FullName string `json:"full_name"` IsAdmin bool `json:"is_admin"` ReadMetrics bool `json:"read_metrics"` + Generation uint `json:"generation"` jwt.RegisteredClaims } @@ -69,7 +71,18 @@ func (amw *jwtMiddleware) claimsToContext(ctx context.Context, claims *JWTClaims return ctx, runnerErrors.ErrUnauthorized } - ctx = PopulateContext(ctx, userInfo) + var expiresAt *time.Time + if claims.ExpiresAt != nil { + expires := claims.ExpiresAt.Time.UTC() + expiresAt = &expires + } + + if userInfo.Generation != claims.Generation { + // Password was reset since token was issued. Invalidate. + return ctx, runnerErrors.ErrUnauthorized + } + + ctx = PopulateContext(ctx, userInfo, expiresAt) return ctx, nil } diff --git a/cmd/garm-cli/cmd/log.go b/cmd/garm-cli/cmd/log.go index ccd55ca6..b0862db9 100644 --- a/cmd/garm-cli/cmd/log.go +++ b/cmd/garm-cli/cmd/log.go @@ -16,6 +16,7 @@ import ( "github.com/cloudbase/garm-provider-common/util" apiParams "github.com/cloudbase/garm/apiserver/params" + garmWs "github.com/cloudbase/garm/websocket" ) var logCmd = &cobra.Command{ @@ -66,7 +67,9 @@ var logCmd = &cobra.Command{ for { _, message, err := c.ReadMessage() if err != nil { - slog.With(slog.Any("error", err)).Error("reading log message") + if garmWs.IsErrorOfInterest(err) { + slog.With(slog.Any("error", err)).Error("reading log message") + } return } fmt.Println(util.SanitizeLogEntry(string(message))) diff --git a/cmd/garm/main.go b/cmd/garm/main.go index d8eed80c..ec5d74d6 100644 --- a/cmd/garm/main.go +++ b/cmd/garm/main.go @@ -322,7 +322,7 @@ func main() { slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed") } - slog.With(slog.Any("error", err)).ErrorContext(ctx, "waiting for runner to stop") + slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop") if err := runner.Wait(); err != nil { slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers") os.Exit(1) diff --git a/database/sql/github_test.go b/database/sql/github_test.go index 101ac411..b0399a68 100644 --- a/database/sql/github_test.go +++ b/database/sql/github_test.go @@ -284,7 +284,7 @@ func (s *GithubTestSuite) TestCreateCredentials() { func (s *GithubTestSuite) TestCreateCredentialsFailsOnDuplicateCredentials() { ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "testuser", s.db, s.T()) - testUserCtx := auth.PopulateContext(context.Background(), testUser) + testUserCtx := auth.PopulateContext(context.Background(), testUser, nil) credParams := params.CreateGithubCredentialsParams{ Name: testCredsName, @@ -313,8 +313,8 @@ func (s *GithubTestSuite) TestNormalUsersCanOnlySeeTheirOwnCredentialsAdminCanSe ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "testuser1", s.db, s.T()) testUser2 := garmTesting.CreateGARMTestUser(ctx, "testuser2", s.db, s.T()) - testUserCtx := auth.PopulateContext(context.Background(), testUser) - testUser2Ctx := auth.PopulateContext(context.Background(), testUser2) + testUserCtx := auth.PopulateContext(context.Background(), testUser, nil) + testUser2Ctx := auth.PopulateContext(context.Background(), testUser2, nil) credParams := params.CreateGithubCredentialsParams{ Name: testCredsName, @@ -370,7 +370,7 @@ func (s *GithubTestSuite) TestGetGithubCredentialsFailsWhenCredentialsDontExist( func (s *GithubTestSuite) TestGetGithubCredentialsByNameReturnsOnlyCurrentUserCredentials() { ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "test-user1", s.db, s.T()) - testUserCtx := auth.PopulateContext(context.Background(), testUser) + testUserCtx := auth.PopulateContext(context.Background(), testUser, nil) credParams := params.CreateGithubCredentialsParams{ Name: testCredsName, @@ -472,7 +472,7 @@ func (s *GithubTestSuite) TestDeleteGithubCredentials() { func (s *GithubTestSuite) TestDeleteGithubCredentialsByNonAdminUser() { ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "test-user4", s.db, s.T()) - testUserCtx := auth.PopulateContext(context.Background(), testUser) + testUserCtx := auth.PopulateContext(context.Background(), testUser, nil) credParams := params.CreateGithubCredentialsParams{ Name: testCredsName, @@ -682,7 +682,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsForNonExistingCredentials() func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAdminUser() { ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T()) - testUserCtx := auth.PopulateContext(context.Background(), testUser) + testUserCtx := auth.PopulateContext(context.Background(), testUser, nil) credParams := params.CreateGithubCredentialsParams{ Name: testCredsName, @@ -711,7 +711,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAd func (s *GithubTestSuite) TestAdminUserCanUpdateAnyGithubCredentials() { ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T()) testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T()) - testUserCtx := auth.PopulateContext(context.Background(), testUser) + testUserCtx := auth.PopulateContext(context.Background(), testUser, nil) credParams := params.CreateGithubCredentialsParams{ Name: testCredsName, diff --git a/database/sql/models.go b/database/sql/models.go index 5486dd5b..7c62ea97 100644 --- a/database/sql/models.go +++ b/database/sql/models.go @@ -195,12 +195,13 @@ type Instance struct { type User struct { Base - Username string `gorm:"uniqueIndex;varchar(64)"` - FullName string `gorm:"type:varchar(254)"` - Email string `gorm:"type:varchar(254);unique;index:idx_email"` - Password string `gorm:"type:varchar(60)"` - IsAdmin bool - Enabled bool + Username string `gorm:"uniqueIndex;varchar(64)"` + FullName string `gorm:"type:varchar(254)"` + Email string `gorm:"type:varchar(254);unique;index:idx_email"` + Password string `gorm:"type:varchar(60)"` + Generation uint + IsAdmin bool + Enabled bool } type ControllerInfo struct { diff --git a/database/sql/sql.go b/database/sql/sql.go index 680b9115..1a024516 100644 --- a/database/sql/sql.go +++ b/database/sql/sql.go @@ -239,7 +239,7 @@ func (s *sqlDatabase) migrateCredentialsToDB() (err error) { // user. GARM is not yet multi-user, so it's safe to assume we only have this // one user. adminCtx := context.Background() - adminCtx = auth.PopulateContext(adminCtx, adminUser) + adminCtx = auth.PopulateContext(adminCtx, adminUser, nil) slog.Info("migrating credentials to DB") slog.Info("creating github endpoints table") diff --git a/database/sql/users.go b/database/sql/users.go index 5fc47564..7d604a83 100644 --- a/database/sql/users.go +++ b/database/sql/users.go @@ -26,7 +26,7 @@ import ( "github.com/cloudbase/garm/params" ) -func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) { +func (s *sqlDatabase) getUserByUsernameOrEmail(tx *gorm.DB, user string) (User, error) { field := "username" if util.IsValidEmail(user) { field = "email" @@ -34,7 +34,7 @@ func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) { query := fmt.Sprintf("%s = ?", field) var dbUser User - q := s.conn.Model(&User{}).Where(query, user).First(&dbUser) + q := tx.Model(&User{}).Where(query, user).First(&dbUser) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return User{}, runnerErrors.ErrNotFound @@ -44,9 +44,9 @@ func (s *sqlDatabase) getUserByUsernameOrEmail(user string) (User, error) { return dbUser, nil } -func (s *sqlDatabase) getUserByID(userID string) (User, error) { +func (s *sqlDatabase) getUserByID(tx *gorm.DB, userID string) (User, error) { var dbUser User - q := s.conn.Model(&User{}).Where("id = ?", userID).First(&dbUser) + q := tx.Model(&User{}).Where("id = ?", userID).First(&dbUser) if q.Error != nil { if errors.Is(q.Error, gorm.ErrRecordNotFound) { return User{}, runnerErrors.ErrNotFound @@ -57,20 +57,9 @@ func (s *sqlDatabase) getUserByID(userID string) (User, error) { } func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) (params.User, error) { - if user.Username == "" || user.Email == "" { - return params.User{}, runnerErrors.NewBadRequestError("missing username or email") + if user.Username == "" || user.Email == "" || user.Password == "" { + return params.User{}, runnerErrors.NewBadRequestError("missing username, password or email") } - if _, err := s.getUserByUsernameOrEmail(user.Username); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) { - return params.User{}, runnerErrors.NewConflictError("username already exists") - } - if _, err := s.getUserByUsernameOrEmail(user.Email); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) { - return params.User{}, runnerErrors.NewConflictError("email already exists") - } - - if s.HasAdminUser(context.Background()) && user.IsAdmin { - return params.User{}, runnerErrors.NewBadRequestError("admin user already exists") - } - newUser := User{ Username: user.Username, Password: user.Password, @@ -79,22 +68,42 @@ func (s *sqlDatabase) CreateUser(_ context.Context, user params.NewUserParams) ( Email: user.Email, IsAdmin: user.IsAdmin, } + err := s.conn.Transaction(func(tx *gorm.DB) error { + if _, err := s.getUserByUsernameOrEmail(tx, user.Username); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) { + return runnerErrors.NewConflictError("username already exists") + } + if _, err := s.getUserByUsernameOrEmail(tx, user.Email); err == nil || !errors.Is(err, runnerErrors.ErrNotFound) { + return runnerErrors.NewConflictError("email already exists") + } - q := s.conn.Save(&newUser) - if q.Error != nil { - return params.User{}, errors.Wrap(q.Error, "creating user") + if s.hasAdmin(tx) && user.IsAdmin { + return runnerErrors.NewBadRequestError("admin user already exists") + } + + q := tx.Save(&newUser) + if q.Error != nil { + return errors.Wrap(q.Error, "creating user") + } + return nil + }) + if err != nil { + return params.User{}, errors.Wrap(err, "creating user") } return s.sqlToParamsUser(newUser), nil } -func (s *sqlDatabase) HasAdminUser(_ context.Context) bool { +func (s *sqlDatabase) hasAdmin(tx *gorm.DB) bool { var user User - q := s.conn.Model(&User{}).Where("is_admin = ?", true).First(&user) + q := tx.Model(&User{}).Where("is_admin = ?", true).First(&user) return q.Error == nil } +func (s *sqlDatabase) HasAdminUser(_ context.Context) bool { + return s.hasAdmin(s.conn) +} + func (s *sqlDatabase) GetUser(_ context.Context, user string) (params.User, error) { - dbUser, err := s.getUserByUsernameOrEmail(user) + dbUser, err := s.getUserByUsernameOrEmail(s.conn, user) if err != nil { return params.User{}, errors.Wrap(err, "fetching user") } @@ -102,7 +111,7 @@ func (s *sqlDatabase) GetUser(_ context.Context, user string) (params.User, erro } func (s *sqlDatabase) GetUserByID(_ context.Context, userID string) (params.User, error) { - dbUser, err := s.getUserByID(userID) + dbUser, err := s.getUserByID(s.conn, userID) if err != nil { return params.User{}, errors.Wrap(err, "fetching user") } @@ -110,27 +119,35 @@ func (s *sqlDatabase) GetUserByID(_ context.Context, userID string) (params.User } func (s *sqlDatabase) UpdateUser(_ context.Context, user string, param params.UpdateUserParams) (params.User, error) { - dbUser, err := s.getUserByUsernameOrEmail(user) - if err != nil { - return params.User{}, errors.Wrap(err, "fetching user") - } + var err error + var dbUser User + err = s.conn.Transaction(func(tx *gorm.DB) error { + dbUser, err = s.getUserByUsernameOrEmail(tx, user) + if err != nil { + return errors.Wrap(err, "fetching user") + } - if param.FullName != "" { - dbUser.FullName = param.FullName - } + if param.FullName != "" { + dbUser.FullName = param.FullName + } - if param.Enabled != nil { - dbUser.Enabled = *param.Enabled - } + if param.Enabled != nil { + dbUser.Enabled = *param.Enabled + } - if param.Password != "" { - dbUser.Password = param.Password - } + if param.Password != "" { + dbUser.Password = param.Password + dbUser.Generation++ + } - if q := s.conn.Save(&dbUser); q.Error != nil { - return params.User{}, errors.Wrap(q.Error, "saving user") + if q := tx.Save(&dbUser); q.Error != nil { + return errors.Wrap(q.Error, "saving user") + } + return nil + }) + if err != nil { + return params.User{}, errors.Wrap(err, "updating user") } - return s.sqlToParamsUser(dbUser), nil } diff --git a/database/sql/users_test.go b/database/sql/users_test.go index ec7c0889..627c4b93 100644 --- a/database/sql/users_test.go +++ b/database/sql/users_test.go @@ -145,7 +145,7 @@ func (s *UserTestSuite) TestCreateUserMissingUsernameEmail() { _, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams) s.Require().NotNil(err) - s.Require().Equal(("missing username or email"), err.Error()) + s.Require().Equal(("missing username, password or email"), err.Error()) } func (s *UserTestSuite) TestCreateUserUsernameAlreadyExist() { @@ -154,7 +154,7 @@ func (s *UserTestSuite) TestCreateUserUsernameAlreadyExist() { _, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams) s.Require().NotNil(err) - s.Require().Equal(("username already exists"), err.Error()) + s.Require().Equal(("creating user: username already exists"), err.Error()) } func (s *UserTestSuite) TestCreateUserEmailAlreadyExist() { @@ -163,10 +163,11 @@ func (s *UserTestSuite) TestCreateUserEmailAlreadyExist() { _, err := s.Store.CreateUser(context.Background(), s.Fixtures.NewUserParams) s.Require().NotNil(err) - s.Require().Equal(("email already exists"), err.Error()) + s.Require().Equal(("creating user: email already exists"), err.Error()) } func (s *UserTestSuite) TestCreateUserDBCreateErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE username = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")). WithArgs(s.Fixtures.NewUserParams.Username, 1). @@ -175,7 +176,6 @@ func (s *UserTestSuite) TestCreateUserDBCreateErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE email = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")). WithArgs(s.Fixtures.NewUserParams.Email, 1). WillReturnRows(sqlmock.NewRows([]string{"id"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec("INSERT INTO `users`"). WillReturnError(fmt.Errorf("creating user mock error")) @@ -183,9 +183,9 @@ func (s *UserTestSuite) TestCreateUserDBCreateErr() { _, err := s.StoreSQLMocked.CreateUser(context.Background(), s.Fixtures.NewUserParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating user: creating user mock error", err.Error()) + s.Require().Equal("creating user: creating user: creating user mock error", err.Error()) + s.assertSQLMockExpectations() } func (s *UserTestSuite) TestHasAdminUserNoAdmin() { @@ -253,15 +253,15 @@ func (s *UserTestSuite) TestUpdateUserNotFound() { _, err := s.Store.UpdateUser(context.Background(), "dummy-user", s.Fixtures.UpdateUserParams) s.Require().NotNil(err) - s.Require().Equal("fetching user: not found", err.Error()) + s.Require().Equal("updating user: fetching user: not found", err.Error()) } func (s *UserTestSuite) TestUpdateUserDBSaveErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` WHERE username = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?")). WithArgs(s.Fixtures.Users[0].ID, 1). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Users[0].ID)) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(("UPDATE `users` SET")). WillReturnError(fmt.Errorf("saving user mock error")) @@ -271,7 +271,7 @@ func (s *UserTestSuite) TestUpdateUserDBSaveErr() { s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving user: saving user mock error", err.Error()) + s.Require().Equal("updating user: saving user: saving user mock error", err.Error()) } func TestUserTestSuite(t *testing.T) { diff --git a/database/sql/util.go b/database/sql/util.go index dd861197..063c68a6 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -316,15 +316,16 @@ func (s *sqlDatabase) sqlToCommonRepository(repo Repository, detailed bool) (par func (s *sqlDatabase) sqlToParamsUser(user User) params.User { return params.User{ - ID: user.ID.String(), - CreatedAt: user.CreatedAt, - UpdatedAt: user.UpdatedAt, - Email: user.Email, - Username: user.Username, - FullName: user.FullName, - Password: user.Password, - Enabled: user.Enabled, - IsAdmin: user.IsAdmin, + ID: user.ID.String(), + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + Email: user.Email, + Username: user.Username, + FullName: user.FullName, + Password: user.Password, + Enabled: user.Enabled, + IsAdmin: user.IsAdmin, + Generation: user.Generation, } } diff --git a/database/watcher/filters.go b/database/watcher/filters.go index ffff9320..3838d04d 100644 --- a/database/watcher/filters.go +++ b/database/watcher/filters.go @@ -104,10 +104,19 @@ func WithEntityFilter(entity params.GithubEntity) dbCommon.PayloadFilterFunc { var ok bool switch payload.EntityType { case dbCommon.RepositoryEntityType: + if entity.EntityType != params.GithubEntityTypeRepository { + return false + } ent, ok = payload.Payload.(params.Repository) case dbCommon.OrganizationEntityType: + if entity.EntityType != params.GithubEntityTypeOrganization { + return false + } ent, ok = payload.Payload.(params.Organization) case dbCommon.EnterpriseEntityType: + if entity.EntityType != params.GithubEntityTypeEnterprise { + return false + } ent, ok = payload.Payload.(params.Enterprise) default: return false @@ -165,3 +174,17 @@ func WithGithubCredentialsFilter(creds params.GithubCredentials) dbCommon.Payloa return credsPayload.ID == creds.ID } } + +// WithUserIDFilter returns a filter function that filters payloads by user ID. +func WithUserIDFilter(userID string) dbCommon.PayloadFilterFunc { + return func(payload dbCommon.ChangePayload) bool { + if payload.EntityType != dbCommon.UserEntityType { + return false + } + userPayload, ok := payload.Payload.(params.User) + if !ok { + return false + } + return userPayload.ID == userID + } +} diff --git a/internal/testing/testing.go b/internal/testing/testing.go index 6e76956f..1b937b6c 100644 --- a/internal/testing/testing.go +++ b/internal/testing/testing.go @@ -57,7 +57,7 @@ func ImpersonateAdminContext(ctx context.Context, db common.Store, s *testing.T) s.Fatalf("failed to create admin user: %v", err) } } - ctx = auth.PopulateContext(ctx, adminUser) + ctx = auth.PopulateContext(ctx, adminUser, nil) return ctx } diff --git a/params/params.go b/params/params.go index e7bdf869..a0d01ce9 100644 --- a/params/params.go +++ b/params/params.go @@ -543,9 +543,11 @@ type User struct { Email string `json:"email"` Username string `json:"username"` FullName string `json:"full_name"` - Password string `json:"-"` Enabled bool `json:"enabled"` IsAdmin bool `json:"is_admin"` + // Do not serialize sensitive info. + Password string `json:"-"` + Generation uint `json:"-"` } // JWTResponse holds the JWT token returned as a result of a diff --git a/websocket/client.go b/websocket/client.go index 69812a0d..d7bb9f6a 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -1,11 +1,20 @@ package websocket import ( + "context" + "fmt" "log/slog" + "sync" "time" "github.com/google/uuid" "github.com/gorilla/websocket" + "github.com/pkg/errors" + + "github.com/cloudbase/garm/auth" + "github.com/cloudbase/garm/database/common" + "github.com/cloudbase/garm/database/watcher" + "github.com/cloudbase/garm/params" ) const ( @@ -22,13 +31,34 @@ const ( maxMessageSize = 1024 ) -func NewClient(conn *websocket.Conn, hub *Hub) (*Client, error) { +type HandleWebsocketMessage func([]byte) error + +func NewClient(ctx context.Context, conn *websocket.Conn) (*Client, error) { clientID := uuid.New() + consumerID := fmt.Sprintf("ws-client-watcher-%s", clientID.String()) + + user := auth.UserID(ctx) + if user == "" { + return nil, fmt.Errorf("user not found in context") + } + generation := auth.PasswordGeneration(ctx) + + consumer, err := watcher.RegisterConsumer( + ctx, consumerID, + watcher.WithUserIDFilter(user), + ) + if err != nil { + return nil, errors.Wrap(err, "registering consumer") + } return &Client{ - id: clientID.String(), - conn: conn, - hub: hub, - send: make(chan []byte, 100), + id: clientID.String(), + conn: conn, + ctx: ctx, + userID: user, + passwordGeneration: generation, + consumer: consumer, + done: make(chan struct{}), + send: make(chan []byte, 100), }, nil } @@ -37,21 +67,84 @@ type Client struct { conn *websocket.Conn // Buffered channel of outbound messages. send chan []byte + mux sync.Mutex + ctx context.Context + + userID string + passwordGeneration uint + consumer common.Consumer + + messageHandler HandleWebsocketMessage + + running bool + done chan struct{} +} + +func (c *Client) ID() string { + return c.id +} + +func (c *Client) Stop() { + c.mux.Lock() + defer c.mux.Unlock() + + if !c.running { + return + } + + c.running = false + c.conn.Close() + close(c.send) + close(c.done) +} - hub *Hub +func (c *Client) Done() <-chan struct{} { + return c.done } -func (c *Client) Go() { +func (c *Client) SetMessageHandler(handler HandleWebsocketMessage) { + c.mux.Lock() + defer c.mux.Unlock() + c.messageHandler = handler +} + +func (c *Client) Start() error { + c.mux.Lock() + defer c.mux.Unlock() + + c.running = true + + go c.runWatcher() go c.clientReader() go c.clientWriter() + + return nil +} + +func (c *Client) Write(msg []byte) (int, error) { + c.mux.Lock() + defer c.mux.Unlock() + + if !c.running { + return 0, fmt.Errorf("client is stopped") + } + + tmp := make([]byte, len(msg)) + copy(tmp, msg) + + select { + case <-time.After(5 * time.Second): + return 0, fmt.Errorf("timed out sending message to client") + case c.send <- tmp: + } + return len(tmp), nil } // clientReader waits for options changes from the client. The client can at any time // change the log level and binary name it watches. func (c *Client) clientReader() { defer func() { - c.hub.unregister <- c - c.conn.Close() + c.Stop() }() c.conn.SetReadLimit(maxMessageSize) if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { @@ -64,10 +157,19 @@ func (c *Client) clientReader() { return nil }) for { - mt, _, err := c.conn.ReadMessage() + mt, data, err := c.conn.ReadMessage() if err != nil { + if IsErrorOfInterest(err) { + slog.ErrorContext(c.ctx, "error reading websocket message", slog.Any("error", err)) + } break } + + if c.messageHandler != nil { + if err := c.messageHandler(data); err != nil { + slog.ErrorContext(c.ctx, "error handling message", slog.Any("error", err)) + } + } if mt == websocket.CloseMessage { break } @@ -78,9 +180,14 @@ func (c *Client) clientReader() { func (c *Client) clientWriter() { ticker := time.NewTicker(pingPeriod) defer func() { + c.Stop() ticker.Stop() - c.conn.Close() }() + var authExpires time.Time + expires := auth.Expires(c.ctx) + if expires != nil { + authExpires = *expires + } for { select { case message, ok := <-c.send: @@ -90,13 +197,17 @@ func (c *Client) clientWriter() { if !ok { // The hub closed the channel. if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil { - slog.With(slog.Any("error", err)).Error("failed to write message") + if IsErrorOfInterest(err) { + slog.With(slog.Any("error", err)).Error("failed to write message") + } } return } if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { - slog.With(slog.Any("error", err)).Error("error sending message") + if IsErrorOfInterest(err) { + slog.With(slog.Any("error", err)).Error("error sending message") + } return } case <-ticker.C: @@ -104,8 +215,81 @@ func (c *Client) clientWriter() { slog.With(slog.Any("error", err)).Error("failed to set write deadline") } if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + if IsErrorOfInterest(err) { + slog.With(slog.Any("error", err)).Error("failed to write ping message") + } + return + } + case <-c.ctx.Done(): + return + case <-time.After(time.Until(authExpires)): + // Auth has expired + slog.DebugContext(c.ctx, "auth expired, closing connection") + return + } + } +} + +func (c *Client) runWatcher() { + defer func() { + c.Stop() + }() + for { + select { + case <-c.Done(): + return + case <-c.ctx.Done(): + return + case event, ok := <-c.consumer.Watch(): + if !ok { + slog.InfoContext(c.ctx, "watcher closed") return } + go func(event common.ChangePayload) { + if event.EntityType != common.UserEntityType { + return + } + + user, ok := event.Payload.(params.User) + if !ok { + slog.ErrorContext(c.ctx, "failed to cast payload to user") + return + } + + if user.ID != c.userID { + return + } + + if user.Generation != c.passwordGeneration { + slog.InfoContext(c.ctx, "password generation mismatch; closing connection") + c.Stop() + } + }(event) + } + } +} + +func IsErrorOfInterest(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, websocket.ErrCloseSent) { + return false + } + + if errors.Is(err, websocket.ErrBadHandshake) { + return false + } + + asCloseErr, ok := err.(*websocket.CloseError) + if ok { + switch asCloseErr.Code { + case websocket.CloseNormalClosure, websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, websocket.CloseAbnormalClosure: + return false } } + + return true } diff --git a/websocket/websocket.go b/websocket/websocket.go index 1286222f..18b56585 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -3,19 +3,18 @@ package websocket import ( "context" "fmt" + "log/slog" "sync" "time" ) func NewHub(ctx context.Context) *Hub { return &Hub{ - clients: map[string]*Client{}, - broadcast: make(chan []byte, 100), - register: make(chan *Client, 100), - unregister: make(chan *Client, 100), - ctx: ctx, - closed: make(chan struct{}), - quit: make(chan struct{}), + clients: map[string]*Client{}, + broadcast: make(chan []byte, 100), + ctx: ctx, + closed: make(chan struct{}), + quit: make(chan struct{}), } } @@ -29,12 +28,6 @@ type Hub struct { // Inbound messages from the clients. broadcast chan []byte - // Register requests from the clients. - register chan *Client - - // Unregister requests from clients. - unregister chan *Client - mux sync.Mutex once sync.Once } @@ -49,22 +42,6 @@ func (h *Hub) run() { return case <-h.ctx.Done(): return - case client := <-h.register: - if client != nil { - h.mux.Lock() - h.clients[client.id] = client - h.mux.Unlock() - } - case client := <-h.unregister: - if client != nil { - h.mux.Lock() - if _, ok := h.clients[client.id]; ok { - client.conn.Close() - close(client.send) - delete(h.clients, client.id) - } - h.mux.Unlock() - } case message := <-h.broadcast: staleClients := []string{} for id, client := range h.clients { @@ -73,9 +50,7 @@ func (h *Hub) run() { continue } - select { - case client.send <- message: - case <-time.After(5 * time.Second): + if _, err := client.Write(message); err != nil { staleClients = append(staleClients, id) } } @@ -97,7 +72,35 @@ func (h *Hub) run() { } func (h *Hub) Register(client *Client) error { - h.register <- client + if client == nil { + return nil + } + h.mux.Lock() + defer h.mux.Unlock() + cli, ok := h.clients[client.ID()] + if ok { + if cli != nil { + return fmt.Errorf("client already registered") + } + } + slog.DebugContext(h.ctx, "registering client", "client_id", client.ID()) + h.clients[client.id] = client + return nil +} + +func (h *Hub) Unregister(client *Client) error { + if client == nil { + return nil + } + h.mux.Lock() + defer h.mux.Unlock() + cli, ok := h.clients[client.ID()] + if ok { + cli.Stop() + slog.DebugContext(h.ctx, "unregistering client", "client_id", cli.ID()) + delete(h.clients, cli.ID()) + slog.DebugContext(h.ctx, "current client count", "count", len(h.clients)) + } return nil }