diff --git a/server/handlers.go b/server/handlers.go index 6521bf6a93..5954820caa 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -534,23 +534,6 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, "connector_id", authReq.ConnectorID, "username", claims.Username, "preferred_username", claims.PreferredUsername, "email", email, "groups", claims.Groups) - // we can skip the redirect to /approval and go ahead and send code if it's not required - if s.skipApproval && !authReq.ForceApprovalPrompt { - return "", true, nil - } - - // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original - // flow would be unable to poll for the result at the /approval endpoint - h := hmac.New(sha256.New, authReq.HMACKey) - h.Write([]byte(authReq.ID)) - mac := h.Sum(nil) - - returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + "&hmac=" + base64.RawURLEncoding.EncodeToString(mac) - _, ok := conn.(connector.RefreshConnector) - if !ok { - return returnURL, false, nil - } - offlineAccessRequested := false for _, scope := range authReq.Scopes { if scope == scopeOfflineAccess { @@ -558,45 +541,55 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, break } } - if !offlineAccessRequested { - return returnURL, false, nil - } + _, canRefresh := conn.(connector.RefreshConnector) - // Try to retrieve an existing OfflineSession object for the corresponding user. - session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID) - if err != nil { - if err != storage.ErrNotFound { - s.logger.ErrorContext(ctx, "failed to get offline session", "err", err) - return "", false, err - } - offlineSessions := storage.OfflineSessions{ - UserID: identity.UserID, - ConnID: authReq.ConnectorID, - Refresh: make(map[string]*storage.RefreshTokenRef), - ConnectorData: identity.ConnectorData, - } + if offlineAccessRequested && canRefresh { + // Try to retrieve an existing OfflineSession object for the corresponding user. + session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID) + switch { + case err != nil && err == storage.ErrNotFound: + offlineSessions := storage.OfflineSessions{ + UserID: identity.UserID, + ConnID: authReq.ConnectorID, + Refresh: make(map[string]*storage.RefreshTokenRef), + ConnectorData: identity.ConnectorData, + } - // Create a new OfflineSession object for the user and add a reference object for - // the newly received refreshtoken. - if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { - s.logger.ErrorContext(ctx, "failed to create offline session", "err", err) + // Create a new OfflineSession object for the user and add a reference object for + // the newly received refreshtoken. + if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { + s.logger.ErrorContext(ctx, "failed to create offline session", "err", err) + return "", false, err + } + case err == nil: + // Update existing OfflineSession obj with new RefreshTokenRef. + if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if len(identity.ConnectorData) > 0 { + old.ConnectorData = identity.ConnectorData + } + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "failed to update offline session", "err", err) + return "", false, err + } + default: + s.logger.ErrorContext(ctx, "failed to get offline session", "err", err) return "", false, err } - - return returnURL, false, nil } - // Update existing OfflineSession obj with new RefreshTokenRef. - if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { - if len(identity.ConnectorData) > 0 { - old.ConnectorData = identity.ConnectorData - } - return old, nil - }); err != nil { - s.logger.ErrorContext(ctx, "failed to update offline session", "err", err) - return "", false, err + // we can skip the redirect to /approval and go ahead and send code if it's not required + if s.skipApproval && !authReq.ForceApprovalPrompt { + return "", true, nil } + // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original + // flow would be unable to poll for the result at the /approval endpoint + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + mac := h.Sum(nil) + + returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + "&hmac=" + base64.RawURLEncoding.EncodeToString(mac) return returnURL, false, nil } diff --git a/server/handlers_test.go b/server/handlers_test.go index d32101b1cf..08b02f75c9 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -519,7 +519,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { Scopes: []string{"offline_access"}, }, expectedRes: "/auth/mockPw/cb", - offlineSessionCreated: false, + offlineSessionCreated: true, }, }